use anyhow::{Context, Result};
use std::path::{Path, PathBuf};
use std::time::Duration;
use tokio::time::timeout;
use zeroize::Zeroizing;
use super::tokio_client::AuthMethod;
const AUTH_PROMPT_TIMEOUT: Duration = Duration::from_secs(30);
const MAX_USERNAME_LENGTH: usize = 256;
const MAX_HOSTNAME_LENGTH: usize = 253;
#[derive(Debug, Clone)]
pub struct AuthContext {
pub key_path: Option<PathBuf>,
pub use_agent: bool,
pub use_password: bool,
pub username: String,
pub host: String,
}
impl AuthContext {
pub fn new(username: String, host: String) -> Result<Self> {
if username.is_empty() {
anyhow::bail!("Username cannot be empty");
}
if username.len() > MAX_USERNAME_LENGTH {
anyhow::bail!("Username too long (max {MAX_USERNAME_LENGTH} characters)");
}
if username.contains(['/', '\0', '\n', '\r']) {
anyhow::bail!("Username contains invalid characters");
}
if host.is_empty() {
anyhow::bail!("Hostname cannot be empty");
}
if host.len() > MAX_HOSTNAME_LENGTH {
anyhow::bail!("Hostname too long (max {MAX_HOSTNAME_LENGTH} characters)");
}
if host.contains(['\0', '\n', '\r']) {
anyhow::bail!("Hostname contains invalid characters");
}
Ok(Self {
key_path: None,
use_agent: false,
use_password: false,
username,
host,
})
}
pub fn with_key_path(mut self, key_path: Option<PathBuf>) -> Result<Self> {
if let Some(path) = key_path {
let canonical_path = path
.canonicalize()
.with_context(|| format!("Failed to resolve SSH key path: {path:?}"))?;
if !canonical_path.is_file() {
anyhow::bail!("SSH key path is not a file: {canonical_path:?}");
}
self.key_path = Some(canonical_path);
} else {
self.key_path = None;
}
Ok(self)
}
pub fn with_agent(mut self, use_agent: bool) -> Self {
self.use_agent = use_agent;
self
}
pub fn with_password(mut self, use_password: bool) -> Self {
self.use_password = use_password;
self
}
pub async fn determine_method(&self) -> Result<AuthMethod> {
let start_time = std::time::Instant::now();
let result = self.determine_method_internal().await;
let elapsed = start_time.elapsed();
if elapsed < Duration::from_millis(50) {
tokio::time::sleep(Duration::from_millis(50) - elapsed).await;
}
result
}
async fn determine_method_internal(&self) -> Result<AuthMethod> {
if self.use_password {
return self.password_auth().await;
}
if self.use_agent {
if let Some(auth) = self.agent_auth()? {
return Ok(auth);
}
}
if let Some(ref key_path) = self.key_path {
return self.key_file_auth(key_path).await;
}
#[cfg(not(target_os = "windows"))]
if self.use_agent {
if let Some(auth) = self.agent_auth()? {
return Ok(auth);
}
}
self.default_key_auth().await
}
async fn password_auth(&self) -> Result<AuthMethod> {
tracing::debug!("Using password authentication");
let prompt_future = tokio::task::spawn_blocking({
let username = self.username.clone();
let host = self.host.clone();
move || -> Result<Zeroizing<String>> {
let password = Zeroizing::new(
rpassword::prompt_password(format!("Enter password for {username}@{host}: "))
.with_context(|| "Failed to read password")?,
);
Ok(password)
}
});
let password = timeout(AUTH_PROMPT_TIMEOUT, prompt_future)
.await
.context("Password prompt timed out")?
.context("Password prompt task failed")??;
Ok(AuthMethod::with_password(&password))
}
#[cfg(not(target_os = "windows"))]
fn agent_auth(&self) -> Result<Option<AuthMethod>> {
match std::env::var_os("SSH_AUTH_SOCK") {
Some(socket_path) => {
let path = std::path::Path::new(&socket_path);
if path.exists() {
tracing::debug!("Using SSH agent for authentication");
Ok(Some(AuthMethod::Agent))
} else {
tracing::warn!("SSH_AUTH_SOCK points to non-existent socket");
Ok(None)
}
}
None => {
tracing::warn!(
"SSH agent requested but SSH_AUTH_SOCK environment variable not set"
);
Ok(None)
}
}
}
#[cfg(target_os = "windows")]
fn agent_auth(&self) -> Result<Option<AuthMethod>> {
anyhow::bail!("SSH agent authentication is not supported on Windows");
}
fn is_key_encrypted(key_contents: &str) -> bool {
key_contents.contains("ENCRYPTED")
|| key_contents.contains("Proc-Type: 4,ENCRYPTED")
|| key_contents.contains("DEK-Info:") }
async fn key_file_auth(&self, key_path: &Path) -> Result<AuthMethod> {
tracing::debug!("Authenticating with key: {:?}", key_path);
let key_contents = tokio::fs::read_to_string(key_path)
.await
.with_context(|| format!("Failed to read SSH key file: {key_path:?}"))?;
let passphrase = if Self::is_key_encrypted(&key_contents) {
tracing::debug!("Detected encrypted SSH key, prompting for passphrase");
let key_path_str = key_path.display().to_string();
let prompt_future =
tokio::task::spawn_blocking(move || -> Result<Zeroizing<String>> {
let pass = Zeroizing::new(
rpassword::prompt_password(format!(
"Enter passphrase for key {key_path_str}: "
))
.with_context(|| "Failed to read passphrase")?,
);
Ok(pass)
});
let pass = timeout(AUTH_PROMPT_TIMEOUT, prompt_future)
.await
.context("Passphrase prompt timed out")?
.context("Passphrase prompt task failed")??;
Some(pass)
} else {
None
};
drop(key_contents);
Ok(AuthMethod::with_key_file(
key_path,
passphrase.as_ref().map(|p| p.as_str()),
))
}
async fn default_key_auth(&self) -> Result<AuthMethod> {
let home_dir = dirs::home_dir()
.ok_or_else(|| anyhow::anyhow!("Could not determine home directory"))?;
let ssh_dir = home_dir.join(".ssh");
if !ssh_dir.is_dir() {
anyhow::bail!(
"SSH directory not found: {ssh_dir:?}\n\
Please ensure ~/.ssh directory exists with proper permissions."
);
}
let default_keys = [
ssh_dir.join("id_ed25519"),
ssh_dir.join("id_rsa"),
ssh_dir.join("id_ecdsa"),
ssh_dir.join("id_dsa"),
];
for default_key in &default_keys {
if default_key.exists() && default_key.is_file() {
let canonical_key = default_key
.canonicalize()
.with_context(|| format!("Failed to resolve key path: {default_key:?}"))?;
tracing::debug!("Using default key: {:?}", canonical_key);
let key_contents = tokio::fs::read_to_string(&canonical_key)
.await
.with_context(|| format!("Failed to read SSH key file: {canonical_key:?}"))?;
let passphrase = if Self::is_key_encrypted(&key_contents) {
tracing::debug!("Detected encrypted SSH key, prompting for passphrase");
let key_path_str = canonical_key.display().to_string();
let prompt_future =
tokio::task::spawn_blocking(move || -> Result<Zeroizing<String>> {
let pass = Zeroizing::new(
rpassword::prompt_password(format!(
"Enter passphrase for key {key_path_str}: "
))
.with_context(|| "Failed to read passphrase")?,
);
Ok(pass)
});
let pass = timeout(AUTH_PROMPT_TIMEOUT, prompt_future)
.await
.context("Passphrase prompt timed out")?
.context("Passphrase prompt task failed")??;
Some(pass)
} else {
None
};
drop(key_contents);
return Ok(AuthMethod::with_key_file(
&canonical_key,
passphrase.as_ref().map(|p| p.as_str()),
));
}
}
anyhow::bail!(
"SSH authentication failed: No authentication method available.\n\
\n\
Tried:\n\
- SSH agent: {}\n\
- Default SSH keys: Not found\n\
\n\
Solutions:\n\
- Use --password for password authentication\n\
- Start SSH agent and add keys with 'ssh-add'\n\
- Specify a key file with -i/--identity\n\
- Create a default SSH key with 'ssh-keygen'",
if cfg!(target_os = "windows") {
"Not supported on Windows"
} else if std::env::var_os("SSH_AUTH_SOCK").is_some() {
"Available but no identities"
} else {
"Not available (SSH_AUTH_SOCK not set)"
}
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[tokio::test]
async fn test_auth_context_creation() {
let ctx = AuthContext::new("testuser".to_string(), "testhost".to_string()).unwrap();
assert_eq!(ctx.username, "testuser");
assert_eq!(ctx.host, "testhost");
assert_eq!(ctx.key_path, None);
assert!(!ctx.use_agent);
assert!(!ctx.use_password);
}
#[tokio::test]
async fn test_auth_context_validation() {
let result = AuthContext::new("".to_string(), "host".to_string());
assert!(result.is_err());
let result = AuthContext::new("user/name".to_string(), "host".to_string());
assert!(result.is_err());
let result = AuthContext::new("user".to_string(), "".to_string());
assert!(result.is_err());
let long_username = "a".repeat(MAX_USERNAME_LENGTH + 1);
let result = AuthContext::new(long_username, "host".to_string());
assert!(result.is_err());
}
#[tokio::test]
async fn test_auth_context_with_key_path() {
let temp_dir = TempDir::new().unwrap();
let key_path = temp_dir.path().join("test_key");
std::fs::write(&key_path, "fake key content").unwrap();
let ctx = AuthContext::new("user".to_string(), "host".to_string())
.unwrap()
.with_key_path(Some(key_path.clone()))
.unwrap();
assert!(ctx.key_path.is_some());
assert!(ctx.key_path.unwrap().is_absolute());
}
#[tokio::test]
async fn test_auth_context_with_invalid_key_path() {
let temp_dir = TempDir::new().unwrap();
let result = AuthContext::new("user".to_string(), "host".to_string())
.unwrap()
.with_key_path(Some(temp_dir.path().to_path_buf()));
assert!(result.is_err());
}
#[tokio::test]
async fn test_auth_context_with_agent() {
let ctx = AuthContext::new("user".to_string(), "host".to_string())
.unwrap()
.with_agent(true);
assert!(ctx.use_agent);
}
#[tokio::test]
async fn test_auth_context_with_password() {
let ctx = AuthContext::new("user".to_string(), "host".to_string())
.unwrap()
.with_password(true);
assert!(ctx.use_password);
}
#[tokio::test]
async fn test_is_key_encrypted() {
assert!(AuthContext::is_key_encrypted(
"-----BEGIN ENCRYPTED PRIVATE KEY-----"
));
assert!(AuthContext::is_key_encrypted("Proc-Type: 4,ENCRYPTED"));
assert!(AuthContext::is_key_encrypted("DEK-Info: AES-128-CBC"));
assert!(!AuthContext::is_key_encrypted(
"-----BEGIN PRIVATE KEY-----"
));
assert!(!AuthContext::is_key_encrypted("ssh-rsa AAAAB3..."));
}
#[tokio::test]
async fn test_determine_method_with_key_file() {
let temp_dir = TempDir::new().unwrap();
let key_path = temp_dir.path().join("test_key");
std::fs::write(
&key_path,
"-----BEGIN PRIVATE KEY-----\nfake key content\n-----END PRIVATE KEY-----",
)
.unwrap();
let ctx = AuthContext::new("user".to_string(), "host".to_string())
.unwrap()
.with_key_path(Some(key_path.clone()))
.unwrap();
let auth = ctx.determine_method().await.unwrap();
match auth {
AuthMethod::PrivateKeyFile { key_file_path, .. } => {
assert!(key_file_path.is_absolute());
}
_ => panic!("Expected PrivateKeyFile auth method"),
}
}
#[cfg(not(target_os = "windows"))]
#[tokio::test]
async fn test_agent_auth_with_invalid_socket() {
std::env::set_var("SSH_AUTH_SOCK", "/tmp/nonexistent-ssh-agent.sock");
let ctx = AuthContext::new("user".to_string(), "host".to_string())
.unwrap()
.with_agent(true);
let auth = ctx.agent_auth().unwrap();
assert!(auth.is_none());
std::env::remove_var("SSH_AUTH_SOCK");
}
#[tokio::test]
async fn test_timing_attack_mitigation() {
let ctx = AuthContext::new("user".to_string(), "host".to_string()).unwrap();
let start = std::time::Instant::now();
let _ = ctx.determine_method().await;
let duration = start.elapsed();
assert!(duration >= Duration::from_millis(50));
}
}