use crate::commands::cloudflared_access::{CloudflaredTcpOptions, CloudflaredTunnel};
use crate::commands::ssh_helpers::{resolve_or_prompt, resolve_or_prompt_password};
use crate::config::SshConfig;
use async_ssh2_tokio::client::{AuthMethod, Client, ServerCheckMethod};
use crossterm::terminal::{disable_raw_mode, enable_raw_mode, size as terminal_size};
use russh::{ChannelMsg, Sig};
use std::env;
use std::net::Ipv4Addr;
use std::path::PathBuf;
use tokio::io::{stderr, stdin, stdout, AsyncReadExt, AsyncWriteExt};
use tokio::signal;
use tracing::debug;
const DEFAULT_TERM: &str = "xterm-256color";
#[derive(Debug, Clone)]
pub struct InteractiveShellOptions {
pub ssh_host: Option<String>,
pub ssh_port: u16,
pub ssh_username: Option<String>,
pub ssh_password: Option<String>,
pub private_key: Option<PathBuf>,
pub private_key_passphrase: Option<String>,
pub command: Option<String>,
pub term: Option<String>,
pub no_host_key_check: bool,
pub host_key: Option<String>,
pub known_hosts_file: Option<PathBuf>,
pub cloudflared_hostname: Option<String>,
pub cloudflared_binary: Option<PathBuf>,
pub cloudflared_destination: Option<String>,
}
pub async fn run_interactive_shell(
options: InteractiveShellOptions,
debug_mode: bool,
) -> Result<(), String> {
let mut config = SshConfig::load().map_err(|e| format!("Failed to load SSH config: {}", e))?;
let mut config_dirty = false;
let resolved_host = resolve_or_prompt(
options.ssh_host.clone(),
&mut config.host,
"Enter SSH host: ",
&mut config_dirty,
)?;
let resolved_username = resolve_or_prompt(
options.ssh_username.clone(),
&mut config.username,
"Enter SSH username: ",
&mut config_dirty,
)?;
let auth = resolve_auth_method(&options, &mut config, &mut config_dirty)?;
if config_dirty {
config
.save()
.map_err(|e| format!("Failed to save SSH config: {}", e))?;
}
let mut cloudflared = if let Some(hostname) = options.cloudflared_hostname.as_deref() {
Some(
CloudflaredTunnel::start(
CloudflaredTcpOptions {
hostname: hostname.to_string(),
listener: None,
destination: options.cloudflared_destination.clone(),
binary_path: options.cloudflared_binary.clone(),
},
debug_mode,
)
.await?,
)
} else {
None
};
let server_check = resolve_server_check(&options, cloudflared.is_some());
let term = resolve_term_value(options.term.clone());
let connect_result = if let Some(tunnel) = cloudflared.as_ref() {
if debug_mode {
debug!(
"SSH shell via cloudflared => tunnel host: {}, local port: {}, remote identity: {}, user: {}",
tunnel.hostname, tunnel.local_port, resolved_host, resolved_username
);
}
Client::connect(
(Ipv4Addr::LOCALHOST, tunnel.local_port),
resolved_username.as_str(),
auth,
server_check,
)
.await
.map_err(|e| format!("SSH connection failed: {}", e))
} else {
if debug_mode {
debug!(
"SSH shell direct => host: {}, port: {}, user: {}",
resolved_host, options.ssh_port, resolved_username
);
}
Client::connect(
(resolved_host.as_str(), options.ssh_port),
resolved_username.as_str(),
auth,
server_check,
)
.await
.map_err(|e| format!("SSH connection failed: {}", e))
};
let client = match connect_result {
Ok(client) => client,
Err(err) => {
if let Some(tunnel) = cloudflared.as_mut() {
tunnel.shutdown().await;
}
return Err(err);
}
};
let session_result =
run_channel_session(&client, &term, options.command.as_deref(), debug_mode).await;
if let Err(err) = client.disconnect().await {
debug!("Failed to cleanly disconnect SSH session: {}", err);
}
if let Some(tunnel) = cloudflared.as_mut() {
tunnel.shutdown().await;
}
session_result
}
fn resolve_auth_method(
options: &InteractiveShellOptions,
config: &mut SshConfig,
config_dirty: &mut bool,
) -> Result<AuthMethod, String> {
if let Some(key_path) = options.private_key.as_deref() {
return Ok(AuthMethod::with_key_file(
key_path,
options.private_key_passphrase.as_deref(),
));
}
let password = resolve_or_prompt_password(
options.ssh_password.clone(),
&mut config.password,
"Enter SSH password: ",
config_dirty,
)?;
Ok(AuthMethod::with_password(&password))
}
fn resolve_server_check(
options: &InteractiveShellOptions,
using_cloudflared: bool,
) -> ServerCheckMethod {
if let Some(host_key) = options.host_key.as_deref().map(str::trim) {
if !host_key.is_empty() {
return ServerCheckMethod::with_public_key(host_key);
}
}
if options.no_host_key_check {
return ServerCheckMethod::NoCheck;
}
if let Some(path) = options.known_hosts_file.as_deref() {
return ServerCheckMethod::with_known_hosts_file(&path.to_string_lossy());
}
if using_cloudflared {
return ServerCheckMethod::NoCheck;
}
ServerCheckMethod::DefaultKnownHostsFile
}
fn resolve_term_value(explicit_term: Option<String>) -> String {
explicit_term
.and_then(normalize_optional_string)
.or_else(|| env::var("TERM").ok().and_then(normalize_optional_string))
.unwrap_or_else(|| DEFAULT_TERM.to_string())
}
fn normalize_optional_string(value: String) -> Option<String> {
let trimmed = value.trim();
if trimmed.is_empty() {
None
} else {
Some(trimmed.to_string())
}
}
async fn run_channel_session(
client: &Client,
term: &str,
command: Option<&str>,
debug_mode: bool,
) -> Result<(), String> {
let channel = client
.get_channel()
.await
.map_err(|e| format!("Failed to open SSH channel: {}", e))?;
let (cols, rows) = terminal_size().unwrap_or((120, 32));
channel
.request_pty(false, term, u32::from(cols), u32::from(rows), 0, 0, &[])
.await
.map_err(|e| format!("Failed to request remote PTY: {}", e))?;
if let Some(command) = command {
channel
.exec(true, command)
.await
.map_err(|e| format!("Failed to execute remote command: {}", e))?;
} else {
channel
.request_shell(true)
.await
.map_err(|e| format!("Failed to start remote shell: {}", e))?;
}
if debug_mode {
debug!(
"SSH channel ready => term: {}, cols: {}, rows: {}, command: {}",
term,
cols,
rows,
command.unwrap_or("<login-shell>")
);
}
let _raw_mode = RawModeGuard::enable()?;
stream_interactive_channel(channel).await
}
async fn stream_interactive_channel(
mut channel: russh::Channel<russh::client::Msg>,
) -> Result<(), String> {
let mut stdin = stdin();
let mut stdout = stdout();
let mut stderr = stderr();
let mut read_buf = [0_u8; 8192];
let mut exit_status: Option<u32> = None;
let mut stdin_closed = false;
let ctrl_c = signal::ctrl_c();
tokio::pin!(ctrl_c);
loop {
tokio::select! {
read_result = stdin.read(&mut read_buf), if !stdin_closed => {
match read_result {
Ok(0) => {
stdin_closed = true;
channel
.eof()
.await
.map_err(|e| format!("Failed to close remote stdin: {}", e))?;
}
Ok(read_len) => {
channel
.data(&read_buf[..read_len])
.await
.map_err(|e| format!("Failed to send SSH input: {}", e))?;
}
Err(err) => return Err(format!("Failed to read terminal input: {}", err)),
}
}
msg = channel.wait() => match msg {
Some(ChannelMsg::Data { ref data }) => {
stdout
.write_all(data)
.await
.map_err(|e| format!("Failed to write remote stdout: {}", e))?;
stdout
.flush()
.await
.map_err(|e| format!("Failed to flush stdout: {}", e))?;
}
Some(ChannelMsg::ExtendedData { ref data, ext }) => {
if ext == 1 {
stderr
.write_all(data)
.await
.map_err(|e| format!("Failed to write remote stderr: {}", e))?;
stderr
.flush()
.await
.map_err(|e| format!("Failed to flush stderr: {}", e))?;
}
}
Some(ChannelMsg::ExitStatus { exit_status: status }) => {
exit_status = Some(status);
}
Some(ChannelMsg::ExitSignal { signal_name, .. }) => {
if exit_status.is_none() {
exit_status = Some(signal_to_exit_status(&signal_name));
}
}
Some(ChannelMsg::Eof) | Some(ChannelMsg::Close) | None => {
break;
}
Some(_) => {}
},
_ = &mut ctrl_c => {
channel
.signal(Sig::INT)
.await
.map_err(|e| format!("Failed to send interrupt signal: {}", e))?;
}
}
}
match exit_status {
Some(0) | None => Ok(()),
Some(status) => Err(format!("Remote shell exited with status: {}", status)),
}
}
fn signal_to_exit_status(signal: &Sig) -> u32 {
match signal {
Sig::INT => 130,
Sig::TERM => 143,
Sig::QUIT => 131,
Sig::KILL => 137,
Sig::HUP => 129,
Sig::PIPE => 141,
_ => 128,
}
}
struct RawModeGuard;
impl RawModeGuard {
fn enable() -> Result<Self, String> {
enable_raw_mode().map_err(|e| format!("Failed to enable raw terminal mode: {}", e))?;
Ok(Self)
}
}
impl Drop for RawModeGuard {
fn drop(&mut self) {
let _ = disable_raw_mode();
}
}
#[cfg(test)]
mod tests {
use super::{
normalize_optional_string, resolve_server_check, resolve_term_value,
InteractiveShellOptions, DEFAULT_TERM,
};
use async_ssh2_tokio::client::ServerCheckMethod;
use std::env;
use std::path::PathBuf;
fn base_options() -> InteractiveShellOptions {
InteractiveShellOptions {
ssh_host: Some("prod.example.com".to_string()),
ssh_port: 22,
ssh_username: Some("deploy".to_string()),
ssh_password: Some("secret".to_string()),
private_key: None,
private_key_passphrase: None,
command: None,
term: None,
no_host_key_check: false,
host_key: None,
known_hosts_file: None,
cloudflared_hostname: None,
cloudflared_binary: None,
cloudflared_destination: None,
}
}
#[test]
fn explicit_host_key_wins() {
let mut options = base_options();
options.host_key = Some("AAAAB3NzaC1yc2EAAAADAQABAAABAQDc".to_string());
options.no_host_key_check = true;
let server_check = resolve_server_check(&options, true);
assert!(matches!(server_check, ServerCheckMethod::PublicKey(_)));
}
#[test]
fn cloudflared_defaults_to_no_check_without_pin() {
let options = base_options();
let server_check = resolve_server_check(&options, true);
assert!(matches!(server_check, ServerCheckMethod::NoCheck));
}
#[test]
fn known_hosts_file_is_used_for_direct_ssh() {
let mut options = base_options();
options.known_hosts_file = Some(PathBuf::from("C:/Users/floris/.ssh/known_hosts"));
let server_check = resolve_server_check(&options, false);
assert!(matches!(server_check, ServerCheckMethod::KnownHostsFile(_)));
}
#[test]
fn resolve_term_prefers_explicit_value() {
let term = resolve_term_value(Some("screen-256color".to_string()));
assert_eq!(term, "screen-256color");
}
#[test]
fn resolve_term_falls_back_to_default() {
let env_term = env::var("TERM").ok();
unsafe {
env::remove_var("TERM");
}
let term = resolve_term_value(None);
if let Some(value) = env_term {
unsafe {
env::set_var("TERM", value);
}
}
assert_eq!(term, DEFAULT_TERM);
}
#[test]
fn normalize_optional_string_trims_and_filters_empty_values() {
assert_eq!(
normalize_optional_string(" cloudflared.example.com ".to_string()),
Some("cloudflared.example.com".to_string())
);
assert_eq!(normalize_optional_string(" ".to_string()), None);
}
}