1use crate::error::PyroLinkError;
17use bytes::Bytes;
18use quinn::{RecvStream, SendStream};
19
20pub const MSG_SCHEMA: u8 = 0x01;
24pub const MSG_RECORD_BATCH: u8 = 0x02;
26pub const MSG_GET_FLIGHT_INFO: u8 = 0x03;
28pub const MSG_DO_GET: u8 = 0x04;
30pub const MSG_DO_ACTION: u8 = 0x05;
32pub const MSG_DO_PUT: u8 = 0x06;
34pub const MSG_LIST_ACTIONS: u8 = 0x07;
36pub const MSG_PREPARE_STATEMENT: u8 = 0x08;
38pub const MSG_QUERY: u8 = 0x09;
40pub const MSG_TOPOLOGY: u8 = 0x0A;
42pub const MSG_NOTIFICATION: u8 = 0x0F;
47pub const MSG_EOS: u8 = 0xFF;
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub enum RpcType {
55 GetFlightInfo,
57 DoGet,
59 DoPut,
61 DoAction,
63 ListActions,
65 PrepareStatement,
67 Query,
69 Topology,
71}
72
73impl TryFrom<u8> for RpcType {
74 type Error = PyroLinkError;
75
76 fn try_from(byte: u8) -> Result<Self, Self::Error> {
77 match byte {
78 MSG_GET_FLIGHT_INFO => Ok(Self::GetFlightInfo),
79 MSG_DO_GET => Ok(Self::DoGet),
80 MSG_DO_PUT => Ok(Self::DoPut),
81 MSG_DO_ACTION => Ok(Self::DoAction),
82 MSG_LIST_ACTIONS => Ok(Self::ListActions),
83 MSG_PREPARE_STATEMENT => Ok(Self::PrepareStatement),
84 MSG_QUERY => Ok(Self::Query),
85 MSG_TOPOLOGY => Ok(Self::Topology),
86 other => Err(PyroLinkError::UnknownRpcType(other)),
87 }
88 }
89}
90
91pub async fn read_rpc_type(recv: &mut RecvStream) -> Result<RpcType, PyroLinkError> {
103 let mut buf = [0u8; 1];
104 recv.read_exact(&mut buf).await?;
105 RpcType::try_from(buf[0])
106}
107
108pub async fn read_message(
118 recv: &mut RecvStream,
119) -> Result<Option<(u8, Bytes)>, PyroLinkError> {
120 let mut header = [0u8; 5];
122 recv.read_exact(&mut header[..1]).await?;
123
124 if header[0] == MSG_EOS {
125 return Ok(None);
126 }
127
128 recv.read_exact(&mut header[1..5]).await?;
130 let len = u32::from_le_bytes([header[1], header[2], header[3], header[4]]) as usize;
131
132 if len > 256 * 1024 * 1024 {
134 return Err(PyroLinkError::Framing(format!(
135 "message length {len} exceeds the 256 MiB limit"
136 )));
137 }
138
139 let mut payload = vec![0u8; len];
141 recv.read_exact(&mut payload).await?;
142
143 Ok(Some((header[0], Bytes::from(payload))))
144}
145
146pub async fn write_message(
154 send: &mut SendStream,
155 type_byte: u8,
156 payload: &[u8],
157) -> Result<(), PyroLinkError> {
158 let len = (payload.len() as u32).to_le_bytes();
160 let header = [type_byte, len[0], len[1], len[2], len[3]];
161 send.write_all(&header)
162 .await
163 .map_err(|e| PyroLinkError::Stream(e.to_string()))?;
164 if !payload.is_empty() {
165 send.write_all(payload)
166 .await
167 .map_err(|e| PyroLinkError::Stream(e.to_string()))?;
168 }
169 Ok(())
170}
171
172pub async fn write_eos(send: &mut SendStream) -> Result<(), PyroLinkError> {
180 send.write_all(&[MSG_EOS])
181 .await
182 .map_err(|e| PyroLinkError::Stream(e.to_string()))?;
183 send.finish()
184 .map_err(|e| PyroLinkError::Stream(e.to_string()))
185}
186
187#[cfg(test)]
190mod tests {
191 use super::*;
192
193 #[test]
194 fn rpc_type_round_trips() {
195 let cases = [
196 (MSG_GET_FLIGHT_INFO, RpcType::GetFlightInfo),
197 (MSG_DO_GET, RpcType::DoGet),
198 (MSG_DO_PUT, RpcType::DoPut),
199 (MSG_DO_ACTION, RpcType::DoAction),
200 (MSG_LIST_ACTIONS, RpcType::ListActions),
201 (MSG_PREPARE_STATEMENT, RpcType::PrepareStatement),
202 (MSG_QUERY, RpcType::Query),
203 (MSG_TOPOLOGY, RpcType::Topology),
204 ];
205 for (byte, expected) in cases {
206 let got = RpcType::try_from(byte).expect("known byte must convert");
207 assert_eq!(got, expected);
208 }
209 }
210
211 #[test]
212 fn unknown_rpc_type_byte_returns_error() {
213 let result = RpcType::try_from(0x42u8);
214 assert!(matches!(result, Err(PyroLinkError::UnknownRpcType(0x42))));
215 }
216
217 #[test]
218 fn eos_byte_is_distinct_from_all_rpc_types() {
219 assert!(RpcType::try_from(MSG_EOS).is_err());
221 }
222
223 fn build_frame(type_byte: u8, payload: &[u8]) -> Vec<u8> {
232 let len = payload.len() as u32;
233 let mut frame = Vec::with_capacity(5 + payload.len());
234 frame.push(type_byte);
235 frame.extend_from_slice(&len.to_le_bytes());
236 frame.extend_from_slice(payload);
237 frame
238 }
239
240 fn decode_frame(frame: &[u8]) -> Option<(u8, &[u8])> {
243 assert!(frame.len() >= 1, "frame too short");
244 let type_byte = frame[0];
245 if type_byte == MSG_EOS {
246 return None;
247 }
248 assert!(frame.len() >= 5, "frame missing length field");
249 let len = u32::from_le_bytes([frame[1], frame[2], frame[3], frame[4]]) as usize;
250 assert_eq!(frame.len(), 5 + len, "frame payload length mismatch");
251 Some((type_byte, &frame[5..]))
252 }
253
254 #[test]
255 fn frame_wire_format_type_and_length_prefix() {
256 let payload = b"abc";
261 let frame = build_frame(MSG_SCHEMA, payload);
262 assert_eq!(frame.len(), 8);
263 assert_eq!(frame[0], MSG_SCHEMA);
264 assert_eq!(&frame[1..5], &3u32.to_le_bytes());
265 assert_eq!(&frame[5..], payload);
266 }
267
268 #[test]
269 fn frame_roundtrip_record_batch() {
270 let payload: Vec<u8> = (0u8..=15u8).collect();
271 let frame = build_frame(MSG_RECORD_BATCH, &payload);
272 let (got_type, got_payload) = decode_frame(&frame).expect("not EOS");
273 assert_eq!(got_type, MSG_RECORD_BATCH);
274 assert_eq!(got_payload, payload.as_slice());
275 }
276
277 #[test]
278 fn frame_roundtrip_empty_payload() {
279 let frame = build_frame(MSG_DO_ACTION, &[]);
280 assert_eq!(frame.len(), 5, "header only for empty payload");
281 let (got_type, got_payload) = decode_frame(&frame).expect("not EOS");
282 assert_eq!(got_type, MSG_DO_ACTION);
283 assert!(got_payload.is_empty());
284 }
285
286 #[test]
287 fn eos_frame_is_single_byte() {
288 let eos_frame = vec![MSG_EOS];
290 let result = decode_frame(&eos_frame);
291 assert!(result.is_none(), "EOS frame should return None");
292 }
293
294 #[test]
295 fn frame_length_is_little_endian() {
296 let payload = vec![0u8; 256];
298 let frame = build_frame(MSG_DO_GET, &payload);
299 assert_eq!(frame[1], 0x00, "LSB of 256");
301 assert_eq!(frame[2], 0x01, "next byte of 256");
302 assert_eq!(frame[3], 0x00);
303 assert_eq!(frame[4], 0x00);
304 }
305
306 #[test]
307 fn frame_large_payload_65536_bytes() {
308 let payload: Vec<u8> = (0u8..=255u8).cycle().take(65_536).collect();
310 let frame = build_frame(MSG_RECORD_BATCH, &payload);
311 let (got_type, got_payload) = decode_frame(&frame).expect("not EOS");
312 assert_eq!(got_type, MSG_RECORD_BATCH);
313 assert_eq!(got_payload.len(), 65_536);
314 assert_eq!(got_payload, payload.as_slice());
315 }
316
317 #[test]
318 fn all_rpc_type_bytes_covered_by_try_from() {
319 let known = [
321 MSG_GET_FLIGHT_INFO,
322 MSG_DO_GET,
323 MSG_DO_PUT,
324 MSG_DO_ACTION,
325 MSG_LIST_ACTIONS,
326 MSG_PREPARE_STATEMENT,
327 MSG_QUERY,
328 MSG_TOPOLOGY,
329 ];
330 for b in 0u8..=254u8 {
331 let result = RpcType::try_from(b);
332 if known.contains(&b) {
333 assert!(result.is_ok(), "byte {b:#04x} should map to a known RpcType");
334 } else {
335 assert!(result.is_err(), "byte {b:#04x} should be unknown");
336 }
337 }
338 }
339}