1use serde_json::Value as JsonValue;
4
5use super::{BuildError, Frame, FrameBuilder, MessageKind};
6
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub struct OpenStreamRequest {
9 pub sql: String,
10 pub opts_raw: Vec<u8>,
11}
12
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum OpenStreamParseError {
15 NotJson,
16 NotObject,
17 MissingSql,
18 EmptySql,
19}
20
21impl OpenStreamParseError {
22 pub fn code(&self) -> &'static str {
23 match self {
24 Self::NotJson | Self::NotObject => "open_stream_invalid_payload",
25 Self::MissingSql | Self::EmptySql => "open_stream_missing_sql",
26 }
27 }
28
29 pub fn message(&self) -> &'static str {
30 match self {
31 Self::NotJson => "OpenStream payload must be JSON",
32 Self::NotObject => "OpenStream payload must be a JSON object",
33 Self::MissingSql => "OpenStream payload missing 'sql' string field",
34 Self::EmptySql => "OpenStream payload 'sql' must be non-empty",
35 }
36 }
37}
38
39pub fn parse_open_stream(payload: &[u8]) -> Result<OpenStreamRequest, OpenStreamParseError> {
40 let v: JsonValue =
41 serde_json::from_slice(payload).map_err(|_| OpenStreamParseError::NotJson)?;
42 let obj = v.as_object().ok_or(OpenStreamParseError::NotObject)?;
43 let sql = obj
44 .get("sql")
45 .and_then(|x| x.as_str())
46 .ok_or(OpenStreamParseError::MissingSql)?;
47 if sql.is_empty() {
48 return Err(OpenStreamParseError::EmptySql);
49 }
50 let opts_raw = obj
51 .get("opts")
52 .map(|v| serde_json::to_vec(v).unwrap_or_default())
53 .unwrap_or_default();
54 Ok(OpenStreamRequest {
55 sql: sql.to_string(),
56 opts_raw,
57 })
58}
59
60#[derive(Debug, Clone, Default, PartialEq, Eq)]
61pub struct StreamCancelRequest {
62 pub reason: Option<String>,
63}
64
65pub fn parse_stream_cancel(payload: &[u8]) -> StreamCancelRequest {
66 if payload.is_empty() {
67 return StreamCancelRequest::default();
68 }
69 let v: JsonValue = match serde_json::from_slice(payload) {
70 Ok(v) => v,
71 Err(_) => return StreamCancelRequest::default(),
72 };
73 let reason = v
74 .as_object()
75 .and_then(|o| o.get("reason"))
76 .and_then(|x| x.as_str())
77 .map(|s| s.to_string());
78 StreamCancelRequest { reason }
79}
80
81pub fn build_open_stream_payload(request: &OpenStreamRequest) -> Vec<u8> {
82 let mut obj = serde_json::Map::new();
83 obj.insert("sql".to_string(), JsonValue::String(request.sql.clone()));
84 if !request.opts_raw.is_empty() {
85 let opts = serde_json::from_slice(&request.opts_raw).unwrap_or(JsonValue::Null);
86 obj.insert("opts".to_string(), opts);
87 }
88 serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default()
89}
90
91pub fn build_open_stream_frame(
92 correlation_id: u64,
93 stream_id: u16,
94 request: &OpenStreamRequest,
95) -> Result<Frame, BuildError> {
96 FrameBuilder::request(correlation_id)
97 .kind(MessageKind::OpenStream)
98 .stream_id(stream_id)
99 .payload(build_open_stream_payload(request))
100 .build()
101}
102
103pub fn build_open_ack_payload(lease_id: u64, snapshot_lsn: u64, resumable: bool) -> Vec<u8> {
104 let mut obj = serde_json::Map::new();
105 obj.insert(
106 "lease_handle".to_string(),
107 JsonValue::String(lease_id.to_string()),
108 );
109 obj.insert("resumable".to_string(), JsonValue::Bool(resumable));
110 obj.insert(
111 "snapshot_lsn".to_string(),
112 JsonValue::Number(snapshot_lsn.into()),
113 );
114 serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default()
115}
116
117pub fn build_open_ack_frame(
118 correlation_id: u64,
119 stream_id: u16,
120 lease_id: u64,
121 snapshot_lsn: u64,
122 resumable: bool,
123) -> Result<Frame, BuildError> {
124 FrameBuilder::reply_to(correlation_id)
125 .kind(MessageKind::OpenAck)
126 .stream_id(stream_id)
127 .payload(build_open_ack_payload(lease_id, snapshot_lsn, resumable))
128 .build()
129}
130
131pub fn build_stream_chunk_payload(seq: u64, rows: Vec<JsonValue>, terminal: bool) -> Vec<u8> {
132 let mut obj = serde_json::Map::new();
133 obj.insert("seq".to_string(), JsonValue::Number(seq.into()));
134 obj.insert("rows".to_string(), JsonValue::Array(rows));
135 obj.insert("terminal".to_string(), JsonValue::Bool(terminal));
136 serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default()
137}
138
139pub fn build_stream_chunk_payload_from_json_bytes(
140 seq: u64,
141 rows: Vec<Vec<u8>>,
142 terminal: bool,
143) -> Vec<u8> {
144 let rows = rows
145 .into_iter()
146 .map(|row| serde_json::from_slice(&row).unwrap_or(JsonValue::Null))
147 .collect();
148 build_stream_chunk_payload(seq, rows, terminal)
149}
150
151pub fn build_stream_chunk_frame_from_json_bytes(
152 correlation_id: u64,
153 stream_id: u16,
154 seq: u64,
155 rows: Vec<Vec<u8>>,
156 terminal: bool,
157) -> Result<Frame, BuildError> {
158 FrameBuilder::reply_to(correlation_id)
159 .kind(MessageKind::StreamChunk)
160 .stream_id(stream_id)
161 .payload(build_stream_chunk_payload_from_json_bytes(
162 seq, rows, terminal,
163 ))
164 .build()
165}
166
167pub fn build_stream_error_payload(seq: Option<u64>, code: &str, message: &str) -> Vec<u8> {
168 let mut obj = serde_json::Map::new();
169 if let Some(s) = seq {
170 obj.insert("seq".to_string(), JsonValue::Number(s.into()));
171 }
172 obj.insert("code".to_string(), JsonValue::String(code.to_string()));
173 obj.insert(
174 "message".to_string(),
175 JsonValue::String(message.to_string()),
176 );
177 serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default()
178}
179
180pub fn build_stream_error_frame(
181 correlation_id: u64,
182 stream_id: u16,
183 seq: Option<u64>,
184 code: &str,
185 message: &str,
186) -> Result<Frame, BuildError> {
187 FrameBuilder::reply_to(correlation_id)
188 .kind(MessageKind::StreamError)
189 .stream_id(stream_id)
190 .payload(build_stream_error_payload(seq, code, message))
191 .build()
192}
193
194pub fn build_stream_end_payload(
195 row_count: u64,
196 lease_id: u64,
197 snapshot_lsn: u64,
198 cancelled: bool,
199) -> Vec<u8> {
200 let mut obj = serde_json::Map::new();
201 let mut stats = serde_json::Map::new();
202 stats.insert("row_count".to_string(), JsonValue::Number(row_count.into()));
203 stats.insert("lease_id".to_string(), JsonValue::Number(lease_id.into()));
204 stats.insert(
205 "snapshot_lsn".to_string(),
206 JsonValue::Number(snapshot_lsn.into()),
207 );
208 stats.insert("cancelled".to_string(), JsonValue::Bool(cancelled));
209 obj.insert("stats".to_string(), JsonValue::Object(stats));
210 serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default()
211}
212
213pub fn build_stream_end_frame(
214 correlation_id: u64,
215 stream_id: u16,
216 row_count: u64,
217 lease_id: u64,
218 snapshot_lsn: u64,
219 cancelled: bool,
220) -> Result<Frame, BuildError> {
221 FrameBuilder::reply_to(correlation_id)
222 .kind(MessageKind::StreamEnd)
223 .stream_id(stream_id)
224 .payload(build_stream_end_payload(
225 row_count,
226 lease_id,
227 snapshot_lsn,
228 cancelled,
229 ))
230 .build()
231}
232
233pub fn open_stream_is_input(payload: &[u8]) -> bool {
234 serde_json::from_slice::<JsonValue>(payload)
235 .ok()
236 .and_then(|v| {
237 v.as_object()
238 .and_then(|o| o.get("direction"))
239 .and_then(|d| d.as_str())
240 .map(|s| s.eq_ignore_ascii_case("in"))
241 })
242 .unwrap_or(false)
243}
244
245#[derive(Debug, Clone, PartialEq, Eq)]
246pub struct OpenInputRequest {
247 pub target: String,
248 pub columns: Vec<String>,
249}
250
251#[derive(Debug, Clone, PartialEq, Eq)]
252pub enum OpenInputParseError {
253 NotJson,
254 NotObject,
255 MissingTarget,
256 UnsafeTarget,
257 MissingColumns,
258 EmptyColumns,
259 UnsafeColumn,
260}
261
262impl OpenInputParseError {
263 pub fn code(&self) -> &'static str {
264 match self {
265 Self::NotJson | Self::NotObject => "open_stream_invalid_payload",
266 Self::MissingTarget | Self::UnsafeTarget => "open_stream_invalid_target",
267 Self::MissingColumns | Self::EmptyColumns | Self::UnsafeColumn => {
268 "open_stream_invalid_columns"
269 }
270 }
271 }
272
273 pub fn message(&self) -> &'static str {
274 match self {
275 Self::NotJson => "OpenStream payload must be JSON",
276 Self::NotObject => "OpenStream payload must be a JSON object",
277 Self::MissingTarget => "input OpenStream payload missing 'target' string field",
278 Self::UnsafeTarget => "input OpenStream 'target' is not a safe SQL identifier",
279 Self::MissingColumns => "input OpenStream payload missing 'columns' array field",
280 Self::EmptyColumns => "input OpenStream 'columns' must be a non-empty array",
281 Self::UnsafeColumn => "input OpenStream 'columns' entry is not a safe SQL identifier",
282 }
283 }
284}
285
286pub fn parse_open_input(payload: &[u8]) -> Result<OpenInputRequest, OpenInputParseError> {
287 let v: JsonValue = serde_json::from_slice(payload).map_err(|_| OpenInputParseError::NotJson)?;
288 let obj = v.as_object().ok_or(OpenInputParseError::NotObject)?;
289 let target = obj
290 .get("target")
291 .and_then(|x| x.as_str())
292 .ok_or(OpenInputParseError::MissingTarget)?;
293 if !is_safe_sql_identifier(target) {
294 return Err(OpenInputParseError::UnsafeTarget);
295 }
296 let columns_v = obj
297 .get("columns")
298 .and_then(|x| x.as_array())
299 .ok_or(OpenInputParseError::MissingColumns)?;
300 if columns_v.is_empty() {
301 return Err(OpenInputParseError::EmptyColumns);
302 }
303 let mut columns = Vec::with_capacity(columns_v.len());
304 for c in columns_v {
305 let name = c.as_str().ok_or(OpenInputParseError::UnsafeColumn)?;
306 if !is_safe_sql_identifier(name) {
307 return Err(OpenInputParseError::UnsafeColumn);
308 }
309 columns.push(name.to_string());
310 }
311 Ok(OpenInputRequest {
312 target: target.to_string(),
313 columns,
314 })
315}
316
317fn is_safe_sql_identifier(name: &str) -> bool {
318 let mut chars = name.chars();
319 match chars.next() {
320 Some(c) if c.is_ascii_alphabetic() || c == '_' => {}
321 _ => return false,
322 }
323 chars.all(|c| c.is_ascii_alphanumeric() || c == '_')
324}
325
326#[derive(Debug, Clone, PartialEq)]
327pub struct InputChunk {
328 pub seq: u64,
329 pub rows: Vec<JsonValue>,
330 pub terminal: bool,
331}
332
333#[derive(Debug, Clone, PartialEq, Eq)]
334pub struct InputChunkJson {
335 pub seq: u64,
336 pub rows_json: Vec<Vec<u8>>,
337 pub terminal: bool,
338}
339
340#[derive(Debug, Clone, PartialEq, Eq)]
341pub enum ChunkParseError {
342 NotJson,
343 NotObject,
344 RowsNotArray,
345}
346
347impl ChunkParseError {
348 pub fn code(&self) -> &'static str {
349 "invalid_chunk"
350 }
351
352 pub fn message(&self) -> &'static str {
353 match self {
354 Self::NotJson => "StreamChunk payload must be JSON",
355 Self::NotObject => "StreamChunk payload must be a JSON object",
356 Self::RowsNotArray => "StreamChunk 'rows' must be an array",
357 }
358 }
359}
360
361pub fn parse_input_chunk(payload: &[u8]) -> Result<InputChunk, ChunkParseError> {
362 let v: JsonValue = serde_json::from_slice(payload).map_err(|_| ChunkParseError::NotJson)?;
363 let obj = v.as_object().ok_or(ChunkParseError::NotObject)?;
364 let seq = obj.get("seq").and_then(|x| x.as_u64()).unwrap_or(0);
365 let terminal = obj
366 .get("terminal")
367 .and_then(|x| x.as_bool())
368 .unwrap_or(false);
369 let rows = match obj.get("rows") {
370 None | Some(JsonValue::Null) => Vec::new(),
371 Some(JsonValue::Array(arr)) => arr.clone(),
372 Some(_) => return Err(ChunkParseError::RowsNotArray),
373 };
374 Ok(InputChunk {
375 seq,
376 rows,
377 terminal,
378 })
379}
380
381pub fn parse_input_chunk_json(payload: &[u8]) -> Result<InputChunkJson, ChunkParseError> {
382 let chunk = parse_input_chunk(payload)?;
383 let rows_json = chunk
384 .rows
385 .iter()
386 .map(|row| serde_json::to_vec(row).unwrap_or_default())
387 .collect();
388 Ok(InputChunkJson {
389 seq: chunk.seq,
390 rows_json,
391 terminal: chunk.terminal,
392 })
393}
394
395pub fn build_input_stream_end_payload(
396 row_count: u64,
397 chunk_count: u64,
398 committed_rid: u64,
399 snapshot_lsn: u64,
400 cancelled: bool,
401) -> Vec<u8> {
402 let mut obj = serde_json::Map::new();
403 let mut stats = serde_json::Map::new();
404 stats.insert("row_count".to_string(), JsonValue::Number(row_count.into()));
405 stats.insert(
406 "chunk_count".to_string(),
407 JsonValue::Number(chunk_count.into()),
408 );
409 stats.insert(
410 "committed_rid".to_string(),
411 JsonValue::Number(committed_rid.into()),
412 );
413 stats.insert(
414 "snapshot_lsn".to_string(),
415 JsonValue::Number(snapshot_lsn.into()),
416 );
417 stats.insert("cancelled".to_string(), JsonValue::Bool(cancelled));
418 obj.insert("stats".to_string(), JsonValue::Object(stats));
419 serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default()
420}
421
422pub fn build_input_stream_end_frame(
423 correlation_id: u64,
424 stream_id: u16,
425 row_count: u64,
426 chunk_count: u64,
427 committed_rid: u64,
428 snapshot_lsn: u64,
429 cancelled: bool,
430) -> Result<Frame, BuildError> {
431 FrameBuilder::reply_to(correlation_id)
432 .kind(MessageKind::StreamEnd)
433 .stream_id(stream_id)
434 .payload(build_input_stream_end_payload(
435 row_count,
436 chunk_count,
437 committed_rid,
438 snapshot_lsn,
439 cancelled,
440 ))
441 .build()
442}
443
444pub fn build_input_stream_error_payload(
445 code: &str,
446 message: &str,
447 chunk_seq: u64,
448 recoverable_rid: u64,
449) -> Vec<u8> {
450 let mut obj = serde_json::Map::new();
451 obj.insert("code".to_string(), JsonValue::String(code.to_string()));
452 obj.insert(
453 "message".to_string(),
454 JsonValue::String(message.to_string()),
455 );
456 obj.insert("chunk_seq".to_string(), JsonValue::Number(chunk_seq.into()));
457 obj.insert(
458 "recoverable_rid".to_string(),
459 JsonValue::Number(recoverable_rid.into()),
460 );
461 serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default()
462}
463
464pub fn build_input_stream_error_frame(
465 correlation_id: u64,
466 stream_id: u16,
467 code: &str,
468 message: &str,
469 chunk_seq: u64,
470 recoverable_rid: u64,
471) -> Result<Frame, BuildError> {
472 FrameBuilder::reply_to(correlation_id)
473 .kind(MessageKind::StreamError)
474 .stream_id(stream_id)
475 .payload(build_input_stream_error_payload(
476 code,
477 message,
478 chunk_seq,
479 recoverable_rid,
480 ))
481 .build()
482}
483
484#[cfg(test)]
485mod tests {
486 use super::*;
487
488 #[test]
489 fn output_open_stream_contract_parses_opts() {
490 let req = parse_open_stream(br#"{"sql":"SELECT 1","opts":{"resume_after_rid":42}}"#)
491 .expect("parse open stream");
492 assert_eq!(req.sql, "SELECT 1");
493 assert!(!req.opts_raw.is_empty());
494 }
495
496 #[test]
497 fn output_open_stream_builder_round_trips_request() {
498 let request = OpenStreamRequest {
499 sql: "SELECT id FROM widgets".to_string(),
500 opts_raw: br#"{"resume_after_rid":42}"#.to_vec(),
501 };
502 let frame = build_open_stream_frame(12, 4, &request).unwrap();
503 assert_eq!(frame.kind, MessageKind::OpenStream);
504 assert_eq!(frame.correlation_id, 12);
505 assert_eq!(frame.stream_id, 4);
506 let parsed = parse_open_stream(&frame.payload).unwrap();
507 assert_eq!(parsed.sql, request.sql);
508 assert_eq!(
509 serde_json::from_slice::<JsonValue>(&parsed.opts_raw).unwrap(),
510 serde_json::from_slice::<JsonValue>(&request.opts_raw).unwrap()
511 );
512 }
513
514 #[test]
515 fn input_open_contract_rejects_unsafe_identifiers() {
516 assert_eq!(
517 parse_open_input(br#"{"direction":"in","target":"t;drop","columns":["id"]}"#),
518 Err(OpenInputParseError::UnsafeTarget)
519 );
520 assert_eq!(
521 parse_open_input(br#"{"direction":"in","target":"t","columns":["bad name"]}"#),
522 Err(OpenInputParseError::UnsafeColumn)
523 );
524 }
525
526 #[test]
527 fn input_chunk_json_preserves_rows_as_json_bytes() {
528 let chunk =
529 parse_input_chunk_json(br#"{"seq":3,"rows":[{"id":1}],"terminal":true}"#).unwrap();
530 assert_eq!(chunk.seq, 3);
531 assert_eq!(chunk.rows_json.len(), 1);
532 assert!(std::str::from_utf8(&chunk.rows_json[0])
533 .unwrap()
534 .contains("\"id\""));
535 assert!(chunk.terminal);
536 }
537
538 #[test]
539 fn stream_payload_builders_emit_json_objects() {
540 let ack = build_open_ack_payload(42, 7, false);
541 let value: JsonValue = serde_json::from_slice(&ack).unwrap();
542 assert_eq!(value["lease_handle"], "42");
543 assert_eq!(value["resumable"], false);
544 assert_eq!(value["snapshot_lsn"], 7);
545
546 let end = build_stream_end_payload(5, 42, 7, true);
547 let value: JsonValue = serde_json::from_slice(&end).unwrap();
548 assert_eq!(value["stats"]["row_count"], 5);
549 assert_eq!(value["stats"]["lease_id"], 42);
550 assert_eq!(value["stats"]["snapshot_lsn"], 7);
551 assert_eq!(value["stats"]["cancelled"], true);
552
553 let with_seq = build_stream_error_payload(Some(3), "x", "y");
554 let value: JsonValue = serde_json::from_slice(&with_seq).unwrap();
555 assert_eq!(value["seq"], 3);
556 assert_eq!(value["code"], "x");
557 assert_eq!(value["message"], "y");
558
559 let without_seq = build_stream_error_payload(None, "x", "y");
560 let value: JsonValue = serde_json::from_slice(&without_seq).unwrap();
561 assert!(value.as_object().unwrap().get("seq").is_none());
562 }
563
564 #[test]
565 fn input_stream_payload_builders_emit_committed_range_and_error_cursor() {
566 let end = build_input_stream_end_payload(3, 2, 42, 40, false);
567 let value: JsonValue = serde_json::from_slice(&end).unwrap();
568 assert_eq!(value["stats"]["row_count"], 3);
569 assert_eq!(value["stats"]["chunk_count"], 2);
570 assert_eq!(value["stats"]["committed_rid"], 42);
571 assert_eq!(value["stats"]["snapshot_lsn"], 40);
572 assert_eq!(value["stats"]["cancelled"], false);
573
574 let error = build_input_stream_error_payload("invalid_row", "bad", 2, 41);
575 let value: JsonValue = serde_json::from_slice(&error).unwrap();
576 assert_eq!(value["code"], "invalid_row");
577 assert_eq!(value["message"], "bad");
578 assert_eq!(value["chunk_seq"], 2);
579 assert_eq!(value["recoverable_rid"], 41);
580 }
581
582 #[test]
583 fn stream_frame_builders_echo_stream_and_correlation() {
584 let ack = build_open_ack_frame(99, 7, 42, 100, false).unwrap();
585 assert_eq!(ack.kind, MessageKind::OpenAck);
586 assert_eq!(ack.correlation_id, 99);
587 assert_eq!(ack.stream_id, 7);
588
589 let chunk = build_stream_chunk_frame_from_json_bytes(
590 99,
591 7,
592 1,
593 vec![br#"{"id":1}"#.to_vec()],
594 false,
595 )
596 .unwrap();
597 assert_eq!(chunk.kind, MessageKind::StreamChunk);
598 assert_eq!(chunk.stream_id, 7);
599
600 let error = build_stream_error_frame(99, 7, Some(1), "bad", "failed").unwrap();
601 assert_eq!(error.kind, MessageKind::StreamError);
602 assert_eq!(error.correlation_id, 99);
603
604 let end = build_stream_end_frame(99, 7, 5, 42, 100, true).unwrap();
605 assert_eq!(end.kind, MessageKind::StreamEnd);
606 assert_eq!(end.stream_id, 7);
607
608 let input_error =
609 build_input_stream_error_frame(99, 8, "invalid_row", "bad", 2, 41).unwrap();
610 assert_eq!(input_error.kind, MessageKind::StreamError);
611 assert_eq!(input_error.stream_id, 8);
612
613 let input_end = build_input_stream_end_frame(99, 8, 3, 2, 42, 40, false).unwrap();
614 assert_eq!(input_end.kind, MessageKind::StreamEnd);
615 assert_eq!(input_end.correlation_id, 99);
616 }
617}