1use std::borrow::Cow;
2
3use std::io::ErrorKind;
4use std::io::Read;
5use std::io::Write;
6use std::marker::PhantomData;
7use std::ops::Range;
8use std::slice::from_raw_parts;
9use std::str::from_utf8_unchecked;
10use std::str::from_utf8;
11
12use crate::ThriftError;
13use crate::uleb::*;
14
15pub const MAX_BINARY_LEN: usize = 16*1024*1024;
16pub const MAX_COLLECTION_LEN: usize = 10_000_000;
17
18#[inline(never)] #[cold]
20fn read_full_field_id<'i, I: CompactThriftInput<'i> + ?Sized>(input: &mut I) -> Result<i16, ThriftError> {
21 input.read_i16()
22}
23
24#[inline(always)]
25fn zigzag_decode16(i: u16) -> i16 {
26 (i >> 1) as i16 ^ -((i & 1) as i16)
27}
28
29#[inline(always)]
30fn zigzag_decode32(i: u32) -> i32 {
31 (i >> 1) as i32 ^ -((i & 1) as i32)
32}
33
34#[inline(always)]
35fn zigzag_decode64(i: u64) -> i64 {
36 (i >> 1) as i64 ^ -((i & 1) as i64)
37}
38
39#[inline(always)]
40fn zigzag_encode16(i: i16) -> u16 {
41 ((i << 1) ^ (i >> 15)) as u16
42}
43
44#[inline(always)]
45fn zigzag_encode32(i: i32) -> u32 {
46 ((i << 1) ^ (i >> 31)) as u32
47}
48
49#[inline(always)]
50fn zigzag_encode64(i: i64) -> u64 {
51 ((i << 1) ^ (i >> 63)) as u64
52}
53
54pub trait CompactThriftInput<'i> {
55 fn read_byte(&mut self) -> Result<u8, ThriftError>;
56 fn read_len(&mut self) -> Result<usize, ThriftError> {
57 let len = decode_uleb(self)?;
58 Ok(len as _)
59 }
60 fn read_i16(&mut self) -> Result<i16, ThriftError> {
61 let i = decode_uleb(self)?;
62 Ok(zigzag_decode16(i as _))
63 }
64 fn read_i32(&mut self) -> Result<i32, ThriftError> {
65 let i = decode_uleb(self)?;
66 Ok(zigzag_decode32(i as _))
67 }
68 fn read_i64(&mut self) -> Result<i64, ThriftError> {
69 let i = decode_uleb(self)?;
70 Ok(zigzag_decode64(i as _))
71 }
72 fn read_double(&mut self) -> Result<f64, ThriftError>;
73 fn read_binary(&mut self) -> Result<Cow<'i, [u8]>, ThriftError>;
74 fn read_string(&mut self) -> Result<Cow<'i, str>, ThriftError> {
75 let binary = self.read_binary()?;
76 let _ = from_utf8(binary.as_ref()).map_err(|_| ThriftError::InvalidString)?;
77 unsafe {
79 match binary {
80 Cow::Owned(v) => Ok(Cow::Owned(String::from_utf8_unchecked(v))),
81 Cow::Borrowed(v) => Ok(Cow::Borrowed(from_utf8_unchecked(v))),
82 }
83 }
84 }
85 fn skip_integer(&mut self) -> Result<(), ThriftError> {
86 let _ = self.read_i64()?;
87 Ok(())
88 }
89 fn skip_binary(&mut self) -> Result<(), ThriftError> {
90 self.read_binary()?;
91 Ok(())
92 }
93 fn skip_field(&mut self, field_type: u8) -> Result<(), ThriftError> {
94 skip_field(self, field_type, false)
95 }
96 fn read_field_header(&mut self, last_field_id: &mut i16) -> Result<u8, ThriftError> {
97 let field_header = self.read_byte()?;
98
99 if field_header == 0 {
100 return Ok(0)
101 }
102
103 let field_type = field_header & 0x0F;
104 let field_delta = field_header >> 4;
105 if field_delta != 0 {
106 *last_field_id += field_delta as i16;
107 } else {
108 *last_field_id = read_full_field_id(self)?;
109 }
110
111 Ok(field_type)
112 }
113}
114
115pub fn read_collection_len_and_type<'i, T: CompactThriftInput<'i> + ?Sized>(input: &mut T) -> Result<(u32, u8), ThriftError> {
116 let header = input.read_byte()?;
117 let field_type = header & 0x0F;
118 let maybe_len = (header & 0xF0) >> 4;
119 let len = if maybe_len != 0x0F {
120 maybe_len as usize
122 } else {
123 input.read_len()?
124 };
125
126 if len > MAX_COLLECTION_LEN {
127 return Err(ThriftError::InvalidCollectionLen)
128 }
129
130 Ok((len as u32, field_type))
131}
132
133pub fn read_map_len_and_types<'i, T: CompactThriftInput<'i> + ?Sized>(input: &mut T) -> Result<(u32, u8, u8), ThriftError> {
134 let len = input.read_len()?;
135 if len == 0 {
136 return Ok((0, 0, 0))
137 }
138 let entry_type = input.read_byte()?;
139 let key_type = entry_type >> 4;
141 let val_type = entry_type & 0x0F;
142
143 if len > MAX_COLLECTION_LEN {
144 return Err(ThriftError::InvalidCollectionLen)
145 }
146
147 Ok((len as u32, key_type, val_type))
148}
149
150fn skip_collection<'i, T: CompactThriftInput<'i> + ?Sized>(input: &mut T) -> Result<(), ThriftError> {
151 let (len, element_type) = read_collection_len_and_type(input)?;
152 match element_type {
153 1..=3 => {
154 for _ in 0..len {
156 let _ = input.read_byte()?;
157 }
158 }
159 4..=6 => {
160 for _ in 0..len {
163 input.skip_integer()?;
164 }
165 }
166 7 => {
167 for _ in 0..len {
168 input.read_double()?;
169 }
170 }
171 8 => {
172 for _ in 0..len {
175 input.skip_binary()?;
176 }
177 }
178 9 | 10 => {
179 for _ in 0..len {
181 skip_collection(input)?;
182 }
183 }
184 11 => {
185 for _ in 0..len {
187 skip_map(input)?;
188 }
189 }
190 12 => {
191 for _ in 0..len {
192 skip_field(input, 12, false)?;
193 }
194 }
195 _ => {
196 return Err(ThriftError::InvalidType)
197 }
198 }
199 Ok(())
200}
201
202fn skip_map<'i, T: CompactThriftInput<'i> + ?Sized>(input: &mut T) -> Result<(), ThriftError> {
203 let (len, key_type, val_type) = read_map_len_and_types(input)?;
204 for _ in 0..len {
205 skip_field(input, key_type, true)?;
206 skip_field(input, val_type, true)?;
207 }
208 Ok(())
209}
210
211pub(crate) fn skip_field<'i, T: CompactThriftInput<'i> + ?Sized>(input: &mut T, field_type: u8, inside_collection: bool) -> Result<(), ThriftError> {
212 match field_type {
213 1..=2 => {
214 if inside_collection {
216 input.read_byte()?;
217 }
218 }
219 3 => {
220 input.read_byte()?;
221 }
222 4..=6 => {
223 input.skip_integer()?;
226 }
227 7 => {
228 input.read_double()?;
229 }
230 8 => {
231 input.skip_binary()?;
234 }
235 9 | 10 => {
236 skip_collection(input)?;
238 }
239 11 => {
240 skip_map(input)?;
242 }
243 12 => {
244 let mut last_field_id = 0_i16;
246 loop {
247 let field_type = input.read_field_header(&mut last_field_id)?;
248 if field_type == 0 {
249 break;
250 }
251 skip_field(input, field_type, false)?;
252 }
253 }
254 _ => {
255 return Err(ThriftError::InvalidType)
256 }
257 }
258 Ok(())
259}
260
261#[inline]
262pub(crate) fn write_field_header<T: CompactThriftOutput>(output: &mut T, field_type: u8, field_id: i16, last_field_id: &mut i16) -> Result<(), ThriftError> {
263 let field_delta = field_id.wrapping_sub(*last_field_id);
264
265 if field_delta > 15 {
266 output.write_byte(field_type)?;
267 output.write_i16(field_delta)?
268 } else {
269 output.write_byte(field_type | ((field_delta as u8) << 4))?;
270 }
271 *last_field_id = field_id;
272 Ok(())
273}
274
275
276impl<R: Read + ?Sized> CompactThriftInput<'static> for R {
277 #[inline]
278 fn read_byte(&mut self) -> Result<u8, ThriftError> {
279 let mut buf = [0_u8; 1];
280 self.read_exact(&mut buf)?;
281 Ok(buf[0])
282 }
283
284 fn read_double(&mut self) -> Result<f64, ThriftError> {
285 let mut buf = [0_u8; 8];
286 self.read_exact(&mut buf)?;
287 Ok(f64::from_le_bytes(buf))
288 }
289
290 #[expect(clippy::uninit_vec)]
291 fn read_binary(&mut self) -> Result<Cow<'static, [u8]>, ThriftError> {
292 let len = self.read_len()?;
293 if len > MAX_BINARY_LEN {
294 return Err(ThriftError::InvalidBinaryLen(len));
295 }
296 let mut buf = Vec::with_capacity(len);
297 unsafe {
300 buf.set_len(len);
301 }
302 self.read_exact(buf.as_mut_slice())?;
303 Ok(buf.into())
304 }
305
306}
307
308#[derive(Clone)]
309pub struct CompactThriftInputSlice<'a> {
310 range: Range<*const u8>,
311 phantom: PhantomData<&'a [u8]>,
312}
313
314impl <'a> CompactThriftInputSlice<'a> {
315 #[inline]
316 pub fn new(slice: &'a [u8]) -> Self {
317 Self {range: slice.as_ptr_range(), phantom: PhantomData}
318 }
319
320 #[inline]
321 pub fn as_slice(&self) -> &'a [u8] {
322 unsafe { from_raw_parts(self.range.start, self.range.end.offset_from_unsigned(self.range.start)) }
324 }
325
326 #[inline]
327 fn len(&self) -> usize {
328 unsafe { self.range.end.offset_from_unsigned(self.range.start) }
329 }
330}
331
332impl <'a> From<&'a [u8]> for CompactThriftInputSlice<'a> {
333 fn from(slice: &'a [u8]) -> Self {
334 Self::new(slice)
335 }
336}
337
338impl <'i> CompactThriftInput<'i> for CompactThriftInputSlice<'i> {
339 #[inline]
340 fn read_byte(&mut self) -> Result<u8, ThriftError> {
341 if self.range.is_empty() {
342 Err(ThriftError::from(ErrorKind::UnexpectedEof))
343 } else {
344 let byte = unsafe { self.range.start.read() };
346 self.range.start = unsafe { self.range.start.add(1) };
347 Ok(byte)
348 }
349 }
350
351 #[inline]
352 fn read_double(&mut self) -> Result<f64, ThriftError> {
353 if self.len() < 8 {
354 return Err(ThriftError::from(ErrorKind::UnexpectedEof))
355 }
356 let value = unsafe { self.range.start.cast::<f64>().read_unaligned() };
357 self.range.start = unsafe { self.range.start.add(8) };
358 Ok(value)
359 }
360
361 #[inline]
362 fn read_binary(&mut self) -> Result<Cow<'i, [u8]>, ThriftError> {
363 let len = self.read_len()?;
364 if len > MAX_BINARY_LEN {
365 return Err(ThriftError::InvalidBinaryLen(len));
366 }
367 if self.len() < len {
368 return Err(ThriftError::from(ErrorKind::UnexpectedEof))
369 }
370 let slice = unsafe { from_raw_parts(self.range.start, len) };
371 self.range.start = unsafe { self.range.start.add(len) };
372 Ok(Cow::Borrowed(slice))
373 }
374
375 fn skip_binary(&mut self) -> Result<(), ThriftError> {
376 let len = self.read_len()?;
377 if len > MAX_BINARY_LEN {
378 return Err(ThriftError::InvalidBinaryLen(len));
379 }
380 if self.len() < len {
381 return Err(ThriftError::from(ErrorKind::UnexpectedEof))
382 }
383 self.range.start = unsafe { self.range.start.add(len) };
384 Ok(())
385 }
386}
387
388pub trait CompactThriftOutput {
389 fn write_byte(&mut self, value: u8) -> Result<(), ThriftError>;
390 fn write_len(&mut self, value: usize) -> Result<(), ThriftError>;
391 fn write_i16(&mut self, value: i16) -> Result<(), ThriftError>;
392 fn write_i32(&mut self, value: i32) -> Result<(), ThriftError>;
393 fn write_i64(&mut self, value: i64) -> Result<(), ThriftError>;
394 fn write_double(&mut self, value: f64) -> Result<(), ThriftError>;
395 fn write_binary(&mut self, value: &[u8]) -> Result<(), ThriftError>;
396 fn write_string(&mut self, value: &str) -> Result<(), ThriftError> {
397 self.write_binary(value.as_bytes())
398 }
399}
400
401impl <W: Write> CompactThriftOutput for W {
402 fn write_byte(&mut self, value: u8) -> Result<(), ThriftError> {
403 self.write_all(&[value])?;
404 Ok(())
405 }
406
407 fn write_len(&mut self, value: usize) -> Result<(), ThriftError> {
408 encode_uleb(self, value as _)
409 }
410
411 fn write_i16(&mut self, value: i16) -> Result<(), ThriftError> {
412 encode_uleb(self, zigzag_encode16(value) as _)
413 }
414
415 fn write_i32(&mut self, value: i32) -> Result<(), ThriftError> {
416 encode_uleb(self, zigzag_encode32(value) as _)
417 }
418
419 fn write_i64(&mut self, value: i64) -> Result<(), ThriftError> {
420 encode_uleb(self, zigzag_encode64(value) as _)
421 }
422
423 fn write_double(&mut self, value: f64) -> Result<(), ThriftError> {
424 self.write_all(&value.to_le_bytes())?;
425 Ok(())
426 }
427
428 fn write_binary(&mut self, value: &[u8]) -> Result<(), ThriftError> {
429 if value.len() > MAX_BINARY_LEN {
430 return Err(ThriftError::InvalidBinaryLen(value.len()));
431 }
432 self.write_len(value.len())?;
433 self.write_all(value)?;
434 Ok(())
435 }
436}
437
438pub trait CompactThriftProtocol<'i> {
439 const FIELD_TYPE: u8;
443
444 fn read_thrift<T: CompactThriftInput<'i>>(input: &mut T) -> Result<Self, ThriftError> where Self: Default{
445 let mut result = Self::default();
446 Self::fill_thrift(&mut result, input)?;
447 Ok(result)
448 }
449 fn fill_thrift<T: CompactThriftInput<'i>>(&mut self, input: &mut T) -> Result<(), ThriftError>;
450 #[inline]
451 fn fill_thrift_field<T: CompactThriftInput<'i>>(&mut self, input: &mut T, field_type: u8) -> Result<(), ThriftError> {
452 if field_type != Self::FIELD_TYPE {
453 return Err(ThriftError::InvalidType)
454 }
455 self.fill_thrift(input)
456 }
457 fn write_thrift<T: CompactThriftOutput>(&self, output: &mut T) -> Result<(), ThriftError>;
458 #[inline]
459 fn write_thrift_field<T: CompactThriftOutput>(&self, output: &mut T, field_id: i16, last_field_id: &mut i16) -> Result<(), ThriftError> {
460 write_field_header(output, Self::FIELD_TYPE, field_id, last_field_id)?;
461 self.write_thrift(output)?;
462 Ok(())
463 }
464}