Skip to main content

qusql_mysql/
package_parser.rs

1//! Contains parser used to parse packages
2use std::num::TryFromIntError;
3
4use bytes::Buf;
5use thiserror::Error;
6
7/// Error returned by the [PackageParser]
8#[derive(Error, Debug)]
9pub enum DecodeError {
10    /// The package was shorter than expected
11    #[error("End of package")]
12    EndOfPackage,
13    /// You are decoding more columns than there is in the response
14    #[error("End of columns")]
15    EndOfColumns,
16    /// A string in a package was not utf-8 as expected
17    #[error("Utf-8 error at {valid_up_to}")]
18    Utf8Error {
19        /// The string is valid utf-8 until this many bytes
20        valid_up_to: u32,
21        /// The length of the error
22        error_len: Option<u8>,
23    },
24    /// We expected a non-null value, but found null
25    #[error("Unexpected null value")]
26    Null,
27    /// We expected an unsigned integer but we got aa signed integer
28    #[error("Expected unsigned got signed")]
29    ExpectedUnsigned,
30    /// We expected an signed integer but we got aa unsigned integer
31    #[error("Expected signed got unsigned")]
32    ExpectedSigned,
33    /// The field we are decoding has a different type than expected
34    #[error("Type error")]
35    TypeError {
36        /// The field has this type as defined in [crate::constants::type_]
37        got: u8,
38        /// We expected this type
39        expected: u8,
40    },
41    /// A variable length encoded field has an unexpected size
42    #[error("Invalid size {0}")]
43    InvalidSize(u8),
44    /// The value could not be converted to the target type
45    #[error("Invalid value")]
46    InvalidValue,
47    /// Error converting between integer types
48    #[error("Invalid integer cast")]
49    TryFromInt,
50}
51
52const _: () = {
53    assert!(size_of::<DecodeError>() <= 8);
54};
55
56impl From<bytes::TryGetError> for DecodeError {
57    fn from(_value: bytes::TryGetError) -> Self {
58        DecodeError::EndOfPackage
59    }
60}
61
62impl From<std::str::Utf8Error> for DecodeError {
63    fn from(value: std::str::Utf8Error) -> Self {
64        DecodeError::Utf8Error {
65            valid_up_to: value.valid_up_to().try_into().unwrap_or(u32::MAX),
66            error_len: value.error_len().map(|v| v.try_into().unwrap_or(0xFF)),
67        }
68    }
69}
70
71impl From<TryFromIntError> for DecodeError {
72    fn from(_value: TryFromIntError) -> Self {
73        DecodeError::TryFromInt
74    }
75}
76
77/// Result returned by [PackageParser]
78pub type DecodeResult<T> = std::result::Result<T, DecodeError>;
79
80/// Parse a Mysql/Mariadb package
81#[derive(Clone, Copy)]
82pub struct PackageParser<'a>(&'a [u8]);
83
84impl<'a> PackageParser<'a> {
85    /// Construct a new [PackageParser] for the given package
86    pub(crate) fn new(package: &'a [u8]) -> Self {
87        Self(package)
88    }
89
90    /// Read a u8 from the package
91    #[inline]
92    pub fn get_u8(&mut self) -> DecodeResult<u8> {
93        Ok(self.0.try_get_u8()?)
94    }
95
96    /// Read a i8 from the package
97    #[inline]
98    pub fn get_i8(&mut self) -> DecodeResult<i8> {
99        Ok(self.0.try_get_i8()?)
100    }
101
102    /// Read a u16 from the package
103    #[inline]
104    pub fn get_u16(&mut self) -> DecodeResult<u16> {
105        Ok(self.0.try_get_u16_le()?)
106    }
107
108    /// Read a i16 from the package
109    #[inline]
110    pub fn get_i16(&mut self) -> DecodeResult<i16> {
111        Ok(self.0.try_get_i16_le()?)
112    }
113
114    /// Read a u32 from the package
115    #[inline]
116    pub fn get_u32(&mut self) -> DecodeResult<u32> {
117        Ok(self.0.try_get_u32_le()?)
118    }
119
120    /// Read a i32 from the package
121    #[inline]
122    pub fn get_i32(&mut self) -> DecodeResult<i32> {
123        Ok(self.0.try_get_i32_le()?)
124    }
125
126    /// Read a u64 from the package
127    #[inline]
128    pub fn get_u64(&mut self) -> DecodeResult<u64> {
129        Ok(self.0.try_get_u64_le()?)
130    }
131
132    /// Read a i64 from the package
133    #[inline]
134    pub fn get_i64(&mut self) -> DecodeResult<i64> {
135        Ok(self.0.try_get_i64_le()?)
136    }
137
138    /// Read a f32 from the package
139    #[inline]
140    pub fn get_f32(&mut self) -> DecodeResult<f32> {
141        Ok(self.0.try_get_f32_le()?)
142    }
143
144    /// Read a f64 from the package
145    #[inline]
146    pub fn get_f64(&mut self) -> DecodeResult<f64> {
147        Ok(self.0.try_get_f64_le()?)
148    }
149
150    /// Read a u64 from the package
151    #[inline]
152    pub fn get_u24(&mut self) -> DecodeResult<u32> {
153        let a: u32 = self.get_u8()?.into();
154        let b: u32 = self.get_u8()?.into();
155        let c: u32 = self.get_u8()?.into();
156        Ok(a | (b << 8) | (c << 16))
157    }
158
159    /// Read a variable encoded length
160    ///
161    /// See <https://mariadb.com/docs/server/reference/clientserver-protocol/protocol-data-types#length-encoded-integers>
162    #[inline]
163    pub fn get_lenenc(&mut self) -> DecodeResult<u64> {
164        let v = self.get_u8()?;
165        Ok(match v {
166            0xFC => self.get_u16()?.into(),
167            0xFD => self.get_u24()?.into(),
168            0xFE => self.get_u64()?,
169            v => v.into(),
170        })
171    }
172
173    /// Read a variable encoded blob
174    #[inline]
175    pub fn get_lenenc_blob(&mut self) -> DecodeResult<&'a [u8]> {
176        let len = self.get_lenenc()?;
177        self.get_bytes(len as usize)
178    }
179
180    /// Read a variable encoded utf8-string
181    #[inline]
182    pub fn get_lenenc_str(&mut self) -> DecodeResult<&'a str> {
183        let len = self.get_lenenc()?;
184        let v = self.get_bytes(len as usize)?;
185        Ok(str::from_utf8(v)?)
186    }
187
188    /// Skip past a variable encoded string or blob
189    #[inline]
190    pub fn skip_lenenc_str(&mut self) -> DecodeResult<()> {
191        let l = self.get_lenenc()?;
192        self.0.advance(l as usize);
193        Ok(())
194    }
195
196    /// Skip a given number of bytes
197    #[inline]
198    pub fn skip_bytes(&mut self, len: usize) {
199        self.0.advance(len);
200    }
201
202    /// Read a null-terminated string
203    #[inline]
204    pub fn get_null_str(&mut self) -> DecodeResult<&'a str> {
205        match std::ffi::CStr::from_bytes_until_nul(self.0) {
206            Ok(v) => {
207                let v = v.to_str()?;
208                self.0.advance(v.len() + 1);
209                Ok(v)
210            }
211            Err(_) => Err(DecodeError::EndOfPackage),
212        }
213    }
214
215    /// Skip past a null-terminated string
216    #[inline]
217    pub fn skip_null_str(&mut self) -> DecodeResult<()> {
218        match std::ffi::CStr::from_bytes_until_nul(self.0) {
219            Ok(v) => {
220                self.0.advance(v.count_bytes() + 1);
221                Ok(())
222            }
223            Err(_) => Err(DecodeError::EndOfPackage),
224        }
225    }
226
227    /// Read the rest of the package as a utf-8 string
228    #[inline]
229    pub fn get_eof_str(&mut self) -> DecodeResult<&'a str> {
230        let v = str::from_utf8(self.0)?;
231        self.0.advance(v.len());
232        Ok(v)
233    }
234
235    /// Read some bytes from the package
236    #[inline]
237    pub fn get_bytes(&mut self, len: usize) -> DecodeResult<&'a [u8]> {
238        match self.0.get(..len) {
239            Some(v) => {
240                self.0.advance(len);
241                Ok(v)
242            }
243            None => Err(DecodeError::EndOfPackage),
244        }
245    }
246}