use crate::{CheckedSum, Decode, Encode, Error, Reader, Result, Writer};
use alloc::{boxed::Box, vec::Vec};
use core::fmt;
#[cfg(feature = "bigint")]
use crate::Uint;
#[cfg(feature = "ctutils")]
use ctutils::{Choice, CtEq};
#[cfg(any(feature = "bigint", feature = "zeroize"))]
use zeroize::Zeroize;
#[cfg(feature = "bigint")]
use zeroize::Zeroizing;
#[cfg_attr(not(feature = "ctutils"), derive(Clone))]
#[cfg_attr(feature = "ctutils", derive(Clone, Ord, PartialOrd))] pub struct Mpint {
inner: Box<[u8]>,
}
impl Mpint {
pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
bytes.try_into()
}
#[must_use]
pub fn from_positive_bytes(mut bytes: &[u8]) -> Self {
while bytes.first().copied() == Some(0) {
bytes = &bytes[1..];
}
let inner = match bytes.first().copied() {
Some(n) if n >= 0x80 => {
let mut inner = Vec::with_capacity(bytes.len().saturating_add(1));
inner.push(0);
inner.extend_from_slice(bytes);
inner
}
_ => Vec::from(bytes),
};
Self {
inner: inner.into_boxed_slice(),
}
}
#[must_use]
pub fn as_bytes(&self) -> &[u8] {
&self.inner
}
#[must_use]
pub fn as_positive_bytes(&self) -> Option<&[u8]> {
match self.as_bytes() {
[0x00, rest @ ..] => Some(rest),
[byte, ..] if *byte < 0x80 => Some(self.as_bytes()),
_ => None,
}
}
#[must_use]
pub fn is_positive(&self) -> bool {
self.as_positive_bytes().is_some()
}
}
impl AsRef<[u8]> for Mpint {
fn as_ref(&self) -> &[u8] {
self.as_bytes()
}
}
#[cfg(feature = "ctutils")]
impl CtEq for Mpint {
fn ct_eq(&self, other: &Self) -> Choice {
self.as_ref().ct_eq(other.as_ref())
}
}
#[cfg(feature = "ctutils")]
impl Eq for Mpint {}
#[cfg(feature = "ctutils")]
impl PartialEq for Mpint {
fn eq(&self, other: &Self) -> bool {
self.ct_eq(other).into()
}
}
impl Decode for Mpint {
type Error = Error;
fn decode(reader: &mut impl Reader) -> Result<Self> {
Vec::decode(reader)?.into_boxed_slice().try_into()
}
}
impl Encode for Mpint {
fn encoded_len(&self) -> Result<usize> {
[4, self.as_bytes().len()].checked_sum()
}
fn encode(&self, writer: &mut impl Writer) -> Result<()> {
self.as_bytes().encode(writer)?;
Ok(())
}
}
impl TryFrom<&[u8]> for Mpint {
type Error = Error;
fn try_from(bytes: &[u8]) -> Result<Self> {
Vec::from(bytes).into_boxed_slice().try_into()
}
}
impl TryFrom<Box<[u8]>> for Mpint {
type Error = Error;
fn try_from(bytes: Box<[u8]>) -> Result<Self> {
match &*bytes {
[0x00] => Err(Error::MpintEncoding),
[0x00, n, ..] if *n < 0x80 => Err(Error::MpintEncoding),
_ => Ok(Self { inner: bytes }),
}
}
}
#[cfg(feature = "zeroize")]
impl Zeroize for Mpint {
fn zeroize(&mut self) {
self.inner.zeroize();
}
}
impl fmt::Debug for Mpint {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Mpint({self:X})")
}
}
impl fmt::Display for Mpint {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{self:X}")
}
}
impl fmt::LowerHex for Mpint {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
for byte in self.as_bytes() {
write!(f, "{byte:02x}")?;
}
Ok(())
}
}
impl fmt::UpperHex for Mpint {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
for byte in self.as_bytes() {
write!(f, "{byte:02X}")?;
}
Ok(())
}
}
#[cfg(feature = "bigint")]
impl From<&Uint> for Mpint {
fn from(uint: &Uint) -> Mpint {
let bytes = Zeroizing::new(uint.to_be_bytes());
Mpint::from_positive_bytes(&bytes)
}
}
#[cfg(feature = "bigint")]
impl From<Uint> for Mpint {
fn from(uint: Uint) -> Mpint {
Mpint::from(&uint)
}
}
#[cfg(feature = "bigint")]
impl TryFrom<Mpint> for Uint {
type Error = Error;
fn try_from(mpint: Mpint) -> Result<Uint> {
Uint::try_from(&mpint)
}
}
#[cfg(feature = "bigint")]
impl TryFrom<&Mpint> for Uint {
type Error = Error;
fn try_from(mpint: &Mpint) -> Result<Uint> {
let bytes = mpint.as_positive_bytes().ok_or(Error::MpintEncoding)?;
Ok(Uint::from_be_slice_vartime(bytes))
}
}
#[cfg(test)]
mod tests {
use super::Mpint;
use hex_literal::hex;
#[test]
fn decode_0() {
let n = Mpint::from_bytes(b"").unwrap();
assert_eq!(b"", n.as_bytes());
}
#[test]
fn reject_extra_leading_zeroes() {
assert!(Mpint::from_bytes(&hex!("00")).is_err());
assert!(Mpint::from_bytes(&hex!("00 00")).is_err());
assert!(Mpint::from_bytes(&hex!("00 01")).is_err());
}
#[test]
fn decode_9a378f9b2e332a7() {
assert!(Mpint::from_bytes(&hex!("09 a3 78 f9 b2 e3 32 a7")).is_ok());
}
#[test]
fn decode_80() {
let n = Mpint::from_bytes(&hex!("00 80")).unwrap();
assert_eq!(&hex!("80"), n.as_positive_bytes().unwrap());
}
#[test]
fn from_positive_bytes_strips_leading_zeroes() {
assert_eq!(Mpint::from_positive_bytes(&hex!("00")).as_ref(), b"");
assert_eq!(Mpint::from_positive_bytes(&hex!("00 00")).as_ref(), b"");
assert_eq!(Mpint::from_positive_bytes(&hex!("00 01")).as_ref(), b"\x01");
}
#[test]
fn decode_neg_1234() {
let n = Mpint::from_bytes(&hex!("ed cc")).unwrap();
assert!(n.as_positive_bytes().is_none());
}
#[test]
fn decode_neg_deadbeef() {
let n = Mpint::from_bytes(&hex!("ff 21 52 41 11")).unwrap();
assert!(n.as_positive_bytes().is_none());
}
}