1use std::io;
13use std::sync::Arc;
14
15use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
16
17use crate::auth::store::AuthStore;
18use crate::runtime::RedDBRuntime;
19use crate::serde_json::{self, Value as JsonValue};
20use reddb_wire::query_with_params::{
21 decode_query_with_params, ParamValue as RedWireParamValue, FEATURE_PARAMS,
22};
23
24use super::auth::{
25 build_auth_fail, build_auth_ok, build_hello_ack, pick_auth_method, validate_auth_response,
26 AuthOutcome, Hello,
27};
28use super::codec::{decode_frame, encode_frame};
29use super::frame::{Frame, MessageDirection, MessageKind, FRAME_HEADER_SIZE};
30use super::{FrameBuilder, MAX_KNOWN_MINOR_VERSION, REDWIRE_MAGIC};
31
32#[derive(Debug)]
33struct AuthedSession {
34 username: String,
35 #[allow(dead_code)]
36 session_id: String,
37}
38
39pub async fn handle_session<S>(
40 mut stream: S,
41 runtime: Arc<RedDBRuntime>,
42 auth_store: Option<Arc<AuthStore>>,
43 oauth: Option<Arc<crate::auth::oauth::OAuthValidator>>,
44) -> io::Result<()>
45where
46 S: AsyncRead + AsyncWrite + Unpin + Send,
47{
48 let session = perform_handshake(
52 &mut stream,
53 runtime.as_ref(),
54 auth_store.as_deref(),
55 oauth.as_deref(),
56 )
57 .await?;
58 if session.is_none() {
59 return Ok(());
60 }
61 let _session = session.unwrap();
62
63 let mut stream_session: Option<crate::wire::listener::BulkStreamSession> = None;
66 let mut prepared_stmts: std::collections::HashMap<u32, crate::wire::listener::PreparedStmt> =
67 std::collections::HashMap::new();
68
69 let mut buf = vec![0u8; FRAME_HEADER_SIZE];
70 loop {
71 if let Err(err) = stream.read_exact(&mut buf[..FRAME_HEADER_SIZE]).await {
73 if err.kind() == io::ErrorKind::UnexpectedEof {
74 return Ok(());
75 }
76 return Err(err);
77 }
78 let length = u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
79 if length < FRAME_HEADER_SIZE || length > super::frame::MAX_FRAME_SIZE as usize {
80 return Err(io::Error::other(format!("invalid frame length {length}")));
81 }
82 if buf.len() < length {
83 buf.resize(length, 0);
84 }
85 let payload_len = length - FRAME_HEADER_SIZE;
86 if payload_len > 0 {
87 stream
88 .read_exact(&mut buf[FRAME_HEADER_SIZE..length])
89 .await?;
90 }
91 let (frame, _) = decode_frame(&buf[..length])
92 .map_err(|e| io::Error::other(format!("decode frame: {e}")))?;
93
94 if frame.kind.direction() == MessageDirection::ServerToClient {
99 let err_frame = FrameBuilder::reply_to(frame.correlation_id)
100 .kind(MessageKind::Error)
101 .payload(format!("redwire: {:?} is server-only", frame.kind).into_bytes())
102 .build()
103 .map_err(|e| io::Error::other(format!("build Error frame: {e}")))?;
104 stream.write_all(&encode_frame(&err_frame)).await?;
105 continue;
106 }
107
108 match frame.kind {
109 MessageKind::Bye => {
110 let bye = encode_frame(&build_reply(
111 frame.correlation_id,
112 MessageKind::Bye,
113 vec![],
114 )?);
115 let _ = stream.write_all(&bye).await;
116 return Ok(());
117 }
118 MessageKind::Ping => {
119 let pong = encode_frame(&build_reply(
120 frame.correlation_id,
121 MessageKind::Pong,
122 vec![],
123 )?);
124 stream.write_all(&pong).await?;
125 }
126 MessageKind::Query => {
127 let response = run_query(&runtime, &frame);
128 stream.write_all(&encode_frame(&response)).await?;
129 }
130 MessageKind::QueryWithParams => {
131 let response = run_query_with_params(&runtime, &frame);
132 stream.write_all(&encode_frame(&response)).await?;
133 }
134 MessageKind::BulkInsert => {
138 let response = run_insert_dispatch(&runtime, &frame);
139 stream.write_all(&encode_frame(&response)).await?;
140 }
141 MessageKind::BulkInsertBinary => {
145 let raw =
146 crate::wire::listener::handle_bulk_insert_binary(&runtime, &frame.payload);
147 stream
148 .write_all(&encode_frame(&rewrap_handler_response(&raw, &frame)))
149 .await?;
150 }
151 MessageKind::BulkInsertPrevalidated => {
152 let raw = crate::wire::listener::handle_bulk_insert_binary_prevalidated(
153 &runtime,
154 &frame.payload,
155 );
156 stream
157 .write_all(&encode_frame(&rewrap_handler_response(&raw, &frame)))
158 .await?;
159 }
160 MessageKind::QueryBinary => {
161 let raw = crate::wire::listener::handle_query_binary(&runtime, &frame.payload);
162 stream
163 .write_all(&encode_frame(&rewrap_handler_response(&raw, &frame)))
164 .await?;
165 }
166 MessageKind::BulkStreamStart => {
169 let raw =
170 crate::wire::listener::handle_stream_start(&frame.payload, &mut stream_session);
171 stream
172 .write_all(&encode_frame(&rewrap_handler_response(&raw, &frame)))
173 .await?;
174 }
175 MessageKind::BulkStreamRows => {
176 let raw = crate::wire::listener::handle_stream_rows(
177 &runtime,
178 &frame.payload,
179 &mut stream_session,
180 );
181 if !raw.is_empty() {
187 stream
188 .write_all(&encode_frame(&rewrap_handler_response(&raw, &frame)))
189 .await?;
190 }
191 }
192 MessageKind::BulkStreamCommit => {
193 let raw =
194 crate::wire::listener::handle_stream_commit(&runtime, &mut stream_session);
195 stream
196 .write_all(&encode_frame(&rewrap_handler_response(&raw, &frame)))
197 .await?;
198 }
199 MessageKind::Prepare => {
203 let raw = crate::wire::listener::handle_prepare(
204 &runtime,
205 &frame.payload,
206 &mut prepared_stmts,
207 );
208 stream
209 .write_all(&encode_frame(&rewrap_handler_response(&raw, &frame)))
210 .await?;
211 }
212 MessageKind::ExecutePrepared => {
213 let raw = crate::wire::listener::handle_execute_prepared(
214 &runtime,
215 &frame.payload,
216 &prepared_stmts,
217 );
218 stream
219 .write_all(&encode_frame(&rewrap_handler_response(&raw, &frame)))
220 .await?;
221 }
222 MessageKind::Get => {
223 let response = run_get(&runtime, &frame);
224 stream.write_all(&encode_frame(&response)).await?;
225 }
226 MessageKind::Delete => {
227 let response = run_delete(&runtime, &frame);
228 stream.write_all(&encode_frame(&response)).await?;
229 }
230 other => {
231 let err_frame = FrameBuilder::reply_to(frame.correlation_id)
232 .kind(MessageKind::Error)
233 .payload(format!("redwire: cannot dispatch {other:?} yet").into_bytes())
234 .build()
235 .map_err(|e| io::Error::other(format!("build Error frame: {e}")))?;
236 let err = encode_frame(&err_frame);
237 stream.write_all(&err).await?;
238 }
239 }
240 }
241}
242
243async fn perform_handshake<S>(
246 stream: &mut S,
247 runtime: &RedDBRuntime,
248 auth_store: Option<&AuthStore>,
249 oauth: Option<&crate::auth::oauth::OAuthValidator>,
250) -> io::Result<Option<AuthedSession>>
251where
252 S: AsyncRead + AsyncWrite + Unpin + Send,
253{
254 let mut minor_buf = [0u8; 1];
256 stream.read_exact(&mut minor_buf).await?;
257 let minor = minor_buf[0];
258 if minor > MAX_KNOWN_MINOR_VERSION {
259 return Ok(None);
263 }
264
265 let hello = read_frame(stream).await?;
267 if hello.kind != MessageKind::Hello {
268 let fail = encode_frame(&build_reply(
269 hello.correlation_id,
270 MessageKind::AuthFail,
271 build_auth_fail("first frame after magic must be Hello"),
272 )?);
273 let _ = stream.write_all(&fail).await;
274 return Ok(None);
275 }
276 let hello_msg = match Hello::from_payload(&hello.payload) {
277 Ok(h) => h,
278 Err(e) => {
279 let fail = encode_frame(&build_reply(
280 hello.correlation_id,
281 MessageKind::AuthFail,
282 build_auth_fail(&e),
283 )?);
284 let _ = stream.write_all(&fail).await;
285 return Ok(None);
286 }
287 };
288
289 let chosen_version = hello_msg
290 .versions
291 .iter()
292 .copied()
293 .filter(|v| *v <= MAX_KNOWN_MINOR_VERSION)
294 .max()
295 .unwrap_or(0);
296 if chosen_version == 0 {
297 let fail = encode_frame(&build_reply(
298 hello.correlation_id,
299 MessageKind::AuthFail,
300 build_auth_fail("no overlapping protocol version"),
301 )?);
302 let _ = stream.write_all(&fail).await;
303 return Ok(None);
304 }
305
306 let server_anon_ok = auth_store.map(|s| !s.is_enabled()).unwrap_or(true);
307 let chosen = match pick_auth_method(&hello_msg.auth_methods, server_anon_ok) {
308 Some(m) => m,
309 None => {
310 let fail = encode_frame(&build_reply(
311 hello.correlation_id,
312 MessageKind::AuthFail,
313 build_auth_fail("no overlapping auth method"),
314 )?);
315 let _ = stream.write_all(&fail).await;
316 return Ok(None);
317 }
318 };
319
320 let server_features = FEATURE_PARAMS;
329 let topology = build_topology_for_hello_ack(runtime);
330 let ack_frame = FrameBuilder::reply_to(hello.correlation_id)
331 .kind(MessageKind::HelloAck)
332 .payload(build_hello_ack(
333 chosen_version,
334 chosen,
335 server_features,
336 topology.as_ref(),
337 ))
338 .build()
339 .map_err(|e| io::Error::other(format!("build HelloAck: {e}")))?;
340 let ack = encode_frame(&ack_frame);
341 stream.write_all(&ack).await?;
342
343 if chosen == "scram-sha-256" {
347 return perform_scram_handshake(stream, auth_store, hello.correlation_id, server_features)
348 .await;
349 }
350
351 let resp = read_frame(stream).await?;
354 if resp.kind != MessageKind::AuthResponse {
355 let fail = encode_frame(&build_reply(
356 resp.correlation_id,
357 MessageKind::AuthFail,
358 build_auth_fail("expected AuthResponse"),
359 )?);
360 let _ = stream.write_all(&fail).await;
361 return Ok(None);
362 }
363
364 if chosen == "oauth-jwt" {
368 let validator = match oauth {
369 Some(v) => v,
370 None => {
371 let fail = encode_frame(&build_reply(
372 resp.correlation_id,
373 MessageKind::AuthFail,
374 build_auth_fail("oauth-jwt requires RedWireConfig.oauth"),
375 )?);
376 let _ = stream.write_all(&fail).await;
377 return Ok(None);
378 }
379 };
380 let raw = match crate::serde_json::from_slice::<JsonValue>(&resp.payload)
381 .ok()
382 .and_then(|v| {
383 v.as_object()
384 .and_then(|o| o.get("jwt").cloned())
385 .and_then(|x| x.as_str().map(String::from))
386 }) {
387 Some(s) if !s.is_empty() => s,
388 _ => {
389 let fail = encode_frame(&build_reply(
390 resp.correlation_id,
391 MessageKind::AuthFail,
392 build_auth_fail("oauth-jwt: AuthResponse missing 'jwt' string"),
393 )?);
394 let _ = stream.write_all(&fail).await;
395 return Ok(None);
396 }
397 };
398 match super::auth::validate_oauth_jwt(validator, &raw) {
399 Ok((username, role)) => {
400 let session_id = super::auth::new_session_id_for_scram();
401 let ok = encode_frame(&build_reply(
402 resp.correlation_id,
403 MessageKind::AuthOk,
404 build_auth_ok(&session_id, &username, role, server_features),
405 )?);
406 stream.write_all(&ok).await?;
407 return Ok(Some(AuthedSession {
408 username,
409 session_id,
410 }));
411 }
412 Err(reason) => {
413 let fail = encode_frame(&build_reply(
414 resp.correlation_id,
415 MessageKind::AuthFail,
416 build_auth_fail(&format!("oauth-jwt: {reason}")),
417 )?);
418 let _ = stream.write_all(&fail).await;
419 return Ok(None);
420 }
421 }
422 }
423
424 match validate_auth_response(chosen, &resp.payload, auth_store) {
425 AuthOutcome::Authenticated {
426 username,
427 role,
428 session_id,
429 } => {
430 let ok_frame = FrameBuilder::reply_to(resp.correlation_id)
431 .kind(MessageKind::AuthOk)
432 .payload(build_auth_ok(&session_id, &username, role, server_features))
433 .build()
434 .map_err(|e| io::Error::other(format!("build AuthOk: {e}")))?;
435 let ok = encode_frame(&ok_frame);
436 stream.write_all(&ok).await?;
437 Ok(Some(AuthedSession {
438 username,
439 session_id,
440 }))
441 }
442 AuthOutcome::Refused(reason) => {
443 let fail = encode_frame(&build_reply(
444 resp.correlation_id,
445 MessageKind::AuthFail,
446 build_auth_fail(&reason),
447 )?);
448 let _ = stream.write_all(&fail).await;
449 Ok(None)
450 }
451 }
452}
453
454async fn perform_scram_handshake<S>(
461 stream: &mut S,
462 auth_store: Option<&AuthStore>,
463 initial_correlation: u64,
464 server_features: u32,
465) -> io::Result<Option<AuthedSession>>
466where
467 S: AsyncRead + AsyncWrite + Unpin + Send,
468{
469 let store = match auth_store {
470 Some(s) => s,
471 None => {
472 let fail = encode_frame(&build_reply(
473 initial_correlation,
474 MessageKind::AuthFail,
475 build_auth_fail("scram-sha-256 requires an AuthStore"),
476 )?);
477 let _ = stream.write_all(&fail).await;
478 return Ok(None);
479 }
480 };
481
482 let cf = read_frame(stream).await?;
484 if cf.kind != MessageKind::AuthResponse {
485 let fail = encode_frame(&build_reply(
486 cf.correlation_id,
487 MessageKind::AuthFail,
488 build_auth_fail("expected AuthResponse(client-first-message)"),
489 )?);
490 let _ = stream.write_all(&fail).await;
491 return Ok(None);
492 }
493 let (username, client_nonce, client_first_bare) =
494 match super::auth::parse_scram_client_first(&cf.payload) {
495 Ok(t) => t,
496 Err(e) => {
497 let fail = encode_frame(&build_reply(
498 cf.correlation_id,
499 MessageKind::AuthFail,
500 build_auth_fail(&format!("scram client-first: {e}")),
501 )?);
502 let _ = stream.write_all(&fail).await;
503 return Ok(None);
504 }
505 };
506
507 let verifier = store.lookup_scram_verifier_global(&username);
516 let (salt, iter, stored_key, server_key, user_known) = match &verifier {
517 Some(v) => (v.salt.clone(), v.iter, v.stored_key, v.server_key, true),
518 None => (
519 crate::auth::store::random_bytes(16),
520 crate::auth::scram::DEFAULT_ITER,
521 [0u8; 32],
522 [0u8; 32],
523 false,
524 ),
525 };
526
527 let server_nonce = super::auth::new_server_nonce();
529 let server_first =
530 super::auth::build_scram_server_first(&client_nonce, &server_nonce, &salt, iter);
531 let req = encode_frame(&build_reply(
532 cf.correlation_id,
533 MessageKind::AuthRequest,
534 server_first.as_bytes().to_vec(),
535 )?);
536 stream.write_all(&req).await?;
537
538 let cfinal = read_frame(stream).await?;
540 if cfinal.kind != MessageKind::AuthResponse {
541 let fail = encode_frame(&build_reply(
542 cfinal.correlation_id,
543 MessageKind::AuthFail,
544 build_auth_fail("expected AuthResponse(client-final-message)"),
545 )?);
546 let _ = stream.write_all(&fail).await;
547 return Ok(None);
548 }
549 let (combined_nonce, presented_proof, client_final_no_proof) =
550 match super::auth::parse_scram_client_final(&cfinal.payload) {
551 Ok(t) => t,
552 Err(e) => {
553 let fail = encode_frame(&build_reply(
554 cfinal.correlation_id,
555 MessageKind::AuthFail,
556 build_auth_fail(&format!("scram client-final: {e}")),
557 )?);
558 let _ = stream.write_all(&fail).await;
559 return Ok(None);
560 }
561 };
562 let expected_combined = format!("{client_nonce}{server_nonce}");
563 if combined_nonce != expected_combined {
564 let fail = encode_frame(&build_reply(
565 cfinal.correlation_id,
566 MessageKind::AuthFail,
567 build_auth_fail("scram nonce mismatch — replay protection failed"),
568 )?);
569 let _ = stream.write_all(&fail).await;
570 return Ok(None);
571 }
572
573 let auth_message =
575 crate::auth::scram::auth_message(&client_first_bare, &server_first, &client_final_no_proof);
576 let proof_ok = if user_known {
577 let v = crate::auth::scram::ScramVerifier {
578 salt: salt.clone(),
579 iter,
580 stored_key,
581 server_key,
582 };
583 crate::auth::scram::verify_client_proof(&v, &auth_message, &presented_proof)
584 } else {
585 false
586 };
587 if !proof_ok {
588 let fail = encode_frame(&build_reply(
589 cfinal.correlation_id,
590 MessageKind::AuthFail,
591 build_auth_fail("invalid SCRAM proof"),
592 )?);
593 let _ = stream.write_all(&fail).await;
594 return Ok(None);
595 }
596
597 let role = store
599 .list_users()
600 .into_iter()
601 .find(|u| u.username == username)
602 .map(|u| u.role)
603 .unwrap_or(crate::auth::Role::Read);
604 let server_sig = crate::auth::scram::server_signature(&server_key, &auth_message);
605 let session_id = super::auth::new_session_id_for_scram();
606 let ok_payload = super::auth::build_scram_auth_ok(
607 &session_id,
608 &username,
609 role,
610 server_features,
611 &server_sig,
612 );
613 let ok = encode_frame(&build_reply(
614 cfinal.correlation_id,
615 MessageKind::AuthOk,
616 ok_payload,
617 )?);
618 stream.write_all(&ok).await?;
619 Ok(Some(AuthedSession {
620 username,
621 session_id,
622 }))
623}
624
625async fn read_frame<S>(stream: &mut S) -> io::Result<Frame>
626where
627 S: AsyncRead + AsyncWrite + Unpin + Send,
628{
629 let mut header = [0u8; FRAME_HEADER_SIZE];
630 stream.read_exact(&mut header).await?;
631 let length = u32::from_le_bytes([header[0], header[1], header[2], header[3]]) as usize;
632 if length < FRAME_HEADER_SIZE || length > super::frame::MAX_FRAME_SIZE as usize {
633 return Err(io::Error::other(format!(
634 "redwire frame length {length} out of range"
635 )));
636 }
637 let mut buf = vec![0u8; length];
638 buf[..FRAME_HEADER_SIZE].copy_from_slice(&header);
639 if length > FRAME_HEADER_SIZE {
640 stream
641 .read_exact(&mut buf[FRAME_HEADER_SIZE..length])
642 .await?;
643 }
644 let (frame, _) =
645 decode_frame(&buf).map_err(|e| io::Error::other(format!("decode frame: {e}")))?;
646 Ok(frame)
647}
648
649fn run_query(runtime: &RedDBRuntime, frame: &Frame) -> Frame {
650 let sql = match std::str::from_utf8(&frame.payload) {
651 Ok(s) => s,
652 Err(_) => {
653 return error_frame(frame.correlation_id, "Query payload must be UTF-8 SQL");
654 }
655 };
656 match runtime.execute_query(sql) {
657 Ok(result) => {
658 let mut obj = crate::serde_json::Map::new();
659 obj.insert("ok".to_string(), JsonValue::Bool(true));
660 obj.insert(
661 "statement".to_string(),
662 JsonValue::String(result.statement_type.to_string()),
663 );
664 obj.insert(
665 "affected".to_string(),
666 JsonValue::Number(result.affected_rows as f64),
667 );
668 let payload = serde_json::to_vec(&JsonValue::Object(obj)).unwrap_or_default();
669 build_dispatch_reply(frame.correlation_id, MessageKind::Result, payload)
670 }
671 Err(err) => error_frame(frame.correlation_id, &err.to_string()),
672 }
673}
674
675fn run_query_with_params(runtime: &RedDBRuntime, frame: &Frame) -> Frame {
676 let (sql, params) = match decode_query_with_params(&frame.payload) {
677 Ok(decoded) => decoded,
678 Err(err) => return error_frame(frame.correlation_id, &err.to_string()),
679 };
680 let params = params
681 .into_iter()
682 .map(param_to_schema_value)
683 .collect::<Vec<_>>();
684 let parsed = match crate::storage::query::modes::parse_multi(&sql) {
685 Ok(parsed) => parsed,
686 Err(err) => return error_frame(frame.correlation_id, &err.to_string()),
687 };
688 let bound = match crate::storage::query::user_params::bind(&parsed, ¶ms) {
689 Ok(bound) => bound,
690 Err(err) => return error_frame(frame.correlation_id, &err.to_string()),
691 };
692 match runtime.execute_query_expr(bound) {
693 Ok(result) => {
694 let is_mutation = matches!(result.statement_type, "insert" | "update" | "delete");
695 if is_mutation {
696 let post_lsn = runtime.cdc_current_lsn();
697 if let Err(err) = runtime.enforce_commit_policy(post_lsn) {
698 return error_frame(frame.correlation_id, &err.to_string());
699 }
700 }
701 let payload = serde_json::to_vec(
702 &crate::presentation::query_result_json::runtime_query_json(&result, &None, &None),
703 )
704 .unwrap_or_default();
705 build_dispatch_reply(frame.correlation_id, MessageKind::Result, payload)
706 }
707 Err(err) => error_frame(frame.correlation_id, &err.to_string()),
708 }
709}
710
711fn param_to_schema_value(value: RedWireParamValue) -> crate::storage::schema::Value {
712 use crate::storage::schema::Value;
713 match value {
714 RedWireParamValue::Null => Value::Null,
715 RedWireParamValue::Bool(value) => Value::Boolean(value),
716 RedWireParamValue::Int(value) => Value::Integer(value),
717 RedWireParamValue::Float(value) => Value::Float(value),
718 RedWireParamValue::Text(value) => Value::Text(Arc::from(value.as_str())),
719 RedWireParamValue::Bytes(value) => Value::Blob(value),
720 RedWireParamValue::Vector(value) => Value::Vector(value),
721 RedWireParamValue::Json(value) => Value::Json(value),
722 RedWireParamValue::Timestamp(value) => Value::Timestamp(value),
723 RedWireParamValue::Uuid(value) => Value::Uuid(value),
724 }
725}
726
727fn run_insert_dispatch(runtime: &RedDBRuntime, frame: &Frame) -> Frame {
735 let v: JsonValue = match serde_json::from_slice(&frame.payload) {
736 Ok(v) => v,
737 Err(e) => return error_frame(frame.correlation_id, &format!("Insert: invalid JSON: {e}")),
738 };
739 let obj = match v.as_object() {
740 Some(o) => o,
741 None => {
742 return error_frame(
743 frame.correlation_id,
744 "Insert: payload must be a JSON object",
745 )
746 }
747 };
748 let collection = match obj.get("collection").and_then(|x| x.as_str()) {
749 Some(s) if !s.is_empty() => s,
750 _ => return error_frame(frame.correlation_id, "Insert: missing 'collection' string"),
751 };
752
753 if let Some(rows) = obj.get("payloads").and_then(|x| x.as_array()) {
754 let mut objects = Vec::with_capacity(rows.len());
755 for entry in rows {
756 objects.push(match entry.as_object() {
757 Some(o) => o,
758 None => {
759 return error_frame(
760 frame.correlation_id,
761 "Insert: each payload must be a JSON object",
762 )
763 }
764 });
765 }
766
767 if crate::rpc_stdio::should_bulk_insert_graph(runtime, collection, &objects) {
768 return match crate::rpc_stdio::bulk_insert_graph(runtime, collection, &objects) {
769 Ok(body) => {
770 let payload = serde_json::to_vec(&body).unwrap_or_default();
771 build_dispatch_reply(frame.correlation_id, MessageKind::BulkOk, payload)
772 }
773 Err(err) => error_frame(frame.correlation_id, &err.to_string()),
774 };
775 }
776
777 let mut affected: u64 = 0;
778 let mut ids = Vec::with_capacity(objects.len());
779 for row in objects {
780 let sql = crate::rpc_stdio::build_insert_sql(collection, row.iter());
781 match runtime.execute_query(&sql) {
782 Ok(qr) => {
783 affected += qr.affected_rows;
784 if let Some(id) = crate::rpc_stdio::insert_result_to_json(&qr).get("id") {
785 ids.push(id.clone());
786 }
787 }
788 Err(err) => return error_frame(frame.correlation_id, &err.to_string()),
789 }
790 }
791 let mut out = crate::serde_json::Map::new();
792 out.insert("affected".to_string(), JsonValue::Number(affected as f64));
793 out.insert("ids".to_string(), JsonValue::Array(ids));
794 let payload = serde_json::to_vec(&JsonValue::Object(out)).unwrap_or_default();
795 return build_dispatch_reply(frame.correlation_id, MessageKind::BulkOk, payload);
796 }
797
798 let row = match obj.get("payload").and_then(|x| x.as_object()) {
799 Some(o) => o,
800 None => {
801 return error_frame(
802 frame.correlation_id,
803 "Insert: missing 'payload' object or 'payloads' array",
804 )
805 }
806 };
807 let sql = crate::rpc_stdio::build_insert_sql(collection, row.iter());
808 match runtime.execute_query(&sql) {
809 Ok(qr) => {
810 let body = crate::rpc_stdio::insert_result_to_json(&qr);
811 let payload = serde_json::to_vec(&body).unwrap_or_default();
812 build_dispatch_reply(frame.correlation_id, MessageKind::BulkOk, payload)
813 }
814 Err(err) => error_frame(frame.correlation_id, &err.to_string()),
815 }
816}
817
818fn build_topology_for_hello_ack(runtime: &RedDBRuntime) -> Option<reddb_wire::topology::Topology> {
830 use crate::auth::middleware::AuthResult;
831 use crate::replication::{LagConfig, TopologyAdvertiser};
832 use reddb_wire::topology::Endpoint;
833
834 let db = runtime.db();
835 let primary_endpoint = Endpoint {
836 addr: runtime.config_string("red.redwire.advertise_addr", ""),
837 region: db.options().replication.region.clone(),
838 };
839 let (replicas, current_lsn, epoch) = match db.replication.as_ref() {
840 Some(repl) => (
841 repl.replica_snapshots(),
842 repl.wal_buffer.current_lsn(),
843 repl.topology_epoch(),
844 ),
845 None => (Vec::new(), 0u64, 0u64),
846 };
847 let lag = LagConfig::from_now();
848 Some(TopologyAdvertiser::advertise(
849 &replicas,
850 &AuthResult::Anonymous,
851 epoch,
852 primary_endpoint,
853 current_lsn,
854 &lag,
855 ))
856}
857
858fn error_frame(correlation_id: u64, msg: &str) -> Frame {
859 FrameBuilder::reply_to(correlation_id)
864 .kind(MessageKind::Error)
865 .payload(msg.as_bytes().to_vec())
866 .build()
867 .expect("error frame fits in MAX_FRAME_SIZE")
868}
869
870fn build_reply(correlation_id: u64, kind: MessageKind, payload: Vec<u8>) -> io::Result<Frame> {
879 FrameBuilder::reply_to(correlation_id)
880 .kind(kind)
881 .payload(payload)
882 .build()
883 .map_err(|e| io::Error::other(format!("build {kind:?}: {e}")))
884}
885
886fn build_dispatch_reply(correlation_id: u64, kind: MessageKind, payload: Vec<u8>) -> Frame {
893 FrameBuilder::reply_to(correlation_id)
894 .kind(kind)
895 .payload(payload)
896 .build()
897 .unwrap_or_else(|e| error_frame(correlation_id, &e.to_string()))
898}
899
900fn rewrap_handler_response(raw_bytes: &[u8], req: &Frame) -> Frame {
912 if raw_bytes.len() < 5 {
913 return error_frame(
914 req.correlation_id,
915 "fast-path handler returned a truncated frame",
916 );
917 }
918 let kind_byte = raw_bytes[4];
919 let kind = MessageKind::from_u8(kind_byte).unwrap_or(MessageKind::Error);
920 let body = raw_bytes[5..].to_vec();
921 build_dispatch_reply(req.correlation_id, kind, body)
922}
923
924fn run_get(runtime: &RedDBRuntime, frame: &Frame) -> Frame {
928 let v: JsonValue = match serde_json::from_slice(&frame.payload) {
929 Ok(v) => v,
930 Err(e) => return error_frame(frame.correlation_id, &format!("Get: invalid JSON: {e}")),
931 };
932 let obj = match v.as_object() {
933 Some(o) => o,
934 None => return error_frame(frame.correlation_id, "Get: payload must be a JSON object"),
935 };
936 let collection = match obj.get("collection").and_then(|x| x.as_str()) {
937 Some(s) if !s.is_empty() => s,
938 _ => return error_frame(frame.correlation_id, "Get: missing 'collection' string"),
939 };
940 let id = match obj.get("id").and_then(|x| x.as_str()) {
941 Some(s) if !s.is_empty() => s,
942 _ => return error_frame(frame.correlation_id, "Get: missing 'id' string"),
943 };
944 let id_lit = crate::rpc_stdio::value_to_sql_literal(&JsonValue::String(id.to_string()));
947 let sql = format!("SELECT * FROM {collection} WHERE _id = {id_lit} LIMIT 1");
948 match runtime.execute_query(&sql) {
949 Ok(qr) => {
950 let mut out = crate::serde_json::Map::new();
951 out.insert("ok".to_string(), JsonValue::Bool(true));
952 out.insert(
953 "found".to_string(),
954 JsonValue::Bool(!qr.result.records.is_empty()),
955 );
956 let payload = serde_json::to_vec(&JsonValue::Object(out)).unwrap_or_default();
959 build_dispatch_reply(frame.correlation_id, MessageKind::Result, payload)
960 }
961 Err(err) => error_frame(frame.correlation_id, &err.to_string()),
962 }
963}
964
965fn run_delete(runtime: &RedDBRuntime, frame: &Frame) -> Frame {
969 let v: JsonValue = match serde_json::from_slice(&frame.payload) {
970 Ok(v) => v,
971 Err(e) => return error_frame(frame.correlation_id, &format!("Delete: invalid JSON: {e}")),
972 };
973 let obj = match v.as_object() {
974 Some(o) => o,
975 None => {
976 return error_frame(
977 frame.correlation_id,
978 "Delete: payload must be a JSON object",
979 )
980 }
981 };
982 let collection = match obj.get("collection").and_then(|x| x.as_str()) {
983 Some(s) if !s.is_empty() => s,
984 _ => return error_frame(frame.correlation_id, "Delete: missing 'collection' string"),
985 };
986 let id = match obj.get("id").and_then(|x| x.as_str()) {
987 Some(s) if !s.is_empty() => s,
988 _ => return error_frame(frame.correlation_id, "Delete: missing 'id' string"),
989 };
990 let id_lit = crate::rpc_stdio::value_to_sql_literal(&JsonValue::String(id.to_string()));
991 let sql = format!("DELETE FROM {collection} WHERE _id = {id_lit}");
992 match runtime.execute_query(&sql) {
993 Ok(qr) => {
994 let mut out = crate::serde_json::Map::new();
995 out.insert(
996 "affected".to_string(),
997 JsonValue::Number(qr.affected_rows as f64),
998 );
999 let payload = serde_json::to_vec(&JsonValue::Object(out)).unwrap_or_default();
1000 build_dispatch_reply(frame.correlation_id, MessageKind::DeleteOk, payload)
1001 }
1002 Err(err) => error_frame(frame.correlation_id, &err.to_string()),
1003 }
1004}
1005
1006#[cfg(test)]
1007mod tests {
1008 use super::*;
1009
1010 use crate::runtime::RedDBRuntime;
1011 use tokio::io::{AsyncReadExt, AsyncWriteExt};
1012
1013 fn create_graph_collection(runtime: &RedDBRuntime, name: &str) {
1014 let db = runtime.db();
1015 db.store()
1016 .create_collection(name)
1017 .expect("create collection");
1018 let now = std::time::SystemTime::now()
1019 .duration_since(std::time::UNIX_EPOCH)
1020 .unwrap_or_default()
1021 .as_millis();
1022 db.save_collection_contract(crate::physical::CollectionContract {
1023 name: name.to_string(),
1024 declared_model: crate::catalog::CollectionModel::Graph,
1025 schema_mode: crate::catalog::SchemaMode::Dynamic,
1026 origin: crate::physical::ContractOrigin::Explicit,
1027 version: 1,
1028 created_at_unix_ms: now,
1029 updated_at_unix_ms: now,
1030 default_ttl_ms: None,
1031 vector_dimension: None,
1032 vector_metric: None,
1033 context_index_fields: Vec::new(),
1034 declared_columns: Vec::new(),
1035 table_def: None,
1036 timestamps_enabled: false,
1037 context_index_enabled: false,
1038 metrics_raw_retention_ms: None,
1039 metrics_rollup_policies: Vec::new(),
1040 metrics_tenant_identity: None,
1041 metrics_namespace: None,
1042 append_only: false,
1043 subscriptions: Vec::new(),
1044 })
1045 .expect("save graph contract");
1046 }
1047
1048 #[test]
1049 fn magic_byte_is_0xfe() {
1050 assert_eq!(REDWIRE_MAGIC, 0xFE);
1051 }
1052
1053 #[test]
1054 fn redwire_bulk_insert_graph_rows_returns_ids() {
1055 let runtime = RedDBRuntime::in_memory().expect("runtime");
1056 create_graph_collection(&runtime, "network");
1057
1058 let nodes = Frame::new(
1059 MessageKind::BulkInsert,
1060 7,
1061 br#"{"collection":"network","payloads":[{"label":"Host","name":"app"},{"label":"Host","name":"db"}]}"#.to_vec(),
1062 );
1063 let nodes_reply = run_insert_dispatch(&runtime, &nodes);
1064 assert_eq!(nodes_reply.kind, MessageKind::BulkOk);
1065 let node_body: JsonValue =
1066 serde_json::from_slice(&nodes_reply.payload).expect("nodes json");
1067 assert_eq!(
1068 node_body.get("affected").and_then(JsonValue::as_u64),
1069 Some(2)
1070 );
1071 let ids = node_body
1072 .get("ids")
1073 .and_then(JsonValue::as_array)
1074 .expect("node ids");
1075 assert_eq!(ids.len(), 2);
1076
1077 let from = ids[0].as_u64().expect("from id");
1078 let to = ids[1].as_u64().expect("to id");
1079 let edges = Frame::new(
1080 MessageKind::BulkInsert,
1081 8,
1082 format!(
1083 r#"{{"collection":"network","payloads":[{{"label":"connects","from":{from},"to":{to},"role":"primary"}}]}}"#
1084 )
1085 .into_bytes(),
1086 );
1087 let edges_reply = run_insert_dispatch(&runtime, &edges);
1088 assert_eq!(edges_reply.kind, MessageKind::BulkOk);
1089 let edge_body: JsonValue =
1090 serde_json::from_slice(&edges_reply.payload).expect("edges json");
1091 assert_eq!(
1092 edge_body.get("affected").and_then(JsonValue::as_u64),
1093 Some(1)
1094 );
1095 assert_eq!(
1096 edge_body
1097 .get("ids")
1098 .and_then(JsonValue::as_array)
1099 .map(|ids| ids.len()),
1100 Some(1)
1101 );
1102 }
1103
1104 async fn read_one_frame<R: tokio::io::AsyncRead + Unpin>(r: &mut R) -> Frame {
1106 let mut header = [0u8; FRAME_HEADER_SIZE];
1107 r.read_exact(&mut header).await.expect("read header");
1108 let length = u32::from_le_bytes([header[0], header[1], header[2], header[3]]) as usize;
1109 let mut buf = vec![0u8; length];
1110 buf[..FRAME_HEADER_SIZE].copy_from_slice(&header);
1111 if length > FRAME_HEADER_SIZE {
1112 r.read_exact(&mut buf[FRAME_HEADER_SIZE..])
1113 .await
1114 .expect("read body");
1115 }
1116 let (frame, _) = decode_frame(&buf).expect("decode");
1117 frame
1118 }
1119
1120 fn stream_start_payload(coll: &str, cols: &[&str]) -> Vec<u8> {
1121 let mut p = Vec::new();
1122 p.extend_from_slice(&(coll.len() as u16).to_le_bytes());
1123 p.extend_from_slice(coll.as_bytes());
1124 p.extend_from_slice(&(cols.len() as u16).to_le_bytes());
1125 for c in cols {
1126 p.extend_from_slice(&(c.len() as u16).to_le_bytes());
1127 p.extend_from_slice(c.as_bytes());
1128 }
1129 p
1130 }
1131
1132 fn stream_rows_payload(rows: &[(i64, &str)]) -> Vec<u8> {
1133 let mut p = Vec::new();
1134 p.extend_from_slice(&(rows.len() as u32).to_le_bytes());
1135 for (id, name) in rows {
1136 crate::wire::protocol::encode_value(
1137 &mut p,
1138 &crate::storage::schema::Value::Integer(*id),
1139 );
1140 crate::wire::protocol::encode_value(
1141 &mut p,
1142 &crate::storage::schema::Value::text(name.to_string()),
1143 );
1144 }
1145 p
1146 }
1147
1148 #[tokio::test]
1155 async fn bulk_stream_rows_success_emits_no_response_frame() {
1156 let runtime = std::sync::Arc::new(RedDBRuntime::in_memory().expect("runtime"));
1158 runtime
1159 .execute_query("CREATE TABLE target (id INT, name TEXT)")
1160 .expect("create table");
1161
1162 let (server_io, mut client) = tokio::io::duplex(64 * 1024);
1165
1166 let server_task = tokio::spawn(async move {
1167 let _ = handle_session(server_io, runtime, None, None).await;
1168 });
1169
1170 client.write_all(&[1u8]).await.expect("write minor");
1173
1174 let hello_payload =
1176 br#"{"versions":[1],"auth_methods":["anonymous"],"features":0,"client_name":"test"}"#
1177 .to_vec();
1178 let hello = encode_frame(&Frame::new(MessageKind::Hello, 1, hello_payload));
1179 client.write_all(&hello).await.expect("write hello");
1180
1181 let ack = read_one_frame(&mut client).await;
1183 assert_eq!(ack.kind, MessageKind::HelloAck);
1184
1185 let authresp = encode_frame(&Frame::new(MessageKind::AuthResponse, 2, b"{}".to_vec()));
1187 client.write_all(&authresp).await.expect("write authresp");
1188
1189 let auth_ok = read_one_frame(&mut client).await;
1191 assert_eq!(auth_ok.kind, MessageKind::AuthOk);
1192
1193 let start = encode_frame(&Frame::new(
1195 MessageKind::BulkStreamStart,
1196 3,
1197 stream_start_payload("target", &["id", "name"]),
1198 ));
1199 client.write_all(&start).await.expect("write start");
1200 let start_ack = read_one_frame(&mut client).await;
1201 assert_eq!(start_ack.kind, MessageKind::BulkStreamAck);
1202 assert_eq!(start_ack.correlation_id, 3);
1203
1204 let rows = encode_frame(&Frame::new(
1206 MessageKind::BulkStreamRows,
1207 4,
1208 stream_rows_payload(&[(1, "a"), (2, "b")]),
1209 ));
1210 client.write_all(&rows).await.expect("write rows");
1211
1212 let commit = encode_frame(&Frame::new(MessageKind::BulkStreamCommit, 5, vec![]));
1218 client.write_all(&commit).await.expect("write commit");
1219
1220 let next = read_one_frame(&mut client).await;
1221 assert_eq!(
1222 next.kind,
1223 MessageKind::BulkOk,
1224 "expected BulkOk after commit; got {:?} — BulkStreamRows leaked an ack frame",
1225 next.kind
1226 );
1227 assert_eq!(
1228 next.correlation_id, 5,
1229 "commit response must carry the commit's correlation id"
1230 );
1231
1232 let bye = encode_frame(&Frame::new(MessageKind::Bye, 6, vec![]));
1234 client.write_all(&bye).await.expect("write bye");
1235 let _ = read_one_frame(&mut client).await; drop(client);
1237 let _ = server_task.await;
1238 }
1239
1240 #[tokio::test]
1243 async fn bulk_stream_rows_error_still_emits_error_frame() {
1244 let runtime = std::sync::Arc::new(RedDBRuntime::in_memory().expect("runtime"));
1245 let (server_io, mut client) = tokio::io::duplex(64 * 1024);
1246
1247 let server_task = tokio::spawn(async move {
1248 let _ = handle_session(server_io, runtime, None, None).await;
1249 });
1250
1251 client.write_all(&[1u8]).await.unwrap();
1252 let hello_payload =
1253 br#"{"versions":[1],"auth_methods":["anonymous"],"features":0}"#.to_vec();
1254 client
1255 .write_all(&encode_frame(&Frame::new(
1256 MessageKind::Hello,
1257 1,
1258 hello_payload,
1259 )))
1260 .await
1261 .unwrap();
1262 let _ack = read_one_frame(&mut client).await;
1263 client
1264 .write_all(&encode_frame(&Frame::new(
1265 MessageKind::AuthResponse,
1266 2,
1267 b"{}".to_vec(),
1268 )))
1269 .await
1270 .unwrap();
1271 let _auth_ok = read_one_frame(&mut client).await;
1272
1273 let rows = encode_frame(&Frame::new(
1277 MessageKind::BulkStreamRows,
1278 7,
1279 stream_rows_payload(&[(1, "a")]),
1280 ));
1281 client.write_all(&rows).await.unwrap();
1282 let resp = read_one_frame(&mut client).await;
1283 assert_eq!(resp.kind, MessageKind::Error);
1284 assert_eq!(resp.correlation_id, 7);
1285
1286 drop(client);
1287 let _ = server_task.await;
1288 }
1289}