use std::net::SocketAddr;
use std::time::Instant;
use async_std::io::{self};
use async_std::net::{TcpStream};
use async_std::prelude::*;
use futures::executor::LocalPool;
use futures::task::SpawnExt;
use futures_timer::Delay;
use log::{debug, error, info, warn};
use crate::{Config};
use byte_mutator::ByteMutator;
use byte_mutator::fuzz_config::FuzzConfig;
pub fn clobber(config: Config, message: Vec<u8>) -> std::io::Result<()> {
info!("Starting: {:#?}", config);
let mut threads = Vec::with_capacity(config.num_threads() as usize);
let message = match &config.fuzz_path {
None => ByteMutator::new(&message),
Some(path) => {
match FuzzConfig::from_file(&path) {
Ok(fuzz_config) => ByteMutator::new_from_config(&message, fuzz_config),
Err(e) => {
return Err(e)
},
}
},
};
for _ in 0..config.num_threads() {
let message = message.clone();
let config = config.clone();
let thread = std::thread::spawn(move || {
let mut pool = LocalPool::new();
let mut spawner = pool.spawner();
for i in 0..config.connections_per_thread() {
let message = message.clone();
let config = config.clone();
spawner
.spawn(async move {
if config.rate.is_some() {
Delay::new(i * config.connection_delay());
}
connection(message, config)
.await
.expect("Failed to run connection");
}).unwrap();
}
pool.run();
});
threads.push(thread);
}
for handle in threads {
handle.join().unwrap();
}
Ok(())
}
async fn connection(mut message: ByteMutator, config: Config) -> io::Result<()> {
let start = Instant::now();
let mut count = 0;
let mut loop_complete = |config:&Config| {
count += 1;
if let Some(duration) = config.duration {
if Instant::now() >= start + duration {
return true;
}
}
if let Some(limit) = config.limit_per_connection() {
if count > limit {
return true;
}
}
false
};
let should_delay = |elapsed, config: &Config| {
match config.rate {
Some(_) => {
if elapsed < config.connection_delay() {
true
} else {
warn!("running behind; consider adding more connections");
false
}
}
None => false,
}
};
let mut read_buffer = [0u8; 1024]; while !loop_complete(&config) {
let request_start = Instant::now();
if let Ok(mut stream) = connect(&config.target).await {
for _ in 0..config.repeat {
if write(&mut stream, message.read()).await.is_ok() {
read(&mut stream, &mut read_buffer).await.ok();
}
}
message.next();
}
if config.rate.is_some() {
let elapsed = Instant::now() - request_start;
if should_delay(elapsed, &config) {
Delay::new(config.connection_delay() - elapsed)
.await
.unwrap();
}
}
}
Ok(())
}
async fn connect(addr: &SocketAddr) -> io::Result<TcpStream> {
match TcpStream::connect(addr).await {
Ok(stream) => {
debug!("connected to {}", addr);
Ok(stream)
}
Err(e) => {
if e.kind() != io::ErrorKind::TimedOut {
error!("unknown connect error: '{}'", e);
}
Err(e)
}
}
}
async fn write(stream: &mut TcpStream, buf: &[u8]) -> io::Result<usize> {
match stream.write_all(buf).await {
Ok(_) => {
let n = buf.len();
debug!("{} bytes written", n);
Ok(n)
}
Err(e) => {
error!("write error: '{}'", e);
Err(e)
}
}
}
async fn read(stream: &mut TcpStream, mut read_buffer: &mut [u8]) -> io::Result<usize> {
match stream.read(&mut read_buffer).await {
Ok(n) => {
debug!("{} bytes read ", n);
Ok(n)
}
Err(e) => {
error!("read error: '{}'", e);
Err(e)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::server::echo_server;
#[test]
fn test_connect() {
let result = async_std::task::block_on(async {
let addr = echo_server().unwrap();
let result = connect(&addr).await;
result
});
assert!(result.is_ok());
}
#[test]
fn test_write() {
let addr = echo_server().unwrap();
let input = "test".as_bytes();
let want = input.len();
let result = async_std::task::block_on(async move {
let mut stream = connect(&addr).await?;
let bytes_written = write(&mut stream, &input).await?;
Ok::<_, io::Error>(bytes_written)
});
assert!(result.is_ok());
assert_eq!(result.unwrap(), want);
}
#[test]
fn test_read() {
let addr = echo_server().unwrap();
let input = "test\n\r\n".as_bytes();
let want = input.len();
let result = async_std::task::block_on(async move {
let mut stream = connect(&addr).await?;
let mut read_buffer = [0u8; 1024];
let _ = write(&mut stream, &input).await?;
let bytes_read = read(&mut stream, &mut read_buffer).await?;
Ok::<_, io::Error>(bytes_read)
});
assert!(result.is_ok());
assert_eq!(want, result.unwrap());
}
}