mysql_connector/connection/
parse_buf.rs

1use {super::Deserialize, crate::error::ProtocolError, std::io};
2
3#[derive(Debug, Clone)]
4pub(crate) struct ParseBuf<'a>(pub(crate) &'a [u8]);
5
6impl io::Read for ParseBuf<'_> {
7    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
8        let count = self.0.len().min(buf.len());
9        (buf[..count]).copy_from_slice(&self.0[..count]);
10        self.0 = &self.0[count..];
11        Ok(count)
12    }
13}
14
15impl<'a> ParseBuf<'a> {
16    #[inline]
17    pub fn is_empty(&self) -> bool {
18        self.len() == 0
19    }
20
21    #[inline]
22    pub fn len(&self) -> usize {
23        self.0.len()
24    }
25
26    #[inline]
27    pub fn skip(&mut self, cnt: usize) {
28        self.0 = &self.0[usize::min(cnt, self.0.len())..];
29    }
30
31    #[inline]
32    pub fn check_len(&self, len: usize) -> Result<(), ProtocolError> {
33        if self.len() < len {
34            Err(ProtocolError::eof())
35        } else {
36            Ok(())
37        }
38    }
39
40    /// Eats n bytes
41    ///
42    /// # Panic
43    ///
44    /// Will panic if `n > self.len()`
45    #[inline]
46    pub fn eat(&mut self, n: usize) -> &'a [u8] {
47        let (left, right) = self.0.split_at(n);
48        self.0 = right;
49        left
50    }
51
52    #[inline]
53    pub fn checked_eat(&mut self, n: usize) -> Result<&'a [u8], ProtocolError> {
54        if self.len() >= n {
55            Ok(self.eat(n))
56        } else {
57            Err(ProtocolError::eof())
58        }
59    }
60
61    #[inline]
62    pub fn eat_all(&mut self) -> &'a [u8] {
63        let value = self.0;
64        self.0 = &[];
65        value
66    }
67
68    #[inline]
69    pub fn parse_unchecked<T>(&mut self, ctx: T::Ctx) -> Result<T, ProtocolError>
70    where
71        T: Deserialize<'a>,
72    {
73        T::deserialize(self, ctx)
74    }
75
76    #[inline]
77    pub fn parse<T>(&mut self, ctx: T::Ctx) -> Result<T, ProtocolError>
78    where
79        T: Deserialize<'a>,
80    {
81        if let Some(size) = T::SIZE {
82            if self.len() < size {
83                return Err(ProtocolError::eof());
84            }
85        }
86        self.parse_unchecked(ctx)
87    }
88}
89
90macro_rules! parse_num {
91    ($t:ty) => {
92        paste::paste! {
93            #[allow(dead_code)]
94            pub fn [< eat_ $t >](&mut self) -> $t {
95                const SIZE: usize = std::mem::size_of::<$t>();
96                let bytes = self.eat(SIZE);
97                unsafe { <$t>::from_le_bytes(*(bytes as *const _ as *const [_; SIZE])) }
98            }
99
100            #[allow(dead_code)]
101            pub fn [< checked_eat_ $t >](&mut self) -> Result<$t, ProtocolError> {
102                const SIZE: usize = std::mem::size_of::<$t>();
103                let bytes = self.checked_eat(SIZE)?;
104                Ok(unsafe { <$t>::from_le_bytes(*(bytes as *const _ as *const [_; SIZE])) })
105            }
106        }
107    };
108    ($t:ty, $name:ident, $size:literal) => {
109        paste::paste! {
110            #[allow(dead_code)]
111            pub fn [< eat_ $name >](&mut self) -> $t {
112                let mut bytes = [0u8; std::mem::size_of::<$t>()];
113                for (i, b) in self.eat($size).iter().enumerate() {
114                    bytes[i] = *b;
115                }
116                <$t>::from_le_bytes(bytes)
117            }
118
119            #[allow(dead_code)]
120            pub fn [< checked_eat_ $name >](&mut self) -> Result<$t, ProtocolError> {
121                let mut bytes = [0u8; std::mem::size_of::<$t>()];
122                for (i, b) in self.checked_eat($size)?.iter().enumerate() {
123                    bytes[i] = *b;
124                }
125                Ok(<$t>::from_le_bytes(bytes))
126            }
127        }
128    };
129    ($($($t:ty)? $({$t2:ty, $n:ident, $s:literal})?),* $(,)?) => {
130        $(
131            $(parse_num!($t);)?
132            $(parse_num!($t2, $n, $s);)?
133        )*
134    };
135}
136
137impl ParseBuf<'_> {
138    parse_num!(u8, u16, {u32,u24,3}, u32, {u64,u40,5}, {u64,u48,6}, {u64,u56,7}, u64, u128);
139    parse_num!(i8, i16, {i32,i24,3}, i32, {i64,i40,5}, {i64,i48,6}, {i64,i56,7}, i64, i128);
140    parse_num!(f32, f64);
141}
142
143#[allow(dead_code)]
144impl<'a> ParseBuf<'a> {
145    /// Consumes MySql length-encoded integer.
146    ///
147    /// Returns `0` if integer is malformed (starts with 0xff or 0xfb).
148    pub fn eat_lenenc_int(&mut self) -> u64 {
149        match self.eat_u8() {
150            x @ 0..=0xfa => x as u64,
151            0xfc => self.eat_u16() as u64,
152            0xfd => self.eat_u24() as u64,
153            0xfe => self.eat_u64(),
154            0xfb | 0xff => 0,
155        }
156    }
157
158    pub fn checked_eat_lenenc_int(&mut self) -> Result<u64, ProtocolError> {
159        match self.checked_eat_u8()? {
160            x @ 0..=0xfa => Ok(x as u64),
161            0xfc => self.checked_eat_u16().map(|x| x as u64),
162            0xfd => self.checked_eat_u24().map(|x| x as u64),
163            0xfe => self.checked_eat_u64(),
164            0xfb | 0xff => Ok(0),
165        }
166    }
167
168    /// Returns an empty slice if length is malformed (starts with 0xff or 0xfb).
169    pub fn eat_lenenc_slice(&mut self) -> &'a [u8] {
170        let len: u64 = self.eat_lenenc_int();
171        self.eat(len as usize)
172    }
173
174    /// Returns an empty string if length is malformed (starts with 0xff or 0xfb).
175    pub fn eat_lenenc_str(&mut self) -> Result<&'a str, ProtocolError> {
176        std::str::from_utf8(self.eat_lenenc_slice()).map_err(Into::into)
177    }
178
179    pub fn checked_eat_lenenc_slice(&mut self) -> Result<&'a [u8], ProtocolError> {
180        let len = self.checked_eat_lenenc_int()?;
181        self.checked_eat(len as usize)
182    }
183
184    pub fn checked_eat_lenenc_str(&mut self) -> Result<&'a str, ProtocolError> {
185        std::str::from_utf8(self.checked_eat_lenenc_slice()?).map_err(Into::into)
186    }
187
188    pub fn eat_u8_slice(&mut self) -> &'a [u8] {
189        let len = self.eat_u8();
190        self.eat(len as usize)
191    }
192
193    pub fn eat_u8_str(&mut self) -> Result<&'a str, ProtocolError> {
194        std::str::from_utf8(self.eat_u8_slice()).map_err(Into::into)
195    }
196
197    pub fn checked_eat_u8_slice(&mut self) -> Result<&'a [u8], ProtocolError> {
198        let len = self.checked_eat_u8()?;
199        self.checked_eat(len as usize)
200    }
201
202    pub fn checked_eat_u8_str(&mut self) -> Result<&'a str, ProtocolError> {
203        std::str::from_utf8(self.checked_eat_u8_slice()?).map_err(Into::into)
204    }
205
206    /// Consumes whole buffer if there is no `0`-byte.
207    pub fn eat_null_slice(&mut self) -> &'a [u8] {
208        let pos = self
209            .0
210            .iter()
211            .position(|x| *x == 0)
212            .map(|x| x + 1)
213            .unwrap_or_else(|| self.len());
214        match self.eat(pos) {
215            [head @ .., 0_u8] => head,
216            x => x,
217        }
218    }
219
220    pub fn eat_null_str(&mut self) -> Result<&'a str, ProtocolError> {
221        std::str::from_utf8(self.eat_null_slice()).map_err(Into::into)
222    }
223}