use crate::error::{FlecheError, Result};
use chrono::Utc;
use std::fs::{File, OpenOptions};
use std::io::Write;
use std::path::PathBuf;
use std::process::Stdio;
use std::time::Duration;
use tokio::process::Command;
const MAX_RETRIES: u32 = 3;
const RETRY_BASE_DELAY: Duration = Duration::from_secs(1);
const DEFAULT_EXEC_TIMEOUT_SECS: u64 = 60;
const DEFAULT_CONNECT_TIMEOUT_SECS: u64 = 30;
fn ssh_log_path() -> Option<PathBuf> {
dirs::config_dir().map(|p| p.join("fleche").join("ssh.log"))
}
#[cfg(unix)]
pub fn ssh_socket_dir() -> PathBuf {
use std::os::unix::fs::PermissionsExt;
let uid = nix::unistd::getuid();
let dir = PathBuf::from(format!("/tmp/fleche-ssh-{uid}"));
let _ = std::fs::create_dir_all(&dir);
let _ = std::fs::set_permissions(&dir, std::fs::Permissions::from_mode(0o700));
dir
}
#[cfg(not(unix))]
pub fn ssh_socket_dir() -> PathBuf {
let dir = dirs::cache_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join("fleche-ssh");
let _ = std::fs::create_dir_all(&dir);
dir
}
fn is_retryable_error(stderr: &str) -> bool {
stderr.contains("Permission denied")
|| stderr.contains("Connection refused")
|| stderr.contains("Connection reset")
|| stderr.contains("Connection timed out")
|| stderr.contains("No route to host")
|| stderr.contains("Host is down")
}
fn format_timeout_error(command: &str, timeout: Duration) -> String {
let is_sbatch = command.contains("sbatch");
let mut msg = format!("Command timed out after {timeout:?}");
if is_sbatch {
msg.push_str("\n\nThis usually means the Slurm scheduler is overloaded or down.");
msg.push_str("\nRun 'fleche ping' to check cluster status.");
} else {
msg.push_str("\n\nThis may indicate:");
msg.push_str("\n - The remote host is slow or overloaded");
msg.push_str("\n - Network connectivity issues");
msg.push_str("\n - A stale SSH connection");
}
msg
}
fn append_to_ssh_log(host: &str, command: &str, stderr: &str) {
let Some(log_path) = ssh_log_path() else {
return;
};
if let Some(parent) = log_path.parent() {
let _ = std::fs::create_dir_all(parent);
}
if let Ok(metadata) = std::fs::metadata(&log_path) {
if metadata.len() > 1_000_000 {
let _ = File::create(&log_path); }
}
let Ok(mut file) = OpenOptions::new().create(true).append(true).open(&log_path) else {
return;
};
let timestamp = Utc::now().format("%Y-%m-%d %H:%M:%S UTC");
let _ = writeln!(file, "\n=== [{timestamp}] ssh {host} {command} ===");
let _ = writeln!(file, "{stderr}");
}
pub struct SshClient {
host: String,
debug: bool,
exec_timeout: Duration,
connect_timeout: Duration,
}
impl SshClient {
pub fn new(host: &str, debug: bool) -> Self {
Self::with_timeouts(
host,
debug,
DEFAULT_EXEC_TIMEOUT_SECS,
DEFAULT_CONNECT_TIMEOUT_SECS,
)
}
pub fn with_timeouts(
host: &str,
debug: bool,
exec_timeout_secs: u64,
connect_timeout_secs: u64,
) -> Self {
SshClient {
host: host.to_string(),
debug,
exec_timeout: Duration::from_secs(exec_timeout_secs),
connect_timeout: Duration::from_secs(connect_timeout_secs),
}
}
#[cfg(unix)]
async fn kill_control_socket(&self) {
let socket_dir = ssh_socket_dir();
let control_path = socket_dir.join("%r@%h-%p");
let _ = Command::new("ssh")
.args(["-O", "exit"])
.args(["-o", &format!("ControlPath=\"{}\"", control_path.display())])
.arg(&self.host)
.output()
.await;
if let Ok(mut entries) = tokio::fs::read_dir(&socket_dir).await {
while let Ok(Some(entry)) = entries.next_entry().await {
let name = entry.file_name();
if name.to_string_lossy().contains(&self.host) {
let _ = tokio::fs::remove_file(entry.path()).await;
}
}
}
append_to_ssh_log(
&self.host,
"[socket cleanup]",
"Killed stale control socket",
);
}
#[cfg(not(unix))]
async fn kill_control_socket(&self) {}
fn ssh_args(&self) -> Vec<String> {
let mut args = vec![
"-o".to_string(),
"ClearAllForwardings=yes".to_string(),
"-o".to_string(),
format!("ConnectTimeout={}", self.connect_timeout.as_secs()),
"-o".to_string(),
"ServerAliveInterval=15".to_string(),
"-o".to_string(),
"ServerAliveCountMax=3".to_string(),
"-o".to_string(),
"BatchMode=yes".to_string(),
];
#[cfg(unix)]
{
let socket_dir = ssh_socket_dir();
let control_path = socket_dir.join("%r@%h-%p");
args.extend([
"-o".to_string(),
"ControlMaster=auto".to_string(),
"-o".to_string(),
format!("ControlPath=\"{}\"", control_path.display()),
"-o".to_string(),
"ControlPersist=600".to_string(),
]);
}
if self.debug {
args.insert(0, "-v".to_string());
}
args
}
pub async fn exec(&self, command: &str) -> Result<String> {
match self.exec_inner(command).await {
Ok(result) => Ok(result),
Err(FlecheError::SshTimeout(_)) => {
self.kill_control_socket().await;
self.exec_inner(command).await
}
Err(e) => Err(e),
}
}
async fn exec_inner(&self, command: &str) -> Result<String> {
let mut last_error = None;
for attempt in 0..=MAX_RETRIES {
if attempt > 0 {
let delay = RETRY_BASE_DELAY * 2_u32.pow(attempt - 1);
append_to_ssh_log(
&self.host,
command,
&format!("Retry attempt {attempt}/{MAX_RETRIES} after {delay:?}"),
);
tokio::time::sleep(delay).await;
}
let output_future = Command::new("ssh")
.args(self.ssh_args())
.arg(&self.host)
.arg(command)
.output();
let output = match tokio::time::timeout(self.exec_timeout, output_future).await {
Ok(Ok(output)) => output,
Ok(Err(e)) => {
return Err(FlecheError::SshConnection(format!(
"Failed to execute ssh: {e}"
)));
}
Err(_) => {
append_to_ssh_log(
&self.host,
command,
&format!("Command timed out after {:?}", self.exec_timeout),
);
return Err(FlecheError::SshTimeout(format_timeout_error(
command,
self.exec_timeout,
)));
}
};
let stderr = String::from_utf8_lossy(&output.stderr);
append_to_ssh_log(&self.host, command, &stderr);
if self.debug {
eprint!("{stderr}");
}
if output.status.success() {
return Ok(String::from_utf8_lossy(&output.stdout).to_string());
}
let stdout = String::from_utf8_lossy(&output.stdout);
let error = FlecheError::SshCommand(format!(
"Command failed with exit code {:?}\nstdout: {}\nstderr: {}",
output.status.code(),
stdout,
stderr
));
if !is_retryable_error(&stderr) {
return Err(error);
}
last_error = Some(error);
}
Err(last_error.expect("loop sets last_error on retryable failures"))
}
pub async fn exec_allow_failure(&self, command: &str) -> Result<(bool, String, String)> {
match self.exec_allow_failure_inner(command).await {
Ok(result) => Ok(result),
Err(FlecheError::SshTimeout(_)) => {
self.kill_control_socket().await;
self.exec_allow_failure_inner(command).await
}
Err(e) => Err(e),
}
}
async fn exec_allow_failure_inner(&self, command: &str) -> Result<(bool, String, String)> {
for attempt in 0..=MAX_RETRIES {
if attempt > 0 {
let delay = RETRY_BASE_DELAY * 2_u32.pow(attempt - 1);
append_to_ssh_log(
&self.host,
command,
&format!("Retry attempt {attempt}/{MAX_RETRIES} after {delay:?}"),
);
tokio::time::sleep(delay).await;
}
let output_future = Command::new("ssh")
.args(self.ssh_args())
.arg(&self.host)
.arg(command)
.output();
let output = match tokio::time::timeout(self.exec_timeout, output_future).await {
Ok(Ok(output)) => output,
Ok(Err(e)) => {
return Err(FlecheError::SshConnection(format!(
"Failed to execute ssh: {e}"
)));
}
Err(_) => {
append_to_ssh_log(
&self.host,
command,
&format!("Command timed out after {:?}", self.exec_timeout),
);
return Err(FlecheError::SshTimeout(format_timeout_error(
command,
self.exec_timeout,
)));
}
};
let stdout = String::from_utf8_lossy(&output.stdout).to_string();
let stderr = String::from_utf8_lossy(&output.stderr).to_string();
append_to_ssh_log(&self.host, command, &stderr);
if self.debug {
eprint!("{stderr}");
}
if output.status.code() == Some(255) && is_retryable_error(&stderr) {
continue;
}
return Ok((output.status.success(), stdout, stderr));
}
let output_future = Command::new("ssh")
.args(self.ssh_args())
.arg(&self.host)
.arg(command)
.output();
let output = match tokio::time::timeout(self.exec_timeout, output_future).await {
Ok(Ok(output)) => output,
Ok(Err(e)) => {
return Err(FlecheError::SshConnection(format!(
"Failed to execute ssh: {e}"
)));
}
Err(_) => {
return Err(FlecheError::SshTimeout(format_timeout_error(
command,
self.exec_timeout,
)));
}
};
let stdout = String::from_utf8_lossy(&output.stdout).to_string();
let stderr = String::from_utf8_lossy(&output.stderr).to_string();
Ok((output.status.success(), stdout, stderr))
}
pub async fn mkdir(&self, path: &str) -> Result<()> {
self.exec(&format!("mkdir -p {}", shell_escape(path)))
.await?;
Ok(())
}
pub async fn rm_rf(&self, path: &str) -> Result<()> {
self.exec(&format!("rm -rf {}", shell_escape(path))).await?;
Ok(())
}
pub async fn write_file(&self, path: &str, content: &str) -> Result<()> {
let command = format!(
"cat > {} << 'RJOB_EOF'\n{}\nRJOB_EOF",
shell_escape(path),
content
);
self.exec(&command).await?;
Ok(())
}
pub async fn cat_tail(&self, path: &str, tail: Option<usize>) -> Result<String> {
let cmd = if let Some(n) = tail {
format!("tail -n {n} {}", shell_escape(path))
} else {
format!("cat {}", shell_escape(path))
};
self.exec(&cmd).await
}
pub async fn is_dir(&self, path: &str) -> Result<bool> {
let (success, _, _) = self
.exec_allow_failure(&format!("test -d {}", shell_escape(path)))
.await?;
Ok(success)
}
pub async fn list_files_recursive(&self, path: &str) -> Result<Vec<String>> {
let output = self
.exec(&format!(
"find {} -type f 2>/dev/null || true",
shell_escape(path)
))
.await?;
let prefix = format!("{}/", path.trim_end_matches('/'));
Ok(output
.lines()
.filter(|line| !line.is_empty())
.filter_map(|line| line.strip_prefix(&prefix))
.map(String::from)
.collect())
}
pub fn tail_follow(&self, paths: &[&str]) -> Result<tokio::process::Child> {
let stderr_cfg = if self.debug {
Stdio::inherit()
} else {
Stdio::null()
};
let escaped: Vec<String> = paths.iter().map(|p| shell_escape(p)).collect();
let paths_arg = escaped.join(" ");
let child = Command::new("ssh")
.args(self.ssh_args())
.arg(&self.host)
.arg(format!("tail -F -n +1 -q {paths_arg} 2>/dev/null"))
.stdout(Stdio::inherit())
.stderr(stderr_cfg)
.spawn()
.map_err(|e| FlecheError::SshConnection(format!("Failed to spawn ssh: {e}")))?;
Ok(child)
}
}
pub fn shell_escape(s: &str) -> String {
if let Some(rest) = s.strip_prefix("~/") {
format!("~/{}", quote_with_vars(rest))
} else {
quote_with_vars(s)
}
}
fn quote_with_vars(s: &str) -> String {
let mut segments: Vec<String> = Vec::new();
let mut rest = s;
while let Some(start) = rest.find("${") {
if let Some(end) = rest[start..].find('}') {
let literal = &rest[..start];
if !literal.is_empty() {
segments.push(quote_single(literal));
}
segments.push(rest[start..=(start + end)].to_string());
rest = &rest[start + end + 1..];
} else {
break;
}
}
if !rest.is_empty() {
segments.push(quote_single(rest));
}
if segments.is_empty() {
"''".to_string()
} else {
segments.join("")
}
}
fn quote_single(s: &str) -> String {
format!("'{}'", s.replace('\'', "'\\''"))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_quote_with_vars_simple() {
assert_eq!(quote_with_vars("hello"), "'hello'");
assert_eq!(quote_with_vars("path/to/file"), "'path/to/file'");
}
#[test]
fn test_quote_with_vars_spaces() {
assert_eq!(quote_with_vars("hello world"), "'hello world'");
}
#[test]
fn test_quote_with_vars_single_quotes() {
assert_eq!(quote_with_vars("it's"), "'it'\\''s'");
}
#[test]
fn test_quote_with_vars_empty() {
assert_eq!(quote_with_vars(""), "''");
}
#[test]
fn test_quote_with_vars_variable() {
assert_eq!(
quote_with_vars("/scratch/${SSH_USER}/fleche"),
"'/scratch/'${SSH_USER}'/fleche'"
);
}
#[test]
fn test_quote_with_vars_multiple() {
assert_eq!(quote_with_vars("${A}/mid/${B}"), "${A}'/mid/'${B}");
}
#[test]
fn test_quote_with_vars_unclosed_brace() {
assert_eq!(quote_with_vars("${NOPE"), "'${NOPE'");
}
#[test]
fn test_shell_escape_simple() {
assert_eq!(shell_escape("hello"), "'hello'");
assert_eq!(shell_escape("/path/to/file"), "'/path/to/file'");
}
#[test]
fn test_shell_escape_tilde_expansion() {
assert_eq!(shell_escape("~/path"), "~/'path'");
assert_eq!(shell_escape("~/path/to/file"), "~/'path/to/file'");
}
#[test]
fn test_shell_escape_tilde_not_at_start() {
assert_eq!(shell_escape("/home/~user"), "'/home/~user'");
assert_eq!(shell_escape("some~path"), "'some~path'");
}
#[test]
fn test_shell_escape_special_chars() {
assert_eq!(shell_escape("file with spaces"), "'file with spaces'");
assert_eq!(shell_escape("file$var"), "'file$var'");
assert_eq!(shell_escape("file;cmd"), "'file;cmd'");
}
#[test]
fn test_shell_escape_tilde_with_special_chars() {
assert_eq!(shell_escape("~/my files"), "~/'my files'");
assert_eq!(shell_escape("~/path's"), "~/'path'\\''s'");
}
#[test]
fn test_shell_escape_variable_in_path() {
assert_eq!(
shell_escape("/scratch/users/${SSH_USER}/fleche"),
"'/scratch/users/'${SSH_USER}'/fleche'"
);
}
#[test]
fn test_shell_escape_tilde_with_variable() {
assert_eq!(shell_escape("~/${USER}/fleche"), "~/${USER}'/fleche'");
}
}