#![forbid(unsafe_code)]
pub use dashu_base::Sign;
pub use dashu_base::{
Abs, AbsOrd, BitTest, CubicRoot, DivEuclid, DivRem, DivRemAssign, DivRemEuclid, EstimatedLog2,
ExtendedGcd, Gcd, Inverse, PowerOfTwo, RemEuclid, Signed, SquareRoot, UnsignedAbs,
};
pub type OxiNumResult<T> = Result<T, OxiNumError>;
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum OxiNumError {
Parse(std::borrow::Cow<'static, str>),
Precision(std::borrow::Cow<'static, str>),
DivByZero,
Overflow(std::borrow::Cow<'static, str>),
InvalidRadix(u32),
Domain(std::borrow::Cow<'static, str>),
}
impl OxiNumError {
#[must_use]
pub fn context(self, ctx: impl AsRef<str>) -> Self {
let ctx = ctx.as_ref();
match self {
Self::Parse(s) => Self::Parse(format!("{ctx}: {s}").into()),
Self::Precision(s) => Self::Precision(format!("{ctx}: {s}").into()),
Self::Overflow(s) => Self::Overflow(format!("{ctx}: {s}").into()),
Self::Domain(s) => Self::Domain(format!("{ctx}: {s}").into()),
other => other,
}
}
}
impl std::fmt::Display for OxiNumError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Parse(s) => write!(f, "parse error: {s}"),
Self::Precision(s) => write!(f, "precision error: {s}"),
Self::DivByZero => write!(f, "division by zero"),
Self::Overflow(s) => write!(f, "overflow: {s}"),
Self::InvalidRadix(r) => write!(f, "invalid radix: {r} (must be 2..=36)"),
Self::Domain(s) => write!(f, "domain error: {s}"),
}
}
}
impl std::error::Error for OxiNumError {}
impl From<OxiNumError> for std::io::Error {
fn from(e: OxiNumError) -> Self {
std::io::Error::other(e.to_string())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ParseNumberError {
pub message: String,
pub line: usize,
pub column: usize,
}
impl ParseNumberError {
pub fn new(message: impl Into<String>, line: usize, column: usize) -> Self {
Self {
message: message.into(),
line,
column,
}
}
}
impl std::fmt::Display for ParseNumberError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"parse error at line {line}, column {column}: {message}",
line = self.line,
column = self.column,
message = self.message,
)
}
}
impl std::error::Error for ParseNumberError {}
impl From<ParseNumberError> for OxiNumError {
fn from(e: ParseNumberError) -> Self {
let ParseNumberError {
message,
line,
column,
} = e;
OxiNumError::Parse(format!("{message} (line {line}, col {column})").into())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum RoundingMode {
Up,
Down,
Ceiling,
Floor,
HalfUp,
HalfDown,
HalfEven,
Unnecessary,
}
impl std::fmt::Display for RoundingMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let name = match self {
Self::Up => "Up",
Self::Down => "Down",
Self::Ceiling => "Ceiling",
Self::Floor => "Floor",
Self::HalfUp => "HalfUp",
Self::HalfDown => "HalfDown",
Self::HalfEven => "HalfEven",
Self::Unnecessary => "Unnecessary",
};
f.write_str(name)
}
}
pub trait OxiNum: std::fmt::Display + std::fmt::Debug + Clone + PartialEq {
fn is_zero(&self) -> bool;
fn is_one(&self) -> bool;
}
pub trait OxiSigned: OxiNum {
fn signum(&self) -> Sign;
fn abs(&self) -> Self;
fn is_negative(&self) -> bool {
self.signum() == Sign::Negative
}
fn is_positive(&self) -> bool {
!self.is_zero() && self.signum() == Sign::Positive
}
}
pub trait OxiUnsigned: OxiNum {}
pub trait FromRadix: Sized {
fn from_radix(src: &str, radix: u32) -> OxiNumResult<Self>;
}
pub trait ToRadix {
fn to_radix(&self, radix: u32) -> OxiNumResult<String>;
}
pub trait Pow<Exp> {
type Output;
fn pow(&self, exp: Exp) -> Self::Output;
}
pub trait Roots {
fn sqrt(&self) -> Self;
fn cbrt(&self) -> Self;
fn nth_root(&self, n: u32) -> Self;
}
pub trait ModularArithmetic {
fn mod_add(&self, rhs: &Self, modulus: &Self) -> Self;
fn mod_sub(&self, rhs: &Self, modulus: &Self) -> Self;
fn mod_mul(&self, rhs: &Self, modulus: &Self) -> Self;
fn mod_pow(&self, exp: &Self, modulus: &Self) -> Self;
}
pub trait Primality {
fn is_probably_prime(&self, witnesses: u32) -> bool;
fn next_prime(&self) -> Self;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn error_display_parse() {
let e = OxiNumError::Parse("bad input".into());
assert!(e.to_string().contains("bad input"));
}
#[test]
fn error_display_precision() {
let e = OxiNumError::Precision("too low".into());
assert!(e.to_string().contains("too low"));
}
#[test]
fn error_display_div_by_zero() {
let e = OxiNumError::DivByZero;
assert_eq!(e.to_string(), "division by zero");
}
#[test]
fn error_display_overflow() {
let e = OxiNumError::Overflow("u64 max exceeded".into());
assert!(e.to_string().contains("u64 max exceeded"));
}
#[test]
fn error_display_invalid_radix() {
let e = OxiNumError::InvalidRadix(42);
assert!(e.to_string().contains("42"));
assert!(e.to_string().contains("must be 2..=36"));
}
#[test]
fn error_display_domain() {
let e = OxiNumError::Domain("sqrt of negative is undefined for real BigFloat".into());
assert_eq!(
e.to_string(),
"domain error: sqrt of negative is undefined for real BigFloat"
);
}
#[test]
fn context_prefixes_domain_message() {
let e = OxiNumError::Domain("sqrt of negative".into()).context("BigFloat::sqrt");
match &e {
OxiNumError::Domain(s) => assert_eq!(s, "BigFloat::sqrt: sqrt of negative"),
other => panic!("expected Domain, got {other:?}"),
}
assert!(e.to_string().contains("domain error:"));
assert!(e.to_string().contains("BigFloat::sqrt:"));
}
#[test]
fn error_into_io_error() {
let e = OxiNumError::DivByZero;
let io_err: std::io::Error = e.into();
assert_eq!(io_err.kind(), std::io::ErrorKind::Other);
assert!(io_err.to_string().contains("division by zero"));
}
#[test]
fn sign_positive() {
let s = Sign::Positive;
assert_eq!(s, Sign::Positive);
}
#[test]
fn rounding_mode_display() {
assert_eq!(RoundingMode::HalfEven.to_string(), "HalfEven");
assert_eq!(RoundingMode::Up.to_string(), "Up");
assert_eq!(RoundingMode::Unnecessary.to_string(), "Unnecessary");
}
#[test]
fn rounding_mode_equality() {
assert_eq!(RoundingMode::Floor, RoundingMode::Floor);
assert_ne!(RoundingMode::Up, RoundingMode::Down);
}
#[test]
fn error_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<OxiNumError>();
}
#[test]
fn error_size_is_small() {
let size = std::mem::size_of::<OxiNumError>();
assert!(size <= 32, "OxiNumError is {size} bytes, expected <= 32");
}
#[test]
fn oxinumerror_implements_std_error() {
fn assert_error<E: std::error::Error>(_: &E) {}
let e = OxiNumError::DivByZero;
assert_error(&e);
let _boxed: Box<dyn std::error::Error> = Box::new(OxiNumError::DivByZero);
}
#[test]
fn error_is_std_error() {
let e: Box<dyn std::error::Error> = Box::new(OxiNumError::Parse("test".into()));
assert!(e.to_string().contains("test"));
}
#[test]
fn oxi_num_result_alias() {
let ok: OxiNumResult<u32> = Ok(42);
assert_eq!(ok, Ok(42));
let err: OxiNumResult<u32> = Err(OxiNumError::DivByZero);
assert!(err.is_err());
}
#[test]
fn parse_number_error_constructs_and_displays() {
let pe = ParseNumberError::new("bad digit", 3, 7);
assert_eq!(pe.message, "bad digit");
assert_eq!(pe.line, 3);
assert_eq!(pe.column, 7);
let pe_disp = pe.to_string();
assert!(pe_disp.contains("line 3"), "got {pe_disp}");
assert!(pe_disp.contains("column 7"), "got {pe_disp}");
assert!(pe_disp.contains("bad digit"), "got {pe_disp}");
let oe: OxiNumError = pe.into();
match &oe {
OxiNumError::Parse(_) => {}
other => panic!("expected Parse, got {other:?}"),
}
let oe_disp = oe.to_string();
assert!(oe_disp.contains("bad digit"), "got {oe_disp}");
assert!(oe_disp.contains("line 3"), "got {oe_disp}");
assert!(oe_disp.contains("col 7"), "got {oe_disp}");
}
#[test]
fn parse_number_error_is_std_error() {
let boxed: Box<dyn std::error::Error> = Box::new(ParseNumberError::new("oops", 1, 1));
assert!(boxed.to_string().contains("oops"));
}
#[test]
fn context_prefixes_message_variants() {
let parse = OxiNumError::Parse("x".into()).context("at A");
match parse {
OxiNumError::Parse(ref s) => assert_eq!(s, "at A: x"),
other => panic!("expected Parse, got {other:?}"),
}
assert!(parse.to_string().contains("at A:"));
let precision = OxiNumError::Precision("y".into()).context("at B");
match precision {
OxiNumError::Precision(ref s) => assert_eq!(s, "at B: y"),
other => panic!("expected Precision, got {other:?}"),
}
let overflow = OxiNumError::Overflow("z".into()).context("at C");
match overflow {
OxiNumError::Overflow(ref s) => assert_eq!(s, "at C: z"),
other => panic!("expected Overflow, got {other:?}"),
}
}
#[test]
fn context_leaves_kindful_variants_unchanged() {
assert_eq!(
OxiNumError::DivByZero.context("ignored"),
OxiNumError::DivByZero,
);
assert_eq!(
OxiNumError::InvalidRadix(37).context("ignored"),
OxiNumError::InvalidRadix(37),
);
}
#[test]
fn existing_size_of_oxinumerror_unchanged() {
let size = std::mem::size_of::<OxiNumError>();
assert!(size <= 32, "OxiNumError is {size} bytes, expected <= 32");
}
}
#[cfg(test)]
mod proptests {
use super::*;
use proptest::prelude::*;
proptest! {
#[test]
fn parse_display_roundtrip(s in any::<String>()) {
let e = OxiNumError::Parse(s.clone().into());
prop_assert!(e.to_string().contains(&s));
}
#[test]
fn precision_display_roundtrip(s in any::<String>()) {
let e = OxiNumError::Precision(s.clone().into());
prop_assert!(e.to_string().contains(&s));
}
#[test]
fn overflow_display_roundtrip(s in any::<String>()) {
let e = OxiNumError::Overflow(s.clone().into());
prop_assert!(e.to_string().contains(&s));
}
#[test]
fn domain_display_roundtrip(s in any::<String>()) {
let e = OxiNumError::Domain(s.clone().into());
prop_assert!(e.to_string().contains(&s));
}
#[test]
fn parse_number_error_display_roundtrip(
msg in any::<String>(),
line in 1usize..=10_000,
column in 1usize..=10_000,
) {
let pe = ParseNumberError::new(msg.clone(), line, column);
let disp = pe.to_string();
prop_assert!(disp.contains(&msg));
prop_assert!(disp.contains(&line.to_string()));
prop_assert!(disp.contains(&column.to_string()));
}
}
}
#[cfg(all(test, feature = "serde"))]
mod serde_tests {
use super::*;
fn roundtrip_oxi(original: OxiNumError) {
let json = serde_json::to_string(&original).expect("serialize OxiNumError");
let back: OxiNumError = serde_json::from_str(&json).expect("deserialize OxiNumError");
assert_eq!(back, original, "round-trip mismatch for {original:?}");
}
fn roundtrip_rounding(original: RoundingMode) {
let json = serde_json::to_string(&original).expect("serialize RoundingMode");
let back: RoundingMode = serde_json::from_str(&json).expect("deserialize RoundingMode");
assert_eq!(back, original, "round-trip mismatch for {original:?}");
}
#[test]
fn oxinum_error_json_roundtrip_all_variants() {
roundtrip_oxi(OxiNumError::Parse("e".into()));
roundtrip_oxi(OxiNumError::Precision("p".into()));
roundtrip_oxi(OxiNumError::DivByZero);
roundtrip_oxi(OxiNumError::Overflow("o".into()));
roundtrip_oxi(OxiNumError::InvalidRadix(3));
roundtrip_oxi(OxiNumError::Domain("sqrt of negative".into()));
}
#[test]
fn rounding_mode_json_roundtrip_all_variants() {
roundtrip_rounding(RoundingMode::Up);
roundtrip_rounding(RoundingMode::Down);
roundtrip_rounding(RoundingMode::Ceiling);
roundtrip_rounding(RoundingMode::Floor);
roundtrip_rounding(RoundingMode::HalfUp);
roundtrip_rounding(RoundingMode::HalfDown);
roundtrip_rounding(RoundingMode::HalfEven);
roundtrip_rounding(RoundingMode::Unnecessary);
}
#[test]
fn parse_number_error_json_roundtrip() {
let pe = ParseNumberError::new("bad", 4, 9);
let json = serde_json::to_string(&pe).expect("serialize ParseNumberError");
let back: ParseNumberError =
serde_json::from_str(&json).expect("deserialize ParseNumberError");
assert_eq!(back, pe);
}
}