#![cfg(all(feature = "net", feature = "cortex"))]
use std::net::SocketAddr;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
use bytes::Bytes;
use futures::StreamExt;
use net::adapter::net::cortex::{
RequestStream, RpcDuplexHandler, RpcHandlerError, RpcResponseSink, RpcStreamingContext,
};
use net::adapter::net::mesh_rpc::{CallOptions, CodecDirection, RpcError};
use net::adapter::net::{EntityKeypair, MeshNode, MeshNodeConfig, SocketBufferConfig};
const TEST_BUFFER_SIZE: usize = 256 * 1024;
const PSK: [u8; 32] = [0x42u8; 32];
fn test_config() -> MeshNodeConfig {
let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
let mut cfg = MeshNodeConfig::new(addr, PSK)
.with_heartbeat_interval(Duration::from_millis(200))
.with_session_timeout(Duration::from_secs(5))
.with_handshake(3, Duration::from_secs(2))
.with_capability_gc_interval(Duration::from_millis(250));
cfg.socket_buffers = SocketBufferConfig {
send_buffer_size: TEST_BUFFER_SIZE,
recv_buffer_size: TEST_BUFFER_SIZE,
};
cfg
}
async fn build_node() -> Arc<MeshNode> {
let cfg = test_config();
let keypair = EntityKeypair::generate();
Arc::new(MeshNode::new(keypair, cfg).await.expect("MeshNode::new"))
}
async fn handshake_pair(a: &Arc<MeshNode>, b: &Arc<MeshNode>) {
let a_id = a.node_id();
let b_id = b.node_id();
let b_pub = *b.public_key();
let b_addr = b.local_addr();
let b_clone = b.clone();
let accept = tokio::spawn(async move { b_clone.accept(a_id).await });
a.connect(b_addr, &b_pub, b_id)
.await
.expect("connect failed");
accept
.await
.expect("accept task panicked")
.expect("accept failed");
a.start();
b.start();
}
struct EchoDuplexHandler;
#[async_trait::async_trait]
impl RpcDuplexHandler for EchoDuplexHandler {
async fn call(
&self,
_ctx: RpcStreamingContext,
mut requests: RequestStream,
responses: RpcResponseSink,
) -> Result<(), RpcHandlerError> {
let mut count: u64 = 0;
while let Some(req) = requests.next().await {
let mut body = b"echo:".to_vec();
body.extend_from_slice(&req);
responses.send(body);
count += 1;
}
responses.send(format!("total:{count}").into_bytes());
Ok(())
}
}
struct EmitNThenWaitForEosHandler {
n: usize,
}
#[async_trait::async_trait]
impl RpcDuplexHandler for EmitNThenWaitForEosHandler {
async fn call(
&self,
_ctx: RpcStreamingContext,
mut requests: RequestStream,
responses: RpcResponseSink,
) -> Result<(), RpcHandlerError> {
for i in 0..self.n {
responses.send(format!("pre-{i}").into_bytes());
}
while requests.next().await.is_some() {}
Ok(())
}
}
struct ServerTerminatesFirstHandler;
#[async_trait::async_trait]
impl RpcDuplexHandler for ServerTerminatesFirstHandler {
async fn call(
&self,
_ctx: RpcStreamingContext,
_requests: RequestStream,
responses: RpcResponseSink,
) -> Result<(), RpcHandlerError> {
responses.send(Bytes::from_static(b"only"));
Ok(())
}
}
struct ForeverHandler {
observed_cancel: Arc<AtomicBool>,
consumed: Arc<AtomicUsize>,
}
#[async_trait::async_trait]
impl RpcDuplexHandler for ForeverHandler {
async fn call(
&self,
ctx: RpcStreamingContext,
mut requests: RequestStream,
responses: RpcResponseSink,
) -> Result<(), RpcHandlerError> {
loop {
tokio::select! {
_ = ctx.cancellation.cancelled() => {
self.observed_cancel.store(true, Ordering::SeqCst);
return Ok(());
}
maybe = requests.next() => {
match maybe {
Some(body) => {
self.consumed.fetch_add(1, Ordering::SeqCst);
responses.send(body);
}
None => {
ctx.cancellation.cancelled().await;
self.observed_cancel.store(true, Ordering::SeqCst);
return Ok(());
}
}
}
}
}
}
}
#[tokio::test]
async fn duplex_interleaves_send_and_recv() {
let server = build_node().await;
let caller = build_node().await;
handshake_pair(&caller, &server).await;
let _serve = server
.serve_rpc_duplex("echo", Arc::new(EchoDuplexHandler))
.expect("serve_rpc_duplex");
let mut call = caller
.call_duplex(server.node_id(), "echo", CallOptions::default())
.await
.expect("call_duplex");
for i in 0..5u8 {
call.send(Bytes::copy_from_slice(&[b'a' + i]))
.await
.expect("send");
}
call.finish_sending().await.expect("finish_sending");
let mut collected: Vec<String> = Vec::new();
while let Some(item) = call.next().await {
let chunk = item.expect("chunk must be Ok");
collected.push(String::from_utf8(chunk.to_vec()).unwrap());
}
assert_eq!(collected.len(), 6, "5 echoes + 1 summary");
for (i, label) in (0..5u8).zip(["a", "b", "c", "d", "e"]) {
assert_eq!(collected[i as usize], format!("echo:{label}"));
}
assert_eq!(collected[5], "total:5");
}
#[tokio::test]
async fn duplex_finish_sending_keeps_response_stream_open() {
let server = build_node().await;
let caller = build_node().await;
handshake_pair(&caller, &server).await;
let _serve = server
.serve_rpc_duplex("emit_n", Arc::new(EmitNThenWaitForEosHandler { n: 3 }))
.expect("serve_rpc_duplex");
let mut call = caller
.call_duplex(server.node_id(), "emit_n", CallOptions::default())
.await
.expect("call_duplex");
call.finish_sending().await.expect("finish_sending");
let mut collected: Vec<String> = Vec::new();
while let Some(item) = call.next().await {
let chunk = item.expect("Ok chunk");
collected.push(String::from_utf8(chunk.to_vec()).unwrap());
}
assert_eq!(collected, vec!["pre-0", "pre-1", "pre-2"]);
}
#[tokio::test]
async fn duplex_server_terminates_first() {
let server = build_node().await;
let caller = build_node().await;
handshake_pair(&caller, &server).await;
let _serve = server
.serve_rpc_duplex("term_first", Arc::new(ServerTerminatesFirstHandler))
.expect("serve_rpc_duplex");
let mut call = caller
.call_duplex(server.node_id(), "term_first", CallOptions::default())
.await
.expect("call_duplex");
call.send(Bytes::from_static(b"hello"))
.await
.expect("send hello");
let mut collected: Vec<Bytes> = Vec::new();
while let Some(item) = call.next().await {
match item {
Ok(body) => collected.push(body),
Err(_) => break,
}
}
assert_eq!(collected.len(), 1);
assert_eq!(&collected[0][..], b"only");
}
#[tokio::test]
async fn duplex_into_split_lets_halves_run_independently() {
let server = build_node().await;
let caller = build_node().await;
handshake_pair(&caller, &server).await;
let _serve = server
.serve_rpc_duplex("echo_split", Arc::new(EchoDuplexHandler))
.expect("serve_rpc_duplex");
let call = caller
.call_duplex(server.node_id(), "echo_split", CallOptions::default())
.await
.expect("call_duplex");
let (mut sink, mut stream) = call.into_split();
let sender = tokio::spawn(async move {
for i in 0..5u8 {
sink.send(Bytes::copy_from_slice(&[b'A' + i]))
.await
.expect("send");
}
sink.finish_sending().await.expect("finish_sending");
});
let receiver = tokio::spawn(async move {
let mut count = 0;
while let Some(item) = stream.next().await {
std::hint::black_box(item.expect("Ok"));
count += 1;
}
count
});
sender.await.expect("sender task");
let count = receiver.await.expect("receiver task");
assert_eq!(count, 6, "5 echoes + 1 summary");
}
#[tokio::test]
async fn duplex_cancel_from_caller_closes_both_halves() {
let server = build_node().await;
let caller = build_node().await;
handshake_pair(&caller, &server).await;
let observed = Arc::new(AtomicBool::new(false));
let consumed = Arc::new(AtomicUsize::new(0));
let _serve = server
.serve_rpc_duplex(
"forever",
Arc::new(ForeverHandler {
observed_cancel: observed.clone(),
consumed: consumed.clone(),
}),
)
.expect("serve_rpc_duplex");
let mut call = caller
.call_duplex(server.node_id(), "forever", CallOptions::default())
.await
.expect("call_duplex");
call.send(Bytes::from_static(b"first"))
.await
.expect("send 1");
call.send(Bytes::from_static(b"second"))
.await
.expect("send 2");
drop(call);
let deadline = std::time::Instant::now() + Duration::from_secs(3);
while !observed.load(Ordering::SeqCst) && std::time::Instant::now() < deadline {
tokio::time::sleep(Duration::from_millis(50)).await;
}
assert!(
observed.load(Ordering::SeqCst),
"handler must observe ctx.cancellation after caller drops the handle"
);
}
#[tokio::test]
async fn duplex_into_split_one_half_drop_does_not_cancel() {
let server = build_node().await;
let caller = build_node().await;
handshake_pair(&caller, &server).await;
let observed = Arc::new(AtomicBool::new(false));
let consumed = Arc::new(AtomicUsize::new(0));
let _serve = server
.serve_rpc_duplex(
"no_cancel_on_one_drop",
Arc::new(ForeverHandler {
observed_cancel: observed.clone(),
consumed: consumed.clone(),
}),
)
.expect("serve_rpc_duplex");
let call = caller
.call_duplex(
server.node_id(),
"no_cancel_on_one_drop",
CallOptions::default(),
)
.await
.expect("call_duplex");
let (mut sink, stream) = call.into_split();
sink.send(Bytes::from_static(b"keepalive"))
.await
.expect("send");
drop(sink);
tokio::time::sleep(Duration::from_millis(200)).await;
assert!(
!observed.load(Ordering::SeqCst),
"server must NOT observe CANCEL while the stream half is alive",
);
drop(stream);
let deadline = std::time::Instant::now() + Duration::from_secs(3);
while !observed.load(Ordering::SeqCst) && std::time::Instant::now() < deadline {
tokio::time::sleep(Duration::from_millis(50)).await;
}
assert!(
observed.load(Ordering::SeqCst),
"server must observe CANCEL once BOTH halves drop",
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn call_duplex_rejects_zero_request_window() {
let caller = build_node().await;
let target = 0xC0DE_u64;
let opts = CallOptions {
request_window_initial: Some(0),
..CallOptions::default()
};
let err = match caller.call_duplex(target, "anything", opts).await {
Ok(_) => panic!("Some(0) must be rejected before any wire traffic"),
Err(e) => e,
};
match err {
RpcError::Codec { direction, message } => {
assert_eq!(direction, CodecDirection::Encode);
assert!(
message.contains("request_window_initial"),
"diagnostic must name the offending option: {message}",
);
}
other => panic!("expected RpcError::Codec(Encode), got {other:?}"),
}
}