Skip to main content

reddb_wire/redwire/
stream.rs

1//! RedWire stream payload contracts.
2
3use serde_json::Value as JsonValue;
4
5use super::{BuildError, Frame, FrameBuilder, MessageKind};
6
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub struct OpenStreamRequest {
9    pub sql: String,
10    pub opts_raw: Vec<u8>,
11}
12
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum OpenStreamParseError {
15    NotJson,
16    NotObject,
17    MissingSql,
18    EmptySql,
19}
20
21impl OpenStreamParseError {
22    pub fn code(&self) -> &'static str {
23        match self {
24            Self::NotJson | Self::NotObject => "open_stream_invalid_payload",
25            Self::MissingSql | Self::EmptySql => "open_stream_missing_sql",
26        }
27    }
28
29    pub fn message(&self) -> &'static str {
30        match self {
31            Self::NotJson => "OpenStream payload must be JSON",
32            Self::NotObject => "OpenStream payload must be a JSON object",
33            Self::MissingSql => "OpenStream payload missing 'sql' string field",
34            Self::EmptySql => "OpenStream payload 'sql' must be non-empty",
35        }
36    }
37}
38
39pub fn parse_open_stream(payload: &[u8]) -> Result<OpenStreamRequest, OpenStreamParseError> {
40    let v: JsonValue =
41        serde_json::from_slice(payload).map_err(|_| OpenStreamParseError::NotJson)?;
42    let obj = v.as_object().ok_or(OpenStreamParseError::NotObject)?;
43    let sql = obj
44        .get("sql")
45        .and_then(|x| x.as_str())
46        .ok_or(OpenStreamParseError::MissingSql)?;
47    if sql.is_empty() {
48        return Err(OpenStreamParseError::EmptySql);
49    }
50    let opts_raw = obj
51        .get("opts")
52        .map(|v| serde_json::to_vec(v).unwrap_or_default())
53        .unwrap_or_default();
54    Ok(OpenStreamRequest {
55        sql: sql.to_string(),
56        opts_raw,
57    })
58}
59
60#[derive(Debug, Clone, Default, PartialEq, Eq)]
61pub struct StreamCancelRequest {
62    pub reason: Option<String>,
63}
64
65pub fn parse_stream_cancel(payload: &[u8]) -> StreamCancelRequest {
66    if payload.is_empty() {
67        return StreamCancelRequest::default();
68    }
69    let v: JsonValue = match serde_json::from_slice(payload) {
70        Ok(v) => v,
71        Err(_) => return StreamCancelRequest::default(),
72    };
73    let reason = v
74        .as_object()
75        .and_then(|o| o.get("reason"))
76        .and_then(|x| x.as_str())
77        .map(|s| s.to_string());
78    StreamCancelRequest { reason }
79}
80
81pub fn build_open_stream_payload(request: &OpenStreamRequest) -> Vec<u8> {
82    let mut obj = serde_json::Map::new();
83    obj.insert("sql".to_string(), JsonValue::String(request.sql.clone()));
84    if !request.opts_raw.is_empty() {
85        let opts = serde_json::from_slice(&request.opts_raw).unwrap_or(JsonValue::Null);
86        obj.insert("opts".to_string(), opts);
87    }
88    serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default()
89}
90
91pub fn build_open_stream_frame(
92    correlation_id: u64,
93    stream_id: u16,
94    request: &OpenStreamRequest,
95) -> Result<Frame, BuildError> {
96    FrameBuilder::request(correlation_id)
97        .kind(MessageKind::OpenStream)
98        .stream_id(stream_id)
99        .payload(build_open_stream_payload(request))
100        .build()
101}
102
103pub fn build_open_ack_payload(lease_id: u64, snapshot_lsn: u64, resumable: bool) -> Vec<u8> {
104    let mut obj = serde_json::Map::new();
105    obj.insert(
106        "lease_handle".to_string(),
107        JsonValue::String(lease_id.to_string()),
108    );
109    obj.insert("resumable".to_string(), JsonValue::Bool(resumable));
110    obj.insert(
111        "snapshot_lsn".to_string(),
112        JsonValue::Number(snapshot_lsn.into()),
113    );
114    serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default()
115}
116
117pub fn build_open_ack_frame(
118    correlation_id: u64,
119    stream_id: u16,
120    lease_id: u64,
121    snapshot_lsn: u64,
122    resumable: bool,
123) -> Result<Frame, BuildError> {
124    FrameBuilder::reply_to(correlation_id)
125        .kind(MessageKind::OpenAck)
126        .stream_id(stream_id)
127        .payload(build_open_ack_payload(lease_id, snapshot_lsn, resumable))
128        .build()
129}
130
131pub fn build_stream_chunk_payload(seq: u64, rows: Vec<JsonValue>, terminal: bool) -> Vec<u8> {
132    let mut obj = serde_json::Map::new();
133    obj.insert("seq".to_string(), JsonValue::Number(seq.into()));
134    obj.insert("rows".to_string(), JsonValue::Array(rows));
135    obj.insert("terminal".to_string(), JsonValue::Bool(terminal));
136    serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default()
137}
138
139pub fn build_stream_chunk_payload_from_json_bytes(
140    seq: u64,
141    rows: Vec<Vec<u8>>,
142    terminal: bool,
143) -> Vec<u8> {
144    let rows = rows
145        .into_iter()
146        .map(|row| serde_json::from_slice(&row).unwrap_or(JsonValue::Null))
147        .collect();
148    build_stream_chunk_payload(seq, rows, terminal)
149}
150
151pub fn build_stream_chunk_frame_from_json_bytes(
152    correlation_id: u64,
153    stream_id: u16,
154    seq: u64,
155    rows: Vec<Vec<u8>>,
156    terminal: bool,
157) -> Result<Frame, BuildError> {
158    FrameBuilder::reply_to(correlation_id)
159        .kind(MessageKind::StreamChunk)
160        .stream_id(stream_id)
161        .payload(build_stream_chunk_payload_from_json_bytes(
162            seq, rows, terminal,
163        ))
164        .build()
165}
166
167pub fn build_stream_error_payload(seq: Option<u64>, code: &str, message: &str) -> Vec<u8> {
168    let mut obj = serde_json::Map::new();
169    if let Some(s) = seq {
170        obj.insert("seq".to_string(), JsonValue::Number(s.into()));
171    }
172    obj.insert("code".to_string(), JsonValue::String(code.to_string()));
173    obj.insert(
174        "message".to_string(),
175        JsonValue::String(message.to_string()),
176    );
177    serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default()
178}
179
180pub fn build_stream_error_frame(
181    correlation_id: u64,
182    stream_id: u16,
183    seq: Option<u64>,
184    code: &str,
185    message: &str,
186) -> Result<Frame, BuildError> {
187    FrameBuilder::reply_to(correlation_id)
188        .kind(MessageKind::StreamError)
189        .stream_id(stream_id)
190        .payload(build_stream_error_payload(seq, code, message))
191        .build()
192}
193
194pub fn build_stream_end_payload(
195    row_count: u64,
196    lease_id: u64,
197    snapshot_lsn: u64,
198    cancelled: bool,
199) -> Vec<u8> {
200    let mut obj = serde_json::Map::new();
201    let mut stats = serde_json::Map::new();
202    stats.insert("row_count".to_string(), JsonValue::Number(row_count.into()));
203    stats.insert("lease_id".to_string(), JsonValue::Number(lease_id.into()));
204    stats.insert(
205        "snapshot_lsn".to_string(),
206        JsonValue::Number(snapshot_lsn.into()),
207    );
208    stats.insert("cancelled".to_string(), JsonValue::Bool(cancelled));
209    obj.insert("stats".to_string(), JsonValue::Object(stats));
210    serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default()
211}
212
213pub fn build_stream_end_frame(
214    correlation_id: u64,
215    stream_id: u16,
216    row_count: u64,
217    lease_id: u64,
218    snapshot_lsn: u64,
219    cancelled: bool,
220) -> Result<Frame, BuildError> {
221    FrameBuilder::reply_to(correlation_id)
222        .kind(MessageKind::StreamEnd)
223        .stream_id(stream_id)
224        .payload(build_stream_end_payload(
225            row_count,
226            lease_id,
227            snapshot_lsn,
228            cancelled,
229        ))
230        .build()
231}
232
233pub fn open_stream_is_input(payload: &[u8]) -> bool {
234    serde_json::from_slice::<JsonValue>(payload)
235        .ok()
236        .and_then(|v| {
237            v.as_object()
238                .and_then(|o| o.get("direction"))
239                .and_then(|d| d.as_str())
240                .map(|s| s.eq_ignore_ascii_case("in"))
241        })
242        .unwrap_or(false)
243}
244
245#[derive(Debug, Clone, PartialEq, Eq)]
246pub struct OpenInputRequest {
247    pub target: String,
248    pub columns: Vec<String>,
249}
250
251#[derive(Debug, Clone, PartialEq, Eq)]
252pub enum OpenInputParseError {
253    NotJson,
254    NotObject,
255    MissingTarget,
256    UnsafeTarget,
257    MissingColumns,
258    EmptyColumns,
259    UnsafeColumn,
260}
261
262impl OpenInputParseError {
263    pub fn code(&self) -> &'static str {
264        match self {
265            Self::NotJson | Self::NotObject => "open_stream_invalid_payload",
266            Self::MissingTarget | Self::UnsafeTarget => "open_stream_invalid_target",
267            Self::MissingColumns | Self::EmptyColumns | Self::UnsafeColumn => {
268                "open_stream_invalid_columns"
269            }
270        }
271    }
272
273    pub fn message(&self) -> &'static str {
274        match self {
275            Self::NotJson => "OpenStream payload must be JSON",
276            Self::NotObject => "OpenStream payload must be a JSON object",
277            Self::MissingTarget => "input OpenStream payload missing 'target' string field",
278            Self::UnsafeTarget => "input OpenStream 'target' is not a safe SQL identifier",
279            Self::MissingColumns => "input OpenStream payload missing 'columns' array field",
280            Self::EmptyColumns => "input OpenStream 'columns' must be a non-empty array",
281            Self::UnsafeColumn => "input OpenStream 'columns' entry is not a safe SQL identifier",
282        }
283    }
284}
285
286pub fn parse_open_input(payload: &[u8]) -> Result<OpenInputRequest, OpenInputParseError> {
287    let v: JsonValue = serde_json::from_slice(payload).map_err(|_| OpenInputParseError::NotJson)?;
288    let obj = v.as_object().ok_or(OpenInputParseError::NotObject)?;
289    let target = obj
290        .get("target")
291        .and_then(|x| x.as_str())
292        .ok_or(OpenInputParseError::MissingTarget)?;
293    if !is_safe_sql_identifier(target) {
294        return Err(OpenInputParseError::UnsafeTarget);
295    }
296    let columns_v = obj
297        .get("columns")
298        .and_then(|x| x.as_array())
299        .ok_or(OpenInputParseError::MissingColumns)?;
300    if columns_v.is_empty() {
301        return Err(OpenInputParseError::EmptyColumns);
302    }
303    let mut columns = Vec::with_capacity(columns_v.len());
304    for c in columns_v {
305        let name = c.as_str().ok_or(OpenInputParseError::UnsafeColumn)?;
306        if !is_safe_sql_identifier(name) {
307            return Err(OpenInputParseError::UnsafeColumn);
308        }
309        columns.push(name.to_string());
310    }
311    Ok(OpenInputRequest {
312        target: target.to_string(),
313        columns,
314    })
315}
316
317fn is_safe_sql_identifier(name: &str) -> bool {
318    let mut chars = name.chars();
319    match chars.next() {
320        Some(c) if c.is_ascii_alphabetic() || c == '_' => {}
321        _ => return false,
322    }
323    chars.all(|c| c.is_ascii_alphanumeric() || c == '_')
324}
325
326#[derive(Debug, Clone, PartialEq)]
327pub struct InputChunk {
328    pub seq: u64,
329    pub rows: Vec<JsonValue>,
330    pub terminal: bool,
331}
332
333#[derive(Debug, Clone, PartialEq, Eq)]
334pub struct InputChunkJson {
335    pub seq: u64,
336    pub rows_json: Vec<Vec<u8>>,
337    pub terminal: bool,
338}
339
340#[derive(Debug, Clone, PartialEq, Eq)]
341pub enum ChunkParseError {
342    NotJson,
343    NotObject,
344    RowsNotArray,
345}
346
347impl ChunkParseError {
348    pub fn code(&self) -> &'static str {
349        "invalid_chunk"
350    }
351
352    pub fn message(&self) -> &'static str {
353        match self {
354            Self::NotJson => "StreamChunk payload must be JSON",
355            Self::NotObject => "StreamChunk payload must be a JSON object",
356            Self::RowsNotArray => "StreamChunk 'rows' must be an array",
357        }
358    }
359}
360
361pub fn parse_input_chunk(payload: &[u8]) -> Result<InputChunk, ChunkParseError> {
362    let v: JsonValue = serde_json::from_slice(payload).map_err(|_| ChunkParseError::NotJson)?;
363    let obj = v.as_object().ok_or(ChunkParseError::NotObject)?;
364    let seq = obj.get("seq").and_then(|x| x.as_u64()).unwrap_or(0);
365    let terminal = obj
366        .get("terminal")
367        .and_then(|x| x.as_bool())
368        .unwrap_or(false);
369    let rows = match obj.get("rows") {
370        None | Some(JsonValue::Null) => Vec::new(),
371        Some(JsonValue::Array(arr)) => arr.clone(),
372        Some(_) => return Err(ChunkParseError::RowsNotArray),
373    };
374    Ok(InputChunk {
375        seq,
376        rows,
377        terminal,
378    })
379}
380
381pub fn parse_input_chunk_json(payload: &[u8]) -> Result<InputChunkJson, ChunkParseError> {
382    let chunk = parse_input_chunk(payload)?;
383    let rows_json = chunk
384        .rows
385        .iter()
386        .map(|row| serde_json::to_vec(row).unwrap_or_default())
387        .collect();
388    Ok(InputChunkJson {
389        seq: chunk.seq,
390        rows_json,
391        terminal: chunk.terminal,
392    })
393}
394
395pub fn build_input_stream_end_payload(
396    row_count: u64,
397    chunk_count: u64,
398    committed_rid: u64,
399    snapshot_lsn: u64,
400    cancelled: bool,
401) -> Vec<u8> {
402    let mut obj = serde_json::Map::new();
403    let mut stats = serde_json::Map::new();
404    stats.insert("row_count".to_string(), JsonValue::Number(row_count.into()));
405    stats.insert(
406        "chunk_count".to_string(),
407        JsonValue::Number(chunk_count.into()),
408    );
409    stats.insert(
410        "committed_rid".to_string(),
411        JsonValue::Number(committed_rid.into()),
412    );
413    stats.insert(
414        "snapshot_lsn".to_string(),
415        JsonValue::Number(snapshot_lsn.into()),
416    );
417    stats.insert("cancelled".to_string(), JsonValue::Bool(cancelled));
418    obj.insert("stats".to_string(), JsonValue::Object(stats));
419    serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default()
420}
421
422pub fn build_input_stream_end_frame(
423    correlation_id: u64,
424    stream_id: u16,
425    row_count: u64,
426    chunk_count: u64,
427    committed_rid: u64,
428    snapshot_lsn: u64,
429    cancelled: bool,
430) -> Result<Frame, BuildError> {
431    FrameBuilder::reply_to(correlation_id)
432        .kind(MessageKind::StreamEnd)
433        .stream_id(stream_id)
434        .payload(build_input_stream_end_payload(
435            row_count,
436            chunk_count,
437            committed_rid,
438            snapshot_lsn,
439            cancelled,
440        ))
441        .build()
442}
443
444pub fn build_input_stream_error_payload(
445    code: &str,
446    message: &str,
447    chunk_seq: u64,
448    recoverable_rid: u64,
449) -> Vec<u8> {
450    let mut obj = serde_json::Map::new();
451    obj.insert("code".to_string(), JsonValue::String(code.to_string()));
452    obj.insert(
453        "message".to_string(),
454        JsonValue::String(message.to_string()),
455    );
456    obj.insert("chunk_seq".to_string(), JsonValue::Number(chunk_seq.into()));
457    obj.insert(
458        "recoverable_rid".to_string(),
459        JsonValue::Number(recoverable_rid.into()),
460    );
461    serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default()
462}
463
464pub fn build_input_stream_error_frame(
465    correlation_id: u64,
466    stream_id: u16,
467    code: &str,
468    message: &str,
469    chunk_seq: u64,
470    recoverable_rid: u64,
471) -> Result<Frame, BuildError> {
472    FrameBuilder::reply_to(correlation_id)
473        .kind(MessageKind::StreamError)
474        .stream_id(stream_id)
475        .payload(build_input_stream_error_payload(
476            code,
477            message,
478            chunk_seq,
479            recoverable_rid,
480        ))
481        .build()
482}
483
484#[cfg(test)]
485mod tests {
486    use super::*;
487
488    #[test]
489    fn output_open_stream_contract_parses_opts() {
490        let req = parse_open_stream(br#"{"sql":"SELECT 1","opts":{"resume_after_rid":42}}"#)
491            .expect("parse open stream");
492        assert_eq!(req.sql, "SELECT 1");
493        assert!(!req.opts_raw.is_empty());
494    }
495
496    #[test]
497    fn output_open_stream_builder_round_trips_request() {
498        let request = OpenStreamRequest {
499            sql: "SELECT id FROM widgets".to_string(),
500            opts_raw: br#"{"resume_after_rid":42}"#.to_vec(),
501        };
502        let frame = build_open_stream_frame(12, 4, &request).unwrap();
503        assert_eq!(frame.kind, MessageKind::OpenStream);
504        assert_eq!(frame.correlation_id, 12);
505        assert_eq!(frame.stream_id, 4);
506        let parsed = parse_open_stream(&frame.payload).unwrap();
507        assert_eq!(parsed.sql, request.sql);
508        assert_eq!(
509            serde_json::from_slice::<JsonValue>(&parsed.opts_raw).unwrap(),
510            serde_json::from_slice::<JsonValue>(&request.opts_raw).unwrap()
511        );
512    }
513
514    #[test]
515    fn input_open_contract_rejects_unsafe_identifiers() {
516        assert_eq!(
517            parse_open_input(br#"{"direction":"in","target":"t;drop","columns":["id"]}"#),
518            Err(OpenInputParseError::UnsafeTarget)
519        );
520        assert_eq!(
521            parse_open_input(br#"{"direction":"in","target":"t","columns":["bad name"]}"#),
522            Err(OpenInputParseError::UnsafeColumn)
523        );
524    }
525
526    #[test]
527    fn input_chunk_json_preserves_rows_as_json_bytes() {
528        let chunk =
529            parse_input_chunk_json(br#"{"seq":3,"rows":[{"id":1}],"terminal":true}"#).unwrap();
530        assert_eq!(chunk.seq, 3);
531        assert_eq!(chunk.rows_json.len(), 1);
532        assert!(std::str::from_utf8(&chunk.rows_json[0])
533            .unwrap()
534            .contains("\"id\""));
535        assert!(chunk.terminal);
536    }
537
538    #[test]
539    fn stream_payload_builders_emit_json_objects() {
540        let ack = build_open_ack_payload(42, 7, false);
541        let value: JsonValue = serde_json::from_slice(&ack).unwrap();
542        assert_eq!(value["lease_handle"], "42");
543        assert_eq!(value["resumable"], false);
544        assert_eq!(value["snapshot_lsn"], 7);
545
546        let end = build_stream_end_payload(5, 42, 7, true);
547        let value: JsonValue = serde_json::from_slice(&end).unwrap();
548        assert_eq!(value["stats"]["row_count"], 5);
549        assert_eq!(value["stats"]["lease_id"], 42);
550        assert_eq!(value["stats"]["snapshot_lsn"], 7);
551        assert_eq!(value["stats"]["cancelled"], true);
552
553        let with_seq = build_stream_error_payload(Some(3), "x", "y");
554        let value: JsonValue = serde_json::from_slice(&with_seq).unwrap();
555        assert_eq!(value["seq"], 3);
556        assert_eq!(value["code"], "x");
557        assert_eq!(value["message"], "y");
558
559        let without_seq = build_stream_error_payload(None, "x", "y");
560        let value: JsonValue = serde_json::from_slice(&without_seq).unwrap();
561        assert!(value.as_object().unwrap().get("seq").is_none());
562    }
563
564    #[test]
565    fn input_stream_payload_builders_emit_committed_range_and_error_cursor() {
566        let end = build_input_stream_end_payload(3, 2, 42, 40, false);
567        let value: JsonValue = serde_json::from_slice(&end).unwrap();
568        assert_eq!(value["stats"]["row_count"], 3);
569        assert_eq!(value["stats"]["chunk_count"], 2);
570        assert_eq!(value["stats"]["committed_rid"], 42);
571        assert_eq!(value["stats"]["snapshot_lsn"], 40);
572        assert_eq!(value["stats"]["cancelled"], false);
573
574        let error = build_input_stream_error_payload("invalid_row", "bad", 2, 41);
575        let value: JsonValue = serde_json::from_slice(&error).unwrap();
576        assert_eq!(value["code"], "invalid_row");
577        assert_eq!(value["message"], "bad");
578        assert_eq!(value["chunk_seq"], 2);
579        assert_eq!(value["recoverable_rid"], 41);
580    }
581
582    #[test]
583    fn stream_frame_builders_echo_stream_and_correlation() {
584        let ack = build_open_ack_frame(99, 7, 42, 100, false).unwrap();
585        assert_eq!(ack.kind, MessageKind::OpenAck);
586        assert_eq!(ack.correlation_id, 99);
587        assert_eq!(ack.stream_id, 7);
588
589        let chunk = build_stream_chunk_frame_from_json_bytes(
590            99,
591            7,
592            1,
593            vec![br#"{"id":1}"#.to_vec()],
594            false,
595        )
596        .unwrap();
597        assert_eq!(chunk.kind, MessageKind::StreamChunk);
598        assert_eq!(chunk.stream_id, 7);
599
600        let error = build_stream_error_frame(99, 7, Some(1), "bad", "failed").unwrap();
601        assert_eq!(error.kind, MessageKind::StreamError);
602        assert_eq!(error.correlation_id, 99);
603
604        let end = build_stream_end_frame(99, 7, 5, 42, 100, true).unwrap();
605        assert_eq!(end.kind, MessageKind::StreamEnd);
606        assert_eq!(end.stream_id, 7);
607
608        let input_error =
609            build_input_stream_error_frame(99, 8, "invalid_row", "bad", 2, 41).unwrap();
610        assert_eq!(input_error.kind, MessageKind::StreamError);
611        assert_eq!(input_error.stream_id, 8);
612
613        let input_end = build_input_stream_end_frame(99, 8, 3, 2, 42, 40, false).unwrap();
614        assert_eq!(input_end.kind, MessageKind::StreamEnd);
615        assert_eq!(input_end.correlation_id, 99);
616    }
617}