use std::{
path::{Path, PathBuf},
time::{Duration, Instant},
};
use indicatif::{ProgressBar, ProgressState, ProgressStyle};
use log::{debug, info, log_enabled, trace, warn, Level};
use rand::{
distr::{Distribution, Uniform},
rng,
};
use russh::{client, ChannelMsg};
use russh_sftp::client::SftpSession;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use crate::{
summary::{EchoTestSummary, SpeedTestResult, SpeedTestSummary},
util::Formatter,
};
fn get_progress_bar_style(test_name: &str) -> ProgressStyle {
ProgressStyle::default_bar()
.template(
&format!(
"{name} {{spinner:.green}} [{{elapsed_precise}}] [{{wide_bar:.cyan/blue}}] {{bytes}}/{{total_bytes}} ({{eta}})",
name = test_name
)
)
.unwrap()
.with_key("eta", |state: &ProgressState, w: &mut dyn std::fmt::Write|
write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap()
)
.progress_chars("#>-")
}
pub async fn run_echo_test<H: client::Handler>(
session: &mut client::Handle<H>,
echo_cmd: &str,
char_count: usize,
time_limit: Option<f64>,
formatter: &Formatter,
) -> Result<EchoTestSummary, String> {
info!("Running echo latency test");
debug!("Running echo test with command: {echo_cmd:?}");
debug!("Number of characters to echo: {char_count:?}");
debug!("Time limit for echo: {time_limit:?} seconds");
trace!("Preparing channel session");
let mut channel = session
.channel_open_session()
.await
.map_err(|e| e.to_string())?;
channel
.request_pty(true, "sshping", 10, 5, 0, 0, &[])
.await
.map_err(|e| e.to_string())?;
channel
.request_shell(false)
.await
.map_err(|e| e.to_string())?;
trace!("Starting echo command");
let echo_cmd_bytes = format!("{echo_cmd}\n").into_bytes();
channel
.data(&echo_cmd_bytes[..])
.await
.map_err(|e| e.to_string())?;
tokio::time::sleep(Duration::from_millis(100)).await;
while let Some(msg) = channel.wait().await {
match msg {
ChannelMsg::Data { .. } => break,
ChannelMsg::Eof => return Err("Channel closed unexpectedly".to_string()),
_ => {}
}
}
trace!("Testing echo latency");
let write_buffer = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
let mut latencies = Vec::with_capacity(char_count);
let timeout = time_limit.map(Duration::from_secs_f64);
let start_time = Instant::now();
let progress_bar = ProgressBar::new(char_count as u64);
progress_bar.set_style(get_progress_bar_style("Echo test"));
for (n, idx) in (0..char_count).zip((0..write_buffer.len()).cycle()) {
let start = Instant::now();
let byte_slice = &write_buffer[idx..idx + 1];
channel.data(byte_slice).await.map_err(|e| e.to_string())?;
loop {
if let Some(msg) = channel.wait().await {
match msg {
ChannelMsg::Data { data } => {
if !data.is_empty() {
break;
}
}
ChannelMsg::Eof => {
return Err("Channel closed unexpectedly".to_string());
}
_ => {}
}
}
}
let latency = start.elapsed().as_nanos();
latencies.push(latency);
if let Some(timeout) = timeout
&& start_time.elapsed() > timeout
{
break;
}
progress_bar.set_position((n as u64) + 1);
}
progress_bar.finish_and_clear();
latencies.sort();
let result = EchoTestSummary::from_latencies(&latencies, formatter);
if result.char_sent == 0 {
return Err("Unable to get any echos in given time".to_string());
}
if result.char_sent < 20 {
warn!("Insufficient data points for accurate latency measurement");
}
if log_enabled!(Level::Info) {
let p1_latency = Duration::from_nanos(
latencies
.iter()
.rev()
.nth(result.char_sent / 100)
.unwrap()
.to_owned() as u64,
);
let p5_latency = Duration::from_nanos(
latencies
.iter()
.rev()
.nth(result.char_sent / 20)
.unwrap()
.to_owned() as u64,
);
let p10_latency = Duration::from_nanos(
latencies
.iter()
.rev()
.nth(result.char_sent / 10)
.unwrap()
.to_owned() as u64,
);
info!(
"Sent {}/{char_count}, Latency:\n\tMean:\t{}\n\tStd:\t{}\n\tMin:\t{}\n\tMedian:\t{}\n\tMax:\t{}\n\t1% High:\t{}\n\t5% High:\t{}\n\t10% High:\t{}",
result.char_sent,
result.avg_latency,
result.std_latency,
result.min_latency,
result.med_latency,
result.max_latency,
formatter.format_duration(p1_latency),
formatter.format_duration(p5_latency),
formatter.format_duration(p10_latency)
);
}
Ok(result)
}
async fn run_upload_test<H: client::Handler>(
session: &mut client::Handle<H>,
size: u64,
chunk_size: u64,
remote_file: &Path,
formatter: &Formatter,
) -> Result<SpeedTestResult, String> {
info!("Running upload speed test");
trace!("Establishing SFTP channel");
let channel = session
.channel_open_session()
.await
.map_err(|e| e.to_string())?;
channel
.request_subsystem(true, "sftp")
.await
.map_err(|e| e.to_string())?;
let sftp = SftpSession::new(channel.into_stream())
.await
.map_err(|e| e.to_string())?;
trace!("Generating random data");
let dist = Uniform::try_from(0..128_u8).unwrap();
let buffer: Vec<u8> = dist
.sample_iter(rng())
.take(size as usize)
.map(|v| (v & 0x3f) + 32)
.collect();
let remote_path = remote_file.to_str().ok_or("Invalid remote file path")?;
let mut file = sftp.create(remote_path).await.map_err(|e| e.to_string())?;
let mut total_bytes_sent = 0;
let start_time: Instant = Instant::now();
let progress_bar = ProgressBar::new(size);
progress_bar.set_style(get_progress_bar_style("Upload test"));
trace!("Sending file in chunks");
for chunk in buffer.chunks(chunk_size as usize) {
file.write_all(chunk).await.map_err(|e| e.to_string())?;
total_bytes_sent += chunk.len();
progress_bar.set_position(total_bytes_sent as u64);
}
progress_bar.finish_and_clear();
file.shutdown().await.map_err(|e| e.to_string())?;
let result = SpeedTestResult::new(total_bytes_sent as u64, start_time.elapsed(), formatter);
info!(
"Sent {}, Time Elapsed: {}, Average Speed: {}",
result.size, result.time, result.speed
);
Ok(result)
}
async fn run_download_test<H: client::Handler>(
session: &mut client::Handle<H>,
chunk_size: u64,
remote_file: &Path,
formatter: &Formatter,
) -> Result<SpeedTestResult, String> {
info!("Running download speed test");
trace!("Establishing SFTP channel");
let channel = session
.channel_open_session()
.await
.map_err(|e| e.to_string())?;
channel
.request_subsystem(true, "sftp")
.await
.map_err(|e| e.to_string())?;
let sftp = SftpSession::new(channel.into_stream())
.await
.map_err(|e| e.to_string())?;
let remote_path = remote_file.to_str().ok_or("Invalid remote file path")?;
let metadata = sftp
.metadata(remote_path)
.await
.map_err(|e| e.to_string())?;
let size = metadata.len();
if size == 0 {
return Err("Remote file is empty".to_string());
}
let mut file = sftp.open(remote_path).await.map_err(|e| e.to_string())?;
trace!("Preparing buffer for downloading");
let mut buffer = vec![0; chunk_size as usize];
let mut total_bytes_recv = 0;
let start_time: Instant = Instant::now();
let progress_bar = ProgressBar::new(size);
progress_bar.set_style(get_progress_bar_style("Download test"));
trace!("Receiving file in chunks");
while size - total_bytes_recv > chunk_size {
file.read_exact(&mut buffer)
.await
.map_err(|e| e.to_string())?;
total_bytes_recv += chunk_size;
progress_bar.set_position(total_bytes_recv);
}
if size - total_bytes_recv > 0 {
let mut remaining = vec![0; (size - total_bytes_recv) as usize];
file.read_exact(&mut remaining)
.await
.map_err(|e| e.to_string())?;
total_bytes_recv += remaining.len() as u64;
progress_bar.set_position(total_bytes_recv);
}
progress_bar.finish_and_clear();
let result = SpeedTestResult::new(total_bytes_recv, start_time.elapsed(), formatter);
info!(
"Received {}, Time Elapsed: {}, Average Speed: {}",
result.size, result.time, result.speed
);
Ok(result)
}
pub async fn run_speed_test<H: client::Handler>(
session: &mut client::Handle<H>,
size: u64,
chunk_size: u64,
remote_file: &PathBuf,
formatter: &Formatter,
) -> Result<SpeedTestSummary, String> {
info!("Running speed test");
debug!(
"Running speed test with file size: {}",
formatter.format_size(size)
);
debug!("Remote file path: {remote_file:?}");
let upload_result = run_upload_test(session, size, chunk_size, remote_file, formatter).await?;
let download_result = run_download_test(session, chunk_size, remote_file, formatter).await?;
Ok(SpeedTestSummary {
upload: upload_result,
download: download_result,
})
}