use crate::config::{Config, Protocol};
use crate::measurements::{
get_connection_info, get_system_info, get_tcp_stats, IntervalStats, MeasurementsCollector,
TestConfig,
};
use crate::protocol::{deserialize_message, serialize_message, Message};
use crate::{Error, Result};
use log::{debug, error, info};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpStream, UdpSocket};
use tokio::time;
#[derive(Debug, Clone)]
pub enum ProgressEvent {
TestStarted,
IntervalUpdate {
interval_start: Duration,
interval_end: Duration,
bytes: u64,
bits_per_second: f64,
packets: Option<u64>,
jitter_ms: Option<f64>,
lost_packets: Option<u64>,
lost_percent: Option<f64>,
retransmits: Option<u64>,
},
TestCompleted {
total_bytes: u64,
duration: Duration,
bits_per_second: f64,
total_packets: Option<u64>,
jitter_ms: Option<f64>,
lost_packets: Option<u64>,
lost_percent: Option<f64>,
out_of_order: Option<u64>,
},
Error(String),
}
pub trait ProgressCallback: Send + Sync {
fn on_progress(&self, event: ProgressEvent);
}
impl<F> ProgressCallback for F
where
F: Fn(ProgressEvent) + Send + Sync,
{
fn on_progress(&self, event: ProgressEvent) {
self(event)
}
}
type CallbackRef = Arc<dyn ProgressCallback>;
pub struct Client {
config: Config,
measurements: MeasurementsCollector,
callback: Option<CallbackRef>,
}
impl Client {
pub fn new(config: Config) -> Result<Self> {
if config.server_addr.is_none() {
return Err(Error::Config(
"Server address is required for client mode".to_string(),
));
}
Ok(Self {
config,
measurements: MeasurementsCollector::new(),
callback: None,
})
}
pub fn with_callback<C: ProgressCallback + 'static>(mut self, callback: C) -> Self {
self.callback = Some(Arc::new(callback));
self
}
fn notify(&self, event: ProgressEvent) {
if let Some(callback) = &self.callback {
callback.on_progress(event);
}
}
pub async fn run(&self) -> Result<()> {
let server_addr = self
.config
.server_addr
.as_ref()
.ok_or_else(|| Error::Config("Server address not set".to_string()))?;
let full_addr = format!("{}:{}", server_addr, self.config.port);
info!("Connecting to rperf3 server at {}", full_addr);
match self.config.protocol {
Protocol::Tcp => self.run_tcp(&full_addr).await,
Protocol::Udp => self.run_udp(&full_addr).await,
}
}
async fn run_tcp(&self, server_addr: &str) -> Result<()> {
let mut stream = TcpStream::connect(server_addr).await?;
info!("Connected to {}", server_addr);
let connection_info = get_connection_info(&stream).ok();
let system_info = Some(get_system_info());
let setup = Message::setup(
format!("{:?}", self.config.protocol),
self.config.duration,
self.config.bandwidth,
self.config.buffer_size,
self.config.parallel,
self.config.reverse,
);
let setup_bytes = serialize_message(&setup)?;
stream.write_all(&setup_bytes).await?;
stream.flush().await?;
let ack_msg = deserialize_message(&mut stream).await?;
match ack_msg {
Message::SetupAck { port, cookie } => {
debug!("Received setup ack: port={}, cookie={}", port, cookie);
}
Message::Error { message } => {
return Err(Error::Protocol(format!("Server error: {}", message)));
}
_ => {
return Err(Error::Protocol("Expected SetupAck message".to_string()));
}
}
let start_msg = deserialize_message(&mut stream).await?;
match start_msg {
Message::Start { .. } => {
info!("Test started");
self.notify(ProgressEvent::TestStarted);
}
_ => {
return Err(Error::Protocol("Expected Start message".to_string()));
}
}
self.measurements.set_start_time(Instant::now());
if self.config.reverse {
receive_data(
&mut stream,
0,
&self.measurements,
&self.config,
&self.callback,
)
.await?;
} else {
send_data(
&mut stream,
0,
&self.measurements,
&self.config,
&self.callback,
)
.await?;
}
let _tcp_stats = get_tcp_stats(&stream).ok();
match deserialize_message(&mut stream).await {
Ok(result_msg) => match result_msg {
Message::Result {
stream_id,
bytes_sent,
bytes_received,
duration: _,
bits_per_second,
..
} => {
info!(
"Stream {}: {} bytes sent, {} bytes received, {:.2} Mbps",
stream_id,
bytes_sent,
bytes_received,
bits_per_second / 1_000_000.0
);
}
_ => {
debug!("Unexpected message, continuing");
}
},
Err(e) => {
debug!(
"Could not read result message (connection may be closed): {}",
e
);
}
}
match deserialize_message(&mut stream).await {
Ok(done_msg) => match done_msg {
Message::Done => {
info!("Test completed");
}
_ => {
debug!("Expected Done message");
}
},
Err(e) => {
debug!(
"Could not read done message (connection may be closed): {}",
e
);
info!("Test completed");
}
}
let final_measurements = self.measurements.get();
self.notify(ProgressEvent::TestCompleted {
total_bytes: final_measurements.total_bytes_sent
+ final_measurements.total_bytes_received,
duration: final_measurements.total_duration,
bits_per_second: final_measurements.total_bits_per_second(),
total_packets: None, jitter_ms: None,
lost_packets: None,
lost_percent: None,
out_of_order: None,
});
if !self.config.json {
print_results(&final_measurements);
} else {
let test_config = TestConfig {
protocol: format!("{:?}", self.config.protocol),
num_streams: self.config.parallel,
blksize: self.config.buffer_size,
omit: 0,
duration: self.config.duration.as_secs(),
reverse: self.config.reverse,
};
let detailed_results =
self.measurements
.get_detailed_results(connection_info, system_info, test_config);
let json = serde_json::to_string_pretty(&detailed_results)?;
println!("{}", json);
}
Ok(())
}
async fn run_udp(&self, server_addr: &str) -> Result<()> {
let mut control_stream = TcpStream::connect(server_addr).await?;
let setup = Message::setup(
format!("{:?}", self.config.protocol),
self.config.duration,
self.config.bandwidth,
self.config.buffer_size,
self.config.parallel,
self.config.reverse,
);
let setup_bytes = serialize_message(&setup)?;
control_stream.write_all(&setup_bytes).await?;
control_stream.flush().await?;
let ack_msg = deserialize_message(&mut control_stream).await?;
match ack_msg {
Message::SetupAck { port, cookie } => {
debug!("Received setup ack: port={}, cookie={}", port, cookie);
}
Message::Error { message } => {
return Err(Error::Protocol(format!("Server error: {}", message)));
}
_ => {
return Err(Error::Protocol("Expected SetupAck message".to_string()));
}
}
let start_msg = deserialize_message(&mut control_stream).await?;
match start_msg {
Message::Start { .. } => {
info!("Test started");
self.notify(ProgressEvent::TestStarted);
}
_ => {
return Err(Error::Protocol("Expected Start message".to_string()));
}
}
let socket = UdpSocket::bind("0.0.0.0:0").await?;
socket.connect(server_addr).await?;
info!("UDP client connected to {}", server_addr);
let result = if self.config.reverse {
let init_packet = crate::udp_packet::create_packet(u64::MAX, 0);
socket.send(&init_packet).await?;
self.run_udp_receive(socket).await
} else {
self.run_udp_send(socket).await
};
drop(control_stream);
result
}
async fn run_udp_send(&self, socket: UdpSocket) -> Result<()> {
let start = Instant::now();
let mut last_interval = start;
let mut interval_bytes = 0u64;
let mut interval_packets = 0u64;
let mut sequence = 0u64;
let payload_size = if self.config.buffer_size > crate::udp_packet::UdpPacketHeader::SIZE {
self.config.buffer_size - crate::udp_packet::UdpPacketHeader::SIZE
} else {
1024
};
let target_bytes_per_sec = self.config.bandwidth.map(|bw| bw / 8);
let mut total_bytes_sent = 0u64;
let mut last_bandwidth_check = start;
while start.elapsed() < self.config.duration {
let packet = crate::udp_packet::create_packet(sequence, payload_size);
match socket.send(&packet).await {
Ok(n) => {
self.measurements.record_bytes_sent(0, n as u64);
self.measurements.record_udp_packet(0);
interval_bytes += n as u64;
interval_packets += 1;
sequence += 1;
total_bytes_sent += n as u64;
if let Some(target_bps) = target_bytes_per_sec {
let elapsed = last_bandwidth_check.elapsed().as_secs_f64();
if elapsed >= 0.001 {
let expected_bytes = (target_bps as f64 * elapsed) as u64;
let bytes_sent_in_period = total_bytes_sent;
if bytes_sent_in_period > expected_bytes {
let bytes_ahead = (bytes_sent_in_period - expected_bytes) as f64;
let sleep_time = bytes_ahead / target_bps as f64;
if sleep_time > 0.0001 {
time::sleep(Duration::from_secs_f64(sleep_time)).await;
}
}
last_bandwidth_check = Instant::now();
total_bytes_sent = 0;
}
}
if last_interval.elapsed() >= self.config.interval {
let elapsed = start.elapsed();
let interval_duration = last_interval.elapsed();
let bps = (interval_bytes as f64 * 8.0) / interval_duration.as_secs_f64();
let interval_start = if elapsed > interval_duration {
elapsed - interval_duration
} else {
Duration::ZERO
};
self.measurements.add_interval(IntervalStats {
start: interval_start,
end: elapsed,
bytes: interval_bytes,
bits_per_second: bps,
packets: Some(interval_packets),
});
let (lost, expected) = self.measurements.calculate_udp_loss();
let loss_percent = if expected > 0 {
(lost as f64 / expected as f64) * 100.0
} else {
0.0
};
let measurements = self.measurements.get();
self.notify(ProgressEvent::IntervalUpdate {
interval_start,
interval_end: elapsed,
bytes: interval_bytes,
bits_per_second: bps,
packets: Some(interval_packets),
jitter_ms: Some(measurements.jitter_ms),
lost_packets: Some(lost),
lost_percent: Some(loss_percent),
retransmits: None,
});
if !self.config.json {
println!(
"[{:4.1}-{:4.1} sec] {} bytes {:.2} Mbps ({} packets)",
interval_start.as_secs_f64(),
elapsed.as_secs_f64(),
interval_bytes,
bps / 1_000_000.0,
interval_packets
);
}
interval_bytes = 0;
interval_packets = 0;
last_interval = Instant::now();
}
}
Err(e) => {
error!("Error sending UDP packet: {}", e);
break;
}
}
}
self.measurements.set_duration(start.elapsed());
let final_measurements = self.measurements.get();
let (lost, expected) = self.measurements.calculate_udp_loss();
let loss_percent = if expected > 0 {
(lost as f64 / expected as f64) * 100.0
} else {
0.0
};
self.notify(ProgressEvent::TestCompleted {
total_bytes: final_measurements.total_bytes_sent
+ final_measurements.total_bytes_received,
duration: final_measurements.total_duration,
bits_per_second: final_measurements.total_bits_per_second(),
total_packets: Some(final_measurements.total_packets),
jitter_ms: Some(final_measurements.jitter_ms),
lost_packets: Some(lost),
lost_percent: Some(loss_percent),
out_of_order: Some(final_measurements.out_of_order_packets),
});
if !self.config.json {
print_results(&final_measurements);
} else {
let system_info = Some(get_system_info());
let test_config = TestConfig {
protocol: format!("{:?}", self.config.protocol),
num_streams: self.config.parallel,
blksize: self.config.buffer_size,
omit: 0,
duration: self.config.duration.as_secs(),
reverse: self.config.reverse,
};
let detailed_results = self.measurements.get_detailed_results(
None, system_info,
test_config,
);
let json = serde_json::to_string_pretty(&detailed_results)?;
println!("{}", json);
}
Ok(())
}
async fn run_udp_receive(&self, socket: UdpSocket) -> Result<()> {
let start = Instant::now();
let mut last_interval = start;
let mut interval_bytes = 0u64;
let mut interval_packets = 0u64;
let mut buffer = vec![0u8; 65536];
while start.elapsed() < self.config.duration {
let timeout =
tokio::time::timeout(Duration::from_millis(100), socket.recv(&mut buffer));
match timeout.await {
Ok(Ok(n)) => {
if let Some((header, _payload)) = crate::udp_packet::parse_packet(&buffer[..n])
{
let recv_timestamp_us = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("Time went backwards")
.as_micros() as u64;
self.measurements.record_udp_packet_received(
header.sequence,
header.timestamp_us,
recv_timestamp_us,
);
}
self.measurements.record_bytes_received(0, n as u64);
interval_bytes += n as u64;
interval_packets += 1;
if last_interval.elapsed() >= self.config.interval {
let elapsed = start.elapsed();
let interval_duration = last_interval.elapsed();
let bps = (interval_bytes as f64 * 8.0) / interval_duration.as_secs_f64();
let interval_start = if elapsed > interval_duration {
elapsed - interval_duration
} else {
Duration::ZERO
};
self.measurements.add_interval(IntervalStats {
start: interval_start,
end: elapsed,
bytes: interval_bytes,
bits_per_second: bps,
packets: Some(interval_packets),
});
let (lost, expected) = self.measurements.calculate_udp_loss();
let loss_percent = if expected > 0 {
(lost as f64 / expected as f64) * 100.0
} else {
0.0
};
let measurements = self.measurements.get();
self.notify(ProgressEvent::IntervalUpdate {
interval_start,
interval_end: elapsed,
bytes: interval_bytes,
bits_per_second: bps,
packets: Some(interval_packets),
jitter_ms: Some(measurements.jitter_ms),
lost_packets: Some(lost),
lost_percent: Some(loss_percent),
retransmits: None,
});
if !self.config.json {
println!(
"[{:4.1}-{:4.1} sec] {} bytes {:.2} Mbps ({} packets)",
interval_start.as_secs_f64(),
elapsed.as_secs_f64(),
interval_bytes,
bps / 1_000_000.0,
interval_packets
);
}
interval_bytes = 0;
interval_packets = 0;
last_interval = Instant::now();
}
}
Ok(Err(e)) => {
error!("Error receiving UDP packet: {}", e);
break;
}
Err(_) => {
continue;
}
}
}
self.measurements.set_duration(start.elapsed());
let final_measurements = self.measurements.get();
let (lost, expected) = self.measurements.calculate_udp_loss();
let loss_percent = if expected > 0 {
(lost as f64 / expected as f64) * 100.0
} else {
0.0
};
self.notify(ProgressEvent::TestCompleted {
total_bytes: final_measurements.total_bytes_sent
+ final_measurements.total_bytes_received,
duration: final_measurements.total_duration,
bits_per_second: final_measurements.total_bits_per_second(),
total_packets: Some(final_measurements.total_packets),
jitter_ms: Some(final_measurements.jitter_ms),
lost_packets: Some(lost),
lost_percent: Some(loss_percent),
out_of_order: Some(final_measurements.out_of_order_packets),
});
if !self.config.json {
print_results(&final_measurements);
} else {
let system_info = Some(get_system_info());
let test_config = TestConfig {
protocol: format!("{:?}", self.config.protocol),
num_streams: self.config.parallel,
blksize: self.config.buffer_size,
omit: 0,
duration: self.config.duration.as_secs(),
reverse: self.config.reverse,
};
let detailed_results = self.measurements.get_detailed_results(
None, system_info,
test_config,
);
let json = serde_json::to_string_pretty(&detailed_results)?;
println!("{}", json);
}
Ok(())
}
pub fn get_measurements(&self) -> crate::Measurements {
self.measurements.get()
}
}
async fn send_data(
stream: &mut TcpStream,
stream_id: usize,
measurements: &MeasurementsCollector,
config: &Config,
callback: &Option<CallbackRef>,
) -> Result<()> {
let buffer = vec![0u8; config.buffer_size];
let start = Instant::now();
let mut last_interval = start;
let mut interval_bytes = 0u64;
while start.elapsed() < config.duration {
match stream.write(&buffer).await {
Ok(n) => {
measurements.record_bytes_sent(stream_id, n as u64);
interval_bytes += n as u64;
if last_interval.elapsed() >= config.interval {
let elapsed = start.elapsed();
let interval_duration = last_interval.elapsed();
let bps = (interval_bytes as f64 * 8.0) / interval_duration.as_secs_f64();
let interval_start = if elapsed > interval_duration {
elapsed - interval_duration
} else {
Duration::ZERO
};
measurements.add_interval(IntervalStats {
start: interval_start,
end: elapsed,
bytes: interval_bytes,
bits_per_second: bps,
packets: None,
});
if let Some(cb) = callback {
cb.on_progress(ProgressEvent::IntervalUpdate {
interval_start,
interval_end: elapsed,
bytes: interval_bytes,
bits_per_second: bps,
packets: None,
jitter_ms: None,
lost_packets: None,
lost_percent: None,
retransmits: None,
});
}
if !config.json {
println!(
"[{:4.1}-{:4.1} sec] {} bytes {:.2} Mbps",
interval_start.as_secs_f64(),
elapsed.as_secs_f64(),
interval_bytes,
bps / 1_000_000.0
);
}
interval_bytes = 0;
last_interval = Instant::now();
}
}
Err(e) => {
error!("Error sending data: {}", e);
break;
}
}
}
measurements.set_duration(start.elapsed());
stream.flush().await?;
Ok(())
}
async fn receive_data(
stream: &mut TcpStream,
stream_id: usize,
measurements: &MeasurementsCollector,
config: &Config,
callback: &Option<CallbackRef>,
) -> Result<()> {
let mut buffer = vec![0u8; config.buffer_size];
let start = Instant::now();
let mut last_interval = start;
let mut interval_bytes = 0u64;
while start.elapsed() < config.duration {
match time::timeout(Duration::from_millis(100), stream.read(&mut buffer)).await {
Ok(Ok(0)) => {
break;
}
Ok(Ok(n)) => {
measurements.record_bytes_received(stream_id, n as u64);
interval_bytes += n as u64;
if last_interval.elapsed() >= config.interval {
let elapsed = start.elapsed();
let interval_duration = last_interval.elapsed();
let bps = (interval_bytes as f64 * 8.0) / interval_duration.as_secs_f64();
let interval_start = if elapsed > interval_duration {
elapsed - interval_duration
} else {
Duration::ZERO
};
measurements.add_interval(IntervalStats {
start: interval_start,
end: elapsed,
bytes: interval_bytes,
bits_per_second: bps,
packets: None,
});
if let Some(cb) = callback {
cb.on_progress(ProgressEvent::IntervalUpdate {
interval_start,
interval_end: elapsed,
bytes: interval_bytes,
bits_per_second: bps,
packets: None,
jitter_ms: None,
lost_packets: None,
lost_percent: None,
retransmits: None,
});
}
if !config.json {
println!(
"[{:4.1}-{:4.1} sec] {} bytes {:.2} Mbps",
interval_start.as_secs_f64(),
elapsed.as_secs_f64(),
interval_bytes,
bps / 1_000_000.0
);
}
interval_bytes = 0;
last_interval = Instant::now();
}
}
Ok(Err(e)) => {
error!("Error receiving data: {}", e);
break;
}
Err(_) => {
if start.elapsed() >= config.duration {
break;
}
}
}
}
measurements.set_duration(start.elapsed());
Ok(())
}
fn print_results(measurements: &crate::Measurements) {
println!("\n- - - - - - - - - - - - - - - - - - - - - - - - -");
println!("Test Complete");
println!("- - - - - - - - - - - - - - - - - - - - - - - - -");
for (i, stream) in measurements.streams.iter().enumerate() {
println!(
"Stream {}: {:.2} seconds, {} bytes, {:.2} Mbps",
i,
stream.duration.as_secs_f64(),
stream.bytes_sent + stream.bytes_received,
stream.bits_per_second() / 1_000_000.0
);
}
println!("- - - - - - - - - - - - - - - - - - - - - - - - -");
println!(
"Total: {:.2} seconds, {} bytes sent, {} bytes received",
measurements.total_duration.as_secs_f64(),
measurements.total_bytes_sent,
measurements.total_bytes_received
);
println!(
"Bandwidth: {:.2} Mbps",
measurements.total_bits_per_second() / 1_000_000.0
);
if measurements.total_packets > 0 {
let (lost, expected) = if measurements.total_bytes_received > 0 {
let (l, e) = measurements.calculate_udp_loss();
(l, e)
} else {
(0, measurements.total_packets)
};
let loss_percent = if expected > 0 {
(lost as f64 / expected as f64) * 100.0
} else {
0.0
};
if measurements.total_bytes_received > 0 {
println!(
"UDP: {} packets received, {} lost ({:.2}%), {:.3} ms jitter",
expected, lost, loss_percent, measurements.jitter_ms
);
if measurements.out_of_order_packets > 0 {
println!(
" {} out-of-order packets",
measurements.out_of_order_packets
);
}
} else {
println!(
"UDP: {} packets sent (loss and jitter measured at receiver)",
measurements.total_packets
);
}
}
println!("- - - - - - - - - - - - - - - - - - - - - - - - -\n");
}