Skip to main content

fresh/services/remote/
connection.rs

1//! SSH connection management
2//!
3//! Handles spawning SSH process and bootstrapping the Python agent.
4
5use crate::services::remote::channel::AgentChannel;
6use crate::services::remote::protocol::AgentResponse;
7use crate::services::remote::AGENT_SOURCE;
8use std::path::PathBuf;
9use std::process::Stdio;
10use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
11use tokio::process::{Child, Command};
12
13/// Error type for SSH connection
14#[derive(Debug, thiserror::Error)]
15pub enum SshError {
16    #[error("Failed to spawn SSH process: {0}")]
17    SpawnFailed(#[from] std::io::Error),
18
19    #[error("Agent failed to start: {0}")]
20    AgentStartFailed(String),
21
22    #[error("Protocol version mismatch: expected {expected}, got {got}")]
23    VersionMismatch { expected: u32, got: u32 },
24
25    #[error("Connection closed")]
26    ConnectionClosed,
27
28    #[error("Authentication failed")]
29    AuthenticationFailed,
30}
31
32/// SSH connection parameters
33#[derive(Debug, Clone)]
34pub struct ConnectionParams {
35    pub user: String,
36    pub host: String,
37    pub port: Option<u16>,
38    pub identity_file: Option<PathBuf>,
39}
40
41impl ConnectionParams {
42    /// Parse a connection string like "user@host" or "user@host:port"
43    pub fn parse(s: &str) -> Option<Self> {
44        let (user_host, port) = if let Some((uh, p)) = s.rsplit_once(':') {
45            if let Ok(port) = p.parse::<u16>() {
46                (uh, Some(port))
47            } else {
48                (s, None)
49            }
50        } else {
51            (s, None)
52        };
53
54        let (user, host) = user_host.split_once('@')?;
55        if user.is_empty() || host.is_empty() {
56            return None;
57        }
58
59        Some(Self {
60            user: user.to_string(),
61            host: host.to_string(),
62            port,
63            identity_file: None,
64        })
65    }
66
67    /// Format as connection string
68    pub fn to_string(&self) -> String {
69        if let Some(port) = self.port {
70            format!("{}@{}:{}", self.user, self.host, port)
71        } else {
72            format!("{}@{}", self.user, self.host)
73        }
74    }
75}
76
77/// Active SSH connection with bootstrapped agent
78pub struct SshConnection {
79    /// SSH child process
80    process: Child,
81    /// Communication channel with agent (wrapped in Arc for sharing)
82    channel: std::sync::Arc<AgentChannel>,
83    /// Connection parameters
84    params: ConnectionParams,
85}
86
87impl SshConnection {
88    /// Establish a new SSH connection and bootstrap the agent
89    pub async fn connect(params: ConnectionParams) -> Result<Self, SshError> {
90        let mut cmd = Command::new("ssh");
91
92        // Don't check host key strictly for ease of use
93        cmd.arg("-o").arg("StrictHostKeyChecking=accept-new");
94        // Allow password prompts - SSH will use the terminal for this
95        // Note: We inherit stderr so SSH can prompt for password if needed
96
97        if let Some(port) = params.port {
98            cmd.arg("-p").arg(port.to_string());
99        }
100
101        if let Some(ref identity) = params.identity_file {
102            cmd.arg("-i").arg(identity);
103        }
104
105        cmd.arg(format!("{}@{}", params.user, params.host));
106
107        // Bootstrap the agent using Python itself to read the exact byte count.
108        // This avoids requiring bash or other shell utilities on the remote.
109        // Python reads exactly N bytes (the agent code), execs it, and the agent
110        // then continues reading from stdin for protocol messages.
111        //
112        // Note: SSH passes the remote command through a shell, so we need to
113        // properly quote the Python code. We use double quotes for the outer
114        // shell and avoid problematic characters in the Python code.
115        let agent_len = AGENT_SOURCE.len();
116        let bootstrap = format!(
117            "python3 -u -c \"import sys;exec(sys.stdin.read({}))\"",
118            agent_len
119        );
120        cmd.arg(bootstrap);
121
122        cmd.stdin(Stdio::piped());
123        cmd.stdout(Stdio::piped());
124        // Inherit stderr so SSH can prompt for password on the terminal
125        cmd.stderr(Stdio::inherit());
126
127        let mut child = cmd.spawn()?;
128
129        // Get handles
130        let mut stdin = child
131            .stdin
132            .take()
133            .ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
134        let stdout = child
135            .stdout
136            .take()
137            .ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
138        // Note: stderr is inherited so SSH can prompt for password on the terminal
139
140        // Send the agent code (exact byte count)
141        stdin.write_all(AGENT_SOURCE.as_bytes()).await?;
142        stdin.flush().await?;
143
144        // Create buffered reader for stdout
145        let mut reader = BufReader::new(stdout);
146
147        // Wait for ready message from agent
148        // No timeout needed - all failure modes (auth failure, network issues, etc.)
149        // result in SSH exiting and us getting EOF. User can Ctrl+C if needed.
150        let mut ready_line = String::new();
151        match reader.read_line(&mut ready_line).await {
152            Ok(0) => {
153                return Err(SshError::AgentStartFailed(
154                    "connection closed (check terminal for SSH errors)".to_string(),
155                ));
156            }
157            Ok(_) => {}
158            Err(e) => return Err(SshError::AgentStartFailed(format!("read error: {}", e))),
159        }
160
161        let ready: AgentResponse = serde_json::from_str(&ready_line).map_err(|e| {
162            SshError::AgentStartFailed(format!(
163                "invalid ready message '{}': {}",
164                ready_line.trim(),
165                e
166            ))
167        })?;
168
169        if !ready.is_ready() {
170            return Err(SshError::AgentStartFailed(
171                "agent did not send ready message".to_string(),
172            ));
173        }
174
175        // Check protocol version
176        let version = ready.version.unwrap_or(0);
177        if version != crate::services::remote::protocol::PROTOCOL_VERSION {
178            return Err(SshError::VersionMismatch {
179                expected: crate::services::remote::protocol::PROTOCOL_VERSION,
180                got: version,
181            });
182        }
183
184        // Create channel (takes ownership of stdin for writing)
185        let channel = std::sync::Arc::new(AgentChannel::new(reader, stdin));
186
187        Ok(Self {
188            process: child,
189            channel,
190            params,
191        })
192    }
193
194    /// Get the communication channel as an Arc for sharing
195    pub fn channel(&self) -> std::sync::Arc<AgentChannel> {
196        self.channel.clone()
197    }
198
199    /// Get connection parameters
200    pub fn params(&self) -> &ConnectionParams {
201        &self.params
202    }
203
204    /// Check if the connection is still alive
205    pub fn is_connected(&self) -> bool {
206        self.channel.is_connected()
207    }
208
209    /// Get the connection string for display
210    pub fn connection_string(&self) -> String {
211        self.params.to_string()
212    }
213}
214
215impl Drop for SshConnection {
216    fn drop(&mut self) {
217        // Try to kill the SSH process gracefully
218        let _ = self.process.start_kill();
219    }
220}
221
222/// Spawn a local agent process for testing (no SSH)
223///
224/// This is used by integration tests to test the full stack without SSH.
225/// Not intended for production use.
226#[doc(hidden)]
227pub async fn spawn_local_agent() -> Result<std::sync::Arc<AgentChannel>, SshError> {
228    use tokio::process::Command as TokioCommand;
229
230    let mut child = TokioCommand::new("python3")
231        .arg("-u")
232        .arg("-c")
233        .arg(AGENT_SOURCE)
234        .stdin(Stdio::piped())
235        .stdout(Stdio::piped())
236        .stderr(Stdio::piped())
237        .spawn()?;
238
239    let stdin = child
240        .stdin
241        .take()
242        .ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
243    let stdout = child
244        .stdout
245        .take()
246        .ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
247
248    let mut reader = BufReader::new(stdout);
249
250    // Wait for ready message
251    let mut ready_line = String::new();
252    reader.read_line(&mut ready_line).await?;
253
254    let ready: AgentResponse = serde_json::from_str(&ready_line)
255        .map_err(|e| SshError::AgentStartFailed(format!("invalid ready message: {}", e)))?;
256
257    if !ready.is_ready() {
258        return Err(SshError::AgentStartFailed(
259            "agent did not send ready message".to_string(),
260        ));
261    }
262
263    Ok(std::sync::Arc::new(AgentChannel::new(reader, stdin)))
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269
270    #[test]
271    fn test_parse_connection_params() {
272        let params = ConnectionParams::parse("user@host").unwrap();
273        assert_eq!(params.user, "user");
274        assert_eq!(params.host, "host");
275        assert_eq!(params.port, None);
276
277        let params = ConnectionParams::parse("user@host:22").unwrap();
278        assert_eq!(params.user, "user");
279        assert_eq!(params.host, "host");
280        assert_eq!(params.port, Some(22));
281
282        assert!(ConnectionParams::parse("hostonly").is_none());
283        assert!(ConnectionParams::parse("@host").is_none());
284        assert!(ConnectionParams::parse("user@").is_none());
285    }
286
287    #[test]
288    fn test_connection_string() {
289        let params = ConnectionParams {
290            user: "alice".to_string(),
291            host: "example.com".to_string(),
292            port: None,
293            identity_file: None,
294        };
295        assert_eq!(params.to_string(), "alice@example.com");
296
297        let params = ConnectionParams {
298            user: "bob".to_string(),
299            host: "server.local".to_string(),
300            port: Some(2222),
301            identity_file: None,
302        };
303        assert_eq!(params.to_string(), "bob@server.local:2222");
304    }
305}