use crate::services::process_hidden::HideWindow;
use crate::services::remote::channel::AgentChannel;
use crate::services::remote::protocol::AgentResponse;
use crate::services::remote::AGENT_SOURCE;
use std::path::PathBuf;
use std::process::Stdio;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::{Child, Command};
#[derive(Debug, thiserror::Error)]
pub enum SshError {
#[error("Failed to spawn SSH process ({0}). Is the `ssh` command installed and in your PATH?")]
SpawnFailed(#[from] std::io::Error),
#[error("Agent failed to start: {0}")]
AgentStartFailed(String),
#[error("Protocol version mismatch: expected {expected}, got {got}")]
VersionMismatch { expected: u32, got: u32 },
#[error("Connection closed")]
ConnectionClosed,
#[error("Authentication failed")]
AuthenticationFailed,
}
#[derive(Debug, Clone)]
pub struct ConnectionParams {
pub user: String,
pub host: String,
pub port: Option<u16>,
pub identity_file: Option<PathBuf>,
}
impl ConnectionParams {
pub fn parse(s: &str) -> Option<Self> {
let (user_host, port) = if let Some((uh, p)) = s.rsplit_once(':') {
if let Ok(port) = p.parse::<u16>() {
(uh, Some(port))
} else {
(s, None)
}
} else {
(s, None)
};
let (user, host) = user_host.split_once('@')?;
if user.is_empty() || host.is_empty() {
return None;
}
Some(Self {
user: user.to_string(),
host: host.to_string(),
port,
identity_file: None,
})
}
}
impl std::fmt::Display for ConnectionParams {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if let Some(port) = self.port {
write!(f, "{}@{}:{}", self.user, self.host, port)
} else {
write!(f, "{}@{}", self.user, self.host)
}
}
}
pub struct SshConnection {
process: Child,
channel: std::sync::Arc<AgentChannel>,
params: ConnectionParams,
}
impl SshConnection {
pub async fn connect(params: ConnectionParams) -> Result<Self, SshError> {
let mut cmd = Command::new("ssh");
cmd.arg("-o").arg("StrictHostKeyChecking=accept-new");
if let Some(port) = params.port {
cmd.arg("-p").arg(port.to_string());
}
if let Some(ref identity) = params.identity_file {
cmd.arg("-i").arg(identity);
}
cmd.arg(format!("{}@{}", params.user, params.host));
let agent_len = AGENT_SOURCE.len();
let bootstrap = format!(
"python3 -u -c \"import sys;exec(sys.stdin.read({}))\"",
agent_len
);
cmd.arg(bootstrap);
cmd.stdin(Stdio::piped());
cmd.stdout(Stdio::piped());
cmd.stderr(Stdio::inherit());
cmd.hide_window();
let mut child = cmd.spawn()?;
let mut stdin = child
.stdin
.take()
.ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
let stdout = child
.stdout
.take()
.ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
stdin.write_all(AGENT_SOURCE.as_bytes()).await?;
stdin.flush().await?;
let mut reader = BufReader::new(stdout);
let mut ready_line = String::new();
match reader.read_line(&mut ready_line).await {
Ok(0) => {
return Err(ssh_eof_error(&mut child, ¶ms).await);
}
Ok(_) => {}
Err(e) => return Err(SshError::AgentStartFailed(format!("read error: {}", e))),
}
let ready: AgentResponse = serde_json::from_str(&ready_line).map_err(|e| {
SshError::AgentStartFailed(format!(
"invalid ready message '{}': {}",
ready_line.trim(),
e
))
})?;
if !ready.is_ready() {
return Err(SshError::AgentStartFailed(
"agent did not send ready message".to_string(),
));
}
let version = ready.version.unwrap_or(0);
if version != crate::services::remote::protocol::PROTOCOL_VERSION {
return Err(SshError::VersionMismatch {
expected: crate::services::remote::protocol::PROTOCOL_VERSION,
got: version,
});
}
let channel = std::sync::Arc::new(AgentChannel::new(reader, stdin));
Ok(Self {
process: child,
channel,
params,
})
}
pub fn channel(&self) -> std::sync::Arc<AgentChannel> {
self.channel.clone()
}
pub fn params(&self) -> &ConnectionParams {
&self.params
}
pub fn is_connected(&self) -> bool {
self.channel.is_connected()
}
pub fn connection_string(&self) -> String {
self.params.to_string()
}
}
impl Drop for SshConnection {
fn drop(&mut self) {
if let Ok(()) = self.process.start_kill() {}
}
}
const DEFAULT_RECONNECT_INTERVAL: std::time::Duration = std::time::Duration::from_secs(5);
pub struct ReconnectConfig {
pub interval: std::time::Duration,
}
impl Default for ReconnectConfig {
fn default() -> Self {
Self {
interval: DEFAULT_RECONNECT_INTERVAL,
}
}
}
pub fn spawn_reconnect_task(
channel: std::sync::Arc<AgentChannel>,
params: ConnectionParams,
) -> tokio::task::JoinHandle<()> {
let connect_fn = move || {
let params = params.clone();
async move {
let (reader, writer, _child) = establish_ssh_transport(¶ms).await?;
let reader: Box<dyn tokio::io::AsyncBufRead + Unpin + Send> = Box::new(reader);
let writer: Box<dyn tokio::io::AsyncWrite + Unpin + Send> = Box::new(writer);
Ok::<_, SshError>((reader, writer))
}
};
spawn_reconnect_task_with(
channel,
connect_fn,
ReconnectConfig::default(),
"SSH remote",
)
}
pub fn spawn_reconnect_task_with<F, Fut>(
channel: std::sync::Arc<AgentChannel>,
connect_fn: F,
config: ReconnectConfig,
label: &'static str,
) -> tokio::task::JoinHandle<()>
where
F: Fn() -> Fut + Send + 'static,
Fut: std::future::Future<
Output = Result<
(
Box<dyn tokio::io::AsyncBufRead + Unpin + Send>,
Box<dyn tokio::io::AsyncWrite + Unpin + Send>,
),
SshError,
>,
> + Send,
{
tokio::spawn(async move {
loop {
while channel.is_connected() {
tokio::time::sleep(config.interval).await;
}
tracing::info!("{label}: connection lost, attempting reconnection...");
loop {
tokio::time::sleep(config.interval).await;
if !channel.is_connected() {
} else {
break;
}
match (connect_fn)().await {
Ok((reader, writer)) => {
tracing::info!("{label}: reconnected successfully");
channel.replace_transport(reader, writer).await;
break;
}
Err(e) => {
tracing::debug!("{label}: reconnection attempt failed: {e}");
}
}
}
}
})
}
async fn ssh_eof_error(child: &mut Child, params: &ConnectionParams) -> SshError {
let status = tokio::time::timeout(std::time::Duration::from_secs(5), child.wait()).await;
let hint = match status {
Ok(Ok(status)) => {
match status.code() {
Some(255) => format!(
"SSH could not connect to {}. Check that the host is \
reachable, the hostname is correct, and your SSH \
credentials are valid (exit code 255)",
params
),
Some(127) => format!(
"python3 was not found on the remote host {}. \
Ensure Python 3 is installed on the remote machine",
params
),
Some(code) => format!(
"SSH process exited with code {} while connecting to {}",
code, params
),
None => format!(
"SSH process was killed by a signal while connecting to {}",
params
),
}
}
Ok(Err(e)) => format!("failed to get SSH exit status: {}", e),
Err(_) => {
if let Err(e) = child.start_kill() {
tracing::warn!("Failed to kill timed-out SSH process: {}", e);
}
format!(
"SSH process did not exit in time while connecting to {}",
params
)
}
};
SshError::AgentStartFailed(hint)
}
async fn establish_ssh_transport(
params: &ConnectionParams,
) -> Result<
(
BufReader<tokio::process::ChildStdout>,
tokio::process::ChildStdin,
Child,
),
SshError,
> {
let mut cmd = Command::new("ssh");
cmd.arg("-o").arg("StrictHostKeyChecking=accept-new");
cmd.arg("-o").arg("BatchMode=yes");
if let Some(port) = params.port {
cmd.arg("-p").arg(port.to_string());
}
if let Some(ref identity) = params.identity_file {
cmd.arg("-i").arg(identity);
}
cmd.arg(format!("{}@{}", params.user, params.host));
let agent_len = AGENT_SOURCE.len();
let bootstrap = format!(
"python3 -u -c \"import sys;exec(sys.stdin.read({}))\"",
agent_len
);
cmd.arg(bootstrap);
cmd.stdin(Stdio::piped());
cmd.stdout(Stdio::piped());
cmd.stderr(Stdio::null()); cmd.hide_window();
let mut child = cmd.spawn()?;
let mut stdin = child
.stdin
.take()
.ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
let stdout = child
.stdout
.take()
.ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
stdin.write_all(AGENT_SOURCE.as_bytes()).await?;
stdin.flush().await?;
let mut reader = BufReader::new(stdout);
let mut ready_line = String::new();
match reader.read_line(&mut ready_line).await {
Ok(0) => {
return Err(ssh_eof_error(&mut child, params).await);
}
Ok(_) => {}
Err(e) => return Err(SshError::AgentStartFailed(format!("read error: {}", e))),
}
let ready: AgentResponse = serde_json::from_str(&ready_line).map_err(|e| {
SshError::AgentStartFailed(format!(
"invalid ready message '{}': {}",
ready_line.trim(),
e
))
})?;
if !ready.is_ready() {
return Err(SshError::AgentStartFailed(
"agent did not send ready message".to_string(),
));
}
let version = ready.version.unwrap_or(0);
if version != crate::services::remote::protocol::PROTOCOL_VERSION {
return Err(SshError::VersionMismatch {
expected: crate::services::remote::protocol::PROTOCOL_VERSION,
got: version,
});
}
Ok((reader, stdin, child))
}
#[doc(hidden)]
pub async fn spawn_local_agent() -> Result<std::sync::Arc<AgentChannel>, SshError> {
use tokio::process::Command as TokioCommand;
let mut child = TokioCommand::new("python3")
.arg("-u")
.arg("-c")
.arg(AGENT_SOURCE)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.hide_window()
.spawn()?;
let stdin = child
.stdin
.take()
.ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
let stdout = child
.stdout
.take()
.ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
let mut reader = BufReader::new(stdout);
let mut ready_line = String::new();
reader.read_line(&mut ready_line).await?;
let ready: AgentResponse = serde_json::from_str(&ready_line)
.map_err(|e| SshError::AgentStartFailed(format!("invalid ready message: {}", e)))?;
if !ready.is_ready() {
return Err(SshError::AgentStartFailed(
"agent did not send ready message".to_string(),
));
}
Ok(std::sync::Arc::new(AgentChannel::new(reader, stdin)))
}
#[doc(hidden)]
pub async fn spawn_local_agent_with_capacity(
data_channel_capacity: usize,
) -> Result<std::sync::Arc<AgentChannel>, SshError> {
use tokio::process::Command as TokioCommand;
let mut child = TokioCommand::new("python3")
.arg("-u")
.arg("-c")
.arg(AGENT_SOURCE)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.hide_window()
.spawn()?;
let stdin = child
.stdin
.take()
.ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
let stdout = child
.stdout
.take()
.ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
let mut reader = BufReader::new(stdout);
let mut ready_line = String::new();
reader.read_line(&mut ready_line).await?;
let ready: AgentResponse = serde_json::from_str(&ready_line)
.map_err(|e| SshError::AgentStartFailed(format!("invalid ready message: {}", e)))?;
if !ready.is_ready() {
return Err(SshError::AgentStartFailed(
"agent did not send ready message".to_string(),
));
}
Ok(std::sync::Arc::new(AgentChannel::with_capacity(
reader,
stdin,
data_channel_capacity,
)))
}
#[doc(hidden)]
pub async fn spawn_local_agent_transport() -> Result<
(
tokio::io::BufReader<tokio::process::ChildStdout>,
tokio::process::ChildStdin,
),
SshError,
> {
use tokio::process::Command as TokioCommand;
let mut child = TokioCommand::new("python3")
.arg("-u")
.arg("-c")
.arg(AGENT_SOURCE)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.hide_window()
.spawn()?;
let stdin = child
.stdin
.take()
.ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
let stdout = child
.stdout
.take()
.ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
let mut reader = BufReader::new(stdout);
let mut ready_line = String::new();
reader.read_line(&mut ready_line).await?;
let ready: AgentResponse = serde_json::from_str(&ready_line)
.map_err(|e| SshError::AgentStartFailed(format!("invalid ready message: {}", e)))?;
if !ready.is_ready() {
return Err(SshError::AgentStartFailed(
"agent did not send ready message".to_string(),
));
}
Ok((reader, stdin))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_connection_params() {
let params = ConnectionParams::parse("user@host").unwrap();
assert_eq!(params.user, "user");
assert_eq!(params.host, "host");
assert_eq!(params.port, None);
let params = ConnectionParams::parse("user@host:22").unwrap();
assert_eq!(params.user, "user");
assert_eq!(params.host, "host");
assert_eq!(params.port, Some(22));
assert!(ConnectionParams::parse("hostonly").is_none());
assert!(ConnectionParams::parse("@host").is_none());
assert!(ConnectionParams::parse("user@").is_none());
}
#[test]
fn test_connection_string() {
let params = ConnectionParams {
user: "alice".to_string(),
host: "example.com".to_string(),
port: None,
identity_file: None,
};
assert_eq!(params.to_string(), "alice@example.com");
let params = ConnectionParams {
user: "bob".to_string(),
host: "server.local".to_string(),
port: Some(2222),
identity_file: None,
};
assert_eq!(params.to_string(), "bob@server.local:2222");
}
}