parquet_format_async_temp/thrift/protocol/
compact_stream.rs1use std::convert::{From, TryFrom};
19use std::io;
20
21use async_trait::async_trait;
22use futures::{AsyncRead, AsyncReadExt, AsyncSeek};
23use integer_encoding::VarIntAsyncReader;
24
25use super::compact::{
26 collection_u8_to_type, u8_to_type, COMPACT_PROTOCOL_ID, COMPACT_VERSION, COMPACT_VERSION_MASK,
27};
28use super::{
29 TFieldIdentifier, TInputStreamProtocol, TListIdentifier, TMapIdentifier, TMessageIdentifier,
30 TMessageType,
31};
32use super::{TSetIdentifier, TStructIdentifier, TType};
33use crate::thrift::{Error, ProtocolError, ProtocolErrorKind, Result};
34
35#[derive(Debug)]
36pub struct TCompactInputStreamProtocol<T: Send> {
37 last_read_field_id: i16,
39 read_field_id_stack: Vec<i16>,
41 pending_read_bool_value: Option<bool>,
45 transport: T,
47}
48
49impl<T: VarIntAsyncReader + AsyncRead + Unpin + Send> TCompactInputStreamProtocol<T> {
50 pub fn new(transport: T) -> Self {
52 Self {
53 last_read_field_id: 0,
54 read_field_id_stack: Vec::new(),
55 pending_read_bool_value: None,
56 transport,
57 }
58 }
59
60 async fn read_list_set_begin(&mut self) -> Result<(TType, i32)> {
61 let header = self.read_byte().await?;
62 let element_type = collection_u8_to_type(header & 0x0F)?;
63
64 let possible_element_count = (header & 0xF0) >> 4;
65 let element_count = if possible_element_count != 15 {
66 possible_element_count as i32
68 } else {
69 self.transport.read_varint_async::<u32>().await? as i32
70 };
71
72 Ok((element_type, element_count))
73 }
74}
75
76#[async_trait]
77impl<T: VarIntAsyncReader + AsyncRead + Unpin + Send> TInputStreamProtocol
78 for TCompactInputStreamProtocol<T>
79{
80 async fn read_message_begin(&mut self) -> Result<TMessageIdentifier> {
81 let compact_id = self.read_byte().await?;
82 if compact_id != COMPACT_PROTOCOL_ID {
83 Err(Error::Protocol(ProtocolError {
84 kind: ProtocolErrorKind::BadVersion,
85 message: format!("invalid compact protocol header {:?}", compact_id),
86 }))
87 } else {
88 Ok(())
89 }?;
90
91 let type_and_byte = self.read_byte().await?;
92 let received_version = type_and_byte & COMPACT_VERSION_MASK;
93 if received_version != COMPACT_VERSION {
94 Err(Error::Protocol(ProtocolError {
95 kind: ProtocolErrorKind::BadVersion,
96 message: format!(
97 "cannot process compact protocol version {:?}",
98 received_version
99 ),
100 }))
101 } else {
102 Ok(())
103 }?;
104
105 let message_type: TMessageType = TMessageType::try_from(type_and_byte >> 5)?;
107 let sequence_number = self.transport.read_varint_async::<u32>().await? as i32;
109 let service_call_name = self.read_string().await?;
110
111 self.last_read_field_id = 0;
112
113 Ok(TMessageIdentifier::new(
114 service_call_name,
115 message_type,
116 sequence_number,
117 ))
118 }
119
120 async fn read_message_end(&mut self) -> Result<()> {
121 Ok(())
122 }
123
124 async fn read_struct_begin(&mut self) -> Result<Option<TStructIdentifier>> {
125 self.read_field_id_stack.push(self.last_read_field_id);
126 self.last_read_field_id = 0;
127 Ok(None)
128 }
129
130 async fn read_struct_end(&mut self) -> Result<()> {
131 self.last_read_field_id = self
132 .read_field_id_stack
133 .pop()
134 .expect("should have previous field ids");
135 Ok(())
136 }
137
138 async fn read_field_begin(&mut self) -> Result<TFieldIdentifier> {
139 let field_type = self.read_byte().await?;
143 let field_delta = (field_type & 0xF0) >> 4;
144 let field_type = match field_type & 0x0F {
145 0x01 => {
146 self.pending_read_bool_value = Some(true);
147 Ok(TType::Bool)
148 }
149 0x02 => {
150 self.pending_read_bool_value = Some(false);
151 Ok(TType::Bool)
152 }
153 ttu8 => u8_to_type(ttu8),
154 }?;
155
156 match field_type {
157 TType::Stop => Ok(
158 TFieldIdentifier::new::<Option<String>, String, Option<i16>>(
159 None,
160 TType::Stop,
161 None,
162 ),
163 ),
164 _ => {
165 if field_delta != 0 {
166 self.last_read_field_id += field_delta as i16;
167 } else {
168 self.last_read_field_id = self.read_i16().await?;
169 };
170
171 Ok(TFieldIdentifier {
172 name: None,
173 field_type,
174 id: Some(self.last_read_field_id),
175 })
176 }
177 }
178 }
179
180 async fn read_field_end(&mut self) -> Result<()> {
181 Ok(())
182 }
183
184 async fn read_bool(&mut self) -> Result<bool> {
185 match self.pending_read_bool_value.take() {
186 Some(b) => Ok(b),
187 None => {
188 let b = self.read_byte().await?;
189 match b {
190 0x01 => Ok(true),
191 0x02 => Ok(false),
192 unkn => Err(Error::Protocol(ProtocolError {
193 kind: ProtocolErrorKind::InvalidData,
194 message: format!("cannot convert {} into bool", unkn),
195 })),
196 }
197 }
198 }
199 }
200
201 async fn read_bytes(&mut self) -> Result<Vec<u8>> {
202 let len = self.transport.read_varint_async::<u32>().await?;
203 let mut buf = vec![0u8; len as usize];
204 self.transport
205 .read_exact(&mut buf)
206 .await
207 .map_err(From::from)
208 .map(|_| buf)
209 }
210
211 async fn read_i8(&mut self) -> Result<i8> {
212 self.read_byte().await.map(|i| i as i8)
213 }
214
215 async fn read_i16(&mut self) -> Result<i16> {
216 self.transport
217 .read_varint_async::<i16>()
218 .await
219 .map_err(From::from)
220 }
221
222 async fn read_i32(&mut self) -> Result<i32> {
223 self.transport
224 .read_varint_async::<i32>()
225 .await
226 .map_err(From::from)
227 }
228
229 async fn read_i64(&mut self) -> Result<i64> {
230 self.transport
231 .read_varint_async::<i64>()
232 .await
233 .map_err(From::from)
234 }
235
236 async fn read_double(&mut self) -> Result<f64> {
237 let mut buf = [0; 8];
238 self.transport.read_exact(&mut buf).await?;
239 let r = f64::from_le_bytes(buf);
240 Ok(r)
241 }
242
243 async fn read_string(&mut self) -> Result<String> {
244 let bytes = self.read_bytes().await?;
245 String::from_utf8(bytes).map_err(From::from)
246 }
247
248 async fn read_list_begin(&mut self) -> Result<TListIdentifier> {
249 let (element_type, element_count) = self.read_list_set_begin().await?;
250 Ok(TListIdentifier::new(element_type, element_count))
251 }
252
253 async fn read_list_end(&mut self) -> Result<()> {
254 Ok(())
255 }
256
257 async fn read_set_begin(&mut self) -> Result<TSetIdentifier> {
258 let (element_type, element_count) = self.read_list_set_begin().await?;
259 Ok(TSetIdentifier::new(element_type, element_count))
260 }
261
262 async fn read_set_end(&mut self) -> Result<()> {
263 Ok(())
264 }
265
266 async fn read_map_begin(&mut self) -> Result<TMapIdentifier> {
267 let element_count = self.transport.read_varint_async::<u32>().await? as i32;
268 if element_count == 0 {
269 Ok(TMapIdentifier::new(None, None, 0))
270 } else {
271 let type_header = self.read_byte().await?;
272 let key_type = collection_u8_to_type((type_header & 0xF0) >> 4)?;
273 let val_type = collection_u8_to_type(type_header & 0x0F)?;
274 Ok(TMapIdentifier::new(key_type, val_type, element_count))
275 }
276 }
277
278 async fn read_map_end(&mut self) -> Result<()> {
279 Ok(())
280 }
281
282 async fn read_byte(&mut self) -> Result<u8> {
286 let mut buf = [0u8; 1];
287 self.transport
288 .read_exact(&mut buf)
289 .await
290 .map_err(From::from)
291 .map(|_| buf[0])
292 }
293}
294
295impl<T> AsyncSeek for TCompactInputStreamProtocol<T>
296where
297 T: AsyncSeek + Unpin + Send,
298{
299 fn poll_seek(
300 mut self: std::pin::Pin<&mut Self>,
301 cx: &mut std::task::Context<'_>,
302 pos: io::SeekFrom,
303 ) -> std::task::Poll<io::Result<u64>> {
304 std::pin::Pin::new(&mut self.transport).poll_seek(cx, pos)
305 }
306}