parquet_format_async_temp/thrift/protocol/
compact.rs1use std::convert::{From, TryFrom};
18use std::io;
19use std::io::{Read, Write};
20
21use integer_encoding::{VarIntReader, VarIntWriter};
22
23use super::super::{Error, ProtocolError, ProtocolErrorKind, Result};
24use super::{
25 TFieldIdentifier, TInputProtocol, TListIdentifier, TMapIdentifier, TMessageIdentifier,
26 TMessageType,
27};
28use super::{TOutputProtocol, TSetIdentifier, TStructIdentifier, TType};
29
30pub(super) const COMPACT_PROTOCOL_ID: u8 = 0x82;
31pub(super) const COMPACT_VERSION: u8 = 0x01;
32pub(super) const COMPACT_VERSION_MASK: u8 = 0x1F;
33
34#[derive(Debug)]
53pub struct TCompactInputProtocol<T>
54where
55 T: Read,
56{
57 last_read_field_id: i16,
59 read_field_id_stack: Vec<i16>,
61 pending_read_bool_value: Option<bool>,
65 transport: T,
67}
68
69impl<T> TCompactInputProtocol<T>
70where
71 T: Read,
72{
73 pub fn new(transport: T) -> TCompactInputProtocol<T> {
75 TCompactInputProtocol {
76 last_read_field_id: 0,
77 read_field_id_stack: Vec::new(),
78 pending_read_bool_value: None,
79 transport,
80 }
81 }
82
83 fn read_list_set_begin(&mut self) -> Result<(TType, i32)> {
84 let header = self.read_byte()?;
85 let element_type = collection_u8_to_type(header & 0x0F)?;
86
87 let possible_element_count = (header & 0xF0) >> 4;
88 let element_count = if possible_element_count != 15 {
89 possible_element_count as i32
91 } else {
92 self.transport.read_varint::<u32>()? as i32
93 };
94
95 Ok((element_type, element_count))
96 }
97}
98
99impl<T> TInputProtocol for TCompactInputProtocol<T>
100where
101 T: Read,
102{
103 fn read_message_begin(&mut self) -> Result<TMessageIdentifier> {
104 let compact_id = self.read_byte()?;
105 if compact_id != COMPACT_PROTOCOL_ID {
106 Err(Error::Protocol(ProtocolError {
107 kind: ProtocolErrorKind::BadVersion,
108 message: format!("invalid compact protocol header {:?}", compact_id),
109 }))
110 } else {
111 Ok(())
112 }?;
113
114 let type_and_byte = self.read_byte()?;
115 let received_version = type_and_byte & COMPACT_VERSION_MASK;
116 if received_version != COMPACT_VERSION {
117 Err(Error::Protocol(ProtocolError {
118 kind: ProtocolErrorKind::BadVersion,
119 message: format!(
120 "cannot process compact protocol version {:?}",
121 received_version
122 ),
123 }))
124 } else {
125 Ok(())
126 }?;
127
128 let message_type: TMessageType = TMessageType::try_from(type_and_byte >> 5)?;
130 let sequence_number = self.transport.read_varint::<u32>()? as i32;
132 let service_call_name = self.read_string()?;
133
134 self.last_read_field_id = 0;
135
136 Ok(TMessageIdentifier::new(
137 service_call_name,
138 message_type,
139 sequence_number,
140 ))
141 }
142
143 fn read_message_end(&mut self) -> Result<()> {
144 Ok(())
145 }
146
147 fn read_struct_begin(&mut self) -> Result<Option<TStructIdentifier>> {
148 self.read_field_id_stack.push(self.last_read_field_id);
149 self.last_read_field_id = 0;
150 Ok(None)
151 }
152
153 fn read_struct_end(&mut self) -> Result<()> {
154 self.last_read_field_id = self
155 .read_field_id_stack
156 .pop()
157 .expect("should have previous field ids");
158 Ok(())
159 }
160
161 fn read_field_begin(&mut self) -> Result<TFieldIdentifier> {
162 let field_type = self.read_byte()?;
166 let field_delta = (field_type & 0xF0) >> 4;
167 let field_type = match field_type & 0x0F {
168 0x01 => {
169 self.pending_read_bool_value = Some(true);
170 Ok(TType::Bool)
171 }
172 0x02 => {
173 self.pending_read_bool_value = Some(false);
174 Ok(TType::Bool)
175 }
176 ttu8 => u8_to_type(ttu8),
177 }?;
178
179 match field_type {
180 TType::Stop => Ok(
181 TFieldIdentifier::new::<Option<String>, String, Option<i16>>(
182 None,
183 TType::Stop,
184 None,
185 ),
186 ),
187 _ => {
188 if field_delta != 0 {
189 self.last_read_field_id += field_delta as i16;
190 } else {
191 self.last_read_field_id = self.read_i16()?;
192 };
193
194 Ok(TFieldIdentifier {
195 name: None,
196 field_type,
197 id: Some(self.last_read_field_id),
198 })
199 }
200 }
201 }
202
203 fn read_field_end(&mut self) -> Result<()> {
204 Ok(())
205 }
206
207 fn read_bool(&mut self) -> Result<bool> {
208 match self.pending_read_bool_value.take() {
209 Some(b) => Ok(b),
210 None => {
211 let b = self.read_byte()?;
212 match b {
213 0x01 => Ok(true),
214 0x02 => Ok(false),
215 unkn => Err(Error::Protocol(ProtocolError {
216 kind: ProtocolErrorKind::InvalidData,
217 message: format!("cannot convert {} into bool", unkn),
218 })),
219 }
220 }
221 }
222 }
223
224 fn read_bytes(&mut self) -> Result<Vec<u8>> {
225 let len = self.transport.read_varint::<u32>()?;
226 let mut buf = vec![0u8; len as usize];
227 self.transport
228 .read_exact(&mut buf)
229 .map_err(From::from)
230 .map(|_| buf)
231 }
232
233 fn read_i8(&mut self) -> Result<i8> {
234 self.read_byte().map(|i| i as i8)
235 }
236
237 fn read_i16(&mut self) -> Result<i16> {
238 self.transport.read_varint::<i16>().map_err(From::from)
239 }
240
241 fn read_i32(&mut self) -> Result<i32> {
242 self.transport.read_varint::<i32>().map_err(From::from)
243 }
244
245 fn read_i64(&mut self) -> Result<i64> {
246 self.transport.read_varint::<i64>().map_err(From::from)
247 }
248
249 fn read_double(&mut self) -> Result<f64> {
250 let mut data = [0u8; 8];
251 self.transport.read_exact(&mut data)?;
252 Ok(f64::from_le_bytes(data))
253 }
254
255 fn read_string(&mut self) -> Result<String> {
256 let bytes = self.read_bytes()?;
257 String::from_utf8(bytes).map_err(From::from)
258 }
259
260 fn read_list_begin(&mut self) -> Result<TListIdentifier> {
261 let (element_type, element_count) = self.read_list_set_begin()?;
262 Ok(TListIdentifier::new(element_type, element_count))
263 }
264
265 fn read_list_end(&mut self) -> Result<()> {
266 Ok(())
267 }
268
269 fn read_set_begin(&mut self) -> Result<TSetIdentifier> {
270 let (element_type, element_count) = self.read_list_set_begin()?;
271 Ok(TSetIdentifier::new(element_type, element_count))
272 }
273
274 fn read_set_end(&mut self) -> Result<()> {
275 Ok(())
276 }
277
278 fn read_map_begin(&mut self) -> Result<TMapIdentifier> {
279 let element_count = self.transport.read_varint::<u32>()? as i32;
280 if element_count == 0 {
281 Ok(TMapIdentifier::new(None, None, 0))
282 } else {
283 let type_header = self.read_byte()?;
284 let key_type = collection_u8_to_type((type_header & 0xF0) >> 4)?;
285 let val_type = collection_u8_to_type(type_header & 0x0F)?;
286 Ok(TMapIdentifier::new(key_type, val_type, element_count))
287 }
288 }
289
290 fn read_map_end(&mut self) -> Result<()> {
291 Ok(())
292 }
293
294 fn read_byte(&mut self) -> Result<u8> {
298 let mut buf = [0u8; 1];
299 self.transport
300 .read_exact(&mut buf)
301 .map_err(From::from)
302 .map(|_| buf[0])
303 }
304}
305
306impl<T> io::Seek for TCompactInputProtocol<T>
307where
308 T: io::Seek + Read,
309{
310 fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> {
311 self.transport.seek(pos)
312 }
313}
314
315#[derive(Debug)]
334pub struct TCompactOutputProtocol<T>
335where
336 T: Write,
337{
338 last_write_field_id: i16,
340 write_field_id_stack: Vec<i16>,
342 pending_write_bool_field_identifier: Option<TFieldIdentifier>,
345 transport: T,
347}
348
349impl<T> TCompactOutputProtocol<T>
350where
351 T: Write,
352{
353 pub fn new(transport: T) -> TCompactOutputProtocol<T> {
355 TCompactOutputProtocol {
356 last_write_field_id: 0,
357 write_field_id_stack: Vec::new(),
358 pending_write_bool_field_identifier: None,
359 transport,
360 }
361 }
362
363 fn write_field_header(&mut self, field_type: u8, field_id: i16) -> Result<usize> {
365 let mut written = 0;
366
367 let field_delta = field_id - self.last_write_field_id;
368 if field_delta > 0 && field_delta < 15 {
369 written += self.write_byte(((field_delta as u8) << 4) | field_type)?;
370 } else {
371 written += self.write_byte(field_type)?;
372 written += self.write_i16(field_id)?;
373 }
374 self.last_write_field_id = field_id;
375 Ok(written)
376 }
377
378 fn write_list_set_begin(&mut self, element_type: TType, element_count: i32) -> Result<usize> {
379 let mut written = 0;
380
381 let elem_identifier = collection_type_to_u8(element_type);
382 if element_count <= 14 {
383 let header = (element_count as u8) << 4 | elem_identifier;
384 written += self.write_byte(header)?;
385 } else {
386 let header = 0xF0 | elem_identifier;
387 written += self.write_byte(header)?;
388 written += self.transport.write_varint(element_count as u32)?;
391 }
392 Ok(written)
393 }
394
395 fn assert_no_pending_bool_write(&self) {
396 if let Some(ref f) = self.pending_write_bool_field_identifier {
397 panic!("pending bool field {:?} not written", f)
398 }
399 }
400}
401
402impl<T> TOutputProtocol for TCompactOutputProtocol<T>
403where
404 T: Write,
405{
406 fn write_message_begin(&mut self, identifier: &TMessageIdentifier) -> Result<usize> {
407 let mut written = 0;
408 written += self.write_byte(COMPACT_PROTOCOL_ID)?;
409 written += self.write_byte((u8::from(identifier.message_type) << 5) | COMPACT_VERSION)?;
410 written += self
412 .transport
413 .write_varint(identifier.sequence_number as u32)?;
414 written += self.write_string(&identifier.name)?;
415 Ok(written)
416 }
417
418 fn write_message_end(&mut self) -> Result<usize> {
419 self.assert_no_pending_bool_write();
420 Ok(0)
421 }
422
423 fn write_struct_begin(&mut self, _: &TStructIdentifier) -> Result<usize> {
424 self.write_field_id_stack.push(self.last_write_field_id);
425 self.last_write_field_id = 0;
426 Ok(0)
427 }
428
429 fn write_struct_end(&mut self) -> Result<usize> {
430 self.assert_no_pending_bool_write();
431 self.last_write_field_id = self
432 .write_field_id_stack
433 .pop()
434 .expect("should have previous field ids");
435 Ok(0)
436 }
437
438 fn write_field_begin(&mut self, identifier: &TFieldIdentifier) -> Result<usize> {
439 match identifier.field_type {
440 TType::Bool => {
441 if self.pending_write_bool_field_identifier.is_some() {
442 panic!(
443 "should not have a pending bool while writing another bool with id: \
444 {:?}",
445 identifier
446 )
447 }
448 self.pending_write_bool_field_identifier = Some(identifier.clone());
449 Ok(0)
450 }
451 _ => {
452 let field_type = type_to_u8(identifier.field_type);
453 let field_id = identifier.id.expect("non-stop field should have field id");
454 self.write_field_header(field_type, field_id)
455 }
456 }
457 }
458
459 fn write_field_end(&mut self) -> Result<usize> {
460 self.assert_no_pending_bool_write();
461 Ok(0)
462 }
463
464 fn write_field_stop(&mut self) -> Result<usize> {
465 self.assert_no_pending_bool_write();
466 self.write_byte(type_to_u8(TType::Stop))
467 }
468
469 fn write_bool(&mut self, b: bool) -> Result<usize> {
470 match self.pending_write_bool_field_identifier.take() {
471 Some(pending) => {
472 let field_id = pending.id.expect("bool field should have a field id");
473 let field_type_as_u8 = if b { 0x01 } else { 0x02 };
474 self.write_field_header(field_type_as_u8, field_id)
475 }
476 None => {
477 if b {
478 self.write_byte(0x01)
479 } else {
480 self.write_byte(0x02)
481 }
482 }
483 }
484 }
485
486 fn write_bytes(&mut self, b: &[u8]) -> Result<usize> {
487 let mut written = 0;
488 written += self.transport.write_varint(b.len() as u32)?;
491 self.transport.write_all(b)?;
492 written += b.len();
493 Ok(written)
494 }
495
496 fn write_i8(&mut self, i: i8) -> Result<usize> {
497 self.write_byte(i as u8)
498 }
499
500 fn write_i16(&mut self, i: i16) -> Result<usize> {
501 self.transport.write_varint(i).map_err(From::from)
502 }
503
504 fn write_i32(&mut self, i: i32) -> Result<usize> {
505 self.transport.write_varint(i).map_err(From::from)
506 }
507
508 fn write_i64(&mut self, i: i64) -> Result<usize> {
509 self.transport.write_varint(i).map_err(From::from)
510 }
511
512 fn write_double(&mut self, d: f64) -> Result<usize> {
513 let bytes = d.to_le_bytes();
514 self.transport.write_all(&bytes)?;
515 Ok(8)
516 }
517
518 fn write_string(&mut self, s: &str) -> Result<usize> {
519 self.write_bytes(s.as_bytes())
520 }
521
522 fn write_list_begin(&mut self, identifier: &TListIdentifier) -> Result<usize> {
523 self.write_list_set_begin(identifier.element_type, identifier.size)
524 }
525
526 fn write_list_end(&mut self) -> Result<usize> {
527 Ok(0)
528 }
529
530 fn write_set_begin(&mut self, identifier: &TSetIdentifier) -> Result<usize> {
531 self.write_list_set_begin(identifier.element_type, identifier.size)
532 }
533
534 fn write_set_end(&mut self) -> Result<usize> {
535 Ok(0)
536 }
537
538 fn write_map_begin(&mut self, identifier: &TMapIdentifier) -> Result<usize> {
539 if identifier.size == 0 {
540 self.write_byte(0)
541 } else {
542 let mut written = 0;
543 written += self.transport.write_varint(identifier.size as u32)?;
546
547 let key_type = identifier
548 .key_type
549 .expect("map identifier to write should contain key type");
550 let key_type_byte = collection_type_to_u8(key_type) << 4;
551
552 let val_type = identifier
553 .value_type
554 .expect("map identifier to write should contain value type");
555 let val_type_byte = collection_type_to_u8(val_type);
556
557 let map_type_header = key_type_byte | val_type_byte;
558 written += self.write_byte(map_type_header)?;
559 Ok(written)
560 }
561 }
562
563 fn write_map_end(&mut self) -> Result<usize> {
564 Ok(0)
565 }
566
567 fn flush(&mut self) -> Result<()> {
568 self.transport.flush().map_err(From::from)
569 }
570
571 fn write_byte(&mut self, b: u8) -> Result<usize> {
575 self.transport.write(&[b]).map_err(From::from)
576 }
577}
578
579pub(super) fn collection_type_to_u8(field_type: TType) -> u8 {
580 match field_type {
581 TType::Bool => 0x01,
582 f => type_to_u8(f),
583 }
584}
585
586pub(super) fn type_to_u8(field_type: TType) -> u8 {
587 match field_type {
588 TType::Stop => 0x00,
589 TType::I08 => 0x03, TType::I16 => 0x04,
591 TType::I32 => 0x05,
592 TType::I64 => 0x06,
593 TType::Double => 0x07,
594 TType::String => 0x08,
595 TType::List => 0x09,
596 TType::Set => 0x0A,
597 TType::Map => 0x0B,
598 TType::Struct => 0x0C,
599 _ => panic!("should not have attempted to convert {} to u8", field_type),
600 }
601}
602
603pub(super) fn collection_u8_to_type(b: u8) -> Result<TType> {
604 match b {
605 0x01 => Ok(TType::Bool),
606 o => u8_to_type(o),
607 }
608}
609
610pub(super) fn u8_to_type(b: u8) -> Result<TType> {
611 match b {
612 0x00 => Ok(TType::Stop),
613 0x03 => Ok(TType::I08), 0x04 => Ok(TType::I16),
615 0x05 => Ok(TType::I32),
616 0x06 => Ok(TType::I64),
617 0x07 => Ok(TType::Double),
618 0x08 => Ok(TType::String),
619 0x09 => Ok(TType::List),
620 0x0A => Ok(TType::Set),
621 0x0B => Ok(TType::Map),
622 0x0C => Ok(TType::Struct),
623 unkn => Err(Error::Protocol(ProtocolError {
624 kind: ProtocolErrorKind::InvalidData,
625 message: format!("cannot convert {} into TType", unkn),
626 })),
627 }
628}