Skip to main content

reddb_server/wire/redwire/
input_stream.rs

1//! RedWire input-stream dispatch (issue #764, PRD #759 S5).
2//!
3//! Brings the S4 HTTP NDJSON input-stream behaviour
4//! ([`crate::server::handlers_query::handle_query_ndjson_input_stream`])
5//! to the RedWire protocol, reusing the S3 envelope vocabulary:
6//!
7//!   - `OpenStream` (client→server) — carries `direction: "in"` plus a
8//!     `target` table and `columns`. The output-stream variant
9//!     (`direction: "out"`, the default) keeps using `sql` and is
10//!     handled by [`super::output_stream`]; the two never collide
11//!     because the dispatch loop branches on `direction` first.
12//!   - `OpenAck`    (server→client) — input stream accepted; carries
13//!     the lease handle + snapshot LSN, identical to the output ack.
14//!   - `StreamChunk`(client→server) — one chunk of rows. Each chunk
15//!     is committed atomically (one multi-row `INSERT`) before the
16//!     next frame is read, so rows from chunk K are durable and
17//!     visible before chunk K+1 arrives (auto-commit per chunk). A
18//!     chunk with `terminal: true` closes the input phase.
19//!   - `StreamEnd`  (server→client) — success terminal carrying the
20//!     committed RID range (`snapshot_lsn` .. `committed_rid`) and
21//!     stats (`row_count`, `chunk_count`).
22//!   - `StreamError`(server→client) — a chunk failed to commit. Rows
23//!     from earlier chunks remain durable; the error carries
24//!     `recoverable_rid` (the CDC LSN at the last good commit) and
25//!     the failing `chunk_seq`. No further frames are emitted for the
26//!     `stream_id` (AC #3).
27//!   - `StreamCancel`(client→server) — discard the in-flight (not yet
28//!     committed) chunk; prior committed chunks stay durable (AC #4).
29//!
30//! Input streams are driven *inline* from the per-connection reader
31//! loop (each `StreamChunk` commits synchronously) and tracked in an
32//! [`InputStreamRegistry`] keyed by `stream_id`, kept separate from
33//! the spawned-worker [`super::output_stream::StreamRegistry`]. Both
34//! kinds of stream therefore coexist on one connection, dispatched by
35//! `stream_id` (AC #2).
36
37use std::collections::HashMap;
38
39use crate::runtime::RedDBRuntime;
40use crate::serde_json::{self, Value as JsonValue};
41use reddb_wire::redwire::frame::{Frame, MessageKind};
42
43use super::output_stream::RegisterError;
44use super::FrameBuilder;
45use crate::server::output_stream::{Clock, OpenStreamError, StreamConfig, StreamLease};
46
47/// `true` when an `OpenStream` payload requests the input direction
48/// (`{"direction":"in", ...}`). Any other value — including a missing
49/// field or a malformed payload — is treated as the output direction
50/// so the existing S3 path keeps owning the default.
51pub fn open_stream_is_input(payload: &[u8]) -> bool {
52    serde_json::from_slice::<JsonValue>(payload)
53        .ok()
54        .and_then(|v| {
55            v.as_object()
56                .and_then(|o| o.get("direction"))
57                .and_then(|d| d.as_str())
58                .map(|s| s.eq_ignore_ascii_case("in"))
59        })
60        .unwrap_or(false)
61}
62
63/// Parsed `OpenStream {direction:"in"}` payload. Shape:
64///
65/// ```json
66/// { "direction": "in", "target": "<table>", "columns": ["c1", "c2"] }
67/// ```
68#[derive(Debug, Clone, PartialEq, Eq)]
69pub struct OpenInputRequest {
70    pub target: String,
71    pub columns: Vec<String>,
72}
73
74#[derive(Debug, Clone, PartialEq, Eq)]
75pub enum OpenInputParseError {
76    NotJson,
77    NotObject,
78    MissingTarget,
79    UnsafeTarget,
80    MissingColumns,
81    EmptyColumns,
82    UnsafeColumn,
83}
84
85impl OpenInputParseError {
86    pub fn code(&self) -> &'static str {
87        match self {
88            Self::NotJson | Self::NotObject => "open_stream_invalid_payload",
89            Self::MissingTarget | Self::UnsafeTarget => "open_stream_invalid_target",
90            Self::MissingColumns | Self::EmptyColumns | Self::UnsafeColumn => {
91                "open_stream_invalid_columns"
92            }
93        }
94    }
95    pub fn message(&self) -> &'static str {
96        match self {
97            Self::NotJson => "OpenStream payload must be JSON",
98            Self::NotObject => "OpenStream payload must be a JSON object",
99            Self::MissingTarget => "input OpenStream payload missing 'target' string field",
100            Self::UnsafeTarget => "input OpenStream 'target' is not a safe SQL identifier",
101            Self::MissingColumns => "input OpenStream payload missing 'columns' array field",
102            Self::EmptyColumns => "input OpenStream 'columns' must be a non-empty array",
103            Self::UnsafeColumn => "input OpenStream 'columns' entry is not a safe SQL identifier",
104        }
105    }
106}
107
108pub fn parse_open_input(payload: &[u8]) -> Result<OpenInputRequest, OpenInputParseError> {
109    use crate::server::handlers_query::is_safe_sql_identifier;
110    let v: JsonValue = serde_json::from_slice(payload).map_err(|_| OpenInputParseError::NotJson)?;
111    let obj = v.as_object().ok_or(OpenInputParseError::NotObject)?;
112    let target = obj
113        .get("target")
114        .and_then(|x| x.as_str())
115        .ok_or(OpenInputParseError::MissingTarget)?;
116    if !is_safe_sql_identifier(target) {
117        return Err(OpenInputParseError::UnsafeTarget);
118    }
119    let columns_v = obj
120        .get("columns")
121        .and_then(|x| x.as_array())
122        .ok_or(OpenInputParseError::MissingColumns)?;
123    if columns_v.is_empty() {
124        return Err(OpenInputParseError::EmptyColumns);
125    }
126    let mut columns = Vec::with_capacity(columns_v.len());
127    for c in columns_v {
128        let name = c.as_str().ok_or(OpenInputParseError::UnsafeColumn)?;
129        if !is_safe_sql_identifier(name) {
130            return Err(OpenInputParseError::UnsafeColumn);
131        }
132        columns.push(name.to_string());
133    }
134    Ok(OpenInputRequest {
135        target: target.to_string(),
136        columns,
137    })
138}
139
140/// Parsed `StreamChunk` payload sent by an input-stream client. Shape
141/// mirrors the output-stream chunk (`{"seq", "rows", "terminal"}`) but
142/// the rows are JSON objects keyed by column rather than already-shaped
143/// output rows.
144// No `Eq`: `serde_json::Value` rows may carry floats, which are only
145// `PartialEq`.
146#[derive(Debug, Clone, PartialEq)]
147pub struct InputChunk {
148    pub seq: u64,
149    pub rows: Vec<JsonValue>,
150    pub terminal: bool,
151}
152
153#[derive(Debug, Clone, PartialEq, Eq)]
154pub enum ChunkParseError {
155    NotJson,
156    NotObject,
157    RowsNotArray,
158}
159
160impl ChunkParseError {
161    pub fn code(&self) -> &'static str {
162        "invalid_chunk"
163    }
164    pub fn message(&self) -> &'static str {
165        match self {
166            Self::NotJson => "StreamChunk payload must be JSON",
167            Self::NotObject => "StreamChunk payload must be a JSON object",
168            Self::RowsNotArray => "StreamChunk 'rows' must be an array",
169        }
170    }
171}
172
173pub fn parse_input_chunk(payload: &[u8]) -> Result<InputChunk, ChunkParseError> {
174    let v: JsonValue = serde_json::from_slice(payload).map_err(|_| ChunkParseError::NotJson)?;
175    let obj = v.as_object().ok_or(ChunkParseError::NotObject)?;
176    let seq = obj.get("seq").and_then(|x| x.as_u64()).unwrap_or(0);
177    let terminal = obj
178        .get("terminal")
179        .and_then(|x| x.as_bool())
180        .unwrap_or(false);
181    // `rows` is optional so a bare terminal frame (`{"terminal":true}`)
182    // can close the stream without carrying a final batch.
183    let rows = match obj.get("rows") {
184        None | Some(JsonValue::Null) => Vec::new(),
185        Some(JsonValue::Array(arr)) => arr.clone(),
186        Some(_) => return Err(ChunkParseError::RowsNotArray),
187    };
188    Ok(InputChunk {
189        seq,
190        rows,
191        terminal,
192    })
193}
194
195/// Per-stream state for an in-flight input stream. Lives in the
196/// session loop's [`InputStreamRegistry`] and is mutated synchronously
197/// as each `StreamChunk` is committed.
198#[derive(Debug)]
199pub struct InputStreamState {
200    pub lease: StreamLease,
201    pub target: String,
202    pub columns: Vec<String>,
203    /// CDC LSN at the last successful per-chunk commit; the start of
204    /// the committed RID range is the lease's `snapshot_lsn`.
205    pub committed_rid: u64,
206    pub row_count: u64,
207    pub chunk_count: u64,
208    pub snapshot_lsn: u64,
209}
210
211impl InputStreamState {
212    pub fn new(lease: StreamLease, target: String, columns: Vec<String>) -> Self {
213        let snapshot_lsn = lease.snapshot_lsn;
214        Self {
215            lease,
216            target,
217            columns,
218            committed_rid: snapshot_lsn,
219            row_count: 0,
220            chunk_count: 0,
221            snapshot_lsn,
222        }
223    }
224
225    /// Commit one chunk of rows as a single atomic multi-row `INSERT`.
226    /// On success the rows are durable and `committed_rid` advances to
227    /// the post-commit CDC LSN. On failure nothing in this chunk
228    /// commits — `committed_rid` (and therefore `recoverable_rid`)
229    /// stays at the last good commit, so chunks `1..N-1` remain
230    /// durable (AC #3).
231    pub fn commit_chunk(
232        &mut self,
233        runtime: &RedDBRuntime,
234        rows: &[JsonValue],
235    ) -> Result<(), (String, String)> {
236        if rows.is_empty() {
237            return Ok(());
238        }
239        // Project each row object onto the declared columns (missing
240        // keys become NULL), matching the S4 `parse_row_frame` shape.
241        let mut positional: Vec<Vec<JsonValue>> = Vec::with_capacity(rows.len());
242        for row in rows {
243            let obj = row.as_object().ok_or_else(|| {
244                (
245                    "invalid_row".to_string(),
246                    "row must be a JSON object".to_string(),
247                )
248            })?;
249            let mut values = Vec::with_capacity(self.columns.len());
250            for col in &self.columns {
251                values.push(obj.get(col).cloned().unwrap_or(JsonValue::Null));
252            }
253            positional.push(values);
254        }
255        let sql = crate::server::handlers_query::build_insert_sql(
256            &self.target,
257            &self.columns,
258            &positional,
259        )
260        .map_err(|message| ("invalid_row".to_string(), message))?;
261        match runtime.execute_query(&sql) {
262            Ok(_) => {
263                self.row_count += rows.len() as u64;
264                self.committed_rid = runtime.cdc_current_lsn();
265                self.chunk_count += 1;
266                Ok(())
267            }
268            Err(err) => Err(("chunk_commit_failed".to_string(), err.to_string())),
269        }
270    }
271}
272
273/// Per-connection registry of in-flight input streams. Keyed by
274/// `stream_id`, separate from the output-stream worker registry so an
275/// input and an output stream may share one connection without
276/// colliding (AC #2).
277#[derive(Default)]
278pub struct InputStreamRegistry {
279    inner: HashMap<u16, InputStreamState>,
280}
281
282impl InputStreamRegistry {
283    pub fn new() -> Self {
284        Self::default()
285    }
286
287    /// Register a freshly-opened input stream. Mirrors the output
288    /// registry's reserved-id / duplicate guards and reuses its
289    /// [`RegisterError`] codes so clients see one taxonomy.
290    pub fn register(
291        &mut self,
292        stream_id: u16,
293        state: InputStreamState,
294    ) -> Result<(), RegisterError> {
295        if stream_id == 0 {
296            return Err(RegisterError::ReservedStreamId);
297        }
298        if self.inner.contains_key(&stream_id) {
299            return Err(RegisterError::StreamInUse);
300        }
301        self.inner.insert(stream_id, state);
302        Ok(())
303    }
304
305    pub fn get_mut(&mut self, stream_id: u16) -> Option<&mut InputStreamState> {
306        self.inner.get_mut(&stream_id)
307    }
308
309    pub fn contains(&self, stream_id: u16) -> bool {
310        self.inner.contains_key(&stream_id)
311    }
312
313    /// Drop the stream from the registry, returning its state so the
314    /// caller can read final stats for a terminal frame. Idempotent —
315    /// a second remove returns `None`.
316    pub fn remove(&mut self, stream_id: u16) -> Option<InputStreamState> {
317        self.inner.remove(&stream_id)
318    }
319
320    pub fn active_count(&self) -> usize {
321        self.inner.len()
322    }
323}
324
325/// Build the success terminal `StreamEnd` payload for an input stream.
326/// Carries the committed RID range (`snapshot_lsn` .. `committed_rid`)
327/// and ingest stats.
328pub fn build_input_stream_end_payload(
329    row_count: u64,
330    chunk_count: u64,
331    committed_rid: u64,
332    snapshot_lsn: u64,
333    cancelled: bool,
334) -> Vec<u8> {
335    let mut obj = serde_json::Map::new();
336    let mut stats = serde_json::Map::new();
337    stats.insert("row_count".to_string(), JsonValue::Number(row_count as f64));
338    stats.insert(
339        "chunk_count".to_string(),
340        JsonValue::Number(chunk_count as f64),
341    );
342    stats.insert(
343        "committed_rid".to_string(),
344        JsonValue::Number(committed_rid as f64),
345    );
346    stats.insert(
347        "snapshot_lsn".to_string(),
348        JsonValue::Number(snapshot_lsn as f64),
349    );
350    stats.insert("cancelled".to_string(), JsonValue::Bool(cancelled));
351    obj.insert("stats".to_string(), JsonValue::Object(stats));
352    serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default()
353}
354
355/// Build the input-stream `StreamError` payload. Unlike the output
356/// variant it carries the `recoverable_rid` prefix (the CDC LSN of the
357/// last good commit) and the failing `chunk_seq`.
358pub fn build_input_stream_error_payload(
359    code: &str,
360    message: &str,
361    chunk_seq: u64,
362    recoverable_rid: u64,
363) -> Vec<u8> {
364    let mut obj = serde_json::Map::new();
365    obj.insert("code".to_string(), JsonValue::String(code.to_string()));
366    obj.insert(
367        "message".to_string(),
368        JsonValue::String(message.to_string()),
369    );
370    obj.insert("chunk_seq".to_string(), JsonValue::Number(chunk_seq as f64));
371    obj.insert(
372        "recoverable_rid".to_string(),
373        JsonValue::Number(recoverable_rid as f64),
374    );
375    serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default()
376}
377
378/// Build an input-stream `StreamError` frame addressed to `stream_id`,
379/// echoing `correlation_id` so the client can pair it to the request.
380pub fn build_input_stream_error_frame(
381    correlation_id: u64,
382    stream_id: u16,
383    code: &str,
384    message: &str,
385    chunk_seq: u64,
386    recoverable_rid: u64,
387) -> std::io::Result<Frame> {
388    FrameBuilder::reply_to(correlation_id)
389        .kind(MessageKind::StreamError)
390        .stream_id(stream_id)
391        .payload(build_input_stream_error_payload(
392            code,
393            message,
394            chunk_seq,
395            recoverable_rid,
396        ))
397        .build()
398        .map_err(|e| std::io::Error::other(format!("build input StreamError: {e}")))
399}
400
401/// Build the input-stream `StreamEnd` frame.
402pub fn build_input_stream_end_frame(
403    correlation_id: u64,
404    stream_id: u16,
405    row_count: u64,
406    chunk_count: u64,
407    committed_rid: u64,
408    snapshot_lsn: u64,
409    cancelled: bool,
410) -> std::io::Result<Frame> {
411    FrameBuilder::reply_to(correlation_id)
412        .kind(MessageKind::StreamEnd)
413        .stream_id(stream_id)
414        .payload(build_input_stream_end_payload(
415            row_count,
416            chunk_count,
417            committed_rid,
418            snapshot_lsn,
419            cancelled,
420        ))
421        .build()
422        .map_err(|e| std::io::Error::other(format!("build input StreamEnd: {e}")))
423}
424
425/// Open an input-stream lease, reusing the output-stream lease
426/// primitive so HTTP, output, and input streams agree on TTL and the
427/// in-transaction refusal (AC mirrors S4 #4).
428pub fn open_input_lease(
429    config: StreamConfig,
430    snapshot_lsn: u64,
431    in_transaction: bool,
432    clock: &dyn Clock,
433) -> Result<StreamLease, OpenStreamError> {
434    crate::server::output_stream::open_stream(config, snapshot_lsn, in_transaction, clock)
435}
436
437#[cfg(test)]
438mod tests {
439    use super::*;
440
441    #[test]
442    fn detects_input_direction() {
443        assert!(open_stream_is_input(
444            br#"{"direction":"in","target":"t","columns":["a"]}"#
445        ));
446        assert!(open_stream_is_input(br#"{"direction":"IN"}"#));
447        // Default / output direction.
448        assert!(!open_stream_is_input(br#"{"sql":"SELECT 1"}"#));
449        assert!(!open_stream_is_input(br#"{"direction":"out"}"#));
450        assert!(!open_stream_is_input(b"not json"));
451    }
452
453    #[test]
454    fn parse_open_input_accepts_target_and_columns() {
455        let req =
456            parse_open_input(br#"{"direction":"in","target":"events","columns":["id","name"]}"#)
457                .unwrap();
458        assert_eq!(req.target, "events");
459        assert_eq!(req.columns, vec!["id".to_string(), "name".to_string()]);
460    }
461
462    #[test]
463    fn parse_open_input_rejects_missing_target() {
464        assert!(matches!(
465            parse_open_input(br#"{"direction":"in","columns":["a"]}"#),
466            Err(OpenInputParseError::MissingTarget)
467        ));
468    }
469
470    #[test]
471    fn parse_open_input_rejects_unsafe_target() {
472        assert!(matches!(
473            parse_open_input(br#"{"direction":"in","target":"t;DROP","columns":["a"]}"#),
474            Err(OpenInputParseError::UnsafeTarget)
475        ));
476    }
477
478    #[test]
479    fn parse_open_input_rejects_empty_or_missing_columns() {
480        assert!(matches!(
481            parse_open_input(br#"{"direction":"in","target":"t","columns":[]}"#),
482            Err(OpenInputParseError::EmptyColumns)
483        ));
484        assert!(matches!(
485            parse_open_input(br#"{"direction":"in","target":"t"}"#),
486            Err(OpenInputParseError::MissingColumns)
487        ));
488    }
489
490    #[test]
491    fn parse_open_input_rejects_unsafe_column() {
492        assert!(matches!(
493            parse_open_input(br#"{"direction":"in","target":"t","columns":["ok","b ad"]}"#),
494            Err(OpenInputParseError::UnsafeColumn)
495        ));
496    }
497
498    #[test]
499    fn parse_chunk_extracts_rows_seq_terminal() {
500        let chunk =
501            parse_input_chunk(br#"{"seq":3,"rows":[{"id":1},{"id":2}],"terminal":true}"#).unwrap();
502        assert_eq!(chunk.seq, 3);
503        assert_eq!(chunk.rows.len(), 2);
504        assert!(chunk.terminal);
505    }
506
507    #[test]
508    fn parse_chunk_allows_bare_terminal() {
509        let chunk = parse_input_chunk(br#"{"terminal":true}"#).unwrap();
510        assert!(chunk.rows.is_empty());
511        assert!(chunk.terminal);
512        assert_eq!(chunk.seq, 0);
513    }
514
515    #[test]
516    fn parse_chunk_rejects_non_array_rows() {
517        assert!(matches!(
518            parse_input_chunk(br#"{"rows":5}"#),
519            Err(ChunkParseError::RowsNotArray)
520        ));
521    }
522
523    #[test]
524    fn registry_register_rejects_reserved_and_duplicate() {
525        let mut reg = InputStreamRegistry::new();
526        let lease = StreamLease {
527            id: 1,
528            lease_handle: "h".to_string(),
529            snapshot_lsn: 10,
530            opened_at_ms: 0,
531            config: StreamConfig::default(),
532        };
533        assert!(matches!(
534            reg.register(
535                0,
536                InputStreamState::new(
537                    StreamLease {
538                        id: 2,
539                        lease_handle: "h2".to_string(),
540                        snapshot_lsn: 10,
541                        opened_at_ms: 0,
542                        config: StreamConfig::default(),
543                    },
544                    "t".to_string(),
545                    vec!["a".to_string()],
546                )
547            ),
548            Err(RegisterError::ReservedStreamId)
549        ));
550        reg.register(
551            5,
552            InputStreamState::new(lease, "t".to_string(), vec!["a".to_string()]),
553        )
554        .unwrap();
555        assert!(reg.contains(5));
556        assert!(matches!(
557            reg.register(
558                5,
559                InputStreamState::new(
560                    StreamLease {
561                        id: 3,
562                        lease_handle: "h3".to_string(),
563                        snapshot_lsn: 10,
564                        opened_at_ms: 0,
565                        config: StreamConfig::default(),
566                    },
567                    "t".to_string(),
568                    vec!["a".to_string()],
569                )
570            ),
571            Err(RegisterError::StreamInUse)
572        ));
573        assert_eq!(reg.active_count(), 1);
574        assert!(reg.remove(5).is_some());
575        assert!(reg.remove(5).is_none());
576    }
577
578    #[test]
579    fn end_payload_carries_committed_rid_range_and_stats() {
580        let bytes = build_input_stream_end_payload(3, 2, 42, 40, false);
581        let v: JsonValue = serde_json::from_slice(&bytes).unwrap();
582        let stats = v.as_object().unwrap().get("stats").unwrap();
583        assert_eq!(stats.get("row_count").and_then(|x| x.as_u64()), Some(3));
584        assert_eq!(stats.get("chunk_count").and_then(|x| x.as_u64()), Some(2));
585        assert_eq!(
586            stats.get("committed_rid").and_then(|x| x.as_u64()),
587            Some(42)
588        );
589        assert_eq!(stats.get("snapshot_lsn").and_then(|x| x.as_u64()), Some(40));
590        assert_eq!(
591            stats.get("cancelled").and_then(|x| x.as_bool()),
592            Some(false)
593        );
594    }
595
596    #[test]
597    fn error_payload_carries_recoverable_rid_and_chunk_seq() {
598        let bytes = build_input_stream_error_payload("invalid_row", "bad", 2, 41);
599        let v: JsonValue = serde_json::from_slice(&bytes).unwrap();
600        let obj = v.as_object().unwrap();
601        assert_eq!(
602            obj.get("code").and_then(|x| x.as_str()),
603            Some("invalid_row")
604        );
605        assert_eq!(obj.get("chunk_seq").and_then(|x| x.as_u64()), Some(2));
606        assert_eq!(
607            obj.get("recoverable_rid").and_then(|x| x.as_u64()),
608            Some(41)
609        );
610    }
611}