1use std::io;
13use std::sync::Arc;
14
15use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
16
17use crate::auth::store::AuthStore;
18use crate::runtime::RedDBRuntime;
19use crate::serde_json::{self, Value as JsonValue};
20use reddb_wire::query_with_params::{
21 decode_query_with_params, ParamValue as RedWireParamValue, FEATURE_PARAMS,
22};
23
24use super::auth::{
25 build_auth_fail, build_auth_ok, build_hello_ack, pick_auth_method, validate_auth_response,
26 AuthOutcome, Hello,
27};
28use super::codec::{decode_frame, encode_frame};
29use super::frame::{Frame, MessageDirection, MessageKind, FRAME_HEADER_SIZE};
30use super::{FrameBuilder, MAX_KNOWN_MINOR_VERSION, REDWIRE_MAGIC};
31
32#[derive(Debug)]
33struct AuthedSession {
34 username: String,
35 #[allow(dead_code)]
36 session_id: String,
37}
38
39pub async fn handle_session<S>(
40 mut stream: S,
41 runtime: Arc<RedDBRuntime>,
42 auth_store: Option<Arc<AuthStore>>,
43 oauth: Option<Arc<crate::auth::oauth::OAuthValidator>>,
44) -> io::Result<()>
45where
46 S: AsyncRead + AsyncWrite + Unpin + Send,
47{
48 let session = perform_handshake(
52 &mut stream,
53 runtime.as_ref(),
54 auth_store.as_deref(),
55 oauth.as_deref(),
56 )
57 .await?;
58 if session.is_none() {
59 return Ok(());
60 }
61 let _session = session.unwrap();
62
63 let mut stream_session: Option<crate::wire::listener::BulkStreamSession> = None;
66 let mut prepared_stmts: std::collections::HashMap<u32, crate::wire::listener::PreparedStmt> =
67 std::collections::HashMap::new();
68
69 let mut buf = vec![0u8; FRAME_HEADER_SIZE];
70 loop {
71 if let Err(err) = stream.read_exact(&mut buf[..FRAME_HEADER_SIZE]).await {
73 if err.kind() == io::ErrorKind::UnexpectedEof {
74 return Ok(());
75 }
76 return Err(err);
77 }
78 let length = u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
79 if length < FRAME_HEADER_SIZE || length > super::frame::MAX_FRAME_SIZE as usize {
80 return Err(io::Error::other(format!("invalid frame length {length}")));
81 }
82 if buf.len() < length {
83 buf.resize(length, 0);
84 }
85 let payload_len = length - FRAME_HEADER_SIZE;
86 if payload_len > 0 {
87 stream
88 .read_exact(&mut buf[FRAME_HEADER_SIZE..length])
89 .await?;
90 }
91 let (frame, _) = decode_frame(&buf[..length])
92 .map_err(|e| io::Error::other(format!("decode frame: {e}")))?;
93
94 if frame.kind.direction() == MessageDirection::ServerToClient {
99 let err_frame = FrameBuilder::reply_to(frame.correlation_id)
100 .kind(MessageKind::Error)
101 .payload(format!("redwire: {:?} is server-only", frame.kind).into_bytes())
102 .build()
103 .map_err(|e| io::Error::other(format!("build Error frame: {e}")))?;
104 stream.write_all(&encode_frame(&err_frame)).await?;
105 continue;
106 }
107
108 match frame.kind {
109 MessageKind::Bye => {
110 let bye = encode_frame(&build_reply(
111 frame.correlation_id,
112 MessageKind::Bye,
113 vec![],
114 )?);
115 let _ = stream.write_all(&bye).await;
116 return Ok(());
117 }
118 MessageKind::Ping => {
119 let pong = encode_frame(&build_reply(
120 frame.correlation_id,
121 MessageKind::Pong,
122 vec![],
123 )?);
124 stream.write_all(&pong).await?;
125 }
126 MessageKind::Query => {
127 let response = run_query(&runtime, &frame);
128 stream.write_all(&encode_frame(&response)).await?;
129 }
130 MessageKind::QueryWithParams => {
131 let response = run_query_with_params(&runtime, &frame);
132 stream.write_all(&encode_frame(&response)).await?;
133 }
134 MessageKind::BulkInsert => {
138 let response = run_insert_dispatch(&runtime, &frame);
139 stream.write_all(&encode_frame(&response)).await?;
140 }
141 MessageKind::BulkInsertBinary => {
145 let raw =
146 crate::wire::listener::handle_bulk_insert_binary(&runtime, &frame.payload);
147 stream
148 .write_all(&encode_frame(&rewrap_handler_response(&raw, &frame)))
149 .await?;
150 }
151 MessageKind::BulkInsertPrevalidated => {
152 let raw = crate::wire::listener::handle_bulk_insert_binary_prevalidated(
153 &runtime,
154 &frame.payload,
155 );
156 stream
157 .write_all(&encode_frame(&rewrap_handler_response(&raw, &frame)))
158 .await?;
159 }
160 MessageKind::QueryBinary => {
161 let raw = crate::wire::listener::handle_query_binary(&runtime, &frame.payload);
162 stream
163 .write_all(&encode_frame(&rewrap_handler_response(&raw, &frame)))
164 .await?;
165 }
166 MessageKind::BulkStreamStart => {
169 let raw =
170 crate::wire::listener::handle_stream_start(&frame.payload, &mut stream_session);
171 stream
172 .write_all(&encode_frame(&rewrap_handler_response(&raw, &frame)))
173 .await?;
174 }
175 MessageKind::BulkStreamRows => {
176 let raw = crate::wire::listener::handle_stream_rows(
177 &runtime,
178 &frame.payload,
179 &mut stream_session,
180 );
181 if !raw.is_empty() {
187 stream
188 .write_all(&encode_frame(&rewrap_handler_response(&raw, &frame)))
189 .await?;
190 }
191 }
192 MessageKind::BulkStreamCommit => {
193 let raw =
194 crate::wire::listener::handle_stream_commit(&runtime, &mut stream_session);
195 stream
196 .write_all(&encode_frame(&rewrap_handler_response(&raw, &frame)))
197 .await?;
198 }
199 MessageKind::Prepare => {
203 let raw = crate::wire::listener::handle_prepare(
204 &runtime,
205 &frame.payload,
206 &mut prepared_stmts,
207 );
208 stream
209 .write_all(&encode_frame(&rewrap_handler_response(&raw, &frame)))
210 .await?;
211 }
212 MessageKind::ExecutePrepared => {
213 let raw = crate::wire::listener::handle_execute_prepared(
214 &runtime,
215 &frame.payload,
216 &prepared_stmts,
217 );
218 stream
219 .write_all(&encode_frame(&rewrap_handler_response(&raw, &frame)))
220 .await?;
221 }
222 MessageKind::Get => {
223 let response = run_get(&runtime, &frame);
224 stream.write_all(&encode_frame(&response)).await?;
225 }
226 MessageKind::Delete => {
227 let response = run_delete(&runtime, &frame);
228 stream.write_all(&encode_frame(&response)).await?;
229 }
230 other => {
231 let err_frame = FrameBuilder::reply_to(frame.correlation_id)
232 .kind(MessageKind::Error)
233 .payload(format!("redwire: cannot dispatch {other:?} yet").into_bytes())
234 .build()
235 .map_err(|e| io::Error::other(format!("build Error frame: {e}")))?;
236 let err = encode_frame(&err_frame);
237 stream.write_all(&err).await?;
238 }
239 }
240 }
241}
242
243async fn perform_handshake<S>(
246 stream: &mut S,
247 runtime: &RedDBRuntime,
248 auth_store: Option<&AuthStore>,
249 oauth: Option<&crate::auth::oauth::OAuthValidator>,
250) -> io::Result<Option<AuthedSession>>
251where
252 S: AsyncRead + AsyncWrite + Unpin + Send,
253{
254 let mut minor_buf = [0u8; 1];
256 stream.read_exact(&mut minor_buf).await?;
257 let minor = minor_buf[0];
258 if minor > MAX_KNOWN_MINOR_VERSION {
259 return Ok(None);
263 }
264
265 let hello = read_frame(stream).await?;
267 if hello.kind != MessageKind::Hello {
268 let fail = encode_frame(&build_reply(
269 hello.correlation_id,
270 MessageKind::AuthFail,
271 build_auth_fail("first frame after magic must be Hello"),
272 )?);
273 let _ = stream.write_all(&fail).await;
274 return Ok(None);
275 }
276 let hello_msg = match Hello::from_payload(&hello.payload) {
277 Ok(h) => h,
278 Err(e) => {
279 let fail = encode_frame(&build_reply(
280 hello.correlation_id,
281 MessageKind::AuthFail,
282 build_auth_fail(&e),
283 )?);
284 let _ = stream.write_all(&fail).await;
285 return Ok(None);
286 }
287 };
288
289 let chosen_version = hello_msg
290 .versions
291 .iter()
292 .copied()
293 .filter(|v| *v <= MAX_KNOWN_MINOR_VERSION)
294 .max()
295 .unwrap_or(0);
296 if chosen_version == 0 {
297 let fail = encode_frame(&build_reply(
298 hello.correlation_id,
299 MessageKind::AuthFail,
300 build_auth_fail("no overlapping protocol version"),
301 )?);
302 let _ = stream.write_all(&fail).await;
303 return Ok(None);
304 }
305
306 let server_anon_ok = auth_store.map(|s| !s.is_enabled()).unwrap_or(true);
307 let chosen = match pick_auth_method(&hello_msg.auth_methods, server_anon_ok) {
308 Some(m) => m,
309 None => {
310 let fail = encode_frame(&build_reply(
311 hello.correlation_id,
312 MessageKind::AuthFail,
313 build_auth_fail("no overlapping auth method"),
314 )?);
315 let _ = stream.write_all(&fail).await;
316 return Ok(None);
317 }
318 };
319
320 let server_features = FEATURE_PARAMS;
329 let topology = build_topology_for_hello_ack(runtime);
330 let ack_frame = FrameBuilder::reply_to(hello.correlation_id)
331 .kind(MessageKind::HelloAck)
332 .payload(build_hello_ack(
333 chosen_version,
334 chosen,
335 server_features,
336 topology.as_ref(),
337 ))
338 .build()
339 .map_err(|e| io::Error::other(format!("build HelloAck: {e}")))?;
340 let ack = encode_frame(&ack_frame);
341 stream.write_all(&ack).await?;
342
343 if chosen == "scram-sha-256" {
347 return perform_scram_handshake(stream, auth_store, hello.correlation_id, server_features)
348 .await;
349 }
350
351 let resp = read_frame(stream).await?;
354 if resp.kind != MessageKind::AuthResponse {
355 let fail = encode_frame(&build_reply(
356 resp.correlation_id,
357 MessageKind::AuthFail,
358 build_auth_fail("expected AuthResponse"),
359 )?);
360 let _ = stream.write_all(&fail).await;
361 return Ok(None);
362 }
363
364 if chosen == "oauth-jwt" {
368 let validator = match oauth {
369 Some(v) => v,
370 None => {
371 let fail = encode_frame(&build_reply(
372 resp.correlation_id,
373 MessageKind::AuthFail,
374 build_auth_fail("oauth-jwt requires RedWireConfig.oauth"),
375 )?);
376 let _ = stream.write_all(&fail).await;
377 return Ok(None);
378 }
379 };
380 let raw = match crate::serde_json::from_slice::<JsonValue>(&resp.payload)
381 .ok()
382 .and_then(|v| {
383 v.as_object()
384 .and_then(|o| o.get("jwt").cloned())
385 .and_then(|x| x.as_str().map(String::from))
386 }) {
387 Some(s) if !s.is_empty() => s,
388 _ => {
389 let fail = encode_frame(&build_reply(
390 resp.correlation_id,
391 MessageKind::AuthFail,
392 build_auth_fail("oauth-jwt: AuthResponse missing 'jwt' string"),
393 )?);
394 let _ = stream.write_all(&fail).await;
395 return Ok(None);
396 }
397 };
398 match super::auth::validate_oauth_jwt(validator, &raw) {
399 Ok((username, role)) => {
400 let session_id = super::auth::new_session_id_for_scram();
401 let ok = encode_frame(&build_reply(
402 resp.correlation_id,
403 MessageKind::AuthOk,
404 build_auth_ok(&session_id, &username, role, server_features),
405 )?);
406 stream.write_all(&ok).await?;
407 return Ok(Some(AuthedSession {
408 username,
409 session_id,
410 }));
411 }
412 Err(reason) => {
413 let fail = encode_frame(&build_reply(
414 resp.correlation_id,
415 MessageKind::AuthFail,
416 build_auth_fail(&format!("oauth-jwt: {reason}")),
417 )?);
418 let _ = stream.write_all(&fail).await;
419 return Ok(None);
420 }
421 }
422 }
423
424 match validate_auth_response(chosen, &resp.payload, auth_store) {
425 AuthOutcome::Authenticated {
426 username,
427 role,
428 session_id,
429 } => {
430 let ok_frame = FrameBuilder::reply_to(resp.correlation_id)
431 .kind(MessageKind::AuthOk)
432 .payload(build_auth_ok(&session_id, &username, role, server_features))
433 .build()
434 .map_err(|e| io::Error::other(format!("build AuthOk: {e}")))?;
435 let ok = encode_frame(&ok_frame);
436 stream.write_all(&ok).await?;
437 Ok(Some(AuthedSession {
438 username,
439 session_id,
440 }))
441 }
442 AuthOutcome::Refused(reason) => {
443 let fail = encode_frame(&build_reply(
444 resp.correlation_id,
445 MessageKind::AuthFail,
446 build_auth_fail(&reason),
447 )?);
448 let _ = stream.write_all(&fail).await;
449 Ok(None)
450 }
451 }
452}
453
454async fn perform_scram_handshake<S>(
461 stream: &mut S,
462 auth_store: Option<&AuthStore>,
463 initial_correlation: u64,
464 server_features: u32,
465) -> io::Result<Option<AuthedSession>>
466where
467 S: AsyncRead + AsyncWrite + Unpin + Send,
468{
469 let store = match auth_store {
470 Some(s) => s,
471 None => {
472 let fail = encode_frame(&build_reply(
473 initial_correlation,
474 MessageKind::AuthFail,
475 build_auth_fail("scram-sha-256 requires an AuthStore"),
476 )?);
477 let _ = stream.write_all(&fail).await;
478 return Ok(None);
479 }
480 };
481
482 let cf = read_frame(stream).await?;
484 if cf.kind != MessageKind::AuthResponse {
485 let fail = encode_frame(&build_reply(
486 cf.correlation_id,
487 MessageKind::AuthFail,
488 build_auth_fail("expected AuthResponse(client-first-message)"),
489 )?);
490 let _ = stream.write_all(&fail).await;
491 return Ok(None);
492 }
493 let (username, client_nonce, client_first_bare) =
494 match super::auth::parse_scram_client_first(&cf.payload) {
495 Ok(t) => t,
496 Err(e) => {
497 let fail = encode_frame(&build_reply(
498 cf.correlation_id,
499 MessageKind::AuthFail,
500 build_auth_fail(&format!("scram client-first: {e}")),
501 )?);
502 let _ = stream.write_all(&fail).await;
503 return Ok(None);
504 }
505 };
506
507 let verifier = store.lookup_scram_verifier_global(&username);
516 let (salt, iter, stored_key, server_key, user_known) = match &verifier {
517 Some(v) => (v.salt.clone(), v.iter, v.stored_key, v.server_key, true),
518 None => (
519 crate::auth::store::random_bytes(16),
520 crate::auth::scram::DEFAULT_ITER,
521 [0u8; 32],
522 [0u8; 32],
523 false,
524 ),
525 };
526
527 let server_nonce = super::auth::new_server_nonce();
529 let server_first =
530 super::auth::build_scram_server_first(&client_nonce, &server_nonce, &salt, iter);
531 let req = encode_frame(&build_reply(
532 cf.correlation_id,
533 MessageKind::AuthRequest,
534 server_first.as_bytes().to_vec(),
535 )?);
536 stream.write_all(&req).await?;
537
538 let cfinal = read_frame(stream).await?;
540 if cfinal.kind != MessageKind::AuthResponse {
541 let fail = encode_frame(&build_reply(
542 cfinal.correlation_id,
543 MessageKind::AuthFail,
544 build_auth_fail("expected AuthResponse(client-final-message)"),
545 )?);
546 let _ = stream.write_all(&fail).await;
547 return Ok(None);
548 }
549 let (combined_nonce, presented_proof, client_final_no_proof) =
550 match super::auth::parse_scram_client_final(&cfinal.payload) {
551 Ok(t) => t,
552 Err(e) => {
553 let fail = encode_frame(&build_reply(
554 cfinal.correlation_id,
555 MessageKind::AuthFail,
556 build_auth_fail(&format!("scram client-final: {e}")),
557 )?);
558 let _ = stream.write_all(&fail).await;
559 return Ok(None);
560 }
561 };
562 let expected_combined = format!("{client_nonce}{server_nonce}");
563 if combined_nonce != expected_combined {
564 let fail = encode_frame(&build_reply(
565 cfinal.correlation_id,
566 MessageKind::AuthFail,
567 build_auth_fail("scram nonce mismatch — replay protection failed"),
568 )?);
569 let _ = stream.write_all(&fail).await;
570 return Ok(None);
571 }
572
573 let auth_message =
575 crate::auth::scram::auth_message(&client_first_bare, &server_first, &client_final_no_proof);
576 let proof_ok = if user_known {
577 let v = crate::auth::scram::ScramVerifier {
578 salt: salt.clone(),
579 iter,
580 stored_key,
581 server_key,
582 };
583 crate::auth::scram::verify_client_proof(&v, &auth_message, &presented_proof)
584 } else {
585 false
586 };
587 if !proof_ok {
588 let fail = encode_frame(&build_reply(
589 cfinal.correlation_id,
590 MessageKind::AuthFail,
591 build_auth_fail("invalid SCRAM proof"),
592 )?);
593 let _ = stream.write_all(&fail).await;
594 return Ok(None);
595 }
596
597 let role = store
599 .list_users()
600 .into_iter()
601 .find(|u| u.username == username)
602 .map(|u| u.role)
603 .unwrap_or(crate::auth::Role::Read);
604 let server_sig = crate::auth::scram::server_signature(&server_key, &auth_message);
605 let session_id = super::auth::new_session_id_for_scram();
606 let ok_payload = super::auth::build_scram_auth_ok(
607 &session_id,
608 &username,
609 role,
610 server_features,
611 &server_sig,
612 );
613 let ok = encode_frame(&build_reply(
614 cfinal.correlation_id,
615 MessageKind::AuthOk,
616 ok_payload,
617 )?);
618 stream.write_all(&ok).await?;
619 Ok(Some(AuthedSession {
620 username,
621 session_id,
622 }))
623}
624
625async fn read_frame<S>(stream: &mut S) -> io::Result<Frame>
626where
627 S: AsyncRead + AsyncWrite + Unpin + Send,
628{
629 let mut header = [0u8; FRAME_HEADER_SIZE];
630 stream.read_exact(&mut header).await?;
631 let length = u32::from_le_bytes([header[0], header[1], header[2], header[3]]) as usize;
632 if length < FRAME_HEADER_SIZE || length > super::frame::MAX_FRAME_SIZE as usize {
633 return Err(io::Error::other(format!(
634 "redwire frame length {length} out of range"
635 )));
636 }
637 let mut buf = vec![0u8; length];
638 buf[..FRAME_HEADER_SIZE].copy_from_slice(&header);
639 if length > FRAME_HEADER_SIZE {
640 stream
641 .read_exact(&mut buf[FRAME_HEADER_SIZE..length])
642 .await?;
643 }
644 let (frame, _) =
645 decode_frame(&buf).map_err(|e| io::Error::other(format!("decode frame: {e}")))?;
646 Ok(frame)
647}
648
649fn run_query(runtime: &RedDBRuntime, frame: &Frame) -> Frame {
650 let sql = match std::str::from_utf8(&frame.payload) {
651 Ok(s) => s,
652 Err(_) => {
653 return error_frame(frame.correlation_id, "Query payload must be UTF-8 SQL");
654 }
655 };
656 match runtime.execute_query(sql) {
657 Ok(result) => {
658 let mut obj = crate::serde_json::Map::new();
659 obj.insert("ok".to_string(), JsonValue::Bool(true));
660 obj.insert(
661 "statement".to_string(),
662 JsonValue::String(result.statement_type.to_string()),
663 );
664 obj.insert(
665 "affected".to_string(),
666 JsonValue::Number(result.affected_rows as f64),
667 );
668 let payload = serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default();
669 build_dispatch_reply(frame.correlation_id, MessageKind::Result, payload)
670 }
671 Err(err) => error_frame(frame.correlation_id, &err.to_string()),
672 }
673}
674
675fn run_query_with_params(runtime: &RedDBRuntime, frame: &Frame) -> Frame {
676 let (sql, params) = match decode_query_with_params(&frame.payload) {
677 Ok(decoded) => decoded,
678 Err(err) => return error_frame(frame.correlation_id, &err.to_string()),
679 };
680 let params = params
681 .into_iter()
682 .map(param_to_schema_value)
683 .collect::<Vec<_>>();
684 let parsed = match crate::storage::query::modes::parse_multi(&sql) {
685 Ok(parsed) => parsed,
686 Err(err) => return error_frame(frame.correlation_id, &err.to_string()),
687 };
688 let bound = match crate::storage::query::user_params::bind(&parsed, ¶ms) {
689 Ok(bound) => bound,
690 Err(err) => return error_frame(frame.correlation_id, &err.to_string()),
691 };
692 match runtime.execute_query_expr(bound) {
693 Ok(result) => {
694 let is_mutation = matches!(result.statement_type, "insert" | "update" | "delete");
695 if is_mutation {
696 let post_lsn = runtime.cdc_current_lsn();
697 if let Err(err) = runtime.enforce_commit_policy(post_lsn) {
698 return error_frame(frame.correlation_id, &err.to_string());
699 }
700 }
701 let payload = serde_json::to_vec(
702 &crate::presentation::query_result_json::runtime_query_json(&result, &None, &None),
703 )
704 .unwrap_or_default();
705 build_dispatch_reply(frame.correlation_id, MessageKind::Result, payload)
706 }
707 Err(err) => error_frame(frame.correlation_id, &err.to_string()),
708 }
709}
710
711fn param_to_schema_value(value: RedWireParamValue) -> crate::storage::schema::Value {
712 use crate::storage::schema::Value;
713 match value {
714 RedWireParamValue::Null => Value::Null,
715 RedWireParamValue::Bool(value) => Value::Boolean(value),
716 RedWireParamValue::Int(value) => Value::Integer(value),
717 RedWireParamValue::Float(value) => Value::Float(value),
718 RedWireParamValue::Text(value) => Value::Text(Arc::from(value.as_str())),
719 RedWireParamValue::Bytes(value) => Value::Blob(value),
720 RedWireParamValue::Vector(value) => Value::Vector(value),
721 RedWireParamValue::Json(value) => Value::Json(value),
722 RedWireParamValue::Timestamp(value) => Value::Timestamp(value),
723 RedWireParamValue::Uuid(value) => Value::Uuid(value),
724 }
725}
726
727fn run_insert_dispatch(runtime: &RedDBRuntime, frame: &Frame) -> Frame {
745 let v: JsonValue = match serde_json::from_slice(&frame.payload) {
746 Ok(v) => v,
747 Err(e) => return error_frame(frame.correlation_id, &format!("Insert: invalid JSON: {e}")),
748 };
749 let obj = match v.as_object() {
750 Some(o) => o,
751 None => {
752 return error_frame(
753 frame.correlation_id,
754 "Insert: payload must be a JSON object",
755 )
756 }
757 };
758 let collection = match obj.get("collection").and_then(|x| x.as_str()) {
759 Some(s) if !s.is_empty() => s,
760 _ => return error_frame(frame.correlation_id, "Insert: missing 'collection' string"),
761 };
762
763 let idempotency_key = obj.get("idempotency_key").and_then(|x| x.as_str());
769 let batch_flag = obj.get("batch").and_then(|x| x.as_bool()).unwrap_or(false);
770 if idempotency_key.is_some() || batch_flag {
771 let items = match obj.get("payloads").and_then(|x| x.as_array()) {
772 Some(rows) => rows,
773 None => {
774 return error_frame(
775 frame.correlation_id,
776 "BatchInsert: missing 'payloads' array",
777 )
778 }
779 };
780 let outcome = crate::server::handlers_entity::process_batch_insert(
781 runtime,
782 collection,
783 items,
784 idempotency_key,
785 );
786 let kind = if (200..300).contains(&outcome.status) {
791 MessageKind::BulkOk
792 } else {
793 MessageKind::Error
794 };
795 return build_dispatch_reply(frame.correlation_id, kind, outcome.body);
796 }
797
798 if let Some(rows) = obj.get("payloads").and_then(|x| x.as_array()) {
799 let mut objects = Vec::with_capacity(rows.len());
800 for entry in rows {
801 objects.push(match entry.as_object() {
802 Some(o) => o,
803 None => {
804 return error_frame(
805 frame.correlation_id,
806 "Insert: each payload must be a JSON object",
807 )
808 }
809 });
810 }
811
812 if crate::rpc_stdio::should_bulk_insert_graph(runtime, collection, &objects) {
813 return match crate::rpc_stdio::bulk_insert_graph(runtime, collection, &objects) {
814 Ok(body) => {
815 let payload = serde_json::to_vec(&body).unwrap_or_default();
816 build_dispatch_reply(frame.correlation_id, MessageKind::BulkOk, payload)
817 }
818 Err(err) => error_frame(frame.correlation_id, &err.to_string()),
819 };
820 }
821
822 let mut affected: u64 = 0;
823 let mut ids = Vec::with_capacity(objects.len());
824 for row in objects {
825 let sql = crate::rpc_stdio::build_insert_sql(collection, row.iter());
826 match runtime.execute_query(&sql) {
827 Ok(qr) => {
828 affected += qr.affected_rows;
829 if let Some(id) = crate::rpc_stdio::insert_result_to_json(&qr).get("id") {
830 ids.push(id.clone());
831 }
832 }
833 Err(err) => return error_frame(frame.correlation_id, &err.to_string()),
834 }
835 }
836 let mut out = crate::serde_json::Map::new();
837 out.insert("affected".to_string(), JsonValue::Number(affected as f64));
838 out.insert("ids".to_string(), JsonValue::Array(ids));
839 let payload = serde_json::to_vec(&JsonValue::Object(out)).unwrap_or_default();
840 return build_dispatch_reply(frame.correlation_id, MessageKind::BulkOk, payload);
841 }
842
843 let row = match obj.get("payload").and_then(|x| x.as_object()) {
844 Some(o) => o,
845 None => {
846 return error_frame(
847 frame.correlation_id,
848 "Insert: missing 'payload' object or 'payloads' array",
849 )
850 }
851 };
852 let sql = crate::rpc_stdio::build_insert_sql(collection, row.iter());
853 match runtime.execute_query(&sql) {
854 Ok(qr) => {
855 let body = crate::rpc_stdio::insert_result_to_json(&qr);
856 let payload = serde_json::to_vec(&body).unwrap_or_default();
857 build_dispatch_reply(frame.correlation_id, MessageKind::BulkOk, payload)
858 }
859 Err(err) => error_frame(frame.correlation_id, &err.to_string()),
860 }
861}
862
863fn build_topology_for_hello_ack(runtime: &RedDBRuntime) -> Option<reddb_wire::topology::Topology> {
875 use crate::auth::middleware::AuthResult;
876 use crate::replication::{LagConfig, TopologyAdvertiser};
877 use reddb_wire::topology::Endpoint;
878
879 let db = runtime.db();
880 let primary_endpoint = Endpoint {
881 addr: runtime.config_string("red.redwire.advertise_addr", ""),
882 region: db.options().replication.region.clone(),
883 };
884 let (replicas, current_lsn, epoch) = match db.replication.as_ref() {
885 Some(repl) => (
886 repl.replica_snapshots(),
887 repl.wal_buffer.current_lsn(),
888 repl.topology_epoch(),
889 ),
890 None => (Vec::new(), 0u64, 0u64),
891 };
892 let lag = LagConfig::from_now();
893 Some(TopologyAdvertiser::advertise(
894 &replicas,
895 &AuthResult::Anonymous,
896 epoch,
897 primary_endpoint,
898 current_lsn,
899 &lag,
900 ))
901}
902
903fn error_frame(correlation_id: u64, msg: &str) -> Frame {
904 FrameBuilder::reply_to(correlation_id)
909 .kind(MessageKind::Error)
910 .payload(msg.as_bytes().to_vec())
911 .build()
912 .expect("error frame fits in MAX_FRAME_SIZE")
913}
914
915fn build_reply(correlation_id: u64, kind: MessageKind, payload: Vec<u8>) -> io::Result<Frame> {
924 FrameBuilder::reply_to(correlation_id)
925 .kind(kind)
926 .payload(payload)
927 .build()
928 .map_err(|e| io::Error::other(format!("build {kind:?}: {e}")))
929}
930
931fn build_dispatch_reply(correlation_id: u64, kind: MessageKind, payload: Vec<u8>) -> Frame {
938 FrameBuilder::reply_to(correlation_id)
939 .kind(kind)
940 .payload(payload)
941 .build()
942 .unwrap_or_else(|e| error_frame(correlation_id, &e.to_string()))
943}
944
945fn rewrap_handler_response(raw_bytes: &[u8], req: &Frame) -> Frame {
957 if raw_bytes.len() < 5 {
958 return error_frame(
959 req.correlation_id,
960 "fast-path handler returned a truncated frame",
961 );
962 }
963 let kind_byte = raw_bytes[4];
964 let kind = MessageKind::from_u8(kind_byte).unwrap_or(MessageKind::Error);
965 let body = raw_bytes[5..].to_vec();
966 build_dispatch_reply(req.correlation_id, kind, body)
967}
968
969fn run_get(runtime: &RedDBRuntime, frame: &Frame) -> Frame {
973 let v: JsonValue = match serde_json::from_slice(&frame.payload) {
974 Ok(v) => v,
975 Err(e) => return error_frame(frame.correlation_id, &format!("Get: invalid JSON: {e}")),
976 };
977 let obj = match v.as_object() {
978 Some(o) => o,
979 None => return error_frame(frame.correlation_id, "Get: payload must be a JSON object"),
980 };
981 let collection = match obj.get("collection").and_then(|x| x.as_str()) {
982 Some(s) if !s.is_empty() => s,
983 _ => return error_frame(frame.correlation_id, "Get: missing 'collection' string"),
984 };
985 let id = match obj.get("id").and_then(|x| x.as_str()) {
986 Some(s) if !s.is_empty() => s,
987 _ => return error_frame(frame.correlation_id, "Get: missing 'id' string"),
988 };
989 let id_lit = crate::rpc_stdio::value_to_sql_literal(&JsonValue::String(id.to_string()));
992 let sql = format!("SELECT * FROM {collection} WHERE _id = {id_lit} LIMIT 1");
993 match runtime.execute_query(&sql) {
994 Ok(qr) => {
995 let mut out = crate::serde_json::Map::new();
996 out.insert("ok".to_string(), JsonValue::Bool(true));
997 out.insert(
998 "found".to_string(),
999 JsonValue::Bool(!qr.result.records.is_empty()),
1000 );
1001 let payload = serde_json::to_vec(&JsonValue::Object(out)).unwrap_or_default();
1004 build_dispatch_reply(frame.correlation_id, MessageKind::Result, payload)
1005 }
1006 Err(err) => error_frame(frame.correlation_id, &err.to_string()),
1007 }
1008}
1009
1010fn run_delete(runtime: &RedDBRuntime, frame: &Frame) -> Frame {
1014 let v: JsonValue = match serde_json::from_slice(&frame.payload) {
1015 Ok(v) => v,
1016 Err(e) => return error_frame(frame.correlation_id, &format!("Delete: invalid JSON: {e}")),
1017 };
1018 let obj = match v.as_object() {
1019 Some(o) => o,
1020 None => {
1021 return error_frame(
1022 frame.correlation_id,
1023 "Delete: payload must be a JSON object",
1024 )
1025 }
1026 };
1027 let collection = match obj.get("collection").and_then(|x| x.as_str()) {
1028 Some(s) if !s.is_empty() => s,
1029 _ => return error_frame(frame.correlation_id, "Delete: missing 'collection' string"),
1030 };
1031 let id = match obj.get("id").and_then(|x| x.as_str()) {
1032 Some(s) if !s.is_empty() => s,
1033 _ => return error_frame(frame.correlation_id, "Delete: missing 'id' string"),
1034 };
1035 let id_lit = crate::rpc_stdio::value_to_sql_literal(&JsonValue::String(id.to_string()));
1036 let sql = format!("DELETE FROM {collection} WHERE _id = {id_lit}");
1037 match runtime.execute_query(&sql) {
1038 Ok(qr) => {
1039 let mut out = crate::serde_json::Map::new();
1040 out.insert(
1041 "affected".to_string(),
1042 JsonValue::Number(qr.affected_rows as f64),
1043 );
1044 let payload = serde_json::to_vec(&JsonValue::Object(out)).unwrap_or_default();
1045 build_dispatch_reply(frame.correlation_id, MessageKind::DeleteOk, payload)
1046 }
1047 Err(err) => error_frame(frame.correlation_id, &err.to_string()),
1048 }
1049}
1050
1051#[cfg(test)]
1052mod tests {
1053 use super::*;
1054
1055 use crate::runtime::RedDBRuntime;
1056 use tokio::io::{AsyncReadExt, AsyncWriteExt};
1057
1058 fn create_graph_collection(runtime: &RedDBRuntime, name: &str) {
1059 let db = runtime.db();
1060 db.store()
1061 .create_collection(name)
1062 .expect("create collection");
1063 let now = std::time::SystemTime::now()
1064 .duration_since(std::time::UNIX_EPOCH)
1065 .unwrap_or_default()
1066 .as_millis();
1067 db.save_collection_contract(crate::physical::CollectionContract {
1068 name: name.to_string(),
1069 declared_model: crate::catalog::CollectionModel::Graph,
1070 schema_mode: crate::catalog::SchemaMode::Dynamic,
1071 origin: crate::physical::ContractOrigin::Explicit,
1072 version: 1,
1073 created_at_unix_ms: now,
1074 updated_at_unix_ms: now,
1075 default_ttl_ms: None,
1076 vector_dimension: None,
1077 vector_metric: None,
1078 context_index_fields: Vec::new(),
1079 declared_columns: Vec::new(),
1080 table_def: None,
1081 timestamps_enabled: false,
1082 context_index_enabled: false,
1083 metrics_raw_retention_ms: None,
1084 metrics_rollup_policies: Vec::new(),
1085 metrics_tenant_identity: None,
1086 metrics_namespace: None,
1087 append_only: false,
1088 subscriptions: Vec::new(),
1089 session_key: None,
1090 session_gap_ms: None,
1091 retention_duration_ms: None,
1092 })
1093 .expect("save graph contract");
1094 }
1095
1096 #[test]
1097 fn magic_byte_is_0xfe() {
1098 assert_eq!(REDWIRE_MAGIC, 0xFE);
1099 }
1100
1101 #[test]
1102 fn redwire_bulk_insert_graph_rows_returns_ids() {
1103 let runtime = RedDBRuntime::in_memory().expect("runtime");
1104 create_graph_collection(&runtime, "network");
1105
1106 let nodes = Frame::new(
1107 MessageKind::BulkInsert,
1108 7,
1109 br#"{"collection":"network","payloads":[{"label":"Host","name":"app"},{"label":"Host","name":"db"}]}"#.to_vec(),
1110 );
1111 let nodes_reply = run_insert_dispatch(&runtime, &nodes);
1112 assert_eq!(nodes_reply.kind, MessageKind::BulkOk);
1113 let node_body: JsonValue =
1114 serde_json::from_slice(&nodes_reply.payload).expect("nodes json");
1115 assert_eq!(
1116 node_body.get("affected").and_then(JsonValue::as_u64),
1117 Some(2)
1118 );
1119 let ids = node_body
1120 .get("ids")
1121 .and_then(JsonValue::as_array)
1122 .expect("node ids");
1123 assert_eq!(ids.len(), 2);
1124
1125 let from = ids[0].as_u64().expect("from id");
1126 let to = ids[1].as_u64().expect("to id");
1127 let edges = Frame::new(
1128 MessageKind::BulkInsert,
1129 8,
1130 format!(
1131 r#"{{"collection":"network","payloads":[{{"label":"connects","from":{from},"to":{to},"role":"primary"}}]}}"#
1132 )
1133 .into_bytes(),
1134 );
1135 let edges_reply = run_insert_dispatch(&runtime, &edges);
1136 assert_eq!(edges_reply.kind, MessageKind::BulkOk);
1137 let edge_body: JsonValue =
1138 serde_json::from_slice(&edges_reply.payload).expect("edges json");
1139 assert_eq!(
1140 edge_body.get("affected").and_then(JsonValue::as_u64),
1141 Some(1)
1142 );
1143 assert_eq!(
1144 edge_body
1145 .get("ids")
1146 .and_then(JsonValue::as_array)
1147 .map(|ids| ids.len()),
1148 Some(1)
1149 );
1150 }
1151
1152 #[test]
1167 fn redwire_batch_insert_happy_path_returns_bulkok_with_count() {
1168 let runtime = RedDBRuntime::in_memory().expect("runtime");
1169 runtime
1170 .execute_query("CREATE TABLE events_587_ok (id INTEGER, name TEXT)")
1171 .expect("create table");
1172
1173 let frame = Frame::new(
1174 MessageKind::BulkInsert,
1175 100,
1176 br#"{
1177 "collection":"events_587_ok",
1178 "idempotency_key":"k-ok",
1179 "payloads":[
1180 {"fields":{"id":1,"name":"a"}},
1181 {"fields":{"id":2,"name":"b"}},
1182 {"fields":{"id":3,"name":"c"}}
1183 ]
1184 }"#
1185 .to_vec(),
1186 );
1187 let reply = run_insert_dispatch(&runtime, &frame);
1188 assert_eq!(
1189 reply.kind,
1190 MessageKind::BulkOk,
1191 "body={:?}",
1192 String::from_utf8_lossy(&reply.payload)
1193 );
1194 let body: JsonValue = serde_json::from_slice(&reply.payload).expect("ok body json");
1195 assert_eq!(body.get("ok").and_then(JsonValue::as_bool), Some(true));
1196 assert_eq!(body.get("count").and_then(JsonValue::as_u64), Some(3));
1197
1198 let qr = runtime
1204 .execute_query("SELECT name FROM events_587_ok ORDER BY id ASC")
1205 .expect("scan");
1206 let names: Vec<String> = qr
1207 .result
1208 .records
1209 .iter()
1210 .filter_map(|record| match record.get("name") {
1211 Some(crate::storage::schema::Value::Text(s)) => Some(s.to_string()),
1212 _ => None,
1213 })
1214 .collect();
1215 assert_eq!(names, vec!["a", "b", "c"]);
1216 }
1217
1218 #[test]
1223 fn redwire_batch_insert_row_failure_rolls_back_with_row_index() {
1224 let runtime = RedDBRuntime::in_memory().expect("runtime");
1225 runtime
1226 .execute_query("CREATE TABLE events_587_rollback (id INTEGER, name TEXT)")
1227 .expect("create table");
1228
1229 let frame = Frame::new(
1232 MessageKind::BulkInsert,
1233 101,
1234 br#"{
1235 "collection":"events_587_rollback",
1236 "idempotency_key":"k-rollback",
1237 "payloads":[
1238 {"fields":{"id":1,"name":"a"}},
1239 {"not_fields":{"id":2}},
1240 {"fields":{"id":3,"name":"c"}}
1241 ]
1242 }"#
1243 .to_vec(),
1244 );
1245 let reply = run_insert_dispatch(&runtime, &frame);
1246 assert_eq!(reply.kind, MessageKind::Error);
1247 let body: JsonValue = serde_json::from_slice(&reply.payload).expect("err body json");
1248 assert_eq!(body.get("ok").and_then(JsonValue::as_bool), Some(false));
1249 assert_eq!(
1250 body.get("code").and_then(JsonValue::as_str),
1251 Some("RowParseFailure")
1252 );
1253 assert_eq!(body.get("row_index").and_then(JsonValue::as_u64), Some(1));
1254
1255 let qr = runtime
1258 .execute_query("SELECT name FROM events_587_rollback")
1259 .expect("scan");
1260 assert!(
1261 qr.result.records.is_empty(),
1262 "row 0 leaked despite row 1 rejection: {} rows present",
1263 qr.result.records.len()
1264 );
1265 }
1266
1267 #[test]
1275 fn redwire_batch_insert_idempotency_key_replays_cached_result() {
1276 let runtime = RedDBRuntime::in_memory().expect("runtime");
1277 runtime
1278 .execute_query("CREATE TABLE events_587_dedup (id INTEGER, name TEXT)")
1279 .expect("create table");
1280
1281 let key = format!(
1285 "redwire-587-{}",
1286 std::time::SystemTime::now()
1287 .duration_since(std::time::UNIX_EPOCH)
1288 .unwrap()
1289 .as_nanos()
1290 );
1291
1292 let frame1 = Frame::new(
1293 MessageKind::BulkInsert,
1294 200,
1295 format!(
1296 r#"{{
1297 "collection":"events_587_dedup",
1298 "idempotency_key":"{key}",
1299 "payloads":[{{"fields":{{"id":1,"name":"first"}}}}]
1300 }}"#
1301 )
1302 .into_bytes(),
1303 );
1304 let reply1 = run_insert_dispatch(&runtime, &frame1);
1305 assert_eq!(reply1.kind, MessageKind::BulkOk);
1306 let body1 = reply1.payload.clone();
1307
1308 let frame2 = Frame::new(
1312 MessageKind::BulkInsert,
1313 201,
1314 format!(
1315 r#"{{
1316 "collection":"events_587_dedup",
1317 "idempotency_key":"{key}",
1318 "payloads":[{{"fields":{{"id":2,"name":"second"}}}}]
1319 }}"#
1320 )
1321 .into_bytes(),
1322 );
1323 let reply2 = run_insert_dispatch(&runtime, &frame2);
1324 assert_eq!(reply2.kind, MessageKind::BulkOk);
1325 assert_eq!(
1326 reply2.payload, body1,
1327 "replay must return cached body byte-for-byte"
1328 );
1329
1330 let qr = runtime
1331 .execute_query("SELECT name FROM events_587_dedup")
1332 .expect("scan");
1333 assert_eq!(
1334 qr.result.records.len(),
1335 1,
1336 "replay re-executed and committed the second row"
1337 );
1338 }
1339
1340 #[test]
1344 fn redwire_batch_insert_cache_shared_with_http_transport() {
1345 use crate::runtime::batch_insert::global_cache;
1346
1347 let runtime = RedDBRuntime::in_memory().expect("runtime");
1348 runtime
1349 .execute_query("CREATE TABLE events_587_shared (id INTEGER, name TEXT)")
1350 .expect("create table");
1351
1352 let key = format!(
1353 "shared-cache-587-{}",
1354 std::time::SystemTime::now()
1355 .duration_since(std::time::UNIX_EPOCH)
1356 .unwrap()
1357 .as_nanos()
1358 );
1359
1360 let frame = Frame::new(
1361 MessageKind::BulkInsert,
1362 300,
1363 format!(
1364 r#"{{
1365 "collection":"events_587_shared",
1366 "idempotency_key":"{key}",
1367 "payloads":[{{"fields":{{"id":1,"name":"x"}}}}]
1368 }}"#
1369 )
1370 .into_bytes(),
1371 );
1372 let reply = run_insert_dispatch(&runtime, &frame);
1373 assert_eq!(reply.kind, MessageKind::BulkOk);
1374
1375 let hit = global_cache()
1379 .lookup("events_587_shared", &key, std::time::Instant::now())
1380 .expect("shared cache must serve the RedWire write to HTTP");
1381 assert_eq!(hit.status, 200);
1382 assert_eq!(hit.body, reply.payload);
1383 }
1384
1385 #[test]
1390 fn redwire_batch_insert_schema_validation_rejects_unknown_field() {
1391 use crate::runtime::analytics_schema_registry as reg;
1392
1393 let runtime = RedDBRuntime::in_memory().expect("runtime");
1394 runtime
1395 .execute_query("CREATE TABLE events_587_schema (event_name TEXT, payload TEXT)")
1396 .expect("create table");
1397
1398 let schema =
1399 r#"{"type":"object","properties":{"url":{"type":"string"}},"required":["url"]}"#;
1400 reg::register(runtime.db().store().as_ref(), "click_587", schema).expect("register schema");
1401
1402 let frame = Frame::new(
1403 MessageKind::BulkInsert,
1404 400,
1405 br#"{
1406 "collection":"events_587_schema",
1407 "idempotency_key":"k-schema",
1408 "payloads":[
1409 {"fields":{"event_name":"click_587","payload":"{\"url\":\"/a\"}"}},
1410 {"fields":{"event_name":"click_587","payload":"{\"url\":\"/b\",\"extra\":1}"}}
1411 ]
1412 }"#
1413 .to_vec(),
1414 );
1415 let reply = run_insert_dispatch(&runtime, &frame);
1416 assert_eq!(reply.kind, MessageKind::Error);
1417 let body: JsonValue = serde_json::from_slice(&reply.payload).expect("err body json");
1418 assert_eq!(
1419 body.get("code").and_then(JsonValue::as_str),
1420 Some("RowSchemaRejected")
1421 );
1422 assert_eq!(body.get("row_index").and_then(JsonValue::as_u64), Some(1));
1423
1424 let qr = runtime
1425 .execute_query("SELECT event_name FROM events_587_schema")
1426 .expect("scan");
1427 assert!(
1428 qr.result.records.is_empty(),
1429 "row 0 leaked despite row 1 schema rejection"
1430 );
1431 }
1432
1433 #[test]
1444 fn redwire_batch_insert_oversize_returns_error_before_storage() {
1445 let runtime = RedDBRuntime::in_memory().expect("runtime");
1446 runtime
1447 .execute_query("CREATE TABLE events_587_oversize (id INTEGER, name TEXT)")
1448 .expect("create table");
1449
1450 let max = 10_000usize;
1452 let mut payloads = String::with_capacity(max * 32);
1453 payloads.push('[');
1454 for i in 0..(max + 1) {
1455 if i > 0 {
1456 payloads.push(',');
1457 }
1458 payloads.push_str(&format!(r#"{{"fields":{{"id":{i},"name":"x"}}}}"#));
1459 }
1460 payloads.push(']');
1461 let frame_body = format!(
1462 r#"{{"collection":"events_587_oversize","idempotency_key":"k-oversize-587","payloads":{payloads}}}"#
1463 );
1464 let frame = Frame::new(MessageKind::BulkInsert, 500, frame_body.into_bytes());
1465 let reply = run_insert_dispatch(&runtime, &frame);
1466
1467 assert_eq!(reply.kind, MessageKind::Error);
1468 let body: JsonValue = serde_json::from_slice(&reply.payload).expect("err body json");
1469 assert_eq!(
1470 body.get("code").and_then(JsonValue::as_str),
1471 Some("BatchTooLarge")
1472 );
1473 let qr = runtime
1474 .execute_query("SELECT name FROM events_587_oversize")
1475 .expect("scan");
1476 assert!(
1477 qr.result.records.is_empty(),
1478 "oversize batch leaked rows into storage"
1479 );
1480 }
1481
1482 async fn read_one_frame<R: tokio::io::AsyncRead + Unpin>(r: &mut R) -> Frame {
1484 let mut header = [0u8; FRAME_HEADER_SIZE];
1485 r.read_exact(&mut header).await.expect("read header");
1486 let length = u32::from_le_bytes([header[0], header[1], header[2], header[3]]) as usize;
1487 let mut buf = vec![0u8; length];
1488 buf[..FRAME_HEADER_SIZE].copy_from_slice(&header);
1489 if length > FRAME_HEADER_SIZE {
1490 r.read_exact(&mut buf[FRAME_HEADER_SIZE..])
1491 .await
1492 .expect("read body");
1493 }
1494 let (frame, _) = decode_frame(&buf).expect("decode");
1495 frame
1496 }
1497
1498 fn stream_start_payload(coll: &str, cols: &[&str]) -> Vec<u8> {
1499 let mut p = Vec::new();
1500 p.extend_from_slice(&(coll.len() as u16).to_le_bytes());
1501 p.extend_from_slice(coll.as_bytes());
1502 p.extend_from_slice(&(cols.len() as u16).to_le_bytes());
1503 for c in cols {
1504 p.extend_from_slice(&(c.len() as u16).to_le_bytes());
1505 p.extend_from_slice(c.as_bytes());
1506 }
1507 p
1508 }
1509
1510 fn stream_rows_payload(rows: &[(i64, &str)]) -> Vec<u8> {
1511 let mut p = Vec::new();
1512 p.extend_from_slice(&(rows.len() as u32).to_le_bytes());
1513 for (id, name) in rows {
1514 crate::wire::protocol::encode_value(
1515 &mut p,
1516 &crate::storage::schema::Value::Integer(*id),
1517 );
1518 crate::wire::protocol::encode_value(
1519 &mut p,
1520 &crate::storage::schema::Value::text(name.to_string()),
1521 );
1522 }
1523 p
1524 }
1525
1526 #[tokio::test]
1533 async fn bulk_stream_rows_success_emits_no_response_frame() {
1534 let runtime = std::sync::Arc::new(RedDBRuntime::in_memory().expect("runtime"));
1536 runtime
1537 .execute_query("CREATE TABLE target (id INT, name TEXT)")
1538 .expect("create table");
1539
1540 let (server_io, mut client) = tokio::io::duplex(64 * 1024);
1543
1544 let server_task = tokio::spawn(async move {
1545 let _ = handle_session(server_io, runtime, None, None).await;
1546 });
1547
1548 client.write_all(&[1u8]).await.expect("write minor");
1551
1552 let hello_payload =
1554 br#"{"versions":[1],"auth_methods":["anonymous"],"features":0,"client_name":"test"}"#
1555 .to_vec();
1556 let hello = encode_frame(&Frame::new(MessageKind::Hello, 1, hello_payload));
1557 client.write_all(&hello).await.expect("write hello");
1558
1559 let ack = read_one_frame(&mut client).await;
1561 assert_eq!(ack.kind, MessageKind::HelloAck);
1562
1563 let authresp = encode_frame(&Frame::new(MessageKind::AuthResponse, 2, b"{}".to_vec()));
1565 client.write_all(&authresp).await.expect("write authresp");
1566
1567 let auth_ok = read_one_frame(&mut client).await;
1569 assert_eq!(auth_ok.kind, MessageKind::AuthOk);
1570
1571 let start = encode_frame(&Frame::new(
1573 MessageKind::BulkStreamStart,
1574 3,
1575 stream_start_payload("target", &["id", "name"]),
1576 ));
1577 client.write_all(&start).await.expect("write start");
1578 let start_ack = read_one_frame(&mut client).await;
1579 assert_eq!(start_ack.kind, MessageKind::BulkStreamAck);
1580 assert_eq!(start_ack.correlation_id, 3);
1581
1582 let rows = encode_frame(&Frame::new(
1584 MessageKind::BulkStreamRows,
1585 4,
1586 stream_rows_payload(&[(1, "a"), (2, "b")]),
1587 ));
1588 client.write_all(&rows).await.expect("write rows");
1589
1590 let commit = encode_frame(&Frame::new(MessageKind::BulkStreamCommit, 5, vec![]));
1596 client.write_all(&commit).await.expect("write commit");
1597
1598 let next = read_one_frame(&mut client).await;
1599 assert_eq!(
1600 next.kind,
1601 MessageKind::BulkOk,
1602 "expected BulkOk after commit; got {:?} — BulkStreamRows leaked an ack frame",
1603 next.kind
1604 );
1605 assert_eq!(
1606 next.correlation_id, 5,
1607 "commit response must carry the commit's correlation id"
1608 );
1609
1610 let bye = encode_frame(&Frame::new(MessageKind::Bye, 6, vec![]));
1612 client.write_all(&bye).await.expect("write bye");
1613 let _ = read_one_frame(&mut client).await; drop(client);
1615 let _ = server_task.await;
1616 }
1617
1618 #[tokio::test]
1621 async fn bulk_stream_rows_error_still_emits_error_frame() {
1622 let runtime = std::sync::Arc::new(RedDBRuntime::in_memory().expect("runtime"));
1623 let (server_io, mut client) = tokio::io::duplex(64 * 1024);
1624
1625 let server_task = tokio::spawn(async move {
1626 let _ = handle_session(server_io, runtime, None, None).await;
1627 });
1628
1629 client.write_all(&[1u8]).await.unwrap();
1630 let hello_payload =
1631 br#"{"versions":[1],"auth_methods":["anonymous"],"features":0}"#.to_vec();
1632 client
1633 .write_all(&encode_frame(&Frame::new(
1634 MessageKind::Hello,
1635 1,
1636 hello_payload,
1637 )))
1638 .await
1639 .unwrap();
1640 let _ack = read_one_frame(&mut client).await;
1641 client
1642 .write_all(&encode_frame(&Frame::new(
1643 MessageKind::AuthResponse,
1644 2,
1645 b"{}".to_vec(),
1646 )))
1647 .await
1648 .unwrap();
1649 let _auth_ok = read_one_frame(&mut client).await;
1650
1651 let rows = encode_frame(&Frame::new(
1655 MessageKind::BulkStreamRows,
1656 7,
1657 stream_rows_payload(&[(1, "a")]),
1658 ));
1659 client.write_all(&rows).await.unwrap();
1660 let resp = read_one_frame(&mut client).await;
1661 assert_eq!(resp.kind, MessageKind::Error);
1662 assert_eq!(resp.correlation_id, 7);
1663
1664 drop(client);
1665 let _ = server_task.await;
1666 }
1667}