use bssh::ssh::tokio_client::{AuthMethod, Client, CommandOutput, ServerCheckMethod};
use tokio::sync::mpsc::channel;
type OutputBuffer = (
tokio::sync::mpsc::Sender<CommandOutput>,
tokio::task::JoinHandle<(Vec<u8>, Vec<u8>)>,
);
fn build_test_output_buffer() -> OutputBuffer {
let (sender, mut receiver) = channel(100);
let receiver_task = tokio::task::spawn(async move {
let mut stdout = Vec::new();
let mut stderr = Vec::new();
while let Some(output) = receiver.recv().await {
match output {
CommandOutput::StdOut(buffer) => stdout.extend_from_slice(&buffer),
CommandOutput::StdErr(buffer) => stderr.extend_from_slice(&buffer),
CommandOutput::ExitCode(_) => {
}
}
}
(stdout, stderr)
});
(sender, receiver_task)
}
fn can_ssh_to_localhost() -> bool {
use std::process::Command;
let output = Command::new("ssh")
.args([
"-o",
"ConnectTimeout=2",
"-o",
"StrictHostKeyChecking=no",
"-o",
"UserKnownHostsFile=/dev/null",
"-o",
"PasswordAuthentication=no",
"-o",
"BatchMode=yes",
"localhost",
"echo",
"test",
])
.output();
match output {
Ok(result) => result.status.success(),
Err(_) => false,
}
}
#[tokio::test]
async fn test_localhost_execute_streaming_output() {
if !can_ssh_to_localhost() {
eprintln!("Skipping streaming test: Cannot SSH to localhost");
return;
}
let username = std::env::var("USER").unwrap_or_else(|_| "root".to_string());
let client = Client::connect(
("localhost", 22),
&username,
AuthMethod::Agent, ServerCheckMethod::NoCheck,
)
.await;
if client.is_err() {
eprintln!("Skipping streaming test: Cannot connect to localhost");
return;
}
let client = client.unwrap();
let (sender, receiver_task) = build_test_output_buffer();
let exit_status = client
.execute_streaming("echo 'Hello from streaming test'", sender)
.await;
assert!(exit_status.is_ok(), "Command should execute successfully");
let exit_status = exit_status.unwrap();
assert_eq!(exit_status, 0, "Command should exit with status 0");
let (stdout_bytes, stderr_bytes) = receiver_task.await.unwrap();
let stdout = String::from_utf8_lossy(&stdout_bytes);
let stderr = String::from_utf8_lossy(&stderr_bytes);
assert!(
stdout.contains("Hello from streaming test"),
"Stdout should contain test message, got: {stdout}"
);
assert_eq!(stderr, "", "Stderr should be empty, got: {stderr}");
}
#[tokio::test]
async fn test_backward_compatibility_execute() {
if !can_ssh_to_localhost() {
eprintln!("Skipping backward compatibility test: Cannot SSH to localhost");
return;
}
let username = std::env::var("USER").unwrap_or_else(|_| "root".to_string());
let client = Client::connect(
("localhost", 22),
&username,
AuthMethod::Agent,
ServerCheckMethod::NoCheck,
)
.await;
if client.is_err() {
eprintln!("Skipping backward compatibility test: Cannot connect to localhost");
return;
}
let client = client.unwrap();
let result = client.execute("echo 'Backward compatibility test'").await;
assert!(result.is_ok(), "Command should execute successfully");
let result = result.unwrap();
assert_eq!(result.exit_status, 0, "Command should exit with status 0");
assert!(
result.stdout.contains("Backward compatibility test"),
"Stdout should contain test message, got: {}",
result.stdout
);
assert_eq!(
result.stderr, "",
"Stderr should be empty, got: {}",
result.stderr
);
}
#[tokio::test]
async fn test_streaming_with_stderr() {
if !can_ssh_to_localhost() {
eprintln!("Skipping stderr streaming test: Cannot SSH to localhost");
return;
}
let username = std::env::var("USER").unwrap_or_else(|_| "root".to_string());
let client = Client::connect(
("localhost", 22),
&username,
AuthMethod::Agent,
ServerCheckMethod::NoCheck,
)
.await;
if client.is_err() {
eprintln!("Skipping stderr streaming test: Cannot connect to localhost");
return;
}
let client = client.unwrap();
let (sender, receiver_task) = build_test_output_buffer();
let exit_status = client
.execute_streaming("echo 'stdout message' && echo 'stderr message' >&2", sender)
.await;
assert!(exit_status.is_ok(), "Command should execute successfully");
let (stdout_bytes, stderr_bytes) = receiver_task.await.unwrap();
let stdout = String::from_utf8_lossy(&stdout_bytes);
let stderr = String::from_utf8_lossy(&stderr_bytes);
assert!(
stdout.contains("stdout message"),
"Stdout should contain stdout message, got: {stdout}"
);
assert!(
stderr.contains("stderr message"),
"Stderr should contain stderr message, got: {stderr}"
);
}
#[tokio::test]
async fn test_streaming_large_output_backpressure() {
if !can_ssh_to_localhost() {
eprintln!("Skipping large output test: Cannot SSH to localhost");
return;
}
let username = std::env::var("USER").unwrap_or_else(|_| "root".to_string());
let client = Client::connect(
("localhost", 22),
&username,
AuthMethod::Agent,
ServerCheckMethod::NoCheck,
)
.await;
if client.is_err() {
eprintln!("Skipping large output test: Cannot connect to localhost");
return;
}
let client = client.unwrap();
let (sender, receiver_task) = build_test_output_buffer();
let exit_status = client
.execute_streaming("for i in {1..10000}; do echo \"Line $i\"; done", sender)
.await;
assert!(
exit_status.is_ok(),
"Large output command should execute successfully"
);
let exit_status = exit_status.unwrap();
assert_eq!(exit_status, 0, "Command should exit with status 0");
let (stdout_bytes, _stderr_bytes) = receiver_task.await.unwrap();
let stdout = String::from_utf8_lossy(&stdout_bytes);
assert!(stdout.contains("Line 1"), "Should contain first line");
assert!(stdout.contains("Line 10000"), "Should contain last line");
let line_count = stdout.lines().count();
assert_eq!(
line_count, 10000,
"Should have exactly 10000 lines, got: {line_count}"
);
}
#[tokio::test]
async fn test_streaming_receiver_drop_handling() {
if !can_ssh_to_localhost() {
eprintln!("Skipping receiver drop test: Cannot SSH to localhost");
return;
}
let username = std::env::var("USER").unwrap_or_else(|_| "root".to_string());
let client = Client::connect(
("localhost", 22),
&username,
AuthMethod::Agent,
ServerCheckMethod::NoCheck,
)
.await;
if client.is_err() {
eprintln!("Skipping receiver drop test: Cannot connect to localhost");
return;
}
let client = client.unwrap();
let (sender, receiver) = channel(100);
drop(receiver);
let exit_status = client.execute_streaming("echo 'test output'", sender).await;
assert!(
exit_status.is_ok(),
"Command should handle receiver drop gracefully"
);
let exit_status = exit_status.unwrap();
assert_eq!(
exit_status, 0,
"Command should still report correct exit status"
);
}