ssh_encoding/
mpint.rs

1//! Multiple precision integer
2
3use crate::{CheckedSum, Decode, Encode, Error, Reader, Result, Writer};
4use alloc::{boxed::Box, vec::Vec};
5use core::fmt;
6
7#[cfg(feature = "bigint")]
8use crate::{NonZeroUint, OddUint, Uint};
9
10#[cfg(feature = "subtle")]
11use subtle::{Choice, ConstantTimeEq};
12
13#[cfg(any(feature = "bigint", feature = "zeroize"))]
14use zeroize::Zeroize;
15#[cfg(feature = "bigint")]
16use zeroize::Zeroizing;
17
18/// Multiple precision integer, a.k.a. `mpint`.
19///
20/// Described in [RFC4251 ยง 5](https://datatracker.ietf.org/doc/html/rfc4251#section-5):
21///
22/// > Represents multiple precision integers in two's complement format,
23/// > stored as a string, 8 bits per byte, MSB first.  Negative numbers
24/// > have the value 1 as the most significant bit of the first byte of
25/// > the data partition.  If the most significant bit would be set for
26/// > a positive number, the number MUST be preceded by a zero byte.
27/// > Unnecessary leading bytes with the value 0 or 255 MUST NOT be
28/// > included.  The value zero MUST be stored as a string with zero
29/// > bytes of data.
30/// >
31/// > By convention, a number that is used in modular computations in
32/// > Z_n SHOULD be represented in the range 0 <= x < n.
33///
34/// ## Examples
35///
36/// | value (hex)     | representation (hex) |
37/// |-----------------|----------------------|
38/// | 0               | `00 00 00 00`
39/// | 9a378f9b2e332a7 | `00 00 00 08 09 a3 78 f9 b2 e3 32 a7`
40/// | 80              | `00 00 00 02 00 80`
41/// |-1234            | `00 00 00 02 ed cc`
42/// | -deadbeef       | `00 00 00 05 ff 21 52 41 11`
43#[cfg_attr(not(feature = "subtle"), derive(Clone))]
44#[cfg_attr(feature = "subtle", derive(Clone, Ord, PartialOrd))] // TODO: constant time (Partial)`Ord`?
45pub struct Mpint {
46    /// Inner big endian-serialized integer value
47    inner: Box<[u8]>,
48}
49
50impl Mpint {
51    /// Create a new multiple precision integer from the given
52    /// big endian-encoded byte slice.
53    ///
54    /// Note that this method expects a leading zero on positive integers whose
55    /// MSB is set, but does *NOT* expect a 4-byte length prefix.
56    pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
57        bytes.try_into()
58    }
59
60    /// Create a new multiple precision integer from the given big endian
61    /// encoded byte slice representing a positive integer.
62    ///
63    /// The input may begin with leading zeros, which will be stripped when
64    /// converted to [`Mpint`] encoding.
65    pub fn from_positive_bytes(mut bytes: &[u8]) -> Result<Self> {
66        let mut inner = Vec::with_capacity(bytes.len());
67
68        while bytes.first().copied() == Some(0) {
69            bytes = &bytes[1..];
70        }
71
72        match bytes.first().copied() {
73            Some(n) if n >= 0x80 => inner.push(0),
74            _ => (),
75        }
76
77        inner.extend_from_slice(bytes);
78        inner.into_boxed_slice().try_into()
79    }
80
81    /// Get the big integer data encoded as big endian bytes.
82    ///
83    /// This slice will contain a leading zero if the value is positive but the
84    /// MSB is also set. Use [`Mpint::as_positive_bytes`] to ensure the number
85    /// is positive and strip the leading zero byte if it exists.
86    pub fn as_bytes(&self) -> &[u8] {
87        &self.inner
88    }
89
90    /// Get the bytes of a positive integer.
91    ///
92    /// # Returns
93    /// - `Some(bytes)` if the number is positive. The leading zero byte will be stripped.
94    /// - `None` if the value is negative
95    pub fn as_positive_bytes(&self) -> Option<&[u8]> {
96        match self.as_bytes() {
97            [0x00, rest @ ..] => Some(rest),
98            [byte, ..] if *byte < 0x80 => Some(self.as_bytes()),
99            _ => None,
100        }
101    }
102
103    /// Is this [`Mpint`] positive?
104    pub fn is_positive(&self) -> bool {
105        self.as_positive_bytes().is_some()
106    }
107}
108
109impl AsRef<[u8]> for Mpint {
110    fn as_ref(&self) -> &[u8] {
111        self.as_bytes()
112    }
113}
114
115#[cfg(feature = "subtle")]
116impl ConstantTimeEq for Mpint {
117    fn ct_eq(&self, other: &Self) -> Choice {
118        self.as_ref().ct_eq(other.as_ref())
119    }
120}
121
122#[cfg(feature = "subtle")]
123impl Eq for Mpint {}
124
125#[cfg(feature = "subtle")]
126impl PartialEq for Mpint {
127    fn eq(&self, other: &Self) -> bool {
128        self.ct_eq(other).into()
129    }
130}
131
132impl Decode for Mpint {
133    type Error = Error;
134
135    fn decode(reader: &mut impl Reader) -> Result<Self> {
136        Vec::decode(reader)?.into_boxed_slice().try_into()
137    }
138}
139
140impl Encode for Mpint {
141    fn encoded_len(&self) -> Result<usize> {
142        [4, self.as_bytes().len()].checked_sum()
143    }
144
145    fn encode(&self, writer: &mut impl Writer) -> Result<()> {
146        self.as_bytes().encode(writer)?;
147        Ok(())
148    }
149}
150
151impl TryFrom<&[u8]> for Mpint {
152    type Error = Error;
153
154    fn try_from(bytes: &[u8]) -> Result<Self> {
155        Vec::from(bytes).into_boxed_slice().try_into()
156    }
157}
158
159impl TryFrom<Box<[u8]>> for Mpint {
160    type Error = Error;
161
162    fn try_from(bytes: Box<[u8]>) -> Result<Self> {
163        match &*bytes {
164            // Unnecessary leading 0
165            [0x00] => Err(Error::MpintEncoding),
166            // Unnecessary leading 0
167            [0x00, n, ..] if *n < 0x80 => Err(Error::MpintEncoding),
168            _ => Ok(Self { inner: bytes }),
169        }
170    }
171}
172
173#[cfg(feature = "zeroize")]
174impl Zeroize for Mpint {
175    fn zeroize(&mut self) {
176        self.inner.zeroize();
177    }
178}
179
180impl fmt::Debug for Mpint {
181    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
182        write!(f, "Mpint({self:X})")
183    }
184}
185
186impl fmt::Display for Mpint {
187    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
188        write!(f, "{self:X}")
189    }
190}
191
192impl fmt::LowerHex for Mpint {
193    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
194        for byte in self.as_bytes() {
195            write!(f, "{byte:02x}")?;
196        }
197        Ok(())
198    }
199}
200
201impl fmt::UpperHex for Mpint {
202    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
203        for byte in self.as_bytes() {
204            write!(f, "{byte:02X}")?;
205        }
206        Ok(())
207    }
208}
209
210#[cfg(feature = "bigint")]
211impl TryFrom<NonZeroUint> for Mpint {
212    type Error = Error;
213
214    fn try_from(uint: NonZeroUint) -> Result<Mpint> {
215        Mpint::try_from(&uint)
216    }
217}
218
219#[cfg(feature = "bigint")]
220impl TryFrom<&NonZeroUint> for Mpint {
221    type Error = Error;
222
223    fn try_from(uint: &NonZeroUint) -> Result<Mpint> {
224        Self::try_from(uint.as_ref())
225    }
226}
227
228#[cfg(feature = "bigint")]
229impl TryFrom<OddUint> for Mpint {
230    type Error = Error;
231
232    fn try_from(uint: OddUint) -> Result<Mpint> {
233        Mpint::try_from(&uint)
234    }
235}
236
237#[cfg(feature = "bigint")]
238impl TryFrom<&OddUint> for Mpint {
239    type Error = Error;
240
241    fn try_from(uint: &OddUint) -> Result<Mpint> {
242        Self::try_from(uint.as_ref())
243    }
244}
245
246#[cfg(feature = "bigint")]
247impl TryFrom<Uint> for Mpint {
248    type Error = Error;
249
250    fn try_from(uint: Uint) -> Result<Mpint> {
251        Mpint::try_from(&uint)
252    }
253}
254
255#[cfg(feature = "bigint")]
256impl TryFrom<&Uint> for Mpint {
257    type Error = Error;
258
259    fn try_from(uint: &Uint) -> Result<Mpint> {
260        let bytes = Zeroizing::new(uint.to_be_bytes());
261        Mpint::from_positive_bytes(&bytes)
262    }
263}
264
265#[cfg(feature = "bigint")]
266impl TryFrom<Mpint> for NonZeroUint {
267    type Error = Error;
268
269    fn try_from(mpint: Mpint) -> Result<NonZeroUint> {
270        NonZeroUint::try_from(&mpint)
271    }
272}
273
274#[cfg(feature = "bigint")]
275impl TryFrom<&Mpint> for NonZeroUint {
276    type Error = Error;
277
278    fn try_from(mpint: &Mpint) -> Result<NonZeroUint> {
279        let uint = Uint::try_from(mpint)?;
280        NonZeroUint::new(uint)
281            .into_option()
282            .ok_or(Error::MpintEncoding)
283    }
284}
285
286#[cfg(feature = "bigint")]
287impl TryFrom<Mpint> for OddUint {
288    type Error = Error;
289
290    fn try_from(mpint: Mpint) -> Result<OddUint> {
291        OddUint::try_from(&mpint)
292    }
293}
294
295#[cfg(feature = "bigint")]
296impl TryFrom<&Mpint> for OddUint {
297    type Error = Error;
298
299    fn try_from(mpint: &Mpint) -> Result<OddUint> {
300        let uint = Uint::try_from(mpint)?;
301        OddUint::new(uint).into_option().ok_or(Error::MpintEncoding)
302    }
303}
304
305#[cfg(feature = "bigint")]
306impl TryFrom<Mpint> for Uint {
307    type Error = Error;
308
309    fn try_from(mpint: Mpint) -> Result<Uint> {
310        Uint::try_from(&mpint)
311    }
312}
313
314#[cfg(feature = "bigint")]
315impl TryFrom<&Mpint> for Uint {
316    type Error = Error;
317
318    fn try_from(mpint: &Mpint) -> Result<Uint> {
319        let bytes = mpint.as_positive_bytes().ok_or(Error::MpintEncoding)?;
320        let bits_precision = bytes
321            .len()
322            .checked_mul(8)
323            .and_then(|n| u32::try_from(n).ok())
324            .ok_or(Error::MpintEncoding)?;
325
326        Ok(Uint::from_be_slice(bytes, bits_precision)?)
327    }
328}
329
330#[cfg(test)]
331mod tests {
332    use super::Mpint;
333    use hex_literal::hex;
334
335    #[test]
336    fn decode_0() {
337        let n = Mpint::from_bytes(b"").unwrap();
338        assert_eq!(b"", n.as_bytes())
339    }
340
341    #[test]
342    fn reject_extra_leading_zeroes() {
343        assert!(Mpint::from_bytes(&hex!("00")).is_err());
344        assert!(Mpint::from_bytes(&hex!("00 00")).is_err());
345        assert!(Mpint::from_bytes(&hex!("00 01")).is_err());
346    }
347
348    #[test]
349    fn decode_9a378f9b2e332a7() {
350        assert!(Mpint::from_bytes(&hex!("09 a3 78 f9 b2 e3 32 a7")).is_ok());
351    }
352
353    #[test]
354    fn decode_80() {
355        let n = Mpint::from_bytes(&hex!("00 80")).unwrap();
356
357        // Leading zero stripped
358        assert_eq!(&hex!("80"), n.as_positive_bytes().unwrap())
359    }
360    #[test]
361    fn from_positive_bytes_strips_leading_zeroes() {
362        assert_eq!(
363            Mpint::from_positive_bytes(&hex!("00")).unwrap().as_ref(),
364            b""
365        );
366        assert_eq!(
367            Mpint::from_positive_bytes(&hex!("00 00")).unwrap().as_ref(),
368            b""
369        );
370        assert_eq!(
371            Mpint::from_positive_bytes(&hex!("00 01")).unwrap().as_ref(),
372            b"\x01"
373        );
374    }
375
376    // TODO(tarcieri): drop support for negative numbers?
377    #[test]
378    fn decode_neg_1234() {
379        let n = Mpint::from_bytes(&hex!("ed cc")).unwrap();
380        assert!(n.as_positive_bytes().is_none());
381    }
382
383    // TODO(tarcieri): drop support for negative numbers?
384    #[test]
385    fn decode_neg_deadbeef() {
386        let n = Mpint::from_bytes(&hex!("ff 21 52 41 11")).unwrap();
387        assert!(n.as_positive_bytes().is_none());
388    }
389}