pulzr 0.3.2

A http load testing tool for performance testing.
Documentation
use crate::auth::AuthMethod;
use crate::network::{Http2Config, TlsConfig};
use crate::stats::{RequestResult, StatsCollector};
use crate::user_agent::UserAgentManager;
use anyhow::Result;
use chrono::Utc;
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
use reqwest::{Client, Method};
use std::sync::Arc;
use std::time::{Duration, Instant};

pub struct HttpClient {
    client: Client,
    url: String,
    method: Method,
    headers: HeaderMap,
    payload: Option<String>,
    user_agent_manager: Arc<UserAgentManager>,
    stats_collector: Arc<StatsCollector>,
    auth_method: Arc<AuthMethod>,
    http2_config: Arc<Http2Config>,
    tls_config: Arc<TlsConfig>,
    retry_count: u32,
    expect_body: Option<String>,
}

impl HttpClient {
    #[allow(clippy::too_many_arguments)]
    pub fn new(
        url: String,
        method: Method,
        headers: Vec<String>,
        payload: Option<String>,
        user_agent_manager: Arc<UserAgentManager>,
        stats_collector: Arc<StatsCollector>,
        timeout: Option<Duration>,
        auth_method: Arc<AuthMethod>,
        http2_config: Arc<Http2Config>,
        tls_config: Arc<TlsConfig>,
        follow_redirects: bool,
        retry_count: u32,
        expect_body: Option<String>,
    ) -> Result<Self> {
        let mut client_builder = Client::builder();

        if let Some(timeout) = timeout {
            client_builder = client_builder.timeout(timeout);
        }

        if !follow_redirects {
            client_builder = client_builder.redirect(reqwest::redirect::Policy::none());
        }

        client_builder = http2_config.apply_to_client_builder(client_builder);
        client_builder = tls_config.apply_to_client_builder(client_builder)?;

        let client = client_builder.build()?;

        let mut header_map = HeaderMap::new();
        for header in headers {
            if let Some((key, value)) = header.split_once(':') {
                let key = key.trim();
                let value = value.trim();
                if let (Ok(name), Ok(val)) = (
                    HeaderName::from_bytes(key.as_bytes()),
                    HeaderValue::from_str(value),
                ) {
                    header_map.insert(name, val);
                }
            }
        }

        Ok(Self {
            client,
            url,
            method,
            headers: header_map,
            payload,
            user_agent_manager,
            stats_collector,
            auth_method,
            http2_config,
            tls_config,
            retry_count,
            expect_body,
        })
    }

    pub async fn send_request(&self) -> Result<()> {
        let timestamp = Utc::now();
        let user_agent = self.user_agent_manager.get_user_agent().to_string();

        let mut url = self.url.clone();
        if let Some((key, value)) = self.auth_method.get_query_params().await {
            let separator = if url.contains('?') { "&" } else { "?" };
            url = format!("{url}{separator}{key}={value}");
        }

        let max_attempts = self.retry_count + 1;

        for attempt in 0..max_attempts {
            let mut builder = self.client.request(self.method.clone(), &url);
            builder = builder.header("User-Agent", &user_agent);

            if let Some((key, value)) = self.auth_method.get_auth_header().await {
                builder = builder.header(key, value);
            }

            for (key, value) in &self.headers {
                builder = builder.header(key, value);
            }

            if let Some(payload) = &self.payload {
                if self.method == Method::POST
                    || self.method == Method::PUT
                    || self.method == Method::PATCH
                {
                    builder = builder.body(payload.clone());
                    if !self.headers.contains_key("content-type")
                        && (payload.trim_start().starts_with('{')
                            || payload.trim_start().starts_with('['))
                    {
                        builder = builder.header("Content-Type", "application/json");
                    }
                }
            }

            let start = Instant::now();
            let request = builder.build()?;

            match self.client.execute(request).await {
                Ok(response) => {
                    let status_code = response.status().as_u16();
                    let duration = start.elapsed().as_millis() as u64;

                    let mut error = if status_code >= 400 {
                        Some(format!(
                            "HTTP {}: {}",
                            status_code,
                            response.status().canonical_reason().unwrap_or("Unknown")
                        ))
                    } else {
                        None
                    };

                    let body_bytes = response.bytes().await.unwrap_or_default();
                    let bytes_received = body_bytes.len() as u64;

                    if error.is_none() {
                        if let Some(expected) = &self.expect_body {
                            let body_str = String::from_utf8_lossy(&body_bytes);
                            if !body_str.contains(expected.as_str()) {
                                error = Some(format!(
                                    "body validation failed: expected {:?} in response",
                                    expected
                                ));
                            }
                        }
                    }

                    let result = RequestResult {
                        timestamp,
                        duration_ms: duration,
                        status_code: Some(status_code),
                        error,
                        user_agent: Some(user_agent.clone()),
                        bytes_received,
                    };
                    self.stats_collector.record_request(result).await;
                    return Ok(());
                }
                Err(e) => {
                    if attempt + 1 < max_attempts {
                        continue;
                    }
                    let duration = start.elapsed().as_millis() as u64;
                    let err_str = e.to_string();
                    let result = RequestResult {
                        timestamp,
                        duration_ms: duration,
                        status_code: None,
                        error: Some(err_str),
                        user_agent: Some(user_agent.clone()),
                        bytes_received: 0,
                    };
                    self.stats_collector.record_request(result).await;
                    return Err(e.into());
                }
            }
        }

        unreachable!()
    }

    pub fn get_http2_config(&self) -> &Http2Config {
        &self.http2_config
    }

    pub fn get_tls_config(&self) -> &TlsConfig {
        &self.tls_config
    }

    pub fn get_protocol_info(&self) -> crate::network::ProtocolInfo {
        self.http2_config.get_protocol_info()
    }

    pub fn get_tls_info(&self) -> crate::network::TlsInfo {
        self.tls_config.get_summary()
    }
}