1use std::collections::HashMap;
27use std::sync::Arc;
28
29use tokio::sync::{oneshot, Mutex};
30
31use crate::runtime::RedDBRuntime;
32use crate::serde_json::{self, Value as JsonValue};
33use crate::server::output_stream::{
34 self as outs, Clock, OpenStreamError, StreamConfig, SystemClock,
35};
36use reddb_wire::redwire::frame::{Frame, MessageKind};
37
38use super::codec::encode_frame;
39use super::FrameBuilder;
40
41#[derive(Debug, Clone, PartialEq, Eq)]
50pub struct OpenStreamRequest {
51 pub sql: String,
52 pub opts_raw: Vec<u8>,
53}
54
55#[derive(Debug, Clone, PartialEq, Eq)]
56pub enum OpenStreamParseError {
57 NotJson,
58 NotObject,
59 MissingSql,
60 EmptySql,
61}
62
63impl OpenStreamParseError {
64 pub fn code(&self) -> &'static str {
65 match self {
66 Self::NotJson | Self::NotObject => "open_stream_invalid_payload",
67 Self::MissingSql | Self::EmptySql => "open_stream_missing_sql",
68 }
69 }
70 pub fn message(&self) -> &'static str {
71 match self {
72 Self::NotJson => "OpenStream payload must be JSON",
73 Self::NotObject => "OpenStream payload must be a JSON object",
74 Self::MissingSql => "OpenStream payload missing 'sql' string field",
75 Self::EmptySql => "OpenStream payload 'sql' must be non-empty",
76 }
77 }
78}
79
80pub fn parse_open_stream(payload: &[u8]) -> Result<OpenStreamRequest, OpenStreamParseError> {
81 let v: JsonValue =
82 serde_json::from_slice(payload).map_err(|_| OpenStreamParseError::NotJson)?;
83 let obj = v.as_object().ok_or(OpenStreamParseError::NotObject)?;
84 let sql = obj
85 .get("sql")
86 .and_then(|x| x.as_str())
87 .ok_or(OpenStreamParseError::MissingSql)?;
88 if sql.is_empty() {
89 return Err(OpenStreamParseError::EmptySql);
90 }
91 let opts_raw = obj
92 .get("opts")
93 .map(|v| serde_json::to_vec(v).unwrap_or_default())
94 .unwrap_or_default();
95 Ok(OpenStreamRequest {
96 sql: sql.to_string(),
97 opts_raw,
98 })
99}
100
101#[derive(Debug, Clone, Default, PartialEq, Eq)]
104pub struct StreamCancelRequest {
105 pub reason: Option<String>,
106}
107
108pub fn parse_stream_cancel(payload: &[u8]) -> StreamCancelRequest {
109 if payload.is_empty() {
110 return StreamCancelRequest::default();
111 }
112 let v: JsonValue = match serde_json::from_slice(payload) {
113 Ok(v) => v,
114 Err(_) => return StreamCancelRequest::default(),
115 };
116 let reason = v
117 .as_object()
118 .and_then(|o| o.get("reason"))
119 .and_then(|x| x.as_str())
120 .map(|s| s.to_string());
121 StreamCancelRequest { reason }
122}
123
124pub fn build_open_ack_payload(lease_id: u64, snapshot_lsn: u64, resumable: bool) -> Vec<u8> {
125 let mut obj = serde_json::Map::new();
126 obj.insert(
127 "lease_handle".to_string(),
128 JsonValue::String(lease_id.to_string()),
129 );
130 obj.insert("resumable".to_string(), JsonValue::Bool(resumable));
131 obj.insert(
132 "snapshot_lsn".to_string(),
133 JsonValue::Number(snapshot_lsn as f64),
134 );
135 serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default()
136}
137
138pub fn build_stream_chunk_payload(seq: u64, rows: Vec<JsonValue>, terminal: bool) -> Vec<u8> {
139 let mut obj = serde_json::Map::new();
140 obj.insert("seq".to_string(), JsonValue::Number(seq as f64));
141 obj.insert("rows".to_string(), JsonValue::Array(rows));
142 obj.insert("terminal".to_string(), JsonValue::Bool(terminal));
143 serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default()
144}
145
146pub fn build_stream_error_payload(seq: Option<u64>, code: &str, message: &str) -> Vec<u8> {
147 let mut obj = serde_json::Map::new();
148 if let Some(s) = seq {
149 obj.insert("seq".to_string(), JsonValue::Number(s as f64));
150 }
151 obj.insert("code".to_string(), JsonValue::String(code.to_string()));
152 obj.insert(
153 "message".to_string(),
154 JsonValue::String(message.to_string()),
155 );
156 serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default()
157}
158
159pub fn build_stream_end_payload(
160 row_count: u64,
161 lease_id: u64,
162 snapshot_lsn: u64,
163 cancelled: bool,
164) -> Vec<u8> {
165 let mut obj = serde_json::Map::new();
166 let mut stats = serde_json::Map::new();
167 stats.insert("row_count".to_string(), JsonValue::Number(row_count as f64));
168 stats.insert("lease_id".to_string(), JsonValue::Number(lease_id as f64));
169 stats.insert(
170 "snapshot_lsn".to_string(),
171 JsonValue::Number(snapshot_lsn as f64),
172 );
173 stats.insert("cancelled".to_string(), JsonValue::Bool(cancelled));
174 obj.insert("stats".to_string(), JsonValue::Object(stats));
175 serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default()
176}
177
178#[derive(Default)]
183pub struct StreamRegistry {
184 inner: Mutex<HashMap<u16, oneshot::Sender<()>>>,
185}
186
187impl StreamRegistry {
188 pub fn new() -> Self {
189 Self::default()
190 }
191
192 pub async fn register(&self, stream_id: u16) -> Result<oneshot::Receiver<()>, RegisterError> {
196 if stream_id == 0 {
197 return Err(RegisterError::ReservedStreamId);
198 }
199 let mut guard = self.inner.lock().await;
200 if guard.contains_key(&stream_id) {
201 return Err(RegisterError::StreamInUse);
202 }
203 let (tx, rx) = oneshot::channel();
204 guard.insert(stream_id, tx);
205 Ok(rx)
206 }
207
208 pub async fn cancel(&self, stream_id: u16) -> bool {
212 let mut guard = self.inner.lock().await;
213 match guard.remove(&stream_id) {
214 Some(tx) => {
215 let _ = tx.send(());
216 true
217 }
218 None => false,
219 }
220 }
221
222 pub async fn unregister(&self, stream_id: u16) {
225 let mut guard = self.inner.lock().await;
226 guard.remove(&stream_id);
227 }
228
229 pub async fn active_count(&self) -> usize {
230 self.inner.lock().await.len()
231 }
232}
233
234#[derive(Debug, Clone, Copy, PartialEq, Eq)]
235pub enum RegisterError {
236 ReservedStreamId,
237 StreamInUse,
238}
239
240impl RegisterError {
241 pub fn code(&self) -> &'static str {
242 match self {
243 Self::ReservedStreamId => "open_stream_reserved_id",
244 Self::StreamInUse => "open_stream_id_in_use",
245 }
246 }
247 pub fn message(&self) -> &'static str {
248 match self {
249 Self::ReservedStreamId => {
250 "OpenStream cannot use stream_id 0 (reserved for unsolicited)"
251 }
252 Self::StreamInUse => "OpenStream stream_id already has an active stream",
253 }
254 }
255}
256
257pub fn build_stream_error_frame(
261 correlation_id: u64,
262 stream_id: u16,
263 code: &str,
264 message: &str,
265) -> std::io::Result<Frame> {
266 FrameBuilder::reply_to(correlation_id)
267 .kind(MessageKind::StreamError)
268 .stream_id(stream_id)
269 .payload(build_stream_error_payload(None, code, message))
270 .build()
271 .map_err(|e| std::io::Error::other(format!("build StreamError: {e}")))
272}
273
274pub async fn run_output_stream(
284 runtime: Arc<RedDBRuntime>,
285 correlation_id: u64,
286 stream_id: u16,
287 request: OpenStreamRequest,
288 in_transaction: bool,
289 mut cancel_rx: oneshot::Receiver<()>,
290 send: FrameTx,
291) {
292 let clock = SystemClock;
293 let config = StreamConfig::load(&runtime);
294 let snapshot_lsn = runtime.cdc_current_lsn();
295
296 let lease = match outs::open_stream(config, snapshot_lsn, in_transaction, &clock) {
297 Ok(l) => l,
298 Err(OpenStreamError::TransactionActive) => {
299 let err = OpenStreamError::TransactionActive;
300 let frame = match build_stream_error_frame(
301 correlation_id,
302 stream_id,
303 err.code(),
304 err.message(),
305 ) {
306 Ok(f) => f,
307 Err(_) => return,
308 };
309 send.send_frame(frame);
310 return;
311 }
312 };
313
314 let ack = match FrameBuilder::reply_to(correlation_id)
316 .kind(MessageKind::OpenAck)
317 .stream_id(stream_id)
318 .payload(build_open_ack_payload(lease.id, lease.snapshot_lsn, false))
319 .build()
320 {
321 Ok(f) => f,
322 Err(_) => return,
323 };
324 send.send_frame(ack);
325
326 let result = runtime.execute_query(&request.sql);
328
329 let mut seq: u64 = 0;
331 let mut row_count: u64 = 0;
332 let mut cancelled = false;
333 let mut had_error: Option<(String, String)> = None;
334
335 match result {
336 Ok(qr) => {
337 let columns = qr.result.columns.clone();
338 let rows: Vec<JsonValue> = qr
339 .result
340 .records
341 .iter()
342 .map(|r| crate::presentation::query_result_json::unified_record_json(r, &columns))
343 .collect();
344
345 for row in rows {
353 if let Ok(()) = cancel_rx.try_recv() {
355 cancelled = true;
356 break;
357 }
358 if lease.snapshot_expired(clock.now_ms()) {
359 had_error = Some((
360 "snapshot_expired".to_string(),
361 "stream snapshot pin TTL elapsed".to_string(),
362 ));
363 break;
364 }
365 let payload = build_stream_chunk_payload(seq, vec![row], false);
366 let frame = match FrameBuilder::reply_to(correlation_id)
367 .kind(MessageKind::StreamChunk)
368 .stream_id(stream_id)
369 .payload(payload)
370 .build()
371 {
372 Ok(f) => f,
373 Err(_) => break,
374 };
375 send.send_frame(frame);
376 seq += 1;
377 row_count += 1;
378 }
379 let _ = config;
383 }
384 Err(err) => {
385 had_error = Some(("query_failed".to_string(), err.to_string()));
386 }
387 }
388
389 if let Some((code, message)) = had_error {
390 let payload = build_stream_error_payload(Some(seq), &code, &message);
391 if let Ok(frame) = FrameBuilder::reply_to(correlation_id)
392 .kind(MessageKind::StreamError)
393 .stream_id(stream_id)
394 .payload(payload)
395 .build()
396 {
397 send.send_frame(frame);
398 }
399 }
400
401 let end_payload = build_stream_end_payload(row_count, lease.id, lease.snapshot_lsn, cancelled);
405 if let Ok(frame) = FrameBuilder::reply_to(correlation_id)
406 .kind(MessageKind::StreamEnd)
407 .stream_id(stream_id)
408 .payload(end_payload)
409 .build()
410 {
411 send.send_frame(frame);
412 }
413}
414
415#[derive(Clone)]
421pub struct FrameTx {
422 tx: tokio::sync::mpsc::UnboundedSender<Vec<u8>>,
423}
424
425impl FrameTx {
426 pub fn new(tx: tokio::sync::mpsc::UnboundedSender<Vec<u8>>) -> Self {
427 Self { tx }
428 }
429
430 pub fn send_frame(&self, frame: Frame) {
434 let bytes = encode_frame(&frame);
435 let _ = self.tx.send(bytes);
436 }
437}
438
439#[cfg(test)]
440mod tests {
441 use super::*;
442
443 #[test]
444 fn parse_open_stream_accepts_minimal_payload() {
445 let req = parse_open_stream(br#"{"sql":"SELECT 1"}"#).unwrap();
446 assert_eq!(req.sql, "SELECT 1");
447 assert!(req.opts_raw.is_empty());
448 }
449
450 #[test]
451 fn parse_open_stream_captures_opts_opaque() {
452 let req =
453 parse_open_stream(br#"{"sql":"SELECT 1","opts":{"resume_after_rid":42}}"#).unwrap();
454 assert_eq!(req.sql, "SELECT 1");
455 assert!(!req.opts_raw.is_empty());
456 }
457
458 #[test]
459 fn parse_open_stream_rejects_non_object() {
460 assert!(matches!(
461 parse_open_stream(b"\"sql\""),
462 Err(OpenStreamParseError::NotObject)
463 ));
464 }
465
466 #[test]
467 fn parse_open_stream_rejects_missing_sql() {
468 assert!(matches!(
469 parse_open_stream(b"{}"),
470 Err(OpenStreamParseError::MissingSql)
471 ));
472 }
473
474 #[test]
475 fn parse_open_stream_rejects_empty_sql() {
476 assert!(matches!(
477 parse_open_stream(br#"{"sql":""}"#),
478 Err(OpenStreamParseError::EmptySql)
479 ));
480 }
481
482 #[test]
483 fn parse_open_stream_rejects_invalid_json() {
484 assert!(matches!(
485 parse_open_stream(b"{not json"),
486 Err(OpenStreamParseError::NotJson)
487 ));
488 }
489
490 #[test]
491 fn parse_stream_cancel_with_reason() {
492 let r = parse_stream_cancel(br#"{"reason":"client-abort"}"#);
493 assert_eq!(r.reason.as_deref(), Some("client-abort"));
494 }
495
496 #[test]
497 fn parse_stream_cancel_empty_payload_is_default() {
498 assert_eq!(parse_stream_cancel(b""), StreamCancelRequest::default());
499 assert_eq!(parse_stream_cancel(b"{}"), StreamCancelRequest::default());
500 }
501
502 #[test]
503 fn open_ack_payload_round_trips_through_json() {
504 let bytes = build_open_ack_payload(42, 1234, false);
505 let v: JsonValue = serde_json::from_slice(&bytes).unwrap();
506 let obj = v.as_object().unwrap();
507 assert_eq!(obj.get("lease_handle").and_then(|x| x.as_str()), Some("42"));
508 assert_eq!(obj.get("resumable").and_then(|x| x.as_bool()), Some(false));
509 assert_eq!(
510 obj.get("snapshot_lsn").and_then(|x| x.as_f64()),
511 Some(1234.0)
512 );
513 }
514
515 #[test]
516 fn stream_end_payload_carries_cancelled_flag() {
517 let bytes = build_stream_end_payload(5, 7, 99, true);
518 let v: JsonValue = serde_json::from_slice(&bytes).unwrap();
519 let stats = v
520 .as_object()
521 .unwrap()
522 .get("stats")
523 .and_then(|x| x.as_object())
524 .unwrap();
525 assert_eq!(stats.get("row_count").and_then(|x| x.as_f64()), Some(5.0));
526 assert_eq!(stats.get("cancelled").and_then(|x| x.as_bool()), Some(true));
527 }
528
529 #[test]
530 fn stream_error_payload_includes_optional_seq() {
531 let with = build_stream_error_payload(Some(3), "x", "y");
532 let v: JsonValue = serde_json::from_slice(&with).unwrap();
533 assert_eq!(
534 v.as_object().unwrap().get("seq").and_then(|x| x.as_f64()),
535 Some(3.0)
536 );
537
538 let without = build_stream_error_payload(None, "x", "y");
539 let v: JsonValue = serde_json::from_slice(&without).unwrap();
540 assert!(v.as_object().unwrap().get("seq").is_none());
541 }
542
543 #[tokio::test]
544 async fn registry_rejects_reserved_id_and_duplicates() {
545 let r = StreamRegistry::new();
546 assert!(matches!(
547 r.register(0).await,
548 Err(RegisterError::ReservedStreamId)
549 ));
550 let _rx = r.register(1).await.unwrap();
551 assert!(matches!(
552 r.register(1).await,
553 Err(RegisterError::StreamInUse)
554 ));
555 assert_eq!(r.active_count().await, 1);
556 }
557
558 #[tokio::test]
559 async fn registry_cancel_signals_named_stream_only() {
560 let r = StreamRegistry::new();
562 let rx1 = r.register(1).await.unwrap();
563 let mut rx2 = r.register(2).await.unwrap();
564 assert!(r.cancel(1).await);
565 assert!(rx1.await.is_ok());
567 match rx2.try_recv() {
569 Err(tokio::sync::oneshot::error::TryRecvError::Empty) => {}
570 other => panic!("stream 2 should not be cancelled: {other:?}"),
571 }
572 assert_eq!(r.active_count().await, 1);
573 }
574
575 #[tokio::test]
576 async fn registry_cancel_unknown_returns_false() {
577 let r = StreamRegistry::new();
578 assert!(!r.cancel(99).await);
579 }
580
581 #[tokio::test]
582 async fn registry_unregister_is_idempotent() {
583 let r = StreamRegistry::new();
584 let _rx = r.register(1).await.unwrap();
585 r.unregister(1).await;
586 r.unregister(1).await;
587 assert_eq!(r.active_count().await, 0);
588 }
589
590 #[test]
591 fn build_stream_error_frame_carries_stream_id_and_correlation() {
592 let frame = build_stream_error_frame(99, 7, "unknown_stream", "no such stream").unwrap();
593 assert_eq!(frame.kind, MessageKind::StreamError);
594 assert_eq!(frame.stream_id, 7);
595 assert_eq!(frame.correlation_id, 99);
596 }
597}