parquet_format_async_temp/thrift/protocol/
compact.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17use std::convert::{From, TryFrom};
18use std::io;
19use std::io::{Read, Write};
20
21use integer_encoding::{VarIntReader, VarIntWriter};
22
23use super::super::{Error, ProtocolError, ProtocolErrorKind, Result};
24use super::{
25    TFieldIdentifier, TInputProtocol, TListIdentifier, TMapIdentifier, TMessageIdentifier,
26    TMessageType,
27};
28use super::{TOutputProtocol, TSetIdentifier, TStructIdentifier, TType};
29
30pub(super) const COMPACT_PROTOCOL_ID: u8 = 0x82;
31pub(super) const COMPACT_VERSION: u8 = 0x01;
32pub(super) const COMPACT_VERSION_MASK: u8 = 0x1F;
33
34/// Read messages encoded in the Thrift compact protocol.
35///
36/// # Examples
37///
38/// Create and use a `TCompactInputProtocol`.
39///
40/// ```no_run
41/// use thrift::protocol::{TCompactInputProtocol, TInputProtocol};
42/// use thrift::transport::TTcpChannel;
43///
44/// let mut channel = TTcpChannel::new();
45/// channel.open("localhost:9090").unwrap();
46///
47/// let mut protocol = TCompactInputProtocol::new(channel);
48///
49/// let recvd_bool = protocol.read_bool().unwrap();
50/// let recvd_string = protocol.read_string().unwrap();
51/// ```
52#[derive(Debug)]
53pub struct TCompactInputProtocol<T>
54where
55    T: Read,
56{
57    // Identifier of the last field deserialized for a struct.
58    last_read_field_id: i16,
59    // Stack of the last read field ids (a new entry is added each time a nested struct is read).
60    read_field_id_stack: Vec<i16>,
61    // Boolean value for a field.
62    // Saved because boolean fields and their value are encoded in a single byte,
63    // and reading the field only occurs after the field id is read.
64    pending_read_bool_value: Option<bool>,
65    // Underlying transport used for byte-level operations.
66    transport: T,
67}
68
69impl<T> TCompactInputProtocol<T>
70where
71    T: Read,
72{
73    /// Create a `TCompactInputProtocol` that reads bytes from `transport`.
74    pub fn new(transport: T) -> TCompactInputProtocol<T> {
75        TCompactInputProtocol {
76            last_read_field_id: 0,
77            read_field_id_stack: Vec::new(),
78            pending_read_bool_value: None,
79            transport,
80        }
81    }
82
83    fn read_list_set_begin(&mut self) -> Result<(TType, i32)> {
84        let header = self.read_byte()?;
85        let element_type = collection_u8_to_type(header & 0x0F)?;
86
87        let possible_element_count = (header & 0xF0) >> 4;
88        let element_count = if possible_element_count != 15 {
89            // high bits set high if count and type encoded separately
90            possible_element_count as i32
91        } else {
92            self.transport.read_varint::<u32>()? as i32
93        };
94
95        Ok((element_type, element_count))
96    }
97}
98
99impl<T> TInputProtocol for TCompactInputProtocol<T>
100where
101    T: Read,
102{
103    fn read_message_begin(&mut self) -> Result<TMessageIdentifier> {
104        let compact_id = self.read_byte()?;
105        if compact_id != COMPACT_PROTOCOL_ID {
106            Err(Error::Protocol(ProtocolError {
107                kind: ProtocolErrorKind::BadVersion,
108                message: format!("invalid compact protocol header {:?}", compact_id),
109            }))
110        } else {
111            Ok(())
112        }?;
113
114        let type_and_byte = self.read_byte()?;
115        let received_version = type_and_byte & COMPACT_VERSION_MASK;
116        if received_version != COMPACT_VERSION {
117            Err(Error::Protocol(ProtocolError {
118                kind: ProtocolErrorKind::BadVersion,
119                message: format!(
120                    "cannot process compact protocol version {:?}",
121                    received_version
122                ),
123            }))
124        } else {
125            Ok(())
126        }?;
127
128        // NOTE: unsigned right shift will pad with 0s
129        let message_type: TMessageType = TMessageType::try_from(type_and_byte >> 5)?;
130        // writing side wrote signed sequence number as u32 to avoid zigzag encoding
131        let sequence_number = self.transport.read_varint::<u32>()? as i32;
132        let service_call_name = self.read_string()?;
133
134        self.last_read_field_id = 0;
135
136        Ok(TMessageIdentifier::new(
137            service_call_name,
138            message_type,
139            sequence_number,
140        ))
141    }
142
143    fn read_message_end(&mut self) -> Result<()> {
144        Ok(())
145    }
146
147    fn read_struct_begin(&mut self) -> Result<Option<TStructIdentifier>> {
148        self.read_field_id_stack.push(self.last_read_field_id);
149        self.last_read_field_id = 0;
150        Ok(None)
151    }
152
153    fn read_struct_end(&mut self) -> Result<()> {
154        self.last_read_field_id = self
155            .read_field_id_stack
156            .pop()
157            .expect("should have previous field ids");
158        Ok(())
159    }
160
161    fn read_field_begin(&mut self) -> Result<TFieldIdentifier> {
162        // we can read at least one byte, which is:
163        // - the type
164        // - the field delta and the type
165        let field_type = self.read_byte()?;
166        let field_delta = (field_type & 0xF0) >> 4;
167        let field_type = match field_type & 0x0F {
168            0x01 => {
169                self.pending_read_bool_value = Some(true);
170                Ok(TType::Bool)
171            }
172            0x02 => {
173                self.pending_read_bool_value = Some(false);
174                Ok(TType::Bool)
175            }
176            ttu8 => u8_to_type(ttu8),
177        }?;
178
179        match field_type {
180            TType::Stop => Ok(
181                TFieldIdentifier::new::<Option<String>, String, Option<i16>>(
182                    None,
183                    TType::Stop,
184                    None,
185                ),
186            ),
187            _ => {
188                if field_delta != 0 {
189                    self.last_read_field_id += field_delta as i16;
190                } else {
191                    self.last_read_field_id = self.read_i16()?;
192                };
193
194                Ok(TFieldIdentifier {
195                    name: None,
196                    field_type,
197                    id: Some(self.last_read_field_id),
198                })
199            }
200        }
201    }
202
203    fn read_field_end(&mut self) -> Result<()> {
204        Ok(())
205    }
206
207    fn read_bool(&mut self) -> Result<bool> {
208        match self.pending_read_bool_value.take() {
209            Some(b) => Ok(b),
210            None => {
211                let b = self.read_byte()?;
212                match b {
213                    0x01 => Ok(true),
214                    0x02 => Ok(false),
215                    unkn => Err(Error::Protocol(ProtocolError {
216                        kind: ProtocolErrorKind::InvalidData,
217                        message: format!("cannot convert {} into bool", unkn),
218                    })),
219                }
220            }
221        }
222    }
223
224    fn read_bytes(&mut self) -> Result<Vec<u8>> {
225        let len = self.transport.read_varint::<u32>()?;
226        let mut buf = vec![0u8; len as usize];
227        self.transport
228            .read_exact(&mut buf)
229            .map_err(From::from)
230            .map(|_| buf)
231    }
232
233    fn read_i8(&mut self) -> Result<i8> {
234        self.read_byte().map(|i| i as i8)
235    }
236
237    fn read_i16(&mut self) -> Result<i16> {
238        self.transport.read_varint::<i16>().map_err(From::from)
239    }
240
241    fn read_i32(&mut self) -> Result<i32> {
242        self.transport.read_varint::<i32>().map_err(From::from)
243    }
244
245    fn read_i64(&mut self) -> Result<i64> {
246        self.transport.read_varint::<i64>().map_err(From::from)
247    }
248
249    fn read_double(&mut self) -> Result<f64> {
250        let mut data = [0u8; 8];
251        self.transport.read_exact(&mut data)?;
252        Ok(f64::from_le_bytes(data))
253    }
254
255    fn read_string(&mut self) -> Result<String> {
256        let bytes = self.read_bytes()?;
257        String::from_utf8(bytes).map_err(From::from)
258    }
259
260    fn read_list_begin(&mut self) -> Result<TListIdentifier> {
261        let (element_type, element_count) = self.read_list_set_begin()?;
262        Ok(TListIdentifier::new(element_type, element_count))
263    }
264
265    fn read_list_end(&mut self) -> Result<()> {
266        Ok(())
267    }
268
269    fn read_set_begin(&mut self) -> Result<TSetIdentifier> {
270        let (element_type, element_count) = self.read_list_set_begin()?;
271        Ok(TSetIdentifier::new(element_type, element_count))
272    }
273
274    fn read_set_end(&mut self) -> Result<()> {
275        Ok(())
276    }
277
278    fn read_map_begin(&mut self) -> Result<TMapIdentifier> {
279        let element_count = self.transport.read_varint::<u32>()? as i32;
280        if element_count == 0 {
281            Ok(TMapIdentifier::new(None, None, 0))
282        } else {
283            let type_header = self.read_byte()?;
284            let key_type = collection_u8_to_type((type_header & 0xF0) >> 4)?;
285            let val_type = collection_u8_to_type(type_header & 0x0F)?;
286            Ok(TMapIdentifier::new(key_type, val_type, element_count))
287        }
288    }
289
290    fn read_map_end(&mut self) -> Result<()> {
291        Ok(())
292    }
293
294    // utility
295    //
296
297    fn read_byte(&mut self) -> Result<u8> {
298        let mut buf = [0u8; 1];
299        self.transport
300            .read_exact(&mut buf)
301            .map_err(From::from)
302            .map(|_| buf[0])
303    }
304}
305
306impl<T> io::Seek for TCompactInputProtocol<T>
307where
308    T: io::Seek + Read,
309{
310    fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> {
311        self.transport.seek(pos)
312    }
313}
314
315/// Write messages using the Thrift compact protocol.
316///
317/// # Examples
318///
319/// Create and use a `TCompactOutputProtocol`.
320///
321/// ```no_run
322/// use thrift::protocol::{TCompactOutputProtocol, TOutputProtocol};
323/// use thrift::transport::TTcpChannel;
324///
325/// let mut channel = TTcpChannel::new();
326/// channel.open("localhost:9090").unwrap();
327///
328/// let mut protocol = TCompactOutputProtocol::new(channel);
329///
330/// protocol.write_bool(true).unwrap();
331/// protocol.write_string("test_string").unwrap();
332/// ```
333#[derive(Debug)]
334pub struct TCompactOutputProtocol<T>
335where
336    T: Write,
337{
338    // Identifier of the last field serialized for a struct.
339    last_write_field_id: i16,
340    // Stack of the last written field ids (new entry added each time a nested struct is written).
341    write_field_id_stack: Vec<i16>,
342    // Field identifier of the boolean field to be written.
343    // Saved because boolean fields and their value are encoded in a single byte
344    pending_write_bool_field_identifier: Option<TFieldIdentifier>,
345    // Underlying transport used for byte-level operations.
346    transport: T,
347}
348
349impl<T> TCompactOutputProtocol<T>
350where
351    T: Write,
352{
353    /// Create a `TCompactOutputProtocol` that writes bytes to `transport`.
354    pub fn new(transport: T) -> TCompactOutputProtocol<T> {
355        TCompactOutputProtocol {
356            last_write_field_id: 0,
357            write_field_id_stack: Vec::new(),
358            pending_write_bool_field_identifier: None,
359            transport,
360        }
361    }
362
363    // FIXME: field_type as unconstrained u8 is bad
364    fn write_field_header(&mut self, field_type: u8, field_id: i16) -> Result<usize> {
365        let mut written = 0;
366
367        let field_delta = field_id - self.last_write_field_id;
368        if field_delta > 0 && field_delta < 15 {
369            written += self.write_byte(((field_delta as u8) << 4) | field_type)?;
370        } else {
371            written += self.write_byte(field_type)?;
372            written += self.write_i16(field_id)?;
373        }
374        self.last_write_field_id = field_id;
375        Ok(written)
376    }
377
378    fn write_list_set_begin(&mut self, element_type: TType, element_count: i32) -> Result<usize> {
379        let mut written = 0;
380
381        let elem_identifier = collection_type_to_u8(element_type);
382        if element_count <= 14 {
383            let header = (element_count as u8) << 4 | elem_identifier;
384            written += self.write_byte(header)?;
385        } else {
386            let header = 0xF0 | elem_identifier;
387            written += self.write_byte(header)?;
388            // element count is strictly positive as per the spec, so
389            // cast i32 as u32 so that varint writing won't use zigzag encoding
390            written += self.transport.write_varint(element_count as u32)?;
391        }
392        Ok(written)
393    }
394
395    fn assert_no_pending_bool_write(&self) {
396        if let Some(ref f) = self.pending_write_bool_field_identifier {
397            panic!("pending bool field {:?} not written", f)
398        }
399    }
400}
401
402impl<T> TOutputProtocol for TCompactOutputProtocol<T>
403where
404    T: Write,
405{
406    fn write_message_begin(&mut self, identifier: &TMessageIdentifier) -> Result<usize> {
407        let mut written = 0;
408        written += self.write_byte(COMPACT_PROTOCOL_ID)?;
409        written += self.write_byte((u8::from(identifier.message_type) << 5) | COMPACT_VERSION)?;
410        // cast i32 as u32 so that varint writing won't use zigzag encoding
411        written += self
412            .transport
413            .write_varint(identifier.sequence_number as u32)?;
414        written += self.write_string(&identifier.name)?;
415        Ok(written)
416    }
417
418    fn write_message_end(&mut self) -> Result<usize> {
419        self.assert_no_pending_bool_write();
420        Ok(0)
421    }
422
423    fn write_struct_begin(&mut self, _: &TStructIdentifier) -> Result<usize> {
424        self.write_field_id_stack.push(self.last_write_field_id);
425        self.last_write_field_id = 0;
426        Ok(0)
427    }
428
429    fn write_struct_end(&mut self) -> Result<usize> {
430        self.assert_no_pending_bool_write();
431        self.last_write_field_id = self
432            .write_field_id_stack
433            .pop()
434            .expect("should have previous field ids");
435        Ok(0)
436    }
437
438    fn write_field_begin(&mut self, identifier: &TFieldIdentifier) -> Result<usize> {
439        match identifier.field_type {
440            TType::Bool => {
441                if self.pending_write_bool_field_identifier.is_some() {
442                    panic!(
443                        "should not have a pending bool while writing another bool with id: \
444                         {:?}",
445                        identifier
446                    )
447                }
448                self.pending_write_bool_field_identifier = Some(identifier.clone());
449                Ok(0)
450            }
451            _ => {
452                let field_type = type_to_u8(identifier.field_type);
453                let field_id = identifier.id.expect("non-stop field should have field id");
454                self.write_field_header(field_type, field_id)
455            }
456        }
457    }
458
459    fn write_field_end(&mut self) -> Result<usize> {
460        self.assert_no_pending_bool_write();
461        Ok(0)
462    }
463
464    fn write_field_stop(&mut self) -> Result<usize> {
465        self.assert_no_pending_bool_write();
466        self.write_byte(type_to_u8(TType::Stop))
467    }
468
469    fn write_bool(&mut self, b: bool) -> Result<usize> {
470        match self.pending_write_bool_field_identifier.take() {
471            Some(pending) => {
472                let field_id = pending.id.expect("bool field should have a field id");
473                let field_type_as_u8 = if b { 0x01 } else { 0x02 };
474                self.write_field_header(field_type_as_u8, field_id)
475            }
476            None => {
477                if b {
478                    self.write_byte(0x01)
479                } else {
480                    self.write_byte(0x02)
481                }
482            }
483        }
484    }
485
486    fn write_bytes(&mut self, b: &[u8]) -> Result<usize> {
487        let mut written = 0;
488        // length is strictly positive as per the spec, so
489        // cast i32 as u32 so that varint writing won't use zigzag encoding
490        written += self.transport.write_varint(b.len() as u32)?;
491        self.transport.write_all(b)?;
492        written += b.len();
493        Ok(written)
494    }
495
496    fn write_i8(&mut self, i: i8) -> Result<usize> {
497        self.write_byte(i as u8)
498    }
499
500    fn write_i16(&mut self, i: i16) -> Result<usize> {
501        self.transport.write_varint(i).map_err(From::from)
502    }
503
504    fn write_i32(&mut self, i: i32) -> Result<usize> {
505        self.transport.write_varint(i).map_err(From::from)
506    }
507
508    fn write_i64(&mut self, i: i64) -> Result<usize> {
509        self.transport.write_varint(i).map_err(From::from)
510    }
511
512    fn write_double(&mut self, d: f64) -> Result<usize> {
513        let bytes = d.to_le_bytes();
514        self.transport.write_all(&bytes)?;
515        Ok(8)
516    }
517
518    fn write_string(&mut self, s: &str) -> Result<usize> {
519        self.write_bytes(s.as_bytes())
520    }
521
522    fn write_list_begin(&mut self, identifier: &TListIdentifier) -> Result<usize> {
523        self.write_list_set_begin(identifier.element_type, identifier.size)
524    }
525
526    fn write_list_end(&mut self) -> Result<usize> {
527        Ok(0)
528    }
529
530    fn write_set_begin(&mut self, identifier: &TSetIdentifier) -> Result<usize> {
531        self.write_list_set_begin(identifier.element_type, identifier.size)
532    }
533
534    fn write_set_end(&mut self) -> Result<usize> {
535        Ok(0)
536    }
537
538    fn write_map_begin(&mut self, identifier: &TMapIdentifier) -> Result<usize> {
539        if identifier.size == 0 {
540            self.write_byte(0)
541        } else {
542            let mut written = 0;
543            // element count is strictly positive as per the spec, so
544            // cast i32 as u32 so that varint writing won't use zigzag encoding
545            written += self.transport.write_varint(identifier.size as u32)?;
546
547            let key_type = identifier
548                .key_type
549                .expect("map identifier to write should contain key type");
550            let key_type_byte = collection_type_to_u8(key_type) << 4;
551
552            let val_type = identifier
553                .value_type
554                .expect("map identifier to write should contain value type");
555            let val_type_byte = collection_type_to_u8(val_type);
556
557            let map_type_header = key_type_byte | val_type_byte;
558            written += self.write_byte(map_type_header)?;
559            Ok(written)
560        }
561    }
562
563    fn write_map_end(&mut self) -> Result<usize> {
564        Ok(0)
565    }
566
567    fn flush(&mut self) -> Result<()> {
568        self.transport.flush().map_err(From::from)
569    }
570
571    // utility
572    //
573
574    fn write_byte(&mut self, b: u8) -> Result<usize> {
575        self.transport.write(&[b]).map_err(From::from)
576    }
577}
578
579pub(super) fn collection_type_to_u8(field_type: TType) -> u8 {
580    match field_type {
581        TType::Bool => 0x01,
582        f => type_to_u8(f),
583    }
584}
585
586pub(super) fn type_to_u8(field_type: TType) -> u8 {
587    match field_type {
588        TType::Stop => 0x00,
589        TType::I08 => 0x03, // equivalent to TType::Byte
590        TType::I16 => 0x04,
591        TType::I32 => 0x05,
592        TType::I64 => 0x06,
593        TType::Double => 0x07,
594        TType::String => 0x08,
595        TType::List => 0x09,
596        TType::Set => 0x0A,
597        TType::Map => 0x0B,
598        TType::Struct => 0x0C,
599        _ => panic!("should not have attempted to convert {} to u8", field_type),
600    }
601}
602
603pub(super) fn collection_u8_to_type(b: u8) -> Result<TType> {
604    match b {
605        0x01 => Ok(TType::Bool),
606        o => u8_to_type(o),
607    }
608}
609
610pub(super) fn u8_to_type(b: u8) -> Result<TType> {
611    match b {
612        0x00 => Ok(TType::Stop),
613        0x03 => Ok(TType::I08), // equivalent to TType::Byte
614        0x04 => Ok(TType::I16),
615        0x05 => Ok(TType::I32),
616        0x06 => Ok(TType::I64),
617        0x07 => Ok(TType::Double),
618        0x08 => Ok(TType::String),
619        0x09 => Ok(TType::List),
620        0x0A => Ok(TType::Set),
621        0x0B => Ok(TType::Map),
622        0x0C => Ok(TType::Struct),
623        unkn => Err(Error::Protocol(ProtocolError {
624            kind: ProtocolErrorKind::InvalidData,
625            message: format!("cannot convert {} into TType", unkn),
626        })),
627    }
628}