compact_thrift_runtime/
protocol.rs

1use std::borrow::Cow;
2
3use std::io::ErrorKind;
4use std::io::Read;
5use std::io::Write;
6use std::marker::PhantomData;
7use std::ops::Range;
8use std::slice::from_raw_parts;
9use std::str::from_utf8_unchecked;
10use std::str::from_utf8;
11
12use crate::ThriftError;
13use crate::uleb::*;
14
15pub const MAX_BINARY_LEN: usize = 16*1024*1024;
16pub const MAX_COLLECTION_LEN: usize = 10_000_000;
17
18#[inline(never)] // full field ids are uncommon and inlining this bloats the code
19#[cold]
20fn read_full_field_id<'i, I: CompactThriftInput<'i> + ?Sized>(input: &mut I) -> Result<i16, ThriftError> {
21    input.read_i16()
22}
23
24#[inline(always)]
25fn zigzag_decode16(i: u16) -> i16 {
26    (i >> 1) as i16 ^ -((i & 1) as i16)
27}
28
29#[inline(always)]
30fn zigzag_decode32(i: u32) -> i32 {
31    (i >> 1) as i32 ^ -((i & 1) as i32)
32}
33
34#[inline(always)]
35fn zigzag_decode64(i: u64) -> i64 {
36    (i >> 1) as i64 ^ -((i & 1) as i64)
37}
38
39#[inline(always)]
40fn zigzag_encode16(i: i16) -> u16 {
41    ((i << 1) ^ (i >> 15)) as u16
42}
43
44#[inline(always)]
45fn zigzag_encode32(i: i32) -> u32 {
46    ((i << 1) ^ (i >> 31)) as u32
47}
48
49#[inline(always)]
50fn zigzag_encode64(i: i64) -> u64 {
51    ((i << 1) ^ (i >> 63)) as u64
52}
53
54pub trait CompactThriftInput<'i> {
55    fn read_byte(&mut self) -> Result<u8, ThriftError>;
56    fn read_len(&mut self) -> Result<usize, ThriftError> {
57        let len = decode_uleb(self)?;
58        Ok(len as _)
59    }
60    fn read_i16(&mut self) -> Result<i16, ThriftError> {
61        let i = decode_uleb(self)?;
62        Ok(zigzag_decode16(i as _))
63    }
64    fn read_i32(&mut self) -> Result<i32, ThriftError> {
65        let i = decode_uleb(self)?;
66        Ok(zigzag_decode32(i as _))
67    }
68    fn read_i64(&mut self) -> Result<i64, ThriftError> {
69        let i = decode_uleb(self)?;
70        Ok(zigzag_decode64(i as _))
71    }
72    fn read_double(&mut self) -> Result<f64, ThriftError>;
73    fn read_binary(&mut self) -> Result<Cow<'i, [u8]>, ThriftError>;
74    fn read_string(&mut self) -> Result<Cow<'i, str>, ThriftError> {
75        let binary = self.read_binary()?;
76        let _ = from_utf8(binary.as_ref()).map_err(|_| ThriftError::InvalidString)?;
77        // Safety: just checked for valid utf8
78        unsafe {
79            match binary {
80                Cow::Owned(v) => Ok(Cow::Owned(String::from_utf8_unchecked(v))),
81                Cow::Borrowed(v) => Ok(Cow::Borrowed(from_utf8_unchecked(v))),
82            }
83        }
84    }
85    fn skip_integer(&mut self) -> Result<(), ThriftError> {
86        let _ = self.read_i64()?;
87        Ok(())
88    }
89    fn skip_binary(&mut self) -> Result<(), ThriftError> {
90        self.read_binary()?;
91        Ok(())
92    }
93    fn skip_field(&mut self, field_type: u8) -> Result<(), ThriftError> {
94        skip_field(self, field_type, false)
95    }
96    fn read_field_header(&mut self, last_field_id: &mut i16) -> Result<u8, ThriftError> {
97        let field_header = self.read_byte()?;
98
99        if field_header == 0 {
100            return Ok(0)
101        }
102
103        let field_type = field_header & 0x0F;
104        let field_delta = field_header >> 4;
105        if field_delta != 0 {
106            *last_field_id += field_delta as i16;
107        } else {
108            *last_field_id = read_full_field_id(self)?;
109        }
110
111        Ok(field_type)
112    }
113}
114
115pub fn read_collection_len_and_type<'i, T: CompactThriftInput<'i> + ?Sized>(input: &mut T) -> Result<(u32, u8), ThriftError> {
116    let header = input.read_byte()?;
117    let field_type = header & 0x0F;
118    let maybe_len = (header & 0xF0) >> 4;
119    let len = if maybe_len != 0x0F {
120        // high bits set high if count and type encoded separately
121        maybe_len as usize
122    } else {
123        input.read_len()?
124    };
125
126    if len > MAX_COLLECTION_LEN {
127        return Err(ThriftError::InvalidCollectionLen)
128    }
129
130    Ok((len as u32, field_type))
131}
132
133pub fn read_map_len_and_types<'i, T: CompactThriftInput<'i> + ?Sized>(input: &mut T) -> Result<(u32, u8, u8), ThriftError> {
134    let len = input.read_len()?;
135    if len == 0 {
136        return Ok((0, 0, 0))
137    }
138    let entry_type = input.read_byte()?;
139    // TODO: check order of nibbles
140    let key_type = entry_type >> 4;
141    let val_type = entry_type & 0x0F;
142
143    if len > MAX_COLLECTION_LEN {
144        return Err(ThriftError::InvalidCollectionLen)
145    }
146
147    Ok((len as u32, key_type, val_type))
148}
149
150fn skip_collection<'i, T: CompactThriftInput<'i> + ?Sized>(input: &mut T) -> Result<(), ThriftError> {
151    let (len, element_type) = read_collection_len_and_type(input)?;
152    match element_type {
153        1..=3 => {
154            // TRUE, FALSE, i8 stored as single byte
155            for _ in 0..len {
156                let _ = input.read_byte()?;
157            }
158        }
159        4..=6 => {
160            // since we do not error on overlong sequences,
161            // skipping for all integer types works the same.
162            for _ in 0..len {
163                input.skip_integer()?;
164            }
165        }
166        7 => {
167            for _ in 0..len {
168                input.read_double()?;
169            }
170        }
171        8 => {
172            // thrift does not distinguish binary and string types in field_type,
173            // consequently there is no utf8 validation for skipped strings.
174            for _ in 0..len {
175                input.skip_binary()?;
176            }
177        }
178        9 | 10 => {
179            // list | set
180            for _ in 0..len {
181                skip_collection(input)?;
182            }
183        }
184        11 => {
185            // map
186            for _ in 0..len {
187                skip_map(input)?;
188            }
189        }
190        12 => {
191            for _ in 0..len {
192                skip_field(input, 12, false)?;
193            }
194        }
195        _ => {
196            return Err(ThriftError::InvalidType)
197        }
198    }
199    Ok(())
200}
201
202fn skip_map<'i, T: CompactThriftInput<'i> + ?Sized>(input: &mut T) -> Result<(), ThriftError> {
203    let (len, key_type, val_type) = read_map_len_and_types(input)?;
204    for _ in 0..len {
205        skip_field(input, key_type, true)?;
206        skip_field(input, val_type, true)?;
207    }
208    Ok(())
209}
210
211pub(crate) fn skip_field<'i, T: CompactThriftInput<'i> + ?Sized>(input: &mut T, field_type: u8, inside_collection: bool) -> Result<(), ThriftError> {
212    match field_type {
213        1..=2 => {
214            // boolean stored inside the field header outside of collections
215            if inside_collection {
216                input.read_byte()?;
217            }
218        }
219        3 => {
220            input.read_byte()?;
221        }
222        4..=6 => {
223            // since we do not error on overlong sequences,
224            // skipping for all integer types works the same.
225            input.skip_integer()?;
226        }
227        7 => {
228            input.read_double()?;
229        }
230        8 => {
231            // thrift does not distinguish binary and string types in field_type,
232            // consequently there is no utf8 validation for skipped strings.
233            input.skip_binary()?;
234        }
235        9 | 10 => {
236            // list | set
237            skip_collection(input)?;
238        }
239        11 => {
240            // map
241            skip_map(input)?;
242        }
243        12 => {
244            // struct | union
245            let mut last_field_id = 0_i16;
246            loop {
247                let field_type = input.read_field_header(&mut last_field_id)?;
248                if field_type == 0 {
249                    break;
250                }
251                skip_field(input, field_type, false)?;
252            }
253        }
254        _ => {
255            return Err(ThriftError::InvalidType)
256        }
257    }
258    Ok(())
259}
260
261#[inline]
262pub(crate) fn write_field_header<T: CompactThriftOutput>(output: &mut T, field_type: u8, field_id: i16, last_field_id: &mut i16) -> Result<(), ThriftError> {
263    let field_delta = field_id.wrapping_sub(*last_field_id);
264
265    if field_delta > 15 {
266        output.write_byte(field_type)?;
267        output.write_i16(field_delta)?
268    } else {
269        output.write_byte(field_type | ((field_delta as u8) << 4))?;
270    }
271    *last_field_id = field_id;
272    Ok(())
273}
274
275
276impl<R: Read + ?Sized> CompactThriftInput<'static> for R {
277    #[inline]
278    fn read_byte(&mut self) -> Result<u8, ThriftError> {
279        let mut buf = [0_u8; 1];
280        self.read_exact(&mut buf)?;
281        Ok(buf[0])
282    }
283
284    fn read_double(&mut self) -> Result<f64, ThriftError> {
285        let mut buf = [0_u8; 8];
286        self.read_exact(&mut buf)?;
287        Ok(f64::from_le_bytes(buf))
288    }
289
290    #[expect(clippy::uninit_vec)]
291    fn read_binary(&mut self) -> Result<Cow<'static, [u8]>, ThriftError> {
292        let len = self.read_len()?;
293        if len > MAX_BINARY_LEN {
294            return Err(ThriftError::InvalidBinaryLen(len));
295        }
296        let mut buf = Vec::with_capacity(len);
297        // Safety: we trust the Read implementation to only write into buf,
298        // and not to look at uninitialized bytes
299        unsafe {
300            buf.set_len(len);
301        }
302        self.read_exact(buf.as_mut_slice())?;
303        Ok(buf.into())
304    }
305
306}
307
308#[derive(Clone)]
309pub struct CompactThriftInputSlice<'a> {
310    range: Range<*const u8>,
311    phantom: PhantomData<&'a [u8]>,
312}
313
314impl <'a> CompactThriftInputSlice<'a> {
315    #[inline]
316    pub fn new(slice: &'a [u8]) -> Self {
317        Self {range: slice.as_ptr_range(), phantom: PhantomData}
318    }
319
320    #[inline]
321    pub fn as_slice(&self) -> &'a [u8] {
322        // See from_ptr_range
323        unsafe { from_raw_parts(self.range.start, self.range.end.offset_from_unsigned(self.range.start)) }
324    }
325
326    #[inline]
327    fn len(&self) -> usize {
328        unsafe { self.range.end.offset_from_unsigned(self.range.start) }
329    }
330}
331
332impl <'a> From<&'a [u8]> for CompactThriftInputSlice<'a> {
333    fn from(slice: &'a [u8]) -> Self {
334        Self::new(slice)
335    }
336}
337
338impl <'i> CompactThriftInput<'i> for CompactThriftInputSlice<'i> {
339    #[inline]
340    fn read_byte(&mut self) -> Result<u8, ThriftError> {
341        if self.range.is_empty() {
342            Err(ThriftError::from(ErrorKind::UnexpectedEof))
343        } else {
344            // Safety: Range is not exhausted
345            let byte = unsafe { self.range.start.read() };
346            self.range.start = unsafe { self.range.start.add(1) };
347            Ok(byte)
348        }
349    }
350
351    #[inline]
352    fn read_double(&mut self) -> Result<f64, ThriftError> {
353        if self.len() < 8 {
354            return Err(ThriftError::from(ErrorKind::UnexpectedEof))
355        }
356        let value = unsafe { self.range.start.cast::<f64>().read_unaligned() };
357        self.range.start = unsafe { self.range.start.add(8) };
358        Ok(value)
359    }
360
361    #[inline]
362    fn read_binary(&mut self) -> Result<Cow<'i, [u8]>, ThriftError> {
363        let len = self.read_len()?;
364        if len > MAX_BINARY_LEN {
365            return Err(ThriftError::InvalidBinaryLen(len));
366        }
367        if self.len() < len {
368            return Err(ThriftError::from(ErrorKind::UnexpectedEof))
369        }
370        let slice = unsafe { from_raw_parts(self.range.start, len) };
371        self.range.start = unsafe { self.range.start.add(len) };
372        Ok(Cow::Borrowed(slice))
373    }
374
375    fn skip_binary(&mut self) -> Result<(), ThriftError> {
376        let len = self.read_len()?;
377        if len > MAX_BINARY_LEN {
378            return Err(ThriftError::InvalidBinaryLen(len));
379        }
380        if self.len() < len {
381            return Err(ThriftError::from(ErrorKind::UnexpectedEof))
382        }
383        self.range.start = unsafe { self.range.start.add(len) };
384        Ok(())
385    }
386}
387
388pub trait CompactThriftOutput {
389    fn write_byte(&mut self, value: u8) -> Result<(), ThriftError>;
390    fn write_len(&mut self, value: usize) -> Result<(), ThriftError>;
391    fn write_i16(&mut self, value: i16) -> Result<(), ThriftError>;
392    fn write_i32(&mut self, value: i32) -> Result<(), ThriftError>;
393    fn write_i64(&mut self, value: i64) -> Result<(), ThriftError>;
394    fn write_double(&mut self, value: f64) -> Result<(), ThriftError>;
395    fn write_binary(&mut self, value: &[u8]) -> Result<(), ThriftError>;
396    fn write_string(&mut self, value: &str) -> Result<(), ThriftError> {
397        self.write_binary(value.as_bytes())
398    }
399}
400
401impl <W: Write> CompactThriftOutput for W {
402    fn write_byte(&mut self, value: u8) -> Result<(), ThriftError> {
403        self.write_all(&[value])?;
404        Ok(())
405    }
406
407    fn write_len(&mut self, value: usize) -> Result<(), ThriftError> {
408        encode_uleb(self, value as _)
409    }
410
411    fn write_i16(&mut self, value: i16) -> Result<(), ThriftError> {
412        encode_uleb(self, zigzag_encode16(value) as _)
413    }
414
415    fn write_i32(&mut self, value: i32) -> Result<(), ThriftError> {
416        encode_uleb(self, zigzag_encode32(value) as _)
417    }
418
419    fn write_i64(&mut self, value: i64) -> Result<(), ThriftError> {
420        encode_uleb(self, zigzag_encode64(value) as _)
421    }
422
423    fn write_double(&mut self, value: f64) -> Result<(), ThriftError> {
424        self.write_all(&value.to_le_bytes())?;
425        Ok(())
426    }
427
428    fn write_binary(&mut self, value: &[u8]) -> Result<(), ThriftError> {
429        if value.len() > MAX_BINARY_LEN {
430            return Err(ThriftError::InvalidBinaryLen(value.len()));
431        }
432        self.write_len(value.len())?;
433        self.write_all(value)?;
434        Ok(())
435    }
436}
437
438pub trait CompactThriftProtocol<'i> {
439    // In the compact protocol the tags for field and element types are currently the same.
440    // The documentation states "there is _no guarantee_ that this will
441    // remain true after new types are added".
442    const FIELD_TYPE: u8;
443
444    fn read_thrift<T: CompactThriftInput<'i>>(input: &mut T) -> Result<Self, ThriftError> where Self: Default{
445        let mut result = Self::default();
446        Self::fill_thrift(&mut result, input)?;
447        Ok(result)
448    }
449    fn fill_thrift<T: CompactThriftInput<'i>>(&mut self, input: &mut T) -> Result<(), ThriftError>;
450    #[inline]
451    fn fill_thrift_field<T: CompactThriftInput<'i>>(&mut self, input: &mut T, field_type: u8) -> Result<(), ThriftError> {
452        if field_type != Self::FIELD_TYPE {
453            return Err(ThriftError::InvalidType)
454        }
455        self.fill_thrift(input)
456    }
457    fn write_thrift<T: CompactThriftOutput>(&self, output: &mut T) -> Result<(), ThriftError>;
458    #[inline]
459    fn write_thrift_field<T: CompactThriftOutput>(&self, output: &mut T, field_id: i16, last_field_id: &mut i16) -> Result<(), ThriftError> {
460        write_field_header(output, Self::FIELD_TYPE, field_id, last_field_id)?;
461        self.write_thrift(output)?;
462        Ok(())
463    }
464}