use rperf3::{Client, Config, ProgressCallback, ProgressEvent, Protocol, Server};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::time::sleep;
struct TestCallback {
events: Arc<Mutex<Vec<ProgressEvent>>>,
}
impl TestCallback {
fn new() -> Self {
Self {
events: Arc::new(Mutex::new(Vec::new())),
}
}
#[allow(dead_code)]
fn get_events(&self) -> Vec<ProgressEvent> {
self.events.lock().unwrap().clone()
}
}
impl ProgressCallback for TestCallback {
fn on_progress(&self, event: ProgressEvent) {
self.events.lock().unwrap().push(event);
}
}
#[tokio::test]
async fn test_custom_callback_struct() {
let server_config = Config::server(15201).with_protocol(Protocol::Tcp);
let server = Server::new(server_config);
tokio::spawn(async move {
let _ = server.run().await;
});
sleep(Duration::from_millis(100)).await;
let callback = TestCallback::new();
let events_ref = callback.events.clone();
let client_config = Config::client("127.0.0.1".to_string(), 15201)
.with_protocol(Protocol::Tcp)
.with_duration(Duration::from_secs(2));
let client = Client::new(client_config).unwrap().with_callback(callback);
let _ = client.run().await;
let events = events_ref.lock().unwrap();
assert!(!events.is_empty(), "Should have received events");
assert!(
events
.iter()
.any(|e| matches!(e, ProgressEvent::TestStarted)),
"Should have received TestStarted event"
);
let interval_updates: Vec<_> = events
.iter()
.filter(|e| matches!(e, ProgressEvent::IntervalUpdate { .. }))
.collect();
assert!(
!interval_updates.is_empty(),
"Should have received at least one IntervalUpdate event"
);
assert!(
events
.iter()
.any(|e| matches!(e, ProgressEvent::TestCompleted { .. })),
"Should have received TestCompleted event"
);
}
#[tokio::test]
async fn test_closure_callback() {
let server_config = Config::server(15202).with_protocol(Protocol::Tcp);
let server = Server::new(server_config);
tokio::spawn(async move {
let _ = server.run().await;
});
sleep(Duration::from_millis(100)).await;
let events = Arc::new(Mutex::new(Vec::new()));
let events_clone = events.clone();
let client_config = Config::client("127.0.0.1".to_string(), 15202)
.with_protocol(Protocol::Tcp)
.with_duration(Duration::from_secs(2));
let client = Client::new(client_config)
.unwrap()
.with_callback(move |event: ProgressEvent| {
events_clone.lock().unwrap().push(event);
});
let _ = client.run().await;
let captured_events = events.lock().unwrap();
assert!(!captured_events.is_empty(), "Should have captured events");
}
#[tokio::test]
async fn test_udp_metrics_in_callbacks() {
let server_config = Config::server(15203).with_protocol(Protocol::Tcp);
let server = Server::new(server_config);
tokio::spawn(async move {
let _ = server.run().await;
});
sleep(Duration::from_millis(100)).await;
let has_interval = Arc::new(Mutex::new(false));
let interval_ref = has_interval.clone();
let has_completion = Arc::new(Mutex::new(false));
let completion_ref = has_completion.clone();
let client_config = Config::client("127.0.0.1".to_string(), 15203)
.with_protocol(Protocol::Tcp)
.with_duration(Duration::from_secs(2));
let client = Client::new(client_config)
.unwrap()
.with_callback(move |event: ProgressEvent| {
match event {
ProgressEvent::IntervalUpdate {
packets,
jitter_ms,
lost_packets,
lost_percent,
retransmits,
..
} => {
*interval_ref.lock().unwrap() = true;
let _p = packets;
let _j = jitter_ms;
let _l = lost_packets;
let _lp = lost_percent;
let _r = retransmits;
}
ProgressEvent::TestCompleted {
total_packets,
jitter_ms,
lost_packets,
lost_percent,
out_of_order,
..
} => {
*completion_ref.lock().unwrap() = true;
assert!(
total_packets.is_some() || total_packets.is_none(),
"Field accessible"
);
assert!(
jitter_ms.is_some() || jitter_ms.is_none(),
"Field accessible"
);
assert!(
lost_packets.is_some() || lost_packets.is_none(),
"Field accessible"
);
assert!(
lost_percent.is_some() || lost_percent.is_none(),
"Field accessible"
);
assert!(
out_of_order.is_some() || out_of_order.is_none(),
"Field accessible"
);
}
_ => {}
}
});
let _ = client.run().await;
assert!(
*has_interval.lock().unwrap(),
"Should have received interval updates"
);
assert!(
*has_completion.lock().unwrap(),
"Should have received completion event"
);
}
#[tokio::test]
async fn test_tcp_callback_no_udp_metrics() {
let server_config = Config::server(15204).with_protocol(Protocol::Tcp);
let server = Server::new(server_config);
tokio::spawn(async move {
let _ = server.run().await;
});
sleep(Duration::from_millis(100)).await;
let has_tcp_completion = Arc::new(Mutex::new(false));
let tcp_flag = has_tcp_completion.clone();
let client_config = Config::client("127.0.0.1".to_string(), 15204)
.with_protocol(Protocol::Tcp)
.with_duration(Duration::from_secs(2));
let client = Client::new(client_config)
.unwrap()
.with_callback(move |event: ProgressEvent| {
if let ProgressEvent::TestCompleted {
total_packets,
jitter_ms,
lost_packets,
lost_percent,
..
} = event
{
*tcp_flag.lock().unwrap() = true;
let _packets = total_packets;
let _jitter = jitter_ms;
let _lost = lost_packets;
let _percent = lost_percent;
}
});
let _ = client.run().await;
assert!(
*has_tcp_completion.lock().unwrap(),
"Should have received TestCompleted event for TCP"
);
}
#[test]
fn test_callback_trait_implementation() {
let _closure_callback = |event: ProgressEvent| {
println!("Event: {:?}", event);
};
fn handle_event(event: ProgressEvent) {
println!("Event: {:?}", event);
}
let _fn_callback = handle_event;
struct MyCallback;
impl ProgressCallback for MyCallback {
fn on_progress(&self, _event: ProgressEvent) {
}
}
let _custom_callback = MyCallback;
}