compact_thrift_runtime/
protocol.rs

1use std::borrow::Cow;
2
3use std::io::ErrorKind;
4use std::io::Read;
5use std::io::Write;
6
7use std::str::from_utf8_unchecked;
8use std::str::from_utf8;
9
10use crate::ThriftError;
11use crate::uleb::*;
12
13pub const MAX_BINARY_LEN: usize = 1024*1024;
14pub const MAX_COLLECTION_LEN: usize = 10_000_000;
15
16#[inline(never)] // full field ids are uncommon and inlining this bloats the code
17fn read_full_field_id<'i, I: CompactThriftInput<'i> + ?Sized>(input: &mut I) -> Result<i16, ThriftError> {
18    let field_id = decode_uleb(input)? as u16;
19    Ok(zigzag_decode16(field_id))
20}
21
22#[inline(always)]
23fn zigzag_decode16(i: u16) -> i16 {
24    (i >> 1) as i16 ^ -((i & 1) as i16)
25}
26
27#[inline(always)]
28fn zigzag_decode32(i: u32) -> i32 {
29    (i >> 1) as i32 ^ -((i & 1) as i32)
30}
31
32#[inline(always)]
33fn zigzag_decode64(i: u64) -> i64 {
34    (i >> 1) as i64 ^ -((i & 1) as i64)
35}
36
37#[inline(always)]
38fn zigzag_encode16(i: i16) -> u16 {
39    ((i << 1) ^ (i >> 15)) as u16
40}
41
42#[inline(always)]
43fn zigzag_encode32(i: i32) -> u32 {
44    ((i << 1) ^ (i >> 31)) as u32
45}
46
47#[inline(always)]
48fn zigzag_encode64(i: i64) -> u64 {
49    ((i << 1) ^ (i >> 63)) as u64
50}
51
52pub trait CompactThriftInput<'i> {
53    fn read_byte(&mut self) -> Result<u8, ThriftError>;
54    fn read_len(&mut self) -> Result<usize, ThriftError> {
55        let len = decode_uleb(self)?;
56        Ok(len as _)
57    }
58    fn read_i16(&mut self) -> Result<i16, ThriftError> {
59        let i = decode_uleb(self)?;
60        Ok(zigzag_decode16(i as _))
61    }
62    fn read_i32(&mut self) -> Result<i32, ThriftError> {
63        let i = decode_uleb(self)?;
64        Ok(zigzag_decode32(i as _))
65    }
66    fn read_i64(&mut self) -> Result<i64, ThriftError> {
67        let i = decode_uleb(self)?;
68        Ok(zigzag_decode64(i as _))
69    }
70    fn read_double(&mut self) -> Result<f64, ThriftError>;
71    fn read_binary(&mut self) -> Result<Cow<'i, [u8]>, ThriftError>;
72    fn read_string(&mut self) -> Result<Cow<'i, str>, ThriftError> {
73        let binary = self.read_binary()?;
74        let _ = from_utf8(binary.as_ref()).map_err(|_| ThriftError::InvalidString)?;
75        // Safety: just checked for valid utf8
76        unsafe {
77            match binary {
78                Cow::Owned(v) => Ok(Cow::Owned(String::from_utf8_unchecked(v))),
79                Cow::Borrowed(v) => Ok(Cow::Borrowed(from_utf8_unchecked(v))),
80            }
81        }
82    }
83    fn skip_integer(&mut self) -> Result<(), ThriftError> {
84        let _ = self.read_i64()?;
85        Ok(())
86    }
87    fn skip_binary(&mut self) -> Result<(), ThriftError> {
88        self.read_binary()?;
89        Ok(())
90    }
91    fn skip_field(&mut self, field_type: u8) -> Result<(), ThriftError> {
92        skip_field(self, field_type, false)
93    }
94    fn read_field_header(&mut self, last_field_id: &mut i16) -> Result<u8, ThriftError> {
95        let field_header = self.read_byte()?;
96
97        if field_header == 0 {
98            return Ok(0)
99        }
100
101        let field_type = field_header & 0x0F;
102        let field_delta = field_header >> 4;
103        if field_delta != 0 {
104            *last_field_id += field_delta as i16;
105        } else {
106            *last_field_id = read_full_field_id(self)?;
107        }
108
109        Ok(field_type)
110    }
111}
112
113pub(crate) fn read_collection_len_and_type<'i, T: CompactThriftInput<'i> + ?Sized>(input: &mut T) -> Result<(u32, u8), ThriftError> {
114    let header = input.read_byte()?;
115    let field_type = header & 0x0F;
116    let maybe_len = (header & 0xF0) >> 4;
117    let len = if maybe_len != 0x0F {
118        // high bits set high if count and type encoded separately
119        maybe_len as usize
120    } else {
121        input.read_len()?
122    };
123
124    if len > MAX_COLLECTION_LEN {
125        return Err(ThriftError::InvalidCollectionLen)
126    }
127
128    Ok((len as u32, field_type))
129}
130
131pub(crate) fn read_map_len_and_types<'i, T: CompactThriftInput<'i> + ?Sized>(input: &mut T) -> Result<(u32, u8, u8), ThriftError> {
132    let len = input.read_len()?;
133    if len == 0 {
134        return Ok((0, 0, 0))
135    }
136    let entry_type = input.read_byte()?;
137    // TODO: check order of nibbles
138    let key_type = entry_type >> 4;
139    let val_type = entry_type & 0x0F;
140
141    if len > MAX_COLLECTION_LEN {
142        return Err(ThriftError::InvalidCollectionLen)
143    }
144
145    Ok((len as u32, key_type, val_type))
146}
147
148fn skip_collection<'i, T: CompactThriftInput<'i> + ?Sized>(input: &mut T) -> Result<(), ThriftError> {
149    let (len, element_type) = read_collection_len_and_type(input)?;
150    match element_type {
151        1..=3 => {
152            // TRUE, FALSE, i8 stored as single byte
153            for _ in 0..len {
154                let _ = input.read_byte()?;
155            }
156        }
157        4..=6 => {
158            // since we do not error on overlong sequences,
159            // skipping for all integer types works the same.
160            for _ in 0..len {
161                input.skip_integer()?;
162            }
163        }
164        7 => {
165            for _ in 0..len {
166                input.read_double()?;
167            }
168        }
169        8 => {
170            // thrift does not distinguish binary and string types in field_type,
171            // consequently there is no utf8 validation for skipped strings.
172            for _ in 0..len {
173                input.skip_binary()?;
174            }
175        }
176        9 | 10 => {
177            // list | set
178            for _ in 0..len {
179                skip_collection(input)?;
180            }
181        }
182        11 => {
183            // map
184            for _ in 0..len {
185                skip_map(input)?;
186            }
187        }
188        12 => {
189            for _ in 0..len {
190                skip_field(input, 12, false)?;
191            }
192        }
193        _ => {
194            return Err(ThriftError::InvalidType)
195        }
196    }
197    Ok(())
198}
199
200fn skip_map<'i, T: CompactThriftInput<'i> + ?Sized>(input: &mut T) -> Result<(), ThriftError> {
201    let (len, key_type, val_type) = read_map_len_and_types(input)?;
202    for _ in 0..len {
203        skip_field(input, key_type, true)?;
204        skip_field(input, val_type, true)?;
205    }
206    Ok(())
207}
208
209pub(crate) fn skip_field<'i, T: CompactThriftInput<'i> + ?Sized>(input: &mut T, field_type: u8, inside_collection: bool) -> Result<(), ThriftError> {
210    match field_type {
211        1..=2 => {
212            // boolean stored inside the field header outside of collections
213            if inside_collection {
214                input.read_byte()?;
215            }
216        }
217        3 => {
218            input.read_byte()?;
219        }
220        4..=6 => {
221            // since we do not error on overlong sequences,
222            // skipping for all integer types works the same.
223            input.skip_integer()?;
224        }
225        7 => {
226            input.read_double()?;
227        }
228        8 => {
229            // thrift does not distinguish binary and string types in field_type,
230            // consequently there is no utf8 validation for skipped strings.
231            input.skip_binary()?;
232        }
233        9 | 10 => {
234            // list | set
235            skip_collection(input)?;
236        }
237        11 => {
238            // map
239            skip_map(input)?;
240        }
241        12 => {
242            // struct | union
243            let mut last_field_id = 0_i16;
244            loop {
245                let field_type = input.read_field_header(&mut last_field_id)?;
246                if field_type == 0 {
247                    break;
248                }
249                skip_field(input, field_type, false)?;
250            }
251        }
252        _ => {
253            return Err(ThriftError::InvalidType)
254        }
255    }
256    Ok(())
257}
258
259#[inline]
260pub(crate) fn write_field_header<T: CompactThriftOutput>(output: &mut T, field_type: u8, field_id: i16, last_field_id: &mut i16) -> Result<(), ThriftError> {
261    let field_delta = field_id.wrapping_sub(*last_field_id);
262
263    if field_delta > 15 {
264        output.write_byte(field_type)?;
265        output.write_i16(field_delta)?
266    } else {
267        output.write_byte(field_type | ((field_delta as u8) << 4))?;
268    }
269    *last_field_id = field_id;
270    Ok(())
271}
272
273
274impl<R: Read + ?Sized> CompactThriftInput<'static> for R {
275    #[inline]
276    fn read_byte(&mut self) -> Result<u8, ThriftError> {
277        let mut buf = [0_u8; 1];
278        self.read_exact(&mut buf)?;
279        Ok(buf[0])
280    }
281
282    fn read_double(&mut self) -> Result<f64, ThriftError> {
283        let mut buf = [0_u8; 8];
284        self.read_exact(&mut buf)?;
285        Ok(f64::from_le_bytes(buf))
286    }
287
288    #[expect(clippy::uninit_vec)]
289    fn read_binary(&mut self) -> Result<Cow<'static, [u8]>, ThriftError> {
290        let len = self.read_len()?;
291        if len > MAX_BINARY_LEN {
292            return Err(ThriftError::InvalidBinaryLen);
293        }
294        let mut buf = Vec::with_capacity(len);
295        // Safety: we trust the Read implementation to only write into buf,
296        // and not to look at uninitialized bytes
297        unsafe {
298            buf.set_len(len);
299        }
300        self.read_exact(buf.as_mut_slice())?;
301        Ok(buf.into())
302    }
303
304}
305
306#[derive(Clone)]
307pub struct CompactThriftInputSlice<'a>(&'a [u8]);
308
309impl <'a> CompactThriftInputSlice<'a> {
310    pub fn new(slice: &'a [u8]) -> Self {
311        Self(slice)
312    }
313
314    pub fn as_slice(&self) -> &'a [u8] {
315        self.0
316    }
317}
318
319impl <'a> From<&'a [u8]> for CompactThriftInputSlice<'a> {
320    fn from(slice: &'a [u8]) -> Self {
321        Self(slice)
322    }
323}
324
325impl <'i> CompactThriftInput<'i> for CompactThriftInputSlice<'i> {
326    #[inline]
327    fn read_byte(&mut self) -> Result<u8, ThriftError> {
328        if let [first, rest @ ..] = self.0 {
329            self.0 = rest;
330            Ok(*first)
331        } else {
332            Err(ThriftError::from(ErrorKind::UnexpectedEof))
333        }
334    }
335
336    #[inline]
337    fn read_double(&mut self) -> Result<f64, ThriftError> {
338        if self.0.len() < 8 {
339            return Err(ThriftError::from(ErrorKind::UnexpectedEof))
340        }
341        let value = f64::from_le_bytes(self.0[..8].try_into().unwrap());
342        self.0 = &self.0[8..];
343        Ok(value)
344    }
345
346    #[inline]
347    fn read_binary(&mut self) -> Result<Cow<'i, [u8]>, ThriftError> {
348        let len = self.read_len()?;
349        if len > MAX_BINARY_LEN {
350            return Err(ThriftError::InvalidBinaryLen);
351        }
352        if self.0.len() < len {
353            return Err(ThriftError::from(ErrorKind::UnexpectedEof))
354        }
355        let (first, rest) = std::mem::take(&mut self.0).split_at(len);
356        self.0 = rest;
357        Ok(Cow::Borrowed(first))
358    }
359
360    #[inline]
361    #[cfg(target_feature = "sse2")]
362    fn skip_integer(&mut self) -> Result<(), ThriftError> {
363        if self.0.len() >= 16 {
364            self.0 = unsafe { skip_uleb_sse2(self.0) };
365            Ok(())
366        } else {
367            self.0 = skip_uleb_fallback(self.0);
368            Ok(())
369        }
370    }
371
372    fn skip_binary(&mut self) -> Result<(), ThriftError> {
373        let len = self.read_len()?;
374        if len > MAX_BINARY_LEN {
375            return Err(ThriftError::InvalidBinaryLen);
376        }
377        if self.0.len() < len {
378            return Err(ThriftError::from(ErrorKind::UnexpectedEof))
379        }
380        self.0 = &self.0[len..];
381        Ok(())
382    }
383}
384
385pub trait CompactThriftOutput {
386    fn write_byte(&mut self, value: u8) -> Result<(), ThriftError>;
387    fn write_len(&mut self, value: usize) -> Result<(), ThriftError>;
388    fn write_i16(&mut self, value: i16) -> Result<(), ThriftError>;
389    fn write_i32(&mut self, value: i32) -> Result<(), ThriftError>;
390    fn write_i64(&mut self, value: i64) -> Result<(), ThriftError>;
391    fn write_double(&mut self, value: f64) -> Result<(), ThriftError>;
392    fn write_binary(&mut self, value: &[u8]) -> Result<(), ThriftError>;
393    fn write_string(&mut self, value: &str) -> Result<(), ThriftError> {
394        self.write_binary(value.as_bytes())
395    }
396}
397
398impl <W: Write> CompactThriftOutput for W {
399    fn write_byte(&mut self, value: u8) -> Result<(), ThriftError> {
400        self.write_all(&[value])?;
401        Ok(())
402    }
403
404    fn write_len(&mut self, value: usize) -> Result<(), ThriftError> {
405        encode_uleb(self, value as _)
406    }
407
408    fn write_i16(&mut self, value: i16) -> Result<(), ThriftError> {
409        encode_uleb(self, zigzag_encode16(value) as _)
410    }
411
412    fn write_i32(&mut self, value: i32) -> Result<(), ThriftError> {
413        encode_uleb(self, zigzag_encode32(value) as _)
414    }
415
416    fn write_i64(&mut self, value: i64) -> Result<(), ThriftError> {
417        encode_uleb(self, zigzag_encode64(value) as _)
418    }
419
420    fn write_double(&mut self, value: f64) -> Result<(), ThriftError> {
421        self.write_all(&value.to_le_bytes())?;
422        Ok(())
423    }
424
425    fn write_binary(&mut self, value: &[u8]) -> Result<(), ThriftError> {
426        if value.len() > MAX_BINARY_LEN {
427            return Err(ThriftError::InvalidBinaryLen);
428        }
429        self.write_len(value.len())?;
430        self.write_all(value)?;
431        Ok(())
432    }
433}
434
435pub trait CompactThriftProtocol<'i> {
436    // In the compact protocol the tags for field and element types are currently the same.
437    // The documentation states "there is _no guarantee_ that this will
438    // remain true after new types are added".
439    const FIELD_TYPE: u8;
440
441    fn read_thrift<T: CompactThriftInput<'i>>(input: &mut T) -> Result<Self, ThriftError> where Self: Default{
442        let mut result = Self::default();
443        Self::fill_thrift(&mut result, input)?;
444        Ok(result)
445    }
446    fn fill_thrift<T: CompactThriftInput<'i>>(&mut self, input: &mut T) -> Result<(), ThriftError>;
447    #[inline]
448    fn fill_thrift_field<T: CompactThriftInput<'i>>(&mut self, input: &mut T, field_type: u8) -> Result<(), ThriftError> {
449        if field_type != Self::FIELD_TYPE {
450            return Err(ThriftError::InvalidType)
451        }
452        self.fill_thrift(input)
453    }
454    fn write_thrift<T: CompactThriftOutput>(&self, output: &mut T) -> Result<(), ThriftError>;
455    #[inline]
456    fn write_thrift_field<T: CompactThriftOutput>(&self, output: &mut T, field_id: i16, last_field_id: &mut i16) -> Result<(), ThriftError> {
457        write_field_header(output, Self::FIELD_TYPE, field_id, last_field_id)?;
458        self.write_thrift(output)?;
459        Ok(())
460    }
461}