use super::core::SshClient;
use crate::jump::{parse_jump_hosts, JumpHostChain};
use crate::ssh::known_hosts::StrictHostKeyChecking;
use crate::ssh::tokio_client::{AuthMethod, Client, SshConnectionConfig};
use anyhow::{Context, Result};
use std::path::Path;
use std::time::Duration;
const SSH_CONNECT_TIMEOUT_SECS: u64 = 30;
impl SshClient {
pub(super) async fn determine_auth_method(
&self,
key_path: Option<&Path>,
use_agent: bool,
use_password: bool,
#[cfg(target_os = "macos")] use_keychain: bool,
) -> Result<AuthMethod> {
let mut auth_ctx =
crate::ssh::auth::AuthContext::new(self.username.clone(), self.host.clone())
.with_context(|| {
format!("Invalid credentials for {}@{}", self.username, self.host)
})?;
if let Some(path) = key_path {
auth_ctx = auth_ctx
.with_key_path(Some(path.to_path_buf()))
.with_context(|| format!("Invalid SSH key path: {path:?}"))?;
}
auth_ctx = auth_ctx.with_agent(use_agent).with_password(use_password);
#[cfg(target_os = "macos")]
{
auth_ctx = auth_ctx.with_keychain(use_keychain);
}
auth_ctx.determine_method().await
}
pub(super) async fn connect_direct(
&self,
auth_method: &AuthMethod,
strict_mode: StrictHostKeyChecking,
connect_timeout_seconds: Option<u64>,
ssh_connection_config: Option<&SshConnectionConfig>,
) -> Result<Client> {
const RATE_LIMIT_DELAY: Duration = Duration::from_millis(100);
tokio::time::sleep(RATE_LIMIT_DELAY).await;
let start_time = std::time::Instant::now();
let addr = (self.host.as_str(), self.port);
let check_method = crate::ssh::known_hosts::get_check_method(strict_mode);
let connect_timeout =
Duration::from_secs(connect_timeout_seconds.unwrap_or(SSH_CONNECT_TIMEOUT_SECS));
let default_conn_cfg;
let conn_cfg = match ssh_connection_config {
Some(c) => c,
None => {
default_conn_cfg = SshConnectionConfig::default();
&default_conn_cfg
}
};
let result = match tokio::time::timeout(
connect_timeout,
Client::connect_with_ssh_config(
addr,
&self.username,
auth_method.clone(),
check_method,
conn_cfg,
),
)
.await
{
Ok(Ok(client)) => Ok(client),
Ok(Err(e)) => {
let error_msg = match &e {
crate::ssh::tokio_client::Error::KeyAuthFailed => {
"Authentication failed. The private key was rejected by the server.".to_string()
}
crate::ssh::tokio_client::Error::PasswordWrong => {
"Password authentication failed.".to_string()
}
crate::ssh::tokio_client::Error::ServerCheckFailed => {
"Host key verification failed. The server's host key was not recognized or has changed.".to_string()
}
crate::ssh::tokio_client::Error::KeyInvalid(key_err) => {
format!("Failed to load SSH key: {key_err}. Please check the key file format and passphrase.")
}
crate::ssh::tokio_client::Error::AgentConnectionFailed => {
"Failed to connect to SSH agent. Please ensure SSH_AUTH_SOCK is set and the agent is running.".to_string()
}
crate::ssh::tokio_client::Error::AgentNoIdentities => {
"SSH agent has no identities. Please add your key to the agent using 'ssh-add'.".to_string()
}
crate::ssh::tokio_client::Error::AgentAuthenticationFailed => {
"SSH agent authentication failed.".to_string()
}
crate::ssh::tokio_client::Error::SshError(ssh_err) => {
format!("SSH connection error: {ssh_err}")
}
_ => {
format!("Failed to connect: {e}")
}
};
Err(anyhow::anyhow!(error_msg).context(e))
}
Err(_) => Err(anyhow::anyhow!(
"Connection timeout after {} seconds. \
Please check if the host is reachable and SSH service is running.",
connect_timeout.as_secs()
)),
};
const MIN_AUTH_DURATION: Duration = Duration::from_millis(500);
let elapsed = start_time.elapsed();
if elapsed < MIN_AUTH_DURATION {
tokio::time::sleep(MIN_AUTH_DURATION - elapsed).await;
}
result
}
#[allow(clippy::too_many_arguments)]
pub(super) async fn connect_via_jump_hosts(
&self,
jump_hosts: &[crate::jump::parser::JumpHost],
auth_method: &AuthMethod,
strict_mode: StrictHostKeyChecking,
key_path: Option<&Path>,
use_agent: bool,
use_password: bool,
connect_timeout_seconds: Option<u64>,
ssh_connection_config: Option<&SshConnectionConfig>,
) -> Result<Client> {
let connect_timeout =
Duration::from_secs(connect_timeout_seconds.unwrap_or(SSH_CONNECT_TIMEOUT_SECS));
let mut chain = JumpHostChain::new(jump_hosts.to_vec())
.with_connect_timeout(connect_timeout)
.with_command_timeout(Duration::from_secs(300));
if let Some(cfg) = ssh_connection_config {
chain = chain.with_ssh_connection_config(cfg.clone());
}
let connection = chain
.connect(
&self.host,
self.port,
&self.username,
auth_method.clone(),
key_path,
Some(strict_mode),
use_agent,
use_password,
)
.await
.with_context(|| {
format!(
"Failed to establish jump host connection to {}:{}",
self.host, self.port
)
})?;
tracing::info!(
"Jump host connection established: {}",
connection.jump_info.path_description()
);
Ok(connection.client)
}
#[allow(clippy::too_many_arguments)]
pub(super) async fn establish_connection(
&self,
auth_method: &AuthMethod,
strict_mode: StrictHostKeyChecking,
jump_hosts_spec: Option<&str>,
key_path: Option<&Path>,
use_agent: bool,
use_password: bool,
connect_timeout_seconds: Option<u64>,
ssh_connection_config: Option<&SshConnectionConfig>,
) -> Result<Client> {
if let Some(jump_spec) = jump_hosts_spec {
let jump_hosts = parse_jump_hosts(jump_spec).with_context(|| {
format!("Failed to parse jump host specification: '{jump_spec}'")
})?;
if jump_hosts.is_empty() {
tracing::debug!("No valid jump hosts found, using direct connection");
self.connect_direct(
auth_method,
strict_mode,
connect_timeout_seconds,
ssh_connection_config,
)
.await
} else {
tracing::info!(
"Connecting to {}:{} via {} jump host(s): {}",
self.host,
self.port,
jump_hosts.len(),
jump_hosts
.iter()
.map(|j| j.to_string())
.collect::<Vec<_>>()
.join(" -> ")
);
self.connect_via_jump_hosts(
&jump_hosts,
auth_method,
strict_mode,
key_path,
use_agent,
use_password,
connect_timeout_seconds,
ssh_connection_config,
)
.await
}
} else {
tracing::debug!("Using direct connection (no jump hosts)");
self.connect_direct(
auth_method,
strict_mode,
connect_timeout_seconds,
ssh_connection_config,
)
.await
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serial_test::serial;
use tempfile::TempDir;
#[tokio::test]
async fn test_determine_auth_method_with_key() {
let temp_dir = TempDir::new().unwrap();
let key_path = temp_dir.path().join("test_key");
std::fs::write(&key_path, "fake key content").unwrap();
let client = SshClient::new("test.com".to_string(), 22, "user".to_string());
let auth = client
.determine_auth_method(
Some(&key_path),
false,
false,
#[cfg(target_os = "macos")]
false,
)
.await
.unwrap();
match auth {
AuthMethod::PrivateKeyFile { key_file_path, .. } => {
assert!(key_file_path.is_absolute());
}
_ => panic!("Expected PrivateKeyFile auth method"),
}
}
#[cfg(target_os = "macos")]
#[tokio::test]
#[serial]
async fn test_determine_auth_method_with_agent() {
use std::os::unix::net::UnixListener;
let original_ssh_auth_sock = std::env::var("SSH_AUTH_SOCK").ok();
let original_home = std::env::var("HOME").ok();
let temp_dir = TempDir::new().unwrap();
let socket_path = temp_dir.path().join("ssh-agent.sock");
let _listener = UnixListener::bind(&socket_path).unwrap();
let ssh_dir = temp_dir.path().join(".ssh");
std::fs::create_dir_all(&ssh_dir).unwrap();
let key_content =
"-----BEGIN PRIVATE KEY-----\nfake key content\n-----END PRIVATE KEY-----";
std::fs::write(ssh_dir.join("id_rsa"), key_content).unwrap();
std::env::set_var("SSH_AUTH_SOCK", socket_path.to_str().unwrap());
std::env::set_var("HOME", temp_dir.path());
let client = SshClient::new("test.com".to_string(), 22, "user".to_string());
let auth = client
.determine_auth_method(
None,
true,
false,
#[cfg(target_os = "macos")]
false,
)
.await
.unwrap();
if let Some(sock) = original_ssh_auth_sock {
std::env::set_var("SSH_AUTH_SOCK", sock);
} else {
std::env::remove_var("SSH_AUTH_SOCK");
}
if let Some(home) = original_home {
std::env::set_var("HOME", home);
}
match auth {
AuthMethod::Agent => {
}
AuthMethod::PrivateKeyFile { .. } => {
}
_ => panic!("Expected Agent or PrivateKeyFile auth method"),
}
}
#[cfg(target_os = "linux")]
#[tokio::test]
#[serial]
async fn test_determine_auth_method_with_agent() {
use std::os::unix::net::UnixListener;
let original_ssh_auth_sock = std::env::var("SSH_AUTH_SOCK").ok();
let original_home = std::env::var("HOME").ok();
let temp_dir = TempDir::new().unwrap();
let socket_path = temp_dir.path().join("ssh-agent.sock");
let _listener = UnixListener::bind(&socket_path).unwrap();
let ssh_dir = temp_dir.path().join(".ssh");
std::fs::create_dir_all(&ssh_dir).unwrap();
let key_content =
"-----BEGIN PRIVATE KEY-----\nfake key content\n-----END PRIVATE KEY-----";
std::fs::write(ssh_dir.join("id_rsa"), key_content).unwrap();
std::env::set_var("SSH_AUTH_SOCK", socket_path.to_str().unwrap());
std::env::set_var("HOME", temp_dir.path());
let client = SshClient::new("test.com".to_string(), 22, "user".to_string());
let auth = client
.determine_auth_method(None, true, false)
.await
.unwrap();
if let Some(sock) = original_ssh_auth_sock {
std::env::set_var("SSH_AUTH_SOCK", sock);
} else {
std::env::remove_var("SSH_AUTH_SOCK");
}
if let Some(home) = original_home {
std::env::set_var("HOME", home);
}
match auth {
AuthMethod::Agent => {
}
AuthMethod::PrivateKeyFile { .. } => {
}
_ => panic!("Expected Agent or PrivateKeyFile auth method"),
}
}
#[test]
fn test_determine_auth_method_with_password() {
let _client = SshClient::new("test.com".to_string(), 22, "user".to_string());
}
#[tokio::test]
#[serial]
async fn test_determine_auth_method_fallback_to_default() {
let original_home = std::env::var("HOME").ok();
let original_ssh_auth_sock = std::env::var("SSH_AUTH_SOCK").ok();
let temp_dir = TempDir::new().unwrap();
let ssh_dir = temp_dir.path().join(".ssh");
std::fs::create_dir_all(&ssh_dir).unwrap();
let default_key = ssh_dir.join("id_rsa");
std::fs::write(&default_key, "fake key").unwrap();
std::env::set_var("HOME", temp_dir.path().to_str().unwrap());
std::env::remove_var("SSH_AUTH_SOCK");
let client = SshClient::new("test.com".to_string(), 22, "user".to_string());
let auth = client
.determine_auth_method(
None,
false,
false,
#[cfg(target_os = "macos")]
false,
)
.await
.unwrap();
if let Some(home) = original_home {
std::env::set_var("HOME", home);
} else {
std::env::remove_var("HOME");
}
if let Some(sock) = original_ssh_auth_sock {
std::env::set_var("SSH_AUTH_SOCK", sock);
}
match auth {
AuthMethod::PrivateKeyFile { key_file_path, .. } => {
assert!(key_file_path.is_absolute());
}
_ => panic!("Expected PrivateKeyFile auth method"),
}
}
}