neocurl 2.0.5

A command line tool to test servers
Documentation
use super::{PyAsyncResponses, PyMethod, PyRequest, PyResponse, async_responses::ResponseStats};
use indicatif::{ProgressBar, ProgressStyle};
use pyo3::{prelude::*, types::PyDict};
use reqwest::Client;
use std::{collections::HashMap, sync::Arc};
use tokio::{
    sync::{Semaphore, mpsc},
    task,
};

#[pyclass(name = "Client")]
#[derive(Default)]
pub struct PyClient {}

impl PyClient {
    fn send_request(&self, request: PyRequest) -> PyResult<PyResponse> {
        let request_builder = request.to_reqwest_blocking();

        let start = std::time::Instant::now();
        let response = request_builder.send().map_err(|e| {
            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("Request failed: {}", e))
        })?;
        let duration = start.elapsed();

        let status_code = response.status().as_u16();
        let status = response.status().to_string();

        let headers: HashMap<String, String> = response
            .headers()
            .iter()
            .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
            .collect();

        let body_raw = response
            .bytes()
            .map_err(|e| {
                PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
                    "Failed to read response body: {}",
                    e
                ))
            })?
            .to_vec();

        let body = String::from_utf8(body_raw.to_vec()).map_err(|e| {
            PyErr::new::<pyo3::exceptions::PyUnicodeDecodeError, _>(format!(
                "Failed to decode response body: {}",
                e
            ))
        })?;

        Ok(PyResponse {
            status_code,
            status,
            headers,
            body,
            body_raw,
            duration: duration.as_millis() as u64,
        })
    }

    fn send_requests_async(
        &self,
        request: PyRequest,
        amount: u32,
        threads: u32,
    ) -> PyResult<PyAsyncResponses> {
        let progress_bar = ProgressBar::new(amount.into());
        let style = ProgressStyle::with_template(
            "[{elapsed_precise}] {bar:40.cyan/blue} {pos:>5}/{len:5} {msg}",
        )
        .unwrap()
        .progress_chars("##-");
        progress_bar.set_style(style.clone());
        progress_bar.set_message("Processing");

        let rt = tokio::runtime::Builder::new_multi_thread()
            .worker_threads(threads as usize)
            .enable_all()
            .build()
            .map_err(|e| {
                PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("Runtime error: {}", e))
            })?;

        let mut total_duration = 0;

        let results = rt.block_on(async {
            let request_template = Arc::new(
                request
                    .to_reqwest()
                    .build()
                    .expect("Failed to build request"),
            );
            let semaphore = Arc::new(Semaphore::new(threads as usize));
            let (tx, mut rx) = mpsc::unbounded_channel();

            let sending_start = std::time::Instant::now();
            let per_thread = amount / threads;

            let mut handles = Vec::new();

            for _ in 0..threads {
                let tx = tx.clone();
                let semaphore = semaphore.clone();
                let request = request_template.clone();
                let progress_bar = progress_bar.clone();

                let handle = task::spawn(async move {
                    let client = Client::new();

                    for _ in 0..per_thread {
                        let _permit = semaphore.acquire().await.unwrap();
                        let req = request.try_clone().expect("Failed to clone request");

                        let start = std::time::Instant::now();

                        match client.execute(req).await {
                            Ok(response) => {
                                let duration = start.elapsed();
                                let status_code = response.status().as_u16();
                                let status = response.status().to_string();
                                let headers: HashMap<String, String> = response
                                    .headers()
                                    .iter()
                                    .map(|(k, v)| {
                                        (k.to_string(), v.to_str().unwrap_or("").to_string())
                                    })
                                    .collect();

                                let body_raw = match response.bytes().await {
                                    Ok(bytes) => bytes.to_vec(),
                                    Err(e) => {
                                        eprintln!("Failed to read response body: {}", e);
                                        Vec::new()
                                    }
                                };

                                let body = String::from_utf8(body_raw.to_vec())
                                    .map_err(|e| {
                                        eprintln!("Failed to decode response body: {}", e);
                                    })
                                    .unwrap_or("".to_string());

                                let response = PyResponse {
                                    status_code,
                                    status,
                                    headers,
                                    body,
                                    body_raw,
                                    duration: duration.as_millis() as u64,
                                };

                                if let Err(e) = tx.send(response) {
                                    eprintln!("Failed to send response: {}", e);
                                }
                            }
                            Err(e) => {
                                let status = e
                                    .status()
                                    .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR);
                                if let Err(e) = tx.send(PyResponse {
                                    status_code: status.as_u16(),
                                    status: status.to_string(),
                                    headers: HashMap::new(),
                                    body: "".to_string(),
                                    body_raw: Vec::new(),
                                    duration: start.elapsed().as_millis() as u64,
                                }) {
                                    eprintln!("Failed to send error response: {}", e);
                                }
                            }
                        };

                        progress_bar.inc(1);
                    }
                });

                handles.push(handle);
            }

            for handle in handles {
                handle.await.expect("Thread failed");
            }

            drop(tx);

            let mut responses = Vec::with_capacity(amount as usize);
            while let Some(res) = rx.recv().await {
                responses.push(res);
            }

            progress_bar.finish_and_clear();

            total_duration = sending_start.elapsed().as_millis() as u64;

            responses
        });

        println!("[{}] Responses received", results.len());

        let durations: Vec<u64> = results.iter().map(|r| r.duration).collect();
        let response_codes: Vec<u16> = results.iter().map(|r| r.status_code).collect();
        let async_responses = PyAsyncResponses {
            responses: results.clone(),
            responses_stats: {
                ResponseStats {
                    durations,
                    responses: response_codes,
                    total_duration,
                }
            },
        };

        Ok(async_responses)
    }
}

#[pymethods]
impl PyClient {
    #[pyo3(signature = (url, **kwargs))]
    fn send(&mut self, url: String, kwargs: Option<&Bound<'_, PyDict>>) -> PyResult<PyResponse> {
        let method = kwargs
            .and_then(|d| d.get_item("method").ok()?)
            .and_then(|m| m.extract::<PyMethod>().ok())
            .unwrap_or(PyMethod::Get);

        let request = PyRequest::from_args(url, method, kwargs)?;
        self.send_request(request)
    }

    #[pyo3(signature = (url, **kwargs))]
    fn send_async(
        &mut self,
        url: String,
        kwargs: Option<&Bound<'_, PyDict>>,
    ) -> PyResult<PyAsyncResponses> {
        let method = kwargs
            .and_then(|d| d.get_item("method").ok()?)
            .and_then(|m| m.extract::<PyMethod>().ok())
            .unwrap_or(PyMethod::Get);

        let amount = kwargs
            .and_then(|d| d.get_item("amount").ok()?)
            .and_then(|v| v.extract::<u32>().ok())
            .unwrap_or(1);

        let threads = kwargs
            .and_then(|d| d.get_item("threads").ok()?)
            .and_then(|v| v.extract::<u32>().ok())
            .unwrap_or(1);

        let request = PyRequest::from_args(url, method, kwargs)?;
        self.send_requests_async(request, amount, threads)
    }

    #[pyo3(signature = (url, **kwargs))]
    fn get(&mut self, url: String, kwargs: Option<&Bound<'_, PyDict>>) -> PyResult<PyResponse> {
        let request = PyRequest::from_args(url, PyMethod::Get, kwargs)?;
        self.send_request(request)
    }

    #[pyo3(signature = (url, **kwargs))]
    fn get_async(
        &mut self,
        url: String,
        kwargs: Option<&Bound<'_, PyDict>>,
    ) -> PyResult<PyAsyncResponses> {
        let request = PyRequest::from_args(url, PyMethod::Get, kwargs)?;

        let amount = kwargs
            .and_then(|d| d.get_item("amount").ok()?)
            .and_then(|v| v.extract::<u32>().ok())
            .unwrap_or(1);

        let threads = kwargs
            .and_then(|d| d.get_item("threads").ok()?)
            .and_then(|v| v.extract::<u32>().ok())
            .unwrap_or(1);

        self.send_requests_async(request, amount, threads)
    }

    #[pyo3(signature = (url, **kwargs))]
    fn post(&mut self, url: String, kwargs: Option<&Bound<'_, PyDict>>) -> PyResult<PyResponse> {
        let request = PyRequest::from_args(url, PyMethod::Post, kwargs)?;
        self.send_request(request)
    }

    #[pyo3(signature = (url, **kwargs))]
    fn post_async(
        &mut self,
        url: String,
        kwargs: Option<&Bound<'_, PyDict>>,
    ) -> PyResult<PyAsyncResponses> {
        let request = PyRequest::from_args(url, PyMethod::Post, kwargs)?;

        let amount = kwargs
            .and_then(|d| d.get_item("amount").ok()?)
            .and_then(|v| v.extract::<u32>().ok())
            .unwrap_or(1);
        let threads = kwargs
            .and_then(|d| d.get_item("threads").ok()?)
            .and_then(|v| v.extract::<u32>().ok())
            .unwrap_or(1);

        self.send_requests_async(request, amount, threads)
    }
}

#[pyfunction()]
fn client() -> PyClient {
    PyClient::default()
}

pub fn register(module: &Bound<'_, PyModule>) -> PyResult<()> {
    module.add_class::<PyClient>()?;
    module.add_function(wrap_pyfunction!(client, module)?)?;

    Ok(())
}