1use std::io::Cursor;
18
19use bufsize::SizeCounter;
20use bytes::Bytes;
21use bytes::BytesMut;
22use ghost::phantom;
23
24use crate::binary_type::CopyFromBuf;
25use crate::bufext::BufExt;
26use crate::bufext::BufMutExt;
27use crate::bufext::DeserializeSource;
28use crate::deserialize::Deserialize;
29use crate::errors::ProtocolError;
30use crate::framing::Framing;
31use crate::protocol::Field;
32use crate::protocol::Protocol;
33use crate::protocol::ProtocolReader;
34use crate::protocol::ProtocolWriter;
35use crate::serialize::Serialize;
36use crate::thrift_protocol::MessageType;
37use crate::thrift_protocol::ProtocolID;
38use crate::ttype::TType;
39use crate::varint;
40use crate::Result;
41
42const COMPACT_PROTOCOL_VERSION: u8 = 0x02;
43const PROTOCOL_ID: u8 = 0x82;
44const TYPE_MASK: u8 = 0xE0;
45const TYPE_SHIFT_AMOUNT: usize = 5;
46const VERSION_MASK: u8 = 0x1f;
47
48#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
49#[repr(u8)]
50pub enum CType {
51 Stop = 0,
52 BoolTrue = 1,
53 BoolFalse = 2,
54 Byte = 3,
55 I16 = 4,
56 I32 = 5,
57 I64 = 6,
58 Double = 7,
59 Binary = 8,
60 List = 9,
61 Set = 10,
62 Map = 11,
63 Struct = 12,
64 Float = 13,
65}
66
67impl TryFrom<i8> for CType {
68 type Error = anyhow::Error;
69
70 fn try_from(v: i8) -> Result<Self> {
71 let ret = match v {
72 0 => CType::Stop,
73 1 => CType::BoolTrue,
74 2 => CType::BoolFalse,
75 3 => CType::Byte,
76 4 => CType::I16,
77 5 => CType::I32,
78 6 => CType::I64,
79 7 => CType::Double,
80 8 => CType::Binary,
81 9 => CType::List,
82 10 => CType::Set,
83 11 => CType::Map,
84 12 => CType::Struct,
85 13 => CType::Float,
86 _ => bail_err!(ProtocolError::InvalidTypeTag),
87 };
88 Ok(ret)
89 }
90}
91
92impl From<CType> for TType {
93 fn from(cty: CType) -> TType {
94 match cty {
95 CType::Stop => TType::Stop,
96 CType::BoolFalse | CType::BoolTrue => TType::Bool,
97 CType::Byte => TType::Byte,
98 CType::Double => TType::Double,
99 CType::I16 => TType::I16,
100 CType::I32 => TType::I32,
101 CType::I64 => TType::I64,
102 CType::Binary => TType::String,
103 CType::Struct => TType::Struct,
104 CType::Map => TType::Map,
105 CType::Set => TType::Set,
106 CType::List => TType::List,
107 CType::Float => TType::Float,
108 }
109 }
110}
111
112impl From<TType> for CType {
113 fn from(tty: TType) -> CType {
114 match tty {
115 TType::Stop => CType::Stop,
116 TType::Bool => CType::BoolTrue,
117 TType::Byte => CType::Byte,
118 TType::Double => CType::Double,
119 TType::I16 => CType::I16,
120 TType::I32 => CType::I32,
121 TType::I64 => CType::I64,
122 TType::String => CType::Binary,
123 TType::Struct => CType::Struct,
124 TType::Map => CType::Map,
125 TType::Set => CType::Set,
126 TType::List => CType::List,
127 TType::Float => CType::Float,
128 bad => panic!("Don't know how to convert TType {:?} to CType", bad),
129 }
130 }
131}
132
133impl From<bool> for CType {
134 fn from(b: bool) -> CType {
135 if b { CType::BoolTrue } else { CType::BoolFalse }
136 }
137}
138
139#[phantom]
160#[derive(Copy, Clone)]
161pub struct CompactProtocol<F = Bytes>;
162
163#[derive(Debug, Clone)]
164struct EncState {
165 idxstack: Vec<i16>,
166 lastidx: i16,
167 field: Option<(TType, i16)>,
168}
169
170impl EncState {
171 fn new() -> Self {
172 EncState {
173 idxstack: Vec::new(),
174 lastidx: 0,
175 field: None,
176 }
177 }
178
179 fn struct_begin(&mut self) {
180 self.idxstack.push(self.lastidx);
181 self.lastidx = 0;
182 }
183
184 fn struct_end(&mut self) {
185 self.lastidx = self.idxstack.pop().expect("struct stack underrun");
186 }
187
188 fn field_begin(&mut self, tty: TType, idx: i16) {
189 self.field = Some((tty, idx));
190 }
191
192 fn field_get(&mut self) -> Option<(TType, i16)> {
193 self.field.take()
194 }
195
196 fn field_end(&mut self) {
197 debug_assert!(
198 self.field.is_none() ||
200 matches!(self.field, Some((TType::Void, _)))
202 )
203 }
204
205 fn in_field(&self) -> bool {
206 self.field.is_some()
207 }
208}
209
210pub struct CompactProtocolSerializer<B> {
211 state: EncState,
212 buffer: B,
213 string_limit: Option<usize>,
214 container_limit: Option<usize>,
215}
216
217pub struct CompactProtocolDeserializer<B> {
218 state: EncState,
219 boolfield: Option<CType>,
220 buffer: B,
221 string_limit: Option<usize>,
222 container_limit: Option<usize>,
223}
224
225impl<F> Protocol for CompactProtocol<F>
226where
227 F: Framing + 'static,
228{
229 type Frame = F;
230 type Sizer = CompactProtocolSerializer<SizeCounter>;
231 type Serializer = CompactProtocolSerializer<F::EncBuf>;
232 type Deserializer = CompactProtocolDeserializer<F::DecBuf>;
233
234 const PROTOCOL_ID: ProtocolID = ProtocolID::CompactProtocol;
235
236 fn serializer<SZ, SER>(size: SZ, ser: SER) -> <Self::Serializer as ProtocolWriter>::Final
237 where
238 SZ: FnOnce(&mut Self::Sizer),
239 SER: FnOnce(&mut Self::Serializer),
240 {
241 let mut sizer = CompactProtocolSerializer {
243 state: EncState::new(),
244 buffer: SizeCounter::new(),
245 container_limit: None,
246 string_limit: None,
247 };
248 size(&mut sizer);
249 let sz = sizer.finish();
250
251 let mut buf = CompactProtocolSerializer {
253 state: EncState::new(),
254 buffer: F::enc_with_capacity(sz),
255 container_limit: None,
256 string_limit: None,
257 };
258 ser(&mut buf);
259
260 buf.finish()
262 }
263
264 fn deserializer(buf: F::DecBuf) -> Self::Deserializer {
265 CompactProtocolDeserializer::new(buf)
266 }
267
268 fn into_buffer(deser: Self::Deserializer) -> F::DecBuf {
269 deser.into_inner()
270 }
271}
272
273impl<B> CompactProtocolSerializer<B>
274where
275 B: BufMutExt,
276{
277 pub fn with_buffer(buffer: B) -> Self {
278 Self {
279 state: EncState::new(),
280 buffer,
281 string_limit: None,
282 container_limit: None,
283 }
284 }
285
286 pub fn into_inner(self) -> B {
287 self.buffer
288 }
289
290 #[inline]
291 fn write_varint_i64(&mut self, v: i64) {
292 self.buffer.put_varint_i64(v)
293 }
294
295 #[inline]
296 fn write_varint_i32(&mut self, v: i32) {
297 self.write_varint_i64(v as i64);
298 }
299
300 #[inline]
301 fn write_varint_i16(&mut self, v: i16) {
302 self.write_varint_i64(v as i64);
303 }
304
305 fn write_field_id(&mut self, cty: CType) {
309 match self.state.field_get() {
310 None => {}
311 Some((tty, idx)) => {
312 debug_assert_eq!(tty, TType::from(cty));
313
314 let delta = idx - self.state.lastidx;
315 self.state.lastidx = idx;
316
317 if delta <= 0 || delta > 15 {
318 self.buffer.put_u8(cty as u8);
319 self.write_varint_i16(idx);
320 } else {
321 self.buffer.put_u8(((delta as u8) << 4) | (cty as u8));
322 }
323 }
324 }
325 }
326
327 fn write_sequence(&mut self, elem_type: TType, size: usize) {
329 assert!(
330 self.container_limit.map_or(true, |lim| size < lim),
331 "container too large {}, lim {:?}",
332 size,
333 self.container_limit
334 );
335
336 let cty = CType::from(elem_type) as u8;
337
338 if size < 0x0f {
340 self.buffer.put_u8((size as u8) << 4 | cty);
341 } else {
342 self.buffer.put_u8((0x0f << 4) | cty);
343 self.buffer.put_varint_u64(size as u64);
344 }
345 }
346}
347
348impl<B: BufMutExt> ProtocolWriter for CompactProtocolSerializer<B> {
349 type Final = B::Final; fn write_message_begin(&mut self, name: &str, msgtype: MessageType, seqid: u32) {
352 let msgtype = msgtype as u8;
353 self.buffer.put_u8(PROTOCOL_ID);
354 self.buffer
355 .put_u8(COMPACT_PROTOCOL_VERSION | ((msgtype << TYPE_SHIFT_AMOUNT) & TYPE_MASK));
356 self.buffer.put_varint_u64(seqid as u64);
357 self.write_string(name);
358 }
359
360 #[inline]
361 fn write_message_end(&mut self) {}
362
363 #[inline]
364 fn write_struct_begin(&mut self, _name: &str) {
365 self.write_field_id(CType::Struct);
366 self.state.struct_begin();
367 }
368
369 #[inline]
370 fn write_struct_end(&mut self) {
371 self.state.struct_end();
372 }
373
374 #[inline]
375 fn write_field_begin(&mut self, _name: &str, type_id: TType, id: i16) {
376 self.state.field_begin(type_id, id);
378 }
379
380 #[inline]
381 fn write_field_end(&mut self) {
382 self.state.field_end()
383 }
384
385 #[inline]
386 fn write_field_stop(&mut self) {
387 self.buffer.put_u8(CType::Stop as u8)
388 }
389
390 fn write_map_begin(&mut self, key_type: TType, value_type: TType, size: usize) {
391 assert!(
392 self.container_limit.map_or(true, |lim| size < lim),
393 "map too large {}, lim {:?}",
394 size,
395 self.container_limit
396 );
397
398 self.write_field_id(CType::Map);
399 self.buffer.put_varint_u64(size as u64);
400 if size > 0 {
401 let ckty = CType::from(key_type);
402 let cvty = CType::from(value_type);
403
404 self.buffer.put_u8((ckty as u8) << 4 | (cvty as u8));
405 }
406 }
407
408 #[inline]
409 fn write_map_key_begin(&mut self) {}
410
411 #[inline]
412 fn write_map_value_begin(&mut self) {}
413
414 #[inline]
415 fn write_map_end(&mut self) {}
416
417 fn write_list_begin(&mut self, elem_type: TType, size: usize) {
418 assert!(self.container_limit.map_or(true, |lim| size < lim));
419 self.write_field_id(CType::List);
420 self.write_sequence(elem_type, size);
421 }
422
423 #[inline]
424 fn write_list_value_begin(&mut self) {}
425
426 #[inline]
427 fn write_list_end(&mut self) {}
428
429 fn write_set_begin(&mut self, elem_type: TType, size: usize) {
430 self.write_field_id(CType::Set);
431 self.write_sequence(elem_type, size);
432 }
433
434 #[inline]
435 fn write_set_value_begin(&mut self) {}
436
437 fn write_set_end(&mut self) {}
438
439 fn write_bool(&mut self, value: bool) {
440 if self.state.in_field() {
444 self.write_field_id(CType::from(value))
445 } else if value {
446 self.write_byte(CType::BoolTrue as i8)
447 } else {
448 self.write_byte(CType::BoolFalse as i8)
449 }
450 }
451
452 fn write_byte(&mut self, value: i8) {
453 self.write_field_id(CType::Byte);
454 self.buffer.put_i8(value)
455 }
456
457 fn write_i16(&mut self, value: i16) {
458 self.write_field_id(CType::I16);
459 self.write_varint_i16(value)
460 }
461
462 fn write_i32(&mut self, value: i32) {
463 self.write_field_id(CType::I32);
464 self.write_varint_i32(value)
465 }
466
467 fn write_i64(&mut self, value: i64) {
468 self.write_field_id(CType::I64);
469 self.write_varint_i64(value)
470 }
471
472 fn write_double(&mut self, value: f64) {
473 self.write_field_id(CType::Double);
474 self.buffer.put_f64(value)
475 }
476
477 fn write_float(&mut self, value: f32) {
478 self.write_field_id(CType::Float);
479 self.buffer.put_f32(value)
480 }
481
482 #[inline]
483 fn write_string(&mut self, value: &str) {
484 self.write_binary(value.as_bytes());
485 }
486
487 fn write_binary(&mut self, value: &[u8]) {
488 let size = value.len();
489 assert!(
490 self.string_limit.map_or(true, |lim| size < lim),
491 "string too large {}, lim {:?}",
492 size,
493 self.string_limit
494 );
495
496 self.write_field_id(CType::Binary);
497 self.buffer.put_varint_u64(size as u64);
498 self.buffer.put_slice(value)
499 }
500
501 fn finish(self) -> B::Final {
502 self.buffer.finalize()
503 }
504}
505
506impl<B: BufExt> CompactProtocolDeserializer<B> {
507 pub fn new(buffer: B) -> Self {
508 CompactProtocolDeserializer {
509 state: EncState::new(),
510 boolfield: None,
511 buffer,
512 string_limit: None,
513 container_limit: None,
514 }
515 }
516
517 pub fn into_inner(self) -> B {
518 self.buffer
519 }
520
521 fn peek_bytes(&self, len: usize) -> Option<&[u8]> {
522 if self.buffer.chunk().len() >= len {
523 Some(&self.buffer.chunk()[..len])
524 } else {
525 None
526 }
527 }
528
529 fn read_varint_u64(&mut self) -> Result<u64> {
530 varint::read_u64(&mut self.buffer)
531 }
532
533 fn read_varint_i64(&mut self) -> Result<i64> {
534 self.read_varint_u64().map(varint::unzigzag)
535 }
536
537 fn read_varint_i32(&mut self) -> Result<i32> {
538 self.read_varint_i64()
539 .and_then(|v| i32::try_from(v).map_err(|_| ProtocolError::InvalidValue.into()))
540 }
541
542 fn read_varint_i16(&mut self) -> Result<i16> {
543 self.read_varint_i64()
544 .and_then(|v| i16::try_from(v).map_err(|_| (ProtocolError::InvalidValue).into()))
545 }
546}
547
548impl<B: BufExt> ProtocolReader for CompactProtocolDeserializer<B> {
549 fn read_message_begin<F, T>(&mut self, msgfn: F) -> Result<(T, MessageType, u32)>
550 where
551 F: FnOnce(&[u8]) -> T,
552 {
553 let protocolid = self.read_byte()? as u8;
554 ensure_err!(protocolid == PROTOCOL_ID, ProtocolError::BadVersion);
555
556 let vandty = self.read_byte()? as u8;
557 ensure_err!(
558 (vandty & VERSION_MASK) == COMPACT_PROTOCOL_VERSION,
559 ProtocolError::BadVersion
560 );
561
562 let msgty = (vandty & TYPE_MASK) >> TYPE_SHIFT_AMOUNT;
563 let msgty = MessageType::try_from(msgty as u32)?;
564
565 let seqid = self.read_varint_u64()? as u32;
566
567 let name = {
568 let len = self.read_varint_u64()? as usize;
569 let (len, name) = {
570 if self.peek_bytes(len).is_some() {
571 let namebuf = self.peek_bytes(len).unwrap();
572 (namebuf.len(), msgfn(namebuf))
573 } else {
574 ensure_err!(
575 self.buffer.remaining() >= len,
576 ProtocolError::InvalidDataLength
577 );
578 let namebuf: Vec<u8> = Vec::copy_from_buf(&mut self.buffer, len);
579 (0, msgfn(namebuf.as_slice()))
580 }
581 };
582 self.buffer.advance(len);
583 name
584 };
585
586 Ok((name, msgty, seqid))
587 }
588
589 fn read_message_end(&mut self) -> Result<()> {
590 Ok(())
591 }
592
593 fn read_struct_begin<F, T>(&mut self, namefn: F) -> Result<T>
594 where
595 F: FnOnce(&[u8]) -> T,
596 {
597 self.state.struct_begin();
598 Ok(namefn(&[]))
599 }
600
601 fn read_struct_end(&mut self) -> Result<()> {
602 self.state.struct_end();
603 Ok(())
604 }
605
606 fn read_field_begin<F, T>(&mut self, fieldfn: F, _fields: &[Field]) -> Result<(T, TType, i16)>
607 where
608 F: FnOnce(&[u8]) -> T,
609 {
610 let tyid = self.read_byte()? as i8;
611 let cty = CType::try_from(tyid & 0x0f)?;
612 let didx = (tyid >> 4) & 0x0f;
613
614 let tty = TType::from(cty);
615
616 let idx = match (tty, didx) {
618 (TType::Stop, _) => 0,
619 (_, 0) => self.read_varint_i16()?,
620 (_, didx) => self.state.lastidx + (didx as i16),
621 };
622
623 self.state.lastidx = idx;
624
625 if tty == TType::Bool {
626 self.boolfield = Some(cty);
627 }
628
629 let f = fieldfn(&[]);
630 Ok((f, tty, idx))
631 }
632
633 fn read_field_end(&mut self) -> Result<()> {
634 Ok(())
635 }
636
637 fn read_map_begin(&mut self) -> Result<(TType, TType, Option<usize>)> {
638 let size = self.read_varint_u64()? as usize;
639
640 ensure_err!(
641 self.container_limit.map_or(true, |lim| size < lim),
642 ProtocolError::InvalidDataLength
643 );
644
645 let kvtype = if size > 0 { self.read_byte()? } else { 0 };
647
648 let kcty = CType::try_from((kvtype >> 4) & 0x0f)?;
649 let vcty = CType::try_from((kvtype) & 0x0f)?;
650
651 Ok((TType::from(kcty), TType::from(vcty), Some(size as usize)))
652 }
653
654 #[inline]
655 fn read_map_key_begin(&mut self) -> Result<bool> {
656 Ok(true)
657 }
658
659 #[inline]
660 fn read_map_value_begin(&mut self) -> Result<()> {
661 Ok(())
662 }
663
664 #[inline]
665 fn read_map_value_end(&mut self) -> Result<()> {
666 Ok(())
667 }
668
669 fn read_map_end(&mut self) -> Result<()> {
670 Ok(())
671 }
672
673 fn read_list_begin(&mut self) -> Result<(TType, Option<usize>)> {
674 let szty = self.read_byte()?;
675 let cty = CType::try_from(szty & 0x0f)?;
676 let elem_type = TType::from(cty);
677
678 let size = match (szty >> 4) & 0x0f {
679 0x0f => self.read_varint_u64()? as usize,
680 sz => sz as usize,
681 };
682
683 ensure_err!(
684 self.container_limit.map_or(true, |lim| size < lim),
685 ProtocolError::InvalidDataLength
686 );
687
688 Ok((elem_type, Some(size)))
689 }
690
691 #[inline]
692 fn read_list_value_begin(&mut self) -> Result<bool> {
693 Ok(true)
694 }
695
696 #[inline]
697 fn read_list_value_end(&mut self) -> Result<()> {
698 Ok(())
699 }
700
701 fn read_list_end(&mut self) -> Result<()> {
702 Ok(())
703 }
704
705 fn read_set_begin(&mut self) -> Result<(TType, Option<usize>)> {
706 self.read_list_begin()
707 }
708
709 #[inline]
710 fn read_set_value_begin(&mut self) -> Result<bool> {
711 Ok(true)
712 }
713
714 #[inline]
715 fn read_set_value_end(&mut self) -> Result<()> {
716 Ok(())
717 }
718
719 fn read_set_end(&mut self) -> Result<()> {
720 Ok(())
721 }
722
723 fn read_bool(&mut self) -> Result<bool> {
724 let cty = match self.boolfield.take() {
727 None => CType::try_from(self.read_byte()?)?,
728 Some(cty) => cty,
729 };
730
731 Ok(cty == CType::BoolTrue)
735 }
736
737 fn read_byte(&mut self) -> Result<i8> {
738 ensure_err!(self.buffer.remaining() >= 1, ProtocolError::EOF);
739
740 Ok(self.buffer.get_i8())
741 }
742
743 fn read_i16(&mut self) -> Result<i16> {
744 self.read_varint_i16()
745 }
746
747 fn read_i32(&mut self) -> Result<i32> {
748 self.read_varint_i32()
749 }
750
751 fn read_i64(&mut self) -> Result<i64> {
752 self.read_varint_i64()
753 }
754
755 fn read_double(&mut self) -> Result<f64> {
756 ensure_err!(self.buffer.remaining() >= 8, ProtocolError::EOF);
757
758 Ok(self.buffer.get_f64())
759 }
760
761 fn read_float(&mut self) -> Result<f32> {
762 ensure_err!(self.buffer.remaining() >= 4, ProtocolError::EOF);
763
764 Ok(self.buffer.get_f32())
765 }
766
767 fn read_string(&mut self) -> Result<String> {
768 let vec = self.read_binary::<Vec<u8>>()?;
769
770 Ok(String::from_utf8(vec)?)
771 }
772
773 fn read_binary<V: CopyFromBuf>(&mut self) -> Result<V> {
774 let received_len = self.read_varint_u64()? as usize;
775 ensure_err!(
776 self.string_limit.map_or(true, |lim| received_len < lim),
777 ProtocolError::InvalidDataLength
778 );
779 ensure_err!(self.buffer.remaining() >= received_len, ProtocolError::EOF);
780
781 Ok(V::copy_from_buf(&mut self.buffer, received_len))
782 }
783}
784
785pub fn serialize_size<T>(v: &T) -> usize
787where
788 T: Serialize<CompactProtocolSerializer<SizeCounter>>,
789{
790 let mut sizer = CompactProtocolSerializer {
791 state: EncState::new(),
792 buffer: SizeCounter::new(),
793 container_limit: None,
794 string_limit: None,
795 };
796 v.write(&mut sizer);
797 sizer.finish()
798}
799
800pub fn serialize_to_buffer<T>(v: T, buffer: BytesMut) -> CompactProtocolSerializer<BytesMut>
804where
805 T: Serialize<CompactProtocolSerializer<BytesMut>>,
806{
807 let mut buf = CompactProtocolSerializer {
809 state: EncState::new(),
810 buffer,
811 container_limit: None,
812 string_limit: None,
813 };
814 v.write(&mut buf);
815 buf
816}
817
818pub trait SerializeRef:
819 Serialize<CompactProtocolSerializer<SizeCounter>> + Serialize<CompactProtocolSerializer<BytesMut>>
820where
821 for<'a> &'a Self: Serialize<CompactProtocolSerializer<SizeCounter>>,
822 for<'a> &'a Self: Serialize<CompactProtocolSerializer<BytesMut>>,
823{
824}
825
826impl<T> SerializeRef for T
827where
828 T: Serialize<CompactProtocolSerializer<BytesMut>>,
829 T: Serialize<CompactProtocolSerializer<SizeCounter>>,
830 for<'a> &'a T: Serialize<CompactProtocolSerializer<BytesMut>>,
831 for<'a> &'a T: Serialize<CompactProtocolSerializer<SizeCounter>>,
832{
833}
834
835pub fn serialize<T>(v: T) -> Bytes
837where
838 T: Serialize<CompactProtocolSerializer<SizeCounter>>
839 + Serialize<CompactProtocolSerializer<BytesMut>>,
840{
841 let sz = serialize_size(&v);
842 let buf = serialize_to_buffer(v, BytesMut::with_capacity(sz));
843 buf.finish()
845}
846
847pub trait DeserializeSlice:
848 for<'a> Deserialize<CompactProtocolDeserializer<Cursor<&'a [u8]>>>
849{
850}
851
852impl<T> DeserializeSlice for T where
853 T: for<'a> Deserialize<CompactProtocolDeserializer<Cursor<&'a [u8]>>>
854{
855}
856
857pub fn deserialize<T, B, C>(b: B) -> Result<T>
859where
860 B: Into<DeserializeSource<C>>,
861 C: BufExt,
862 T: Deserialize<CompactProtocolDeserializer<C>>,
863{
864 let source: DeserializeSource<C> = b.into();
865 let mut deser = CompactProtocolDeserializer::new(source.0);
866 Ok(T::read(&mut deser)?)
867}