use std::{
collections::VecDeque,
fmt,
path::Path,
time::{Duration, Instant},
};
use indicatif::{ProgressBar, ProgressState, ProgressStyle};
use log::{debug, info, log_enabled, trace, warn, Level};
use rand::{
distr::{Distribution, Uniform},
rng,
};
use russh::{client, ChannelMsg};
use russh_sftp::client::SftpSession;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use crate::{
summary::{EchoTestSummary, SpeedTestResult, SpeedTestSummary},
util::Formatter,
};
#[derive(Debug)]
pub enum TestError {
Ssh(String),
ChannelClosed,
InvalidRemotePath,
EmptyEchoResult,
EmptyRemoteFile,
SummaryCreation(String),
}
impl fmt::Display for TestError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Ssh(msg) => write!(f, "{msg}"),
Self::ChannelClosed => write!(f, "Channel closed unexpectedly"),
Self::InvalidRemotePath => write!(f, "Invalid remote file path"),
Self::EmptyEchoResult => write!(f, "Unable to get any echos in given time"),
Self::EmptyRemoteFile => write!(f, "Remote file is empty"),
Self::SummaryCreation(msg) => write!(f, "Failed to summarize test result: {msg}"),
}
}
}
impl std::error::Error for TestError {}
impl From<String> for TestError {
fn from(value: String) -> Self {
Self::Ssh(value)
}
}
fn get_progress_bar_style(test_name: &str) -> ProgressStyle {
ProgressStyle::default_bar()
.template(
&format!(
"{name} {{spinner:.green}} [{{elapsed_precise}}] [{{wide_bar:.cyan/blue}}] {{bytes}}/{{total_bytes}} ({{eta}})",
name = test_name
)
)
.unwrap()
.with_key("eta", |state: &ProgressState, w: &mut dyn std::fmt::Write|
write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap()
)
.progress_chars("#>-")
}
pub async fn run_echo_test<H: client::Handler>(
session: &client::Handle<H>,
echo_cmd: &str,
char_count: usize,
time_limit: Option<f64>,
formatter: &Formatter,
) -> Result<EchoTestSummary, TestError> {
info!("Running echo latency test");
debug!("Running echo test with command: {echo_cmd:?}");
debug!("Number of characters to echo: {char_count:?}");
debug!("Time limit for echo: {time_limit:?} seconds");
trace!("Preparing channel session");
let mut channel = session
.channel_open_session()
.await
.map_err(|e| TestError::Ssh(e.to_string()))?;
channel
.exec(false, echo_cmd)
.await
.map_err(|e| TestError::Ssh(e.to_string()))?;
trace!("Testing echo latency");
let write_buffer = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
let mut pending_data = VecDeque::new();
let mut latencies = Vec::with_capacity(char_count);
let deadline = time_limit.map(|limit| Instant::now() + Duration::from_secs_f64(limit));
let progress_bar = ProgressBar::new(char_count as u64);
progress_bar.set_style(get_progress_bar_style("Echo test"));
let mut discarded_mismatch = 0usize;
'echo_loop: for (n, &byte) in write_buffer.iter().cycle().enumerate().take(char_count) {
if let Some(deadline) = deadline
&& Instant::now() >= deadline
{
break;
}
let start = Instant::now();
channel
.data(&[byte][..])
.await
.map_err(|e| TestError::Ssh(e.to_string()))?;
loop {
if let Some(received_byte) = pending_data.pop_front() {
if received_byte == byte {
break;
}
discarded_mismatch += 1;
if discarded_mismatch == 1 || discarded_mismatch.is_multiple_of(64) {
trace!(
"Discarding unexpected echo byte (expected {:?}, got {:?}, discarded={discarded_mismatch})",
byte as char,
received_byte as char
);
}
continue;
}
let msg = if let Some(deadline) = deadline {
let now = Instant::now();
if now >= deadline {
break 'echo_loop;
}
match tokio::time::timeout(deadline - now, channel.wait()).await {
Ok(msg) => msg,
Err(_) => break 'echo_loop,
}
} else {
channel.wait().await
};
if let Some(msg) = msg {
match msg {
ChannelMsg::Data { data } => pending_data.extend(data),
ChannelMsg::Eof | ChannelMsg::Close => return Err(TestError::ChannelClosed),
ChannelMsg::ExitStatus { exit_status } => {
return Err(TestError::Ssh(format!(
"Echo command exited unexpectedly with status {exit_status}"
)));
}
_ => {}
}
} else {
return Err(TestError::ChannelClosed);
}
}
let latency = start.elapsed().as_nanos();
latencies.push(latency);
progress_bar.set_position((n as u64) + 1);
}
progress_bar.finish_and_clear();
if latencies.is_empty() {
return Err(TestError::EmptyEchoResult);
}
latencies.sort();
let result = EchoTestSummary::from_latencies(&latencies, formatter)
.map_err(TestError::SummaryCreation)?;
if result.char_sent < 20 {
warn!("Insufficient data points for accurate latency measurement");
}
if log_enabled!(Level::Info) {
let p1_latency = Duration::from_nanos(
latencies
.iter()
.rev()
.nth(result.char_sent / 100)
.unwrap()
.to_owned() as u64,
);
let p5_latency = Duration::from_nanos(
latencies
.iter()
.rev()
.nth(result.char_sent / 20)
.unwrap()
.to_owned() as u64,
);
let p10_latency = Duration::from_nanos(
latencies
.iter()
.rev()
.nth(result.char_sent / 10)
.unwrap()
.to_owned() as u64,
);
info!(
"Sent {}/{char_count}, Latency:\n\tMean:\t{}\n\tStd:\t{}\n\tMin:\t{}\n\tMedian:\t{}\n\tMax:\t{}\n\t1% High:\t{}\n\t5% High:\t{}\n\t10% High:\t{}",
result.char_sent,
result.avg_latency,
result.std_latency,
result.min_latency,
result.med_latency,
result.max_latency,
formatter.format_duration(p1_latency),
formatter.format_duration(p5_latency),
formatter.format_duration(p10_latency)
);
}
Ok(result)
}
async fn run_upload_test<H: client::Handler>(
session: &client::Handle<H>,
size: u64,
chunk_size: u64,
remote_file: &Path,
formatter: &Formatter,
) -> Result<SpeedTestResult, TestError> {
info!("Running upload speed test");
trace!("Establishing SFTP channel");
let channel = session
.channel_open_session()
.await
.map_err(|e| TestError::Ssh(e.to_string()))?;
channel
.request_subsystem(true, "sftp")
.await
.map_err(|e| TestError::Ssh(e.to_string()))?;
let sftp = SftpSession::new(channel.into_stream())
.await
.map_err(|e| TestError::Ssh(e.to_string()))?;
trace!("Generating random data in chunks");
let dist = Uniform::try_from(0..128_u8).unwrap();
let remote_path = remote_file.to_str().ok_or(TestError::InvalidRemotePath)?;
let mut file = sftp
.create(remote_path)
.await
.map_err(|e| TestError::Ssh(e.to_string()))?;
let mut total_bytes_sent = 0;
let mut transfer_time = Duration::ZERO;
let progress_bar = ProgressBar::new(size);
progress_bar.set_style(get_progress_bar_style("Upload test"));
trace!("Sending file in chunks");
while total_bytes_sent < size {
let to_send = chunk_size.min(size - total_bytes_sent) as usize;
let chunk: Vec<u8> = dist
.sample_iter(rng())
.take(to_send)
.map(|v| (v & 0x3f) + 32)
.collect();
let start = Instant::now();
file.write_all(&chunk)
.await
.map_err(|e| TestError::Ssh(e.to_string()))?;
transfer_time += start.elapsed();
total_bytes_sent += chunk.len() as u64;
progress_bar.set_position(total_bytes_sent);
}
progress_bar.finish_and_clear();
file.shutdown()
.await
.map_err(|e| TestError::Ssh(e.to_string()))?;
let result = SpeedTestResult::new(total_bytes_sent, transfer_time, formatter);
info!(
"Sent {}, Time Elapsed: {}, Average Speed: {}",
result.size, result.time, result.speed
);
Ok(result)
}
async fn run_download_test<H: client::Handler>(
session: &client::Handle<H>,
chunk_size: u64,
remote_file: &Path,
formatter: &Formatter,
) -> Result<SpeedTestResult, TestError> {
info!("Running download speed test");
trace!("Establishing SFTP channel");
let channel = session
.channel_open_session()
.await
.map_err(|e| TestError::Ssh(e.to_string()))?;
channel
.request_subsystem(true, "sftp")
.await
.map_err(|e| TestError::Ssh(e.to_string()))?;
let sftp = SftpSession::new(channel.into_stream())
.await
.map_err(|e| TestError::Ssh(e.to_string()))?;
let remote_path = remote_file.to_str().ok_or(TestError::InvalidRemotePath)?;
let metadata = sftp
.metadata(remote_path)
.await
.map_err(|e| TestError::Ssh(e.to_string()))?;
let size = metadata.len();
if size == 0 {
return Err(TestError::EmptyRemoteFile);
}
let mut file = sftp
.open(remote_path)
.await
.map_err(|e| TestError::Ssh(e.to_string()))?;
trace!("Preparing buffer for downloading");
let mut buffer = vec![0; chunk_size as usize];
let mut total_bytes_recv = 0;
let mut transfer_time = Duration::ZERO;
let progress_bar = ProgressBar::new(size);
progress_bar.set_style(get_progress_bar_style("Download test"));
trace!("Receiving file in chunks");
while size - total_bytes_recv > chunk_size {
let start = Instant::now();
file.read_exact(&mut buffer)
.await
.map_err(|e| TestError::Ssh(e.to_string()))?;
transfer_time += start.elapsed();
total_bytes_recv += chunk_size;
progress_bar.set_position(total_bytes_recv);
}
if size - total_bytes_recv > 0 {
let mut remaining = vec![0; (size - total_bytes_recv) as usize];
let start = Instant::now();
file.read_exact(&mut remaining)
.await
.map_err(|e| TestError::Ssh(e.to_string()))?;
transfer_time += start.elapsed();
total_bytes_recv += remaining.len() as u64;
progress_bar.set_position(total_bytes_recv);
}
progress_bar.finish_and_clear();
let result = SpeedTestResult::new(total_bytes_recv, transfer_time, formatter);
info!(
"Received {}, Time Elapsed: {}, Average Speed: {}",
result.size, result.time, result.speed
);
Ok(result)
}
pub async fn run_speed_test<H: client::Handler>(
session: &client::Handle<H>,
size: u64,
chunk_size: u64,
remote_file: &Path,
formatter: &Formatter,
) -> Result<SpeedTestSummary, TestError> {
info!("Running speed test");
debug!(
"Running speed test with file size: {}",
formatter.format_size(size)
);
debug!("Remote file path: {remote_file:?}");
let upload_result = run_upload_test(session, size, chunk_size, remote_file, formatter).await?;
let download_result = run_download_test(session, chunk_size, remote_file, formatter).await?;
Ok(SpeedTestSummary {
upload: upload_result,
download: download_result,
})
}