strest 0.1.10

Blazing-fast async HTTP load tester in Rust - lock-free design, real-time stats, distributed runs, and optional chart exports for high-load API testing.
Documentation
use std::time::Duration;

use tokio::io::BufReader;
use tokio::net::TcpStream;
use tokio::sync::mpsc;
use tokio::time::MissedTickBehavior;
use tracing::{debug, info};

use crate::args::TesterArgs;
use crate::error::{AppError, AppResult, DistributedError};

use super::command::AgentCommand;
use super::run_exec::{AgentLocalRunPort, run_agent_run};
use super::wire::{build_agent_id, build_hello, send_wire};
use crate::distributed::protocol::{HeartbeatMessage, WireMessage, read_message, send_message};
use crate::distributed::utils::current_time_ms;

pub(super) async fn run_agent_session<TLocalRunPort>(
    base_args: &TesterArgs,
    local_run_port: &TLocalRunPort,
) -> AppResult<()>
where
    TLocalRunPort: AgentLocalRunPort + Sync,
{
    let join = base_args.agent_join.as_deref().ok_or_else(|| {
        AppError::distributed(DistributedError::MissingOption {
            option: "--agent-join",
        })
    })?;

    info!("Connecting to controller {}", join);
    let stream = TcpStream::connect(join).await.map_err(|err| {
        AppError::distributed(DistributedError::Connection {
            addr: join.to_owned(),
            source: err,
        })
    })?;
    info!("Connected to controller {}", join);
    let (read_half, mut write_half) = stream.into_split();
    let (out_tx, mut out_rx) = mpsc::unbounded_channel::<WireMessage>();
    let writer_handle = tokio::spawn(async move {
        while let Some(message) = out_rx.recv().await {
            if send_message(&mut write_half, &message).await.is_err() {
                break;
            }
        }
    });
    let mut reader = BufReader::new(read_half);

    let (cmd_tx, mut cmd_rx) = mpsc::unbounded_channel::<AgentCommand>();
    let reader_handle = tokio::spawn(async move {
        loop {
            let message = match read_message(&mut reader).await {
                Ok(message) => message,
                Err(err) => {
                    if cmd_tx.send(AgentCommand::Disconnected(err)).is_err() {
                        break;
                    }
                    break;
                }
            };

            let command = match message {
                WireMessage::Config(message) => AgentCommand::Config(message),
                WireMessage::Start(message) => AgentCommand::Start(message),
                WireMessage::Stop(message) => AgentCommand::Stop(message),
                WireMessage::Error(message) => {
                    AgentCommand::Error(AppError::distributed(DistributedError::Remote {
                        message: message.message,
                    }))
                }
                WireMessage::Heartbeat(_) => continue,
                WireMessage::Hello(_) | WireMessage::Stream(_) | WireMessage::Report(_) => {
                    AgentCommand::Error(AppError::distributed(
                        DistributedError::UnexpectedMessageFromController,
                    ))
                }
            };

            if cmd_tx.send(command).is_err() {
                break;
            }
        }
    });

    let agent_id = build_agent_id(base_args);
    let hello = build_hello(base_args, &agent_id);
    send_wire(&out_tx, WireMessage::Hello(hello))?;
    debug!("Sent hello as {}", agent_id);

    let heartbeat_interval = Duration::from_millis(base_args.agent_heartbeat_interval_ms.get());
    let heartbeat_tx = out_tx.clone();
    let heartbeat_handle = tokio::spawn(async move {
        let mut interval = tokio::time::interval(heartbeat_interval);
        interval.set_missed_tick_behavior(MissedTickBehavior::Skip);
        loop {
            interval.tick().await;
            let sent_at_ms = u64::try_from(current_time_ms()).unwrap_or(u64::MAX);
            let message = WireMessage::Heartbeat(HeartbeatMessage { sent_at_ms });
            if send_wire(&heartbeat_tx, message).is_err() {
                break;
            }
        }
    });

    let session_result = loop {
        let config = match wait_for_config(&mut cmd_rx).await {
            Ok(config) => config,
            Err(err) => break Err(err),
        };
        info!("Received config for run {}", config.run_id);
        let start = match wait_for_start(&mut cmd_rx, &config.run_id).await {
            Ok(start) => start,
            Err(err) => break Err(err),
        };
        info!(
            "Received start for run {} (start_after_ms={})",
            start.run_id, start.start_after_ms
        );

        if start.run_id != config.run_id {
            break Err(AppError::distributed(DistributedError::RunIdMismatch {
                expected: config.run_id.clone(),
                actual: start.run_id.clone(),
            }));
        }

        if start.start_after_ms > 0 {
            tokio::time::sleep(Duration::from_millis(start.start_after_ms)).await;
        }

        let run_result = run_agent_run(
            base_args,
            config,
            agent_id.clone(),
            &out_tx,
            &mut cmd_rx,
            local_run_port,
        )
        .await;

        if let Err(err) = run_result {
            if !base_args.agent_standby {
                break Err(err);
            }
            tracing::warn!("Agent run error: {}", err);
        }

        if !base_args.agent_standby {
            break Ok(());
        }
    };

    heartbeat_handle.abort();
    drop(out_tx);
    if writer_handle.await.is_err() {
        // ignore
    }
    reader_handle.abort();
    session_result
}

async fn wait_for_config(
    cmd_rx: &mut mpsc::UnboundedReceiver<AgentCommand>,
) -> AppResult<crate::distributed::protocol::ConfigMessage> {
    while let Some(command) = cmd_rx.recv().await {
        match command {
            AgentCommand::Config(config) => return Ok(*config),
            AgentCommand::Disconnected(err) => return Err(err),
            AgentCommand::Error(err) => return Err(err),
            AgentCommand::Stop(_) => continue,
            AgentCommand::Start(_) => {
                return Err(AppError::distributed(DistributedError::StartBeforeConfig));
            }
        }
    }
    Err(AppError::distributed(
        DistributedError::ControllerConnectionClosed,
    ))
}

async fn wait_for_start(
    cmd_rx: &mut mpsc::UnboundedReceiver<AgentCommand>,
    run_id: &str,
) -> AppResult<crate::distributed::protocol::StartMessage> {
    while let Some(command) = cmd_rx.recv().await {
        match command {
            AgentCommand::Start(start) => {
                if start.run_id != run_id {
                    return Err(AppError::distributed(DistributedError::RunIdMismatch {
                        expected: run_id.to_owned(),
                        actual: start.run_id,
                    }));
                }
                return Ok(start);
            }
            AgentCommand::Stop(stop) => {
                if stop.run_id == run_id {
                    return Err(AppError::distributed(DistributedError::StopBeforeStart));
                }
            }
            AgentCommand::Config(_) => {
                return Err(AppError::distributed(
                    DistributedError::ConfigWhileWaitingForStart,
                ));
            }
            AgentCommand::Disconnected(err) => return Err(err),
            AgentCommand::Error(err) => return Err(err),
        }
    }
    Err(AppError::distributed(
        DistributedError::ControllerConnectionClosed,
    ))
}