use anyhow::{Context, Result};
use std::path::{Path, PathBuf};
use std::time::Duration;
use tokio::time::timeout;
use zeroize::Zeroizing;
use super::tokio_client::AuthMethod;
const AUTH_PROMPT_TIMEOUT: Duration = Duration::from_secs(30);
#[cfg(not(target_os = "windows"))]
const AGENT_TIMEOUT: Duration = Duration::from_secs(5);
#[cfg(not(target_os = "windows"))]
async fn agent_has_identities() -> bool {
use russh::keys::agent::client::AgentClient;
let result = timeout(AGENT_TIMEOUT, async {
let mut agent = AgentClient::connect_env().await?;
agent.request_identities().await
})
.await;
match result {
Ok(Ok(identities)) => {
let has_keys = !identities.is_empty();
if has_keys {
tracing::debug!("SSH agent has {} loaded identities", identities.len());
} else {
tracing::debug!("SSH agent is running but has no loaded identities");
}
has_keys
}
Ok(Err(e)) => {
tracing::warn!("Failed to communicate with SSH agent: {e}");
false
}
Err(_) => {
tracing::warn!("SSH agent operation timed out after {:?}", AGENT_TIMEOUT);
false
}
}
}
const MAX_USERNAME_LENGTH: usize = 256;
const MAX_HOSTNAME_LENGTH: usize = 253;
#[derive(Debug, Clone)]
pub struct AuthContext {
pub key_path: Option<PathBuf>,
pub use_agent: bool,
pub use_password: bool,
pub allow_password_fallback: bool,
#[cfg(target_os = "macos")]
pub use_keychain: bool,
pub username: String,
pub host: String,
}
impl AuthContext {
pub fn new(username: String, host: String) -> Result<Self> {
if username.is_empty() {
anyhow::bail!("Username cannot be empty");
}
if username.len() > MAX_USERNAME_LENGTH {
anyhow::bail!("Username too long (max {MAX_USERNAME_LENGTH} characters)");
}
if username.contains(['/', '\0', '\n', '\r']) {
anyhow::bail!("Username contains invalid characters");
}
if host.is_empty() {
anyhow::bail!("Hostname cannot be empty");
}
if host.len() > MAX_HOSTNAME_LENGTH {
anyhow::bail!("Hostname too long (max {MAX_HOSTNAME_LENGTH} characters)");
}
if host.contains(['\0', '\n', '\r']) {
anyhow::bail!("Hostname contains invalid characters");
}
Ok(Self {
key_path: None,
use_agent: false,
use_password: false,
allow_password_fallback: false,
#[cfg(target_os = "macos")]
use_keychain: false,
username,
host,
})
}
pub fn with_key_path(mut self, key_path: Option<PathBuf>) -> Result<Self> {
if let Some(path) = key_path {
let canonical_path = path
.canonicalize()
.with_context(|| format!("Failed to resolve SSH key path: {path:?}"))?;
if !canonical_path.is_file() {
anyhow::bail!("SSH key path is not a file: {canonical_path:?}");
}
self.key_path = Some(canonical_path);
} else {
self.key_path = None;
}
Ok(self)
}
pub fn with_agent(mut self, use_agent: bool) -> Self {
self.use_agent = use_agent;
self
}
pub fn with_password(mut self, use_password: bool) -> Self {
self.use_password = use_password;
self
}
pub fn with_password_fallback(mut self, allow: bool) -> Self {
self.allow_password_fallback = allow;
self
}
#[cfg(target_os = "macos")]
pub fn with_keychain(mut self, use_keychain: bool) -> Self {
self.use_keychain = use_keychain;
self
}
pub async fn determine_method(&self) -> Result<AuthMethod> {
let start_time = std::time::Instant::now();
let result = self.determine_method_internal().await;
let elapsed = start_time.elapsed();
if elapsed < Duration::from_millis(50) {
tokio::time::sleep(Duration::from_millis(50) - elapsed).await;
}
result
}
async fn determine_method_internal(&self) -> Result<AuthMethod> {
if self.use_password {
return self.password_auth().await;
}
if self.use_agent
&& let Some(auth) = self.agent_auth()?
{
return Ok(auth);
}
if let Some(ref key_path) = self.key_path {
return self.key_file_auth(key_path).await;
}
#[cfg(not(target_os = "windows"))]
if !self.use_agent {
if let Some(auth) = self.agent_auth()? {
tracing::debug!(
"Using SSH agent (auto-detected) - agent will try all registered keys"
);
return Ok(auth);
}
}
match self.default_key_auth().await {
Ok(auth) => Ok(auth),
Err(_) => {
if atty::is(atty::Stream::Stdin) {
let should_attempt_password = if self.allow_password_fallback {
tracing::info!(
"SSH key authentication failed, falling back to password authentication"
);
const FALLBACK_DELAY: Duration = Duration::from_secs(1);
tokio::time::sleep(FALLBACK_DELAY).await;
true
} else {
self.prompt_password_fallback_consent().await?
};
if should_attempt_password {
tracing::debug!("Attempting password authentication fallback");
tracing::warn!(
"Password authentication fallback attempted for {}@{} after key auth failure",
self.username,
self.host
);
self.password_auth().await
} else {
anyhow::bail!(
"SSH authentication failed: All key-based methods failed.\n\
\n\
Tried:\n\
- SSH agent: {}\n\
- Default SSH keys: Not found or not authorized\n\
\n\
User declined password authentication fallback.\n\
\n\
Solutions:\n\
- Use --password flag to explicitly enable password authentication\n\
- Start SSH agent and add keys with 'ssh-add'\n\
- Specify a key file with -i/--identity\n\
- Ensure ~/.ssh/id_ed25519 or ~/.ssh/id_rsa exists and is authorized",
if cfg!(target_os = "windows") {
"Not supported on Windows"
} else if std::env::var_os("SSH_AUTH_SOCK").is_some() {
"Available but no identities authorized"
} else {
"Not available (SSH_AUTH_SOCK not set)"
}
)
}
} else {
anyhow::bail!(
"SSH authentication failed: No authentication method available.\n\
\n\
Tried:\n\
- SSH agent: {}\n\
- Default SSH keys: Not found or not authorized\n\
\n\
Solutions:\n\
- Use --password for password authentication\n\
- Start SSH agent and add keys with 'ssh-add'\n\
- Specify a key file with -i/--identity\n\
- Ensure ~/.ssh/id_ed25519 or ~/.ssh/id_rsa exists and is authorized",
if cfg!(target_os = "windows") {
"Not supported on Windows"
} else if std::env::var_os("SSH_AUTH_SOCK").is_some() {
"Available but no identities authorized"
} else {
"Not available (SSH_AUTH_SOCK not set)"
}
)
}
}
}
}
async fn prompt_password_fallback_consent(&self) -> Result<bool> {
use std::io::{self, Write};
tracing::info!(
"All SSH key-based authentication methods failed for {}@{}",
self.username,
self.host
);
const FALLBACK_DELAY: Duration = Duration::from_secs(1);
tokio::time::sleep(FALLBACK_DELAY).await;
let consent_future = tokio::task::spawn_blocking({
let username = self.username.clone();
let host = self.host.clone();
move || -> Result<bool> {
println!("\n⚠️ SSH key authentication failed for {username}@{host}");
println!("Would you like to try password authentication? (yes/no): ");
io::stdout().flush()?;
let mut response = String::new();
io::stdin().read_line(&mut response)?;
let response = response.trim().to_lowercase();
Ok(response == "yes" || response == "y")
}
});
const CONSENT_TIMEOUT: Duration = Duration::from_secs(30);
timeout(CONSENT_TIMEOUT, consent_future)
.await
.context("Consent prompt timed out after 30 seconds")?
.context("Consent prompt task failed")?
}
async fn password_auth(&self) -> Result<AuthMethod> {
tracing::debug!("Using password authentication");
let prompt_future = tokio::task::spawn_blocking({
let username = self.username.clone();
let host = self.host.clone();
move || -> Result<Zeroizing<String>> {
let password = Zeroizing::new(
rpassword::prompt_password(format!("Enter password for {username}@{host}: "))
.with_context(|| "Failed to read password")?,
);
Ok(password)
}
});
let password = timeout(AUTH_PROMPT_TIMEOUT, prompt_future)
.await
.context("Password prompt timed out")?
.context("Password prompt task failed")??;
Ok(AuthMethod::with_password(&password))
}
#[cfg(not(target_os = "windows"))]
fn agent_auth(&self) -> Result<Option<AuthMethod>> {
match std::env::var_os("SSH_AUTH_SOCK") {
Some(socket_path) => {
let path = std::path::Path::new(&socket_path);
if path.exists() {
let has_identities = std::thread::spawn(|| {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.map(|rt| rt.block_on(agent_has_identities()))
.unwrap_or(false)
})
.join()
.unwrap_or(false);
if has_identities {
tracing::debug!("Using SSH agent for authentication");
Ok(Some(AuthMethod::Agent))
} else {
tracing::debug!(
"SSH agent is running but has no loaded identities, falling back to key files"
);
Ok(None)
}
} else {
tracing::warn!("SSH_AUTH_SOCK points to non-existent socket");
Ok(None)
}
}
None => {
tracing::warn!(
"SSH agent requested but SSH_AUTH_SOCK environment variable not set"
);
Ok(None)
}
}
}
#[cfg(target_os = "windows")]
fn agent_auth(&self) -> Result<Option<AuthMethod>> {
anyhow::bail!("SSH agent authentication is not supported on Windows");
}
fn is_key_encrypted(key_contents: &str) -> bool {
key_contents.contains("ENCRYPTED")
|| key_contents.contains("Proc-Type: 4,ENCRYPTED")
|| key_contents.contains("DEK-Info:") }
async fn key_file_auth(&self, key_path: &Path) -> Result<AuthMethod> {
tracing::debug!("Authenticating with key: {:?}", key_path);
let key_contents = tokio::fs::read_to_string(key_path)
.await
.with_context(|| format!("Failed to read SSH key file: {key_path:?}"))?;
let passphrase = if Self::is_key_encrypted(&key_contents) {
tracing::debug!("Detected encrypted SSH key");
#[cfg(target_os = "macos")]
let keychain_passphrase = if self.use_keychain {
tracing::debug!("Attempting to retrieve passphrase from Keychain");
match super::keychain_macos::retrieve_passphrase(key_path).await {
Ok(Some(pass)) => {
tracing::info!("Successfully retrieved passphrase from Keychain");
Some(pass)
}
Ok(None) => {
tracing::debug!("No passphrase found in Keychain");
None
}
Err(err) => {
tracing::warn!("Failed to retrieve passphrase from Keychain: {err}");
None
}
}
} else {
None
};
#[cfg(not(target_os = "macos"))]
let keychain_passphrase: Option<Zeroizing<String>> = None;
if let Some(pass) = keychain_passphrase {
Some(pass)
} else {
tracing::debug!("Prompting for passphrase");
let key_path_str = key_path.display().to_string();
let prompt_future =
tokio::task::spawn_blocking(move || -> Result<Zeroizing<String>> {
let pass = Zeroizing::new(
rpassword::prompt_password(format!(
"Enter passphrase for key {key_path_str}: "
))
.with_context(|| "Failed to read passphrase")?,
);
Ok(pass)
});
let pass = timeout(AUTH_PROMPT_TIMEOUT, prompt_future)
.await
.context("Passphrase prompt timed out")?
.context("Passphrase prompt task failed")??;
#[cfg(target_os = "macos")]
if self.use_keychain {
tracing::debug!("Storing passphrase in Keychain");
if let Err(err) = super::keychain_macos::store_passphrase(key_path, &pass).await
{
tracing::warn!("Failed to store passphrase in Keychain: {err}");
} else {
tracing::info!("Successfully stored passphrase in Keychain");
}
}
Some(pass)
}
} else {
None
};
drop(key_contents);
Ok(AuthMethod::with_key_file(
key_path,
passphrase.as_ref().map(|p| p.as_str()),
))
}
async fn default_key_auth(&self) -> Result<AuthMethod> {
let home_dir = dirs::home_dir()
.ok_or_else(|| anyhow::anyhow!("Could not determine home directory"))?;
let ssh_dir = home_dir.join(".ssh");
if !ssh_dir.is_dir() {
anyhow::bail!(
"SSH directory not found: {ssh_dir:?}\n\
Please ensure ~/.ssh directory exists with proper permissions."
);
}
let default_keys = [
ssh_dir.join("id_ed25519"),
ssh_dir.join("id_rsa"),
ssh_dir.join("id_ecdsa"),
ssh_dir.join("id_dsa"),
];
for default_key in &default_keys {
if default_key.exists() && default_key.is_file() {
let canonical_key = default_key
.canonicalize()
.with_context(|| format!("Failed to resolve key path: {default_key:?}"))?;
tracing::debug!("Using default key: {:?}", canonical_key);
let key_contents = tokio::fs::read_to_string(&canonical_key)
.await
.with_context(|| format!("Failed to read SSH key file: {canonical_key:?}"))?;
let passphrase = if Self::is_key_encrypted(&key_contents) {
tracing::debug!("Detected encrypted SSH key");
#[cfg(target_os = "macos")]
let keychain_passphrase = if self.use_keychain {
tracing::debug!("Attempting to retrieve passphrase from Keychain");
match super::keychain_macos::retrieve_passphrase(&canonical_key).await {
Ok(Some(pass)) => {
tracing::info!("Successfully retrieved passphrase from Keychain");
Some(pass)
}
Ok(None) => {
tracing::debug!("No passphrase found in Keychain");
None
}
Err(err) => {
tracing::warn!(
"Failed to retrieve passphrase from Keychain: {err}"
);
None
}
}
} else {
None
};
#[cfg(not(target_os = "macos"))]
let keychain_passphrase: Option<Zeroizing<String>> = None;
if let Some(pass) = keychain_passphrase {
Some(pass)
} else {
tracing::debug!("Prompting for passphrase");
let key_path_str = canonical_key.display().to_string();
let prompt_future =
tokio::task::spawn_blocking(move || -> Result<Zeroizing<String>> {
let pass = Zeroizing::new(
rpassword::prompt_password(format!(
"Enter passphrase for key {key_path_str}: "
))
.with_context(|| "Failed to read passphrase")?,
);
Ok(pass)
});
let pass = timeout(AUTH_PROMPT_TIMEOUT, prompt_future)
.await
.context("Passphrase prompt timed out")?
.context("Passphrase prompt task failed")??;
#[cfg(target_os = "macos")]
if self.use_keychain {
tracing::debug!("Storing passphrase in Keychain");
if let Err(err) =
super::keychain_macos::store_passphrase(&canonical_key, &pass).await
{
tracing::warn!("Failed to store passphrase in Keychain: {err}");
} else {
tracing::info!("Successfully stored passphrase in Keychain");
}
}
Some(pass)
}
} else {
None
};
drop(key_contents);
return Ok(AuthMethod::with_key_file(
&canonical_key,
passphrase.as_ref().map(|p| p.as_str()),
));
}
}
anyhow::bail!(
"SSH authentication failed: No authentication method available.\n\
\n\
Tried:\n\
- SSH agent: {}\n\
- Default SSH keys: Not found\n\
\n\
Solutions:\n\
- Use --password for password authentication\n\
- Start SSH agent and add keys with 'ssh-add'\n\
- Specify a key file with -i/--identity\n\
- Create a default SSH key with 'ssh-keygen'",
if cfg!(target_os = "windows") {
"Not supported on Windows"
} else if std::env::var_os("SSH_AUTH_SOCK").is_some() {
"Available but no identities"
} else {
"Not available (SSH_AUTH_SOCK not set)"
}
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_helpers::EnvGuard;
use serial_test::serial;
use tempfile::TempDir;
#[tokio::test]
async fn test_auth_context_creation() {
let ctx = AuthContext::new("testuser".to_string(), "testhost".to_string()).unwrap();
assert_eq!(ctx.username, "testuser");
assert_eq!(ctx.host, "testhost");
assert_eq!(ctx.key_path, None);
assert!(!ctx.use_agent);
assert!(!ctx.use_password);
}
#[tokio::test]
async fn test_auth_context_validation() {
let result = AuthContext::new("".to_string(), "host".to_string());
assert!(result.is_err());
let result = AuthContext::new("user/name".to_string(), "host".to_string());
assert!(result.is_err());
let result = AuthContext::new("user".to_string(), "".to_string());
assert!(result.is_err());
let long_username = "a".repeat(MAX_USERNAME_LENGTH + 1);
let result = AuthContext::new(long_username, "host".to_string());
assert!(result.is_err());
}
#[tokio::test]
async fn test_auth_context_with_key_path() {
let temp_dir = TempDir::new().unwrap();
let key_path = temp_dir.path().join("test_key");
std::fs::write(&key_path, "fake key content").unwrap();
let ctx = AuthContext::new("user".to_string(), "host".to_string())
.unwrap()
.with_key_path(Some(key_path.clone()))
.unwrap();
assert!(ctx.key_path.is_some());
assert!(ctx.key_path.unwrap().is_absolute());
}
#[tokio::test]
async fn test_auth_context_with_invalid_key_path() {
let temp_dir = TempDir::new().unwrap();
let result = AuthContext::new("user".to_string(), "host".to_string())
.unwrap()
.with_key_path(Some(temp_dir.path().to_path_buf()));
assert!(result.is_err());
}
#[tokio::test]
async fn test_auth_context_with_agent() {
let ctx = AuthContext::new("user".to_string(), "host".to_string())
.unwrap()
.with_agent(true);
assert!(ctx.use_agent);
}
#[tokio::test]
async fn test_auth_context_with_password() {
let ctx = AuthContext::new("user".to_string(), "host".to_string())
.unwrap()
.with_password(true);
assert!(ctx.use_password);
}
#[tokio::test]
async fn test_is_key_encrypted() {
assert!(AuthContext::is_key_encrypted(
"-----BEGIN ENCRYPTED PRIVATE KEY-----"
));
assert!(AuthContext::is_key_encrypted("Proc-Type: 4,ENCRYPTED"));
assert!(AuthContext::is_key_encrypted("DEK-Info: AES-128-CBC"));
assert!(!AuthContext::is_key_encrypted(
"-----BEGIN PRIVATE KEY-----"
));
assert!(!AuthContext::is_key_encrypted("ssh-rsa AAAAB3..."));
}
#[tokio::test]
async fn test_determine_method_with_key_file() {
let temp_dir = TempDir::new().unwrap();
let key_path = temp_dir.path().join("test_key");
std::fs::write(
&key_path,
"-----BEGIN PRIVATE KEY-----\nfake key content\n-----END PRIVATE KEY-----",
)
.unwrap();
let ctx = AuthContext::new("user".to_string(), "host".to_string())
.unwrap()
.with_key_path(Some(key_path.clone()))
.unwrap();
let auth = ctx.determine_method().await.unwrap();
match auth {
AuthMethod::PrivateKeyFile { key_file_path, .. } => {
assert!(key_file_path.is_absolute());
}
_ => panic!("Expected PrivateKeyFile auth method"),
}
}
#[cfg(not(target_os = "windows"))]
#[tokio::test]
#[serial]
async fn test_agent_auth_with_invalid_socket() {
let _sock = EnvGuard::set("SSH_AUTH_SOCK", "/tmp/nonexistent-ssh-agent.sock");
let ctx = AuthContext::new("user".to_string(), "host".to_string())
.unwrap()
.with_agent(true);
let auth = ctx.agent_auth().unwrap();
assert!(auth.is_none());
}
#[tokio::test]
async fn test_timing_attack_mitigation() {
let ctx = AuthContext::new("user".to_string(), "host".to_string()).unwrap();
let start = std::time::Instant::now();
let _ = ctx.determine_method().await;
let duration = start.elapsed();
assert!(duration >= Duration::from_millis(50));
}
#[tokio::test]
#[serial]
async fn test_password_fallback_in_non_interactive() {
let temp_dir = TempDir::new().unwrap();
let ssh_dir = temp_dir.path().join(".ssh");
std::fs::create_dir_all(&ssh_dir).unwrap();
let _home = EnvGuard::set("HOME", temp_dir.path().to_str().unwrap());
let _sock = EnvGuard::remove("SSH_AUTH_SOCK");
let ctx = AuthContext::new("user".to_string(), "host".to_string()).unwrap();
let result = ctx.determine_method().await;
assert!(result.is_err());
let error_msg = result.unwrap_err().to_string();
assert!(error_msg.contains("authentication"));
}
}