forge_worker_sdk/
framing.rs1use std::str::FromStr;
14
15use bytes::{Buf, BufMut, BytesMut};
16use tokio_util::codec::{Decoder, Encoder};
17
18use crate::protocol::{WireEvent, WireRequest, WireResponse};
19
20pub const KIND_REQUEST: u8 = 1;
23pub const KIND_RESPONSE: u8 = 2;
24pub const KIND_EVENT: u8 = 3;
25
26pub const MAX_FRAME_PAYLOAD: usize = 64 * 1024 * 1024;
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum Encoding {
33 Msgpack,
34 Json,
35}
36
37impl Encoding {
38 pub fn wire_name(self) -> &'static str {
40 match self {
41 Encoding::Msgpack => "msgpack",
42 Encoding::Json => "json",
43 }
44 }
45}
46
47impl FromStr for Encoding {
48 type Err = std::convert::Infallible;
49
50 fn from_str(s: &str) -> Result<Self, Self::Err> {
51 Ok(if s.eq_ignore_ascii_case("json") {
52 Encoding::Json
53 } else {
54 Encoding::Msgpack
55 })
56 }
57}
58
59#[derive(Debug)]
63pub enum Frame {
64 Request(WireRequest),
65 Response(WireResponse),
66 Event(WireEvent),
67}
68
69pub struct FrameCodec {
73 encoding: Encoding,
74}
75
76impl FrameCodec {
77 pub fn new(encoding: Encoding) -> Self { Self { encoding } }
78 pub fn msgpack() -> Self { Self::new(Encoding::Msgpack) }
79 pub fn json() -> Self { Self::new(Encoding::Json) }
80}
81
82impl Decoder for FrameCodec {
85 type Item = Frame;
86 type Error = anyhow::Error;
87
88 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
89 if src.len() < 5 {
91 return Ok(None);
92 }
93
94 let length = u32::from_be_bytes([src[0], src[1], src[2], src[3]]) as usize;
95 let kind = src[4];
96
97 if length > MAX_FRAME_PAYLOAD {
98 return Err(anyhow::anyhow!("frame payload too large: {} bytes (max {})", length, MAX_FRAME_PAYLOAD));
99 }
100
101 if src.len() < 5 + length {
103 src.reserve(5 + length - src.len());
104 return Ok(None);
105 }
106
107 src.advance(5);
109 let payload = src.split_to(length);
110
111 let frame = match kind {
112 KIND_REQUEST => Frame::Request(self.unmarshal(&payload)?),
113 KIND_RESPONSE => Frame::Response(self.unmarshal(&payload)?),
114 KIND_EVENT => Frame::Event(self.unmarshal(&payload)?),
115 k => return Err(anyhow::anyhow!("unknown frame kind: {}", k)),
116 };
117
118 Ok(Some(frame))
119 }
120}
121
122impl Encoder<Frame> for FrameCodec {
125 type Error = anyhow::Error;
126
127 fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
128 let (kind, payload) = match &frame {
129 Frame::Request(r) => (KIND_REQUEST, self.marshal(r)?),
130 Frame::Response(r) => (KIND_RESPONSE, self.marshal(r)?),
131 Frame::Event(e) => (KIND_EVENT, self.marshal(e)?),
132 };
133
134 dst.reserve(5 + payload.len());
135 dst.put_u32(payload.len() as u32); dst.put_u8(kind); dst.put_slice(&payload);
138 Ok(())
139 }
140}
141
142impl FrameCodec {
145 fn marshal<T: serde::Serialize>(&self, v: &T) -> anyhow::Result<Vec<u8>> {
146 match self.encoding {
147 Encoding::Msgpack => rmp_serde::to_vec_named(v).map_err(Into::into),
148 Encoding::Json => serde_json::to_vec(v).map_err(Into::into),
149 }
150 }
151
152 fn unmarshal<T: serde::de::DeserializeOwned>(&self, data: &[u8]) -> anyhow::Result<T> {
153 match self.encoding {
154 Encoding::Msgpack => rmp_serde::from_slice(data).map_err(Into::into),
155 Encoding::Json => serde_json::from_slice(data).map_err(Into::into),
156 }
157 }
158}