runtara_protocol/
frame.rs

1// Copyright (C) 2025 SyncMyOrders Sp. z o.o.
2// SPDX-License-Identifier: AGPL-3.0-or-later
3//! Wire format for QUIC stream framing.
4//!
5//! Each QUIC stream carries one RPC call with the following frame format:
6//! - 4 bytes: message length (big-endian)
7//! - 2 bytes: message type
8//! - N bytes: protobuf payload
9
10use bytes::{Buf, BufMut, Bytes, BytesMut};
11use prost::Message;
12use thiserror::Error;
13use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
14
15/// Maximum frame size (64 MB)
16/// Increased to accommodate large compiled workflow binaries
17pub const MAX_FRAME_SIZE: usize = 64 * 1024 * 1024;
18
19/// Frame header size (4 bytes length + 2 bytes type)
20pub const HEADER_SIZE: usize = 6;
21
22/// Message types for the wire protocol
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24#[repr(u16)]
25pub enum MessageType {
26    /// Request message
27    Request = 1,
28    /// Response message
29    Response = 2,
30    /// Start of a streaming response
31    StreamStart = 3,
32    /// Data chunk in a streaming response
33    StreamData = 4,
34    /// End of a streaming response
35    StreamEnd = 5,
36    /// Error response
37    Error = 6,
38}
39
40impl TryFrom<u16> for MessageType {
41    type Error = FrameError;
42
43    fn try_from(value: u16) -> Result<Self, <Self as TryFrom<u16>>::Error> {
44        match value {
45            1 => Ok(MessageType::Request),
46            2 => Ok(MessageType::Response),
47            3 => Ok(MessageType::StreamStart),
48            4 => Ok(MessageType::StreamData),
49            5 => Ok(MessageType::StreamEnd),
50            6 => Ok(MessageType::Error),
51            _ => Err(FrameError::InvalidMessageType(value)),
52        }
53    }
54}
55
56/// Errors that can occur during frame encoding/decoding
57#[derive(Debug, Error)]
58pub enum FrameError {
59    #[error("frame too large: {0} bytes (max: {MAX_FRAME_SIZE})")]
60    FrameTooLarge(usize),
61
62    #[error("invalid message type: {0}")]
63    InvalidMessageType(u16),
64
65    #[error("IO error: {0}")]
66    Io(#[from] std::io::Error),
67
68    #[error("protobuf decode error: {0}")]
69    Decode(#[from] prost::DecodeError),
70
71    #[error("connection closed")]
72    ConnectionClosed,
73}
74
75/// A framed message with type and payload
76#[derive(Debug, Clone)]
77pub struct Frame {
78    pub message_type: MessageType,
79    pub payload: Bytes,
80}
81
82impl Frame {
83    /// Create a new request frame
84    pub fn request<M: Message>(msg: &M) -> Result<Self, FrameError> {
85        Self::new(MessageType::Request, msg)
86    }
87
88    /// Create a new response frame
89    pub fn response<M: Message>(msg: &M) -> Result<Self, FrameError> {
90        Self::new(MessageType::Response, msg)
91    }
92
93    /// Create a new error frame
94    pub fn error<M: Message>(msg: &M) -> Result<Self, FrameError> {
95        Self::new(MessageType::Error, msg)
96    }
97
98    /// Create a new stream data frame
99    pub fn stream_data<M: Message>(msg: &M) -> Result<Self, FrameError> {
100        Self::new(MessageType::StreamData, msg)
101    }
102
103    /// Create a new frame with the given type and message
104    pub fn new<M: Message>(message_type: MessageType, msg: &M) -> Result<Self, FrameError> {
105        let payload = msg.encode_to_vec();
106        if payload.len() > MAX_FRAME_SIZE {
107            return Err(FrameError::FrameTooLarge(payload.len()));
108        }
109        Ok(Self {
110            message_type,
111            payload: Bytes::from(payload),
112        })
113    }
114
115    /// Decode the payload as a protobuf message
116    pub fn decode<M: Message + Default>(&self) -> Result<M, FrameError> {
117        Ok(M::decode(self.payload.clone())?)
118    }
119
120    /// Encode the frame to bytes for wire transmission
121    pub fn encode(&self) -> Bytes {
122        let mut buf = BytesMut::with_capacity(HEADER_SIZE + self.payload.len());
123        buf.put_u32(self.payload.len() as u32);
124        buf.put_u16(self.message_type as u16);
125        buf.put(self.payload.clone());
126        buf.freeze()
127    }
128
129    /// Decode a frame from bytes
130    pub fn decode_from_bytes(mut bytes: Bytes) -> Result<Self, FrameError> {
131        if bytes.len() < HEADER_SIZE {
132            return Err(FrameError::Io(std::io::Error::new(
133                std::io::ErrorKind::UnexpectedEof,
134                "incomplete frame header",
135            )));
136        }
137
138        let length = bytes.get_u32() as usize;
139        let message_type = MessageType::try_from(bytes.get_u16())?;
140
141        if length > MAX_FRAME_SIZE {
142            return Err(FrameError::FrameTooLarge(length));
143        }
144
145        if bytes.len() < length {
146            return Err(FrameError::Io(std::io::Error::new(
147                std::io::ErrorKind::UnexpectedEof,
148                "incomplete frame payload",
149            )));
150        }
151
152        let payload = bytes.split_to(length);
153        Ok(Self {
154            message_type,
155            payload,
156        })
157    }
158}
159
160/// Write a frame to an async writer
161pub async fn write_frame<W: AsyncWrite + Unpin>(
162    writer: &mut W,
163    frame: &Frame,
164) -> Result<(), FrameError> {
165    let encoded = frame.encode();
166    writer.write_all(&encoded).await?;
167    Ok(())
168}
169
170/// Read a frame from an async reader
171pub async fn read_frame<R: AsyncRead + Unpin>(reader: &mut R) -> Result<Frame, FrameError> {
172    // Read header
173    let mut header = [0u8; HEADER_SIZE];
174    match reader.read_exact(&mut header).await {
175        Ok(_) => {}
176        Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
177            return Err(FrameError::ConnectionClosed);
178        }
179        Err(e) => return Err(e.into()),
180    }
181
182    let length = u32::from_be_bytes([header[0], header[1], header[2], header[3]]) as usize;
183    let message_type = MessageType::try_from(u16::from_be_bytes([header[4], header[5]]))?;
184
185    if length > MAX_FRAME_SIZE {
186        return Err(FrameError::FrameTooLarge(length));
187    }
188
189    // Read payload
190    let mut payload = vec![0u8; length];
191    reader.read_exact(&mut payload).await?;
192
193    Ok(Frame {
194        message_type,
195        payload: Bytes::from(payload),
196    })
197}
198
199/// Framed codec for encoding/decoding frames on a stream
200pub struct FramedStream<S> {
201    stream: S,
202}
203
204impl<S> FramedStream<S> {
205    pub fn new(stream: S) -> Self {
206        Self { stream }
207    }
208
209    pub fn into_inner(self) -> S {
210        self.stream
211    }
212}
213
214impl<S: AsyncRead + Unpin> FramedStream<S> {
215    /// Read the next frame from the stream
216    pub async fn read_frame(&mut self) -> Result<Frame, FrameError> {
217        read_frame(&mut self.stream).await
218    }
219}
220
221impl<S: AsyncWrite + Unpin> FramedStream<S> {
222    /// Write a frame to the stream
223    pub async fn write_frame(&mut self, frame: &Frame) -> Result<(), FrameError> {
224        write_frame(&mut self.stream, frame).await
225    }
226}
227
228impl<S: AsyncRead + AsyncWrite + Unpin> FramedStream<S> {
229    /// Send a request and wait for a response
230    pub async fn request<Req: Message, Resp: Message + Default>(
231        &mut self,
232        request: &Req,
233    ) -> Result<Resp, FrameError> {
234        let frame = Frame::request(request)?;
235        self.write_frame(&frame).await?;
236
237        let response_frame = self.read_frame().await?;
238        match response_frame.message_type {
239            MessageType::Response => response_frame.decode(),
240            MessageType::Error => {
241                // Try to decode as error message
242                Err(FrameError::Io(std::io::Error::other(
243                    "received error response",
244                )))
245            }
246            _ => Err(FrameError::Io(std::io::Error::new(
247                std::io::ErrorKind::InvalidData,
248                "unexpected message type",
249            ))),
250        }
251    }
252
253    /// Send a response
254    pub async fn respond<Resp: Message>(&mut self, response: &Resp) -> Result<(), FrameError> {
255        let frame = Frame::response(response)?;
256        self.write_frame(&frame).await
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263
264    #[test]
265    fn test_message_type_round_trip() {
266        for &mt in &[
267            MessageType::Request,
268            MessageType::Response,
269            MessageType::StreamStart,
270            MessageType::StreamData,
271            MessageType::StreamEnd,
272            MessageType::Error,
273        ] {
274            let value = mt as u16;
275            let decoded = MessageType::try_from(value).unwrap();
276            assert_eq!(mt, decoded);
277        }
278    }
279
280    #[test]
281    fn test_frame_encode_decode() {
282        use crate::management_proto::HealthCheckRequest;
283
284        let msg = HealthCheckRequest {};
285        let frame = Frame::request(&msg).unwrap();
286        let encoded = frame.encode();
287        let decoded = Frame::decode_from_bytes(encoded).unwrap();
288
289        assert_eq!(frame.message_type, decoded.message_type);
290        assert_eq!(frame.payload, decoded.payload);
291    }
292}