parquet_format_async_temp/thrift/protocol/
compact_stream.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::convert::{From, TryFrom};
19use std::io;
20
21use async_trait::async_trait;
22use futures::{AsyncRead, AsyncReadExt, AsyncSeek};
23use integer_encoding::VarIntAsyncReader;
24
25use super::compact::{
26    collection_u8_to_type, u8_to_type, COMPACT_PROTOCOL_ID, COMPACT_VERSION, COMPACT_VERSION_MASK,
27};
28use super::{
29    TFieldIdentifier, TInputStreamProtocol, TListIdentifier, TMapIdentifier, TMessageIdentifier,
30    TMessageType,
31};
32use super::{TSetIdentifier, TStructIdentifier, TType};
33use crate::thrift::{Error, ProtocolError, ProtocolErrorKind, Result};
34
35#[derive(Debug)]
36pub struct TCompactInputStreamProtocol<T: Send> {
37    // Identifier of the last field deserialized for a struct.
38    last_read_field_id: i16,
39    // Stack of the last read field ids (a new entry is added each time a nested struct is read).
40    read_field_id_stack: Vec<i16>,
41    // Boolean value for a field.
42    // Saved because boolean fields and their value are encoded in a single byte,
43    // and reading the field only occurs after the field id is read.
44    pending_read_bool_value: Option<bool>,
45    // Underlying transport used for byte-level operations.
46    transport: T,
47}
48
49impl<T: VarIntAsyncReader + AsyncRead + Unpin + Send> TCompactInputStreamProtocol<T> {
50    /// Create a `TCompactInputProtocol` that reads bytes from `transport`.
51    pub fn new(transport: T) -> Self {
52        Self {
53            last_read_field_id: 0,
54            read_field_id_stack: Vec::new(),
55            pending_read_bool_value: None,
56            transport,
57        }
58    }
59
60    async fn read_list_set_begin(&mut self) -> Result<(TType, i32)> {
61        let header = self.read_byte().await?;
62        let element_type = collection_u8_to_type(header & 0x0F)?;
63
64        let possible_element_count = (header & 0xF0) >> 4;
65        let element_count = if possible_element_count != 15 {
66            // high bits set high if count and type encoded separately
67            possible_element_count as i32
68        } else {
69            self.transport.read_varint_async::<u32>().await? as i32
70        };
71
72        Ok((element_type, element_count))
73    }
74}
75
76#[async_trait]
77impl<T: VarIntAsyncReader + AsyncRead + Unpin + Send> TInputStreamProtocol
78    for TCompactInputStreamProtocol<T>
79{
80    async fn read_message_begin(&mut self) -> Result<TMessageIdentifier> {
81        let compact_id = self.read_byte().await?;
82        if compact_id != COMPACT_PROTOCOL_ID {
83            Err(Error::Protocol(ProtocolError {
84                kind: ProtocolErrorKind::BadVersion,
85                message: format!("invalid compact protocol header {:?}", compact_id),
86            }))
87        } else {
88            Ok(())
89        }?;
90
91        let type_and_byte = self.read_byte().await?;
92        let received_version = type_and_byte & COMPACT_VERSION_MASK;
93        if received_version != COMPACT_VERSION {
94            Err(Error::Protocol(ProtocolError {
95                kind: ProtocolErrorKind::BadVersion,
96                message: format!(
97                    "cannot process compact protocol version {:?}",
98                    received_version
99                ),
100            }))
101        } else {
102            Ok(())
103        }?;
104
105        // NOTE: unsigned right shift will pad with 0s
106        let message_type: TMessageType = TMessageType::try_from(type_and_byte >> 5)?;
107        // writing side wrote signed sequence number as u32 to avoid zigzag encoding
108        let sequence_number = self.transport.read_varint_async::<u32>().await? as i32;
109        let service_call_name = self.read_string().await?;
110
111        self.last_read_field_id = 0;
112
113        Ok(TMessageIdentifier::new(
114            service_call_name,
115            message_type,
116            sequence_number,
117        ))
118    }
119
120    async fn read_message_end(&mut self) -> Result<()> {
121        Ok(())
122    }
123
124    async fn read_struct_begin(&mut self) -> Result<Option<TStructIdentifier>> {
125        self.read_field_id_stack.push(self.last_read_field_id);
126        self.last_read_field_id = 0;
127        Ok(None)
128    }
129
130    async fn read_struct_end(&mut self) -> Result<()> {
131        self.last_read_field_id = self
132            .read_field_id_stack
133            .pop()
134            .expect("should have previous field ids");
135        Ok(())
136    }
137
138    async fn read_field_begin(&mut self) -> Result<TFieldIdentifier> {
139        // we can read at least one byte, which is:
140        // - the type
141        // - the field delta and the type
142        let field_type = self.read_byte().await?;
143        let field_delta = (field_type & 0xF0) >> 4;
144        let field_type = match field_type & 0x0F {
145            0x01 => {
146                self.pending_read_bool_value = Some(true);
147                Ok(TType::Bool)
148            }
149            0x02 => {
150                self.pending_read_bool_value = Some(false);
151                Ok(TType::Bool)
152            }
153            ttu8 => u8_to_type(ttu8),
154        }?;
155
156        match field_type {
157            TType::Stop => Ok(
158                TFieldIdentifier::new::<Option<String>, String, Option<i16>>(
159                    None,
160                    TType::Stop,
161                    None,
162                ),
163            ),
164            _ => {
165                if field_delta != 0 {
166                    self.last_read_field_id += field_delta as i16;
167                } else {
168                    self.last_read_field_id = self.read_i16().await?;
169                };
170
171                Ok(TFieldIdentifier {
172                    name: None,
173                    field_type,
174                    id: Some(self.last_read_field_id),
175                })
176            }
177        }
178    }
179
180    async fn read_field_end(&mut self) -> Result<()> {
181        Ok(())
182    }
183
184    async fn read_bool(&mut self) -> Result<bool> {
185        match self.pending_read_bool_value.take() {
186            Some(b) => Ok(b),
187            None => {
188                let b = self.read_byte().await?;
189                match b {
190                    0x01 => Ok(true),
191                    0x02 => Ok(false),
192                    unkn => Err(Error::Protocol(ProtocolError {
193                        kind: ProtocolErrorKind::InvalidData,
194                        message: format!("cannot convert {} into bool", unkn),
195                    })),
196                }
197            }
198        }
199    }
200
201    async fn read_bytes(&mut self) -> Result<Vec<u8>> {
202        let len = self.transport.read_varint_async::<u32>().await?;
203        let mut buf = vec![0u8; len as usize];
204        self.transport
205            .read_exact(&mut buf)
206            .await
207            .map_err(From::from)
208            .map(|_| buf)
209    }
210
211    async fn read_i8(&mut self) -> Result<i8> {
212        self.read_byte().await.map(|i| i as i8)
213    }
214
215    async fn read_i16(&mut self) -> Result<i16> {
216        self.transport
217            .read_varint_async::<i16>()
218            .await
219            .map_err(From::from)
220    }
221
222    async fn read_i32(&mut self) -> Result<i32> {
223        self.transport
224            .read_varint_async::<i32>()
225            .await
226            .map_err(From::from)
227    }
228
229    async fn read_i64(&mut self) -> Result<i64> {
230        self.transport
231            .read_varint_async::<i64>()
232            .await
233            .map_err(From::from)
234    }
235
236    async fn read_double(&mut self) -> Result<f64> {
237        let mut buf = [0; 8];
238        self.transport.read_exact(&mut buf).await?;
239        let r = f64::from_le_bytes(buf);
240        Ok(r)
241    }
242
243    async fn read_string(&mut self) -> Result<String> {
244        let bytes = self.read_bytes().await?;
245        String::from_utf8(bytes).map_err(From::from)
246    }
247
248    async fn read_list_begin(&mut self) -> Result<TListIdentifier> {
249        let (element_type, element_count) = self.read_list_set_begin().await?;
250        Ok(TListIdentifier::new(element_type, element_count))
251    }
252
253    async fn read_list_end(&mut self) -> Result<()> {
254        Ok(())
255    }
256
257    async fn read_set_begin(&mut self) -> Result<TSetIdentifier> {
258        let (element_type, element_count) = self.read_list_set_begin().await?;
259        Ok(TSetIdentifier::new(element_type, element_count))
260    }
261
262    async fn read_set_end(&mut self) -> Result<()> {
263        Ok(())
264    }
265
266    async fn read_map_begin(&mut self) -> Result<TMapIdentifier> {
267        let element_count = self.transport.read_varint_async::<u32>().await? as i32;
268        if element_count == 0 {
269            Ok(TMapIdentifier::new(None, None, 0))
270        } else {
271            let type_header = self.read_byte().await?;
272            let key_type = collection_u8_to_type((type_header & 0xF0) >> 4)?;
273            let val_type = collection_u8_to_type(type_header & 0x0F)?;
274            Ok(TMapIdentifier::new(key_type, val_type, element_count))
275        }
276    }
277
278    async fn read_map_end(&mut self) -> Result<()> {
279        Ok(())
280    }
281
282    // utility
283    //
284
285    async fn read_byte(&mut self) -> Result<u8> {
286        let mut buf = [0u8; 1];
287        self.transport
288            .read_exact(&mut buf)
289            .await
290            .map_err(From::from)
291            .map(|_| buf[0])
292    }
293}
294
295impl<T> AsyncSeek for TCompactInputStreamProtocol<T>
296where
297    T: AsyncSeek + Unpin + Send,
298{
299    fn poll_seek(
300        mut self: std::pin::Pin<&mut Self>,
301        cx: &mut std::task::Context<'_>,
302        pos: io::SeekFrom,
303    ) -> std::task::Poll<io::Result<u64>> {
304        std::pin::Pin::new(&mut self.transport).poll_seek(cx, pos)
305    }
306}