strest 0.1.10

Blazing-fast async HTTP load tester in Rust - lock-free design, real-time stats, distributed runs, and optional chart exports for high-load API testing.
Documentation
use std::future::Future;
use std::time::Duration;

use clap::Parser;
use futures_util::{SinkExt, StreamExt};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, UdpSocket};
use tokio::sync::mpsc;
use tokio::task::JoinHandle;
use tokio::time::timeout;
use tokio_tungstenite::{accept_async, tungstenite::Message};

use crate::args::TesterArgs;
use crate::error::{AppError, AppResult, ValidationError};
use crate::metrics::Metrics;

use super::resolve::resolve_endpoint;
use super::setup_request_sender;

mod datagram_mqtt;
mod scheme_resolution;
mod transport_http_grpc;

const SHUTDOWN_CHANNEL_CAPACITY: usize = 16;
const TEST_TIMEOUT: Duration = Duration::from_secs(5);

fn run_async_test<F>(future: F) -> AppResult<()>
where
    F: Future<Output = AppResult<()>>,
{
    let runtime = tokio::runtime::Builder::new_current_thread()
        .enable_all()
        .build()
        .map_err(|err| AppError::validation(format!("Failed to build runtime: {}", err)))?;
    runtime.block_on(future)
}

fn permission_denied(err: &std::io::Error) -> bool {
    err.kind() == std::io::ErrorKind::PermissionDenied
}

fn parse_args(protocol: &str, load_mode: &str, url: &str) -> AppResult<TesterArgs> {
    TesterArgs::try_parse_from([
        "strest",
        "--url",
        url,
        "--protocol",
        protocol,
        "--load-mode",
        load_mode,
        "--requests",
        "1",
        "--max-tasks",
        "1",
        "--spawn-rate",
        "1",
        "--spawn-interval",
        "1",
        "--timeout",
        "3s",
        "--connect-timeout",
        "3s",
        "--data",
        "ping",
    ])
    .map_err(|err| AppError::validation(format!("Expected parse success: {}", err)))
}

fn validation_error(err: AppError) -> AppResult<ValidationError> {
    if let AppError::Validation(value) = err {
        Ok(value)
    } else {
        Err(AppError::validation(format!(
            "Expected validation error, got: {}",
            err
        )))
    }
}

async fn wait_metric(
    metrics_rx: &mut mpsc::Receiver<Metrics>,
    protocol: &str,
) -> AppResult<Metrics> {
    let received = timeout(TEST_TIMEOUT, metrics_rx.recv())
        .await
        .map_err(|_err| {
            AppError::validation(format!("Timed out waiting for metric for {}", protocol))
        })?;
    received.ok_or_else(|| AppError::validation(format!("Metric channel closed for {}", protocol)))
}

fn join_handle(handle: JoinHandle<()>, protocol: &str) -> impl Future<Output = AppResult<()>> {
    let protocol = protocol.to_owned();
    async move {
        timeout(TEST_TIMEOUT, handle)
            .await
            .map_err(|_err| AppError::validation(format!("Sender task timeout for {}", protocol)))?
            .map_err(|err| {
                AppError::validation(format!("Sender task failed for {}: {}", protocol, err))
            })
    }
}

fn join_result_handle(
    handle: JoinHandle<AppResult<()>>,
    protocol: &str,
) -> impl Future<Output = AppResult<()>> {
    let protocol = protocol.to_owned();
    async move {
        let output = timeout(TEST_TIMEOUT, handle)
            .await
            .map_err(|_err| AppError::validation(format!("Server task timeout for {}", protocol)))?
            .map_err(|err| {
                AppError::validation(format!("Server task failed for {}: {}", protocol, err))
            })?;
        output?;
        Ok(())
    }
}

async fn spawn_udp_echo_server(
    expected_packets: usize,
    protocol: &str,
) -> AppResult<(std::net::SocketAddr, JoinHandle<AppResult<()>>)> {
    let socket = UdpSocket::bind("127.0.0.1:0").await.map_err(|err| {
        AppError::validation(format!(
            "Failed to bind UDP server for {}: {}",
            protocol, err
        ))
    })?;
    let addr = socket.local_addr().map_err(|err| {
        AppError::validation(format!("Failed to read UDP addr for {}: {}", protocol, err))
    })?;
    let protocol_name = protocol.to_owned();

    let task = tokio::spawn(async move {
        let mut buf = [0_u8; 2048];
        for _ in 0..expected_packets {
            let (bytes, peer) = timeout(TEST_TIMEOUT, socket.recv_from(&mut buf))
                .await
                .map_err(|_err| {
                    AppError::validation(format!("UDP recv timeout for {}", protocol_name))
                })?
                .map_err(|err| {
                    AppError::validation(format!("UDP recv failed for {}: {}", protocol_name, err))
                })?;
            if bytes == 0 {
                return Err(AppError::validation(format!(
                    "UDP server received empty datagram for {}",
                    protocol_name
                )));
            }
            socket.send_to(b"ok", peer).await.map_err(|err| {
                AppError::validation(format!("UDP send failed for {}: {}", protocol_name, err))
            })?;
        }
        Ok(())
    });
    Ok((addr, task))
}

async fn spawn_mqtt_mock_server() -> AppResult<(std::net::SocketAddr, JoinHandle<AppResult<()>>)> {
    let listener = TcpListener::bind("127.0.0.1:0")
        .await
        .map_err(|err| AppError::validation(format!("Failed to bind MQTT server: {}", err)))?;
    let addr = listener
        .local_addr()
        .map_err(|err| AppError::validation(format!("Failed to read MQTT addr: {}", err)))?;

    let task = tokio::spawn(async move {
        for _ in 0..2 {
            let (mut stream, _) = timeout(TEST_TIMEOUT, listener.accept())
                .await
                .map_err(|_err| AppError::validation("MQTT accept timed out"))?
                .map_err(|err| AppError::validation(format!("MQTT accept failed: {}", err)))?;

            let mut connect_buf = [0_u8; 512];
            let connect_len = timeout(TEST_TIMEOUT, stream.read(&mut connect_buf))
                .await
                .map_err(|_err| AppError::validation("MQTT connect read timed out"))?
                .map_err(|err| {
                    AppError::validation(format!("MQTT connect read failed: {}", err))
                })?;
            if connect_len == 0 {
                return Err(AppError::validation("MQTT connect frame was empty"));
            }
            let packet_type = connect_buf.first().copied().unwrap_or(0) & 0xF0;
            if packet_type != 0x10 {
                return Err(AppError::validation(format!(
                    "Expected MQTT CONNECT packet, got 0x{packet_type:02x}"
                )));
            }

            stream
                .write_all(&[0x20, 0x02, 0x00, 0x00])
                .await
                .map_err(|err| {
                    AppError::validation(format!("MQTT connack write failed: {}", err))
                })?;

            let mut publish_buf = [0_u8; 512];
            let publish_len = timeout(TEST_TIMEOUT, stream.read(&mut publish_buf))
                .await
                .map_err(|_err| AppError::validation("MQTT publish read timed out"))?
                .map_err(|err| {
                    AppError::validation(format!("MQTT publish read failed: {}", err))
                })?;
            if publish_len == 0 {
                return Err(AppError::validation("MQTT publish frame was empty"));
            }
            let publish_type = publish_buf.first().copied().unwrap_or(0) & 0xF0;
            if publish_type != 0x30 {
                return Err(AppError::validation(format!(
                    "Expected MQTT PUBLISH packet, got 0x{publish_type:02x}"
                )));
            }
        }
        Ok(())
    });
    Ok((addr, task))
}

async fn spawn_tcp_echo_server(
    expected_connections: usize,
    protocol: &str,
) -> AppResult<(std::net::SocketAddr, JoinHandle<AppResult<()>>)> {
    let listener = TcpListener::bind("127.0.0.1:0").await.map_err(|err| {
        AppError::validation(format!(
            "Failed to bind TCP server for {}: {}",
            protocol, err
        ))
    })?;
    let addr = listener.local_addr().map_err(|err| {
        AppError::validation(format!("Failed to read TCP addr for {}: {}", protocol, err))
    })?;
    let protocol_name = protocol.to_owned();

    let task = tokio::spawn(async move {
        for _ in 0..expected_connections {
            let (mut stream, _) = timeout(TEST_TIMEOUT, listener.accept())
                .await
                .map_err(|_err| {
                    AppError::validation(format!("TCP accept timeout for {}", protocol_name))
                })?
                .map_err(|err| {
                    AppError::validation(format!(
                        "TCP accept failed for {}: {}",
                        protocol_name, err
                    ))
                })?;
            let mut buf = [0_u8; 2048];
            let _ = timeout(TEST_TIMEOUT, stream.read(&mut buf))
                .await
                .map_err(|_err| {
                    AppError::validation(format!("TCP read timeout for {}", protocol_name))
                })?
                .map_err(|err| {
                    AppError::validation(format!("TCP read failed for {}: {}", protocol_name, err))
                })?;
            timeout(TEST_TIMEOUT, stream.write_all(b"ok"))
                .await
                .map_err(|_err| {
                    AppError::validation(format!("TCP write timeout for {}", protocol_name))
                })?
                .map_err(|err| {
                    AppError::validation(format!("TCP write failed for {}: {}", protocol_name, err))
                })?;
        }
        Ok(())
    });
    Ok((addr, task))
}

async fn spawn_http_mock_server(
    expected_connections: usize,
) -> AppResult<(std::net::SocketAddr, JoinHandle<AppResult<()>>)> {
    let listener = TcpListener::bind("127.0.0.1:0")
        .await
        .map_err(|err| AppError::validation(format!("Failed to bind HTTP server: {}", err)))?;
    let addr = listener
        .local_addr()
        .map_err(|err| AppError::validation(format!("Failed to read HTTP addr: {}", err)))?;

    let task = tokio::spawn(async move {
        for _ in 0..expected_connections {
            let (mut stream, _) = timeout(TEST_TIMEOUT, listener.accept())
                .await
                .map_err(|_err| AppError::validation("HTTP accept timed out"))?
                .map_err(|err| AppError::validation(format!("HTTP accept failed: {}", err)))?;
            let mut req = Vec::with_capacity(1024);
            loop {
                let mut chunk = [0_u8; 1024];
                let read = timeout(TEST_TIMEOUT, stream.read(&mut chunk))
                    .await
                    .map_err(|_err| AppError::validation("HTTP read timed out"))?
                    .map_err(|err| AppError::validation(format!("HTTP read failed: {}", err)))?;
                if read == 0 {
                    break;
                }
                let chunk_prefix = chunk.get(..read).ok_or_else(|| {
                    AppError::validation("HTTP server failed to access read buffer prefix")
                })?;
                req.extend_from_slice(chunk_prefix);
                if req.windows(4).any(|bytes| bytes == b"\r\n\r\n") {
                    break;
                }
            }

            let response = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\nConnection: close\r\n\r\nok";
            timeout(TEST_TIMEOUT, stream.write_all(response))
                .await
                .map_err(|_err| AppError::validation("HTTP write timed out"))?
                .map_err(|err| AppError::validation(format!("HTTP write failed: {}", err)))?;
        }
        Ok(())
    });
    Ok((addr, task))
}

async fn spawn_websocket_mock_server(
    expected_connections: usize,
) -> AppResult<(std::net::SocketAddr, JoinHandle<AppResult<()>>)> {
    let listener = TcpListener::bind("127.0.0.1:0")
        .await
        .map_err(|err| AppError::validation(format!("Failed to bind websocket server: {}", err)))?;
    let addr = listener.local_addr().map_err(|err| {
        AppError::validation(format!("Failed to read websocket server addr: {}", err))
    })?;

    let task = tokio::spawn(async move {
        for _ in 0..expected_connections {
            let (stream, _) = timeout(TEST_TIMEOUT, listener.accept())
                .await
                .map_err(|_err| AppError::validation("Websocket accept timed out"))?
                .map_err(|err| AppError::validation(format!("Websocket accept failed: {}", err)))?;

            let mut ws = timeout(TEST_TIMEOUT, accept_async(stream))
                .await
                .map_err(|_err| AppError::validation("Websocket handshake timed out"))?
                .map_err(|err| {
                    AppError::validation(format!("Websocket handshake failed: {}", err))
                })?;

            let incoming = timeout(TEST_TIMEOUT, ws.next())
                .await
                .map_err(|_err| AppError::validation("Websocket recv timed out"))?
                .ok_or_else(|| AppError::validation("Websocket stream closed unexpectedly"))?
                .map_err(|err| AppError::validation(format!("Websocket recv failed: {}", err)))?;
            if !incoming.is_text() && !incoming.is_binary() {
                return Err(AppError::validation("Unexpected websocket message type"));
            }

            timeout(TEST_TIMEOUT, ws.send(Message::Text("ok".to_owned())))
                .await
                .map_err(|_err| AppError::validation("Websocket send timed out"))?
                .map_err(|err| AppError::validation(format!("Websocket send failed: {}", err)))?;
        }
        Ok(())
    });
    Ok((addr, task))
}