use crate::config::GlobalExecutor;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::Duration;
use tokio::sync::{mpsc, oneshot};
use tokio::time::timeout;
use crate::contract::executor::{OpRequestError, OpRequestReceiver, OpRequestSender};
use crate::message::Transaction;
use crate::operations::OpEnum;
use crate::operations::get::GetMsg;
struct TestMediator {
op_request_rx: OpRequestReceiver,
to_event_loop_tx: mpsc::Sender<Transaction>,
from_event_loop_rx: mpsc::Receiver<OpEnum>,
pending:
std::collections::HashMap<Transaction, oneshot::Sender<Result<OpEnum, OpRequestError>>>,
max_pending: usize,
}
impl TestMediator {
fn new(
op_request_rx: OpRequestReceiver,
to_event_loop_tx: mpsc::Sender<Transaction>,
from_event_loop_rx: mpsc::Receiver<OpEnum>,
) -> Self {
Self {
op_request_rx,
to_event_loop_tx,
from_event_loop_rx,
pending: std::collections::HashMap::new(),
max_pending: 100, }
}
async fn process_one_request(&mut self) -> bool {
tokio::select! {
Some((tx, response_tx)) = self.op_request_rx.recv() => {
if self.pending.len() >= self.max_pending {
drop(response_tx.send(Err(OpRequestError::Failed(
"mediator at capacity".to_string()
))));
return true;
}
self.pending.insert(tx, response_tx);
if self.to_event_loop_tx.send(tx).await.is_err() {
if let Some(pending) = self.pending.remove(&tx) {
drop(pending.send(Err(OpRequestError::ChannelClosed)));
}
}
true
}
Some(op_result) = self.from_event_loop_rx.recv() => {
let tx = *op_result.id();
if let Some(pending) = self.pending.remove(&tx) {
drop(pending.send(Ok(op_result)));
}
true
}
else => false
}
}
fn pending_count(&self) -> usize {
self.pending.len()
}
}
fn create_test_channels() -> (
OpRequestSender,
OpRequestReceiver,
mpsc::Sender<Transaction>,
mpsc::Receiver<Transaction>,
mpsc::Sender<OpEnum>,
mpsc::Receiver<OpEnum>,
) {
let (op_tx, op_rx) = mpsc::channel(100);
let (to_el_tx, to_el_rx) = mpsc::channel(100);
let (from_el_tx, from_el_rx) = mpsc::channel(100);
(op_tx, op_rx, to_el_tx, to_el_rx, from_el_tx, from_el_rx)
}
#[tokio::test]
async fn test_executor_drops_before_response() {
let (op_tx, op_rx, to_el_tx, _to_el_rx, _from_el_tx, from_el_rx) = create_test_channels();
let mut mediator = TestMediator::new(op_rx, to_el_tx, from_el_rx);
let tx = Transaction::new::<GetMsg>();
let (response_tx, response_rx) = oneshot::channel();
op_tx.send((tx, response_tx)).await.unwrap();
drop(response_rx);
mediator.process_one_request().await;
assert_eq!(mediator.pending_count(), 1);
}
#[tokio::test]
async fn test_multiple_executors_drop_independently() {
let (op_tx, op_rx, to_el_tx, _to_el_rx, _from_el_tx, from_el_rx) = create_test_channels();
let mut mediator = TestMediator::new(op_rx, to_el_tx, from_el_rx);
let dropped = Arc::new(AtomicU32::new(0));
for i in 0..5 {
let tx = Transaction::new::<GetMsg>();
let (response_tx, response_rx) = oneshot::channel();
op_tx.send((tx, response_tx)).await.unwrap();
if i % 2 == 1 {
drop(response_rx);
dropped.fetch_add(1, Ordering::SeqCst);
} else {
GlobalExecutor::spawn(async move {
tokio::time::sleep(Duration::from_millis(100)).await;
drop(response_rx);
});
}
}
for _ in 0..5 {
mediator.process_one_request().await;
}
assert_eq!(mediator.pending_count(), 5);
}
#[tokio::test]
async fn test_stale_requests_identified() {
use std::time::Instant;
let mut pending: std::collections::HashMap<
Transaction,
(oneshot::Sender<Result<OpEnum, OpRequestError>>, Instant),
> = std::collections::HashMap::new();
for i in 0..5 {
let tx = Transaction::new::<GetMsg>();
let (response_tx, _response_rx) = oneshot::channel();
let created_at = Instant::now() - Duration::from_secs(i * 60);
pending.insert(tx, (response_tx, created_at));
}
let threshold = Duration::from_secs(120);
let stale: Vec<_> = pending
.iter()
.filter(|(_, (_, created))| created.elapsed() > threshold)
.map(|(tx, _)| *tx)
.collect();
assert_eq!(stale.len(), 3);
}
#[tokio::test]
async fn test_stale_cleanup_notifies_waiters() {
let (response_tx, response_rx) = oneshot::channel::<Result<OpEnum, OpRequestError>>();
drop(response_tx.send(Err(OpRequestError::Failed(
"request exceeded stale threshold".to_string(),
))));
let result = response_rx.await.unwrap();
assert!(matches!(result, Err(OpRequestError::Failed(msg)) if msg.contains("stale")));
}
#[tokio::test]
async fn test_event_loop_channel_closes() {
let (op_tx, op_rx, to_el_tx, _to_el_rx, _from_el_tx, from_el_rx) = create_test_channels();
drop(_to_el_rx);
let mut mediator = TestMediator::new(op_rx, to_el_tx, from_el_rx);
let tx = Transaction::new::<GetMsg>();
let (response_tx, response_rx) = oneshot::channel();
op_tx.send((tx, response_tx)).await.unwrap();
mediator.process_one_request().await;
let result = timeout(Duration::from_millis(100), response_rx).await;
assert!(result.is_ok());
let response = result.unwrap().unwrap();
assert!(matches!(response, Err(OpRequestError::ChannelClosed)));
}
#[tokio::test]
async fn test_all_senders_dropped_mediator_exits() {
let (op_tx, op_rx, to_el_tx, _to_el_rx, _from_el_tx, from_el_rx) = create_test_channels();
let mut mediator = TestMediator::new(op_rx, to_el_tx, from_el_rx);
drop(op_tx);
drop(_from_el_tx);
let still_running = mediator.process_one_request().await;
assert!(!still_running);
}
#[tokio::test]
async fn test_mediator_rejects_at_capacity() {
let (op_tx, op_rx, to_el_tx, _to_el_rx, _from_el_tx, from_el_rx) = create_test_channels();
let mut mediator = TestMediator::new(op_rx, to_el_tx, from_el_rx);
mediator.max_pending = 5;
let mut receivers = Vec::new();
for _ in 0..5 {
let tx = Transaction::new::<GetMsg>();
let (response_tx, response_rx) = oneshot::channel();
op_tx.send((tx, response_tx)).await.unwrap();
receivers.push(response_rx);
mediator.process_one_request().await;
}
assert_eq!(mediator.pending_count(), 5);
let tx = Transaction::new::<GetMsg>();
let (response_tx, response_rx) = oneshot::channel();
op_tx.send((tx, response_tx)).await.unwrap();
mediator.process_one_request().await;
let result = timeout(Duration::from_millis(100), response_rx).await;
assert!(result.is_ok());
let response = result.unwrap().unwrap();
assert!(matches!(response, Err(OpRequestError::Failed(msg)) if msg.contains("capacity")));
assert_eq!(mediator.pending_count(), 5);
}
#[tokio::test]
async fn test_capacity_recovers_after_responses() {
let (op_tx, op_rx, to_el_tx, to_el_rx, _from_el_tx, from_el_rx) = create_test_channels();
let mut mediator = TestMediator::new(op_rx, to_el_tx, from_el_rx);
mediator.max_pending = 3;
let mut transactions = Vec::new();
for _ in 0..3 {
let tx = Transaction::new::<GetMsg>();
let (response_tx, _response_rx) = oneshot::channel();
op_tx.send((tx, response_tx)).await.unwrap();
transactions.push(tx);
mediator.process_one_request().await;
}
assert_eq!(mediator.pending_count(), 3);
drop(to_el_rx);
assert_eq!(mediator.pending_count(), 3);
}
#[tokio::test]
async fn test_response_routed_to_correct_executor() {
let mut pending: std::collections::HashMap<
Transaction,
oneshot::Sender<Result<OpEnum, OpRequestError>>,
> = std::collections::HashMap::new();
let mut receivers = Vec::new();
let mut transactions = Vec::new();
for _ in 0..3 {
let tx = Transaction::new::<GetMsg>();
let (response_tx, response_rx) = oneshot::channel();
pending.insert(tx, response_tx);
receivers.push(response_rx);
transactions.push(tx);
}
let target_tx = transactions[1];
if let Some(sender) = pending.remove(&target_tx) {
drop(sender.send(Err(OpRequestError::Failed("test response".to_string()))));
}
for (i, rx) in receivers.into_iter().enumerate() {
let result = timeout(Duration::from_millis(10), rx).await;
if i == 1 {
assert!(result.is_ok(), "Middle request should have response");
} else {
assert!(result.is_err(), "Other requests should timeout");
}
}
}
#[tokio::test]
async fn test_unknown_response_handled_gracefully() {
let (op_tx, op_rx, to_el_tx, _to_el_rx, _from_el_tx, from_el_rx) = create_test_channels();
let mut mediator = TestMediator::new(op_rx, to_el_tx, from_el_rx);
let unknown_tx = Transaction::new::<GetMsg>();
let real_tx = Transaction::new::<GetMsg>();
let (response_tx, _response_rx) = oneshot::channel();
op_tx.send((real_tx, response_tx)).await.unwrap();
mediator.process_one_request().await;
assert!(!mediator.pending.contains_key(&unknown_tx));
}
#[tokio::test]
async fn test_executor_timeout_handling() {
let (op_tx, op_rx, to_el_tx, _to_el_rx, _from_el_tx, from_el_rx) = create_test_channels();
let mut mediator = TestMediator::new(op_rx, to_el_tx, from_el_rx);
let tx = Transaction::new::<GetMsg>();
let (response_tx, response_rx) = oneshot::channel();
op_tx.send((tx, response_tx)).await.unwrap();
mediator.process_one_request().await;
let result = timeout(Duration::from_millis(50), response_rx).await;
assert!(result.is_err(), "Should timeout waiting for response");
assert_eq!(mediator.pending_count(), 1);
}
#[tokio::test]
async fn test_concurrent_requests_independent_timeouts() {
let (op_tx, op_rx, to_el_tx, _to_el_rx, _from_el_tx, from_el_rx) = create_test_channels();
let mut mediator = TestMediator::new(op_rx, to_el_tx, from_el_rx);
let mut receivers = Vec::new();
for _ in 0..5 {
let tx = Transaction::new::<GetMsg>();
let (response_tx, response_rx) = oneshot::channel();
op_tx.send((tx, response_tx)).await.unwrap();
receivers.push(response_rx);
mediator.process_one_request().await;
}
for rx in receivers {
let result = timeout(Duration::from_millis(10), rx).await;
assert!(result.is_err(), "Each request should timeout independently");
}
assert_eq!(mediator.pending_count(), 5);
}