cmail_rpgp/
util.rs

1//! # Utilities
2
3use std::ops::{Range, RangeFrom, RangeTo};
4use std::{hash, io};
5
6use byteorder::{BigEndian, WriteBytesExt};
7use nom::bytes::streaming::take_while1;
8use nom::character::is_alphanumeric;
9use nom::character::streaming::line_ending;
10use nom::combinator::map;
11use nom::multi::many0;
12use nom::number::streaming::{be_u32, be_u8};
13use nom::sequence::preceded;
14use nom::{error_position, Err, InputIter, InputLength, Slice};
15
16use crate::errors::{self, IResult};
17
18#[inline]
19pub fn u8_as_usize(a: u8) -> usize {
20    a as usize
21}
22
23#[inline]
24pub fn u16_as_usize(a: u16) -> usize {
25    a as usize
26}
27
28#[inline]
29pub fn u32_as_usize(a: u32) -> usize {
30    a as usize
31}
32
33#[inline]
34pub fn is_base64_token(c: u8) -> bool {
35    is_alphanumeric(c) || c == b'/' || c == b'+' || c == b'=' || c == b'\n' || c == b'\r'
36}
37
38pub fn prefixed(input: &[u8]) -> IResult<&[u8], &[u8]> {
39    preceded(many0(line_ending), take_while1(is_base64_token))(input)
40}
41
42/// Recognizes one or more body tokens
43pub fn base64_token(input: &[u8]) -> nom::IResult<&[u8], &[u8]> {
44    let input_length = input.input_len();
45    if input_length == 0 {
46        return Err(Err::Incomplete(nom::Needed::Unknown));
47    }
48
49    for (idx, item) in input.iter_indices() {
50        if !is_base64_token(item) {
51            if idx == 0 {
52                return Err(Err::Error(error_position!(
53                    input,
54                    nom::error::ErrorKind::AlphaNumeric
55                )));
56            } else {
57                return Ok((input.slice(idx..), input.slice(0..idx)));
58            }
59        }
60    }
61    Ok((input.slice(input_length..), input))
62}
63
64/// Returns the bit length of a given slice.
65#[inline]
66pub fn bit_size(val: &[u8]) -> usize {
67    if val.is_empty() {
68        0
69    } else {
70        (val.len() * 8) - val[0].leading_zeros() as usize
71    }
72}
73
74#[inline]
75pub fn strip_leading_zeros(bytes: &[u8]) -> &[u8] {
76    bytes
77        .iter()
78        .position(|b| b != &0)
79        .map_or(&[], |offset| &bytes[offset..])
80}
81
82#[inline]
83pub fn strip_leading_zeros_vec(bytes: &mut Vec<u8>) {
84    if let Some(offset) = bytes.iter_mut().position(|b| b != &0) {
85        bytes.drain(..offset);
86    }
87}
88
89/// Convert a slice into an array.
90pub fn clone_into_array<A, T>(slice: &[T]) -> A
91where
92    A: Sized + Default + AsMut<[T]>,
93    T: Clone,
94{
95    let mut a = Default::default();
96    <A as AsMut<[T]>>::as_mut(&mut a).clone_from_slice(slice);
97    a
98}
99
100// Parse a packet length.
101pub(crate) fn packet_length(i: &[u8]) -> IResult<&[u8], usize> {
102    let (i, olen) = be_u8(i)?;
103    match olen {
104        // One-Octet Lengths
105        0..=191 => Ok((i, olen as usize)),
106        // Two-Octet Lengths
107        192..=254 => map(be_u8, |a| ((olen as usize - 192) << 8) + 192 + a as usize)(i),
108        // Five-Octet Lengths
109        255 => map(be_u32, u32_as_usize)(i),
110    }
111}
112
113/// Write packet length, including the prefix for lengths larger or equal than 8384.
114pub fn write_packet_length(len: usize, writer: &mut impl io::Write) -> errors::Result<()> {
115    if len < 192 {
116        writer.write_u8(len.try_into()?)?;
117    } else if len < 8384 {
118        writer.write_u8((((len - 192) / 256) + 192) as u8)?;
119        writer.write_u8(((len - 192) % 256) as u8)?;
120    } else {
121        writer.write_u8(0xFF)?;
122        writer.write_u32::<BigEndian>(len as u32)?;
123    }
124
125    Ok(())
126}
127
128/// Return the length of the remaining input.
129// Adapted from https://github.com/Geal/nom/pull/684
130#[inline]
131pub fn rest_len<T>(input: T) -> IResult<T, usize>
132where
133    T: Slice<Range<usize>> + Slice<RangeFrom<usize>> + Slice<RangeTo<usize>>,
134    T: InputLength,
135{
136    let len = input.input_len();
137    Ok((input, len))
138}
139
140#[macro_export]
141macro_rules! impl_try_from_into {
142    ($enum_name:ident, $( $name:ident => $variant_type:ty ),*) => {
143       $(
144           impl std::convert::TryFrom<$enum_name> for $variant_type {
145               // TODO: Proper error
146               type Error = $crate::errors::Error;
147
148               fn try_from(other: $enum_name) -> ::std::result::Result<$variant_type, Self::Error> {
149                   if let $enum_name::$name(value) = other {
150                       Ok(value)
151                   } else {
152                      Err(format_err!("invalid packet type: {:?}", other))
153                   }
154               }
155           }
156
157           impl From<$variant_type> for $enum_name {
158               fn from(other: $variant_type) -> $enum_name {
159                   $enum_name::$name(other)
160               }
161           }
162       )*
163    }
164}
165
166pub struct TeeWriter<'a, A, B> {
167    a: &'a mut A,
168    b: &'a mut B,
169}
170
171impl<'a, A, B> TeeWriter<'a, A, B> {
172    pub fn new(a: &'a mut A, b: &'a mut B) -> Self {
173        TeeWriter { a, b }
174    }
175}
176
177impl<'a, A: hash::Hasher, B: io::Write> io::Write for TeeWriter<'a, A, B> {
178    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
179        self.a.write(buf);
180        write_all(&mut self.b, buf)?;
181
182        Ok(buf.len())
183    }
184
185    fn flush(&mut self) -> io::Result<()> {
186        self.b.flush()?;
187
188        Ok(())
189    }
190}
191
192/// The same as the std lib, but doesn't choke on write 0. This is a hack, to be compatible with
193/// rust-base64.
194pub fn write_all(writer: &mut impl io::Write, mut buf: &[u8]) -> io::Result<()> {
195    while !buf.is_empty() {
196        match writer.write(buf) {
197            Ok(0) => {}
198            Ok(n) => buf = &buf[n..],
199            Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
200            Err(e) => return Err(e),
201        }
202    }
203    Ok(())
204}
205
206#[cfg(test)]
207mod tests {
208    #![allow(clippy::unwrap_used)]
209
210    use super::*;
211
212    #[test]
213    fn test_write_packet_len() {
214        let mut res = Vec::new();
215        write_packet_length(1173, &mut res).unwrap();
216        assert_eq!(hex::encode(res), "c3d5");
217    }
218
219    #[test]
220    fn test_write_packet_length() {
221        let mut res = Vec::new();
222        write_packet_length(12870, &mut res).unwrap();
223        assert_eq!(hex::encode(res), "ff00003246");
224    }
225
226    #[test]
227    fn test_strip_leading_zeros_with_all_zeros() {
228        let buf = [0, 0, 0];
229        let stripped = strip_leading_zeros(&buf);
230        assert_eq!(stripped, &[]);
231    }
232
233    #[test]
234    fn test_strip_leading_zeros_vec() {
235        let mut vec = vec![0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
236        strip_leading_zeros_vec(&mut vec);
237        assert_eq!(vec, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
238    }
239}