use std::sync::Arc;
use std::time::Duration;
use futures_util::{SinkExt, StreamExt};
use tokio::net::TcpListener;
use tokio::sync::Mutex;
use tokio::task::JoinHandle;
use tokio::time::error::Elapsed;
use tokio_tungstenite::tungstenite::{Bytes, Message};
use atomic_websocket::client_sender::ClientSenders;
use atomic_websocket::server_sender::SenderStatus;
use atomic_websocket::types::{RwClientSenders, RwServerSender};
pub async fn find_available_port() -> u16 {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("Failed to bind to random port");
listener.local_addr().unwrap().port()
}
pub async fn with_timeout<T, F: std::future::Future<Output = T>>(
duration: Duration,
future: F,
) -> Result<T, Elapsed> {
tokio::time::timeout(duration, future).await
}
pub fn create_test_rw_client_senders() -> RwClientSenders {
Arc::new(ClientSenders::new())
}
#[allow(dead_code)]
pub async fn short_delay() {
tokio::time::sleep(Duration::from_millis(50)).await;
}
#[allow(dead_code)]
pub async fn medium_delay() {
tokio::time::sleep(Duration::from_millis(200)).await;
}
pub fn default_timeout() -> Duration {
Duration::from_secs(5)
}
const CATEGORY_PING: u16 = 10000;
const CATEGORY_PONG: u16 = 10001;
const CATEGORY_DISCONNECT: u16 = 10003;
pub struct TestServer {
pub port: u16,
shutdown_tx: Option<tokio::sync::oneshot::Sender<()>>,
handle: JoinHandle<()>,
}
impl TestServer {
pub async fn start(port: u16) -> Self {
let addr: std::net::SocketAddr = format!("127.0.0.1:{}", port).parse().unwrap();
let socket = tokio::net::TcpSocket::new_v4().unwrap();
socket.set_reuseaddr(true).unwrap();
socket.bind(addr).unwrap();
let listener = socket.listen(128).unwrap();
let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel::<()>();
let connection_handles: Arc<Mutex<Vec<JoinHandle<()>>>> = Arc::new(Mutex::new(Vec::new()));
let handles_clone = connection_handles.clone();
let handle = tokio::spawn(async move {
loop {
tokio::select! {
_ = &mut shutdown_rx => {
let handles = handles_clone.lock().await;
for h in handles.iter() {
h.abort();
}
break;
}
result = listener.accept() => {
if let Ok((stream, _)) = result {
let h = tokio::spawn(handle_test_connection(stream));
handles_clone.lock().await.push(h);
}
}
}
}
});
tokio::time::sleep(Duration::from_millis(50)).await;
Self {
port,
shutdown_tx: Some(shutdown_tx),
handle,
}
}
pub async fn shutdown(mut self) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
let _ = self.handle.await;
tokio::time::sleep(Duration::from_millis(100)).await;
}
}
async fn handle_test_connection(stream: tokio::net::TcpStream) {
let Ok(ws_stream) = tokio_tungstenite::accept_async(stream).await else {
return;
};
let (mut write, mut read) = ws_stream.split();
while let Some(Ok(msg)) = read.next().await {
let data = msg.into_data();
if data.len() >= 2 {
let category = data[0] as u16 + (data[1] as u16) * 256;
if category == CATEGORY_PING {
let pong_bytes: Vec<u8> =
vec![(CATEGORY_PONG % 256) as u8, (CATEGORY_PONG / 256) as u8];
let pong = Message::Binary(Bytes::from(pong_bytes));
if write.send(pong).await.is_err() {
break;
}
continue;
}
if category == CATEGORY_DISCONNECT {
break;
}
}
if write
.send(Message::Binary(Bytes::from(data.to_vec())))
.await
.is_err()
{
break;
}
}
}
#[allow(dead_code)]
pub async fn wait_for_status(
rx: &mut tokio::sync::mpsc::Receiver<SenderStatus>,
expected: SenderStatus,
timeout_duration: Duration,
) -> bool {
let deadline = tokio::time::Instant::now() + timeout_duration;
loop {
match tokio::time::timeout_at(deadline, rx.recv()).await {
Ok(Some(status)) if status == expected => return true,
Ok(Some(_)) => continue,
Ok(None) => return false,
Err(_) => return false,
}
}
}
#[allow(dead_code)]
pub async fn simulate_long_downtime(server_sender: &RwServerSender, seconds_ago: i64) {
let now_ts = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs() as i64;
server_sender.write().await.server_received_times = now_ts - seconds_ago;
}