use std::sync::Arc;
use std::time::{Duration, Instant};
use russh::ChannelMsg;
use russh::keys::PrivateKeyWithHashAlg;
use tokio::time;
use tracing::{debug, warn};
use crate::error::AgentError;
use crate::provider::{AgentConfig, AgentProvider, InvokeFuture};
use super::common::{self, DEFAULT_TIMEOUT};
#[derive(Clone)]
enum SshAuth {
Password(String),
PrivateKey {
key_data: String,
passphrase: Option<String>,
},
}
#[derive(Clone, Default, Debug)]
pub enum HostKeyPolicy {
#[default]
AcceptAll,
RejectAll,
}
struct SshHandler {
policy: HostKeyPolicy,
}
impl russh::client::Handler for SshHandler {
type Error = russh::Error;
async fn check_server_key(
&mut self,
_server_public_key: &russh::keys::PublicKey,
) -> Result<bool, Self::Error> {
match self.policy {
HostKeyPolicy::AcceptAll => {
debug!("accepting SSH server key without verification (HostKeyPolicy::AcceptAll)");
Ok(true)
}
HostKeyPolicy::RejectAll => {
warn!("rejecting SSH server key (HostKeyPolicy::RejectAll)");
Ok(false)
}
}
}
}
#[derive(Clone)]
pub struct SshProvider {
host: String,
port: u16,
username: String,
auth: Option<SshAuth>,
claude_path: String,
working_dir: Option<String>,
timeout: Duration,
host_key_policy: HostKeyPolicy,
}
impl SshProvider {
pub fn new(host: &str, username: &str) -> Self {
Self {
host: host.to_string(),
port: 22,
username: username.to_string(),
auth: None,
claude_path: "claude".to_string(),
working_dir: None,
timeout: DEFAULT_TIMEOUT,
host_key_policy: HostKeyPolicy::default(),
}
}
pub fn port(mut self, port: u16) -> Self {
self.port = port;
self
}
pub fn password(mut self, password: &str) -> Self {
self.auth = Some(SshAuth::Password(password.to_string()));
self
}
pub fn private_key(mut self, key_data: &str) -> Self {
self.auth = Some(SshAuth::PrivateKey {
key_data: key_data.to_string(),
passphrase: None,
});
self
}
pub fn private_key_with_passphrase(mut self, key_data: &str, passphrase: &str) -> Self {
self.auth = Some(SshAuth::PrivateKey {
key_data: key_data.to_string(),
passphrase: Some(passphrase.to_string()),
});
self
}
pub fn claude_path(mut self, path: &str) -> Self {
self.claude_path = path.to_string();
self
}
pub fn working_dir(mut self, dir: &str) -> Self {
self.working_dir = Some(dir.to_string());
self
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub fn host_key_policy(mut self, policy: HostKeyPolicy) -> Self {
self.host_key_policy = policy;
self
}
async fn authenticate(
&self,
session: &mut russh::client::Handle<SshHandler>,
) -> Result<(), AgentError> {
let auth = self
.auth
.as_ref()
.ok_or_else(|| AgentError::ProcessFailed {
exit_code: -1,
stderr:
"no SSH authentication method configured - call .password() or .private_key()"
.to_string(),
})?;
let authenticated = match auth {
SshAuth::Password(pw) => session
.authenticate_password(&self.username, pw)
.await
.map_err(|e| AgentError::ProcessFailed {
exit_code: -1,
stderr: format!("SSH password auth failed: {e}"),
})?
.success(),
SshAuth::PrivateKey {
key_data,
passphrase,
} => {
let key = russh::keys::decode_secret_key(key_data, passphrase.as_deref()).map_err(
|e| AgentError::ProcessFailed {
exit_code: -1,
stderr: format!("failed to parse SSH private key: {e}"),
},
)?;
let key_with_alg = PrivateKeyWithHashAlg::new(Arc::new(key), None);
session
.authenticate_publickey(&self.username, key_with_alg)
.await
.map_err(|e| AgentError::ProcessFailed {
exit_code: -1,
stderr: format!("SSH public key auth failed: {e}"),
})?
.success()
}
};
if !authenticated {
return Err(AgentError::ProcessFailed {
exit_code: -1,
stderr: "SSH authentication rejected by server".to_string(),
});
}
Ok(())
}
}
impl AgentProvider for SshProvider {
fn invoke<'a>(&'a self, config: &'a AgentConfig) -> InvokeFuture<'a> {
Box::pin(async move {
common::validate_prompt_size(config)?;
let args = common::build_args(config)?;
let claude_cmd = common::build_shell_command(&self.claude_path, &args);
let env_prefix = common::env_unset_shell_prefix();
let remote_cmd = match (&self.working_dir, &config.working_dir) {
(_, Some(dir)) | (Some(dir), None) => {
format!(
"{env_prefix}cd {} && {}",
common::build_shell_command(dir, &[]),
claude_cmd
)
}
(None, None) => format!("{env_prefix}{claude_cmd}"),
};
debug!(
host = %self.host,
port = self.port,
username = %self.username,
model = %config.model,
"connecting via SSH"
);
let start = Instant::now();
let ssh_config = Arc::new(russh::client::Config::default());
let handler = SshHandler {
policy: self.host_key_policy.clone(),
};
let mut session = time::timeout(
Duration::from_secs(30),
russh::client::connect(ssh_config, (&*self.host, self.port), handler),
)
.await
.map_err(|_| AgentError::Timeout {
limit: Duration::from_secs(30),
})?
.map_err(|e| AgentError::ProcessFailed {
exit_code: -1,
stderr: format!("SSH connection failed: {e}"),
})?;
self.authenticate(&mut session).await?;
let mut channel =
session
.channel_open_session()
.await
.map_err(|e| AgentError::ProcessFailed {
exit_code: -1,
stderr: format!("failed to open SSH session channel: {e}"),
})?;
debug!(
remote_cmd_len = remote_cmd.len(),
"executing remote command"
);
channel
.exec(true, remote_cmd.as_bytes())
.await
.map_err(|e| AgentError::ProcessFailed {
exit_code: -1,
stderr: format!("failed to exec remote command: {e}"),
})?;
channel.eof().await.map_err(|e| AgentError::ProcessFailed {
exit_code: -1,
stderr: format!("failed to send EOF on SSH channel: {e}"),
})?;
let mut stdout_buf = Vec::new();
let mut stderr_buf = Vec::new();
let mut exit_code: Option<u32> = None;
let collect_result = time::timeout(self.timeout, async {
loop {
let msg = channel.wait().await;
let Some(msg) = msg else { break };
match msg {
ChannelMsg::Data { ref data } => {
stdout_buf.extend_from_slice(data);
}
ChannelMsg::ExtendedData { ref data, ext } => {
if ext == 1 {
stderr_buf.extend_from_slice(data);
}
}
ChannelMsg::ExitStatus { exit_status } => {
exit_code = Some(exit_status);
}
_ => {}
}
}
})
.await;
let _ = session
.disconnect(russh::Disconnect::ByApplication, "", "")
.await;
if collect_result.is_err() {
warn!(timeout = ?self.timeout, "SSH command timed out");
return Err(AgentError::Timeout {
limit: self.timeout,
});
}
let duration_ms = start.elapsed().as_millis() as u64;
let code = exit_code.unwrap_or(1) as i32;
let stdout = String::from_utf8_lossy(&stdout_buf).to_string();
let stderr = String::from_utf8_lossy(&stderr_buf).to_string();
if code != 0 {
return common::handle_nonzero_exit(
code,
&stdout,
&stderr,
config,
duration_ms,
"ssh",
);
}
debug!(stdout_len = stdout.len(), "remote claude process completed");
common::parse_output(&stdout, config, duration_ms)
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ssh_provider_defaults() {
let provider = SshProvider::new("host.example.com", "user");
assert_eq!(provider.host, "host.example.com");
assert_eq!(provider.port, 22);
assert_eq!(provider.username, "user");
assert_eq!(provider.claude_path, "claude");
assert!(provider.working_dir.is_none());
assert_eq!(provider.timeout, DEFAULT_TIMEOUT);
assert!(provider.auth.is_none());
assert!(matches!(provider.host_key_policy, HostKeyPolicy::AcceptAll));
}
#[test]
fn ssh_provider_builder_chain() {
let provider = SshProvider::new("host", "user")
.port(2222)
.password("pw")
.claude_path("/usr/local/bin/claude")
.working_dir("/opt/project")
.timeout(Duration::from_secs(600));
assert_eq!(provider.port, 2222);
assert_eq!(provider.claude_path, "/usr/local/bin/claude");
assert_eq!(provider.working_dir, Some("/opt/project".to_string()));
assert_eq!(provider.timeout, Duration::from_secs(600));
assert!(matches!(provider.auth, Some(SshAuth::Password(_))));
}
#[test]
fn ssh_provider_private_key_auth() {
let provider = SshProvider::new("host", "user").private_key("-----BEGIN KEY-----");
assert!(matches!(
provider.auth,
Some(SshAuth::PrivateKey {
passphrase: None,
..
})
));
}
#[test]
fn ssh_provider_private_key_with_passphrase() {
let provider = SshProvider::new("host", "user")
.private_key_with_passphrase("-----BEGIN KEY-----", "secret");
assert!(matches!(
provider.auth,
Some(SshAuth::PrivateKey {
passphrase: Some(_),
..
})
));
}
#[test]
fn ssh_provider_clone() {
let provider = SshProvider::new("host", "user").port(2222).password("pw");
let cloned = provider.clone();
assert_eq!(cloned.host, "host");
assert_eq!(cloned.port, 2222);
}
#[test]
fn host_key_policy_default_is_accept_all() {
let policy = HostKeyPolicy::default();
assert!(matches!(policy, HostKeyPolicy::AcceptAll));
}
#[test]
fn host_key_policy_builder_method() {
let provider = SshProvider::new("host", "user").host_key_policy(HostKeyPolicy::RejectAll);
assert!(matches!(provider.host_key_policy, HostKeyPolicy::RejectAll));
}
const TEST_PUBKEY: &str = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBOZMtGiPyW0pMN+JJuYjIGJfqyO5MHBsFkzseVSp60M test@example";
fn test_public_key() -> russh::keys::PublicKey {
russh::keys::PublicKey::from_openssh(TEST_PUBKEY).expect("parse test public key")
}
#[tokio::test]
async fn host_key_policy_accept_all_returns_true() {
use russh::client::Handler;
let mut handler = SshHandler {
policy: HostKeyPolicy::AcceptAll,
};
let key = test_public_key();
let result = handler.check_server_key(&key).await;
assert!(result.unwrap());
}
#[tokio::test]
async fn host_key_policy_reject_all_returns_false() {
use russh::client::Handler;
let mut handler = SshHandler {
policy: HostKeyPolicy::RejectAll,
};
let key = test_public_key();
let result = handler.check_server_key(&key).await;
assert!(!result.unwrap());
}
}