use crate::config::{Config, Protocol};
use crate::measurements::{IntervalStats, MeasurementsCollector};
use crate::protocol::{deserialize_message, serialize_message, Message};
use crate::{Error, Result};
use log::{debug, error, info};
use std::net::SocketAddr;
use std::time::{Duration, Instant};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream, UdpSocket};
use tokio::time;
pub struct Server {
config: Config,
measurements: MeasurementsCollector,
}
impl Server {
pub fn new(config: Config) -> Self {
Self {
config,
measurements: MeasurementsCollector::new(),
}
}
pub async fn run(&self) -> Result<()> {
let bind_addr = format!(
"{}:{}",
self.config
.bind_addr
.map(|a| a.to_string())
.unwrap_or_else(|| "0.0.0.0".to_string()),
self.config.port
);
info!("Starting rperf3 server on {}", bind_addr);
match self.config.protocol {
Protocol::Tcp => self.run_tcp(&bind_addr).await,
Protocol::Udp => self.run_udp(&bind_addr).await,
}
}
async fn run_tcp(&self, bind_addr: &str) -> Result<()> {
let listener = TcpListener::bind(bind_addr).await?;
info!("TCP server listening on {}", bind_addr);
loop {
match listener.accept().await {
Ok((stream, addr)) => {
info!("New connection from {}", addr);
let config = self.config.clone();
let measurements = self.measurements.clone();
tokio::spawn(async move {
if let Err(e) = handle_tcp_client(stream, addr, config, measurements).await
{
error!("Error handling client {}: {}", addr, e);
}
});
}
Err(e) => {
error!("Error accepting connection: {}", e);
}
}
}
}
async fn run_udp(&self, bind_addr: &str) -> Result<()> {
let socket = UdpSocket::bind(bind_addr).await?;
info!("UDP server listening on {}", bind_addr);
let mut buf = vec![0u8; 65536];
loop {
match socket.recv_from(&mut buf).await {
Ok((len, addr)) => {
debug!("Received {} bytes from {}", len, addr);
}
Err(e) => {
error!("Error receiving UDP packet: {}", e);
}
}
}
}
pub fn get_measurements(&self) -> crate::Measurements {
self.measurements.get()
}
}
async fn handle_tcp_client(
mut stream: TcpStream,
addr: SocketAddr,
config: Config,
measurements: MeasurementsCollector,
) -> Result<()> {
let setup_msg = deserialize_message(&mut stream).await?;
let (duration, reverse, _parallel) = match setup_msg {
Message::Setup {
version: _,
protocol,
duration,
reverse,
parallel,
..
} => {
info!(
"Client {} setup: protocol={}, duration={}s, reverse={}, parallel={}",
addr, protocol, duration, reverse, parallel
);
(Duration::from_secs(duration), reverse, parallel)
}
_ => {
return Err(Error::Protocol("Expected Setup message".to_string()));
}
};
let ack = Message::setup_ack(config.port, format!("{}", addr));
let ack_bytes = serialize_message(&ack)?;
stream.write_all(&ack_bytes).await?;
stream.flush().await?;
let start_msg = Message::start(
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs(),
);
let start_bytes = serialize_message(&start_msg)?;
stream.write_all(&start_bytes).await?;
stream.flush().await?;
measurements.set_start_time(Instant::now());
if reverse {
send_data(&mut stream, 0, duration, &measurements, &config).await?;
} else {
receive_data(&mut stream, 0, duration, &measurements, &config).await?;
}
let final_measurements = measurements.get();
if let Some(stream_stats) = final_measurements.streams.first() {
let result_msg = Message::result(
0,
stream_stats.bytes_sent,
stream_stats.bytes_received,
final_measurements.total_duration.as_secs_f64(),
final_measurements.total_bits_per_second(),
None,
);
let result_bytes = serialize_message(&result_msg)?;
stream.write_all(&result_bytes).await?;
stream.flush().await?;
}
let done_msg = Message::done();
let done_bytes = serialize_message(&done_msg)?;
stream.write_all(&done_bytes).await?;
stream.flush().await?;
info!(
"Test completed for {}: {:.2} Mbps",
addr,
final_measurements.total_bits_per_second() / 1_000_000.0
);
Ok(())
}
async fn send_data(
stream: &mut TcpStream,
stream_id: usize,
duration: Duration,
measurements: &MeasurementsCollector,
config: &Config,
) -> Result<()> {
let buffer = vec![0u8; config.buffer_size];
let start = Instant::now();
let mut last_interval = start;
let mut interval_bytes = 0u64;
while start.elapsed() < duration {
match stream.write(&buffer).await {
Ok(n) => {
measurements.record_bytes_sent(stream_id, n as u64);
interval_bytes += n as u64;
if last_interval.elapsed() >= config.interval {
let elapsed = start.elapsed();
let interval_duration = last_interval.elapsed();
let bps = (interval_bytes as f64 * 8.0) / interval_duration.as_secs_f64();
let interval_start = if elapsed > interval_duration {
elapsed - interval_duration
} else {
Duration::ZERO
};
measurements.add_interval(IntervalStats {
start: interval_start,
end: elapsed,
bytes: interval_bytes,
bits_per_second: bps,
packets: None,
});
interval_bytes = 0;
last_interval = Instant::now();
}
}
Err(e) => {
error!("Error sending data: {}", e);
break;
}
}
}
measurements.set_duration(start.elapsed());
stream.flush().await?;
Ok(())
}
async fn receive_data(
stream: &mut TcpStream,
stream_id: usize,
duration: Duration,
measurements: &MeasurementsCollector,
config: &Config,
) -> Result<()> {
let mut buffer = vec![0u8; config.buffer_size];
let start = Instant::now();
let mut last_interval = start;
let mut interval_bytes = 0u64;
while start.elapsed() < duration {
match time::timeout(Duration::from_millis(100), stream.read(&mut buffer)).await {
Ok(Ok(0)) => {
break;
}
Ok(Ok(n)) => {
measurements.record_bytes_received(stream_id, n as u64);
interval_bytes += n as u64;
if last_interval.elapsed() >= config.interval {
let elapsed = start.elapsed();
let interval_duration = last_interval.elapsed();
let bps = (interval_bytes as f64 * 8.0) / interval_duration.as_secs_f64();
let interval_start = if elapsed > interval_duration {
elapsed - interval_duration
} else {
Duration::ZERO
};
measurements.add_interval(IntervalStats {
start: interval_start,
end: elapsed,
bytes: interval_bytes,
bits_per_second: bps,
packets: None,
});
interval_bytes = 0;
last_interval = Instant::now();
}
}
Ok(Err(e)) => {
error!("Error receiving data: {}", e);
break;
}
Err(_) => {
if start.elapsed() >= duration {
break;
}
}
}
}
measurements.set_duration(start.elapsed());
Ok(())
}