Skip to main content

monoio_thrift/codec/
ttheader.rs

1//! TTheader is a transport protocol designed by CloudWeGo.
2//!
3//! For more information, please visit https://www.cloudwego.io/docs/kitex/reference/transport_protocol_ttheader/
4
5use std::collections::HashMap;
6use std::{io, ptr::copy_nonoverlapping};
7
8use smallvec::SmallVec;
9use smol_str::SmolStr;
10
11use monoio_codec::{Decoded, Decoder, Encoder};
12
13use bytes::{Buf, BufMut};
14use num_enum::TryFromPrimitive;
15
16pub type HeaderMap = HashMap<SmolStr, SmolStr>;
17
18#[derive(Clone)]
19pub struct TTHeader {
20    pub header_length: u32,
21    pub payload_length: u32,
22    pub seq_id: i32,
23    pub flags: u16,
24    pub protocol_id: ProtocolId,
25    // int key < IntMetaKey::INDEX_TABLE_SIZE
26    pub int_headers: [Option<SmolStr>; IntMetaKey::INDEX_TABLE_SIZE],
27    // int key >= IntMetaKey::INDEX_TABLE_SIZE
28    pub int_headers_ext: SmallVec<[(u16, SmolStr); 2]>,
29    pub str_headers: HeaderMap,
30    pub acl_token: Option<SmolStr>,
31}
32
33impl Default for TTHeader {
34    fn default() -> Self {
35        Self {
36            header_length: 0,
37            payload_length: 0,
38            seq_id: 0,
39            flags: 0,
40            protocol_id: ProtocolId::Binary,
41            int_headers: Default::default(),
42            int_headers_ext: Default::default(),
43            str_headers: Default::default(),
44            acl_token: None,
45        }
46    }
47}
48
49impl TTHeader {
50    #[inline]
51    pub fn new_for_encode(payload_length_hint: u32) -> Self {
52        Self {
53            header_length: 0,
54            payload_length: payload_length_hint,
55            seq_id: 0,
56            flags: 0,
57            protocol_id: ProtocolId::Binary,
58            int_headers: Default::default(),
59            int_headers_ext: Default::default(),
60            str_headers: Default::default(),
61            acl_token: None,
62        }
63    }
64
65    #[inline]
66    pub fn new() -> Self {
67        Self::default()
68    }
69
70    // TODO: now only supports io::Error
71    fn decode_header(&mut self, total_length: u32, src: &mut bytes::BytesMut) -> io::Result<()> {
72        #[inline]
73        unsafe fn read_u8_unchecked(buf: &[u8], index: &mut usize) -> u8 {
74            let val = *buf.get_unchecked(*index);
75            *index += 1;
76            val
77        }
78
79        #[inline]
80        unsafe fn read_u16_unchecked(buf: &[u8], index: &mut usize) -> u16 {
81            let val = u16::from_be_bytes(
82                buf.get_unchecked(*index..*index + 2)
83                    .try_into()
84                    .unwrap_unchecked(),
85            );
86            *index += 2;
87            val
88        }
89
90        macro_rules! read_u16_checked {
91            ($buf: ident, $index: ident, $len: expr) => {{
92                if $index + 2 > $len as usize {
93                    return Err(io::Error::new(io::ErrorKind::InvalidData, "invalid data"));
94                }
95                unsafe { read_u16_unchecked($buf, &mut $index) }
96            }};
97        }
98
99        #[inline]
100        unsafe fn read_raw_str_unchecked(buf: &[u8], len: usize, index: &mut usize) -> SmolStr {
101            let val = {
102                let str = SmolStr::new(std::str::from_utf8_unchecked(
103                    buf.get_unchecked(*index..(*index + len)),
104                ));
105                str
106            };
107            *index += len;
108            val
109        }
110
111        macro_rules! read_str_checked {
112            ($buf: ident, $index: ident, $len: expr) => {{
113                let val_len = read_u16_checked!($buf, $index, $len);
114                if $index + val_len as usize > $len as usize {
115                    return Err(io::Error::new(io::ErrorKind::InvalidData, "invalid data"));
116                }
117                unsafe { read_raw_str_unchecked($buf, val_len as usize, &mut $index) }
118            }};
119        }
120
121        src.advance(2); // skip magic
122        self.flags = src.get_u16();
123        self.seq_id = src.get_i32();
124        let header_size = src.get_u16();
125        self.header_length = header_size as u32 * 4;
126        if self.header_length as usize > src.len() || header_size < 1 {
127            return Err(io::Error::new(
128                io::ErrorKind::InvalidData,
129                "invalid header length",
130            ));
131        }
132        let header_buf = src.split_to(self.header_length as usize);
133        self.payload_length = total_length - self.header_length - 10;
134        let buf = header_buf.as_ref();
135        let mut index = 0;
136        // It's safe when checked header_size >= 1
137        if let Ok(protocol_id) = unsafe { ProtocolId::try_from(read_u8_unchecked(buf, &mut index)) }
138        {
139            self.protocol_id = protocol_id;
140        }
141        index += 1; // TODO: support transform
142
143        let mut _padding_num = 0usize;
144
145        while index < self.header_length as usize {
146            // It's safe because while expr
147            let info_id = unsafe { read_u8_unchecked(buf, &mut index) };
148            match info_id {
149                info::INFO_PADDING => {
150                    _padding_num += 1;
151                    continue;
152                }
153                info::INFO_KEY_VALUE => {
154                    let kv_size = read_u16_checked!(buf, index, self.header_length);
155                    // TODO: reserve
156                    for _ in 0..kv_size {
157                        let key = read_str_checked!(buf, index, self.header_length);
158                        let val = read_str_checked!(buf, index, self.header_length);
159                        self.str_headers.insert(key, val);
160                    }
161                }
162                info::INFO_INT_KEY_VALUE => {
163                    let kv_size = read_u16_checked!(buf, index, self.header_length);
164                    for _ in 0..kv_size {
165                        let key = read_u16_checked!(buf, index, self.header_length);
166                        let val = read_str_checked!(buf, index, self.header_length);
167
168                        if (key as usize) < IntMetaKey::INDEX_TABLE_SIZE {
169                            // It's safe because `if expr`
170                            unsafe {
171                                *self.int_headers.get_unchecked_mut(key as usize) = Some(val);
172                            }
173                        } else {
174                            self.int_headers_ext.push((key, val));
175                        }
176                    }
177                }
178                info::ACL_TOKEN_KEY_VALUE => {
179                    self.acl_token = Some(read_str_checked!(buf, index, self.header_length));
180                }
181                _ => {
182                    // We are not able to decode the protocol anymore, since we don't know the
183                    // layout
184                    let msg = format!("unexpected info id in ttheader: {info_id}");
185                    tracing::error!("{}", msg);
186                    return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
187                }
188            }
189        }
190        Ok(())
191    }
192}
193
194#[derive(Default)]
195pub struct TTHeaderDecoder;
196
197impl TTHeaderDecoder {
198    pub const fn new() -> Self {
199        Self
200    }
201}
202
203impl Decoder for TTHeaderDecoder {
204    type Item = TTHeader;
205    type Error = io::Error;
206
207    fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Decoded<Self::Item>, Self::Error> {
208        if src.len() < MIN_HEADER_LENGTH {
209            return Ok(Decoded::InsufficientAtLeast(MIN_HEADER_LENGTH));
210        }
211
212        if src[4..HEADER_DETECT_LENGTH] == [0x10, 0x00] {
213            let mut header_length = [0; 2];
214            unsafe { copy_nonoverlapping(src.as_ptr().add(12), header_length.as_mut_ptr(), 2) };
215            let header_length = u16::from_be_bytes(header_length) as usize * 4;
216            if src.len() < header_length + MIN_HEADER_LENGTH {
217                return Ok(Decoded::InsufficientAtLeast(
218                    header_length + MIN_HEADER_LENGTH,
219                ));
220            }
221
222            let mut length = [0; 4];
223            unsafe { copy_nonoverlapping(src.as_ptr(), length.as_mut_ptr(), 4) };
224            let length = u32::from_be_bytes(length);
225
226            src.advance(4);
227
228            // decode ttheader
229            let mut ttheader = TTHeader::new();
230            ttheader.decode_header(length, src)?; // TODO: which error type?
231            Ok(Decoded::Some(ttheader))
232        } else {
233            Err(io::Error::new(io::ErrorKind::Other, "illegal ttheader"))
234        }
235    }
236}
237
238#[derive(Default)]
239pub struct TTHeaderEncoder;
240
241impl TTHeaderEncoder {
242    pub const fn new() -> Self {
243        Self
244    }
245}
246
247impl Encoder<TTHeader> for TTHeaderEncoder {
248    type Error = io::Error;
249
250    fn encode(&mut self, item: TTHeader, dst: &mut bytes::BytesMut) -> Result<(), Self::Error> {
251        #[inline]
252        fn put_str(s: &SmolStr, dst: &mut bytes::BytesMut) {
253            dst.put_u16(s.len() as u16);
254            dst.put_slice(s.as_bytes());
255        }
256
257        dst.reserve(4 * 1024); // cap 4k
258        let zero_index = dst.len();
259        unsafe {
260            dst.advance_mut(4);
261        }
262
263        // tt header magic
264        dst.put_u16(TT_HEADER_MAGIC);
265        // flags
266        dst.put_u16(item.flags);
267        dst.put_i32(item.seq_id);
268
269        // Alloc 2-byte space as header length
270        unsafe {
271            dst.advance_mut(2);
272        }
273
274        dst.put_u8(item.protocol_id as u8);
275        dst.put_u8(0); // TODO: transform_ids_num
276
277        // Write string KV start.
278        dst.put_u8(info::INFO_KEY_VALUE);
279        dst.put_u16(item.str_headers.len() as u16);
280
281        for (key, val) in item.str_headers.iter() {
282            put_str(key, dst);
283            put_str(val, dst);
284        }
285
286        // Write int KV start.
287        dst.put_u8(info::INFO_INT_KEY_VALUE);
288        let int_kv_index = dst.len();
289        let mut int_kv_len = 0_u16;
290        unsafe {
291            dst.advance_mut(2);
292        }
293
294        for (key, val) in item.int_headers.iter().enumerate() {
295            if let Some(val) = val {
296                dst.put_u16(key as u16);
297                put_str(val, dst);
298                int_kv_len += 1;
299            }
300        }
301
302        for (key, val) in item.int_headers_ext.iter() {
303            dst.put_u16(*key);
304            put_str(val, dst);
305            int_kv_len += 1;
306        }
307
308        // fill int kv length
309        let mut buf = &mut dst[int_kv_index..int_kv_index + 2];
310        buf.put_u16(int_kv_len);
311
312        // fill acl_token
313        if let Some(ref acl_token) = item.acl_token {
314            dst.put_u8(info::ACL_TOKEN_KEY_VALUE);
315            dst.put_u16(acl_token.len() as u16);
316            dst.put_slice(acl_token.as_bytes());
317        }
318
319        // write padding
320        let overflow = (dst.len() - 14 - zero_index) % 4;
321        let padding = (4 - overflow) % 4;
322        (0..padding).for_each(|_| dst.put_u8(0));
323
324        // fill header length
325        let header_size = dst.len() - zero_index;
326        let mut buf = &mut dst[zero_index + 12..zero_index + 12 + 2];
327        buf.put_u16(((header_size - 14) / 4).try_into().unwrap());
328        tracing::trace!(
329            "encode ttheader write header size: {}",
330            (header_size - 14) / 4
331        );
332
333        // fill length
334        let size = dst.len() + item.payload_length as usize - zero_index;
335        let mut buf = &mut dst[zero_index..zero_index + 4];
336        buf.put_u32((size - 4) as u32);
337        tracing::trace!("encode ttheader write length size: {}", size - 4);
338        Ok(())
339    }
340}
341
342pub struct TTHeaderPayload<T> {
343    pub ttheader: TTHeader,
344    pub payload: Option<T>,
345}
346
347impl<T> Clone for TTHeaderPayload<T>
348where
349    T: Clone,
350{
351    fn clone(&self) -> Self {
352        Self {
353            ttheader: self.ttheader.clone(),
354            payload: self.payload.clone(),
355        }
356    }
357}
358
359impl<T> Default for TTHeaderPayload<T> {
360    fn default() -> Self {
361        Self {
362            ttheader: Default::default(),
363            payload: Default::default(),
364        }
365    }
366}
367
368impl<T> TTHeaderPayload<T> {
369    fn new() -> Self {
370        Self {
371            ttheader: TTHeader::new(),
372            payload: None,
373        }
374    }
375}
376
377pub struct TTHeaderPayloadCodec<T> {
378    inner: T,
379}
380
381impl<T> TTHeaderPayloadCodec<T> {
382    pub fn new(inner: T) -> Self {
383        Self { inner }
384    }
385}
386
387impl<T: Decoder> Decoder for TTHeaderPayloadCodec<T>
388where
389    T::Error: From<io::Error>,
390{
391    type Item = TTHeaderPayload<T::Item>;
392    type Error = T::Error;
393
394    fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Decoded<Self::Item>, Self::Error> {
395        if src.len() < MIN_HEADER_LENGTH {
396            return Ok(Decoded::InsufficientAtLeast(MIN_HEADER_LENGTH));
397        }
398
399        if src[4..HEADER_DETECT_LENGTH] == [0x10, 0x00] {
400            let mut length = [0; 4];
401            unsafe { copy_nonoverlapping(src.as_ptr(), length.as_mut_ptr(), 4) };
402            let length = u32::from_be_bytes(length);
403            if src.len() < length as usize + 4 {
404                return Ok(Decoded::InsufficientAtLeast(length as usize + 4));
405            }
406            src.advance(4);
407
408            let mut item = Self::Item::new();
409            item.ttheader.decode_header(length, src)?;
410            match self.inner.decode(src) {
411                Ok(Decoded::Some(payload)) => item.payload = Some(payload),
412                Err(e) => return Err(e),
413                // we have already checked sufficient size, so it's err if Insufficient
414                _ => return Err(io::Error::new(io::ErrorKind::Other, "illegal payload").into()),
415            };
416            Ok(Decoded::Some(item))
417        } else {
418            Err(io::Error::new(io::ErrorKind::Other, "illegal ttheader").into())
419        }
420    }
421}
422
423impl<T, E: Encoder<T>> Encoder<TTHeaderPayload<T>> for TTHeaderPayloadCodec<E> {
424    type Error = E::Error;
425
426    fn encode(
427        &mut self,
428        item: TTHeaderPayload<T>,
429        dst: &mut bytes::BytesMut,
430    ) -> Result<(), Self::Error> {
431        let zero_index = dst.len();
432        let mut ttheader_encoder = TTHeaderEncoder {};
433        ttheader_encoder.encode(item.ttheader, dst)?;
434        self.inner
435            .encode(item.payload.expect("payload must some"), dst)?;
436        // fill length
437        let size = dst.len() - zero_index;
438        let mut buf = &mut dst[zero_index..zero_index + 4];
439        buf.put_u32((size - 4) as u32);
440        tracing::trace!("encode ttheader write length size: {}", size - 4);
441        Ok(())
442    }
443}
444
445#[derive(Default)]
446pub struct RawPayloadCodec;
447
448impl RawPayloadCodec {
449    pub const fn new() -> Self {
450        Self
451    }
452}
453
454impl Decoder for RawPayloadCodec {
455    type Item = bytes::Bytes;
456
457    type Error = io::Error;
458
459    fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Decoded<Self::Item>, Self::Error> {
460        Ok(Decoded::Some(bytes::Bytes::from(src.split())))
461    }
462}
463
464impl Encoder<bytes::Bytes> for RawPayloadCodec {
465    type Error = io::Error;
466
467    fn encode(&mut self, item: bytes::Bytes, dst: &mut bytes::BytesMut) -> Result<(), Self::Error> {
468        dst.reserve(item.len());
469        dst.extend_from_slice(&item);
470        Ok(())
471    }
472}
473
474/// 4-bytes length + 2-bytes magic
475/// https://www.cloudwego.io/docs/kitex/reference/transport_protocol_ttheader/
476const HEADER_DETECT_LENGTH: usize = 6;
477const MIN_HEADER_LENGTH: usize = 14;
478
479pub const TT_HEADER_MAGIC: u16 = 0x1000;
480
481mod info {
482    pub const INFO_PADDING: u8 = 0x00;
483    pub const INFO_KEY_VALUE: u8 = 0x01;
484    pub const INFO_INT_KEY_VALUE: u8 = 0x10;
485    pub const ACL_TOKEN_KEY_VALUE: u8 = 0x11;
486}
487
488#[derive(TryFromPrimitive, Clone, Copy, Default)]
489#[repr(u8)]
490pub enum ProtocolId {
491    #[default]
492    Binary = 0,
493    Compact = 2,   // Apache Thrift compact protocol
494    CompactV2 = 3, // fbthrift compact protocol
495    Protobuf = 4,
496}
497
498#[derive(PartialEq, Eq, Hash, Clone, Copy, TryFromPrimitive, Debug)]
499#[repr(u16)]
500pub enum IntMetaKey {
501    TransportType = 1,
502    // framed / unframed
503    LogId = 2,
504    FromService = 3,
505    FromCluster = 4,
506    FromIdc = 5,
507    ToService = 6,
508    ToMethod = 9,
509    ToCluster = 7,
510    ToIdc = 8,
511    Env = 10,
512    DestAddress = 11,
513    RPCTimeoutMs = 12,
514    RingHashKey = 14,
515    WithHeader = 16,
516    ConnTimeoutMs = 17,
517    TraceSpanCtx = 18,
518    ShortConnection = 19,
519    FromMethod = 20,
520    StressTag = 21,
521    MsgType = 22,
522    ConnectionRecycle = 23,
523    RawRingHashKey = 24,
524    LBType = 25,
525    ClusterShardId = 26,
526}
527
528impl IntMetaKey {
529    const INDEX_TABLE_SIZE: usize = Self::ClusterShardId as usize + 1;
530}