parquet_format_safe/thrift/protocol/
compact.rs1use std::convert::{From, TryFrom, TryInto};
2use std::io;
3use std::io::Read;
4
5use super::super::varint::VarIntReader;
6
7use super::super::{Error, ProtocolError, ProtocolErrorKind, Result};
8use super::{
9 TFieldIdentifier, TInputProtocol, TListIdentifier, TMapIdentifier, TMessageIdentifier,
10 TMessageType,
11};
12use super::{TSetIdentifier, TStructIdentifier, TType};
13
14pub(super) const COMPACT_PROTOCOL_ID: u8 = 0x82;
15pub(super) const COMPACT_VERSION: u8 = 0x01;
16pub(super) const COMPACT_VERSION_MASK: u8 = 0x1F;
17
18#[derive(Debug)]
20pub struct TCompactInputProtocol<R>
21where
22 R: Read,
23{
24 last_read_field_id: i16,
26 read_field_id_stack: Vec<i16>,
28 pending_read_bool_value: Option<bool>,
32 reader: R,
34 remaining: usize,
36}
37
38impl<R> TCompactInputProtocol<R>
39where
40 R: Read,
41{
42 pub fn new(reader: R, max_bytes: usize) -> Self {
45 Self {
46 last_read_field_id: 0,
47 read_field_id_stack: Vec::new(),
48 pending_read_bool_value: None,
49 remaining: max_bytes,
50 reader,
51 }
52 }
53
54 fn update_remaining<T>(&mut self, element: usize) -> Result<()> {
55 self.remaining = self
56 .remaining
57 .checked_sub((element).saturating_mul(std::mem::size_of::<T>()))
58 .ok_or_else(|| {
59 Error::Protocol(ProtocolError {
60 kind: ProtocolErrorKind::SizeLimit,
61 message: "The thrift file would allocate more bytes than allowed".to_string(),
62 })
63 })?;
64 Ok(())
65 }
66
67 fn read_list_set_begin(&mut self) -> Result<(TType, u32)> {
68 let header = self.read_byte()?;
69 let element_type = collection_u8_to_type(header & 0x0F)?;
70
71 let possible_element_count = (header & 0xF0) >> 4;
72 let element_count = if possible_element_count != 15 {
73 possible_element_count.into()
75 } else {
76 self.reader.read_varint::<u32>()?
77 };
78 self.update_remaining::<usize>(element_count as usize)?;
79
80 Ok((element_type, element_count))
81 }
82}
83
84impl<R> TInputProtocol for TCompactInputProtocol<R>
85where
86 R: Read,
87{
88 fn read_message_begin(&mut self) -> Result<TMessageIdentifier> {
89 let compact_id = self.read_byte()?;
90 if compact_id != COMPACT_PROTOCOL_ID {
91 Err(Error::Protocol(ProtocolError {
92 kind: ProtocolErrorKind::BadVersion,
93 message: format!("invalid compact protocol header {:?}", compact_id),
94 }))
95 } else {
96 Ok(())
97 }?;
98
99 let type_and_byte = self.read_byte()?;
100 let received_version = type_and_byte & COMPACT_VERSION_MASK;
101 if received_version != COMPACT_VERSION {
102 Err(Error::Protocol(ProtocolError {
103 kind: ProtocolErrorKind::BadVersion,
104 message: format!(
105 "cannot process compact protocol version {:?}",
106 received_version
107 ),
108 }))
109 } else {
110 Ok(())
111 }?;
112
113 let message_type: TMessageType = TMessageType::try_from(type_and_byte >> 5)?;
115 let sequence_number = self.reader.read_varint::<u32>()?;
116 let service_call_name = self.read_string()?;
117
118 self.last_read_field_id = 0;
119
120 Ok(TMessageIdentifier::new(
121 service_call_name,
122 message_type,
123 sequence_number,
124 ))
125 }
126
127 fn read_message_end(&mut self) -> Result<()> {
128 Ok(())
129 }
130
131 fn read_struct_begin(&mut self) -> Result<Option<TStructIdentifier>> {
132 self.update_remaining::<i16>(1)?;
133 self.read_field_id_stack.push(self.last_read_field_id);
134 self.last_read_field_id = 0;
135 Ok(None)
136 }
137
138 fn read_struct_end(&mut self) -> Result<()> {
139 self.last_read_field_id = self
140 .read_field_id_stack
141 .pop()
142 .expect("should have previous field ids");
143 Ok(())
144 }
145
146 fn read_field_begin(&mut self) -> Result<TFieldIdentifier> {
147 let field_type = self.read_byte()?;
151 let field_delta = (field_type & 0xF0) >> 4;
152 let field_type = match field_type & 0x0F {
153 0x01 => {
154 self.pending_read_bool_value = Some(true);
155 Ok(TType::Bool)
156 }
157 0x02 => {
158 self.pending_read_bool_value = Some(false);
159 Ok(TType::Bool)
160 }
161 ttu8 => u8_to_type(ttu8),
162 }?;
163
164 match field_type {
165 TType::Stop => Ok(
166 TFieldIdentifier::new::<Option<String>, String, Option<i16>>(
167 None,
168 TType::Stop,
169 None,
170 ),
171 ),
172 _ => {
173 if field_delta != 0 {
174 self.last_read_field_id = self
175 .last_read_field_id
176 .checked_add(field_delta as i16)
177 .ok_or(Error::Protocol(ProtocolError {
178 kind: ProtocolErrorKind::DepthLimit,
179 message: String::new(),
180 }))?;
181 } else {
182 self.last_read_field_id = self.read_i16()?;
183 };
184
185 Ok(TFieldIdentifier {
186 name: None,
187 field_type,
188 id: Some(self.last_read_field_id),
189 })
190 }
191 }
192 }
193
194 fn read_field_end(&mut self) -> Result<()> {
195 Ok(())
196 }
197
198 fn read_bool(&mut self) -> Result<bool> {
199 match self.pending_read_bool_value.take() {
200 Some(b) => Ok(b),
201 None => {
202 let b = self.read_byte()?;
203 match b {
204 0x01 => Ok(true),
205 0x02 => Ok(false),
206 unkn => Err(Error::Protocol(ProtocolError {
207 kind: ProtocolErrorKind::InvalidData,
208 message: format!("cannot convert {} into bool", unkn),
209 })),
210 }
211 }
212 }
213 }
214
215 fn read_bytes(&mut self) -> Result<Vec<u8>> {
216 let len = self.reader.read_varint::<u32>()?;
217
218 self.update_remaining::<u8>(len.try_into()?)?;
219
220 let mut buf = vec![];
221 buf.try_reserve(len.try_into()?)?;
222 self.reader
223 .by_ref()
224 .take(len.try_into()?)
225 .read_to_end(&mut buf)?;
226 Ok(buf)
227 }
228
229 fn read_i8(&mut self) -> Result<i8> {
230 self.read_byte().map(|i| i as i8)
231 }
232
233 fn read_i16(&mut self) -> Result<i16> {
234 self.reader.read_varint::<i16>().map_err(From::from)
235 }
236
237 fn read_i32(&mut self) -> Result<i32> {
238 self.reader.read_varint::<i32>().map_err(From::from)
239 }
240
241 fn read_i64(&mut self) -> Result<i64> {
242 self.reader.read_varint::<i64>().map_err(From::from)
243 }
244
245 fn read_double(&mut self) -> Result<f64> {
246 let mut data = [0u8; 8];
247 self.reader.read_exact(&mut data)?;
248 Ok(f64::from_le_bytes(data))
249 }
250
251 fn read_string(&mut self) -> Result<String> {
252 let bytes = self.read_bytes()?;
253 String::from_utf8(bytes).map_err(From::from)
254 }
255
256 fn read_list_begin(&mut self) -> Result<TListIdentifier> {
257 let (element_type, element_count) = self.read_list_set_begin()?;
258 Ok(TListIdentifier::new(element_type, element_count))
259 }
260
261 fn read_list_end(&mut self) -> Result<()> {
262 Ok(())
263 }
264
265 fn read_set_begin(&mut self) -> Result<TSetIdentifier> {
266 let (element_type, element_count) = self.read_list_set_begin()?;
267 Ok(TSetIdentifier::new(element_type, element_count))
268 }
269
270 fn read_set_end(&mut self) -> Result<()> {
271 Ok(())
272 }
273
274 fn read_map_begin(&mut self) -> Result<TMapIdentifier> {
275 let element_count = self.reader.read_varint::<u32>()?;
276 if element_count == 0 {
277 Ok(TMapIdentifier::new(None, None, 0))
278 } else {
279 let type_header = self.read_byte()?;
280 let key_type = collection_u8_to_type((type_header & 0xF0) >> 4)?;
281 let val_type = collection_u8_to_type(type_header & 0x0F)?;
282 self.update_remaining::<usize>(element_count.try_into()?)?;
283 Ok(TMapIdentifier::new(key_type, val_type, element_count))
284 }
285 }
286
287 fn read_map_end(&mut self) -> Result<()> {
288 Ok(())
289 }
290
291 fn read_byte(&mut self) -> Result<u8> {
295 let mut buf = [0u8; 1];
296 self.reader
297 .read_exact(&mut buf)
298 .map_err(From::from)
299 .map(|_| buf[0])
300 }
301}
302
303impl<R> io::Seek for TCompactInputProtocol<R>
304where
305 R: io::Seek + Read,
306{
307 fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> {
308 self.reader.seek(pos)
309 }
310}
311
312pub(super) fn collection_u8_to_type(b: u8) -> Result<TType> {
313 match b {
314 0x01 => Ok(TType::Bool),
315 o => u8_to_type(o),
316 }
317}
318
319pub(super) fn u8_to_type(b: u8) -> Result<TType> {
320 match b {
321 0x00 => Ok(TType::Stop),
322 0x03 => Ok(TType::I08), 0x04 => Ok(TType::I16),
324 0x05 => Ok(TType::I32),
325 0x06 => Ok(TType::I64),
326 0x07 => Ok(TType::Double),
327 0x08 => Ok(TType::String),
328 0x09 => Ok(TType::List),
329 0x0A => Ok(TType::Set),
330 0x0B => Ok(TType::Map),
331 0x0C => Ok(TType::Struct),
332 unkn => Err(Error::Protocol(ProtocolError {
333 kind: ProtocolErrorKind::InvalidData,
334 message: format!("cannot convert {} into TType", unkn),
335 })),
336 }
337}