1use std::io;
13use std::sync::Arc;
14
15use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
16
17use crate::auth::store::AuthStore;
18use crate::runtime::RedDBRuntime;
19use crate::serde_json::{self, Value as JsonValue};
20use reddb_wire::query_with_params::{
21 decode_query_with_params, ParamValue as RedWireParamValue, FEATURE_PARAMS,
22};
23
24use super::auth::{
25 build_auth_fail, build_auth_ok, build_hello_ack, pick_auth_method, validate_auth_response,
26 AuthOutcome, Hello,
27};
28use super::codec::{decode_frame, encode_frame};
29use super::frame::{Frame, MessageDirection, MessageKind, FRAME_HEADER_SIZE};
30use super::{FrameBuilder, MAX_KNOWN_MINOR_VERSION, REDWIRE_MAGIC};
31
32#[derive(Debug)]
33struct AuthedSession {
34 username: String,
35 #[allow(dead_code)]
36 session_id: String,
37}
38
39pub async fn handle_session<S>(
40 mut stream: S,
41 runtime: Arc<RedDBRuntime>,
42 auth_store: Option<Arc<AuthStore>>,
43 oauth: Option<Arc<crate::auth::oauth::OAuthValidator>>,
44) -> io::Result<()>
45where
46 S: AsyncRead + AsyncWrite + Unpin + Send,
47{
48 let session = perform_handshake(
52 &mut stream,
53 runtime.as_ref(),
54 auth_store.as_deref(),
55 oauth.as_deref(),
56 )
57 .await?;
58 if session.is_none() {
59 return Ok(());
60 }
61 let _session = session.unwrap();
62
63 let mut stream_session: Option<crate::wire::listener::BulkStreamSession> = None;
66 let mut prepared_stmts: std::collections::HashMap<u32, crate::wire::listener::PreparedStmt> =
67 std::collections::HashMap::new();
68
69 let mut buf = vec![0u8; FRAME_HEADER_SIZE];
70 loop {
71 if let Err(err) = stream.read_exact(&mut buf[..FRAME_HEADER_SIZE]).await {
73 if err.kind() == io::ErrorKind::UnexpectedEof {
74 return Ok(());
75 }
76 return Err(err);
77 }
78 let length = u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
79 if length < FRAME_HEADER_SIZE || length > super::frame::MAX_FRAME_SIZE as usize {
80 return Err(io::Error::other(format!("invalid frame length {length}")));
81 }
82 if buf.len() < length {
83 buf.resize(length, 0);
84 }
85 let payload_len = length - FRAME_HEADER_SIZE;
86 if payload_len > 0 {
87 stream
88 .read_exact(&mut buf[FRAME_HEADER_SIZE..length])
89 .await?;
90 }
91 let (frame, _) = decode_frame(&buf[..length])
92 .map_err(|e| io::Error::other(format!("decode frame: {e}")))?;
93
94 if frame.kind.direction() == MessageDirection::ServerToClient {
99 let err_frame = FrameBuilder::reply_to(frame.correlation_id)
100 .kind(MessageKind::Error)
101 .payload(format!("redwire: {:?} is server-only", frame.kind).into_bytes())
102 .build()
103 .map_err(|e| io::Error::other(format!("build Error frame: {e}")))?;
104 stream.write_all(&encode_frame(&err_frame)).await?;
105 continue;
106 }
107
108 match frame.kind {
109 MessageKind::Bye => {
110 let bye = encode_frame(&build_reply(
111 frame.correlation_id,
112 MessageKind::Bye,
113 vec![],
114 )?);
115 let _ = stream.write_all(&bye).await;
116 return Ok(());
117 }
118 MessageKind::Ping => {
119 let pong = encode_frame(&build_reply(
120 frame.correlation_id,
121 MessageKind::Pong,
122 vec![],
123 )?);
124 stream.write_all(&pong).await?;
125 }
126 MessageKind::Query => {
127 let response = run_query(&runtime, &frame);
128 stream.write_all(&encode_frame(&response)).await?;
129 }
130 MessageKind::QueryWithParams => {
131 let response = run_query_with_params(&runtime, &frame);
132 stream.write_all(&encode_frame(&response)).await?;
133 }
134 MessageKind::BulkInsert => {
138 let response = run_insert_dispatch(&runtime, &frame);
139 stream.write_all(&encode_frame(&response)).await?;
140 }
141 MessageKind::BulkInsertBinary => {
145 let raw =
146 crate::wire::listener::handle_bulk_insert_binary(&runtime, &frame.payload);
147 stream
148 .write_all(&encode_frame(&rewrap_handler_response(&raw, &frame)))
149 .await?;
150 }
151 MessageKind::BulkInsertPrevalidated => {
152 let raw = crate::wire::listener::handle_bulk_insert_binary_prevalidated(
153 &runtime,
154 &frame.payload,
155 );
156 stream
157 .write_all(&encode_frame(&rewrap_handler_response(&raw, &frame)))
158 .await?;
159 }
160 MessageKind::QueryBinary => {
161 let raw = crate::wire::listener::handle_query_binary(&runtime, &frame.payload);
162 stream
163 .write_all(&encode_frame(&rewrap_handler_response(&raw, &frame)))
164 .await?;
165 }
166 MessageKind::BulkStreamStart => {
169 let raw =
170 crate::wire::listener::handle_stream_start(&frame.payload, &mut stream_session);
171 stream
172 .write_all(&encode_frame(&rewrap_handler_response(&raw, &frame)))
173 .await?;
174 }
175 MessageKind::BulkStreamRows => {
176 let raw = crate::wire::listener::handle_stream_rows(
177 &runtime,
178 &frame.payload,
179 &mut stream_session,
180 );
181 if !raw.is_empty() {
187 stream
188 .write_all(&encode_frame(&rewrap_handler_response(&raw, &frame)))
189 .await?;
190 }
191 }
192 MessageKind::BulkStreamCommit => {
193 let raw =
194 crate::wire::listener::handle_stream_commit(&runtime, &mut stream_session);
195 stream
196 .write_all(&encode_frame(&rewrap_handler_response(&raw, &frame)))
197 .await?;
198 }
199 MessageKind::Prepare => {
203 let raw = crate::wire::listener::handle_prepare(
204 &runtime,
205 &frame.payload,
206 &mut prepared_stmts,
207 );
208 stream
209 .write_all(&encode_frame(&rewrap_handler_response(&raw, &frame)))
210 .await?;
211 }
212 MessageKind::ExecutePrepared => {
213 let raw = crate::wire::listener::handle_execute_prepared(
214 &runtime,
215 &frame.payload,
216 &prepared_stmts,
217 );
218 stream
219 .write_all(&encode_frame(&rewrap_handler_response(&raw, &frame)))
220 .await?;
221 }
222 MessageKind::Get => {
223 let response = run_get(&runtime, &frame);
224 stream.write_all(&encode_frame(&response)).await?;
225 }
226 MessageKind::Delete => {
227 let response = run_delete(&runtime, &frame);
228 stream.write_all(&encode_frame(&response)).await?;
229 }
230 other => {
231 let err_frame = FrameBuilder::reply_to(frame.correlation_id)
232 .kind(MessageKind::Error)
233 .payload(format!("redwire: cannot dispatch {other:?} yet").into_bytes())
234 .build()
235 .map_err(|e| io::Error::other(format!("build Error frame: {e}")))?;
236 let err = encode_frame(&err_frame);
237 stream.write_all(&err).await?;
238 }
239 }
240 }
241}
242
243async fn perform_handshake<S>(
246 stream: &mut S,
247 runtime: &RedDBRuntime,
248 auth_store: Option<&AuthStore>,
249 oauth: Option<&crate::auth::oauth::OAuthValidator>,
250) -> io::Result<Option<AuthedSession>>
251where
252 S: AsyncRead + AsyncWrite + Unpin + Send,
253{
254 let mut minor_buf = [0u8; 1];
256 stream.read_exact(&mut minor_buf).await?;
257 let minor = minor_buf[0];
258 if minor > MAX_KNOWN_MINOR_VERSION {
259 return Ok(None);
263 }
264
265 let hello = read_frame(stream).await?;
267 if hello.kind != MessageKind::Hello {
268 let fail = encode_frame(&build_reply(
269 hello.correlation_id,
270 MessageKind::AuthFail,
271 build_auth_fail("first frame after magic must be Hello"),
272 )?);
273 let _ = stream.write_all(&fail).await;
274 return Ok(None);
275 }
276 let hello_msg = match Hello::from_payload(&hello.payload) {
277 Ok(h) => h,
278 Err(e) => {
279 let fail = encode_frame(&build_reply(
280 hello.correlation_id,
281 MessageKind::AuthFail,
282 build_auth_fail(&e),
283 )?);
284 let _ = stream.write_all(&fail).await;
285 return Ok(None);
286 }
287 };
288
289 let chosen_version = hello_msg
290 .versions
291 .iter()
292 .copied()
293 .filter(|v| *v <= MAX_KNOWN_MINOR_VERSION)
294 .max()
295 .unwrap_or(0);
296 if chosen_version == 0 {
297 let fail = encode_frame(&build_reply(
298 hello.correlation_id,
299 MessageKind::AuthFail,
300 build_auth_fail("no overlapping protocol version"),
301 )?);
302 let _ = stream.write_all(&fail).await;
303 return Ok(None);
304 }
305
306 let server_anon_ok = auth_store.map(|s| !s.is_enabled()).unwrap_or(true);
307 let chosen = match pick_auth_method(&hello_msg.auth_methods, server_anon_ok) {
308 Some(m) => m,
309 None => {
310 let fail = encode_frame(&build_reply(
311 hello.correlation_id,
312 MessageKind::AuthFail,
313 build_auth_fail("no overlapping auth method"),
314 )?);
315 let _ = stream.write_all(&fail).await;
316 return Ok(None);
317 }
318 };
319
320 let server_features = FEATURE_PARAMS;
329 let topology = build_topology_for_hello_ack(runtime);
330 let ack_frame = FrameBuilder::reply_to(hello.correlation_id)
331 .kind(MessageKind::HelloAck)
332 .payload(build_hello_ack(
333 chosen_version,
334 chosen,
335 server_features,
336 topology.as_ref(),
337 ))
338 .build()
339 .map_err(|e| io::Error::other(format!("build HelloAck: {e}")))?;
340 let ack = encode_frame(&ack_frame);
341 stream.write_all(&ack).await?;
342
343 if chosen == "scram-sha-256" {
347 return perform_scram_handshake(stream, auth_store, hello.correlation_id, server_features)
348 .await;
349 }
350
351 let resp = read_frame(stream).await?;
354 if resp.kind != MessageKind::AuthResponse {
355 let fail = encode_frame(&build_reply(
356 resp.correlation_id,
357 MessageKind::AuthFail,
358 build_auth_fail("expected AuthResponse"),
359 )?);
360 let _ = stream.write_all(&fail).await;
361 return Ok(None);
362 }
363
364 if chosen == "oauth-jwt" {
368 let validator = match oauth {
369 Some(v) => v,
370 None => {
371 let fail = encode_frame(&build_reply(
372 resp.correlation_id,
373 MessageKind::AuthFail,
374 build_auth_fail("oauth-jwt requires RedWireConfig.oauth"),
375 )?);
376 let _ = stream.write_all(&fail).await;
377 return Ok(None);
378 }
379 };
380 let raw = match crate::serde_json::from_slice::<JsonValue>(&resp.payload)
381 .ok()
382 .and_then(|v| {
383 v.as_object()
384 .and_then(|o| o.get("jwt").cloned())
385 .and_then(|x| x.as_str().map(String::from))
386 }) {
387 Some(s) if !s.is_empty() => s,
388 _ => {
389 let fail = encode_frame(&build_reply(
390 resp.correlation_id,
391 MessageKind::AuthFail,
392 build_auth_fail("oauth-jwt: AuthResponse missing 'jwt' string"),
393 )?);
394 let _ = stream.write_all(&fail).await;
395 return Ok(None);
396 }
397 };
398 match super::auth::validate_oauth_jwt(validator, &raw) {
399 Ok((username, role)) => {
400 let session_id = super::auth::new_session_id_for_scram();
401 let ok = encode_frame(&build_reply(
402 resp.correlation_id,
403 MessageKind::AuthOk,
404 build_auth_ok(&session_id, &username, role, server_features),
405 )?);
406 stream.write_all(&ok).await?;
407 return Ok(Some(AuthedSession {
408 username,
409 session_id,
410 }));
411 }
412 Err(reason) => {
413 let fail = encode_frame(&build_reply(
414 resp.correlation_id,
415 MessageKind::AuthFail,
416 build_auth_fail(&format!("oauth-jwt: {reason}")),
417 )?);
418 let _ = stream.write_all(&fail).await;
419 return Ok(None);
420 }
421 }
422 }
423
424 match validate_auth_response(chosen, &resp.payload, auth_store) {
425 AuthOutcome::Authenticated {
426 username,
427 role,
428 session_id,
429 } => {
430 let ok_frame = FrameBuilder::reply_to(resp.correlation_id)
431 .kind(MessageKind::AuthOk)
432 .payload(build_auth_ok(&session_id, &username, role, server_features))
433 .build()
434 .map_err(|e| io::Error::other(format!("build AuthOk: {e}")))?;
435 let ok = encode_frame(&ok_frame);
436 stream.write_all(&ok).await?;
437 Ok(Some(AuthedSession {
438 username,
439 session_id,
440 }))
441 }
442 AuthOutcome::Refused(reason) => {
443 let fail = encode_frame(&build_reply(
444 resp.correlation_id,
445 MessageKind::AuthFail,
446 build_auth_fail(&reason),
447 )?);
448 let _ = stream.write_all(&fail).await;
449 Ok(None)
450 }
451 }
452}
453
454async fn perform_scram_handshake<S>(
461 stream: &mut S,
462 auth_store: Option<&AuthStore>,
463 initial_correlation: u64,
464 server_features: u32,
465) -> io::Result<Option<AuthedSession>>
466where
467 S: AsyncRead + AsyncWrite + Unpin + Send,
468{
469 let store = match auth_store {
470 Some(s) => s,
471 None => {
472 let fail = encode_frame(&build_reply(
473 initial_correlation,
474 MessageKind::AuthFail,
475 build_auth_fail("scram-sha-256 requires an AuthStore"),
476 )?);
477 let _ = stream.write_all(&fail).await;
478 return Ok(None);
479 }
480 };
481
482 let cf = read_frame(stream).await?;
484 if cf.kind != MessageKind::AuthResponse {
485 let fail = encode_frame(&build_reply(
486 cf.correlation_id,
487 MessageKind::AuthFail,
488 build_auth_fail("expected AuthResponse(client-first-message)"),
489 )?);
490 let _ = stream.write_all(&fail).await;
491 return Ok(None);
492 }
493 let (username, client_nonce, client_first_bare) =
494 match super::auth::parse_scram_client_first(&cf.payload) {
495 Ok(t) => t,
496 Err(e) => {
497 let fail = encode_frame(&build_reply(
498 cf.correlation_id,
499 MessageKind::AuthFail,
500 build_auth_fail(&format!("scram client-first: {e}")),
501 )?);
502 let _ = stream.write_all(&fail).await;
503 return Ok(None);
504 }
505 };
506
507 let verifier = store.lookup_scram_verifier_global(&username);
516 let (salt, iter, stored_key, server_key, user_known) = match &verifier {
517 Some(v) => (v.salt.clone(), v.iter, v.stored_key, v.server_key, true),
518 None => (
519 crate::auth::store::random_bytes(16),
520 crate::auth::scram::DEFAULT_ITER,
521 [0u8; 32],
522 [0u8; 32],
523 false,
524 ),
525 };
526
527 let server_nonce = super::auth::new_server_nonce();
529 let server_first =
530 super::auth::build_scram_server_first(&client_nonce, &server_nonce, &salt, iter);
531 let req = encode_frame(&build_reply(
532 cf.correlation_id,
533 MessageKind::AuthRequest,
534 server_first.as_bytes().to_vec(),
535 )?);
536 stream.write_all(&req).await?;
537
538 let cfinal = read_frame(stream).await?;
540 if cfinal.kind != MessageKind::AuthResponse {
541 let fail = encode_frame(&build_reply(
542 cfinal.correlation_id,
543 MessageKind::AuthFail,
544 build_auth_fail("expected AuthResponse(client-final-message)"),
545 )?);
546 let _ = stream.write_all(&fail).await;
547 return Ok(None);
548 }
549 let (combined_nonce, presented_proof, client_final_no_proof) =
550 match super::auth::parse_scram_client_final(&cfinal.payload) {
551 Ok(t) => t,
552 Err(e) => {
553 let fail = encode_frame(&build_reply(
554 cfinal.correlation_id,
555 MessageKind::AuthFail,
556 build_auth_fail(&format!("scram client-final: {e}")),
557 )?);
558 let _ = stream.write_all(&fail).await;
559 return Ok(None);
560 }
561 };
562 let expected_combined = format!("{client_nonce}{server_nonce}");
563 if combined_nonce != expected_combined {
564 let fail = encode_frame(&build_reply(
565 cfinal.correlation_id,
566 MessageKind::AuthFail,
567 build_auth_fail("scram nonce mismatch — replay protection failed"),
568 )?);
569 let _ = stream.write_all(&fail).await;
570 return Ok(None);
571 }
572
573 let auth_message =
575 crate::auth::scram::auth_message(&client_first_bare, &server_first, &client_final_no_proof);
576 let proof_ok = if user_known {
577 let v = crate::auth::scram::ScramVerifier {
578 salt: salt.clone(),
579 iter,
580 stored_key,
581 server_key,
582 };
583 crate::auth::scram::verify_client_proof(&v, &auth_message, &presented_proof)
584 } else {
585 false
586 };
587 if !proof_ok {
588 let fail = encode_frame(&build_reply(
589 cfinal.correlation_id,
590 MessageKind::AuthFail,
591 build_auth_fail("invalid SCRAM proof"),
592 )?);
593 let _ = stream.write_all(&fail).await;
594 return Ok(None);
595 }
596
597 let role = store
599 .list_users()
600 .into_iter()
601 .find(|u| u.username == username)
602 .map(|u| u.role)
603 .unwrap_or(crate::auth::Role::Read);
604 let server_sig = crate::auth::scram::server_signature(&server_key, &auth_message);
605 let session_id = super::auth::new_session_id_for_scram();
606 let ok_payload = super::auth::build_scram_auth_ok(
607 &session_id,
608 &username,
609 role,
610 server_features,
611 &server_sig,
612 );
613 let ok = encode_frame(&build_reply(
614 cfinal.correlation_id,
615 MessageKind::AuthOk,
616 ok_payload,
617 )?);
618 stream.write_all(&ok).await?;
619 Ok(Some(AuthedSession {
620 username,
621 session_id,
622 }))
623}
624
625async fn read_frame<S>(stream: &mut S) -> io::Result<Frame>
626where
627 S: AsyncRead + AsyncWrite + Unpin + Send,
628{
629 let mut header = [0u8; FRAME_HEADER_SIZE];
630 stream.read_exact(&mut header).await?;
631 let length = u32::from_le_bytes([header[0], header[1], header[2], header[3]]) as usize;
632 if length < FRAME_HEADER_SIZE || length > super::frame::MAX_FRAME_SIZE as usize {
633 return Err(io::Error::other(format!(
634 "redwire frame length {length} out of range"
635 )));
636 }
637 let mut buf = vec![0u8; length];
638 buf[..FRAME_HEADER_SIZE].copy_from_slice(&header);
639 if length > FRAME_HEADER_SIZE {
640 stream
641 .read_exact(&mut buf[FRAME_HEADER_SIZE..length])
642 .await?;
643 }
644 let (frame, _) =
645 decode_frame(&buf).map_err(|e| io::Error::other(format!("decode frame: {e}")))?;
646 Ok(frame)
647}
648
649fn run_query(runtime: &RedDBRuntime, frame: &Frame) -> Frame {
650 let sql = match std::str::from_utf8(&frame.payload) {
651 Ok(s) => s,
652 Err(_) => {
653 return error_frame(frame.correlation_id, "Query payload must be UTF-8 SQL");
654 }
655 };
656 match runtime.execute_query(sql) {
657 Ok(result) => {
658 let mut obj = crate::serde_json::Map::new();
659 obj.insert("ok".to_string(), JsonValue::Bool(true));
660 obj.insert(
661 "statement".to_string(),
662 JsonValue::String(result.statement_type.to_string()),
663 );
664 obj.insert(
665 "affected".to_string(),
666 JsonValue::Number(result.affected_rows as f64),
667 );
668 let payload = serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default();
669 build_dispatch_reply(frame.correlation_id, MessageKind::Result, payload)
670 }
671 Err(err) => error_frame(frame.correlation_id, &err.to_string()),
672 }
673}
674
675fn run_query_with_params(runtime: &RedDBRuntime, frame: &Frame) -> Frame {
676 let (sql, params) = match decode_query_with_params(&frame.payload) {
677 Ok(decoded) => decoded,
678 Err(err) => return error_frame(frame.correlation_id, &err.to_string()),
679 };
680 let params = params
681 .into_iter()
682 .map(param_to_schema_value)
683 .collect::<Vec<_>>();
684 let parsed = match crate::storage::query::modes::parse_multi(&sql) {
685 Ok(parsed) => parsed,
686 Err(err) => return error_frame(frame.correlation_id, &err.to_string()),
687 };
688 let bound = match crate::storage::query::user_params::bind(&parsed, ¶ms) {
689 Ok(bound) => bound,
690 Err(err) => return error_frame(frame.correlation_id, &err.to_string()),
691 };
692 match runtime.execute_query_expr(bound) {
693 Ok(result) => {
694 let is_mutation = matches!(result.statement_type, "insert" | "update" | "delete");
695 if is_mutation {
696 let post_lsn = runtime.cdc_current_lsn();
697 if let Err(err) = runtime.enforce_commit_policy(post_lsn) {
698 return error_frame(frame.correlation_id, &err.to_string());
699 }
700 }
701 let payload = serde_json::to_vec(
702 &crate::presentation::query_result_json::runtime_query_json(&result, &None, &None),
703 )
704 .unwrap_or_default();
705 build_dispatch_reply(frame.correlation_id, MessageKind::Result, payload)
706 }
707 Err(err) => error_frame(frame.correlation_id, &err.to_string()),
708 }
709}
710
711fn param_to_schema_value(value: RedWireParamValue) -> crate::storage::schema::Value {
712 use crate::storage::schema::Value;
713 match value {
714 RedWireParamValue::Null => Value::Null,
715 RedWireParamValue::Bool(value) => Value::Boolean(value),
716 RedWireParamValue::Int(value) => Value::Integer(value),
717 RedWireParamValue::Float(value) => Value::Float(value),
718 RedWireParamValue::Text(value) => Value::Text(Arc::from(value.as_str())),
719 RedWireParamValue::Bytes(value) => Value::Blob(value),
720 RedWireParamValue::Vector(value) => Value::Vector(value),
721 RedWireParamValue::Json(value) => Value::Json(value),
722 RedWireParamValue::Timestamp(value) => Value::Timestamp(value),
723 RedWireParamValue::Uuid(value) => Value::Uuid(value),
724 }
725}
726
727fn run_insert_dispatch(runtime: &RedDBRuntime, frame: &Frame) -> Frame {
745 let v: JsonValue = match serde_json::from_slice(&frame.payload) {
746 Ok(v) => v,
747 Err(e) => return error_frame(frame.correlation_id, &format!("Insert: invalid JSON: {e}")),
748 };
749 let obj = match v.as_object() {
750 Some(o) => o,
751 None => {
752 return error_frame(
753 frame.correlation_id,
754 "Insert: payload must be a JSON object",
755 )
756 }
757 };
758 let collection = match obj.get("collection").and_then(|x| x.as_str()) {
759 Some(s) if !s.is_empty() => s,
760 _ => return error_frame(frame.correlation_id, "Insert: missing 'collection' string"),
761 };
762
763 let idempotency_key = obj.get("idempotency_key").and_then(|x| x.as_str());
769 let batch_flag = obj
770 .get("batch")
771 .and_then(|x| x.as_bool())
772 .unwrap_or(false);
773 if idempotency_key.is_some() || batch_flag {
774 let items = match obj.get("payloads").and_then(|x| x.as_array()) {
775 Some(rows) => &rows[..],
776 None => {
777 return error_frame(
778 frame.correlation_id,
779 "BatchInsert: missing 'payloads' array",
780 )
781 }
782 };
783 let outcome = crate::server::handlers_entity::process_batch_insert(
784 runtime,
785 collection,
786 items,
787 idempotency_key,
788 );
789 let kind = if (200..300).contains(&outcome.status) {
794 MessageKind::BulkOk
795 } else {
796 MessageKind::Error
797 };
798 return build_dispatch_reply(frame.correlation_id, kind, outcome.body);
799 }
800
801 if let Some(rows) = obj.get("payloads").and_then(|x| x.as_array()) {
802 let mut objects = Vec::with_capacity(rows.len());
803 for entry in rows {
804 objects.push(match entry.as_object() {
805 Some(o) => o,
806 None => {
807 return error_frame(
808 frame.correlation_id,
809 "Insert: each payload must be a JSON object",
810 )
811 }
812 });
813 }
814
815 if crate::rpc_stdio::should_bulk_insert_graph(runtime, collection, &objects) {
816 return match crate::rpc_stdio::bulk_insert_graph(runtime, collection, &objects) {
817 Ok(body) => {
818 let payload = serde_json::to_vec(&body).unwrap_or_default();
819 build_dispatch_reply(frame.correlation_id, MessageKind::BulkOk, payload)
820 }
821 Err(err) => error_frame(frame.correlation_id, &err.to_string()),
822 };
823 }
824
825 let mut affected: u64 = 0;
826 let mut ids = Vec::with_capacity(objects.len());
827 for row in objects {
828 let sql = crate::rpc_stdio::build_insert_sql(collection, row.iter());
829 match runtime.execute_query(&sql) {
830 Ok(qr) => {
831 affected += qr.affected_rows;
832 if let Some(id) = crate::rpc_stdio::insert_result_to_json(&qr).get("id") {
833 ids.push(id.clone());
834 }
835 }
836 Err(err) => return error_frame(frame.correlation_id, &err.to_string()),
837 }
838 }
839 let mut out = crate::serde_json::Map::new();
840 out.insert("affected".to_string(), JsonValue::Number(affected as f64));
841 out.insert("ids".to_string(), JsonValue::Array(ids));
842 let payload = serde_json::to_vec(&JsonValue::Object(out)).unwrap_or_default();
843 return build_dispatch_reply(frame.correlation_id, MessageKind::BulkOk, payload);
844 }
845
846 let row = match obj.get("payload").and_then(|x| x.as_object()) {
847 Some(o) => o,
848 None => {
849 return error_frame(
850 frame.correlation_id,
851 "Insert: missing 'payload' object or 'payloads' array",
852 )
853 }
854 };
855 let sql = crate::rpc_stdio::build_insert_sql(collection, row.iter());
856 match runtime.execute_query(&sql) {
857 Ok(qr) => {
858 let body = crate::rpc_stdio::insert_result_to_json(&qr);
859 let payload = serde_json::to_vec(&body).unwrap_or_default();
860 build_dispatch_reply(frame.correlation_id, MessageKind::BulkOk, payload)
861 }
862 Err(err) => error_frame(frame.correlation_id, &err.to_string()),
863 }
864}
865
866fn build_topology_for_hello_ack(runtime: &RedDBRuntime) -> Option<reddb_wire::topology::Topology> {
878 use crate::auth::middleware::AuthResult;
879 use crate::replication::{LagConfig, TopologyAdvertiser};
880 use reddb_wire::topology::Endpoint;
881
882 let db = runtime.db();
883 let primary_endpoint = Endpoint {
884 addr: runtime.config_string("red.redwire.advertise_addr", ""),
885 region: db.options().replication.region.clone(),
886 };
887 let (replicas, current_lsn, epoch) = match db.replication.as_ref() {
888 Some(repl) => (
889 repl.replica_snapshots(),
890 repl.wal_buffer.current_lsn(),
891 repl.topology_epoch(),
892 ),
893 None => (Vec::new(), 0u64, 0u64),
894 };
895 let lag = LagConfig::from_now();
896 Some(TopologyAdvertiser::advertise(
897 &replicas,
898 &AuthResult::Anonymous,
899 epoch,
900 primary_endpoint,
901 current_lsn,
902 &lag,
903 ))
904}
905
906fn error_frame(correlation_id: u64, msg: &str) -> Frame {
907 FrameBuilder::reply_to(correlation_id)
912 .kind(MessageKind::Error)
913 .payload(msg.as_bytes().to_vec())
914 .build()
915 .expect("error frame fits in MAX_FRAME_SIZE")
916}
917
918fn build_reply(correlation_id: u64, kind: MessageKind, payload: Vec<u8>) -> io::Result<Frame> {
927 FrameBuilder::reply_to(correlation_id)
928 .kind(kind)
929 .payload(payload)
930 .build()
931 .map_err(|e| io::Error::other(format!("build {kind:?}: {e}")))
932}
933
934fn build_dispatch_reply(correlation_id: u64, kind: MessageKind, payload: Vec<u8>) -> Frame {
941 FrameBuilder::reply_to(correlation_id)
942 .kind(kind)
943 .payload(payload)
944 .build()
945 .unwrap_or_else(|e| error_frame(correlation_id, &e.to_string()))
946}
947
948fn rewrap_handler_response(raw_bytes: &[u8], req: &Frame) -> Frame {
960 if raw_bytes.len() < 5 {
961 return error_frame(
962 req.correlation_id,
963 "fast-path handler returned a truncated frame",
964 );
965 }
966 let kind_byte = raw_bytes[4];
967 let kind = MessageKind::from_u8(kind_byte).unwrap_or(MessageKind::Error);
968 let body = raw_bytes[5..].to_vec();
969 build_dispatch_reply(req.correlation_id, kind, body)
970}
971
972fn run_get(runtime: &RedDBRuntime, frame: &Frame) -> Frame {
976 let v: JsonValue = match serde_json::from_slice(&frame.payload) {
977 Ok(v) => v,
978 Err(e) => return error_frame(frame.correlation_id, &format!("Get: invalid JSON: {e}")),
979 };
980 let obj = match v.as_object() {
981 Some(o) => o,
982 None => return error_frame(frame.correlation_id, "Get: payload must be a JSON object"),
983 };
984 let collection = match obj.get("collection").and_then(|x| x.as_str()) {
985 Some(s) if !s.is_empty() => s,
986 _ => return error_frame(frame.correlation_id, "Get: missing 'collection' string"),
987 };
988 let id = match obj.get("id").and_then(|x| x.as_str()) {
989 Some(s) if !s.is_empty() => s,
990 _ => return error_frame(frame.correlation_id, "Get: missing 'id' string"),
991 };
992 let id_lit = crate::rpc_stdio::value_to_sql_literal(&JsonValue::String(id.to_string()));
995 let sql = format!("SELECT * FROM {collection} WHERE _id = {id_lit} LIMIT 1");
996 match runtime.execute_query(&sql) {
997 Ok(qr) => {
998 let mut out = crate::serde_json::Map::new();
999 out.insert("ok".to_string(), JsonValue::Bool(true));
1000 out.insert(
1001 "found".to_string(),
1002 JsonValue::Bool(!qr.result.records.is_empty()),
1003 );
1004 let payload = serde_json::to_vec(&JsonValue::Object(out)).unwrap_or_default();
1007 build_dispatch_reply(frame.correlation_id, MessageKind::Result, payload)
1008 }
1009 Err(err) => error_frame(frame.correlation_id, &err.to_string()),
1010 }
1011}
1012
1013fn run_delete(runtime: &RedDBRuntime, frame: &Frame) -> Frame {
1017 let v: JsonValue = match serde_json::from_slice(&frame.payload) {
1018 Ok(v) => v,
1019 Err(e) => return error_frame(frame.correlation_id, &format!("Delete: invalid JSON: {e}")),
1020 };
1021 let obj = match v.as_object() {
1022 Some(o) => o,
1023 None => {
1024 return error_frame(
1025 frame.correlation_id,
1026 "Delete: payload must be a JSON object",
1027 )
1028 }
1029 };
1030 let collection = match obj.get("collection").and_then(|x| x.as_str()) {
1031 Some(s) if !s.is_empty() => s,
1032 _ => return error_frame(frame.correlation_id, "Delete: missing 'collection' string"),
1033 };
1034 let id = match obj.get("id").and_then(|x| x.as_str()) {
1035 Some(s) if !s.is_empty() => s,
1036 _ => return error_frame(frame.correlation_id, "Delete: missing 'id' string"),
1037 };
1038 let id_lit = crate::rpc_stdio::value_to_sql_literal(&JsonValue::String(id.to_string()));
1039 let sql = format!("DELETE FROM {collection} WHERE _id = {id_lit}");
1040 match runtime.execute_query(&sql) {
1041 Ok(qr) => {
1042 let mut out = crate::serde_json::Map::new();
1043 out.insert(
1044 "affected".to_string(),
1045 JsonValue::Number(qr.affected_rows as f64),
1046 );
1047 let payload = serde_json::to_vec(&JsonValue::Object(out)).unwrap_or_default();
1048 build_dispatch_reply(frame.correlation_id, MessageKind::DeleteOk, payload)
1049 }
1050 Err(err) => error_frame(frame.correlation_id, &err.to_string()),
1051 }
1052}
1053
1054#[cfg(test)]
1055mod tests {
1056 use super::*;
1057
1058 use crate::runtime::RedDBRuntime;
1059 use tokio::io::{AsyncReadExt, AsyncWriteExt};
1060
1061 fn create_graph_collection(runtime: &RedDBRuntime, name: &str) {
1062 let db = runtime.db();
1063 db.store()
1064 .create_collection(name)
1065 .expect("create collection");
1066 let now = std::time::SystemTime::now()
1067 .duration_since(std::time::UNIX_EPOCH)
1068 .unwrap_or_default()
1069 .as_millis();
1070 db.save_collection_contract(crate::physical::CollectionContract {
1071 name: name.to_string(),
1072 declared_model: crate::catalog::CollectionModel::Graph,
1073 schema_mode: crate::catalog::SchemaMode::Dynamic,
1074 origin: crate::physical::ContractOrigin::Explicit,
1075 version: 1,
1076 created_at_unix_ms: now,
1077 updated_at_unix_ms: now,
1078 default_ttl_ms: None,
1079 vector_dimension: None,
1080 vector_metric: None,
1081 context_index_fields: Vec::new(),
1082 declared_columns: Vec::new(),
1083 table_def: None,
1084 timestamps_enabled: false,
1085 context_index_enabled: false,
1086 metrics_raw_retention_ms: None,
1087 metrics_rollup_policies: Vec::new(),
1088 metrics_tenant_identity: None,
1089 metrics_namespace: None,
1090 append_only: false,
1091 subscriptions: Vec::new(),
1092 session_key: None,
1093 session_gap_ms: None,
1094 retention_duration_ms: None,
1095 })
1096 .expect("save graph contract");
1097 }
1098
1099 #[test]
1100 fn magic_byte_is_0xfe() {
1101 assert_eq!(REDWIRE_MAGIC, 0xFE);
1102 }
1103
1104 #[test]
1105 fn redwire_bulk_insert_graph_rows_returns_ids() {
1106 let runtime = RedDBRuntime::in_memory().expect("runtime");
1107 create_graph_collection(&runtime, "network");
1108
1109 let nodes = Frame::new(
1110 MessageKind::BulkInsert,
1111 7,
1112 br#"{"collection":"network","payloads":[{"label":"Host","name":"app"},{"label":"Host","name":"db"}]}"#.to_vec(),
1113 );
1114 let nodes_reply = run_insert_dispatch(&runtime, &nodes);
1115 assert_eq!(nodes_reply.kind, MessageKind::BulkOk);
1116 let node_body: JsonValue =
1117 serde_json::from_slice(&nodes_reply.payload).expect("nodes json");
1118 assert_eq!(
1119 node_body.get("affected").and_then(JsonValue::as_u64),
1120 Some(2)
1121 );
1122 let ids = node_body
1123 .get("ids")
1124 .and_then(JsonValue::as_array)
1125 .expect("node ids");
1126 assert_eq!(ids.len(), 2);
1127
1128 let from = ids[0].as_u64().expect("from id");
1129 let to = ids[1].as_u64().expect("to id");
1130 let edges = Frame::new(
1131 MessageKind::BulkInsert,
1132 8,
1133 format!(
1134 r#"{{"collection":"network","payloads":[{{"label":"connects","from":{from},"to":{to},"role":"primary"}}]}}"#
1135 )
1136 .into_bytes(),
1137 );
1138 let edges_reply = run_insert_dispatch(&runtime, &edges);
1139 assert_eq!(edges_reply.kind, MessageKind::BulkOk);
1140 let edge_body: JsonValue =
1141 serde_json::from_slice(&edges_reply.payload).expect("edges json");
1142 assert_eq!(
1143 edge_body.get("affected").and_then(JsonValue::as_u64),
1144 Some(1)
1145 );
1146 assert_eq!(
1147 edge_body
1148 .get("ids")
1149 .and_then(JsonValue::as_array)
1150 .map(|ids| ids.len()),
1151 Some(1)
1152 );
1153 }
1154
1155 #[test]
1170 fn redwire_batch_insert_happy_path_returns_bulkok_with_count() {
1171 let runtime = RedDBRuntime::in_memory().expect("runtime");
1172 runtime
1173 .execute_query("CREATE TABLE events_587_ok (id INTEGER, name TEXT)")
1174 .expect("create table");
1175
1176 let frame = Frame::new(
1177 MessageKind::BulkInsert,
1178 100,
1179 br#"{
1180 "collection":"events_587_ok",
1181 "idempotency_key":"k-ok",
1182 "payloads":[
1183 {"fields":{"id":1,"name":"a"}},
1184 {"fields":{"id":2,"name":"b"}},
1185 {"fields":{"id":3,"name":"c"}}
1186 ]
1187 }"#
1188 .to_vec(),
1189 );
1190 let reply = run_insert_dispatch(&runtime, &frame);
1191 assert_eq!(reply.kind, MessageKind::BulkOk, "body={:?}", String::from_utf8_lossy(&reply.payload));
1192 let body: JsonValue =
1193 serde_json::from_slice(&reply.payload).expect("ok body json");
1194 assert_eq!(body.get("ok").and_then(JsonValue::as_bool), Some(true));
1195 assert_eq!(body.get("count").and_then(JsonValue::as_u64), Some(3));
1196
1197 let qr = runtime
1203 .execute_query("SELECT name FROM events_587_ok ORDER BY id ASC")
1204 .expect("scan");
1205 let names: Vec<String> = qr
1206 .result
1207 .records
1208 .iter()
1209 .filter_map(|record| match record.get("name") {
1210 Some(crate::storage::schema::Value::Text(s)) => Some(s.to_string()),
1211 _ => None,
1212 })
1213 .collect();
1214 assert_eq!(names, vec!["a", "b", "c"]);
1215 }
1216
1217 #[test]
1222 fn redwire_batch_insert_row_failure_rolls_back_with_row_index() {
1223 let runtime = RedDBRuntime::in_memory().expect("runtime");
1224 runtime
1225 .execute_query("CREATE TABLE events_587_rollback (id INTEGER, name TEXT)")
1226 .expect("create table");
1227
1228 let frame = Frame::new(
1231 MessageKind::BulkInsert,
1232 101,
1233 br#"{
1234 "collection":"events_587_rollback",
1235 "idempotency_key":"k-rollback",
1236 "payloads":[
1237 {"fields":{"id":1,"name":"a"}},
1238 {"not_fields":{"id":2}},
1239 {"fields":{"id":3,"name":"c"}}
1240 ]
1241 }"#
1242 .to_vec(),
1243 );
1244 let reply = run_insert_dispatch(&runtime, &frame);
1245 assert_eq!(reply.kind, MessageKind::Error);
1246 let body: JsonValue =
1247 serde_json::from_slice(&reply.payload).expect("err body json");
1248 assert_eq!(body.get("ok").and_then(JsonValue::as_bool), Some(false));
1249 assert_eq!(
1250 body.get("code").and_then(JsonValue::as_str),
1251 Some("RowParseFailure")
1252 );
1253 assert_eq!(body.get("row_index").and_then(JsonValue::as_u64), Some(1));
1254
1255 let qr = runtime
1258 .execute_query("SELECT name FROM events_587_rollback")
1259 .expect("scan");
1260 assert!(
1261 qr.result.records.is_empty(),
1262 "row 0 leaked despite row 1 rejection: {} rows present",
1263 qr.result.records.len()
1264 );
1265 }
1266
1267 #[test]
1275 fn redwire_batch_insert_idempotency_key_replays_cached_result() {
1276 let runtime = RedDBRuntime::in_memory().expect("runtime");
1277 runtime
1278 .execute_query("CREATE TABLE events_587_dedup (id INTEGER, name TEXT)")
1279 .expect("create table");
1280
1281 let key = format!(
1285 "redwire-587-{}",
1286 std::time::SystemTime::now()
1287 .duration_since(std::time::UNIX_EPOCH)
1288 .unwrap()
1289 .as_nanos()
1290 );
1291
1292 let frame1 = Frame::new(
1293 MessageKind::BulkInsert,
1294 200,
1295 format!(
1296 r#"{{
1297 "collection":"events_587_dedup",
1298 "idempotency_key":"{key}",
1299 "payloads":[{{"fields":{{"id":1,"name":"first"}}}}]
1300 }}"#
1301 )
1302 .into_bytes(),
1303 );
1304 let reply1 = run_insert_dispatch(&runtime, &frame1);
1305 assert_eq!(reply1.kind, MessageKind::BulkOk);
1306 let body1 = reply1.payload.clone();
1307
1308 let frame2 = Frame::new(
1312 MessageKind::BulkInsert,
1313 201,
1314 format!(
1315 r#"{{
1316 "collection":"events_587_dedup",
1317 "idempotency_key":"{key}",
1318 "payloads":[{{"fields":{{"id":2,"name":"second"}}}}]
1319 }}"#
1320 )
1321 .into_bytes(),
1322 );
1323 let reply2 = run_insert_dispatch(&runtime, &frame2);
1324 assert_eq!(reply2.kind, MessageKind::BulkOk);
1325 assert_eq!(reply2.payload, body1, "replay must return cached body byte-for-byte");
1326
1327 let qr = runtime
1328 .execute_query("SELECT name FROM events_587_dedup")
1329 .expect("scan");
1330 assert_eq!(
1331 qr.result.records.len(),
1332 1,
1333 "replay re-executed and committed the second row"
1334 );
1335 }
1336
1337 #[test]
1341 fn redwire_batch_insert_cache_shared_with_http_transport() {
1342 use crate::runtime::batch_insert::global_cache;
1343
1344 let runtime = RedDBRuntime::in_memory().expect("runtime");
1345 runtime
1346 .execute_query("CREATE TABLE events_587_shared (id INTEGER, name TEXT)")
1347 .expect("create table");
1348
1349 let key = format!(
1350 "shared-cache-587-{}",
1351 std::time::SystemTime::now()
1352 .duration_since(std::time::UNIX_EPOCH)
1353 .unwrap()
1354 .as_nanos()
1355 );
1356
1357 let frame = Frame::new(
1358 MessageKind::BulkInsert,
1359 300,
1360 format!(
1361 r#"{{
1362 "collection":"events_587_shared",
1363 "idempotency_key":"{key}",
1364 "payloads":[{{"fields":{{"id":1,"name":"x"}}}}]
1365 }}"#
1366 )
1367 .into_bytes(),
1368 );
1369 let reply = run_insert_dispatch(&runtime, &frame);
1370 assert_eq!(reply.kind, MessageKind::BulkOk);
1371
1372 let hit = global_cache()
1376 .lookup("events_587_shared", &key, std::time::Instant::now())
1377 .expect("shared cache must serve the RedWire write to HTTP");
1378 assert_eq!(hit.status, 200);
1379 assert_eq!(hit.body, reply.payload);
1380 }
1381
1382 #[test]
1387 fn redwire_batch_insert_schema_validation_rejects_unknown_field() {
1388 use crate::runtime::analytics_schema_registry as reg;
1389
1390 let runtime = RedDBRuntime::in_memory().expect("runtime");
1391 runtime
1392 .execute_query("CREATE TABLE events_587_schema (event_name TEXT, payload TEXT)")
1393 .expect("create table");
1394
1395 let schema =
1396 r#"{"type":"object","properties":{"url":{"type":"string"}},"required":["url"]}"#;
1397 reg::register(runtime.db().store().as_ref(), "click_587", schema)
1398 .expect("register schema");
1399
1400 let frame = Frame::new(
1401 MessageKind::BulkInsert,
1402 400,
1403 br#"{
1404 "collection":"events_587_schema",
1405 "idempotency_key":"k-schema",
1406 "payloads":[
1407 {"fields":{"event_name":"click_587","payload":"{\"url\":\"/a\"}"}},
1408 {"fields":{"event_name":"click_587","payload":"{\"url\":\"/b\",\"extra\":1}"}}
1409 ]
1410 }"#
1411 .to_vec(),
1412 );
1413 let reply = run_insert_dispatch(&runtime, &frame);
1414 assert_eq!(reply.kind, MessageKind::Error);
1415 let body: JsonValue =
1416 serde_json::from_slice(&reply.payload).expect("err body json");
1417 assert_eq!(
1418 body.get("code").and_then(JsonValue::as_str),
1419 Some("RowSchemaRejected")
1420 );
1421 assert_eq!(body.get("row_index").and_then(JsonValue::as_u64), Some(1));
1422
1423 let qr = runtime
1424 .execute_query("SELECT event_name FROM events_587_schema")
1425 .expect("scan");
1426 assert!(
1427 qr.result.records.is_empty(),
1428 "row 0 leaked despite row 1 schema rejection"
1429 );
1430 }
1431
1432 #[test]
1443 fn redwire_batch_insert_oversize_returns_error_before_storage() {
1444 let runtime = RedDBRuntime::in_memory().expect("runtime");
1445 runtime
1446 .execute_query("CREATE TABLE events_587_oversize (id INTEGER, name TEXT)")
1447 .expect("create table");
1448
1449 let max = 10_000usize;
1451 let mut payloads = String::with_capacity(max * 32);
1452 payloads.push('[');
1453 for i in 0..(max + 1) {
1454 if i > 0 {
1455 payloads.push(',');
1456 }
1457 payloads.push_str(&format!(r#"{{"fields":{{"id":{i},"name":"x"}}}}"#));
1458 }
1459 payloads.push(']');
1460 let frame_body = format!(
1461 r#"{{"collection":"events_587_oversize","idempotency_key":"k-oversize-587","payloads":{payloads}}}"#
1462 );
1463 let frame = Frame::new(MessageKind::BulkInsert, 500, frame_body.into_bytes());
1464 let reply = run_insert_dispatch(&runtime, &frame);
1465
1466 assert_eq!(reply.kind, MessageKind::Error);
1467 let body: JsonValue =
1468 serde_json::from_slice(&reply.payload).expect("err body json");
1469 assert_eq!(
1470 body.get("code").and_then(JsonValue::as_str),
1471 Some("BatchTooLarge")
1472 );
1473 let qr = runtime
1474 .execute_query("SELECT name FROM events_587_oversize")
1475 .expect("scan");
1476 assert!(
1477 qr.result.records.is_empty(),
1478 "oversize batch leaked rows into storage"
1479 );
1480 }
1481
1482 async fn read_one_frame<R: tokio::io::AsyncRead + Unpin>(r: &mut R) -> Frame {
1484 let mut header = [0u8; FRAME_HEADER_SIZE];
1485 r.read_exact(&mut header).await.expect("read header");
1486 let length = u32::from_le_bytes([header[0], header[1], header[2], header[3]]) as usize;
1487 let mut buf = vec![0u8; length];
1488 buf[..FRAME_HEADER_SIZE].copy_from_slice(&header);
1489 if length > FRAME_HEADER_SIZE {
1490 r.read_exact(&mut buf[FRAME_HEADER_SIZE..])
1491 .await
1492 .expect("read body");
1493 }
1494 let (frame, _) = decode_frame(&buf).expect("decode");
1495 frame
1496 }
1497
1498 fn stream_start_payload(coll: &str, cols: &[&str]) -> Vec<u8> {
1499 let mut p = Vec::new();
1500 p.extend_from_slice(&(coll.len() as u16).to_le_bytes());
1501 p.extend_from_slice(coll.as_bytes());
1502 p.extend_from_slice(&(cols.len() as u16).to_le_bytes());
1503 for c in cols {
1504 p.extend_from_slice(&(c.len() as u16).to_le_bytes());
1505 p.extend_from_slice(c.as_bytes());
1506 }
1507 p
1508 }
1509
1510 fn stream_rows_payload(rows: &[(i64, &str)]) -> Vec<u8> {
1511 let mut p = Vec::new();
1512 p.extend_from_slice(&(rows.len() as u32).to_le_bytes());
1513 for (id, name) in rows {
1514 crate::wire::protocol::encode_value(
1515 &mut p,
1516 &crate::storage::schema::Value::Integer(*id),
1517 );
1518 crate::wire::protocol::encode_value(
1519 &mut p,
1520 &crate::storage::schema::Value::text(name.to_string()),
1521 );
1522 }
1523 p
1524 }
1525
1526 #[tokio::test]
1533 async fn bulk_stream_rows_success_emits_no_response_frame() {
1534 let runtime = std::sync::Arc::new(RedDBRuntime::in_memory().expect("runtime"));
1536 runtime
1537 .execute_query("CREATE TABLE target (id INT, name TEXT)")
1538 .expect("create table");
1539
1540 let (server_io, mut client) = tokio::io::duplex(64 * 1024);
1543
1544 let server_task = tokio::spawn(async move {
1545 let _ = handle_session(server_io, runtime, None, None).await;
1546 });
1547
1548 client.write_all(&[1u8]).await.expect("write minor");
1551
1552 let hello_payload =
1554 br#"{"versions":[1],"auth_methods":["anonymous"],"features":0,"client_name":"test"}"#
1555 .to_vec();
1556 let hello = encode_frame(&Frame::new(MessageKind::Hello, 1, hello_payload));
1557 client.write_all(&hello).await.expect("write hello");
1558
1559 let ack = read_one_frame(&mut client).await;
1561 assert_eq!(ack.kind, MessageKind::HelloAck);
1562
1563 let authresp = encode_frame(&Frame::new(MessageKind::AuthResponse, 2, b"{}".to_vec()));
1565 client.write_all(&authresp).await.expect("write authresp");
1566
1567 let auth_ok = read_one_frame(&mut client).await;
1569 assert_eq!(auth_ok.kind, MessageKind::AuthOk);
1570
1571 let start = encode_frame(&Frame::new(
1573 MessageKind::BulkStreamStart,
1574 3,
1575 stream_start_payload("target", &["id", "name"]),
1576 ));
1577 client.write_all(&start).await.expect("write start");
1578 let start_ack = read_one_frame(&mut client).await;
1579 assert_eq!(start_ack.kind, MessageKind::BulkStreamAck);
1580 assert_eq!(start_ack.correlation_id, 3);
1581
1582 let rows = encode_frame(&Frame::new(
1584 MessageKind::BulkStreamRows,
1585 4,
1586 stream_rows_payload(&[(1, "a"), (2, "b")]),
1587 ));
1588 client.write_all(&rows).await.expect("write rows");
1589
1590 let commit = encode_frame(&Frame::new(MessageKind::BulkStreamCommit, 5, vec![]));
1596 client.write_all(&commit).await.expect("write commit");
1597
1598 let next = read_one_frame(&mut client).await;
1599 assert_eq!(
1600 next.kind,
1601 MessageKind::BulkOk,
1602 "expected BulkOk after commit; got {:?} — BulkStreamRows leaked an ack frame",
1603 next.kind
1604 );
1605 assert_eq!(
1606 next.correlation_id, 5,
1607 "commit response must carry the commit's correlation id"
1608 );
1609
1610 let bye = encode_frame(&Frame::new(MessageKind::Bye, 6, vec![]));
1612 client.write_all(&bye).await.expect("write bye");
1613 let _ = read_one_frame(&mut client).await; drop(client);
1615 let _ = server_task.await;
1616 }
1617
1618 #[tokio::test]
1621 async fn bulk_stream_rows_error_still_emits_error_frame() {
1622 let runtime = std::sync::Arc::new(RedDBRuntime::in_memory().expect("runtime"));
1623 let (server_io, mut client) = tokio::io::duplex(64 * 1024);
1624
1625 let server_task = tokio::spawn(async move {
1626 let _ = handle_session(server_io, runtime, None, None).await;
1627 });
1628
1629 client.write_all(&[1u8]).await.unwrap();
1630 let hello_payload =
1631 br#"{"versions":[1],"auth_methods":["anonymous"],"features":0}"#.to_vec();
1632 client
1633 .write_all(&encode_frame(&Frame::new(
1634 MessageKind::Hello,
1635 1,
1636 hello_payload,
1637 )))
1638 .await
1639 .unwrap();
1640 let _ack = read_one_frame(&mut client).await;
1641 client
1642 .write_all(&encode_frame(&Frame::new(
1643 MessageKind::AuthResponse,
1644 2,
1645 b"{}".to_vec(),
1646 )))
1647 .await
1648 .unwrap();
1649 let _auth_ok = read_one_frame(&mut client).await;
1650
1651 let rows = encode_frame(&Frame::new(
1655 MessageKind::BulkStreamRows,
1656 7,
1657 stream_rows_payload(&[(1, "a")]),
1658 ));
1659 client.write_all(&rows).await.unwrap();
1660 let resp = read_one_frame(&mut client).await;
1661 assert_eq!(resp.kind, MessageKind::Error);
1662 assert_eq!(resp.correlation_id, 7);
1663
1664 drop(client);
1665 let _ = server_task.await;
1666 }
1667}