use crate::jump::parser::JumpHost;
use crate::ssh::tokio_client::{AuthMethod, ClientHandler};
use anyhow::{Context, Result};
use std::path::Path;
use tokio::sync::Mutex;
use tracing::{debug, warn};
use zeroize::Zeroizing;
#[cfg(not(target_os = "windows"))]
const AGENT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(5);
#[cfg(not(target_os = "windows"))]
async fn agent_has_identities() -> bool {
use russh::keys::agent::client::AgentClient;
use tokio::time::timeout;
let result = timeout(AGENT_TIMEOUT, async {
let mut agent = AgentClient::connect_env().await?;
agent.request_identities().await
})
.await;
match result {
Ok(Ok(identities)) => {
let has_keys = !identities.is_empty();
if has_keys {
debug!("SSH agent has {} loaded identities", identities.len());
} else {
debug!("SSH agent is running but has no loaded identities");
}
has_keys
}
Ok(Err(e)) => {
warn!("Failed to communicate with SSH agent: {e}");
false
}
Err(_) => {
warn!("SSH agent operation timed out after {:?}", AGENT_TIMEOUT);
false
}
}
}
pub(super) async fn determine_auth_method(
jump_host: &JumpHost,
key_path: Option<&Path>,
use_agent: bool,
use_password: bool,
auth_mutex: &Mutex<()>,
) -> Result<AuthMethod> {
let effective_key_path = if let Some(ref jump_key) = jump_host.ssh_key {
use crate::config::{expand_env_vars, expand_tilde};
let expanded_path = expand_env_vars(jump_key);
let path = Path::new(&expanded_path);
let expanded_tilde = if expanded_path.starts_with('~') {
expand_tilde(path)
} else {
path.to_path_buf()
};
Some(expanded_tilde)
} else {
key_path.map(|p| p.to_path_buf())
};
#[cfg(not(target_os = "windows"))]
let agent_available = {
if let Ok(socket_path) = std::env::var("SSH_AUTH_SOCK") {
let path = std::path::Path::new(&socket_path);
if path.exists() {
agent_has_identities().await
} else {
debug!(
"SSH_AUTH_SOCK points to non-existent socket: {}, falling back to key files",
socket_path
);
false
}
} else {
false
}
};
#[cfg(target_os = "windows")]
let agent_available = false;
if use_password {
let _guard = auth_mutex.lock().await;
let prompt = format!(
"Enter password for jump host {} ({}@{}): ",
jump_host.to_connection_string(),
jump_host.effective_user(),
jump_host.host
);
let password = Zeroizing::new(
rpassword::prompt_password(prompt).with_context(|| "Failed to read password")?,
);
return Ok(AuthMethod::with_password(&password));
}
if use_agent && agent_available {
#[cfg(not(target_os = "windows"))]
{
return Ok(AuthMethod::Agent);
}
}
if let Some(key_path) = effective_key_path.as_deref() {
let key_contents = Zeroizing::new(
std::fs::read_to_string(key_path)
.with_context(|| format!("Failed to read SSH key file: {key_path:?}"))?,
);
let passphrase = if key_contents.contains("ENCRYPTED")
|| key_contents.contains("Proc-Type: 4,ENCRYPTED")
{
let _guard = auth_mutex.lock().await;
let prompt = format!(
"Enter passphrase for key {key_path:?} (jump host {}): ",
jump_host.to_connection_string()
);
let pass = Zeroizing::new(
rpassword::prompt_password(prompt).with_context(|| "Failed to read passphrase")?,
);
Some(pass)
} else {
None
};
return Ok(AuthMethod::with_key_file(
key_path,
passphrase.as_ref().map(|p| p.as_str()),
));
}
#[cfg(not(target_os = "windows"))]
if agent_available {
return Ok(AuthMethod::Agent);
}
let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
let home_path = Path::new(&home).join(".ssh");
let default_keys = [
home_path.join("id_ed25519"),
home_path.join("id_rsa"),
home_path.join("id_ecdsa"),
home_path.join("id_dsa"),
];
for default_key in &default_keys {
if default_key.exists() {
let key_contents = Zeroizing::new(
std::fs::read_to_string(default_key)
.with_context(|| format!("Failed to read SSH key file: {default_key:?}"))?,
);
let passphrase = if key_contents.contains("ENCRYPTED")
|| key_contents.contains("Proc-Type: 4,ENCRYPTED")
{
let _guard = auth_mutex.lock().await;
let prompt = format!(
"Enter passphrase for key {default_key:?} (jump host {}): ",
jump_host.to_connection_string()
);
let pass = Zeroizing::new(
rpassword::prompt_password(prompt)
.with_context(|| "Failed to read passphrase")?,
);
Some(pass)
} else {
None
};
return Ok(AuthMethod::with_key_file(
default_key,
passphrase.as_ref().map(|p| p.as_str()),
));
}
}
anyhow::bail!(
"No authentication method available for jump host '{}' (user: {}). \
Please specify -i <key_file> or ensure SSH agent is running with loaded keys.",
jump_host.to_connection_string(),
jump_host.effective_user()
)
}
pub(super) async fn authenticate_connection(
handle: &mut russh::client::Handle<ClientHandler>,
username: &str,
auth_method: AuthMethod,
host_description: &str,
) -> Result<()> {
use crate::ssh::tokio_client::AuthMethod;
debug!(
"Authenticating to {} as user '{}' using {:?}",
host_description,
username,
match &auth_method {
AuthMethod::Password(_) => "password".to_string(),
AuthMethod::PrivateKey { .. } => "private key".to_string(),
AuthMethod::PrivateKeyFile { key_file_path, .. } =>
format!("key file {:?}", key_file_path),
#[cfg(not(target_os = "windows"))]
AuthMethod::Agent => "SSH agent".to_string(),
#[allow(unreachable_patterns)]
_ => "unknown".to_string(),
}
);
match auth_method {
AuthMethod::Password(password) => {
let auth_result = handle
.authenticate_password(username, &**password)
.await
.map_err(|e| {
anyhow::anyhow!(
"Password authentication failed for {} (user: {}): {}",
host_description,
username,
e
)
})?;
if !auth_result.success() {
anyhow::bail!(
"Password authentication rejected by {} for user '{}'. \
Please check the password is correct.",
host_description,
username
);
}
}
AuthMethod::PrivateKey { key_data, key_pass } => {
let private_key =
russh::keys::decode_secret_key(&key_data, key_pass.as_ref().map(|p| &***p))
.map_err(|e| {
anyhow::anyhow!(
"Failed to decode private key for {}: {}",
host_description,
e
)
})?;
let auth_result = handle
.authenticate_publickey(
username,
russh::keys::PrivateKeyWithHashAlg::new(
std::sync::Arc::new(private_key),
handle.best_supported_rsa_hash().await?.flatten(),
),
)
.await
.map_err(|e| {
anyhow::anyhow!(
"Private key authentication failed for {} (user: {}): {}",
host_description,
username,
e
)
})?;
if !auth_result.success() {
anyhow::bail!(
"Private key authentication rejected by {} for user '{}'. \
The key may not be authorized on this host.",
host_description,
username
);
}
}
AuthMethod::PrivateKeyFile {
key_file_path,
key_pass,
} => {
let private_key =
russh::keys::load_secret_key(&key_file_path, key_pass.as_ref().map(|p| &***p))
.map_err(|e| {
anyhow::anyhow!(
"Failed to load private key {:?} for {}: {}",
key_file_path,
host_description,
e
)
})?;
let auth_result = handle
.authenticate_publickey(
username,
russh::keys::PrivateKeyWithHashAlg::new(
std::sync::Arc::new(private_key),
handle.best_supported_rsa_hash().await?.flatten(),
),
)
.await
.map_err(|e| {
anyhow::anyhow!(
"Private key file authentication failed for {} (user: {}, key: {:?}): {}",
host_description,
username,
key_file_path,
e
)
})?;
if !auth_result.success() {
anyhow::bail!(
"Private key file authentication rejected by {} for user '{}' (key: {:?}). \
The key may not be authorized on this host.",
host_description,
username,
key_file_path
);
}
}
#[cfg(not(target_os = "windows"))]
AuthMethod::Agent => {
let mut agent = russh::keys::agent::client::AgentClient::connect_env()
.await
.map_err(|e| {
anyhow::anyhow!(
"Failed to connect to SSH agent for {}: {}. \
Check that SSH_AUTH_SOCK is set and the agent is running.",
host_description,
e
)
})?;
let identities = agent.request_identities().await.map_err(|e| {
anyhow::anyhow!(
"Failed to get identities from SSH agent for {}: {}",
host_description,
e
)
})?;
if identities.is_empty() {
anyhow::bail!(
"SSH agent has no loaded keys for {} authentication. \
Please add keys using 'ssh-add' or specify -i <key_file>.",
host_description
);
}
let mut auth_success = false;
let identity_count = identities.len();
for identity in identities {
let result = handle
.authenticate_publickey_with(
username,
identity.public_key().into_owned(),
handle.best_supported_rsa_hash().await?.flatten(),
&mut agent,
)
.await;
if let Ok(auth_result) = result {
if auth_result.success() {
auth_success = true;
break;
}
}
}
if !auth_success {
anyhow::bail!(
"SSH agent authentication rejected by {} for user '{}'. \
Tried {} key(s) from agent. None were authorized on this host.",
host_description,
username,
identity_count
);
}
}
_ => {
anyhow::bail!("Unsupported authentication method for {}", host_description);
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::env;
use tempfile::TempDir;
fn create_test_jump_host() -> JumpHost {
JumpHost::new(
"test.example.com".to_string(),
Some("testuser".to_string()),
Some(22),
)
}
fn create_test_ssh_key(dir: &TempDir, name: &str) -> std::path::PathBuf {
let key_path = dir.path().join(name);
let key_content = r#"-----BEGIN OPENSSH PRIVATE KEY-----
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW
QyNTUxOQAAACBUZXN0IGtleSBmb3IgdW5pdCB0ZXN0cyAtIG5vdCByZWFsAAAAWBAAAABU
ZXN0IGtleSBmb3IgdW5pdCB0ZXN0cyAtIG5vdCByZWFsVGVzdCBrZXkgZm9yIHVuaXQgdG
VzdHMgLSBub3QgcmVhbAAAAAtzczNoLWVkMjU1MTkAAAAgVGVzdCBrZXkgZm9yIHVuaXQg
dGVzdHMgLSBub3QgcmVhbAECAwQ=
-----END OPENSSH PRIVATE KEY-----"#;
std::fs::write(&key_path, key_content).expect("Failed to write test key");
key_path
}
#[test]
#[cfg(not(target_os = "windows"))]
fn test_agent_timeout_constant() {
assert_eq!(AGENT_TIMEOUT, std::time::Duration::from_secs(5));
}
#[tokio::test]
#[serial_test::serial]
async fn test_agent_available_false_when_no_socket() {
let original = env::var("SSH_AUTH_SOCK").ok();
env::remove_var("SSH_AUTH_SOCK");
assert!(env::var("SSH_AUTH_SOCK").is_err());
let agent_available = if env::var("SSH_AUTH_SOCK").is_ok() {
true } else {
false
};
assert!(
!agent_available,
"agent_available should be false when SSH_AUTH_SOCK is not set"
);
if let Some(val) = original {
env::set_var("SSH_AUTH_SOCK", val);
}
}
#[tokio::test]
#[cfg(not(target_os = "windows"))]
async fn test_agent_has_identities_invalid_socket() {
let original = env::var("SSH_AUTH_SOCK").ok();
env::set_var("SSH_AUTH_SOCK", "/tmp/nonexistent_ssh_agent_socket_12345");
let result = agent_has_identities().await;
assert!(
!result,
"agent_has_identities should return false for invalid socket"
);
match original {
Some(val) => env::set_var("SSH_AUTH_SOCK", val),
None => env::remove_var("SSH_AUTH_SOCK"),
}
}
#[tokio::test]
async fn test_determine_auth_method_fallback_to_key_file() {
let original = env::var("SSH_AUTH_SOCK").ok();
env::remove_var("SSH_AUTH_SOCK");
let temp_dir = TempDir::new().expect("Failed to create temp dir");
let key_path = create_test_ssh_key(&temp_dir, "id_test");
let jump_host = create_test_jump_host();
let auth_mutex = Mutex::new(());
let result = determine_auth_method(
&jump_host,
Some(key_path.as_path()),
true, false, &auth_mutex,
)
.await;
assert!(result.is_ok(), "Should succeed with key file fallback");
let auth_method = result.unwrap();
match auth_method {
AuthMethod::PrivateKeyFile { .. } => {
}
AuthMethod::Agent => {
panic!("Should not use Agent when SSH_AUTH_SOCK is not set");
}
other => {
panic!("Unexpected auth method: {:?}", other);
}
}
if let Some(val) = original {
env::set_var("SSH_AUTH_SOCK", val);
}
}
#[tokio::test]
#[cfg(not(target_os = "windows"))]
async fn test_determine_auth_method_prefers_agent_when_available() {
if env::var("SSH_AUTH_SOCK").is_err() {
return;
}
let has_identities = agent_has_identities().await;
if !has_identities {
return;
}
let temp_dir = TempDir::new().expect("Failed to create temp dir");
let key_path = create_test_ssh_key(&temp_dir, "id_test");
let jump_host = create_test_jump_host();
let auth_mutex = Mutex::new(());
let result = determine_auth_method(
&jump_host,
Some(key_path.as_path()),
true, false, &auth_mutex,
)
.await;
assert!(result.is_ok());
let auth_method = result.unwrap();
match auth_method {
AuthMethod::Agent => {
}
AuthMethod::PrivateKeyFile { .. } => {
}
other => {
panic!("Unexpected auth method: {:?}", other);
}
}
}
#[tokio::test]
async fn test_determine_auth_method_tries_default_keys() {
let original_sock = env::var("SSH_AUTH_SOCK").ok();
env::remove_var("SSH_AUTH_SOCK");
let temp_home = TempDir::new().expect("Failed to create temp home");
let ssh_dir = temp_home.path().join(".ssh");
std::fs::create_dir_all(&ssh_dir).expect("Failed to create .ssh dir");
let key_content = r#"-----BEGIN OPENSSH PRIVATE KEY-----
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW
QyNTUxOQAAACBUZXN0IGtleSBmb3IgdW5pdCB0ZXN0cyAtIG5vdCByZWFsAAAAWBAAAABU
ZXN0IGtleSBmb3IgdW5pdCB0ZXN0cyAtIG5vdCByZWFsVGVzdCBrZXkgZm9yIHVuaXQgdG
VzdHMgLSBub3QgcmVhbAAAAAtzczNoLWVkMjU1MTkAAAAgVGVzdCBrZXkgZm9yIHVuaXQg
dGVzdHMgLSBub3QgcmVhbAECAwQ=
-----END OPENSSH PRIVATE KEY-----"#;
std::fs::write(ssh_dir.join("id_ed25519"), key_content).expect("Failed to write key");
let original_home = env::var("HOME").ok();
env::set_var("HOME", temp_home.path());
let jump_host = create_test_jump_host();
let auth_mutex = Mutex::new(());
let result = determine_auth_method(
&jump_host,
None, false, false, &auth_mutex,
)
.await;
assert!(
result.is_ok(),
"Should find default key at ~/.ssh/id_ed25519"
);
let auth_method = result.unwrap();
match auth_method {
AuthMethod::PrivateKeyFile { key_file_path, .. } => {
let path_str = key_file_path.to_string_lossy();
assert!(
path_str.ends_with("id_ed25519") || path_str.contains("id_ed25519"),
"Should use id_ed25519 from default location, got: {path_str}"
);
}
other => {
panic!("Expected PrivateKeyFile, got {:?}", other);
}
}
if let Some(val) = original_sock {
env::set_var("SSH_AUTH_SOCK", val);
}
if let Some(val) = original_home {
env::set_var("HOME", val);
}
}
#[tokio::test]
#[serial_test::serial]
async fn test_determine_auth_method_fails_when_no_method_available() {
let original_sock = env::var("SSH_AUTH_SOCK").ok();
let original_home = env::var("HOME").ok();
env::set_var(
"SSH_AUTH_SOCK",
"/nonexistent/path/to/agent/socket/test_12345",
);
let temp_home = TempDir::new().expect("Failed to create temp home");
let ssh_dir = temp_home.path().join(".ssh");
std::fs::create_dir_all(&ssh_dir).expect("Failed to create .ssh dir");
env::set_var("HOME", temp_home.path());
let jump_host = create_test_jump_host();
let auth_mutex = Mutex::new(());
let result = determine_auth_method(
&jump_host,
None, false, false, &auth_mutex,
)
.await;
match original_sock {
Some(val) => env::set_var("SSH_AUTH_SOCK", val),
None => env::remove_var("SSH_AUTH_SOCK"),
}
if let Some(val) = original_home {
env::set_var("HOME", val);
}
match result {
Err(e) => {
let error_msg = e.to_string();
assert!(
error_msg.contains("No authentication method available"),
"Error should mention no auth method available: {error_msg}"
);
}
Ok(AuthMethod::Agent) | Ok(AuthMethod::PrivateKeyFile { .. }) => {
}
Ok(other) => {
panic!(
"Expected error, Agent, or PrivateKeyFile auth method, got {:?}",
other
);
}
}
}
#[test]
fn test_agent_caching_design() {
}
#[test]
#[cfg(not(target_os = "windows"))]
fn test_timeout_design() {
assert_eq!(
AGENT_TIMEOUT,
std::time::Duration::from_secs(5),
"Timeout should be 5 seconds"
);
}
#[tokio::test]
async fn test_jump_host_ssh_key_priority() {
let original_sock = env::var("SSH_AUTH_SOCK").ok();
env::remove_var("SSH_AUTH_SOCK");
let temp_dir = TempDir::new().expect("Failed to create temp dir");
let jump_key_path = create_test_ssh_key(&temp_dir, "jump_host_key");
let jump_key_str = jump_key_path.to_string_lossy().to_string();
let cluster_key_path = create_test_ssh_key(&temp_dir, "cluster_key");
let jump_host = JumpHost::with_ssh_key(
"test.example.com".to_string(),
Some("testuser".to_string()),
Some(22),
Some(jump_key_str.clone()),
);
let auth_mutex = Mutex::new(());
let result = determine_auth_method(
&jump_host,
Some(cluster_key_path.as_path()), false, false, &auth_mutex,
)
.await;
assert!(result.is_ok(), "Should succeed with jump host's key");
let auth_method = result.unwrap();
match auth_method {
AuthMethod::PrivateKeyFile { key_file_path, .. } => {
let path_str = key_file_path.to_string_lossy();
assert!(
path_str.contains("jump_host_key"),
"Should use jump host's key (jump_host_key), got: {path_str}"
);
assert!(
!path_str.contains("cluster_key"),
"Should NOT use cluster key, got: {path_str}"
);
}
other => {
panic!("Expected PrivateKeyFile, got {:?}", other);
}
}
if let Some(val) = original_sock {
env::set_var("SSH_AUTH_SOCK", val);
}
}
#[tokio::test]
async fn test_fallback_to_cluster_key() {
let original_sock = env::var("SSH_AUTH_SOCK").ok();
env::remove_var("SSH_AUTH_SOCK");
let temp_dir = TempDir::new().expect("Failed to create temp dir");
let cluster_key_path = create_test_ssh_key(&temp_dir, "cluster_key");
let jump_host = JumpHost::new(
"test.example.com".to_string(),
Some("testuser".to_string()),
Some(22),
);
let auth_mutex = Mutex::new(());
let result = determine_auth_method(
&jump_host,
Some(cluster_key_path.as_path()),
false,
false,
&auth_mutex,
)
.await;
assert!(result.is_ok(), "Should succeed with cluster key");
let auth_method = result.unwrap();
match auth_method {
AuthMethod::PrivateKeyFile { key_file_path, .. } => {
let path_str = key_file_path.to_string_lossy();
assert!(
path_str.contains("cluster_key"),
"Should use cluster key, got: {path_str}"
);
}
other => {
panic!("Expected PrivateKeyFile, got {:?}", other);
}
}
if let Some(val) = original_sock {
env::set_var("SSH_AUTH_SOCK", val);
}
}
}