use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tracing::{debug, instrument, warn};
use crate::bridge::envelope::{PhysicalPlan, Priority, Request, Status};
use crate::bridge::physical_plan::{CrdtOp, DocumentOp, GraphOp, VectorOp};
use crate::control::state::SharedState;
use crate::types::{ReadConsistency, RequestId, TenantId, VShardId};
const MAX_FRAME_SIZE: u32 = 16 * 1024 * 1024;
use super::conn_stream::ConnStream;
pub struct Session {
stream: ConnStream,
peer_addr: SocketAddr,
next_request_id: AtomicU64,
state: Arc<SharedState>,
auth_mode: crate::config::auth::AuthMode,
identity: Option<crate::control::security::identity::AuthenticatedIdentity>,
}
impl Session {
fn with_stream(
stream: ConnStream,
peer_addr: SocketAddr,
state: Arc<SharedState>,
auth_mode: crate::config::auth::AuthMode,
) -> Self {
Self {
stream,
peer_addr,
next_request_id: AtomicU64::new(1),
state,
auth_mode,
identity: None,
}
}
pub fn new(
stream: TcpStream,
peer_addr: SocketAddr,
state: Arc<SharedState>,
auth_mode: crate::config::auth::AuthMode,
) -> Self {
Self::with_stream(ConnStream::plain(stream), peer_addr, state, auth_mode)
}
pub fn new_tls(
stream: tokio_rustls::server::TlsStream<TcpStream>,
peer_addr: SocketAddr,
state: Arc<SharedState>,
auth_mode: crate::config::auth::AuthMode,
) -> Self {
Self::with_stream(ConnStream::tls(stream), peer_addr, state, auth_mode)
}
fn next_request_id(&self) -> RequestId {
RequestId::new(self.next_request_id.fetch_add(1, Ordering::Relaxed))
}
#[instrument(skip(self), fields(peer = %self.peer_addr))]
pub async fn run(mut self) -> crate::Result<()> {
let idle_timeout_secs = self.state.idle_timeout_secs();
loop {
let mut len_buf = [0u8; 4];
let read_result: std::io::Result<usize> = if idle_timeout_secs > 0 {
match tokio::time::timeout(
Duration::from_secs(idle_timeout_secs),
self.stream.read_exact(&mut len_buf),
)
.await
{
Ok(result) => result,
Err(_) => {
debug!("session idle timeout ({}s)", idle_timeout_secs);
return Ok(());
}
}
} else {
self.stream.read_exact(&mut len_buf).await
};
match read_result {
Ok(_) => {}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
debug!("client disconnected");
return Ok(());
}
Err(e) => return Err(e.into()),
}
let payload_len = u32::from_be_bytes(len_buf);
if payload_len > MAX_FRAME_SIZE {
warn!(payload_len, "frame too large, closing connection");
return Err(crate::Error::BadRequest {
detail: format!("frame size {payload_len} exceeds maximum {MAX_FRAME_SIZE}"),
});
}
let mut payload = vec![0u8; payload_len as usize];
self.stream.read_exact(&mut payload).await?;
let request_id = self.next_request_id();
match self.handle_frame(request_id, &payload).await {
Ok(response_bytes) => {
let resp_len = (response_bytes.len() as u32).to_be_bytes();
self.stream.write_all(&resp_len).await?;
self.stream.write_all(&response_bytes).await?;
}
Err(e) => {
let error_json = format!(r#"{{"status":"error","error":"{e}"}}"#);
let resp_len = (error_json.len() as u32).to_be_bytes();
self.stream.write_all(&resp_len).await?;
self.stream.write_all(error_json.as_bytes()).await?;
}
}
}
}
async fn handle_frame(
&mut self,
request_id: RequestId,
payload: &[u8],
) -> crate::Result<Vec<u8>> {
let body: serde_json::Value =
serde_json::from_slice(payload).map_err(|e| crate::Error::BadRequest {
detail: format!("invalid JSON: {e}"),
})?;
let op = body["op"]
.as_str()
.ok_or_else(|| crate::Error::BadRequest {
detail: "missing 'op' field".into(),
})?;
if op == "auth" {
let identity = super::session_auth::authenticate(
&self.state,
&self.auth_mode,
&body,
&self.peer_addr.to_string(),
)?;
let resp = format!(
r#"{{"status":"ok","username":"{}","tenant_id":{}}}"#,
identity.username,
identity.tenant_id.as_u32()
);
self.identity = Some(identity);
return Ok(resp.into_bytes());
}
if self.identity.is_none() {
if self.auth_mode == crate::config::auth::AuthMode::Trust {
self.identity = Some(super::session_auth::trust_identity(
&self.state,
"anonymous",
));
} else {
return Err(crate::Error::RejectedAuthz {
tenant_id: TenantId::new(0),
resource: r#"not authenticated. Send {"op":"auth",...} first."#.into(),
});
}
}
let identity = match self.identity.as_ref() {
Some(id) => id,
None => {
return Err(crate::Error::RejectedAuthz {
tenant_id: TenantId::new(0),
resource: "not authenticated".into(),
});
}
};
let tenant_id = identity.tenant_id;
let collection = body["collection"].as_str().unwrap_or("default").to_string();
let vshard_key = body["document_id"].as_str().unwrap_or(&collection);
let vshard_id = VShardId::from_key(vshard_key.as_bytes());
let plan = match op {
"point_get" => {
let document_id = body["document_id"]
.as_str()
.ok_or_else(|| crate::Error::BadRequest {
detail: "missing 'document_id'".into(),
})?
.to_string();
PhysicalPlan::Document(DocumentOp::PointGet {
collection,
document_id,
rls_filters: Vec::new(),
})
}
"vector_search" => {
let query_vector: Vec<f32> = body["query_vector"]
.as_array()
.ok_or_else(|| crate::Error::BadRequest {
detail: "missing 'query_vector'".into(),
})?
.iter()
.filter_map(|v| v.as_f64().map(|f| f as f32))
.collect();
let top_k = body["top_k"].as_u64().unwrap_or(10) as usize;
PhysicalPlan::Vector(VectorOp::Search {
collection,
query_vector: Arc::from(query_vector.into_boxed_slice()),
top_k,
ef_search: 0,
filter_bitmap: None,
field_name: String::new(),
rls_filters: Vec::new(),
})
}
"range_scan" => {
let field = body["field"]
.as_str()
.ok_or_else(|| crate::Error::BadRequest {
detail: "missing 'field'".into(),
})?
.to_string();
let limit = body["limit"].as_u64().unwrap_or(100) as usize;
PhysicalPlan::Document(DocumentOp::RangeScan {
collection,
field,
lower: None,
upper: None,
limit,
})
}
"crdt_read" => {
let document_id = body["document_id"]
.as_str()
.ok_or_else(|| crate::Error::BadRequest {
detail: "missing 'document_id'".into(),
})?
.to_string();
PhysicalPlan::Crdt(CrdtOp::Read {
collection,
document_id,
})
}
"crdt_apply" => {
let document_id = body["document_id"]
.as_str()
.ok_or_else(|| crate::Error::BadRequest {
detail: "missing 'document_id'".into(),
})?
.to_string();
let delta_b64 = body["delta"]
.as_str()
.ok_or_else(|| crate::Error::BadRequest {
detail: "missing 'delta'".into(),
})?;
let delta = delta_b64.as_bytes().to_vec();
let peer_id = body["peer_id"].as_u64().unwrap_or(0);
PhysicalPlan::Crdt(CrdtOp::Apply {
collection,
document_id,
delta,
peer_id,
mutation_id: 0,
})
}
"graph_rag_fusion" => {
let query_vector: Vec<f32> = body["query_vector"]
.as_array()
.ok_or_else(|| crate::Error::BadRequest {
detail: "missing 'query_vector'".into(),
})?
.iter()
.filter_map(|v| v.as_f64().map(|f| f as f32))
.collect();
let vector_top_k = body["vector_top_k"].as_u64().unwrap_or(20) as usize;
let edge_label = body["edge_label"].as_str().map(String::from);
let direction_str = body["direction"].as_str().unwrap_or("out");
let direction = match direction_str {
"in" => crate::engine::graph::edge_store::Direction::In,
"both" => crate::engine::graph::edge_store::Direction::Both,
_ => crate::engine::graph::edge_store::Direction::Out,
};
let expansion_depth = body["expansion_depth"].as_u64().unwrap_or(2) as usize;
let final_top_k = body["final_top_k"].as_u64().unwrap_or(10) as usize;
let vector_k = body["vector_k"].as_f64().unwrap_or(60.0);
let graph_k = body["graph_k"].as_f64().unwrap_or(10.0);
PhysicalPlan::Graph(GraphOp::RagFusion {
collection,
query_vector: Arc::from(query_vector.into_boxed_slice()),
vector_top_k,
edge_label,
direction,
expansion_depth,
final_top_k,
rrf_k: (vector_k, graph_k),
options: Default::default(),
})
}
"alter_collection_policy" => {
let policy = &body["policy"];
if policy.is_null() {
return Err(crate::Error::BadRequest {
detail: "missing 'policy' field".into(),
});
}
let policy_json =
serde_json::to_string(policy).map_err(|e| crate::Error::BadRequest {
detail: format!("invalid policy JSON: {e}"),
})?;
PhysicalPlan::Crdt(CrdtOp::SetPolicy {
collection,
policy_json,
})
}
_ => {
return Err(crate::Error::BadRequest {
detail: format!("unknown op: {op}"),
});
}
};
let request = Request {
request_id,
tenant_id,
vshard_id,
plan,
deadline: Instant::now()
+ Duration::from_secs(self.state.tuning.network.default_deadline_secs),
priority: Priority::Normal,
trace_id: 0,
consistency: ReadConsistency::Strong,
idempotency_key: None,
};
let rx = self.state.tracker.register_oneshot(request_id);
match self.state.dispatcher.lock() {
Ok(mut d) => d.dispatch(request)?,
Err(poisoned) => poisoned.into_inner().dispatch(request)?,
};
let response = tokio::time::timeout(
Duration::from_secs(self.state.tuning.network.default_deadline_secs),
rx,
)
.await
.map_err(|_| crate::Error::DeadlineExceeded { request_id })?
.map_err(|_| crate::Error::Dispatch {
detail: "response channel closed".into(),
})?;
let status_str = match response.status {
Status::Ok => "ok",
Status::Partial => "partial",
Status::Error => "error",
};
let payload_b64 = if response.payload.is_empty() {
String::new()
} else {
String::from_utf8_lossy(&response.payload).into_owned()
};
let error_code_str = response.error_code.as_ref().map(|ec| format!("{ec:?}"));
let resp_json = format!(
r#"{{"request_id":{},"status":"{}","payload":"{}","watermark_lsn":{},"error_code":{}}}"#,
response.request_id.as_u64(),
status_str,
payload_b64,
response.watermark_lsn.as_u64(),
error_code_str
.as_ref()
.map(|s| format!("\"{s}\""))
.unwrap_or_else(|| "null".to_string()),
);
Ok(resp_json.into_bytes())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bridge::dispatch::Dispatcher;
use crate::data::executor::core_loop::CoreLoop;
use crate::wal::WalManager;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
#[tokio::test]
async fn full_request_response_roundtrip() {
let dir = tempfile::tempdir().unwrap();
let wal_path = dir.path().join("test.wal");
let wal = Arc::new(WalManager::open_for_testing(&wal_path).unwrap());
let (dispatcher, data_sides) = Dispatcher::new(1, 64);
let shared = SharedState::new(dispatcher, wal);
let data_side = data_sides.into_iter().next().unwrap();
let core_dir = dir.path().to_path_buf();
let (core_stop_tx, core_stop_rx) = std::sync::mpsc::channel::<()>();
let core_handle = tokio::task::spawn_blocking(move || {
let mut core =
CoreLoop::open(0, data_side.request_rx, data_side.response_tx, &core_dir).unwrap();
while core_stop_rx.try_recv().is_err() {
core.tick();
std::thread::sleep(Duration::from_millis(1));
}
});
let shared_poller = Arc::clone(&shared);
let (poller_stop_tx, mut poller_stop_rx) = tokio::sync::watch::channel(false);
let poller_handle = tokio::spawn(async move {
loop {
shared_poller.poll_and_route_responses();
tokio::select! {
_ = tokio::time::sleep(Duration::from_millis(1)) => {}
_ = poller_stop_rx.changed() => break,
}
}
});
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let shared_session = Arc::clone(&shared);
let session_handle = tokio::spawn(async move {
let (stream, peer_addr) = listener.accept().await.unwrap();
let session = Session::new(
stream,
peer_addr,
shared_session,
crate::config::auth::AuthMode::Trust,
);
session.run().await
});
let mut client = tokio::net::TcpStream::connect(addr).await.unwrap();
let request_json =
r#"{"op":"point_get","tenant_id":1,"collection":"users","document_id":"u1"}"#;
let len = (request_json.len() as u32).to_be_bytes();
client.write_all(&len).await.unwrap();
client.write_all(request_json.as_bytes()).await.unwrap();
let mut resp_len_buf = [0u8; 4];
client.read_exact(&mut resp_len_buf).await.unwrap();
let resp_len = u32::from_be_bytes(resp_len_buf) as usize;
let mut resp_buf = vec![0u8; resp_len];
client.read_exact(&mut resp_buf).await.unwrap();
let resp_str = String::from_utf8(resp_buf).unwrap();
assert!(
resp_str.contains(r#""status""#),
"expected a valid response, got: {resp_str}"
);
assert!(
resp_str.contains(r#""request_id""#),
"expected request_id in response, got: {resp_str}"
);
drop(client);
let _ = session_handle.await;
let _ = poller_stop_tx.send(true);
let _ = poller_handle.await;
let _ = core_stop_tx.send(());
let _ = core_handle.await;
}
}