1use std::collections::HashMap;
6use std::{io, ptr::copy_nonoverlapping};
7
8use smallvec::SmallVec;
9use smol_str::SmolStr;
10
11use monoio_codec::{Decoded, Decoder, Encoder};
12
13use bytes::{Buf, BufMut};
14use num_enum::TryFromPrimitive;
15
16pub type HeaderMap = HashMap<SmolStr, SmolStr>;
17
18#[derive(Clone)]
19pub struct TTHeader {
20 pub header_length: u32,
21 pub payload_length: u32,
22 pub seq_id: i32,
23 pub flags: u16,
24 pub protocol_id: ProtocolId,
25 pub int_headers: [Option<SmolStr>; IntMetaKey::INDEX_TABLE_SIZE],
27 pub int_headers_ext: SmallVec<[(u16, SmolStr); 2]>,
29 pub str_headers: HeaderMap,
30 pub acl_token: Option<SmolStr>,
31}
32
33impl Default for TTHeader {
34 fn default() -> Self {
35 Self {
36 header_length: 0,
37 payload_length: 0,
38 seq_id: 0,
39 flags: 0,
40 protocol_id: ProtocolId::Binary,
41 int_headers: Default::default(),
42 int_headers_ext: Default::default(),
43 str_headers: Default::default(),
44 acl_token: None,
45 }
46 }
47}
48
49impl TTHeader {
50 #[inline]
51 pub fn new_for_encode(payload_length_hint: u32) -> Self {
52 Self {
53 header_length: 0,
54 payload_length: payload_length_hint,
55 seq_id: 0,
56 flags: 0,
57 protocol_id: ProtocolId::Binary,
58 int_headers: Default::default(),
59 int_headers_ext: Default::default(),
60 str_headers: Default::default(),
61 acl_token: None,
62 }
63 }
64
65 #[inline]
66 pub fn new() -> Self {
67 Self::default()
68 }
69
70 fn decode_header(&mut self, total_length: u32, src: &mut bytes::BytesMut) -> io::Result<()> {
72 #[inline]
73 unsafe fn read_u8_unchecked(buf: &[u8], index: &mut usize) -> u8 {
74 let val = *buf.get_unchecked(*index);
75 *index += 1;
76 val
77 }
78
79 #[inline]
80 unsafe fn read_u16_unchecked(buf: &[u8], index: &mut usize) -> u16 {
81 let val = u16::from_be_bytes(
82 buf.get_unchecked(*index..*index + 2)
83 .try_into()
84 .unwrap_unchecked(),
85 );
86 *index += 2;
87 val
88 }
89
90 macro_rules! read_u16_checked {
91 ($buf: ident, $index: ident, $len: expr) => {{
92 if $index + 2 > $len as usize {
93 return Err(io::Error::new(io::ErrorKind::InvalidData, "invalid data"));
94 }
95 unsafe { read_u16_unchecked($buf, &mut $index) }
96 }};
97 }
98
99 #[inline]
100 unsafe fn read_raw_str_unchecked(buf: &[u8], len: usize, index: &mut usize) -> SmolStr {
101 let val = {
102 let str = SmolStr::new(std::str::from_utf8_unchecked(
103 buf.get_unchecked(*index..(*index + len)),
104 ));
105 str
106 };
107 *index += len;
108 val
109 }
110
111 macro_rules! read_str_checked {
112 ($buf: ident, $index: ident, $len: expr) => {{
113 let val_len = read_u16_checked!($buf, $index, $len);
114 if $index + val_len as usize > $len as usize {
115 return Err(io::Error::new(io::ErrorKind::InvalidData, "invalid data"));
116 }
117 unsafe { read_raw_str_unchecked($buf, val_len as usize, &mut $index) }
118 }};
119 }
120
121 src.advance(2); self.flags = src.get_u16();
123 self.seq_id = src.get_i32();
124 let header_size = src.get_u16();
125 self.header_length = header_size as u32 * 4;
126 if self.header_length as usize > src.len() || header_size < 1 {
127 return Err(io::Error::new(
128 io::ErrorKind::InvalidData,
129 "invalid header length",
130 ));
131 }
132 let header_buf = src.split_to(self.header_length as usize);
133 self.payload_length = total_length - self.header_length - 10;
134 let buf = header_buf.as_ref();
135 let mut index = 0;
136 if let Ok(protocol_id) = unsafe { ProtocolId::try_from(read_u8_unchecked(buf, &mut index)) }
138 {
139 self.protocol_id = protocol_id;
140 }
141 index += 1; let mut _padding_num = 0usize;
144
145 while index < self.header_length as usize {
146 let info_id = unsafe { read_u8_unchecked(buf, &mut index) };
148 match info_id {
149 info::INFO_PADDING => {
150 _padding_num += 1;
151 continue;
152 }
153 info::INFO_KEY_VALUE => {
154 let kv_size = read_u16_checked!(buf, index, self.header_length);
155 for _ in 0..kv_size {
157 let key = read_str_checked!(buf, index, self.header_length);
158 let val = read_str_checked!(buf, index, self.header_length);
159 self.str_headers.insert(key, val);
160 }
161 }
162 info::INFO_INT_KEY_VALUE => {
163 let kv_size = read_u16_checked!(buf, index, self.header_length);
164 for _ in 0..kv_size {
165 let key = read_u16_checked!(buf, index, self.header_length);
166 let val = read_str_checked!(buf, index, self.header_length);
167
168 if (key as usize) < IntMetaKey::INDEX_TABLE_SIZE {
169 unsafe {
171 *self.int_headers.get_unchecked_mut(key as usize) = Some(val);
172 }
173 } else {
174 self.int_headers_ext.push((key, val));
175 }
176 }
177 }
178 info::ACL_TOKEN_KEY_VALUE => {
179 self.acl_token = Some(read_str_checked!(buf, index, self.header_length));
180 }
181 _ => {
182 let msg = format!("unexpected info id in ttheader: {info_id}");
185 tracing::error!("{}", msg);
186 return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
187 }
188 }
189 }
190 Ok(())
191 }
192}
193
194#[derive(Default)]
195pub struct TTHeaderDecoder;
196
197impl TTHeaderDecoder {
198 pub const fn new() -> Self {
199 Self
200 }
201}
202
203impl Decoder for TTHeaderDecoder {
204 type Item = TTHeader;
205 type Error = io::Error;
206
207 fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Decoded<Self::Item>, Self::Error> {
208 if src.len() < MIN_HEADER_LENGTH {
209 return Ok(Decoded::InsufficientAtLeast(MIN_HEADER_LENGTH));
210 }
211
212 if src[4..HEADER_DETECT_LENGTH] == [0x10, 0x00] {
213 let mut header_length = [0; 2];
214 unsafe { copy_nonoverlapping(src.as_ptr().add(12), header_length.as_mut_ptr(), 2) };
215 let header_length = u16::from_be_bytes(header_length) as usize * 4;
216 if src.len() < header_length + MIN_HEADER_LENGTH {
217 return Ok(Decoded::InsufficientAtLeast(
218 header_length + MIN_HEADER_LENGTH,
219 ));
220 }
221
222 let mut length = [0; 4];
223 unsafe { copy_nonoverlapping(src.as_ptr(), length.as_mut_ptr(), 4) };
224 let length = u32::from_be_bytes(length);
225
226 src.advance(4);
227
228 let mut ttheader = TTHeader::new();
230 ttheader.decode_header(length, src)?; Ok(Decoded::Some(ttheader))
232 } else {
233 Err(io::Error::new(io::ErrorKind::Other, "illegal ttheader"))
234 }
235 }
236}
237
238#[derive(Default)]
239pub struct TTHeaderEncoder;
240
241impl TTHeaderEncoder {
242 pub const fn new() -> Self {
243 Self
244 }
245}
246
247impl Encoder<TTHeader> for TTHeaderEncoder {
248 type Error = io::Error;
249
250 fn encode(&mut self, item: TTHeader, dst: &mut bytes::BytesMut) -> Result<(), Self::Error> {
251 #[inline]
252 fn put_str(s: &SmolStr, dst: &mut bytes::BytesMut) {
253 dst.put_u16(s.len() as u16);
254 dst.put_slice(s.as_bytes());
255 }
256
257 dst.reserve(4 * 1024); let zero_index = dst.len();
259 unsafe {
260 dst.advance_mut(4);
261 }
262
263 dst.put_u16(TT_HEADER_MAGIC);
265 dst.put_u16(item.flags);
267 dst.put_i32(item.seq_id);
268
269 unsafe {
271 dst.advance_mut(2);
272 }
273
274 dst.put_u8(item.protocol_id as u8);
275 dst.put_u8(0); dst.put_u8(info::INFO_KEY_VALUE);
279 dst.put_u16(item.str_headers.len() as u16);
280
281 for (key, val) in item.str_headers.iter() {
282 put_str(key, dst);
283 put_str(val, dst);
284 }
285
286 dst.put_u8(info::INFO_INT_KEY_VALUE);
288 let int_kv_index = dst.len();
289 let mut int_kv_len = 0_u16;
290 unsafe {
291 dst.advance_mut(2);
292 }
293
294 for (key, val) in item.int_headers.iter().enumerate() {
295 if let Some(val) = val {
296 dst.put_u16(key as u16);
297 put_str(val, dst);
298 int_kv_len += 1;
299 }
300 }
301
302 for (key, val) in item.int_headers_ext.iter() {
303 dst.put_u16(*key);
304 put_str(val, dst);
305 int_kv_len += 1;
306 }
307
308 let mut buf = &mut dst[int_kv_index..int_kv_index + 2];
310 buf.put_u16(int_kv_len);
311
312 if let Some(ref acl_token) = item.acl_token {
314 dst.put_u8(info::ACL_TOKEN_KEY_VALUE);
315 dst.put_u16(acl_token.len() as u16);
316 dst.put_slice(acl_token.as_bytes());
317 }
318
319 let overflow = (dst.len() - 14 - zero_index) % 4;
321 let padding = (4 - overflow) % 4;
322 (0..padding).for_each(|_| dst.put_u8(0));
323
324 let header_size = dst.len() - zero_index;
326 let mut buf = &mut dst[zero_index + 12..zero_index + 12 + 2];
327 buf.put_u16(((header_size - 14) / 4).try_into().unwrap());
328 tracing::trace!(
329 "encode ttheader write header size: {}",
330 (header_size - 14) / 4
331 );
332
333 let size = dst.len() + item.payload_length as usize - zero_index;
335 let mut buf = &mut dst[zero_index..zero_index + 4];
336 buf.put_u32((size - 4) as u32);
337 tracing::trace!("encode ttheader write length size: {}", size - 4);
338 Ok(())
339 }
340}
341
342pub struct TTHeaderPayload<T> {
343 pub ttheader: TTHeader,
344 pub payload: Option<T>,
345}
346
347impl<T> Clone for TTHeaderPayload<T>
348where
349 T: Clone,
350{
351 fn clone(&self) -> Self {
352 Self {
353 ttheader: self.ttheader.clone(),
354 payload: self.payload.clone(),
355 }
356 }
357}
358
359impl<T> Default for TTHeaderPayload<T> {
360 fn default() -> Self {
361 Self {
362 ttheader: Default::default(),
363 payload: Default::default(),
364 }
365 }
366}
367
368impl<T> TTHeaderPayload<T> {
369 fn new() -> Self {
370 Self {
371 ttheader: TTHeader::new(),
372 payload: None,
373 }
374 }
375}
376
377pub struct TTHeaderPayloadCodec<T> {
378 inner: T,
379}
380
381impl<T> TTHeaderPayloadCodec<T> {
382 pub fn new(inner: T) -> Self {
383 Self { inner }
384 }
385}
386
387impl<T: Decoder> Decoder for TTHeaderPayloadCodec<T>
388where
389 T::Error: From<io::Error>,
390{
391 type Item = TTHeaderPayload<T::Item>;
392 type Error = T::Error;
393
394 fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Decoded<Self::Item>, Self::Error> {
395 if src.len() < MIN_HEADER_LENGTH {
396 return Ok(Decoded::InsufficientAtLeast(MIN_HEADER_LENGTH));
397 }
398
399 if src[4..HEADER_DETECT_LENGTH] == [0x10, 0x00] {
400 let mut length = [0; 4];
401 unsafe { copy_nonoverlapping(src.as_ptr(), length.as_mut_ptr(), 4) };
402 let length = u32::from_be_bytes(length);
403 if src.len() < length as usize + 4 {
404 return Ok(Decoded::InsufficientAtLeast(length as usize + 4));
405 }
406 src.advance(4);
407
408 let mut item = Self::Item::new();
409 item.ttheader.decode_header(length, src)?;
410 match self.inner.decode(src) {
411 Ok(Decoded::Some(payload)) => item.payload = Some(payload),
412 Err(e) => return Err(e),
413 _ => return Err(io::Error::new(io::ErrorKind::Other, "illegal payload").into()),
415 };
416 Ok(Decoded::Some(item))
417 } else {
418 Err(io::Error::new(io::ErrorKind::Other, "illegal ttheader").into())
419 }
420 }
421}
422
423impl<T, E: Encoder<T>> Encoder<TTHeaderPayload<T>> for TTHeaderPayloadCodec<E> {
424 type Error = E::Error;
425
426 fn encode(
427 &mut self,
428 item: TTHeaderPayload<T>,
429 dst: &mut bytes::BytesMut,
430 ) -> Result<(), Self::Error> {
431 let zero_index = dst.len();
432 let mut ttheader_encoder = TTHeaderEncoder {};
433 ttheader_encoder.encode(item.ttheader, dst)?;
434 self.inner
435 .encode(item.payload.expect("payload must some"), dst)?;
436 let size = dst.len() - zero_index;
438 let mut buf = &mut dst[zero_index..zero_index + 4];
439 buf.put_u32((size - 4) as u32);
440 tracing::trace!("encode ttheader write length size: {}", size - 4);
441 Ok(())
442 }
443}
444
445#[derive(Default)]
446pub struct RawPayloadCodec;
447
448impl RawPayloadCodec {
449 pub const fn new() -> Self {
450 Self
451 }
452}
453
454impl Decoder for RawPayloadCodec {
455 type Item = bytes::Bytes;
456
457 type Error = io::Error;
458
459 fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Decoded<Self::Item>, Self::Error> {
460 Ok(Decoded::Some(bytes::Bytes::from(src.split())))
461 }
462}
463
464impl Encoder<bytes::Bytes> for RawPayloadCodec {
465 type Error = io::Error;
466
467 fn encode(&mut self, item: bytes::Bytes, dst: &mut bytes::BytesMut) -> Result<(), Self::Error> {
468 dst.reserve(item.len());
469 dst.extend_from_slice(&item);
470 Ok(())
471 }
472}
473
474const HEADER_DETECT_LENGTH: usize = 6;
477const MIN_HEADER_LENGTH: usize = 14;
478
479pub const TT_HEADER_MAGIC: u16 = 0x1000;
480
481mod info {
482 pub const INFO_PADDING: u8 = 0x00;
483 pub const INFO_KEY_VALUE: u8 = 0x01;
484 pub const INFO_INT_KEY_VALUE: u8 = 0x10;
485 pub const ACL_TOKEN_KEY_VALUE: u8 = 0x11;
486}
487
488#[derive(TryFromPrimitive, Clone, Copy, Default)]
489#[repr(u8)]
490pub enum ProtocolId {
491 #[default]
492 Binary = 0,
493 Compact = 2, CompactV2 = 3, Protobuf = 4,
496}
497
498#[derive(PartialEq, Eq, Hash, Clone, Copy, TryFromPrimitive, Debug)]
499#[repr(u16)]
500pub enum IntMetaKey {
501 TransportType = 1,
502 LogId = 2,
504 FromService = 3,
505 FromCluster = 4,
506 FromIdc = 5,
507 ToService = 6,
508 ToMethod = 9,
509 ToCluster = 7,
510 ToIdc = 8,
511 Env = 10,
512 DestAddress = 11,
513 RPCTimeoutMs = 12,
514 RingHashKey = 14,
515 WithHeader = 16,
516 ConnTimeoutMs = 17,
517 TraceSpanCtx = 18,
518 ShortConnection = 19,
519 FromMethod = 20,
520 StressTag = 21,
521 MsgType = 22,
522 ConnectionRecycle = 23,
523 RawRingHashKey = 24,
524 LBType = 25,
525 ClusterShardId = 26,
526}
527
528impl IntMetaKey {
529 const INDEX_TABLE_SIZE: usize = Self::ClusterShardId as usize + 1;
530}