parquet_format_safe/thrift/protocol/
compact.rs

1use std::convert::{From, TryFrom, TryInto};
2use std::io;
3use std::io::Read;
4
5use super::super::varint::VarIntReader;
6
7use super::super::{Error, ProtocolError, ProtocolErrorKind, Result};
8use super::{
9    TFieldIdentifier, TInputProtocol, TListIdentifier, TMapIdentifier, TMessageIdentifier,
10    TMessageType,
11};
12use super::{TSetIdentifier, TStructIdentifier, TType};
13
14pub(super) const COMPACT_PROTOCOL_ID: u8 = 0x82;
15pub(super) const COMPACT_VERSION: u8 = 0x01;
16pub(super) const COMPACT_VERSION_MASK: u8 = 0x1F;
17
18/// Read messages encoded in the Thrift compact protocol.
19#[derive(Debug)]
20pub struct TCompactInputProtocol<R>
21where
22    R: Read,
23{
24    // Identifier of the last field deserialized for a struct.
25    last_read_field_id: i16,
26    // Stack of the last read field ids (a new entry is added each time a nested struct is read).
27    read_field_id_stack: Vec<i16>,
28    // Boolean value for a field.
29    // Saved because boolean fields and their value are encoded in a single byte,
30    // and reading the field only occurs after the field id is read.
31    pending_read_bool_value: Option<bool>,
32    // Underlying reader used for byte-level operations.
33    reader: R,
34    // remaining bytes that can be read before refusing to read more
35    remaining: usize,
36}
37
38impl<R> TCompactInputProtocol<R>
39where
40    R: Read,
41{
42    /// Create a [`TCompactInputProtocol`] that reads bytes from `reader`.
43
44    pub fn new(reader: R, max_bytes: usize) -> Self {
45        Self {
46            last_read_field_id: 0,
47            read_field_id_stack: Vec::new(),
48            pending_read_bool_value: None,
49            remaining: max_bytes,
50            reader,
51        }
52    }
53
54    fn update_remaining<T>(&mut self, element: usize) -> Result<()> {
55        self.remaining = self
56            .remaining
57            .checked_sub((element).saturating_mul(std::mem::size_of::<T>()))
58            .ok_or_else(|| {
59                Error::Protocol(ProtocolError {
60                    kind: ProtocolErrorKind::SizeLimit,
61                    message: "The thrift file would allocate more bytes than allowed".to_string(),
62                })
63            })?;
64        Ok(())
65    }
66
67    fn read_list_set_begin(&mut self) -> Result<(TType, u32)> {
68        let header = self.read_byte()?;
69        let element_type = collection_u8_to_type(header & 0x0F)?;
70
71        let possible_element_count = (header & 0xF0) >> 4;
72        let element_count = if possible_element_count != 15 {
73            // high bits set high if count and type encoded separately
74            possible_element_count.into()
75        } else {
76            self.reader.read_varint::<u32>()?
77        };
78        self.update_remaining::<usize>(element_count as usize)?;
79
80        Ok((element_type, element_count))
81    }
82}
83
84impl<R> TInputProtocol for TCompactInputProtocol<R>
85where
86    R: Read,
87{
88    fn read_message_begin(&mut self) -> Result<TMessageIdentifier> {
89        let compact_id = self.read_byte()?;
90        if compact_id != COMPACT_PROTOCOL_ID {
91            Err(Error::Protocol(ProtocolError {
92                kind: ProtocolErrorKind::BadVersion,
93                message: format!("invalid compact protocol header {:?}", compact_id),
94            }))
95        } else {
96            Ok(())
97        }?;
98
99        let type_and_byte = self.read_byte()?;
100        let received_version = type_and_byte & COMPACT_VERSION_MASK;
101        if received_version != COMPACT_VERSION {
102            Err(Error::Protocol(ProtocolError {
103                kind: ProtocolErrorKind::BadVersion,
104                message: format!(
105                    "cannot process compact protocol version {:?}",
106                    received_version
107                ),
108            }))
109        } else {
110            Ok(())
111        }?;
112
113        // NOTE: unsigned right shift will pad with 0s
114        let message_type: TMessageType = TMessageType::try_from(type_and_byte >> 5)?;
115        let sequence_number = self.reader.read_varint::<u32>()?;
116        let service_call_name = self.read_string()?;
117
118        self.last_read_field_id = 0;
119
120        Ok(TMessageIdentifier::new(
121            service_call_name,
122            message_type,
123            sequence_number,
124        ))
125    }
126
127    fn read_message_end(&mut self) -> Result<()> {
128        Ok(())
129    }
130
131    fn read_struct_begin(&mut self) -> Result<Option<TStructIdentifier>> {
132        self.update_remaining::<i16>(1)?;
133        self.read_field_id_stack.push(self.last_read_field_id);
134        self.last_read_field_id = 0;
135        Ok(None)
136    }
137
138    fn read_struct_end(&mut self) -> Result<()> {
139        self.last_read_field_id = self
140            .read_field_id_stack
141            .pop()
142            .expect("should have previous field ids");
143        Ok(())
144    }
145
146    fn read_field_begin(&mut self) -> Result<TFieldIdentifier> {
147        // we can read at least one byte, which is:
148        // - the type
149        // - the field delta and the type
150        let field_type = self.read_byte()?;
151        let field_delta = (field_type & 0xF0) >> 4;
152        let field_type = match field_type & 0x0F {
153            0x01 => {
154                self.pending_read_bool_value = Some(true);
155                Ok(TType::Bool)
156            }
157            0x02 => {
158                self.pending_read_bool_value = Some(false);
159                Ok(TType::Bool)
160            }
161            ttu8 => u8_to_type(ttu8),
162        }?;
163
164        match field_type {
165            TType::Stop => Ok(
166                TFieldIdentifier::new::<Option<String>, String, Option<i16>>(
167                    None,
168                    TType::Stop,
169                    None,
170                ),
171            ),
172            _ => {
173                if field_delta != 0 {
174                    self.last_read_field_id = self
175                        .last_read_field_id
176                        .checked_add(field_delta as i16)
177                        .ok_or(Error::Protocol(ProtocolError {
178                            kind: ProtocolErrorKind::DepthLimit,
179                            message: String::new(),
180                        }))?;
181                } else {
182                    self.last_read_field_id = self.read_i16()?;
183                };
184
185                Ok(TFieldIdentifier {
186                    name: None,
187                    field_type,
188                    id: Some(self.last_read_field_id),
189                })
190            }
191        }
192    }
193
194    fn read_field_end(&mut self) -> Result<()> {
195        Ok(())
196    }
197
198    fn read_bool(&mut self) -> Result<bool> {
199        match self.pending_read_bool_value.take() {
200            Some(b) => Ok(b),
201            None => {
202                let b = self.read_byte()?;
203                match b {
204                    0x01 => Ok(true),
205                    0x02 => Ok(false),
206                    unkn => Err(Error::Protocol(ProtocolError {
207                        kind: ProtocolErrorKind::InvalidData,
208                        message: format!("cannot convert {} into bool", unkn),
209                    })),
210                }
211            }
212        }
213    }
214
215    fn read_bytes(&mut self) -> Result<Vec<u8>> {
216        let len = self.reader.read_varint::<u32>()?;
217
218        self.update_remaining::<u8>(len.try_into()?)?;
219
220        let mut buf = vec![];
221        buf.try_reserve(len.try_into()?)?;
222        self.reader
223            .by_ref()
224            .take(len.try_into()?)
225            .read_to_end(&mut buf)?;
226        Ok(buf)
227    }
228
229    fn read_i8(&mut self) -> Result<i8> {
230        self.read_byte().map(|i| i as i8)
231    }
232
233    fn read_i16(&mut self) -> Result<i16> {
234        self.reader.read_varint::<i16>().map_err(From::from)
235    }
236
237    fn read_i32(&mut self) -> Result<i32> {
238        self.reader.read_varint::<i32>().map_err(From::from)
239    }
240
241    fn read_i64(&mut self) -> Result<i64> {
242        self.reader.read_varint::<i64>().map_err(From::from)
243    }
244
245    fn read_double(&mut self) -> Result<f64> {
246        let mut data = [0u8; 8];
247        self.reader.read_exact(&mut data)?;
248        Ok(f64::from_le_bytes(data))
249    }
250
251    fn read_string(&mut self) -> Result<String> {
252        let bytes = self.read_bytes()?;
253        String::from_utf8(bytes).map_err(From::from)
254    }
255
256    fn read_list_begin(&mut self) -> Result<TListIdentifier> {
257        let (element_type, element_count) = self.read_list_set_begin()?;
258        Ok(TListIdentifier::new(element_type, element_count))
259    }
260
261    fn read_list_end(&mut self) -> Result<()> {
262        Ok(())
263    }
264
265    fn read_set_begin(&mut self) -> Result<TSetIdentifier> {
266        let (element_type, element_count) = self.read_list_set_begin()?;
267        Ok(TSetIdentifier::new(element_type, element_count))
268    }
269
270    fn read_set_end(&mut self) -> Result<()> {
271        Ok(())
272    }
273
274    fn read_map_begin(&mut self) -> Result<TMapIdentifier> {
275        let element_count = self.reader.read_varint::<u32>()?;
276        if element_count == 0 {
277            Ok(TMapIdentifier::new(None, None, 0))
278        } else {
279            let type_header = self.read_byte()?;
280            let key_type = collection_u8_to_type((type_header & 0xF0) >> 4)?;
281            let val_type = collection_u8_to_type(type_header & 0x0F)?;
282            self.update_remaining::<usize>(element_count.try_into()?)?;
283            Ok(TMapIdentifier::new(key_type, val_type, element_count))
284        }
285    }
286
287    fn read_map_end(&mut self) -> Result<()> {
288        Ok(())
289    }
290
291    // utility
292    //
293
294    fn read_byte(&mut self) -> Result<u8> {
295        let mut buf = [0u8; 1];
296        self.reader
297            .read_exact(&mut buf)
298            .map_err(From::from)
299            .map(|_| buf[0])
300    }
301}
302
303impl<R> io::Seek for TCompactInputProtocol<R>
304where
305    R: io::Seek + Read,
306{
307    fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> {
308        self.reader.seek(pos)
309    }
310}
311
312pub(super) fn collection_u8_to_type(b: u8) -> Result<TType> {
313    match b {
314        0x01 => Ok(TType::Bool),
315        o => u8_to_type(o),
316    }
317}
318
319pub(super) fn u8_to_type(b: u8) -> Result<TType> {
320    match b {
321        0x00 => Ok(TType::Stop),
322        0x03 => Ok(TType::I08), // equivalent to TType::Byte
323        0x04 => Ok(TType::I16),
324        0x05 => Ok(TType::I32),
325        0x06 => Ok(TType::I64),
326        0x07 => Ok(TType::Double),
327        0x08 => Ok(TType::String),
328        0x09 => Ok(TType::List),
329        0x0A => Ok(TType::Set),
330        0x0B => Ok(TType::Map),
331        0x0C => Ok(TType::Struct),
332        unkn => Err(Error::Protocol(ProtocolError {
333            kind: ProtocolErrorKind::InvalidData,
334            message: format!("cannot convert {} into TType", unkn),
335        })),
336    }
337}