use std::collections::{HashMap, HashSet};
use std::pin::Pin;
use std::sync::{Arc, Weak};
use async_trait::async_trait;
use dashmap::DashMap;
use futures::stream::Stream;
use futures::StreamExt;
use parking_lot::RwLock;
use tokio::sync::mpsc;
use tokio::sync::Notify;
use tokio_stream::wrappers::ReceiverStream;
use super::error::MeshError;
use super::executor::MeshQueryExecutor;
use super::federated::{MeshDbTransport, ResponseStream, TransportError};
use super::protocol::{MeshDbFrame, MeshDbRequest, MeshDbResponse, ResultBatch};
pub const MESHDB_RESPONSE_INBOX_CAPACITY: usize = 64;
pub const MESHDB_SERVER_OUTBOX_CAPACITY: usize = 64;
pub const MESHDB_SERVER_BATCH_ROWS: usize = 64;
pub const MESHDB_SERVER_PENDING_CANCELS_CAP: usize = 256;
pub const MESHDB_MAX_INBOUND_FRAME_BYTES: usize = 1024 * 1024;
pub trait MeshDbInboundRouter: Send + Sync {
fn try_route(&self, from_node: u64, bytes: &[u8]) -> Result<(), MeshDbRouteError>;
}
#[derive(Debug)]
pub enum MeshDbRouteError {
Decode(postcard::Error),
FrameTooLarge(usize),
NoServer,
UnknownCallId(u64),
WrongPeer {
call_id: u64,
expected: u64,
actual: u64,
},
InboxFull(u64),
}
impl std::fmt::Display for MeshDbRouteError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Decode(e) => write!(f, "meshdb frame decode failed: {e}"),
Self::FrameTooLarge(n) => write!(
f,
"meshdb inbound frame too large: {n} > {MESHDB_MAX_INBOUND_FRAME_BYTES} bytes"
),
Self::NoServer => write!(f, "no MeshDbServer installed on this node"),
Self::UnknownCallId(id) => write!(f, "no in-flight caller for call_id={id:#x}"),
Self::WrongPeer {
call_id,
expected,
actual,
} => write!(
f,
"meshdb response for call_id={call_id:#x} arrived from node {actual:#x}; expected {expected:#x}"
),
Self::InboxFull(id) => write!(f, "caller mpsc full for call_id={id:#x}"),
}
}
}
impl std::error::Error for MeshDbRouteError {}
#[async_trait]
pub trait MeshDbWireSender: Send + Sync {
async fn send_frame(&self, target_node: u64, frame: MeshDbFrame) -> Result<(), TransportError>;
}
struct InflightCaller {
tx: mpsc::Sender<MeshDbResponse>,
target_node: u64,
}
pub struct MeshDbWireDispatcher {
sender: Arc<dyn MeshDbWireSender>,
inflight: Arc<DashMap<u64, InflightCaller>>,
server: Arc<RwLock<Option<Arc<MeshDbServer>>>>,
}
impl MeshDbWireDispatcher {
pub fn new(sender: Arc<dyn MeshDbWireSender>) -> Self {
Self {
sender,
inflight: Arc::new(DashMap::new()),
server: Arc::new(RwLock::new(None)),
}
}
pub fn set_server(&self, server: Option<Arc<MeshDbServer>>) {
*self.server.write() = server;
}
pub fn transport(&self) -> Arc<MeshDbWireTransport> {
Arc::new(MeshDbWireTransport {
sender: self.sender.clone(),
inflight: self.inflight.clone(),
})
}
}
impl MeshDbInboundRouter for MeshDbWireDispatcher {
fn try_route(&self, from_node: u64, bytes: &[u8]) -> Result<(), MeshDbRouteError> {
if bytes.len() > MESHDB_MAX_INBOUND_FRAME_BYTES {
return Err(MeshDbRouteError::FrameTooLarge(bytes.len()));
}
let frame = MeshDbFrame::decode(bytes).map_err(MeshDbRouteError::Decode)?;
match frame {
MeshDbFrame::Request(req) => {
let server = self.server.read().clone();
match server {
Some(srv) => {
srv.dispatch_request(from_node, req, self.sender.clone());
Ok(())
}
None => Err(MeshDbRouteError::NoServer),
}
}
MeshDbFrame::Response(resp) => {
let call_id = response_call_id(&resp);
let entry = self
.inflight
.get(&call_id)
.ok_or(MeshDbRouteError::UnknownCallId(call_id))?;
if entry.target_node != from_node {
return Err(MeshDbRouteError::WrongPeer {
call_id,
expected: entry.target_node,
actual: from_node,
});
}
entry.tx.try_send(resp).map_err(|e| match e {
mpsc::error::TrySendError::Full(_) => MeshDbRouteError::InboxFull(call_id),
mpsc::error::TrySendError::Closed(_) => {
MeshDbRouteError::UnknownCallId(call_id)
}
})?;
Ok(())
}
}
}
}
fn response_call_id(r: &MeshDbResponse) -> u64 {
match r {
MeshDbResponse::Batch { call_id, .. } => *call_id,
MeshDbResponse::End { call_id } => *call_id,
MeshDbResponse::Error { call_id, .. } => *call_id,
}
}
pub struct MeshDbWireTransport {
sender: Arc<dyn MeshDbWireSender>,
inflight: Arc<DashMap<u64, InflightCaller>>,
}
#[async_trait]
impl MeshDbTransport for MeshDbWireTransport {
async fn send(
&self,
node: u64,
request: MeshDbRequest,
) -> Result<ResponseStream, TransportError> {
let call_id = request_call_id(&request);
let (tx, rx) = mpsc::channel(MESHDB_RESPONSE_INBOX_CAPACITY);
let prev = self.inflight.insert(
call_id,
InflightCaller {
tx,
target_node: node,
},
);
debug_assert!(
prev.is_none(),
"duplicate inflight call_id={call_id:#x}; previous caller silently overwritten",
);
let send_result = self
.sender
.send_frame(node, MeshDbFrame::Request(request))
.await;
if let Err(e) = send_result {
self.inflight.remove(&call_id);
return Err(e);
}
let inflight = self.inflight.clone();
let stream = ResponseStreamGuard {
inner: ReceiverStream::new(rx),
call_id,
inflight,
terminated: false,
};
Ok(Box::pin(stream))
}
}
fn request_call_id(r: &MeshDbRequest) -> u64 {
match r {
MeshDbRequest::Execute { call_id, .. } => *call_id,
MeshDbRequest::Resume { call_id, .. } => *call_id,
MeshDbRequest::Cancel { call_id } => *call_id,
}
}
struct ResponseStreamGuard {
inner: ReceiverStream<MeshDbResponse>,
call_id: u64,
inflight: Arc<DashMap<u64, InflightCaller>>,
terminated: bool,
}
impl Drop for ResponseStreamGuard {
fn drop(&mut self) {
self.inflight.remove(&self.call_id);
}
}
impl Stream for ResponseStreamGuard {
type Item = MeshDbResponse;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
use std::task::Poll;
if self.terminated {
return Poll::Ready(None);
}
match self.inner.poll_next_unpin(cx) {
Poll::Ready(Some(resp)) => {
if matches!(
&resp,
MeshDbResponse::End { .. }
| MeshDbResponse::Error { .. }
| MeshDbResponse::Batch {
batch: ResultBatch { r#final: true, .. },
..
}
) {
self.terminated = true;
}
Poll::Ready(Some(resp))
}
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
pub struct MeshDbServer {
executor: Arc<dyn MeshQueryExecutor>,
inflight: Arc<RwLock<HashMap<(u64, u64), ServerCallHandle>>>,
pending_cancels: Arc<RwLock<HashSet<(u64, u64)>>>,
}
struct ServerCallHandle {
cancel: Arc<Notify>,
}
impl MeshDbServer {
pub fn new(executor: Arc<dyn MeshQueryExecutor>) -> Arc<Self> {
Arc::new(Self {
executor,
inflight: Arc::new(RwLock::new(HashMap::new())),
pending_cancels: Arc::new(RwLock::new(HashSet::new())),
})
}
pub fn inflight_calls(&self) -> usize {
self.inflight.read().len()
}
pub fn pending_cancels(&self) -> usize {
self.pending_cancels.read().len()
}
fn dispatch_request(
self: &Arc<Self>,
peer: u64,
request: MeshDbRequest,
sender: Arc<dyn MeshDbWireSender>,
) {
match request {
MeshDbRequest::Execute { call_id, plan } => {
let cancelled_early = self.pending_cancels.write().remove(&(peer, call_id));
if cancelled_early {
let sender_clone = sender.clone();
tokio::spawn(async move {
let _ = sender_clone
.send_frame(
peer,
MeshDbFrame::Response(MeshDbResponse::Error {
call_id,
error: MeshError::QueryCancelled,
}),
)
.await;
});
return;
}
let cancel = Arc::new(Notify::new());
self.inflight.write().insert(
(peer, call_id),
ServerCallHandle {
cancel: cancel.clone(),
},
);
let executor = self.executor.clone();
let inflight = self.inflight.clone();
tokio::spawn(async move {
run_server_call(peer, call_id, plan, executor, sender, cancel, inflight).await;
});
}
MeshDbRequest::Cancel { call_id } => {
let guard = self.inflight.read();
if let Some(handle) = guard.get(&(peer, call_id)) {
handle.cancel.notify_one();
} else {
drop(guard);
let mut pending = self.pending_cancels.write();
if pending.len() < MESHDB_SERVER_PENDING_CANCELS_CAP {
pending.insert((peer, call_id));
}
}
}
MeshDbRequest::Resume { call_id, .. } => {
let sender_clone = sender.clone();
tokio::spawn(async move {
let _ = sender_clone
.send_frame(
peer,
MeshDbFrame::Response(MeshDbResponse::Error {
call_id,
error: MeshError::PlannerError {
detail:
"MeshDbRequest::Resume is not yet supported by the server"
.to_string(),
},
}),
)
.await;
});
}
}
}
}
async fn run_server_call(
peer: u64,
call_id: u64,
plan: super::planner::ExecutionPlan,
executor: Arc<dyn MeshQueryExecutor>,
sender: Arc<dyn MeshDbWireSender>,
cancel: Arc<Notify>,
inflight: Arc<RwLock<HashMap<(u64, u64), ServerCallHandle>>>,
) {
struct InflightGuard {
peer: u64,
call_id: u64,
inflight: Arc<RwLock<HashMap<(u64, u64), ServerCallHandle>>>,
}
impl Drop for InflightGuard {
fn drop(&mut self) {
self.inflight.write().remove(&(self.peer, self.call_id));
}
}
let _guard = InflightGuard {
peer,
call_id,
inflight,
};
let running = match executor.execute(plan).await {
Ok(r) => r,
Err(err) => {
let _ = sender
.send_frame(
peer,
MeshDbFrame::Response(MeshDbResponse::Error {
call_id,
error: err,
}),
)
.await;
return;
}
};
let handle = running.handle.clone();
let mut stream = running.rows;
let mut batch: Vec<super::query::ResultRow> = Vec::with_capacity(MESHDB_SERVER_BATCH_ROWS);
loop {
tokio::select! {
biased;
_ = cancel.notified() => {
handle.cancel();
let _ = sender
.send_frame(
peer,
MeshDbFrame::Response(MeshDbResponse::Error {
call_id,
error: MeshError::QueryCancelled,
}),
)
.await;
return;
}
item = stream.next() => {
match item {
Some(Ok(row)) => {
batch.push(row);
if batch.len() >= MESHDB_SERVER_BATCH_ROWS {
let chunk = std::mem::take(&mut batch);
if sender
.send_frame(
peer,
MeshDbFrame::Response(MeshDbResponse::Batch {
call_id,
batch: ResultBatch::chunk(chunk),
}),
)
.await
.is_err()
{
return;
}
}
}
Some(Err(err)) => {
let _ = sender
.send_frame(
peer,
MeshDbFrame::Response(MeshDbResponse::Error { call_id, error: err }),
)
.await;
return;
}
None => break,
}
}
}
}
if !batch.is_empty()
&& sender
.send_frame(
peer,
MeshDbFrame::Response(MeshDbResponse::Batch {
call_id,
batch: ResultBatch::last(batch),
}),
)
.await
.is_err()
{
return;
}
let _ = sender
.send_frame(peer, MeshDbFrame::Response(MeshDbResponse::End { call_id }))
.await;
}
pub struct MeshNodeMeshDbSender {
mesh: Weak<crate::adapter::net::MeshNode>,
}
impl MeshNodeMeshDbSender {
pub fn new(mesh: &Arc<crate::adapter::net::MeshNode>) -> Self {
Self {
mesh: Arc::downgrade(mesh),
}
}
}
#[async_trait]
impl MeshDbWireSender for MeshNodeMeshDbSender {
async fn send_frame(&self, target_node: u64, frame: MeshDbFrame) -> Result<(), TransportError> {
let mesh = self
.mesh
.upgrade()
.ok_or_else(|| TransportError::Other("mesh node dropped".into()))?;
let peer_addr = mesh
.peer_addr(target_node)
.ok_or(TransportError::NoRoute(target_node))?;
let bytes = frame
.encode()
.map_err(|e| TransportError::Other(format!("frame encode: {e}")))?;
mesh.send_subprotocol(peer_addr, super::protocol::SUBPROTOCOL_MESHDB, &bytes)
.await
.map_err(|e| TransportError::Other(e.to_string()))
}
}
pub fn enable_meshdb_on_mesh(
mesh: &Arc<crate::adapter::net::MeshNode>,
server: Option<Arc<MeshDbServer>>,
) -> (Arc<MeshDbWireDispatcher>, Arc<MeshDbWireTransport>) {
let sender = Arc::new(MeshNodeMeshDbSender::new(mesh));
let dispatcher = Arc::new(MeshDbWireDispatcher::new(sender));
dispatcher.set_server(server);
let transport = dispatcher.transport();
mesh.set_meshdb_inbound_router(Some(dispatcher.clone() as Arc<dyn MeshDbInboundRouter>));
(dispatcher, transport)
}
#[cfg(test)]
mod tests {
#![allow(
clippy::disallowed_methods,
reason = "test code legitimately uses std::sync::{Mutex,RwLock} for SUT setup; tests have no real poison concern"
)]
use super::*;
use crate::adapter::net::behavior::meshdb::executor::{
ChainReader, LocalMeshQueryExecutor, MeshQueryExecutor,
};
use crate::adapter::net::behavior::meshdb::planner::{
CostEstimate, ExecutionPlan, OperatorNode, OperatorPlan,
};
use crate::adapter::net::behavior::meshdb::query::SeqNum;
use std::collections::BTreeMap;
#[derive(Default)]
struct InMemoryWire {
dispatchers: parking_lot::Mutex<HashMap<u64, Arc<MeshDbWireDispatcher>>>,
from_node_of: parking_lot::Mutex<HashMap<u64, u64>>,
}
impl InMemoryWire {
fn register(&self, local_node: u64, dispatcher: Arc<MeshDbWireDispatcher>) {
self.dispatchers.lock().insert(local_node, dispatcher);
}
fn set_local(&self, target_node: u64, local_node: u64) {
self.from_node_of.lock().insert(target_node, local_node);
}
}
struct SenderTo {
wire: Arc<InMemoryWire>,
local_node: u64,
}
#[async_trait]
impl MeshDbWireSender for SenderTo {
async fn send_frame(
&self,
target_node: u64,
frame: MeshDbFrame,
) -> Result<(), TransportError> {
let bytes = frame
.encode()
.map_err(|e| TransportError::Other(e.to_string()))?;
let dispatcher = self
.wire
.dispatchers
.lock()
.get(&target_node)
.cloned()
.ok_or(TransportError::NoRoute(target_node))?;
self.wire.set_local(target_node, self.local_node);
dispatcher
.try_route(self.local_node, &bytes)
.map_err(|e| TransportError::Other(e.to_string()))?;
Ok(())
}
}
#[derive(Default)]
struct InMemoryChainReader {
chains: std::sync::Mutex<BTreeMap<u64, BTreeMap<SeqNum, Vec<u8>>>>,
}
impl InMemoryChainReader {
fn append(&self, origin: u64, seq: SeqNum, payload: Vec<u8>) {
self.chains
.lock()
.unwrap()
.entry(origin)
.or_default()
.insert(seq, payload);
}
}
impl ChainReader for InMemoryChainReader {
fn read_one(&self, origin: u64, seq: SeqNum) -> Option<Vec<u8>> {
self.chains.lock().unwrap().get(&origin)?.get(&seq).cloned()
}
fn read_range(&self, origin: u64, start: SeqNum, end: SeqNum) -> Vec<(SeqNum, Vec<u8>)> {
self.chains
.lock()
.unwrap()
.get(&origin)
.map(|c| c.range(start..end).map(|(s, p)| (*s, p.clone())).collect())
.unwrap_or_default()
}
fn latest_seq(&self, origin: u64) -> Option<SeqNum> {
self.chains
.lock()
.unwrap()
.get(&origin)?
.keys()
.next_back()
.copied()
}
}
fn atomic_plan(op: OperatorPlan) -> ExecutionPlan {
ExecutionPlan {
root: OperatorNode {
operator: op,
target_nodes: vec![0xB],
cost: CostEstimate::default(),
},
total_cost: CostEstimate::default(),
}
}
#[tokio::test]
async fn wire_dispatcher_round_trips_a_latest_query_across_two_nodes() {
let wire = Arc::new(InMemoryWire::default());
let node_a: u64 = 0xA;
let node_b: u64 = 0xB;
let sender_a = Arc::new(SenderTo {
wire: wire.clone(),
local_node: node_a,
});
let dispatcher_a = Arc::new(MeshDbWireDispatcher::new(sender_a));
wire.register(node_a, dispatcher_a.clone());
let reader_b = Arc::new(InMemoryChainReader::default());
reader_b.append(0xCAFE, SeqNum(7), b"hello-wire".to_vec());
let executor_b: Arc<dyn MeshQueryExecutor> =
Arc::new(LocalMeshQueryExecutor::new(reader_b));
let server_b = MeshDbServer::new(executor_b);
let sender_b = Arc::new(SenderTo {
wire: wire.clone(),
local_node: node_b,
});
let dispatcher_b = Arc::new(MeshDbWireDispatcher::new(sender_b));
dispatcher_b.set_server(Some(server_b.clone()));
wire.register(node_b, dispatcher_b);
let transport_a = dispatcher_a.transport();
let fed_a =
crate::adapter::net::behavior::meshdb::federated::FederatedMeshQueryExecutor::new(
transport_a,
);
let plan = atomic_plan(OperatorPlan::LatestRead { origin: 0xCAFE });
let running = fed_a
.execute(plan)
.await
.expect("federated execute over the wire");
use futures::StreamExt;
let mut rows = Vec::new();
let mut stream = running.rows;
while let Some(item) = stream.next().await {
rows.push(item.expect("row"));
}
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].origin, 0xCAFE);
assert_eq!(rows[0].seq, SeqNum(7));
assert_eq!(rows[0].payload, b"hello-wire");
assert_eq!(server_b.inflight_calls(), 0);
}
#[tokio::test]
async fn wire_dispatcher_no_server_returns_route_error() {
let wire = Arc::new(InMemoryWire::default());
let sender = Arc::new(SenderTo {
wire: wire.clone(),
local_node: 0xA,
});
let dispatcher = Arc::new(MeshDbWireDispatcher::new(sender));
let frame = MeshDbFrame::Request(MeshDbRequest::Execute {
call_id: 42,
plan: atomic_plan(OperatorPlan::LatestRead { origin: 1 }),
});
let err = dispatcher
.try_route(0xB, &frame.encode().unwrap())
.expect_err("no server -> error");
assert!(matches!(err, MeshDbRouteError::NoServer));
}
#[tokio::test]
async fn wire_dispatcher_decode_error_surfaces_route_error() {
let wire = Arc::new(InMemoryWire::default());
let sender = Arc::new(SenderTo {
wire,
local_node: 0xA,
});
let dispatcher = Arc::new(MeshDbWireDispatcher::new(sender));
let err = dispatcher
.try_route(0xB, &[0xFFu8; 8])
.expect_err("garbage bytes -> decode error");
assert!(matches!(err, MeshDbRouteError::Decode(_)));
}
#[tokio::test]
async fn wire_response_to_unknown_call_id_drops() {
let wire = Arc::new(InMemoryWire::default());
let sender = Arc::new(SenderTo {
wire,
local_node: 0xA,
});
let dispatcher = Arc::new(MeshDbWireDispatcher::new(sender));
let frame = MeshDbFrame::Response(MeshDbResponse::End { call_id: 999 });
let err = dispatcher
.try_route(0xB, &frame.encode().unwrap())
.expect_err("no caller -> error");
assert!(matches!(err, MeshDbRouteError::UnknownCallId(999)));
}
#[tokio::test]
async fn server_rejects_resume_with_planner_error() {
let wire = Arc::new(InMemoryWire::default());
let node_a: u64 = 0xA;
let node_b: u64 = 0xB;
let sender_a = Arc::new(SenderTo {
wire: wire.clone(),
local_node: node_a,
});
let dispatcher_a = Arc::new(MeshDbWireDispatcher::new(sender_a));
wire.register(node_a, dispatcher_a.clone());
let reader = Arc::new(InMemoryChainReader::default());
let executor: Arc<dyn MeshQueryExecutor> = Arc::new(LocalMeshQueryExecutor::new(reader));
let server = MeshDbServer::new(executor);
let sender_b = Arc::new(SenderTo {
wire: wire.clone(),
local_node: node_b,
});
let dispatcher_b = Arc::new(MeshDbWireDispatcher::new(sender_b));
dispatcher_b.set_server(Some(server));
wire.register(node_b, dispatcher_b);
let call_id = 0xBEEFu64;
let (tx, mut rx) = mpsc::channel(8);
dispatcher_a.inflight.insert(
call_id,
InflightCaller {
tx,
target_node: node_b,
},
);
let sender_to_b = SenderTo {
wire: wire.clone(),
local_node: node_a,
};
sender_to_b
.send_frame(
node_b,
MeshDbFrame::Request(MeshDbRequest::Resume {
call_id,
token: super::super::protocol::ContinuationToken::new(vec![1, 2, 3]),
}),
)
.await
.expect("Resume must route");
let resp = tokio::time::timeout(std::time::Duration::from_secs(1), rx.recv())
.await
.expect("timed out waiting for Resume rejection")
.expect("server closed channel");
match resp {
MeshDbResponse::Error {
call_id: got,
error: MeshError::PlannerError { detail },
} => {
assert_eq!(got, call_id);
assert!(
detail.contains("Resume"),
"detail should name the unsupported op; got {detail:?}",
);
}
other => panic!("expected PlannerError; got {other:?}"),
}
}
#[tokio::test]
async fn wire_dispatcher_rejects_oversize_frame_before_decode() {
let wire = Arc::new(InMemoryWire::default());
let sender = Arc::new(SenderTo {
wire,
local_node: 0xA,
});
let dispatcher = Arc::new(MeshDbWireDispatcher::new(sender));
let oversize = vec![0u8; MESHDB_MAX_INBOUND_FRAME_BYTES + 1];
let err = dispatcher
.try_route(0xB, &oversize)
.expect_err("oversize -> error");
match err {
MeshDbRouteError::FrameTooLarge(n) => {
assert_eq!(n, MESHDB_MAX_INBOUND_FRAME_BYTES + 1);
}
other => panic!("expected FrameTooLarge; got {other:?}"),
}
let at_cap = vec![0xFFu8; MESHDB_MAX_INBOUND_FRAME_BYTES];
let err = dispatcher
.try_route(0xB, &at_cap)
.expect_err("garbage at cap -> decode error, not size error");
assert!(matches!(err, MeshDbRouteError::Decode(_)));
}
#[tokio::test]
async fn server_cancel_before_execute_short_circuits_to_query_cancelled() {
use std::sync::atomic::{AtomicUsize, Ordering};
let wire = Arc::new(InMemoryWire::default());
let sender = Arc::new(SenderTo {
wire: wire.clone(),
local_node: 0xB,
});
use crate::adapter::net::behavior::meshdb::executor::{ExecuteOptions, RunningQuery};
struct CountingExecutor {
executor: Arc<dyn MeshQueryExecutor>,
calls: Arc<AtomicUsize>,
}
#[async_trait]
impl MeshQueryExecutor for CountingExecutor {
async fn execute(&self, plan: ExecutionPlan) -> Result<RunningQuery, MeshError> {
self.calls.fetch_add(1, Ordering::SeqCst);
self.executor.execute(plan).await
}
async fn execute_with(
&self,
plan: ExecutionPlan,
options: ExecuteOptions,
) -> Result<RunningQuery, MeshError> {
self.calls.fetch_add(1, Ordering::SeqCst);
self.executor.execute_with(plan, options).await
}
}
let reader = Arc::new(InMemoryChainReader::default());
reader.append(0xCAFE, SeqNum(7), b"never-read".to_vec());
let inner: Arc<dyn MeshQueryExecutor> = Arc::new(LocalMeshQueryExecutor::new(reader));
let calls = Arc::new(AtomicUsize::new(0));
let executor: Arc<dyn MeshQueryExecutor> = Arc::new(CountingExecutor {
executor: inner,
calls: calls.clone(),
});
let server = MeshDbServer::new(executor);
let node_a: u64 = 0xA;
let sender_a = Arc::new(SenderTo {
wire: wire.clone(),
local_node: node_a,
});
let dispatcher_a = Arc::new(MeshDbWireDispatcher::new(sender_a));
wire.register(node_a, dispatcher_a.clone());
let call_id = 0xC0FFEE_u64;
let (tx, mut rx) = mpsc::channel(8);
dispatcher_a.inflight.insert(
call_id,
InflightCaller {
tx,
target_node: 0xB,
},
);
let plan = atomic_plan(OperatorPlan::LatestRead { origin: 0xCAFE });
let server_sender: Arc<dyn MeshDbWireSender> = sender.clone();
server.dispatch_request(
node_a,
MeshDbRequest::Cancel { call_id },
server_sender.clone(),
);
assert_eq!(
server.pending_cancels(),
1,
"cancel without a matching inflight handle must be parked",
);
server.dispatch_request(
node_a,
MeshDbRequest::Execute { call_id, plan },
server_sender,
);
let resp = tokio::time::timeout(std::time::Duration::from_secs(1), rx.recv())
.await
.expect("server response timed out")
.expect("server response channel closed");
match resp {
MeshDbResponse::Error {
call_id: got_id,
error: MeshError::QueryCancelled,
} => assert_eq!(got_id, call_id),
other => panic!("expected QueryCancelled; got {other:?}"),
}
assert_eq!(
calls.load(Ordering::SeqCst),
0,
"executor must not be driven when cancel arrives first",
);
assert_eq!(server.pending_cancels(), 0);
}
#[tokio::test]
async fn wire_response_from_wrong_peer_rejected_without_injection() {
let wire = Arc::new(InMemoryWire::default());
let node_a: u64 = 0xA;
let node_b: u64 = 0xB;
let node_c: u64 = 0xC;
let sender_a = Arc::new(SenderTo {
wire: wire.clone(),
local_node: node_a,
});
let dispatcher_a = Arc::new(MeshDbWireDispatcher::new(sender_a));
wire.register(node_a, dispatcher_a.clone());
let reader_b = Arc::new(InMemoryChainReader::default());
reader_b.append(0xCAFE, SeqNum(7), b"real".to_vec());
let executor_b: Arc<dyn MeshQueryExecutor> =
Arc::new(LocalMeshQueryExecutor::new(reader_b));
let server_b = MeshDbServer::new(executor_b);
let sender_b = Arc::new(SenderTo {
wire: wire.clone(),
local_node: node_b,
});
let dispatcher_b = Arc::new(MeshDbWireDispatcher::new(sender_b));
dispatcher_b.set_server(Some(server_b));
wire.register(node_b, dispatcher_b);
let transport_a = dispatcher_a.transport();
let plan_call_id = 0xDEAD_BEEF;
let mut stream = transport_a
.send(
node_b,
MeshDbRequest::Execute {
call_id: plan_call_id,
plan: atomic_plan(OperatorPlan::LatestRead { origin: 0xCAFE }),
},
)
.await
.expect("send to B");
let spoof = MeshDbFrame::Response(MeshDbResponse::Error {
call_id: plan_call_id,
error: MeshError::ExecutorError {
node: node_c,
detail: "spoofed".to_string(),
},
});
let err = dispatcher_a
.try_route(node_c, &spoof.encode().unwrap())
.expect_err("spoofed response must be rejected");
match err {
MeshDbRouteError::WrongPeer {
call_id,
expected,
actual,
} => {
assert_eq!(call_id, plan_call_id);
assert_eq!(expected, node_b);
assert_eq!(actual, node_c);
}
other => panic!("expected WrongPeer; got {other:?}"),
}
let mut got = Vec::new();
while let Some(resp) = stream.next().await {
got.push(resp);
}
assert!(
!got.iter().any(|r| matches!(
r,
MeshDbResponse::Error { error: MeshError::ExecutorError { detail, .. }, .. }
if detail == "spoofed"
)),
"spoofed response must not be visible to the caller; got {got:?}",
);
}
}