import_stdlib!();
pub use num_bigint::{BigInt, BigUint, Sign};
use crate::{
CBOR, CBORCase, Error, Result, TAG_NEGATIVE_BIGNUM, TAG_POSITIVE_BIGNUM,
Tag,
};
fn validate_bignum_magnitude(bytes: &[u8], is_negative: bool) -> Result<()> {
if is_negative {
if bytes.is_empty() {
return Err(Error::NonCanonicalNumeric);
}
if bytes.len() > 1 && bytes[0] == 0 {
return Err(Error::NonCanonicalNumeric);
}
} else {
if !bytes.is_empty() && bytes[0] == 0 {
return Err(Error::NonCanonicalNumeric);
}
}
Ok(())
}
fn strip_leading_zeros(bytes: &[u8]) -> &[u8] {
let start = bytes.iter().position(|&b| b != 0).unwrap_or(bytes.len());
&bytes[start..]
}
impl From<BigUint> for CBOR {
fn from(value: BigUint) -> Self {
let bytes = value.to_bytes_be();
let stripped = strip_leading_zeros(&bytes);
let byte_string = CBOR::to_byte_string(stripped);
CBOR::to_tagged_value(Tag::with_value(TAG_POSITIVE_BIGNUM), byte_string)
}
}
impl From<&BigUint> for CBOR {
fn from(value: &BigUint) -> Self { value.clone().into() }
}
impl From<BigInt> for CBOR {
fn from(value: BigInt) -> Self {
let (sign, magnitude) = value.into_parts();
match sign {
Sign::NoSign | Sign::Plus => {
CBOR::from(magnitude)
}
Sign::Minus => {
let n = magnitude - 1u32;
let bytes = n.to_bytes_be();
let stripped = strip_leading_zeros(&bytes);
let content = if stripped.is_empty() {
CBOR::to_byte_string([0u8])
} else {
CBOR::to_byte_string(stripped)
};
CBOR::to_tagged_value(
Tag::with_value(TAG_NEGATIVE_BIGNUM),
content,
)
}
}
}
}
impl From<&BigInt> for CBOR {
fn from(value: &BigInt) -> Self { value.clone().into() }
}
pub fn biguint_from_untagged_cbor(cbor: CBOR) -> Result<BigUint> {
let bytes = cbor.try_into_byte_string()?;
validate_bignum_magnitude(&bytes, false)?;
Ok(BigUint::from_bytes_be(&bytes))
}
pub fn bigint_from_negative_untagged_cbor(cbor: CBOR) -> Result<BigInt> {
let bytes = cbor.try_into_byte_string()?;
validate_bignum_magnitude(&bytes, true)?;
let n = BigUint::from_bytes_be(&bytes);
let magnitude = n + 1u32;
Ok(BigInt::from_biguint(Sign::Minus, magnitude))
}
impl TryFrom<CBOR> for BigUint {
type Error = Error;
fn try_from(cbor: CBOR) -> Result<Self> {
match cbor.into_case() {
CBORCase::Unsigned(n) => Ok(BigUint::from(n)),
CBORCase::Negative(_) => Err(Error::OutOfRange),
CBORCase::Tagged(tag, inner) => {
let tag_value = tag.value();
if tag_value == TAG_POSITIVE_BIGNUM {
let bytes = inner.try_into_byte_string()?;
validate_bignum_magnitude(&bytes, false)?;
Ok(BigUint::from_bytes_be(&bytes))
} else if tag_value == TAG_NEGATIVE_BIGNUM {
Err(Error::OutOfRange)
} else {
Err(Error::WrongType)
}
}
CBORCase::Simple(_) => {
Err(Error::WrongType)
}
_ => Err(Error::WrongType),
}
}
}
impl TryFrom<CBOR> for BigInt {
type Error = Error;
fn try_from(cbor: CBOR) -> Result<Self> {
match cbor.into_case() {
CBORCase::Unsigned(n) => Ok(BigInt::from(n)),
CBORCase::Negative(n) => {
let magnitude = BigUint::from(n) + 1u32;
Ok(BigInt::from_biguint(Sign::Minus, magnitude))
}
CBORCase::Tagged(tag, inner) => {
let tag_value = tag.value();
if tag_value == TAG_POSITIVE_BIGNUM {
let bytes = inner.try_into_byte_string()?;
validate_bignum_magnitude(&bytes, false)?;
let magnitude = BigUint::from_bytes_be(&bytes);
if magnitude == BigUint::ZERO {
Ok(BigInt::from(0))
} else {
Ok(BigInt::from_biguint(Sign::Plus, magnitude))
}
} else if tag_value == TAG_NEGATIVE_BIGNUM {
let bytes = inner.try_into_byte_string()?;
validate_bignum_magnitude(&bytes, true)?;
let n = BigUint::from_bytes_be(&bytes);
let magnitude = n + 1u32;
Ok(BigInt::from_biguint(Sign::Minus, magnitude))
} else {
Err(Error::WrongType)
}
}
CBORCase::Simple(_) => {
Err(Error::WrongType)
}
_ => Err(Error::WrongType),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_biguint_zero() {
let big = BigUint::from(0u32);
let cbor = CBOR::from(big.clone());
assert_eq!(cbor.diagnostic(), "2(h'')");
let decoded: BigUint = cbor.try_into().unwrap();
assert_eq!(decoded, big);
}
#[test]
fn test_biguint_one() {
let big = BigUint::from(1u32);
let cbor = CBOR::from(big.clone());
assert_eq!(cbor.diagnostic(), "2(h'01')");
let decoded: BigUint = cbor.try_into().unwrap();
assert_eq!(decoded, big);
}
#[test]
fn test_biguint_255() {
let big = BigUint::from(255u32);
let cbor = CBOR::from(big.clone());
assert_eq!(cbor.diagnostic(), "2(h'ff')");
let decoded: BigUint = cbor.try_into().unwrap();
assert_eq!(decoded, big);
}
#[test]
fn test_biguint_256() {
let big = BigUint::from(256u32);
let cbor = CBOR::from(big.clone());
assert_eq!(cbor.diagnostic(), "2(h'0100')");
let decoded: BigUint = cbor.try_into().unwrap();
assert_eq!(decoded, big);
}
#[test]
fn test_bigint_zero() {
let big = BigInt::from(0);
let cbor = CBOR::from(big.clone());
assert_eq!(cbor.diagnostic(), "2(h'')");
let decoded: BigInt = cbor.try_into().unwrap();
assert_eq!(decoded, big);
}
#[test]
fn test_bigint_positive() {
let big = BigInt::from(256);
let cbor = CBOR::from(big.clone());
assert_eq!(cbor.diagnostic(), "2(h'0100')");
let decoded: BigInt = cbor.try_into().unwrap();
assert_eq!(decoded, big);
}
#[test]
fn test_bigint_negative_one() {
let big = BigInt::from(-1);
let cbor = CBOR::from(big.clone());
assert_eq!(cbor.diagnostic(), "3(h'00')");
let decoded: BigInt = cbor.try_into().unwrap();
assert_eq!(decoded, big);
}
#[test]
fn test_bigint_negative_two() {
let big = BigInt::from(-2);
let cbor = CBOR::from(big.clone());
assert_eq!(cbor.diagnostic(), "3(h'01')");
let decoded: BigInt = cbor.try_into().unwrap();
assert_eq!(decoded, big);
}
#[test]
fn test_bigint_negative_256() {
let big = BigInt::from(-256);
let cbor = CBOR::from(big.clone());
assert_eq!(cbor.diagnostic(), "3(h'ff')");
let decoded: BigInt = cbor.try_into().unwrap();
assert_eq!(decoded, big);
}
#[test]
fn test_bigint_negative_257() {
let big = BigInt::from(-257);
let cbor = CBOR::from(big.clone());
assert_eq!(cbor.diagnostic(), "3(h'0100')");
let decoded: BigInt = cbor.try_into().unwrap();
assert_eq!(decoded, big);
}
#[test]
fn test_decode_plain_unsigned_to_biguint() {
let cbor = CBOR::from(12345u64);
let big: BigUint = cbor.try_into().unwrap();
assert_eq!(big, BigUint::from(12345u64));
}
#[test]
fn test_decode_plain_unsigned_to_bigint() {
let cbor = CBOR::from(12345u64);
let big: BigInt = cbor.try_into().unwrap();
assert_eq!(big, BigInt::from(12345));
}
#[test]
fn test_decode_plain_negative_to_bigint() {
let cbor = CBOR::from(-12345i64);
let big: BigInt = cbor.try_into().unwrap();
assert_eq!(big, BigInt::from(-12345));
}
#[test]
fn test_decode_plain_negative_to_biguint_fails() {
let cbor = CBOR::from(-1i64);
let result: Result<BigUint> = cbor.try_into();
assert!(matches!(result, Err(Error::OutOfRange)));
}
#[test]
fn test_decode_tag3_to_biguint_fails() {
let big = BigInt::from(-1);
let cbor = CBOR::from(big);
let result: Result<BigUint> = cbor.try_into();
assert!(matches!(result, Err(Error::OutOfRange)));
}
}