1use std::borrow::Cow;
2
3use std::io::ErrorKind;
4use std::io::Read;
5use std::io::Write;
6
7use std::str::from_utf8_unchecked;
8use std::str::from_utf8;
9
10use crate::ThriftError;
11use crate::uleb::*;
12
13pub const MAX_BINARY_LEN: usize = 1024*1024;
14pub const MAX_COLLECTION_LEN: usize = 10_000_000;
15
16#[inline(never)] fn read_full_field_id<'i, I: CompactThriftInput<'i> + ?Sized>(input: &mut I) -> Result<i16, ThriftError> {
18 let field_id = decode_uleb(input)? as u16;
19 Ok(zigzag_decode16(field_id))
20}
21
22#[inline(always)]
23fn zigzag_decode16(i: u16) -> i16 {
24 (i >> 1) as i16 ^ -((i & 1) as i16)
25}
26
27#[inline(always)]
28fn zigzag_decode32(i: u32) -> i32 {
29 (i >> 1) as i32 ^ -((i & 1) as i32)
30}
31
32#[inline(always)]
33fn zigzag_decode64(i: u64) -> i64 {
34 (i >> 1) as i64 ^ -((i & 1) as i64)
35}
36
37#[inline(always)]
38fn zigzag_encode16(i: i16) -> u16 {
39 ((i << 1) ^ (i >> 15)) as u16
40}
41
42#[inline(always)]
43fn zigzag_encode32(i: i32) -> u32 {
44 ((i << 1) ^ (i >> 31)) as u32
45}
46
47#[inline(always)]
48fn zigzag_encode64(i: i64) -> u64 {
49 ((i << 1) ^ (i >> 63)) as u64
50}
51
52pub trait CompactThriftInput<'i> {
53 fn read_byte(&mut self) -> Result<u8, ThriftError>;
54 fn read_len(&mut self) -> Result<usize, ThriftError> {
55 let len = decode_uleb(self)?;
56 Ok(len as _)
57 }
58 fn read_i16(&mut self) -> Result<i16, ThriftError> {
59 let i = decode_uleb(self)?;
60 Ok(zigzag_decode16(i as _))
61 }
62 fn read_i32(&mut self) -> Result<i32, ThriftError> {
63 let i = decode_uleb(self)?;
64 Ok(zigzag_decode32(i as _))
65 }
66 fn read_i64(&mut self) -> Result<i64, ThriftError> {
67 let i = decode_uleb(self)?;
68 Ok(zigzag_decode64(i as _))
69 }
70 fn read_double(&mut self) -> Result<f64, ThriftError>;
71 fn read_binary(&mut self) -> Result<Cow<'i, [u8]>, ThriftError>;
72 fn read_string(&mut self) -> Result<Cow<'i, str>, ThriftError> {
73 let binary = self.read_binary()?;
74 let _ = from_utf8(binary.as_ref()).map_err(|_| ThriftError::InvalidString)?;
75 unsafe {
77 match binary {
78 Cow::Owned(v) => Ok(Cow::Owned(String::from_utf8_unchecked(v))),
79 Cow::Borrowed(v) => Ok(Cow::Borrowed(from_utf8_unchecked(v))),
80 }
81 }
82 }
83 fn skip_integer(&mut self) -> Result<(), ThriftError> {
84 let _ = self.read_i64()?;
85 Ok(())
86 }
87 fn skip_binary(&mut self) -> Result<(), ThriftError> {
88 self.read_binary()?;
89 Ok(())
90 }
91 fn skip_field(&mut self, field_type: u8) -> Result<(), ThriftError> {
92 skip_field(self, field_type, false)
93 }
94 fn read_field_header(&mut self, last_field_id: &mut i16) -> Result<u8, ThriftError> {
95 let field_header = self.read_byte()?;
96
97 if field_header == 0 {
98 return Ok(0)
99 }
100
101 let field_type = field_header & 0x0F;
102 let field_delta = field_header >> 4;
103 if field_delta != 0 {
104 *last_field_id += field_delta as i16;
105 } else {
106 *last_field_id = read_full_field_id(self)?;
107 }
108
109 Ok(field_type)
110 }
111}
112
113pub(crate) fn read_collection_len_and_type<'i, T: CompactThriftInput<'i> + ?Sized>(input: &mut T) -> Result<(u32, u8), ThriftError> {
114 let header = input.read_byte()?;
115 let field_type = header & 0x0F;
116 let maybe_len = (header & 0xF0) >> 4;
117 let len = if maybe_len != 0x0F {
118 maybe_len as usize
120 } else {
121 input.read_len()?
122 };
123
124 if len > MAX_COLLECTION_LEN {
125 return Err(ThriftError::InvalidCollectionLen)
126 }
127
128 Ok((len as u32, field_type))
129}
130
131pub(crate) fn read_map_len_and_types<'i, T: CompactThriftInput<'i> + ?Sized>(input: &mut T) -> Result<(u32, u8, u8), ThriftError> {
132 let len = input.read_len()?;
133 if len == 0 {
134 return Ok((0, 0, 0))
135 }
136 let entry_type = input.read_byte()?;
137 let key_type = entry_type >> 4;
139 let val_type = entry_type & 0x0F;
140
141 if len > MAX_COLLECTION_LEN {
142 return Err(ThriftError::InvalidCollectionLen)
143 }
144
145 Ok((len as u32, key_type, val_type))
146}
147
148fn skip_collection<'i, T: CompactThriftInput<'i> + ?Sized>(input: &mut T) -> Result<(), ThriftError> {
149 let (len, element_type) = read_collection_len_and_type(input)?;
150 match element_type {
151 1..=3 => {
152 for _ in 0..len {
154 let _ = input.read_byte()?;
155 }
156 }
157 4..=6 => {
158 for _ in 0..len {
161 input.skip_integer()?;
162 }
163 }
164 7 => {
165 for _ in 0..len {
166 input.read_double()?;
167 }
168 }
169 8 => {
170 for _ in 0..len {
173 input.skip_binary()?;
174 }
175 }
176 9 | 10 => {
177 for _ in 0..len {
179 skip_collection(input)?;
180 }
181 }
182 11 => {
183 for _ in 0..len {
185 skip_map(input)?;
186 }
187 }
188 12 => {
189 for _ in 0..len {
190 skip_field(input, 12, false)?;
191 }
192 }
193 _ => {
194 return Err(ThriftError::InvalidType)
195 }
196 }
197 Ok(())
198}
199
200fn skip_map<'i, T: CompactThriftInput<'i> + ?Sized>(input: &mut T) -> Result<(), ThriftError> {
201 let (len, key_type, val_type) = read_map_len_and_types(input)?;
202 for _ in 0..len {
203 skip_field(input, key_type, true)?;
204 skip_field(input, val_type, true)?;
205 }
206 Ok(())
207}
208
209pub(crate) fn skip_field<'i, T: CompactThriftInput<'i> + ?Sized>(input: &mut T, field_type: u8, inside_collection: bool) -> Result<(), ThriftError> {
210 match field_type {
211 1..=2 => {
212 if inside_collection {
214 input.read_byte()?;
215 }
216 }
217 3 => {
218 input.read_byte()?;
219 }
220 4..=6 => {
221 input.skip_integer()?;
224 }
225 7 => {
226 input.read_double()?;
227 }
228 8 => {
229 input.skip_binary()?;
232 }
233 9 | 10 => {
234 skip_collection(input)?;
236 }
237 11 => {
238 skip_map(input)?;
240 }
241 12 => {
242 let mut last_field_id = 0_i16;
244 loop {
245 let field_type = input.read_field_header(&mut last_field_id)?;
246 if field_type == 0 {
247 break;
248 }
249 skip_field(input, field_type, false)?;
250 }
251 }
252 _ => {
253 return Err(ThriftError::InvalidType)
254 }
255 }
256 Ok(())
257}
258
259#[inline]
260pub(crate) fn write_field_header<T: CompactThriftOutput>(output: &mut T, field_type: u8, field_id: i16, last_field_id: &mut i16) -> Result<(), ThriftError> {
261 let field_delta = field_id.wrapping_sub(*last_field_id);
262
263 if field_delta > 15 {
264 output.write_byte(field_type)?;
265 output.write_i16(field_delta)?
266 } else {
267 output.write_byte(field_type | ((field_delta as u8) << 4))?;
268 }
269 *last_field_id = field_id;
270 Ok(())
271}
272
273
274impl<R: Read + ?Sized> CompactThriftInput<'static> for R {
275 #[inline]
276 fn read_byte(&mut self) -> Result<u8, ThriftError> {
277 let mut buf = [0_u8; 1];
278 self.read_exact(&mut buf)?;
279 Ok(buf[0])
280 }
281
282 fn read_double(&mut self) -> Result<f64, ThriftError> {
283 let mut buf = [0_u8; 8];
284 self.read_exact(&mut buf)?;
285 Ok(f64::from_le_bytes(buf))
286 }
287
288 #[expect(clippy::uninit_vec)]
289 fn read_binary(&mut self) -> Result<Cow<'static, [u8]>, ThriftError> {
290 let len = self.read_len()?;
291 if len > MAX_BINARY_LEN {
292 return Err(ThriftError::InvalidBinaryLen);
293 }
294 let mut buf = Vec::with_capacity(len);
295 unsafe {
298 buf.set_len(len);
299 }
300 self.read_exact(buf.as_mut_slice())?;
301 Ok(buf.into())
302 }
303
304}
305
306#[derive(Clone)]
307pub struct CompactThriftInputSlice<'a>(&'a [u8]);
308
309impl <'a> CompactThriftInputSlice<'a> {
310 pub fn new(slice: &'a [u8]) -> Self {
311 Self(slice)
312 }
313
314 pub fn as_slice(&self) -> &'a [u8] {
315 self.0
316 }
317}
318
319impl <'a> From<&'a [u8]> for CompactThriftInputSlice<'a> {
320 fn from(slice: &'a [u8]) -> Self {
321 Self(slice)
322 }
323}
324
325impl <'i> CompactThriftInput<'i> for CompactThriftInputSlice<'i> {
326 #[inline]
327 fn read_byte(&mut self) -> Result<u8, ThriftError> {
328 if let [first, rest @ ..] = self.0 {
329 self.0 = rest;
330 Ok(*first)
331 } else {
332 Err(ThriftError::from(ErrorKind::UnexpectedEof))
333 }
334 }
335
336 #[inline]
337 fn read_double(&mut self) -> Result<f64, ThriftError> {
338 if self.0.len() < 8 {
339 return Err(ThriftError::from(ErrorKind::UnexpectedEof))
340 }
341 let value = f64::from_le_bytes(self.0[..8].try_into().unwrap());
342 self.0 = &self.0[8..];
343 Ok(value)
344 }
345
346 #[inline]
347 fn read_binary(&mut self) -> Result<Cow<'i, [u8]>, ThriftError> {
348 let len = self.read_len()?;
349 if len > MAX_BINARY_LEN {
350 return Err(ThriftError::InvalidBinaryLen);
351 }
352 if self.0.len() < len {
353 return Err(ThriftError::from(ErrorKind::UnexpectedEof))
354 }
355 let (first, rest) = std::mem::take(&mut self.0).split_at(len);
356 self.0 = rest;
357 Ok(Cow::Borrowed(first))
358 }
359
360 #[inline]
361 #[cfg(target_feature = "sse2")]
362 fn skip_integer(&mut self) -> Result<(), ThriftError> {
363 if self.0.len() >= 16 {
364 self.0 = unsafe { skip_uleb_sse2(self.0) };
365 Ok(())
366 } else {
367 self.0 = skip_uleb_fallback(self.0);
368 Ok(())
369 }
370 }
371
372 fn skip_binary(&mut self) -> Result<(), ThriftError> {
373 let len = self.read_len()?;
374 if len > MAX_BINARY_LEN {
375 return Err(ThriftError::InvalidBinaryLen);
376 }
377 if self.0.len() < len {
378 return Err(ThriftError::from(ErrorKind::UnexpectedEof))
379 }
380 self.0 = &self.0[len..];
381 Ok(())
382 }
383}
384
385pub trait CompactThriftOutput {
386 fn write_byte(&mut self, value: u8) -> Result<(), ThriftError>;
387 fn write_len(&mut self, value: usize) -> Result<(), ThriftError>;
388 fn write_i16(&mut self, value: i16) -> Result<(), ThriftError>;
389 fn write_i32(&mut self, value: i32) -> Result<(), ThriftError>;
390 fn write_i64(&mut self, value: i64) -> Result<(), ThriftError>;
391 fn write_double(&mut self, value: f64) -> Result<(), ThriftError>;
392 fn write_binary(&mut self, value: &[u8]) -> Result<(), ThriftError>;
393 fn write_string(&mut self, value: &str) -> Result<(), ThriftError> {
394 self.write_binary(value.as_bytes())
395 }
396}
397
398impl <W: Write> CompactThriftOutput for W {
399 fn write_byte(&mut self, value: u8) -> Result<(), ThriftError> {
400 self.write_all(&[value])?;
401 Ok(())
402 }
403
404 fn write_len(&mut self, value: usize) -> Result<(), ThriftError> {
405 encode_uleb(self, value as _)
406 }
407
408 fn write_i16(&mut self, value: i16) -> Result<(), ThriftError> {
409 encode_uleb(self, zigzag_encode16(value) as _)
410 }
411
412 fn write_i32(&mut self, value: i32) -> Result<(), ThriftError> {
413 encode_uleb(self, zigzag_encode32(value) as _)
414 }
415
416 fn write_i64(&mut self, value: i64) -> Result<(), ThriftError> {
417 encode_uleb(self, zigzag_encode64(value) as _)
418 }
419
420 fn write_double(&mut self, value: f64) -> Result<(), ThriftError> {
421 self.write_all(&value.to_le_bytes())?;
422 Ok(())
423 }
424
425 fn write_binary(&mut self, value: &[u8]) -> Result<(), ThriftError> {
426 if value.len() > MAX_BINARY_LEN {
427 return Err(ThriftError::InvalidBinaryLen);
428 }
429 self.write_len(value.len())?;
430 self.write_all(value)?;
431 Ok(())
432 }
433}
434
435pub trait CompactThriftProtocol<'i> {
436 const FIELD_TYPE: u8;
440
441 fn read_thrift<T: CompactThriftInput<'i>>(input: &mut T) -> Result<Self, ThriftError> where Self: Default{
442 let mut result = Self::default();
443 Self::fill_thrift(&mut result, input)?;
444 Ok(result)
445 }
446 fn fill_thrift<T: CompactThriftInput<'i>>(&mut self, input: &mut T) -> Result<(), ThriftError>;
447 #[inline]
448 fn fill_thrift_field<T: CompactThriftInput<'i>>(&mut self, input: &mut T, field_type: u8) -> Result<(), ThriftError> {
449 if field_type != Self::FIELD_TYPE {
450 return Err(ThriftError::InvalidType)
451 }
452 self.fill_thrift(input)
453 }
454 fn write_thrift<T: CompactThriftOutput>(&self, output: &mut T) -> Result<(), ThriftError>;
455 #[inline]
456 fn write_thrift_field<T: CompactThriftOutput>(&self, output: &mut T, field_id: i16, last_field_id: &mut i16) -> Result<(), ThriftError> {
457 write_field_header(output, Self::FIELD_TYPE, field_id, last_field_id)?;
458 self.write_thrift(output)?;
459 Ok(())
460 }
461}