nintypes 0.2.11

Nintondo shared types
Documentation
use std::str::FromStr;

use primitive_types::U256;

use super::{fixed128::Fixed128, FixedParseErr};

/// WIP
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Default)]
pub struct Fixed256<const PRECISION: u8>(pub U256);

impl<const PRECISION: u8> Fixed256<PRECISION> {
    pub const ONE: Self = Self::from_x128(Fixed128::<PRECISION>::ONE);
    pub const MAX: Self = Self(U256::MAX);
    pub const ZERO: Self = Self(U256::zero());
    pub const MAX_X128: Self = Self::from_x128(Fixed128::<PRECISION>::MAX);

    pub fn from_int(v: U256) -> Self {
        Self(Self::ONE.0 * v)
    }

    pub const fn from_x128(v: Fixed128<PRECISION>) -> Self {
        let raw = v.into_raw();
        let mut ret = [0; 4];
        ret[0] = raw as u64;
        ret[1] = (raw >> 64) as u64;
        Self(U256(ret))
    }

    pub fn checked_x128(self) -> Option<Fixed128<PRECISION>> {
        if self > Self::MAX_X128 {
            return None;
        }
        Some(Fixed128::from_raw(self.0.as_u128()))
    }
    pub fn saturating_x128(self) -> Fixed128<PRECISION> {
        Fixed128::from_raw(self.0.min(Self::MAX_X128.0).as_u128())
    }

    pub fn int_part(self) -> U256 {
        self.0 / Self::ONE.0
    }
    pub fn frac_part(self) -> U256 {
        self.0 % Self::ONE.0
    }

    /// Parse string with possible loss of precision if too many numbers is specified. <br/>
    /// Regular from_str will return `Err(Fixed128ParseErr::Loss)`
    pub fn from_str_lossy(s: &str) -> Result<Self, FixedParseErr> {
        if let Some((int, frac)) = s.split_once('.') {
            let mut frac = frac.to_owned();

            let precision_diff = (PRECISION as isize) - (frac.len() as isize);
            if precision_diff > 0 {
                frac.push_str(&"0".repeat(precision_diff as usize));
            } else {
                frac = frac[..PRECISION as usize].to_owned();
            }

            let int = if int.is_empty() {
                U256::zero()
            } else {
                U256::from_str_radix(int, 10).map_err(|_| FixedParseErr::InvalidChars)?
            };
            let frac = if frac.is_empty() {
                U256::zero()
            } else {
                U256::from_str_radix(&frac, 10).map_err(|_| FixedParseErr::InvalidChars)?
            };

            if int > U256::MAX / Self::ONE.0 {
                return Err(FixedParseErr::TooLarge);
            }

            return Ok(Self(int * Self::ONE.0 + frac));
        }

        let int = U256::from_str_radix(s, 10).map_err(|_| FixedParseErr::InvalidChars)?;

        if int > U256::MAX / Self::ONE.0 {
            return Err(FixedParseErr::TooLarge);
        }

        Ok(Self::from_int(int))
    }
}

#[cfg(feature = "schema")]
impl<const PRECISION: u8> schemars::JsonSchema for Fixed256<PRECISION> {
    fn schema_name() -> std::borrow::Cow<'static, str> {
        "Fixed256".into()
    }

    fn json_schema(generator: &mut schemars::SchemaGenerator) -> schemars::Schema {
        generator.clone().into_root_schema_for::<rust_decimal::Decimal>()
    }
}

impl<const PRECISION: u8> std::fmt::Display for Fixed256<PRECISION> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        let int = self.int_part();
        let mut frac = self.frac_part().to_string();
        frac.insert_str(0, &"0".repeat((PRECISION as usize).saturating_sub(frac.len())));
        let frac = frac.trim_end_matches('0');
        if frac.is_empty() {
            return write!(f, "{int}");
        }
        write!(f, "{int}.{frac}")
    }
}

impl<const PRECISION: u8> std::fmt::Debug for Fixed256<PRECISION> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        std::fmt::Display::fmt(self, f)
    }
}

impl<const PRECISION: u8> FromStr for Fixed256<PRECISION> {
    type Err = FixedParseErr;

    #[allow(clippy::comparison_chain)]
    fn from_str(s: &str) -> Result<Self, Self::Err> {
        if let Some((int, frac)) = s.split_once('.') {
            let mut frac = frac.to_owned();

            let precision_diff = (PRECISION as isize) - (frac.len() as isize);
            if precision_diff > 0 {
                frac.push_str(&"0".repeat(precision_diff as usize));
            } else if precision_diff < 0 {
                return Err(FixedParseErr::Loss);
            }

            let int = if int.is_empty() {
                U256::zero()
            } else {
                U256::from_str_radix(int, 10).map_err(|_| FixedParseErr::InvalidChars)?
            };
            let frac = if frac.is_empty() {
                U256::zero()
            } else {
                U256::from_str_radix(&frac, 10).map_err(|_| FixedParseErr::InvalidChars)?
            };

            if int > U256::MAX / Self::ONE.0 {
                return Err(FixedParseErr::TooLarge);
            }

            return Ok(Self(int * Self::ONE.0 + frac));
        }

        let int = U256::from_str_radix(s, 10).map_err(|_| FixedParseErr::InvalidChars)?;

        if int > U256::MAX / Self::ONE.0 {
            return Err(FixedParseErr::TooLarge);
        }

        Ok(Self::from_int(int))
    }
}

impl<const PRECISION: u8> serde::Serialize for Fixed256<PRECISION> {
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
    where
        S: serde::Serializer,
    {
        serializer.serialize_str(&self.to_string())
    }
}

impl<'de, const PRECISION: u8> serde::Deserialize<'de> for Fixed256<PRECISION> {
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
    where
        D: serde::Deserializer<'de>,
    {
        String::deserialize(deserializer)?.parse::<Self>().map_err(serde::de::Error::custom)
    }
}