use futures::future;
use rand::prelude::*;
use rand_xoshiro::Xoroshiro128StarStar;
use std::{
io::{Error, ErrorKind, Result},
time::Duration,
};
use tokio::{
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
join, select,
sync::{mpsc, oneshot, watch},
time::Instant,
};
use tracing::Instrument;
use aggligator::exec;
const BUF_SIZE: usize = 8192;
const MB: f64 = 1_048_576.;
pub const INTERVAL: Duration = Duration::from_secs(10);
#[allow(clippy::too_many_arguments)]
#[tracing::instrument(skip_all, fields(name =% name))]
pub async fn speed_test(
name: &str, mut read: impl AsyncRead + Unpin + Send + 'static,
mut write: impl AsyncWrite + Unpin + Send + 'static, limit: Option<usize>, duration: Option<Duration>,
send: bool, receive: bool, recv_block: bool, report_interval: Duration,
speed_tx: Option<watch::Sender<(f64, f64)>>,
) -> Result<(f64, f64)> {
let start = Instant::now();
let (stop_tx, _stop_rx) = mpsc::channel::<()>(1);
let (send_tx, recv_tx) = match speed_tx {
Some(speed_tx) => {
let (send_tx, send_rx) = watch::channel(0.);
let (recv_tx, recv_rx) = watch::channel(0.);
let mut send_rx = Some(send_rx);
let mut recv_rx = Some(recv_rx);
exec::spawn(async move {
while send_rx.is_some() || recv_rx.is_some() {
let send = send_rx.as_mut().map(|rx| *rx.borrow_and_update()).unwrap_or_default();
let recv = recv_rx.as_mut().map(|rx| *rx.borrow_and_update()).unwrap_or_default();
if speed_tx.send((send, recv)).is_err() {
break;
}
select! {
res = async {
match &mut send_rx {
Some(rx) => rx.changed().await,
None => future::pending().await,
}
} => if res.is_err() { send_rx = None},
res = async {
match &mut recv_rx {
Some(rx) => rx.changed().await,
None => future::pending().await,
}
} => if res.is_err() { recv_rx = None},
}
}
});
(Some(send_tx), Some(recv_tx))
}
None => (None, None),
};
tracing::info!("Starting speed test");
#[cfg(debug_assertions)]
tracing::warn!("debug build, speed test will not be accurate");
let sender_stop_tx = stop_tx.clone();
let (stop_sender_tx, mut stop_sender_rx) = oneshot::channel();
let sender = exec::spawn(
async move {
if !send {
return Ok((0, Duration::ZERO));
}
let seed = rand::random();
write.write_u64(seed).await?;
let mut rng = Xoroshiro128StarStar::seed_from_u64(seed);
let mut sent_total = 0;
let mut sent_interval = 0;
let mut interval_start = Instant::now();
#[allow(clippy::assertions_on_constants)]
while limit.map(|limit| sent_total <= limit).unwrap_or(true)
&& !sender_stop_tx.is_closed()
&& start.elapsed() < duration.unwrap_or(Duration::MAX)
{
assert!(BUF_SIZE % 8 == 0);
let mut buf = [0; BUF_SIZE];
rng.fill_bytes(&mut buf);
write.write_all(&buf).await?;
sent_total += BUF_SIZE;
sent_interval += BUF_SIZE;
if interval_start.elapsed() >= report_interval {
let speed = sent_interval as f64 / interval_start.elapsed().as_secs_f64();
tracing::info!("Send speed: {:.1} MB/s", speed / MB);
if let Some(tx) = &send_tx {
if tx.send(speed).is_err() {
break;
}
}
sent_interval = 0;
interval_start = Instant::now();
}
if let Ok(()) = stop_sender_rx.try_recv() {
break;
}
}
tracing::info!("Sender exited");
Ok::<_, Error>((sent_total, start.elapsed()))
}
.in_current_span(),
);
let receiver = exec::spawn(
async move {
if !receive {
return Ok((0, Duration::ZERO));
}
let remote_seed = read.read_u64().await?;
let mut rng = Xoroshiro128StarStar::seed_from_u64(remote_seed);
if recv_block {
stop_tx.closed().await;
return Ok((0, Duration::ZERO));
}
let mut recved_total = 0;
let mut recved_interval = 0;
let mut interval_start = Instant::now();
while !stop_tx.is_closed() && start.elapsed() < duration.unwrap_or(Duration::MAX) {
let mut buf = [0; BUF_SIZE];
let mut n = read.read(&mut buf).await?;
if n == 0 {
break;
}
match n % 8 {
0 => (),
rem => {
n += read.read_exact(&mut buf[n..(n + 8 - rem)]).await?;
if n % 8 != 0 {
break;
}
}
}
let buf = &buf[..n];
let mut chk_buf = vec![0; n];
assert!(n % 8 == 0);
rng.fill_bytes(&mut chk_buf);
if chk_buf != buf {
let _ = stop_sender_tx.send(());
return Err(Error::new(ErrorKind::InvalidData, "received data is malformed"));
}
recved_total += n;
recved_interval += n;
if interval_start.elapsed() >= report_interval {
let speed = recved_interval as f64 / interval_start.elapsed().as_secs_f64();
tracing::info!("Receive speed: {:.1} MB/s", speed / MB);
if let Some(tx) = &recv_tx {
if tx.send(speed).is_err() {
break;
}
}
recved_interval = 0;
interval_start = Instant::now();
}
}
tracing::info!("Receiver exited");
Ok((recved_total, start.elapsed()))
}
.in_current_span(),
);
let (Ok(sender), Ok(receiver)) = join!(sender, receiver) else { unreachable!() };
if let Err(err) = &sender {
tracing::warn!(%err, "Sender error");
}
if let Err(err) = &receiver {
tracing::warn!(%err, "Receiver error");
}
let (tx_total, tx_dur) = sender?;
let tx_speed = tx_total as f64 / tx_dur.as_secs_f64().max(1e-10);
let (rx_total, rx_dur) = receiver?;
let rx_speed = rx_total as f64 / rx_dur.as_secs_f64().max(1e-10);
tracing::info!("Upstream: {tx_speed:.0} bytes/s Downstream: {rx_speed:.0} bytes/s");
Ok((tx_speed, rx_speed))
}