Skip to main content

pyrosql_protocol/
codec.rs

1//! Arrow IPC framing for PyroLink QUIC streams.
2//!
3//! Each message on a bidirectional QUIC stream is encoded as:
4//!
5//! ```text
6//! ┌────────┬──────────┬──────────────────────────────────────┐
7//! │ type 1B│ len 4B LE│ Arrow IPC message (Flight payload)   │
8//! └────────┴──────────┴──────────────────────────────────────┘
9//! ```
10//!
11//! For streaming responses (e.g. DoGet), the server sends multiple framed
12//! messages followed by a single [`MsgType::Eos`] byte with no length field.
13//!
14//! All multi-byte integers are **little-endian**.
15
16use crate::error::PyroLinkError;
17use bytes::Bytes;
18use quinn::{RecvStream, SendStream};
19
20// ── Message type constants ────────────────────────────────────────────────────
21
22/// Wire type byte for an Arrow Schema message.
23pub const MSG_SCHEMA: u8 = 0x01;
24/// Wire type byte for an Arrow RecordBatch message.
25pub const MSG_RECORD_BATCH: u8 = 0x02;
26/// Wire type byte for a GetFlightInfo request.
27pub const MSG_GET_FLIGHT_INFO: u8 = 0x03;
28/// Wire type byte for a DoGet request.
29pub const MSG_DO_GET: u8 = 0x04;
30/// Wire type byte for a DoAction request.
31pub const MSG_DO_ACTION: u8 = 0x05;
32/// Wire type byte for a DoPut request.
33pub const MSG_DO_PUT: u8 = 0x06;
34/// Wire type byte for a ListActions request.
35pub const MSG_LIST_ACTIONS: u8 = 0x07;
36/// Wire type byte for a PrepareStatement request.
37pub const MSG_PREPARE_STATEMENT: u8 = 0x08;
38/// Wire type byte for a direct SQL query (single-roundtrip: SQL in, results out).
39pub const MSG_QUERY: u8 = 0x09;
40/// Wire type byte for a topology request (adaptive transport negotiation).
41pub const MSG_TOPOLOGY: u8 = 0x0A;
42/// Wire type byte for a server-pushed notification (LISTEN/NOTIFY, WATCH).
43///
44/// Sent on server-initiated unidirectional streams.  Payload is a JSON object:
45/// `{"channel": "<name>", "payload": "<text>"}`.
46pub const MSG_NOTIFICATION: u8 = 0x0F;
47/// Wire type byte marking end-of-stream (no length or payload follows).
48pub const MSG_EOS: u8 = 0xFF;
49
50// ── RPC type enum ─────────────────────────────────────────────────────────────
51
52/// The RPC operation encoded in the first byte of a new QUIC stream.
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub enum RpcType {
55    /// `GetFlightInfo` — plan a SQL query, return a ticket.
56    GetFlightInfo,
57    /// `DoGet` — stream query results for a previously planned ticket.
58    DoGet,
59    /// `DoPut` — bulk-insert a RecordBatch stream into a table.
60    DoPut,
61    /// `DoAction` — execute a named action (transactions, DDL, etc.).
62    DoAction,
63    /// `ListActions` — enumerate supported server actions.
64    ListActions,
65    /// `PrepareStatement` — create a prepared statement handle.
66    PrepareStatement,
67    /// `Query` — direct SQL execution in a single roundtrip (SQL in, results out).
68    Query,
69    /// `Topology` — request server topology hints for adaptive transport negotiation.
70    Topology,
71}
72
73impl TryFrom<u8> for RpcType {
74    type Error = PyroLinkError;
75
76    fn try_from(byte: u8) -> Result<Self, Self::Error> {
77        match byte {
78            MSG_GET_FLIGHT_INFO => Ok(Self::GetFlightInfo),
79            MSG_DO_GET => Ok(Self::DoGet),
80            MSG_DO_PUT => Ok(Self::DoPut),
81            MSG_DO_ACTION => Ok(Self::DoAction),
82            MSG_LIST_ACTIONS => Ok(Self::ListActions),
83            MSG_PREPARE_STATEMENT => Ok(Self::PrepareStatement),
84            MSG_QUERY => Ok(Self::Query),
85            MSG_TOPOLOGY => Ok(Self::Topology),
86            other => Err(PyroLinkError::UnknownRpcType(other)),
87        }
88    }
89}
90
91// ── Low-level read helpers ────────────────────────────────────────────────────
92
93/// Read the RPC type byte from the start of a new QUIC stream.
94///
95/// This is the first byte sent by the client on every new bidirectional stream.
96/// It tells the server which Arrow Flight RPC this stream carries.
97///
98/// # Errors
99///
100/// Returns [`PyroLinkError`] if the stream ends before a byte is available
101/// or if the byte value is not a known RPC type.
102pub async fn read_rpc_type(recv: &mut RecvStream) -> Result<RpcType, PyroLinkError> {
103    let mut buf = [0u8; 1];
104    recv.read_exact(&mut buf).await?;
105    RpcType::try_from(buf[0])
106}
107
108/// Read one framed message from the stream.
109///
110/// Reads the 1-byte type tag and 4-byte LE length, then reads exactly that many
111/// bytes of payload.  Returns `None` if the stream signals EOS (`0xFF` type byte).
112///
113/// # Errors
114///
115/// Returns [`PyroLinkError::Framing`] if the length field overflows or the
116/// stream ends prematurely.
117pub async fn read_message(
118    recv: &mut RecvStream,
119) -> Result<Option<(u8, Bytes)>, PyroLinkError> {
120    // Read type byte first (EOS has no length field).
121    let mut header = [0u8; 5];
122    recv.read_exact(&mut header[..1]).await?;
123
124    if header[0] == MSG_EOS {
125        return Ok(None);
126    }
127
128    // Read 4-byte LE length.
129    recv.read_exact(&mut header[1..5]).await?;
130    let len = u32::from_le_bytes([header[1], header[2], header[3], header[4]]) as usize;
131
132    // Safety limit: 256 MiB per message
133    if len > 256 * 1024 * 1024 {
134        return Err(PyroLinkError::Framing(format!(
135            "message length {len} exceeds the 256 MiB limit"
136        )));
137    }
138
139    // Read payload
140    let mut payload = vec![0u8; len];
141    recv.read_exact(&mut payload).await?;
142
143    Ok(Some((header[0], Bytes::from(payload))))
144}
145
146/// Write one framed message to the stream.
147///
148/// Encodes `type_byte + len(4B LE) + payload` and writes it to the stream.
149///
150/// # Errors
151///
152/// Returns [`PyroLinkError::Stream`] if the underlying write fails.
153pub async fn write_message(
154    send: &mut SendStream,
155    type_byte: u8,
156    payload: &[u8],
157) -> Result<(), PyroLinkError> {
158    // Stack-allocated 5-byte header avoids a Vec allocation per message.
159    let len = (payload.len() as u32).to_le_bytes();
160    let header = [type_byte, len[0], len[1], len[2], len[3]];
161    send.write_all(&header)
162        .await
163        .map_err(|e| PyroLinkError::Stream(e.to_string()))?;
164    if !payload.is_empty() {
165        send.write_all(payload)
166            .await
167            .map_err(|e| PyroLinkError::Stream(e.to_string()))?;
168    }
169    Ok(())
170}
171
172/// Write the EOS marker (`0xFF`) to the stream and finish it.
173///
174/// No length or payload follows the EOS byte.
175///
176/// # Errors
177///
178/// Returns [`PyroLinkError::Stream`] if the write or finish call fails.
179pub async fn write_eos(send: &mut SendStream) -> Result<(), PyroLinkError> {
180    send.write_all(&[MSG_EOS])
181        .await
182        .map_err(|e| PyroLinkError::Stream(e.to_string()))?;
183    send.finish()
184        .map_err(|e| PyroLinkError::Stream(e.to_string()))
185}
186
187// ── Tests ─────────────────────────────────────────────────────────────────────
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192
193    #[test]
194    fn rpc_type_round_trips() {
195        let cases = [
196            (MSG_GET_FLIGHT_INFO, RpcType::GetFlightInfo),
197            (MSG_DO_GET, RpcType::DoGet),
198            (MSG_DO_PUT, RpcType::DoPut),
199            (MSG_DO_ACTION, RpcType::DoAction),
200            (MSG_LIST_ACTIONS, RpcType::ListActions),
201            (MSG_PREPARE_STATEMENT, RpcType::PrepareStatement),
202            (MSG_QUERY, RpcType::Query),
203            (MSG_TOPOLOGY, RpcType::Topology),
204        ];
205        for (byte, expected) in cases {
206            let got = RpcType::try_from(byte).expect("known byte must convert");
207            assert_eq!(got, expected);
208        }
209    }
210
211    #[test]
212    fn unknown_rpc_type_byte_returns_error() {
213        let result = RpcType::try_from(0x42u8);
214        assert!(matches!(result, Err(PyroLinkError::UnknownRpcType(0x42))));
215    }
216
217    #[test]
218    fn eos_byte_is_distinct_from_all_rpc_types() {
219        // MSG_EOS must not collide with any recognised RPC type byte.
220        assert!(RpcType::try_from(MSG_EOS).is_err());
221    }
222
223    // ── Wire-frame encoding helpers ───────────────────────────────────────────
224    //
225    // `write_message` and `read_message` operate on live QUIC streams and
226    // cannot be called in unit tests without a full QUIC endpoint.  Instead
227    // we verify the *frame layout* produced by the same byte-packing logic
228    // that `write_message` uses: [type_byte | len_le_u32 | payload].
229
230    /// Build a PyroLink frame the same way `write_message` does.
231    fn build_frame(type_byte: u8, payload: &[u8]) -> Vec<u8> {
232        let len = payload.len() as u32;
233        let mut frame = Vec::with_capacity(5 + payload.len());
234        frame.push(type_byte);
235        frame.extend_from_slice(&len.to_le_bytes());
236        frame.extend_from_slice(payload);
237        frame
238    }
239
240    /// Decode a frame built by `build_frame`, returning `(type_byte, payload)`.
241    /// Returns `None` if the first byte is `MSG_EOS`.
242    fn decode_frame(frame: &[u8]) -> Option<(u8, &[u8])> {
243        assert!(frame.len() >= 1, "frame too short");
244        let type_byte = frame[0];
245        if type_byte == MSG_EOS {
246            return None;
247        }
248        assert!(frame.len() >= 5, "frame missing length field");
249        let len = u32::from_le_bytes([frame[1], frame[2], frame[3], frame[4]]) as usize;
250        assert_eq!(frame.len(), 5 + len, "frame payload length mismatch");
251        Some((type_byte, &frame[5..]))
252    }
253
254    #[test]
255    fn frame_wire_format_type_and_length_prefix() {
256        // A MSG_SCHEMA frame with a 3-byte payload must have:
257        //   byte 0 = 0x01 (MSG_SCHEMA)
258        //   bytes 1-4 = 3 as little-endian u32
259        //   bytes 5-7 = payload
260        let payload = b"abc";
261        let frame = build_frame(MSG_SCHEMA, payload);
262        assert_eq!(frame.len(), 8);
263        assert_eq!(frame[0], MSG_SCHEMA);
264        assert_eq!(&frame[1..5], &3u32.to_le_bytes());
265        assert_eq!(&frame[5..], payload);
266    }
267
268    #[test]
269    fn frame_roundtrip_record_batch() {
270        let payload: Vec<u8> = (0u8..=15u8).collect();
271        let frame = build_frame(MSG_RECORD_BATCH, &payload);
272        let (got_type, got_payload) = decode_frame(&frame).expect("not EOS");
273        assert_eq!(got_type, MSG_RECORD_BATCH);
274        assert_eq!(got_payload, payload.as_slice());
275    }
276
277    #[test]
278    fn frame_roundtrip_empty_payload() {
279        let frame = build_frame(MSG_DO_ACTION, &[]);
280        assert_eq!(frame.len(), 5, "header only for empty payload");
281        let (got_type, got_payload) = decode_frame(&frame).expect("not EOS");
282        assert_eq!(got_type, MSG_DO_ACTION);
283        assert!(got_payload.is_empty());
284    }
285
286    #[test]
287    fn eos_frame_is_single_byte() {
288        // write_eos writes exactly one byte: MSG_EOS (0xFF), then finishes.
289        let eos_frame = vec![MSG_EOS];
290        let result = decode_frame(&eos_frame);
291        assert!(result.is_none(), "EOS frame should return None");
292    }
293
294    #[test]
295    fn frame_length_is_little_endian() {
296        // 0x100 = 256 bytes.  In LE: [0x00, 0x01, 0x00, 0x00].
297        let payload = vec![0u8; 256];
298        let frame = build_frame(MSG_DO_GET, &payload);
299        // Check LE byte order explicitly.
300        assert_eq!(frame[1], 0x00, "LSB of 256");
301        assert_eq!(frame[2], 0x01, "next byte of 256");
302        assert_eq!(frame[3], 0x00);
303        assert_eq!(frame[4], 0x00);
304    }
305
306    #[test]
307    fn frame_large_payload_65536_bytes() {
308        // Verify a >64KB payload encodes and decodes correctly.
309        let payload: Vec<u8> = (0u8..=255u8).cycle().take(65_536).collect();
310        let frame = build_frame(MSG_RECORD_BATCH, &payload);
311        let (got_type, got_payload) = decode_frame(&frame).expect("not EOS");
312        assert_eq!(got_type, MSG_RECORD_BATCH);
313        assert_eq!(got_payload.len(), 65_536);
314        assert_eq!(got_payload, payload.as_slice());
315    }
316
317    #[test]
318    fn all_rpc_type_bytes_covered_by_try_from() {
319        // Exhaustive check: only the seven known RPC bytes parse successfully.
320        let known = [
321            MSG_GET_FLIGHT_INFO,
322            MSG_DO_GET,
323            MSG_DO_PUT,
324            MSG_DO_ACTION,
325            MSG_LIST_ACTIONS,
326            MSG_PREPARE_STATEMENT,
327            MSG_QUERY,
328            MSG_TOPOLOGY,
329        ];
330        for b in 0u8..=254u8 {
331            let result = RpcType::try_from(b);
332            if known.contains(&b) {
333                assert!(result.is_ok(), "byte {b:#04x} should map to a known RpcType");
334            } else {
335                assert!(result.is_err(), "byte {b:#04x} should be unknown");
336            }
337        }
338    }
339}