1use std::collections::HashMap;
38
39use crate::runtime::RedDBRuntime;
40use crate::serde_json::{self, Value as JsonValue};
41use reddb_wire::redwire::frame::{Frame, MessageKind};
42
43use super::output_stream::RegisterError;
44use super::FrameBuilder;
45use crate::server::output_stream::{Clock, OpenStreamError, StreamConfig, StreamLease};
46
47pub fn open_stream_is_input(payload: &[u8]) -> bool {
52 serde_json::from_slice::<JsonValue>(payload)
53 .ok()
54 .and_then(|v| {
55 v.as_object()
56 .and_then(|o| o.get("direction"))
57 .and_then(|d| d.as_str())
58 .map(|s| s.eq_ignore_ascii_case("in"))
59 })
60 .unwrap_or(false)
61}
62
63#[derive(Debug, Clone, PartialEq, Eq)]
69pub struct OpenInputRequest {
70 pub target: String,
71 pub columns: Vec<String>,
72}
73
74#[derive(Debug, Clone, PartialEq, Eq)]
75pub enum OpenInputParseError {
76 NotJson,
77 NotObject,
78 MissingTarget,
79 UnsafeTarget,
80 MissingColumns,
81 EmptyColumns,
82 UnsafeColumn,
83}
84
85impl OpenInputParseError {
86 pub fn code(&self) -> &'static str {
87 match self {
88 Self::NotJson | Self::NotObject => "open_stream_invalid_payload",
89 Self::MissingTarget | Self::UnsafeTarget => "open_stream_invalid_target",
90 Self::MissingColumns | Self::EmptyColumns | Self::UnsafeColumn => {
91 "open_stream_invalid_columns"
92 }
93 }
94 }
95 pub fn message(&self) -> &'static str {
96 match self {
97 Self::NotJson => "OpenStream payload must be JSON",
98 Self::NotObject => "OpenStream payload must be a JSON object",
99 Self::MissingTarget => "input OpenStream payload missing 'target' string field",
100 Self::UnsafeTarget => "input OpenStream 'target' is not a safe SQL identifier",
101 Self::MissingColumns => "input OpenStream payload missing 'columns' array field",
102 Self::EmptyColumns => "input OpenStream 'columns' must be a non-empty array",
103 Self::UnsafeColumn => "input OpenStream 'columns' entry is not a safe SQL identifier",
104 }
105 }
106}
107
108pub fn parse_open_input(payload: &[u8]) -> Result<OpenInputRequest, OpenInputParseError> {
109 use crate::server::handlers_query::is_safe_sql_identifier;
110 let v: JsonValue = serde_json::from_slice(payload).map_err(|_| OpenInputParseError::NotJson)?;
111 let obj = v.as_object().ok_or(OpenInputParseError::NotObject)?;
112 let target = obj
113 .get("target")
114 .and_then(|x| x.as_str())
115 .ok_or(OpenInputParseError::MissingTarget)?;
116 if !is_safe_sql_identifier(target) {
117 return Err(OpenInputParseError::UnsafeTarget);
118 }
119 let columns_v = obj
120 .get("columns")
121 .and_then(|x| x.as_array())
122 .ok_or(OpenInputParseError::MissingColumns)?;
123 if columns_v.is_empty() {
124 return Err(OpenInputParseError::EmptyColumns);
125 }
126 let mut columns = Vec::with_capacity(columns_v.len());
127 for c in columns_v {
128 let name = c.as_str().ok_or(OpenInputParseError::UnsafeColumn)?;
129 if !is_safe_sql_identifier(name) {
130 return Err(OpenInputParseError::UnsafeColumn);
131 }
132 columns.push(name.to_string());
133 }
134 Ok(OpenInputRequest {
135 target: target.to_string(),
136 columns,
137 })
138}
139
140#[derive(Debug, Clone, PartialEq)]
147pub struct InputChunk {
148 pub seq: u64,
149 pub rows: Vec<JsonValue>,
150 pub terminal: bool,
151}
152
153#[derive(Debug, Clone, PartialEq, Eq)]
154pub enum ChunkParseError {
155 NotJson,
156 NotObject,
157 RowsNotArray,
158}
159
160impl ChunkParseError {
161 pub fn code(&self) -> &'static str {
162 "invalid_chunk"
163 }
164 pub fn message(&self) -> &'static str {
165 match self {
166 Self::NotJson => "StreamChunk payload must be JSON",
167 Self::NotObject => "StreamChunk payload must be a JSON object",
168 Self::RowsNotArray => "StreamChunk 'rows' must be an array",
169 }
170 }
171}
172
173pub fn parse_input_chunk(payload: &[u8]) -> Result<InputChunk, ChunkParseError> {
174 let v: JsonValue = serde_json::from_slice(payload).map_err(|_| ChunkParseError::NotJson)?;
175 let obj = v.as_object().ok_or(ChunkParseError::NotObject)?;
176 let seq = obj.get("seq").and_then(|x| x.as_u64()).unwrap_or(0);
177 let terminal = obj
178 .get("terminal")
179 .and_then(|x| x.as_bool())
180 .unwrap_or(false);
181 let rows = match obj.get("rows") {
184 None | Some(JsonValue::Null) => Vec::new(),
185 Some(JsonValue::Array(arr)) => arr.clone(),
186 Some(_) => return Err(ChunkParseError::RowsNotArray),
187 };
188 Ok(InputChunk {
189 seq,
190 rows,
191 terminal,
192 })
193}
194
195#[derive(Debug)]
199pub struct InputStreamState {
200 pub lease: StreamLease,
201 pub target: String,
202 pub columns: Vec<String>,
203 pub committed_rid: u64,
206 pub row_count: u64,
207 pub chunk_count: u64,
208 pub snapshot_lsn: u64,
209}
210
211impl InputStreamState {
212 pub fn new(lease: StreamLease, target: String, columns: Vec<String>) -> Self {
213 let snapshot_lsn = lease.snapshot_lsn;
214 Self {
215 lease,
216 target,
217 columns,
218 committed_rid: snapshot_lsn,
219 row_count: 0,
220 chunk_count: 0,
221 snapshot_lsn,
222 }
223 }
224
225 pub fn commit_chunk(
232 &mut self,
233 runtime: &RedDBRuntime,
234 rows: &[JsonValue],
235 ) -> Result<(), (String, String)> {
236 if rows.is_empty() {
237 return Ok(());
238 }
239 let mut positional: Vec<Vec<JsonValue>> = Vec::with_capacity(rows.len());
242 for row in rows {
243 let obj = row.as_object().ok_or_else(|| {
244 (
245 "invalid_row".to_string(),
246 "row must be a JSON object".to_string(),
247 )
248 })?;
249 let mut values = Vec::with_capacity(self.columns.len());
250 for col in &self.columns {
251 values.push(obj.get(col).cloned().unwrap_or(JsonValue::Null));
252 }
253 positional.push(values);
254 }
255 let sql = crate::server::handlers_query::build_insert_sql(
256 &self.target,
257 &self.columns,
258 &positional,
259 )
260 .map_err(|message| ("invalid_row".to_string(), message))?;
261 match runtime.execute_query(&sql) {
262 Ok(_) => {
263 self.row_count += rows.len() as u64;
264 self.committed_rid = runtime.cdc_current_lsn();
265 self.chunk_count += 1;
266 Ok(())
267 }
268 Err(err) => Err(("chunk_commit_failed".to_string(), err.to_string())),
269 }
270 }
271}
272
273#[derive(Default)]
278pub struct InputStreamRegistry {
279 inner: HashMap<u16, InputStreamState>,
280}
281
282impl InputStreamRegistry {
283 pub fn new() -> Self {
284 Self::default()
285 }
286
287 pub fn register(
291 &mut self,
292 stream_id: u16,
293 state: InputStreamState,
294 ) -> Result<(), RegisterError> {
295 if stream_id == 0 {
296 return Err(RegisterError::ReservedStreamId);
297 }
298 if self.inner.contains_key(&stream_id) {
299 return Err(RegisterError::StreamInUse);
300 }
301 self.inner.insert(stream_id, state);
302 Ok(())
303 }
304
305 pub fn get_mut(&mut self, stream_id: u16) -> Option<&mut InputStreamState> {
306 self.inner.get_mut(&stream_id)
307 }
308
309 pub fn contains(&self, stream_id: u16) -> bool {
310 self.inner.contains_key(&stream_id)
311 }
312
313 pub fn remove(&mut self, stream_id: u16) -> Option<InputStreamState> {
317 self.inner.remove(&stream_id)
318 }
319
320 pub fn active_count(&self) -> usize {
321 self.inner.len()
322 }
323}
324
325pub fn build_input_stream_end_payload(
329 row_count: u64,
330 chunk_count: u64,
331 committed_rid: u64,
332 snapshot_lsn: u64,
333 cancelled: bool,
334) -> Vec<u8> {
335 let mut obj = serde_json::Map::new();
336 let mut stats = serde_json::Map::new();
337 stats.insert("row_count".to_string(), JsonValue::Number(row_count as f64));
338 stats.insert(
339 "chunk_count".to_string(),
340 JsonValue::Number(chunk_count as f64),
341 );
342 stats.insert(
343 "committed_rid".to_string(),
344 JsonValue::Number(committed_rid as f64),
345 );
346 stats.insert(
347 "snapshot_lsn".to_string(),
348 JsonValue::Number(snapshot_lsn as f64),
349 );
350 stats.insert("cancelled".to_string(), JsonValue::Bool(cancelled));
351 obj.insert("stats".to_string(), JsonValue::Object(stats));
352 serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default()
353}
354
355pub fn build_input_stream_error_payload(
359 code: &str,
360 message: &str,
361 chunk_seq: u64,
362 recoverable_rid: u64,
363) -> Vec<u8> {
364 let mut obj = serde_json::Map::new();
365 obj.insert("code".to_string(), JsonValue::String(code.to_string()));
366 obj.insert(
367 "message".to_string(),
368 JsonValue::String(message.to_string()),
369 );
370 obj.insert("chunk_seq".to_string(), JsonValue::Number(chunk_seq as f64));
371 obj.insert(
372 "recoverable_rid".to_string(),
373 JsonValue::Number(recoverable_rid as f64),
374 );
375 serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default()
376}
377
378pub fn build_input_stream_error_frame(
381 correlation_id: u64,
382 stream_id: u16,
383 code: &str,
384 message: &str,
385 chunk_seq: u64,
386 recoverable_rid: u64,
387) -> std::io::Result<Frame> {
388 FrameBuilder::reply_to(correlation_id)
389 .kind(MessageKind::StreamError)
390 .stream_id(stream_id)
391 .payload(build_input_stream_error_payload(
392 code,
393 message,
394 chunk_seq,
395 recoverable_rid,
396 ))
397 .build()
398 .map_err(|e| std::io::Error::other(format!("build input StreamError: {e}")))
399}
400
401pub fn build_input_stream_end_frame(
403 correlation_id: u64,
404 stream_id: u16,
405 row_count: u64,
406 chunk_count: u64,
407 committed_rid: u64,
408 snapshot_lsn: u64,
409 cancelled: bool,
410) -> std::io::Result<Frame> {
411 FrameBuilder::reply_to(correlation_id)
412 .kind(MessageKind::StreamEnd)
413 .stream_id(stream_id)
414 .payload(build_input_stream_end_payload(
415 row_count,
416 chunk_count,
417 committed_rid,
418 snapshot_lsn,
419 cancelled,
420 ))
421 .build()
422 .map_err(|e| std::io::Error::other(format!("build input StreamEnd: {e}")))
423}
424
425pub fn open_input_lease(
429 config: StreamConfig,
430 snapshot_lsn: u64,
431 in_transaction: bool,
432 clock: &dyn Clock,
433) -> Result<StreamLease, OpenStreamError> {
434 crate::server::output_stream::open_stream(config, snapshot_lsn, in_transaction, clock)
435}
436
437#[cfg(test)]
438mod tests {
439 use super::*;
440
441 #[test]
442 fn detects_input_direction() {
443 assert!(open_stream_is_input(
444 br#"{"direction":"in","target":"t","columns":["a"]}"#
445 ));
446 assert!(open_stream_is_input(br#"{"direction":"IN"}"#));
447 assert!(!open_stream_is_input(br#"{"sql":"SELECT 1"}"#));
449 assert!(!open_stream_is_input(br#"{"direction":"out"}"#));
450 assert!(!open_stream_is_input(b"not json"));
451 }
452
453 #[test]
454 fn parse_open_input_accepts_target_and_columns() {
455 let req =
456 parse_open_input(br#"{"direction":"in","target":"events","columns":["id","name"]}"#)
457 .unwrap();
458 assert_eq!(req.target, "events");
459 assert_eq!(req.columns, vec!["id".to_string(), "name".to_string()]);
460 }
461
462 #[test]
463 fn parse_open_input_rejects_missing_target() {
464 assert!(matches!(
465 parse_open_input(br#"{"direction":"in","columns":["a"]}"#),
466 Err(OpenInputParseError::MissingTarget)
467 ));
468 }
469
470 #[test]
471 fn parse_open_input_rejects_unsafe_target() {
472 assert!(matches!(
473 parse_open_input(br#"{"direction":"in","target":"t;DROP","columns":["a"]}"#),
474 Err(OpenInputParseError::UnsafeTarget)
475 ));
476 }
477
478 #[test]
479 fn parse_open_input_rejects_empty_or_missing_columns() {
480 assert!(matches!(
481 parse_open_input(br#"{"direction":"in","target":"t","columns":[]}"#),
482 Err(OpenInputParseError::EmptyColumns)
483 ));
484 assert!(matches!(
485 parse_open_input(br#"{"direction":"in","target":"t"}"#),
486 Err(OpenInputParseError::MissingColumns)
487 ));
488 }
489
490 #[test]
491 fn parse_open_input_rejects_unsafe_column() {
492 assert!(matches!(
493 parse_open_input(br#"{"direction":"in","target":"t","columns":["ok","b ad"]}"#),
494 Err(OpenInputParseError::UnsafeColumn)
495 ));
496 }
497
498 #[test]
499 fn parse_chunk_extracts_rows_seq_terminal() {
500 let chunk =
501 parse_input_chunk(br#"{"seq":3,"rows":[{"id":1},{"id":2}],"terminal":true}"#).unwrap();
502 assert_eq!(chunk.seq, 3);
503 assert_eq!(chunk.rows.len(), 2);
504 assert!(chunk.terminal);
505 }
506
507 #[test]
508 fn parse_chunk_allows_bare_terminal() {
509 let chunk = parse_input_chunk(br#"{"terminal":true}"#).unwrap();
510 assert!(chunk.rows.is_empty());
511 assert!(chunk.terminal);
512 assert_eq!(chunk.seq, 0);
513 }
514
515 #[test]
516 fn parse_chunk_rejects_non_array_rows() {
517 assert!(matches!(
518 parse_input_chunk(br#"{"rows":5}"#),
519 Err(ChunkParseError::RowsNotArray)
520 ));
521 }
522
523 #[test]
524 fn registry_register_rejects_reserved_and_duplicate() {
525 let mut reg = InputStreamRegistry::new();
526 let lease = StreamLease {
527 id: 1,
528 lease_handle: "h".to_string(),
529 snapshot_lsn: 10,
530 opened_at_ms: 0,
531 config: StreamConfig::default(),
532 };
533 assert!(matches!(
534 reg.register(
535 0,
536 InputStreamState::new(
537 StreamLease {
538 id: 2,
539 lease_handle: "h2".to_string(),
540 snapshot_lsn: 10,
541 opened_at_ms: 0,
542 config: StreamConfig::default(),
543 },
544 "t".to_string(),
545 vec!["a".to_string()],
546 )
547 ),
548 Err(RegisterError::ReservedStreamId)
549 ));
550 reg.register(
551 5,
552 InputStreamState::new(lease, "t".to_string(), vec!["a".to_string()]),
553 )
554 .unwrap();
555 assert!(reg.contains(5));
556 assert!(matches!(
557 reg.register(
558 5,
559 InputStreamState::new(
560 StreamLease {
561 id: 3,
562 lease_handle: "h3".to_string(),
563 snapshot_lsn: 10,
564 opened_at_ms: 0,
565 config: StreamConfig::default(),
566 },
567 "t".to_string(),
568 vec!["a".to_string()],
569 )
570 ),
571 Err(RegisterError::StreamInUse)
572 ));
573 assert_eq!(reg.active_count(), 1);
574 assert!(reg.remove(5).is_some());
575 assert!(reg.remove(5).is_none());
576 }
577
578 #[test]
579 fn end_payload_carries_committed_rid_range_and_stats() {
580 let bytes = build_input_stream_end_payload(3, 2, 42, 40, false);
581 let v: JsonValue = serde_json::from_slice(&bytes).unwrap();
582 let stats = v.as_object().unwrap().get("stats").unwrap();
583 assert_eq!(stats.get("row_count").and_then(|x| x.as_u64()), Some(3));
584 assert_eq!(stats.get("chunk_count").and_then(|x| x.as_u64()), Some(2));
585 assert_eq!(
586 stats.get("committed_rid").and_then(|x| x.as_u64()),
587 Some(42)
588 );
589 assert_eq!(stats.get("snapshot_lsn").and_then(|x| x.as_u64()), Some(40));
590 assert_eq!(
591 stats.get("cancelled").and_then(|x| x.as_bool()),
592 Some(false)
593 );
594 }
595
596 #[test]
597 fn error_payload_carries_recoverable_rid_and_chunk_seq() {
598 let bytes = build_input_stream_error_payload("invalid_row", "bad", 2, 41);
599 let v: JsonValue = serde_json::from_slice(&bytes).unwrap();
600 let obj = v.as_object().unwrap();
601 assert_eq!(
602 obj.get("code").and_then(|x| x.as_str()),
603 Some("invalid_row")
604 );
605 assert_eq!(obj.get("chunk_seq").and_then(|x| x.as_u64()), Some(2));
606 assert_eq!(
607 obj.get("recoverable_rid").and_then(|x| x.as_u64()),
608 Some(41)
609 );
610 }
611}