use crate::sftp_client::{
FileEntryType, RemoteFileEntry, format_permissions, format_unix_timestamp,
};
use anyhow::Result;
use russh::*;
use russh_keys::PublicKeyBase64;
use russh_keys::*;
use russh_sftp::client::SftpSession;
use serde::{Deserialize, Serialize};
use std::path::Path;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
pub mod host_keys;
pub mod shell;
pub use host_keys::{
HostKeyMismatch, HostKeyStore, HostKeyStoreAccessError, HostKeyVerificationFailure, Verdict,
VerificationFailureSlot,
};
pub const SFTP_CHUNK_SIZE: usize = 32 * 1024;
pub static PREFERRED_HOST_KEY_ALGOS: &[russh_keys::key::Name] = &[
russh_keys::key::ED25519,
russh_keys::key::ECDSA_SHA2_NISTP256,
russh_keys::key::ECDSA_SHA2_NISTP521,
russh_keys::key::RSA_SHA2_256,
russh_keys::key::RSA_SHA2_512,
russh_keys::key::SSH_RSA,
];
pub static PREFERRED_KEX_ALGOS: &[russh::kex::Name] = &[
russh::kex::CURVE25519,
russh::kex::CURVE25519_PRE_RFC_8731,
russh::kex::DH_G16_SHA512,
russh::kex::DH_G14_SHA256,
russh::kex::DH_G14_SHA1,
russh::kex::DH_G1_SHA1,
russh::kex::EXTENSION_SUPPORT_AS_CLIENT,
russh::kex::EXTENSION_OPENSSH_STRICT_KEX_AS_CLIENT,
];
#[derive(Clone, Serialize, Deserialize)]
pub struct SshConfig {
pub host: String,
pub port: u16,
pub username: String,
pub auth_method: AuthMethod,
}
impl std::fmt::Debug for SshConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SshConfig")
.field("host", &self.host)
.field("port", &self.port)
.field("username", &self.username)
.field("auth_method", &self.auth_method)
.finish()
}
}
#[derive(Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum AuthMethod {
Password {
password: String,
},
PublicKey {
key_path: String,
passphrase: Option<String>,
},
Agent {
identity_hint: Option<String>,
},
}
impl std::fmt::Debug for AuthMethod {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AuthMethod::Password { .. } => f
.debug_struct("AuthMethod::Password")
.field("password", &"<redacted>")
.finish(),
AuthMethod::PublicKey {
key_path,
passphrase,
} => f
.debug_struct("AuthMethod::PublicKey")
.field("key_path", key_path)
.field(
"passphrase",
&passphrase
.as_ref()
.map(|_| "<redacted>")
.unwrap_or("<none>"),
)
.finish(),
AuthMethod::Agent { identity_hint } => f
.debug_struct("AuthMethod::Agent")
.field("identity_hint", identity_hint)
.finish(),
}
}
}
pub struct SshClient {
session: Option<Arc<client::Handle<Client>>>,
host_keys: Arc<HostKeyStore>,
sftp: tokio::sync::OnceCell<Arc<SftpSession>>,
}
#[derive(Debug, Clone, Default)]
pub struct CommandOutput {
pub stdout: String,
pub stderr: String,
pub exit_code: Option<u32>,
}
impl CommandOutput {
pub fn is_success(&self) -> bool {
matches!(self.exit_code, Some(0))
}
pub fn combined(&self) -> String {
if self.stderr.is_empty() {
self.stdout.clone()
} else if self.stdout.is_empty() {
self.stderr.clone()
} else {
let mut out = String::with_capacity(self.stdout.len() + self.stderr.len() + 1);
out.push_str(&self.stdout);
if !self.stdout.ends_with('\n') {
out.push('\n');
}
out.push_str(&self.stderr);
out
}
}
}
pub struct PtySession {
pub input_tx: mpsc::Sender<Vec<u8>>,
pub output_rx: Arc<tokio::sync::Mutex<mpsc::Receiver<Vec<u8>>>>,
pub resize_tx: mpsc::Sender<(u32, u32)>,
pub cancel: CancellationToken,
}
pub struct Client {
host: String,
port: u16,
store: Arc<HostKeyStore>,
verification_failure_slot: VerificationFailureSlot,
}
impl Client {
pub fn new(
host: impl Into<String>,
port: u16,
store: Arc<HostKeyStore>,
) -> (Self, VerificationFailureSlot) {
let slot: VerificationFailureSlot = Arc::new(std::sync::Mutex::new(None));
let client = Self {
host: host.into(),
port,
store,
verification_failure_slot: slot.clone(),
};
(client, slot)
}
}
#[async_trait::async_trait]
impl client::Handler for Client {
type Error = russh::Error;
async fn check_server_key(
&mut self,
server_public_key: &key::PublicKey,
) -> Result<bool, Self::Error> {
match self
.store
.verify(&self.host, self.port, server_public_key)
.await
{
Ok(Verdict::Known) => {
tracing::debug!(
"host key for {}:{} matches known_hosts",
self.host,
self.port
);
Ok(true)
}
Ok(Verdict::Unknown) => {
tracing::warn!(
"TOFU: trusting new host key for {}:{} (fingerprint SHA256:{})",
self.host,
self.port,
server_public_key.fingerprint()
);
if let Err(e) = self
.store
.trust(&self.host, self.port, server_public_key)
.await
{
tracing::error!("failed to persist host key: {}", e);
if let Ok(mut slot) = self.verification_failure_slot.lock() {
*slot = Some(HostKeyVerificationFailure::StoreAccess(
HostKeyStoreAccessError {
host: self.host.clone(),
port: self.port,
store_path: self.store.path().to_path_buf(),
operation: "write",
source: e.to_string(),
},
));
}
return Err(
std::io::Error::other("failed to persist trusted SSH host key").into(),
);
}
Ok(true)
}
Ok(Verdict::Mismatch {
expected_fingerprint,
got_fingerprint,
}) => {
tracing::error!(
"host key mismatch for {}:{} — expected SHA256:{}, got SHA256:{}",
self.host,
self.port,
expected_fingerprint,
got_fingerprint
);
if let Ok(mut slot) = self.verification_failure_slot.lock() {
*slot = Some(HostKeyVerificationFailure::Mismatch(HostKeyMismatch {
host: self.host.clone(),
port: self.port,
expected_fingerprint,
got_fingerprint,
store_path: self.store.path().to_path_buf(),
}));
}
Ok(false)
}
Err(e) => {
tracing::error!("failed to access host-key store: {}", e);
if let Ok(mut slot) = self.verification_failure_slot.lock() {
*slot = Some(HostKeyVerificationFailure::StoreAccess(
HostKeyStoreAccessError {
host: self.host.clone(),
port: self.port,
store_path: self.store.path().to_path_buf(),
operation: "read",
source: e.to_string(),
},
));
}
Err(std::io::Error::other("failed to access SSH host-key store").into())
}
}
}
}
pub(crate) enum ResolvedAuth<'a> {
Password {
password: &'a str,
},
Key {
key: Box<key::KeyPair>,
key_path_hint: Option<&'a str>,
},
Agent {
identity_hint: Option<&'a str>,
},
}
pub(crate) async fn connect_authenticated(
host: &str,
port: u16,
username: &str,
auth: ResolvedAuth<'_>,
timeout: Duration,
host_keys: Arc<HostKeyStore>,
) -> Result<client::Handle<Client>> {
let ssh_config = client::Config {
preferred: russh::Preferred {
key: PREFERRED_HOST_KEY_ALGOS,
kex: PREFERRED_KEX_ALGOS,
..russh::Preferred::DEFAULT
},
keepalive_interval: Some(Duration::from_secs(60)),
keepalive_max: 3,
..client::Config::default()
};
let (handler, verification_failure_slot) = Client::new(host, port, host_keys);
let mut session = tokio::time::timeout(
timeout,
client::connect(Arc::new(ssh_config), (host, port), handler),
)
.await
.map_err(|_| {
anyhow::anyhow!(
"Connection timed out after {}s. Please check the host address and network.",
timeout.as_secs()
)
})?
.map_err(|e| {
if let Ok(mut guard) = verification_failure_slot.lock()
&& let Some(failure) = guard.take()
{
return anyhow::anyhow!(format_verification_failure(&failure));
}
let msg = e.to_string();
let looks_like_reset = msg.contains("reset by peer")
|| msg.contains("ConnectionReset")
|| msg.contains("kex_exchange_identification");
if looks_like_reset {
return anyhow::anyhow!(
"The SSH server at {}:{} accepted the TCP connection but then \
reset it during the handshake ({}).\n\n\
This usually means the server is rejecting your source IP \
or SSH client via a firewall / access list. Try:\n\
- Confirm your public IP is on the server's allowlist (ask \
the service operator).\n\
- Connect over a VPN that terminates inside the allowed \
network.\n\
- Verify the host and port are correct for external access \
(some services publish a different SFTP endpoint).",
host,
port,
e
);
}
anyhow::anyhow!("Failed to connect to {}:{}: {}", host, port, e)
})?;
let key_hint_for_error = match &auth {
ResolvedAuth::Password { .. } => None,
ResolvedAuth::Key { key_path_hint, .. } => key_path_hint.map(String::from),
ResolvedAuth::Agent { identity_hint } => Some(
identity_hint
.filter(|hint| !hint.is_empty())
.unwrap_or("SSH agent")
.to_string(),
),
};
let authenticated = match auth {
ResolvedAuth::Password { password } => session
.authenticate_password(username, password)
.await
.map_err(|e| anyhow::anyhow!("Password authentication failed: {}", e))?,
ResolvedAuth::Key { key, .. } => session
.authenticate_publickey(username, Arc::new(*key))
.await
.map_err(|e| {
anyhow::anyhow!(
"Public key authentication failed: {}. The key may not be authorized on the server.",
e
)
})?,
ResolvedAuth::Agent { identity_hint } => {
let mut agent = russh_keys::agent::client::AgentClient::connect_env()
.await
.map_err(|e| {
anyhow::anyhow!(
"SSH agent authentication is enabled, but r-shell could not connect to SSH_AUTH_SOCK: {}",
e
)
})?;
let identities = agent.request_identities().await.map_err(|e| {
anyhow::anyhow!("SSH agent did not return identities: {}", e)
})?;
let key = select_agent_identity(identities, identity_hint).ok_or_else(|| {
if let Some(hint) = identity_hint.filter(|hint| !hint.is_empty()) {
anyhow::anyhow!(
"SSH agent has no identity matching '{}'. Add the key to your agent or clear the identity hint.",
hint
)
} else {
anyhow::anyhow!("SSH agent has no identities. Add a key to your agent and try again.")
}
})?;
let (_agent, result) = session.authenticate_future(username.to_string(), key, agent).await;
result.map_err(|e| anyhow::anyhow!("SSH agent authentication failed: {}", e))?
}
};
if !authenticated {
return Err(match key_hint_for_error {
None => anyhow::anyhow!(
"Authentication failed for {}@{} with password authentication.",
username,
host
),
Some(path) => anyhow::anyhow!(
"Authentication failed for {}@{} using public key {}.",
username,
host,
path
),
});
}
Ok(session)
}
fn select_agent_identity(
identities: Vec<key::PublicKey>,
identity_hint: Option<&str>,
) -> Option<key::PublicKey> {
let hint = identity_hint.map(str::trim).filter(|hint| !hint.is_empty());
match hint {
None => identities.into_iter().next(),
Some(hint) => identities.into_iter().find(|identity| {
let encoded = identity.public_key_base64();
encoded.contains(hint) || hint.contains(&encoded)
}),
}
}
pub fn format_mismatch(m: &HostKeyMismatch) -> String {
format!(
"Host key verification failed for {}:{}.\n\
Expected fingerprint (stored): SHA256:{}\n\
Offered fingerprint (server): SHA256:{}\n\
If the remote host legitimately rotated its key, remove the entry from:\n {}",
m.host,
m.port,
m.expected_fingerprint,
m.got_fingerprint,
m.store_path.display()
)
}
fn format_store_access_error(err: &HostKeyStoreAccessError) -> String {
format!(
"Host key verification could not complete for {}:{}.\n\
r-shell could not {} the trusted host-key store at:\n {}\n\
Underlying error: {}\n\
Connection refused to avoid trusting a host key without a durable trust store.",
err.host,
err.port,
err.operation,
err.store_path.display(),
err.source
)
}
pub fn format_verification_failure(failure: &HostKeyVerificationFailure) -> String {
match failure {
HostKeyVerificationFailure::Mismatch(mismatch) => format_mismatch(mismatch),
HostKeyVerificationFailure::StoreAccess(err) => format_store_access_error(err),
}
}
pub(crate) fn expand_home_path(path: &str) -> Option<String> {
if let Some(rest) = path.strip_prefix("~/") {
let home = dirs::home_dir()?;
Some(home.join(rest).to_string_lossy().into_owned())
} else if path == "~" {
dirs::home_dir().map(|h| h.to_string_lossy().into_owned())
} else {
Some(path.to_string())
}
}
pub(crate) fn load_private_key(key_path: &str, passphrase: Option<&str>) -> Result<key::KeyPair> {
let expanded = expand_home_path(key_path).ok_or_else(|| {
anyhow::anyhow!(
"Cannot resolve '~' in SSH key path '{}': home directory unknown.",
key_path
)
})?;
let path = Path::new(&expanded);
let location = if expanded != key_path {
format!("{} (expanded from {})", expanded, key_path)
} else {
expanded.clone()
};
match path.metadata() {
Ok(_) => {}
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
return Err(anyhow::anyhow!(
"SSH key file not found: {}. Please check the file path and try again.",
location
));
}
Err(e) if e.kind() == std::io::ErrorKind::PermissionDenied => {
return Err(anyhow::anyhow!(
"Permission denied reading SSH key at {}.\n\
On macOS this usually means r-shell hasn't been granted access to this file. \
Open System Settings → Privacy & Security → Full Disk Access (or App Management / Files and Folders), \
add r-shell to the list, then try again.",
location
));
}
Err(e) => {
return Err(anyhow::anyhow!(
"Cannot access SSH key at {}: {}",
location,
e
));
}
}
load_secret_key(path, passphrase).map_err(|e| {
let msg = e.to_string();
if msg.contains("encrypted") || msg.contains("passphrase") {
anyhow::anyhow!(
"Failed to decrypt SSH key at {}. The key may be encrypted. Please provide the correct passphrase.",
expanded
)
} else {
anyhow::anyhow!(
"Failed to load SSH key from {}: {}. Ensure the file is a valid SSH private key (RSA, Ed25519, or ECDSA).",
expanded, e
)
}
})
}
impl SshClient {
pub fn new(host_keys: Arc<HostKeyStore>) -> Self {
Self {
session: None,
host_keys,
sftp: tokio::sync::OnceCell::new(),
}
}
async fn sftp_session(&self) -> Result<Arc<SftpSession>> {
let session = self
.session
.as_ref()
.ok_or_else(|| anyhow::anyhow!("Not connected"))?
.clone();
let sftp = self
.sftp
.get_or_try_init(|| async move {
let channel = session.channel_open_session().await?;
channel.request_subsystem(true, "sftp").await?;
let session = SftpSession::new(channel.into_stream()).await?;
Ok::<_, anyhow::Error>(Arc::new(session))
})
.await?;
Ok(sftp.clone())
}
pub async fn connect(&mut self, config: &SshConfig) -> Result<()> {
let auth = match &config.auth_method {
AuthMethod::Password { password } => ResolvedAuth::Password { password },
AuthMethod::PublicKey {
key_path,
passphrase,
} => ResolvedAuth::Key {
key: Box::new(load_private_key(key_path, passphrase.as_deref())?),
key_path_hint: Some(key_path),
},
AuthMethod::Agent { identity_hint } => ResolvedAuth::Agent {
identity_hint: identity_hint.as_deref(),
},
};
let session = connect_authenticated(
&config.host,
config.port,
&config.username,
auth,
Duration::from_secs(10),
self.host_keys.clone(),
)
.await?;
self.session = Some(Arc::new(session));
Ok(())
}
pub async fn execute_command(&self, command: &str) -> Result<String> {
let out = self.execute_command_full(command).await?;
Ok(out.combined())
}
pub async fn execute_command_full(&self, command: &str) -> Result<CommandOutput> {
let Some(session) = &self.session else {
return Err(anyhow::anyhow!("Not connected"));
};
let mut channel = session.channel_open_session().await?;
channel.exec(true, command).await?;
let mut stdout = String::new();
let mut stderr = String::new();
let mut exit_code: Option<u32> = None;
let mut eof_received = false;
loop {
let msg = channel.wait().await;
match msg {
Some(ChannelMsg::Data { ref data }) => {
stdout.push_str(&String::from_utf8_lossy(data));
}
Some(ChannelMsg::ExtendedData { ref data, .. }) => {
stderr.push_str(&String::from_utf8_lossy(data));
}
Some(ChannelMsg::ExitStatus { exit_status }) => {
exit_code = Some(exit_status);
if eof_received {
break;
}
}
Some(ChannelMsg::Eof) => {
eof_received = true;
if exit_code.is_some() {
break;
}
}
Some(ChannelMsg::Close) | None => {
break;
}
_ => {}
}
}
Ok(CommandOutput {
stdout,
stderr,
exit_code,
})
}
pub async fn execute_command_streaming(
&self,
command: &str,
) -> Result<(mpsc::Receiver<String>, CancellationToken)> {
let Some(session) = &self.session else {
return Err(anyhow::anyhow!("Not connected"));
};
let mut channel = session.channel_open_session().await?;
channel.exec(true, command).await?;
let (tx, rx) = mpsc::channel::<String>(256);
let cancel = CancellationToken::new();
let cancel_task = cancel.clone();
tokio::spawn(async move {
let mut stdout_buf = String::new();
let mut stderr_buf = String::new();
loop {
tokio::select! {
_ = cancel_task.cancelled() => {
let _ = channel.eof().await;
let _ = channel.close().await;
break;
}
msg = channel.wait() => {
match msg {
Some(ChannelMsg::Data { ref data }) => {
stdout_buf.push_str(&String::from_utf8_lossy(data));
while let Some(idx) = stdout_buf.find('\n') {
let line: String = stdout_buf.drain(..=idx).collect();
let trimmed = line.trim_end_matches(['\r', '\n']).to_string();
if tx.send(trimmed).await.is_err() {
cancel_task.cancel();
break;
}
}
}
Some(ChannelMsg::ExtendedData { ref data, .. }) => {
stderr_buf.push_str(&String::from_utf8_lossy(data));
while let Some(idx) = stderr_buf.find('\n') {
let line: String = stderr_buf.drain(..=idx).collect();
let trimmed = line.trim_end_matches(['\r', '\n']).to_string();
if tx.send(format!("!{}", trimmed)).await.is_err() {
cancel_task.cancel();
break;
}
}
}
Some(ChannelMsg::Eof) | Some(ChannelMsg::Close) | None => {
if !stdout_buf.is_empty() {
let _ = tx.send(stdout_buf.trim_end_matches(['\r', '\n']).to_string()).await;
}
if !stderr_buf.is_empty() {
let _ = tx.send(format!("!{}", stderr_buf.trim_end_matches(['\r', '\n']))).await;
}
break;
}
_ => {}
}
}
}
}
});
Ok((rx, cancel))
}
pub async fn disconnect(&mut self) -> Result<()> {
self.sftp.take();
if let Some(session) = self.session.take() {
match Arc::try_unwrap(session) {
Ok(session) => {
if let Err(e) = session.disconnect(Disconnect::ByApplication, "", "").await {
tracing::warn!("SSH disconnect failed cleanly: {}", e);
}
}
Err(arc_session) => {
tracing::debug!("SSH disconnect: other refs still alive, dropping handle");
drop(arc_session);
}
}
}
Ok(())
}
pub async fn open_direct_tcpip(
&self,
host: &str,
port: u16,
) -> Result<russh::Channel<russh::client::Msg>> {
let Some(session) = &self.session else {
return Err(anyhow::anyhow!("Not connected"));
};
let channel = session
.channel_open_direct_tcpip(host.to_string(), port as u32, "127.0.0.1", 0)
.await?;
Ok(channel)
}
pub async fn create_pty_session(&self, cols: u32, rows: u32) -> Result<PtySession> {
if let Some(session) = &self.session {
let mut channel = session.channel_open_session().await?;
channel
.request_pty(
true, "xterm-256color", cols, rows, 0, 0, &[], )
.await?;
channel.request_shell(true).await?;
let (input_tx, mut input_rx) = mpsc::channel::<Vec<u8>>(1000); let (output_tx, output_rx) = mpsc::channel::<Vec<u8>>(2000);
let input_channel = channel.make_writer();
let (resize_tx, mut resize_rx) = mpsc::channel::<(u32, u32)>(16);
let cancel = CancellationToken::new();
let input_cancel = cancel.clone();
tokio::spawn(async move {
let mut writer = input_channel;
loop {
tokio::select! {
biased;
_ = input_cancel.cancelled() => {
tracing::debug!("[PTY] input task cancelled");
break;
}
maybe_data = input_rx.recv() => {
let Some(data) = maybe_data else {
break;
};
if let Err(e) = writer.write_all(&data).await {
tracing::error!("[PTY] failed to send data to SSH: {}", e);
break;
}
if let Err(e) = writer.flush().await {
tracing::error!("[PTY] failed to flush data to SSH: {}", e);
break;
}
}
}
}
});
let output_cancel = cancel.clone();
tokio::spawn(async move {
loop {
tokio::select! {
biased;
_ = output_cancel.cancelled() => {
tracing::debug!("[PTY] output task cancelled");
break;
}
msg = channel.wait() => {
match msg {
Some(ChannelMsg::Data { data })
if output_tx.send(data.to_vec()).await.is_err() =>
{
break;
}
Some(ChannelMsg::ExtendedData { data, .. })
if output_tx.send(data.to_vec()).await.is_err() =>
{
break;
}
Some(ChannelMsg::Eof) | Some(ChannelMsg::Close) | None => {
tracing::debug!("[PTY] channel closed");
break;
}
Some(ChannelMsg::ExitStatus { exit_status }) => {
tracing::info!("[PTY] process exited with status: {}", exit_status);
}
_ => {}
}
}
resize = resize_rx.recv() => {
match resize {
Some((cols, rows)) => {
if let Err(e) = channel.window_change(cols, rows, 0, 0).await {
tracing::error!("[PTY] failed to send window change: {}", e);
} else {
tracing::debug!("[PTY] window changed to {}x{}", cols, rows);
}
}
None => {
break;
}
}
}
}
}
});
Ok(PtySession {
input_tx,
output_rx: Arc::new(tokio::sync::Mutex::new(output_rx)),
resize_tx,
cancel,
})
} else {
Err(anyhow::anyhow!("Not connected"))
}
}
pub async fn list_dir(&self, path: &str) -> Result<Vec<RemoteFileEntry>> {
let sftp = self.sftp_session().await?;
let entries = sftp
.read_dir(path)
.await
.map_err(|e| anyhow::anyhow!("Failed to list directory '{}': {}", path, e))?;
let mut result = Vec::new();
for entry in entries {
let name = entry.file_name();
if name == "." || name == ".." {
continue;
}
let attrs = entry.metadata();
let size = attrs.size.unwrap_or(0);
let mtime_secs = attrs.mtime.map(|t| t as i64);
let modified = mtime_secs.map(format_unix_timestamp);
let permissions = attrs.permissions.map(format_permissions);
let owner = attrs.uid.map(|u| u.to_string());
let group = attrs.gid.map(|g| g.to_string());
let file_type = if attrs.is_dir() {
FileEntryType::Directory
} else if attrs.is_symlink() {
FileEntryType::Symlink
} else {
FileEntryType::File
};
result.push(RemoteFileEntry {
name,
size,
modified,
modified_unix: mtime_secs,
permissions,
owner,
group,
file_type,
});
}
result.sort_by(|a, b| {
let a_is_dir = matches!(a.file_type, FileEntryType::Directory);
let b_is_dir = matches!(b.file_type, FileEntryType::Directory);
b_is_dir
.cmp(&a_is_dir)
.then_with(|| a.name.to_lowercase().cmp(&b.name.to_lowercase()))
});
Ok(result)
}
pub async fn download_file(&self, remote_path: &str, local_path: &str) -> Result<u64> {
self.download_file_with_progress(remote_path, local_path, |_| {}, None)
.await
}
pub async fn download_file_with_progress(
&self,
remote_path: &str,
local_path: &str,
mut progress: impl FnMut(u64),
cancel: Option<&CancellationToken>,
) -> Result<u64> {
let sftp = self.sftp_session().await?;
let mut remote_file = sftp.open(remote_path).await?;
let mut local_file = tokio::fs::File::create(local_path).await?;
let mut buf = vec![0u8; SFTP_CHUNK_SIZE];
let mut total_bytes = 0u64;
loop {
if let Some(token) = cancel
&& token.is_cancelled()
{
return Err(anyhow::anyhow!("Transfer cancelled"));
}
let n = remote_file.read(&mut buf).await?;
if n == 0 {
break;
}
local_file.write_all(&buf[..n]).await?;
total_bytes += n as u64;
progress(total_bytes);
}
local_file.flush().await?;
Ok(total_bytes)
}
pub async fn download_file_to_memory(&self, remote_path: &str) -> Result<Vec<u8>> {
let sftp = self.sftp_session().await?;
let mut remote_file = sftp.open(remote_path).await?;
let mut buffer = Vec::new();
let mut temp_buf = vec![0u8; SFTP_CHUNK_SIZE];
loop {
let n = remote_file.read(&mut temp_buf).await?;
if n == 0 {
break;
}
buffer.extend_from_slice(&temp_buf[..n]);
}
Ok(buffer)
}
pub async fn upload_file(&self, local_path: &str, remote_path: &str) -> Result<u64> {
self.upload_file_with_progress(local_path, remote_path, |_| {}, None)
.await
}
pub async fn upload_file_with_progress(
&self,
local_path: &str,
remote_path: &str,
mut progress: impl FnMut(u64),
cancel: Option<&CancellationToken>,
) -> Result<u64> {
let sftp = self.sftp_session().await?;
let mut local_file = tokio::fs::File::open(local_path).await?;
let mut remote_file = sftp.create(remote_path).await?;
let mut buf = vec![0u8; SFTP_CHUNK_SIZE];
let mut total_bytes = 0u64;
loop {
if let Some(token) = cancel
&& token.is_cancelled()
{
return Err(anyhow::anyhow!("Transfer cancelled"));
}
let n = local_file.read(&mut buf).await?;
if n == 0 {
break;
}
remote_file.write_all(&buf[..n]).await?;
total_bytes += n as u64;
progress(total_bytes);
}
remote_file.flush().await?;
Ok(total_bytes)
}
pub async fn upload_file_from_bytes(&self, data: &[u8], remote_path: &str) -> Result<u64> {
let sftp = self.sftp_session().await?;
let mut remote_file = sftp.create(remote_path).await?;
for chunk in data.chunks(SFTP_CHUNK_SIZE) {
remote_file.write_all(chunk).await?;
}
remote_file.flush().await?;
Ok(data.len() as u64)
}
pub async fn create_dir(&self, path: &str) -> Result<()> {
let sftp = self.sftp_session().await?;
sftp.create_dir(path)
.await
.map_err(|e| anyhow::anyhow!("Failed to create directory '{}': {}", path, e))?;
Ok(())
}
pub async fn rename(&self, old_path: &str, new_path: &str) -> Result<()> {
let sftp = self.sftp_session().await?;
sftp.rename(old_path, new_path).await.map_err(|e| {
anyhow::anyhow!("Failed to rename '{}' to '{}': {}", old_path, new_path, e)
})?;
Ok(())
}
pub async fn delete_file(&self, path: &str) -> Result<()> {
let sftp = self.sftp_session().await?;
sftp.remove_file(path)
.await
.map_err(|e| anyhow::anyhow!("Failed to delete file '{}': {}", path, e))?;
Ok(())
}
pub async fn delete_dir(&self, path: &str) -> Result<()> {
let sftp = self.sftp_session().await?;
sftp.remove_dir(path)
.await
.map_err(|e| anyhow::anyhow!("Failed to delete directory '{}': {}", path, e))?;
Ok(())
}
}
#[cfg(test)]
mod expand_home_tests {
use super::expand_home_path;
#[test]
fn returns_non_tilde_paths_unchanged() {
assert_eq!(
expand_home_path("/absolute/path").as_deref(),
Some("/absolute/path")
);
assert_eq!(
expand_home_path("relative/dir").as_deref(),
Some("relative/dir")
);
assert_eq!(expand_home_path("").as_deref(), Some(""));
}
#[test]
fn expands_tilde_slash_prefix_when_home_is_known() {
let expanded = expand_home_path("~/.ssh/id_rsa");
if let Some(result) = expanded {
assert!(
!result.starts_with("~/"),
"tilde must be expanded: {}",
result
);
assert!(
result.ends_with("/.ssh/id_rsa"),
"suffix preserved: {}",
result
);
}
}
}
#[cfg(test)]
mod command_output_tests {
use super::CommandOutput;
#[test]
fn is_success_requires_zero_exit() {
assert!(
CommandOutput {
stdout: "x".into(),
stderr: "".into(),
exit_code: Some(0),
}
.is_success()
);
assert!(
!CommandOutput {
stdout: "x".into(),
stderr: "".into(),
exit_code: Some(1),
}
.is_success()
);
assert!(
!CommandOutput {
stdout: "x".into(),
stderr: "".into(),
exit_code: None,
}
.is_success()
);
}
#[test]
fn combined_merges_streams_with_separator() {
let c = CommandOutput {
stdout: "out".into(),
stderr: "err".into(),
exit_code: Some(0),
};
assert_eq!(c.combined(), "out\nerr");
}
#[test]
fn combined_preserves_trailing_newline() {
let c = CommandOutput {
stdout: "out\n".into(),
stderr: "err".into(),
exit_code: Some(0),
};
assert_eq!(c.combined(), "out\nerr");
}
#[test]
fn combined_returns_single_stream_when_other_empty() {
assert_eq!(
CommandOutput {
stdout: "only".into(),
stderr: "".into(),
exit_code: Some(0),
}
.combined(),
"only"
);
assert_eq!(
CommandOutput {
stdout: "".into(),
stderr: "only-err".into(),
exit_code: Some(1),
}
.combined(),
"only-err"
);
}
}
#[cfg(test)]
mod redaction_tests {
use super::{AuthMethod, SshConfig};
#[test]
fn debug_redacts_password() {
let cfg = SshConfig {
host: "h".into(),
port: 22,
username: "u".into(),
auth_method: AuthMethod::Password {
password: "super-secret-123".into(),
},
};
let rendered = format!("{:?}", cfg);
assert!(
!rendered.contains("super-secret-123"),
"password must not appear in Debug output: {}",
rendered
);
assert!(rendered.contains("<redacted>"), "expected redaction marker");
}
#[test]
fn debug_redacts_passphrase() {
let m = AuthMethod::PublicKey {
key_path: "/tmp/id".into(),
passphrase: Some("xyz-passphrase".into()),
};
let rendered = format!("{:?}", m);
assert!(!rendered.contains("xyz-passphrase"));
assert!(rendered.contains("<redacted>"));
assert!(rendered.contains("/tmp/id"));
}
#[test]
fn debug_shows_none_when_no_passphrase() {
let m = AuthMethod::PublicKey {
key_path: "/tmp/id".into(),
passphrase: None,
};
let rendered = format!("{:?}", m);
assert!(rendered.contains("<none>"));
}
}
#[cfg(test)]
mod key_loading_tests {
use super::load_private_key;
use std::io::Write;
use tempfile::NamedTempFile;
const TEST_OPENSSH_PRIVATE_KEY: &str = "\
-----BEGIN OPENSSH PRIVATE KEY-----\n\
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\n\
QyNTUxOQAAACCzPq7zfqLffKoBDe/eo04kH2XxtSmk9D7RQyf1xUqrYgAAAJgAIAxdACAM\n\
XQAAAAtzc2gtZWQyNTUxOQAAACCzPq7zfqLffKoBDe/eo04kH2XxtSmk9D7RQyf1xUqrYg\n\
AAAEC2BsIi0QwW2uFscKTUUXNHLsYX4FxlaSDSblbAj7WR7bM+rvN+ot98qgEN796jTiQf\n\
ZfG1KaT0PtFDJ/XFSqtiAAAAEHVzZXJAZXhhbXBsZS5jb20BAgMEBQ==\n\
-----END OPENSSH PRIVATE KEY-----\n";
#[test]
fn load_private_key_reads_key_file_contents() {
let mut key_file = NamedTempFile::new().expect("failed to create temp key file");
key_file
.write_all(TEST_OPENSSH_PRIVATE_KEY.as_bytes())
.expect("failed to write temp key file");
let key = load_private_key(
key_file
.path()
.to_str()
.expect("temp key path must be valid UTF-8"),
None,
)
.expect("expected key file to load successfully");
assert_eq!(key.name(), "ssh-ed25519");
}
}
#[cfg(test)]
mod tests;