mod attack;
mod encode;
mod models;
mod plot;
mod report;
mod utils;
pub use models::{AttackConfig, Header, Metrics, Result as AttackResult, Target};
use anyhow::Result;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::mpsc;
use url::Url;
pub struct AttackBuilder {
rate: f64,
duration: Option<Duration>,
timeout: Duration,
workers: u64,
max_workers: Option<u64>,
keepalive: bool,
connections: usize,
max_connections: Option<usize>,
http2: bool,
name: Option<String>,
max_body: i64,
dns_ttl: Duration,
laddr: String,
lazy: bool,
opentelemetry_addr: Option<String>,
targets: Vec<Target>,
headers: Vec<Header>,
insecure: bool,
h2c: bool,
redirects: i32,
}
impl Default for AttackBuilder {
fn default() -> Self {
Self {
rate: 50.0,
duration: Some(Duration::from_secs(30)),
timeout: Duration::from_secs(30),
workers: 10,
max_workers: None,
keepalive: true,
connections: 10000,
max_connections: None,
http2: true,
name: None,
max_body: -1,
dns_ttl: Duration::from_secs(0),
laddr: "0.0.0.0".to_string(),
lazy: false,
opentelemetry_addr: None,
targets: Vec::new(),
headers: Vec::new(),
insecure: false,
h2c: false,
redirects: 10,
}
}
}
impl AttackBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn rate(mut self, rate: f64) -> Self {
self.rate = rate;
self
}
pub fn duration(mut self, duration: Duration) -> Self {
self.duration = Some(duration);
self
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub fn workers(mut self, workers: u64) -> Self {
self.workers = workers;
self
}
pub fn max_workers(mut self, max_workers: u64) -> Self {
self.max_workers = Some(max_workers);
self
}
pub fn keepalive(mut self, keepalive: bool) -> Self {
self.keepalive = keepalive;
self
}
pub fn connections(mut self, connections: usize) -> Self {
self.connections = connections;
self
}
pub fn max_connections(mut self, max_connections: usize) -> Self {
self.max_connections = Some(max_connections);
self
}
pub fn http2(mut self, http2: bool) -> Self {
self.http2 = http2;
self
}
pub fn name(mut self, name: String) -> Self {
self.name = Some(name);
self
}
pub fn max_body(mut self, max_body: i64) -> Self {
self.max_body = max_body;
self
}
pub fn dns_ttl(mut self, dns_ttl: Duration) -> Self {
self.dns_ttl = dns_ttl;
self
}
pub fn laddr(mut self, laddr: String) -> Self {
self.laddr = laddr;
self
}
pub fn lazy(mut self, lazy: bool) -> Self {
self.lazy = lazy;
self
}
pub fn opentelemetry_addr(mut self, addr: String) -> Self {
self.opentelemetry_addr = Some(addr);
self
}
pub fn targets(mut self, targets: Vec<Target>) -> Self {
self.targets = targets;
self
}
pub fn add_target(mut self, target: Target) -> Self {
self.targets.push(target);
self
}
pub fn headers(mut self, headers: Vec<Header>) -> Self {
self.headers = headers;
self
}
pub fn add_header(mut self, name: &str, value: &str) -> Self {
self.headers.push(Header {
name: name.to_string(),
value: value.to_string(),
});
self
}
pub fn insecure(mut self, insecure: bool) -> Self {
self.insecure = insecure;
self
}
pub fn h2c(mut self, h2c: bool) -> Self {
self.h2c = h2c;
self
}
pub fn redirects(mut self, redirects: i32) -> Self {
self.redirects = redirects;
self
}
pub async fn run(self) -> Result<Vec<AttackResult>> {
if self.targets.is_empty() {
anyhow::bail!("No targets specified");
}
let config = AttackConfig {
rate: self.rate,
duration: self.duration,
timeout: self.timeout,
workers: self.workers,
max_workers: self.max_workers,
keepalive: self.keepalive,
connections: self.connections,
max_connections: self.max_connections,
http2: self.http2,
name: self.name,
max_body: self.max_body,
dns_ttl: self.dns_ttl,
laddr: self.laddr,
lazy: self.lazy,
opentelemetry_addr: self.opentelemetry_addr,
};
let mut client_builder = reqwest::Client::builder()
.timeout(config.timeout)
.pool_max_idle_per_host(config.connections);
if let Some(max_conns) = config.max_connections {
client_builder = client_builder.pool_max_idle_per_host(max_conns);
}
if !config.keepalive {
client_builder = client_builder.pool_idle_timeout(None);
}
if self.insecure {
client_builder = client_builder.danger_accept_invalid_certs(true);
}
if self.h2c {
client_builder = client_builder.http2_prior_knowledge();
} else if config.http2 {
client_builder = client_builder.http2_adaptive_window(true);
}
if config.laddr != "0.0.0.0" {
let local_addr = config.laddr.parse::<std::net::IpAddr>()?;
client_builder = client_builder.local_address(local_addr);
}
if self.redirects >= 0 {
client_builder = client_builder.redirect(reqwest::redirect::Policy::limited(self.redirects as usize));
} else {
client_builder = client_builder.redirect(reqwest::redirect::Policy::none());
}
let client = Arc::new(client_builder.build()?);
let (tx, mut rx) = mpsc::channel::<AttackResult>(1000);
let attack_handle = {
let targets = Arc::new(self.targets);
let headers = Arc::new(self.headers);
let config = Arc::new(config);
let tx = tx.clone();
tokio::spawn(async move {
let delay = if config.rate > 0.0 {
Duration::from_secs_f64(1.0 / config.rate)
} else {
Duration::from_secs(0)
};
let start_time = std::time::Instant::now();
let mut request_count = 0;
let end_time = config.duration.map(|d| start_time + d);
let mut interval = tokio::time::interval(delay);
let worker_semaphore = Arc::new(tokio::sync::Semaphore::new(config.workers as usize));
if let Some(max_workers) = config.max_workers {
if max_workers > config.workers {
let semaphore_clone = worker_semaphore.clone();
let duration_clone = config.duration.clone();
let workers = config.workers;
tokio::spawn(async move {
let worker_diff = max_workers - workers;
let total_duration = duration_clone.unwrap_or(Duration::from_secs(60));
let interval = total_duration.div_f64(worker_diff as f64);
for _ in 0..worker_diff {
tokio::time::sleep(interval).await;
semaphore_clone.add_permits(1);
}
});
}
}
loop {
interval.tick().await;
if let Some(end) = end_time {
if std::time::Instant::now() >= end {
break;
}
}
let target_index = request_count % targets.len();
let target = targets[target_index].clone();
let client = client.clone();
let headers = headers.clone();
let config_clone = config.clone();
let tx = tx.clone();
let semaphore = worker_semaphore.clone();
let permit = match semaphore.clone().try_acquire_owned() {
Ok(permit) => permit,
Err(_) => {
match semaphore.clone().acquire_owned().await {
Ok(permit) => permit,
Err(_) => continue,
}
}
};
tokio::spawn(async move {
let result = attack::make_request(client, target, &headers, &config_clone).await;
let _ = tx.send(result).await;
drop(permit);
});
request_count += 1;
}
})
};
let mut results = Vec::new();
let collector_handle = tokio::spawn(async move {
let mut collected_results = Vec::new();
while let Some(result) = rx.recv().await {
collected_results.push(result);
}
collected_results
});
attack_handle.await?;
drop(tx);
results = collector_handle.await?;
Ok(results)
}
}
pub fn target(method: &str, url: &str) -> Result<Target> {
Ok(Target {
method: method.to_string(),
url: Url::parse(url)?,
headers: Vec::new(),
body: None,
})
}
pub fn get(url: &str) -> Result<Target> {
target("GET", url)
}
pub fn post(url: &str, body: Vec<u8>) -> Result<Target> {
let mut target = target("POST", url)?;
target.body = Some(body);
Ok(target)
}
pub fn calculate_metrics(results: &[AttackResult]) -> Option<Metrics> {
if results.is_empty() {
return None;
}
let requests = results.len();
let success = results.iter().filter(|r| r.status_code >= 200 && r.status_code < 300).count();
let success_rate = success as f64 / requests as f64;
let first_timestamp = results.first().unwrap().timestamp;
let last_timestamp = results.last().unwrap().timestamp;
let duration = (last_timestamp - first_timestamp).to_std().unwrap_or(Duration::from_secs(0));
let mut latencies: Vec<Duration> = results.iter().map(|r| r.latency).collect();
latencies.sort();
let min = latencies.first().cloned().unwrap_or(Duration::from_secs(0));
let max = latencies.last().cloned().unwrap_or(Duration::from_secs(0));
let mean = if !latencies.is_empty() {
let sum: Duration = latencies.iter().sum();
Duration::from_secs_f64(sum.as_secs_f64() / latencies.len() as f64)
} else {
Duration::from_secs(0)
};
let p50 = percentile(&latencies, 0.5);
let p90 = percentile(&latencies, 0.9);
let p95 = percentile(&latencies, 0.95);
let p99 = percentile(&latencies, 0.99);
let rate = if duration.as_secs_f64() > 0.0 {
requests as f64 / duration.as_secs_f64()
} else {
0.0
};
let bytes_in: usize = results.iter().map(|r| r.bytes_in).sum();
let bytes_out: usize = results.iter().map(|r| r.bytes_out).sum();
Some(Metrics {
requests,
success,
duration,
min,
max,
mean,
p50,
p90,
p95,
p99,
rate,
bytes_in,
bytes_out,
success_rate,
})
}
fn percentile(sorted_latencies: &[Duration], percentile: f64) -> Duration {
if sorted_latencies.is_empty() {
return Duration::from_secs(0);
}
let index = (sorted_latencies.len() as f64 * percentile) as usize;
sorted_latencies[index.min(sorted_latencies.len() - 1)]
}