#![cfg(any(test, feature = "mock"))]
use std::net::SocketAddr;
use std::sync::{Arc, Mutex};
use bytes::BytesMut;
use tokio::net::TcpListener;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
pub struct MockBroker {
pub addr: SocketAddr,
shutdown: CancellationToken,
_task: JoinHandle<()>,
}
type Handler = Box<dyn FnMut(i16, i16, i32, &[u8]) -> Option<Vec<u8>> + Send>;
impl MockBroker {
pub async fn start<F>(handler: F) -> Self
where
F: FnMut(i16, i16, i32, &[u8]) -> Option<Vec<u8>> + Send + 'static,
{
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let handler: Arc<Mutex<Handler>> = Arc::new(Mutex::new(Box::new(handler)));
let shutdown = CancellationToken::new();
let task_handler = Arc::clone(&handler);
let task_shutdown = shutdown.clone();
let task = tokio::spawn(async move {
loop {
tokio::select! {
() = task_shutdown.cancelled() => break,
Ok((stream, _)) = listener.accept() => {
let h = Arc::clone(&task_handler);
let sd = task_shutdown.clone();
tokio::spawn(async move {
handle_connection(stream, h, sd).await;
});
}
}
}
});
Self {
addr,
shutdown,
_task: task,
}
}
pub fn stop(self) {
self.shutdown.cancel();
}
}
async fn handle_connection(
stream: tokio::net::TcpStream,
handler: Arc<Mutex<Handler>>,
shutdown: CancellationToken,
) {
use futures_util::{SinkExt, StreamExt};
let mut framed = crate::transport::frame(stream);
loop {
tokio::select! {
() = shutdown.cancelled() => break,
maybe_frame = framed.next() => {
let Some(frame) = maybe_frame else { break; };
let Ok(frame) = frame else { break; };
if frame.len() < 8 { continue; }
let api_key = i16::from_be_bytes([frame[0], frame[1]]);
let version = i16::from_be_bytes([frame[2], frame[3]]);
let corr_id = i32::from_be_bytes([frame[4], frame[5], frame[6], frame[7]]);
let body = &frame[8..];
let response_body_opt = {
let mut h = handler.lock().unwrap();
h(api_key, version, corr_id, body)
};
let Some(response_body) = response_body_opt else { continue; };
let mut resp = BytesMut::with_capacity(4 + response_body.len());
resp.extend_from_slice(&corr_id.to_be_bytes());
resp.extend_from_slice(&response_body);
if framed.send(resp.freeze()).await.is_err() {
break;
}
}
}
}
}