use anyhow::{Context, Result};
use crossterm::terminal;
use russh::Channel;
use russh::client::Msg;
use std::io::{self, Write};
use tokio::time::{Duration, timeout};
use zeroize::Zeroizing;
use crate::jump::{JumpHostChain, parse_jump_hosts};
use crate::node::Node;
use crate::ssh::{
known_hosts::get_check_method,
tokio_client::{AuthMethod, Client, Error as SshError, ServerCheckMethod, SshConnectionConfig},
};
use super::types::{InteractiveCommand, NodeSession};
impl InteractiveCommand {
#[allow(clippy::too_many_arguments)]
async fn establish_connection(
addr: (&str, u16),
username: &str,
auth_method: AuthMethod,
check_method: ServerCheckMethod,
host: &str,
port: u16,
allow_password_fallback: bool,
ssh_config: &SshConnectionConfig,
) -> Result<Client> {
const SSH_CONNECT_TIMEOUT_SECS: u64 = 30;
let connect_timeout = Duration::from_secs(SSH_CONNECT_TIMEOUT_SECS);
const RATE_LIMIT_DELAY: Duration = Duration::from_millis(100);
tokio::time::sleep(RATE_LIMIT_DELAY).await;
let start_time = std::time::Instant::now();
let result = timeout(
connect_timeout,
Client::connect_with_ssh_config(addr, username, auth_method, check_method.clone(), ssh_config),
)
.await
.with_context(|| {
format!(
"Connection timeout: Failed to connect to {host}:{port} after {SSH_CONNECT_TIMEOUT_SECS} seconds"
)
})?;
let result = match result {
Err(ref err)
if allow_password_fallback
&& atty::is(atty::Stream::Stdin)
&& is_auth_error_for_password_fallback(err) =>
{
tracing::debug!(
"SSH authentication failed for {username}@{host}:{port} ({err}), attempting password fallback"
);
let password = Self::prompt_password(username, host).await?;
let password_auth = AuthMethod::with_password(&password);
tokio::time::sleep(Duration::from_millis(500)).await;
timeout(
connect_timeout,
Client::connect_with_ssh_config(addr, username, password_auth, check_method, ssh_config),
)
.await
.with_context(|| {
format!(
"Connection timeout: Failed to connect to {host}:{port} after {SSH_CONNECT_TIMEOUT_SECS} seconds"
)
})?
.with_context(|| format!("SSH connection failed to {host}:{port}"))
}
other => other.with_context(|| format!("SSH connection failed to {host}:{port}")),
};
const MIN_AUTH_DURATION: Duration = Duration::from_millis(500);
let elapsed = start_time.elapsed();
if elapsed < MIN_AUTH_DURATION {
tokio::time::sleep(MIN_AUTH_DURATION - elapsed).await;
}
result
}
async fn prompt_password(username: &str, host: &str) -> Result<Zeroizing<String>> {
let username = username.to_string();
let host = host.to_string();
tokio::task::spawn_blocking(move || {
let password = Zeroizing::new(
rpassword::prompt_password(format!("{username}@{host}'s password: "))
.with_context(|| "Failed to read password")?,
);
Ok(password)
})
.await
.with_context(|| "Password prompt task failed")?
}
pub(super) async fn determine_auth_method(&self, node: &Node) -> Result<AuthMethod> {
let mut auth_ctx = crate::ssh::AuthContext::new(node.username.clone(), node.host.clone())
.with_context(|| {
format!("Invalid credentials for {}@{}", node.username, node.host)
})?;
if let Some(ref path) = self.key_path {
auth_ctx = auth_ctx
.with_key_path(Some(path.clone()))
.with_context(|| format!("Invalid SSH key path: {path:?}"))?;
}
auth_ctx = auth_ctx
.with_agent(self.use_agent)
.with_password(self.use_password)
.with_password_fallback(!self.use_password);
#[cfg(target_os = "macos")]
{
auth_ctx = auth_ctx.with_keychain(self.use_keychain);
}
auth_ctx.determine_method().await
}
pub(super) fn select_nodes_to_connect(&self) -> Result<Vec<Node>> {
if self.single_node {
if self.nodes.is_empty() {
anyhow::bail!("No nodes available for connection");
}
if self.nodes.len() == 1 {
Ok(vec![self.nodes[0].clone()])
} else {
println!("Available nodes:");
for (i, node) in self.nodes.iter().enumerate() {
println!(" [{}] {}", i + 1, node);
}
print!("Select node (1-{}): ", self.nodes.len());
io::stdout().flush()?;
let mut input = String::new();
io::stdin().read_line(&mut input)?;
let selection: usize = input.trim().parse().context("Invalid node selection")?;
if selection == 0 || selection > self.nodes.len() {
anyhow::bail!("Invalid node selection");
}
Ok(vec![self.nodes[selection - 1].clone()])
}
} else {
Ok(self.nodes.clone())
}
}
pub(super) async fn connect_to_node(&self, node: Node) -> Result<NodeSession> {
let auth_method = self.determine_auth_method(&node).await?;
let check_method = get_check_method(self.strict_mode);
let addr = (node.host.as_str(), node.port);
let client = if let Some(ref jump_spec) = self.jump_hosts {
let jump_hosts = parse_jump_hosts(jump_spec).with_context(|| {
format!("Failed to parse jump host specification: '{jump_spec}'")
})?;
if jump_hosts.is_empty() {
tracing::debug!("No valid jump hosts found, using direct connection");
Self::establish_connection(
addr,
&node.username,
auth_method.clone(),
check_method.clone(),
&node.host,
node.port,
!self.use_password, &self.ssh_connection_config,
)
.await?
} else {
tracing::info!(
"Connecting to {}:{} via {} jump host(s) for interactive session",
node.host,
node.port,
jump_hosts.len()
);
const MAX_TIMEOUT_SECS: u64 = 600; const BASE_TIMEOUT: u64 = 30;
const PER_HOP_TIMEOUT: u64 = 15;
let hop_count = jump_hosts.len();
let adjusted_timeout = Duration::from_secs(
BASE_TIMEOUT
.saturating_add(PER_HOP_TIMEOUT.saturating_mul(hop_count as u64))
.min(MAX_TIMEOUT_SECS),
);
let chain = JumpHostChain::new(jump_hosts)
.with_connect_timeout(adjusted_timeout)
.with_command_timeout(Duration::from_secs(300))
.with_ssh_connection_config(self.ssh_connection_config.clone());
let connection = timeout(
adjusted_timeout,
chain.connect(
&node.host,
node.port,
&node.username,
auth_method.clone(),
self.key_path.as_deref(),
Some(self.strict_mode),
self.use_agent,
self.use_password,
),
)
.await
.with_context(|| {
format!(
"Connection timeout: Failed to connect to {}:{} via jump hosts after {} seconds",
node.host, node.port, adjusted_timeout.as_secs()
)
})?
.with_context(|| {
format!(
"Failed to establish jump host connection to {}:{}",
node.host, node.port
)
})?;
tracing::info!(
"Jump host connection established for interactive session: {}",
connection.jump_info.path_description()
);
connection.client
}
} else {
tracing::debug!("Using direct connection (no jump hosts)");
Self::establish_connection(
addr,
&node.username,
auth_method,
check_method,
&node.host,
node.port,
!self.use_password, &self.ssh_connection_config,
)
.await?
};
let (width, height) = terminal::size().unwrap_or((80, 24));
let channel = client
.request_interactive_shell("xterm-256color", u32::from(width), u32::from(height))
.await
.context("Failed to request interactive shell")?;
let working_dir = if let Some(ref dir) = self.work_dir {
let cmd = format!("cd {dir} && pwd\n");
channel.data(cmd.as_bytes()).await?;
dir.clone()
} else {
let pwd_cmd = b"pwd\n";
channel.data(&pwd_cmd[..]).await?;
String::from("~")
};
Ok(NodeSession::new(node, client, channel, working_dir))
}
pub(super) async fn connect_to_node_pty(&self, node: Node) -> Result<Channel<Msg>> {
let auth_method = self.determine_auth_method(&node).await?;
let check_method = get_check_method(self.strict_mode);
let addr = (node.host.as_str(), node.port);
let client = if let Some(ref jump_spec) = self.jump_hosts {
let jump_hosts = parse_jump_hosts(jump_spec).with_context(|| {
format!("Failed to parse jump host specification: '{jump_spec}'")
})?;
if jump_hosts.is_empty() {
tracing::debug!("No valid jump hosts found, using direct connection for PTY");
Self::establish_connection(
addr,
&node.username,
auth_method.clone(),
check_method.clone(),
&node.host,
node.port,
!self.use_password, &self.ssh_connection_config,
)
.await?
} else {
tracing::info!(
"Connecting to {}:{} via {} jump host(s) for PTY session",
node.host,
node.port,
jump_hosts.len()
);
const MAX_TIMEOUT_SECS: u64 = 600; const BASE_TIMEOUT: u64 = 30;
const PER_HOP_TIMEOUT: u64 = 15;
let hop_count = jump_hosts.len();
let adjusted_timeout = Duration::from_secs(
BASE_TIMEOUT
.saturating_add(PER_HOP_TIMEOUT.saturating_mul(hop_count as u64))
.min(MAX_TIMEOUT_SECS),
);
let chain = JumpHostChain::new(jump_hosts)
.with_connect_timeout(adjusted_timeout)
.with_command_timeout(Duration::from_secs(300))
.with_ssh_connection_config(self.ssh_connection_config.clone());
let connection = timeout(
adjusted_timeout,
chain.connect(
&node.host,
node.port,
&node.username,
auth_method.clone(),
self.key_path.as_deref(),
Some(self.strict_mode),
self.use_agent,
self.use_password,
),
)
.await
.with_context(|| {
format!(
"Connection timeout: Failed to connect to {}:{} via jump hosts after {} seconds",
node.host, node.port, adjusted_timeout.as_secs()
)
})?
.with_context(|| {
format!(
"Failed to establish jump host connection to {}:{}",
node.host, node.port
)
})?;
tracing::info!(
"Jump host connection established for PTY session: {}",
connection.jump_info.path_description()
);
connection.client
}
} else {
tracing::debug!("Using direct connection for PTY (no jump hosts)");
Self::establish_connection(
addr,
&node.username,
auth_method,
check_method,
&node.host,
node.port,
!self.use_password, &self.ssh_connection_config,
)
.await?
};
let (width, height) = crate::pty::utils::get_terminal_size().unwrap_or((80, 24));
let channel = client
.request_interactive_shell(&self.pty_config.term_type, width, height)
.await
.context("Failed to request interactive shell with PTY")?;
Ok(channel)
}
}
pub fn is_auth_error_for_password_fallback(error: &SshError) -> bool {
match error {
SshError::KeyAuthFailed
| SshError::AgentAuthenticationFailed
| SshError::AgentNoIdentities
| SshError::AgentConnectionFailed
| SshError::AgentRequestIdentitiesFailed => true,
SshError::SshError(russh::Error::Disconnect) => {
tracing::debug!(
"Treating SshError(Disconnect) as auth failure - server likely \
disconnected after key authentication rejection"
);
true
}
SshError::SshError(russh::Error::RecvError) => {
tracing::debug!(
"Treating SshError(RecvError) as auth failure - server likely \
closed connection during authentication"
);
true
}
_ => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_key_auth_failed_triggers_password_fallback() {
let error = SshError::KeyAuthFailed;
assert!(
is_auth_error_for_password_fallback(&error),
"KeyAuthFailed should trigger password fallback"
);
}
#[test]
fn test_agent_auth_failed_triggers_password_fallback() {
let error = SshError::AgentAuthenticationFailed;
assert!(
is_auth_error_for_password_fallback(&error),
"AgentAuthenticationFailed should trigger password fallback"
);
}
#[test]
fn test_agent_no_identities_triggers_password_fallback() {
let error = SshError::AgentNoIdentities;
assert!(
is_auth_error_for_password_fallback(&error),
"AgentNoIdentities should trigger password fallback"
);
}
#[test]
fn test_agent_connection_failed_triggers_password_fallback() {
let error = SshError::AgentConnectionFailed;
assert!(
is_auth_error_for_password_fallback(&error),
"AgentConnectionFailed should trigger password fallback"
);
}
#[test]
fn test_agent_request_identities_failed_triggers_password_fallback() {
let error = SshError::AgentRequestIdentitiesFailed;
assert!(
is_auth_error_for_password_fallback(&error),
"AgentRequestIdentitiesFailed should trigger password fallback"
);
}
#[test]
fn test_password_wrong_does_not_trigger_fallback() {
let error = SshError::PasswordWrong;
assert!(
!is_auth_error_for_password_fallback(&error),
"PasswordWrong should NOT trigger password fallback (already tried password)"
);
}
#[test]
fn test_server_check_failed_does_not_trigger_fallback() {
let error = SshError::ServerCheckFailed;
assert!(
!is_auth_error_for_password_fallback(&error),
"ServerCheckFailed should NOT trigger password fallback (host key issue)"
);
}
#[test]
fn test_io_error_does_not_trigger_fallback() {
let error = SshError::IoError(std::io::Error::new(
std::io::ErrorKind::ConnectionRefused,
"connection refused",
));
assert!(
!is_auth_error_for_password_fallback(&error),
"IoError should NOT trigger password fallback (network issue)"
);
}
#[test]
fn test_keyboard_interactive_auth_failed_does_not_trigger_fallback() {
let error = SshError::KeyboardInteractiveAuthFailed;
assert!(
!is_auth_error_for_password_fallback(&error),
"KeyboardInteractiveAuthFailed should NOT trigger password fallback"
);
}
#[test]
fn test_ssh_disconnect_triggers_password_fallback() {
let error = SshError::SshError(russh::Error::Disconnect);
assert!(
is_auth_error_for_password_fallback(&error),
"SshError(Disconnect) should trigger password fallback - \
server may disconnect after key auth rejection"
);
}
#[test]
fn test_ssh_recv_error_triggers_password_fallback() {
let error = SshError::SshError(russh::Error::RecvError);
assert!(
is_auth_error_for_password_fallback(&error),
"SshError(RecvError) should trigger password fallback - \
server may close connection during authentication"
);
}
#[test]
fn test_ssh_hup_does_not_trigger_fallback() {
let error = SshError::SshError(russh::Error::HUP);
assert!(
!is_auth_error_for_password_fallback(&error),
"SshError(HUP) should NOT trigger password fallback - \
this indicates remote closed connection, not auth failure"
);
}
#[test]
fn test_ssh_connection_timeout_does_not_trigger_fallback() {
let error = SshError::SshError(russh::Error::ConnectionTimeout);
assert!(
!is_auth_error_for_password_fallback(&error),
"SshError(ConnectionTimeout) should NOT trigger password fallback - \
this is a network issue, not auth failure"
);
}
#[test]
fn test_ssh_not_authenticated_does_not_trigger_fallback() {
let error = SshError::SshError(russh::Error::NotAuthenticated);
assert!(
!is_auth_error_for_password_fallback(&error),
"SshError(NotAuthenticated) should NOT trigger password fallback - \
this means auth hasn't been attempted yet"
);
}
}