1use std::ops::{Range, RangeFrom, RangeTo};
4use std::{hash, io};
5
6use byteorder::{BigEndian, WriteBytesExt};
7use nom::bytes::streaming::take_while1;
8use nom::character::is_alphanumeric;
9use nom::character::streaming::line_ending;
10use nom::combinator::map;
11use nom::multi::many0;
12use nom::number::streaming::{be_u32, be_u8};
13use nom::sequence::preceded;
14use nom::{error_position, Err, InputIter, InputLength, Slice};
15
16use crate::errors::{self, IResult};
17
18#[inline]
19pub fn u8_as_usize(a: u8) -> usize {
20 a as usize
21}
22
23#[inline]
24pub fn u16_as_usize(a: u16) -> usize {
25 a as usize
26}
27
28#[inline]
29pub fn u32_as_usize(a: u32) -> usize {
30 a as usize
31}
32
33#[inline]
34pub fn is_base64_token(c: u8) -> bool {
35 is_alphanumeric(c) || c == b'/' || c == b'+' || c == b'=' || c == b'\n' || c == b'\r'
36}
37
38pub fn prefixed(input: &[u8]) -> IResult<&[u8], &[u8]> {
39 preceded(many0(line_ending), take_while1(is_base64_token))(input)
40}
41
42pub fn base64_token(input: &[u8]) -> nom::IResult<&[u8], &[u8]> {
44 let input_length = input.input_len();
45 if input_length == 0 {
46 return Err(Err::Incomplete(nom::Needed::Unknown));
47 }
48
49 for (idx, item) in input.iter_indices() {
50 if !is_base64_token(item) {
51 if idx == 0 {
52 return Err(Err::Error(error_position!(
53 input,
54 nom::error::ErrorKind::AlphaNumeric
55 )));
56 } else {
57 return Ok((input.slice(idx..), input.slice(0..idx)));
58 }
59 }
60 }
61 Ok((input.slice(input_length..), input))
62}
63
64#[inline]
66pub fn bit_size(val: &[u8]) -> usize {
67 if val.is_empty() {
68 0
69 } else {
70 (val.len() * 8) - val[0].leading_zeros() as usize
71 }
72}
73
74#[inline]
75pub fn strip_leading_zeros(bytes: &[u8]) -> &[u8] {
76 bytes
77 .iter()
78 .position(|b| b != &0)
79 .map_or(&[], |offset| &bytes[offset..])
80}
81
82#[inline]
83pub fn strip_leading_zeros_vec(bytes: &mut Vec<u8>) {
84 if let Some(offset) = bytes.iter_mut().position(|b| b != &0) {
85 bytes.drain(..offset);
86 }
87}
88
89pub fn clone_into_array<A, T>(slice: &[T]) -> A
91where
92 A: Sized + Default + AsMut<[T]>,
93 T: Clone,
94{
95 let mut a = Default::default();
96 <A as AsMut<[T]>>::as_mut(&mut a).clone_from_slice(slice);
97 a
98}
99
100pub(crate) fn packet_length(i: &[u8]) -> IResult<&[u8], usize> {
102 let (i, olen) = be_u8(i)?;
103 match olen {
104 0..=191 => Ok((i, olen as usize)),
106 192..=254 => map(be_u8, |a| ((olen as usize - 192) << 8) + 192 + a as usize)(i),
108 255 => map(be_u32, u32_as_usize)(i),
110 }
111}
112
113pub fn write_packet_length(len: usize, writer: &mut impl io::Write) -> errors::Result<()> {
115 if len < 192 {
116 writer.write_u8(len.try_into()?)?;
117 } else if len < 8384 {
118 writer.write_u8((((len - 192) / 256) + 192) as u8)?;
119 writer.write_u8(((len - 192) % 256) as u8)?;
120 } else {
121 writer.write_u8(0xFF)?;
122 writer.write_u32::<BigEndian>(len as u32)?;
123 }
124
125 Ok(())
126}
127
128#[inline]
131pub fn rest_len<T>(input: T) -> IResult<T, usize>
132where
133 T: Slice<Range<usize>> + Slice<RangeFrom<usize>> + Slice<RangeTo<usize>>,
134 T: InputLength,
135{
136 let len = input.input_len();
137 Ok((input, len))
138}
139
140#[macro_export]
141macro_rules! impl_try_from_into {
142 ($enum_name:ident, $( $name:ident => $variant_type:ty ),*) => {
143 $(
144 impl std::convert::TryFrom<$enum_name> for $variant_type {
145 type Error = $crate::errors::Error;
147
148 fn try_from(other: $enum_name) -> ::std::result::Result<$variant_type, Self::Error> {
149 if let $enum_name::$name(value) = other {
150 Ok(value)
151 } else {
152 Err(format_err!("invalid packet type: {:?}", other))
153 }
154 }
155 }
156
157 impl From<$variant_type> for $enum_name {
158 fn from(other: $variant_type) -> $enum_name {
159 $enum_name::$name(other)
160 }
161 }
162 )*
163 }
164}
165
166pub struct TeeWriter<'a, A, B> {
167 a: &'a mut A,
168 b: &'a mut B,
169}
170
171impl<'a, A, B> TeeWriter<'a, A, B> {
172 pub fn new(a: &'a mut A, b: &'a mut B) -> Self {
173 TeeWriter { a, b }
174 }
175}
176
177impl<'a, A: hash::Hasher, B: io::Write> io::Write for TeeWriter<'a, A, B> {
178 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
179 self.a.write(buf);
180 write_all(&mut self.b, buf)?;
181
182 Ok(buf.len())
183 }
184
185 fn flush(&mut self) -> io::Result<()> {
186 self.b.flush()?;
187
188 Ok(())
189 }
190}
191
192pub fn write_all(writer: &mut impl io::Write, mut buf: &[u8]) -> io::Result<()> {
195 while !buf.is_empty() {
196 match writer.write(buf) {
197 Ok(0) => {}
198 Ok(n) => buf = &buf[n..],
199 Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
200 Err(e) => return Err(e),
201 }
202 }
203 Ok(())
204}
205
206#[cfg(test)]
207mod tests {
208 #![allow(clippy::unwrap_used)]
209
210 use super::*;
211
212 #[test]
213 fn test_write_packet_len() {
214 let mut res = Vec::new();
215 write_packet_length(1173, &mut res).unwrap();
216 assert_eq!(hex::encode(res), "c3d5");
217 }
218
219 #[test]
220 fn test_write_packet_length() {
221 let mut res = Vec::new();
222 write_packet_length(12870, &mut res).unwrap();
223 assert_eq!(hex::encode(res), "ff00003246");
224 }
225
226 #[test]
227 fn test_strip_leading_zeros_with_all_zeros() {
228 let buf = [0, 0, 0];
229 let stripped = strip_leading_zeros(&buf);
230 assert_eq!(stripped, &[]);
231 }
232
233 #[test]
234 fn test_strip_leading_zeros_vec() {
235 let mut vec = vec![0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
236 strip_leading_zeros_vec(&mut vec);
237 assert_eq!(vec, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
238 }
239}