1use std::io::Cursor;
18
19use anyhow::anyhow;
20use anyhow::Result;
21use bufsize::SizeCounter;
22use bytes::Bytes;
23use bytes::BytesMut;
24use ghost::phantom;
25
26use crate::binary_type::CopyFromBuf;
27use crate::bufext::BufExt;
28use crate::bufext::BufMutExt;
29use crate::bufext::DeserializeSource;
30use crate::deserialize::Deserialize;
31use crate::errors::ProtocolError;
32use crate::framing::Framing;
33use crate::protocol::Field;
34use crate::protocol::Protocol;
35use crate::protocol::ProtocolReader;
36use crate::protocol::ProtocolWriter;
37use crate::serialize::Serialize;
38use crate::thrift_protocol::MessageType;
39use crate::thrift_protocol::ProtocolID;
40use crate::ttype::TType;
41use crate::varint;
42
43const COMPACT_PROTOCOL_VERSION: u8 = 0x02;
44const PROTOCOL_ID: u8 = 0x82;
45const TYPE_MASK: u8 = 0xE0;
46const TYPE_SHIFT_AMOUNT: usize = 5;
47const VERSION_MASK: u8 = 0x1f;
48
49#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
50#[repr(u8)]
51pub enum CType {
52 Stop = 0,
53 BoolTrue = 1,
54 BoolFalse = 2,
55 Byte = 3,
56 I16 = 4,
57 I32 = 5,
58 I64 = 6,
59 Double = 7,
60 Binary = 8,
61 List = 9,
62 Set = 10,
63 Map = 11,
64 Struct = 12,
65 Float = 13,
66}
67
68impl TryFrom<i8> for CType {
69 type Error = anyhow::Error;
70
71 fn try_from(v: i8) -> Result<Self> {
72 let ret = match v {
73 0 => CType::Stop,
74 1 => CType::BoolTrue,
75 2 => CType::BoolFalse,
76 3 => CType::Byte,
77 4 => CType::I16,
78 5 => CType::I32,
79 6 => CType::I64,
80 7 => CType::Double,
81 8 => CType::Binary,
82 9 => CType::List,
83 10 => CType::Set,
84 11 => CType::Map,
85 12 => CType::Struct,
86 13 => CType::Float,
87 _ => bail_err!(ProtocolError::InvalidTypeTag),
88 };
89 Ok(ret)
90 }
91}
92
93impl From<CType> for TType {
94 fn from(cty: CType) -> TType {
95 match cty {
96 CType::Stop => TType::Stop,
97 CType::BoolFalse | CType::BoolTrue => TType::Bool,
98 CType::Byte => TType::Byte,
99 CType::Double => TType::Double,
100 CType::I16 => TType::I16,
101 CType::I32 => TType::I32,
102 CType::I64 => TType::I64,
103 CType::Binary => TType::String,
104 CType::Struct => TType::Struct,
105 CType::Map => TType::Map,
106 CType::Set => TType::Set,
107 CType::List => TType::List,
108 CType::Float => TType::Float,
109 }
110 }
111}
112
113impl From<TType> for CType {
114 fn from(tty: TType) -> CType {
115 match tty {
116 TType::Stop => CType::Stop,
117 TType::Bool => CType::BoolTrue,
118 TType::Byte => CType::Byte,
119 TType::Double => CType::Double,
120 TType::I16 => CType::I16,
121 TType::I32 => CType::I32,
122 TType::I64 => CType::I64,
123 TType::String => CType::Binary,
124 TType::Struct => CType::Struct,
125 TType::Map => CType::Map,
126 TType::Set => CType::Set,
127 TType::List => CType::List,
128 TType::Float => CType::Float,
129 bad => panic!("Don't know how to convert TType {:?} to CType", bad),
130 }
131 }
132}
133
134impl From<bool> for CType {
135 fn from(b: bool) -> CType {
136 if b { CType::BoolTrue } else { CType::BoolFalse }
137 }
138}
139
140#[phantom]
161#[derive(Copy, Clone)]
162pub struct CompactProtocol<F = Bytes>;
163
164#[derive(Debug, Clone)]
165struct EncState {
166 idxstack: Vec<i16>,
167 lastidx: i16,
168 field: Option<(TType, i16)>,
169}
170
171impl EncState {
172 fn new() -> Self {
173 EncState {
174 idxstack: Vec::new(),
175 lastidx: 0,
176 field: None,
177 }
178 }
179
180 fn struct_begin(&mut self) {
181 self.idxstack.push(self.lastidx);
182 self.lastidx = 0;
183 }
184
185 fn struct_end(&mut self) {
186 self.lastidx = self.idxstack.pop().expect("struct stack underrun");
187 }
188
189 fn field_begin(&mut self, tty: TType, idx: i16) {
190 self.field = Some((tty, idx));
191 }
192
193 fn field_get(&mut self) -> Option<(TType, i16)> {
194 self.field.take()
195 }
196
197 fn field_end(&mut self) {
198 debug_assert!(
199 self.field.is_none() ||
201 matches!(self.field, Some((TType::Void, _)))
203 )
204 }
205
206 fn in_field(&self) -> bool {
207 self.field.is_some()
208 }
209}
210
211pub struct CompactProtocolSerializer<B> {
212 state: EncState,
213 buffer: B,
214 string_limit: Option<usize>,
215 container_limit: Option<usize>,
216}
217
218pub struct CompactProtocolDeserializer<B> {
219 state: EncState,
220 boolfield: Option<CType>,
221 buffer: B,
222 string_limit: Option<usize>,
223 container_limit: Option<usize>,
224}
225
226impl<F> Protocol for CompactProtocol<F>
227where
228 F: Framing + 'static,
229{
230 type Frame = F;
231 type Sizer = CompactProtocolSerializer<SizeCounter>;
232 type Serializer = CompactProtocolSerializer<F::EncBuf>;
233 type Deserializer = CompactProtocolDeserializer<F::DecBuf>;
234
235 const PROTOCOL_ID: ProtocolID = ProtocolID::CompactProtocol;
236
237 fn serializer<SZ, SER>(size: SZ, ser: SER) -> <Self::Serializer as ProtocolWriter>::Final
238 where
239 SZ: FnOnce(&mut Self::Sizer),
240 SER: FnOnce(&mut Self::Serializer),
241 {
242 let mut sizer = CompactProtocolSerializer {
244 state: EncState::new(),
245 buffer: SizeCounter::new(),
246 container_limit: None,
247 string_limit: None,
248 };
249 size(&mut sizer);
250 let sz = sizer.finish();
251
252 let mut buf = CompactProtocolSerializer {
254 state: EncState::new(),
255 buffer: F::enc_with_capacity(sz),
256 container_limit: None,
257 string_limit: None,
258 };
259 ser(&mut buf);
260
261 buf.finish()
263 }
264
265 fn deserializer(buf: F::DecBuf) -> Self::Deserializer {
266 CompactProtocolDeserializer::new(buf)
267 }
268
269 fn into_buffer(deser: Self::Deserializer) -> F::DecBuf {
270 deser.into_inner()
271 }
272}
273
274impl<B> CompactProtocolSerializer<B>
275where
276 B: BufMutExt,
277{
278 pub fn with_buffer(buffer: B) -> Self {
279 Self {
280 state: EncState::new(),
281 buffer,
282 string_limit: None,
283 container_limit: None,
284 }
285 }
286
287 pub fn into_inner(self) -> B {
288 self.buffer
289 }
290
291 #[inline]
292 fn write_varint_i64(&mut self, v: i64) {
293 self.buffer.put_varint_i64(v)
294 }
295
296 #[inline]
297 fn write_varint_i32(&mut self, v: i32) {
298 self.write_varint_i64(v as i64);
299 }
300
301 #[inline]
302 fn write_varint_i16(&mut self, v: i16) {
303 self.write_varint_i64(v as i64);
304 }
305
306 fn write_field_id(&mut self, cty: CType) {
310 match self.state.field_get() {
311 None => {}
312 Some((tty, idx)) => {
313 debug_assert_eq!(tty, TType::from(cty));
314
315 let delta = idx - self.state.lastidx;
316 self.state.lastidx = idx;
317
318 if delta <= 0 || delta > 15 {
319 self.buffer.put_u8(cty as u8);
320 self.write_varint_i16(idx);
321 } else {
322 self.buffer.put_u8(((delta as u8) << 4) | (cty as u8));
323 }
324 }
325 }
326 }
327
328 fn write_sequence(&mut self, elem_type: TType, size: usize) {
330 assert!(
331 self.container_limit.map_or(true, |lim| size < lim),
332 "container too large {}, lim {:?}",
333 size,
334 self.container_limit
335 );
336
337 let cty = CType::from(elem_type) as u8;
338
339 if size < 0x0f {
341 self.buffer.put_u8((size as u8) << 4 | cty);
342 } else {
343 self.buffer.put_u8((0x0f << 4) | cty);
344 self.buffer.put_varint_u64(size as u64);
345 }
346 }
347}
348
349impl<B: BufMutExt> ProtocolWriter for CompactProtocolSerializer<B> {
350 type Final = B::Final; fn write_message_begin(&mut self, name: &str, msgtype: MessageType, seqid: u32) {
353 let msgtype = msgtype as u8;
354 self.buffer.put_u8(PROTOCOL_ID);
355 self.buffer
356 .put_u8(COMPACT_PROTOCOL_VERSION | ((msgtype << TYPE_SHIFT_AMOUNT) & TYPE_MASK));
357 self.buffer.put_varint_u64(seqid as u64);
358 self.write_string(name);
359 }
360
361 #[inline]
362 fn write_message_end(&mut self) {}
363
364 #[inline]
365 fn write_struct_begin(&mut self, _name: &str) {
366 self.write_field_id(CType::Struct);
367 self.state.struct_begin();
368 }
369
370 #[inline]
371 fn write_struct_end(&mut self) {
372 self.state.struct_end();
373 }
374
375 #[inline]
376 fn write_field_begin(&mut self, _name: &str, type_id: TType, id: i16) {
377 self.state.field_begin(type_id, id);
379 }
380
381 #[inline]
382 fn write_field_end(&mut self) {
383 self.state.field_end()
384 }
385
386 #[inline]
387 fn write_field_stop(&mut self) {
388 self.buffer.put_u8(CType::Stop as u8)
389 }
390
391 fn write_map_begin(&mut self, key_type: TType, value_type: TType, size: usize) {
392 assert!(
393 self.container_limit.map_or(true, |lim| size < lim),
394 "map too large {}, lim {:?}",
395 size,
396 self.container_limit
397 );
398
399 self.write_field_id(CType::Map);
400 self.buffer.put_varint_u64(size as u64);
401 if size > 0 {
402 let ckty = CType::from(key_type);
403 let cvty = CType::from(value_type);
404
405 self.buffer.put_u8((ckty as u8) << 4 | (cvty as u8));
406 }
407 }
408
409 #[inline]
410 fn write_map_key_begin(&mut self) {}
411
412 #[inline]
413 fn write_map_value_begin(&mut self) {}
414
415 #[inline]
416 fn write_map_end(&mut self) {}
417
418 fn write_list_begin(&mut self, elem_type: TType, size: usize) {
419 assert!(self.container_limit.map_or(true, |lim| size < lim));
420 self.write_field_id(CType::List);
421 self.write_sequence(elem_type, size);
422 }
423
424 #[inline]
425 fn write_list_value_begin(&mut self) {}
426
427 #[inline]
428 fn write_list_end(&mut self) {}
429
430 fn write_set_begin(&mut self, elem_type: TType, size: usize) {
431 self.write_field_id(CType::Set);
432 self.write_sequence(elem_type, size);
433 }
434
435 #[inline]
436 fn write_set_value_begin(&mut self) {}
437
438 fn write_set_end(&mut self) {}
439
440 fn write_bool(&mut self, value: bool) {
441 if self.state.in_field() {
445 self.write_field_id(CType::from(value))
446 } else if value {
447 self.write_byte(CType::BoolTrue as i8)
448 } else {
449 self.write_byte(CType::BoolFalse as i8)
450 }
451 }
452
453 fn write_byte(&mut self, value: i8) {
454 self.write_field_id(CType::Byte);
455 self.buffer.put_i8(value)
456 }
457
458 fn write_i16(&mut self, value: i16) {
459 self.write_field_id(CType::I16);
460 self.write_varint_i16(value)
461 }
462
463 fn write_i32(&mut self, value: i32) {
464 self.write_field_id(CType::I32);
465 self.write_varint_i32(value)
466 }
467
468 fn write_i64(&mut self, value: i64) {
469 self.write_field_id(CType::I64);
470 self.write_varint_i64(value)
471 }
472
473 fn write_double(&mut self, value: f64) {
474 self.write_field_id(CType::Double);
475 self.buffer.put_f64(value)
476 }
477
478 fn write_float(&mut self, value: f32) {
479 self.write_field_id(CType::Float);
480 self.buffer.put_f32(value)
481 }
482
483 #[inline]
484 fn write_string(&mut self, value: &str) {
485 self.write_binary(value.as_bytes());
486 }
487
488 fn write_binary(&mut self, value: &[u8]) {
489 let size = value.len();
490 assert!(
491 self.string_limit.map_or(true, |lim| size < lim),
492 "string too large {}, lim {:?}",
493 size,
494 self.string_limit
495 );
496
497 self.write_field_id(CType::Binary);
498 self.buffer.put_varint_u64(size as u64);
499 self.buffer.put_slice(value)
500 }
501
502 fn finish(self) -> B::Final {
503 self.buffer.finalize()
504 }
505}
506
507impl<B: BufExt> CompactProtocolDeserializer<B> {
508 pub fn new(buffer: B) -> Self {
509 CompactProtocolDeserializer {
510 state: EncState::new(),
511 boolfield: None,
512 buffer,
513 string_limit: None,
514 container_limit: None,
515 }
516 }
517
518 pub fn into_inner(self) -> B {
519 self.buffer
520 }
521
522 fn peek_bytes(&self, len: usize) -> Option<&[u8]> {
523 if self.buffer.chunk().len() >= len {
524 Some(&self.buffer.chunk()[..len])
525 } else {
526 None
527 }
528 }
529
530 fn read_varint_u64(&mut self) -> Result<u64> {
531 varint::read_u64(&mut self.buffer)
532 }
533
534 fn read_varint_i64(&mut self) -> Result<i64> {
535 self.read_varint_u64().map(varint::unzigzag)
536 }
537
538 fn read_varint_i32(&mut self) -> Result<i32> {
539 self.read_varint_i64()
540 .and_then(|v| i32::try_from(v).map_err(|_| ProtocolError::InvalidValue.into()))
541 }
542
543 fn read_varint_i16(&mut self) -> Result<i16> {
544 self.read_varint_i64()
545 .and_then(|v| i16::try_from(v).map_err(|_| (ProtocolError::InvalidValue).into()))
546 }
547}
548
549impl<B: BufExt> ProtocolReader for CompactProtocolDeserializer<B> {
550 fn read_message_begin<F, T>(&mut self, msgfn: F) -> Result<(T, MessageType, u32)>
551 where
552 F: FnOnce(&[u8]) -> T,
553 {
554 let protocolid = self.read_byte()? as u8;
555 ensure_err!(protocolid == PROTOCOL_ID, ProtocolError::BadVersion);
556
557 let vandty = self.read_byte()? as u8;
558 ensure_err!(
559 (vandty & VERSION_MASK) == COMPACT_PROTOCOL_VERSION,
560 ProtocolError::BadVersion
561 );
562
563 let msgty = (vandty & TYPE_MASK) >> TYPE_SHIFT_AMOUNT;
564 let msgty = MessageType::try_from(msgty as u32)?;
565
566 let seqid = self.read_varint_u64()? as u32;
567
568 let name = {
569 let len = self.read_varint_u64()? as usize;
570 let (len, name) = {
571 if self.peek_bytes(len).is_some() {
572 let namebuf = self.peek_bytes(len).unwrap();
573 (namebuf.len(), msgfn(namebuf))
574 } else {
575 ensure_err!(
576 self.buffer.remaining() >= len,
577 ProtocolError::InvalidDataLength
578 );
579 let namebuf: Vec<u8> = Vec::copy_from_buf(&mut self.buffer, len);
580 (0, msgfn(namebuf.as_slice()))
581 }
582 };
583 self.buffer.advance(len);
584 name
585 };
586
587 Ok((name, msgty, seqid))
588 }
589
590 fn read_message_end(&mut self) -> Result<()> {
591 Ok(())
592 }
593
594 fn read_struct_begin<F, T>(&mut self, namefn: F) -> Result<T>
595 where
596 F: FnOnce(&[u8]) -> T,
597 {
598 self.state.struct_begin();
599 Ok(namefn(&[]))
600 }
601
602 fn read_struct_end(&mut self) -> Result<()> {
603 self.state.struct_end();
604 Ok(())
605 }
606
607 fn read_field_begin<F, T>(&mut self, fieldfn: F, _fields: &[Field]) -> Result<(T, TType, i16)>
608 where
609 F: FnOnce(&[u8]) -> T,
610 {
611 let tyid = self.read_byte()?;
612 let cty = CType::try_from(tyid & 0x0f)?;
613 let didx = (tyid >> 4) & 0x0f;
614
615 let tty = TType::from(cty);
616
617 let idx = match (tty, didx) {
619 (TType::Stop, _) => 0,
620 (_, 0) => self.read_varint_i16()?,
621 (_, didx) => self.state.lastidx + (didx as i16),
622 };
623
624 self.state.lastidx = idx;
625
626 if tty == TType::Bool {
627 self.boolfield = Some(cty);
628 }
629
630 let f = fieldfn(&[]);
631 Ok((f, tty, idx))
632 }
633
634 fn read_field_end(&mut self) -> Result<()> {
635 Ok(())
636 }
637
638 fn read_map_begin(&mut self) -> Result<(TType, TType, Option<usize>)> {
639 let size = self.read_varint_u64()? as usize;
640
641 ensure_err!(
642 self.container_limit.map_or(true, |lim| size < lim),
643 ProtocolError::InvalidDataLength
644 );
645
646 let kvtype = if size > 0 { self.read_byte()? } else { 0 };
648
649 let kcty = CType::try_from((kvtype >> 4) & 0x0f)?;
650 let vcty = CType::try_from((kvtype) & 0x0f)?;
651
652 Ok((TType::from(kcty), TType::from(vcty), Some(size)))
653 }
654
655 #[inline]
656 fn read_map_key_begin(&mut self) -> Result<bool> {
657 Ok(true)
658 }
659
660 #[inline]
661 fn read_map_value_begin(&mut self) -> Result<()> {
662 Ok(())
663 }
664
665 #[inline]
666 fn read_map_value_end(&mut self) -> Result<()> {
667 Ok(())
668 }
669
670 fn read_map_end(&mut self) -> Result<()> {
671 Ok(())
672 }
673
674 fn read_list_begin(&mut self) -> Result<(TType, Option<usize>)> {
675 let szty = self.read_byte()?;
676 let cty = CType::try_from(szty & 0x0f)?;
677 let elem_type = TType::from(cty);
678
679 let size = match (szty >> 4) & 0x0f {
680 0x0f => self.read_varint_u64()? as usize,
681 sz => sz as usize,
682 };
683
684 ensure_err!(
685 self.container_limit.map_or(true, |lim| size < lim),
686 ProtocolError::InvalidDataLength
687 );
688
689 Ok((elem_type, Some(size)))
690 }
691
692 #[inline]
693 fn read_list_value_begin(&mut self) -> Result<bool> {
694 Ok(true)
695 }
696
697 #[inline]
698 fn read_list_value_end(&mut self) -> Result<()> {
699 Ok(())
700 }
701
702 fn read_list_end(&mut self) -> Result<()> {
703 Ok(())
704 }
705
706 fn read_set_begin(&mut self) -> Result<(TType, Option<usize>)> {
707 self.read_list_begin()
708 }
709
710 #[inline]
711 fn read_set_value_begin(&mut self) -> Result<bool> {
712 Ok(true)
713 }
714
715 #[inline]
716 fn read_set_value_end(&mut self) -> Result<()> {
717 Ok(())
718 }
719
720 fn read_set_end(&mut self) -> Result<()> {
721 Ok(())
722 }
723
724 fn read_bool(&mut self) -> Result<bool> {
725 let cty = match self.boolfield.take() {
728 None => CType::try_from(self.read_byte()?)?,
729 Some(cty) => cty,
730 };
731
732 Ok(cty == CType::BoolTrue)
736 }
737
738 fn read_byte(&mut self) -> Result<i8> {
739 ensure_err!(self.buffer.remaining() >= 1, ProtocolError::EOF);
740
741 Ok(self.buffer.get_i8())
742 }
743
744 fn read_i16(&mut self) -> Result<i16> {
745 self.read_varint_i16()
746 }
747
748 fn read_i32(&mut self) -> Result<i32> {
749 self.read_varint_i32()
750 }
751
752 fn read_i64(&mut self) -> Result<i64> {
753 self.read_varint_i64()
754 }
755
756 fn read_double(&mut self) -> Result<f64> {
757 ensure_err!(self.buffer.remaining() >= 8, ProtocolError::EOF);
758
759 Ok(self.buffer.get_f64())
760 }
761
762 fn read_float(&mut self) -> Result<f32> {
763 ensure_err!(self.buffer.remaining() >= 4, ProtocolError::EOF);
764
765 Ok(self.buffer.get_f32())
766 }
767
768 fn read_string(&mut self) -> Result<String> {
769 let vec = self.read_binary::<Vec<u8>>()?;
770
771 String::from_utf8(vec)
772 .map_err(|utf8_error| anyhow!("deserializing `string` from Thrift compact protocol got invalid utf-8, you need to use `binary` instead: {utf8_error}"))
773 }
774
775 fn read_binary<V: CopyFromBuf>(&mut self) -> Result<V> {
776 let received_len = self.read_varint_u64()? as usize;
777 ensure_err!(
778 self.string_limit.map_or(true, |lim| received_len < lim),
779 ProtocolError::InvalidDataLength
780 );
781 ensure_err!(self.buffer.remaining() >= received_len, ProtocolError::EOF);
782
783 Ok(V::copy_from_buf(&mut self.buffer, received_len))
784 }
785}
786
787pub fn serialize_size<T>(v: &T) -> usize
789where
790 T: Serialize<CompactProtocolSerializer<SizeCounter>>,
791{
792 let mut sizer = CompactProtocolSerializer {
793 state: EncState::new(),
794 buffer: SizeCounter::new(),
795 container_limit: None,
796 string_limit: None,
797 };
798 v.write(&mut sizer);
799 sizer.finish()
800}
801
802pub fn serialize_to_buffer<T>(v: T, buffer: BytesMut) -> CompactProtocolSerializer<BytesMut>
806where
807 T: Serialize<CompactProtocolSerializer<BytesMut>>,
808{
809 let mut buf = CompactProtocolSerializer {
811 state: EncState::new(),
812 buffer,
813 container_limit: None,
814 string_limit: None,
815 };
816 v.write(&mut buf);
817 buf
818}
819
820pub trait SerializeRef:
821 Serialize<CompactProtocolSerializer<SizeCounter>> + Serialize<CompactProtocolSerializer<BytesMut>>
822where
823 for<'a> &'a Self: Serialize<CompactProtocolSerializer<SizeCounter>>,
824 for<'a> &'a Self: Serialize<CompactProtocolSerializer<BytesMut>>,
825{
826}
827
828impl<T> SerializeRef for T
829where
830 T: Serialize<CompactProtocolSerializer<BytesMut>>,
831 T: Serialize<CompactProtocolSerializer<SizeCounter>>,
832 for<'a> &'a T: Serialize<CompactProtocolSerializer<BytesMut>>,
833 for<'a> &'a T: Serialize<CompactProtocolSerializer<SizeCounter>>,
834{
835}
836
837pub fn serialize<T>(v: T) -> Bytes
839where
840 T: Serialize<CompactProtocolSerializer<SizeCounter>>
841 + Serialize<CompactProtocolSerializer<BytesMut>>,
842{
843 let sz = serialize_size(&v);
844 let buf = serialize_to_buffer(v, BytesMut::with_capacity(sz));
845 buf.finish()
847}
848
849pub trait DeserializeSlice:
850 for<'a> Deserialize<CompactProtocolDeserializer<Cursor<&'a [u8]>>>
851{
852}
853
854impl<T> DeserializeSlice for T where
855 T: for<'a> Deserialize<CompactProtocolDeserializer<Cursor<&'a [u8]>>>
856{
857}
858
859pub fn deserialize<T, B, C>(b: B) -> Result<T>
861where
862 B: Into<DeserializeSource<C>>,
863 C: BufExt,
864 T: Deserialize<CompactProtocolDeserializer<C>>,
865{
866 let source: DeserializeSource<C> = b.into();
867 let mut deser = CompactProtocolDeserializer::new(source.0);
868 T::read(&mut deser)
869}