Skip to main content

fresh/services/remote/
connection.rs

1//! SSH connection management
2//!
3//! Handles spawning SSH process and bootstrapping the Python agent.
4
5use crate::services::remote::channel::AgentChannel;
6use crate::services::remote::protocol::AgentResponse;
7use crate::services::remote::AGENT_SOURCE;
8use std::path::PathBuf;
9use std::process::Stdio;
10use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
11use tokio::process::{Child, Command};
12
13/// Error type for SSH connection
14#[derive(Debug, thiserror::Error)]
15pub enum SshError {
16    #[error("Failed to spawn SSH process: {0}")]
17    SpawnFailed(#[from] std::io::Error),
18
19    #[error("Agent failed to start: {0}")]
20    AgentStartFailed(String),
21
22    #[error("Protocol version mismatch: expected {expected}, got {got}")]
23    VersionMismatch { expected: u32, got: u32 },
24
25    #[error("Connection closed")]
26    ConnectionClosed,
27
28    #[error("Authentication failed")]
29    AuthenticationFailed,
30}
31
32/// SSH connection parameters
33#[derive(Debug, Clone)]
34pub struct ConnectionParams {
35    pub user: String,
36    pub host: String,
37    pub port: Option<u16>,
38    pub identity_file: Option<PathBuf>,
39}
40
41impl ConnectionParams {
42    /// Parse a connection string like "user@host" or "user@host:port"
43    pub fn parse(s: &str) -> Option<Self> {
44        let (user_host, port) = if let Some((uh, p)) = s.rsplit_once(':') {
45            if let Ok(port) = p.parse::<u16>() {
46                (uh, Some(port))
47            } else {
48                (s, None)
49            }
50        } else {
51            (s, None)
52        };
53
54        let (user, host) = user_host.split_once('@')?;
55        if user.is_empty() || host.is_empty() {
56            return None;
57        }
58
59        Some(Self {
60            user: user.to_string(),
61            host: host.to_string(),
62            port,
63            identity_file: None,
64        })
65    }
66
67    /// Format as connection string
68    pub fn to_string(&self) -> String {
69        if let Some(port) = self.port {
70            format!("{}@{}:{}", self.user, self.host, port)
71        } else {
72            format!("{}@{}", self.user, self.host)
73        }
74    }
75}
76
77/// Active SSH connection with bootstrapped agent
78pub struct SshConnection {
79    /// SSH child process
80    process: Child,
81    /// Communication channel with agent (wrapped in Arc for sharing)
82    channel: std::sync::Arc<AgentChannel>,
83    /// Connection parameters
84    params: ConnectionParams,
85}
86
87impl SshConnection {
88    /// Establish a new SSH connection and bootstrap the agent
89    pub async fn connect(params: ConnectionParams) -> Result<Self, SshError> {
90        let mut cmd = Command::new("ssh");
91
92        // Don't check host key strictly for ease of use
93        cmd.arg("-o").arg("StrictHostKeyChecking=accept-new");
94        // Allow password prompts - SSH will use the terminal for this
95        // Note: We inherit stderr so SSH can prompt for password if needed
96
97        if let Some(port) = params.port {
98            cmd.arg("-p").arg(port.to_string());
99        }
100
101        if let Some(ref identity) = params.identity_file {
102            cmd.arg("-i").arg(identity);
103        }
104
105        cmd.arg(format!("{}@{}", params.user, params.host));
106
107        // Bootstrap the agent using Python itself to read the exact byte count.
108        // This avoids requiring bash or other shell utilities on the remote.
109        // Python reads exactly N bytes (the agent code), execs it, and the agent
110        // then continues reading from stdin for protocol messages.
111        //
112        // Note: SSH passes the remote command through a shell, so we need to
113        // properly quote the Python code. We use double quotes for the outer
114        // shell and avoid problematic characters in the Python code.
115        let agent_len = AGENT_SOURCE.len();
116        let bootstrap = format!(
117            "python3 -u -c \"import sys;exec(sys.stdin.read({}))\"",
118            agent_len
119        );
120        cmd.arg(bootstrap);
121
122        cmd.stdin(Stdio::piped());
123        cmd.stdout(Stdio::piped());
124        // Inherit stderr so SSH can prompt for password on the terminal
125        cmd.stderr(Stdio::inherit());
126
127        let mut child = cmd.spawn()?;
128
129        // Get handles
130        let mut stdin = child
131            .stdin
132            .take()
133            .ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
134        let stdout = child
135            .stdout
136            .take()
137            .ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
138        // Note: stderr is inherited so SSH can prompt for password on the terminal
139
140        // Send the agent code (exact byte count)
141        stdin.write_all(AGENT_SOURCE.as_bytes()).await?;
142        stdin.flush().await?;
143
144        // Create buffered reader for stdout
145        let mut reader = BufReader::new(stdout);
146
147        // Wait for ready message from agent
148        // No timeout needed - all failure modes (auth failure, network issues, etc.)
149        // result in SSH exiting and us getting EOF. User can Ctrl+C if needed.
150        let mut ready_line = String::new();
151        match reader.read_line(&mut ready_line).await {
152            Ok(0) => {
153                return Err(SshError::AgentStartFailed(
154                    "connection closed (check terminal for SSH errors)".to_string(),
155                ));
156            }
157            Ok(_) => {}
158            Err(e) => return Err(SshError::AgentStartFailed(format!("read error: {}", e))),
159        }
160
161        let ready: AgentResponse = serde_json::from_str(&ready_line).map_err(|e| {
162            SshError::AgentStartFailed(format!(
163                "invalid ready message '{}': {}",
164                ready_line.trim(),
165                e
166            ))
167        })?;
168
169        if !ready.is_ready() {
170            return Err(SshError::AgentStartFailed(
171                "agent did not send ready message".to_string(),
172            ));
173        }
174
175        // Check protocol version
176        let version = ready.version.unwrap_or(0);
177        if version != crate::services::remote::protocol::PROTOCOL_VERSION {
178            return Err(SshError::VersionMismatch {
179                expected: crate::services::remote::protocol::PROTOCOL_VERSION,
180                got: version,
181            });
182        }
183
184        // Create channel (takes ownership of stdin for writing)
185        let channel = std::sync::Arc::new(AgentChannel::new(reader, stdin));
186
187        Ok(Self {
188            process: child,
189            channel,
190            params,
191        })
192    }
193
194    /// Get the communication channel as an Arc for sharing
195    pub fn channel(&self) -> std::sync::Arc<AgentChannel> {
196        self.channel.clone()
197    }
198
199    /// Get connection parameters
200    pub fn params(&self) -> &ConnectionParams {
201        &self.params
202    }
203
204    /// Check if the connection is still alive
205    pub fn is_connected(&self) -> bool {
206        self.channel.is_connected()
207    }
208
209    /// Get the connection string for display
210    pub fn connection_string(&self) -> String {
211        self.params.to_string()
212    }
213}
214
215impl Drop for SshConnection {
216    fn drop(&mut self) {
217        // Best-effort kill of the SSH process during cleanup.
218        // If it fails (process already exited, permission error, etc.)
219        // there's nothing we can do in a Drop impl — the OS will clean
220        // up the zombie when our process exits.
221        if let Ok(()) = self.process.start_kill() {}
222    }
223}
224
225/// Spawn a local agent process for testing (no SSH)
226///
227/// This is used by integration tests to test the full stack without SSH.
228/// Not intended for production use.
229#[doc(hidden)]
230pub async fn spawn_local_agent() -> Result<std::sync::Arc<AgentChannel>, SshError> {
231    use tokio::process::Command as TokioCommand;
232
233    let mut child = TokioCommand::new("python3")
234        .arg("-u")
235        .arg("-c")
236        .arg(AGENT_SOURCE)
237        .stdin(Stdio::piped())
238        .stdout(Stdio::piped())
239        .stderr(Stdio::piped())
240        .spawn()?;
241
242    let stdin = child
243        .stdin
244        .take()
245        .ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
246    let stdout = child
247        .stdout
248        .take()
249        .ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
250
251    let mut reader = BufReader::new(stdout);
252
253    // Wait for ready message
254    let mut ready_line = String::new();
255    reader.read_line(&mut ready_line).await?;
256
257    let ready: AgentResponse = serde_json::from_str(&ready_line)
258        .map_err(|e| SshError::AgentStartFailed(format!("invalid ready message: {}", e)))?;
259
260    if !ready.is_ready() {
261        return Err(SshError::AgentStartFailed(
262            "agent did not send ready message".to_string(),
263        ));
264    }
265
266    Ok(std::sync::Arc::new(AgentChannel::new(reader, stdin)))
267}
268
269/// Spawn a local Python agent with a custom data channel capacity.
270///
271/// Same as `spawn_local_agent` but allows overriding the channel capacity
272/// for stress-testing backpressure handling.
273#[doc(hidden)]
274pub async fn spawn_local_agent_with_capacity(
275    data_channel_capacity: usize,
276) -> Result<std::sync::Arc<AgentChannel>, SshError> {
277    use tokio::process::Command as TokioCommand;
278
279    let mut child = TokioCommand::new("python3")
280        .arg("-u")
281        .arg("-c")
282        .arg(AGENT_SOURCE)
283        .stdin(Stdio::piped())
284        .stdout(Stdio::piped())
285        .stderr(Stdio::piped())
286        .spawn()?;
287
288    let stdin = child
289        .stdin
290        .take()
291        .ok_or_else(|| SshError::AgentStartFailed("failed to get stdin".to_string()))?;
292    let stdout = child
293        .stdout
294        .take()
295        .ok_or_else(|| SshError::AgentStartFailed("failed to get stdout".to_string()))?;
296
297    let mut reader = BufReader::new(stdout);
298
299    // Wait for ready message
300    let mut ready_line = String::new();
301    reader.read_line(&mut ready_line).await?;
302
303    let ready: AgentResponse = serde_json::from_str(&ready_line)
304        .map_err(|e| SshError::AgentStartFailed(format!("invalid ready message: {}", e)))?;
305
306    if !ready.is_ready() {
307        return Err(SshError::AgentStartFailed(
308            "agent did not send ready message".to_string(),
309        ));
310    }
311
312    Ok(std::sync::Arc::new(AgentChannel::with_capacity(
313        reader,
314        stdin,
315        data_channel_capacity,
316    )))
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322
323    #[test]
324    fn test_parse_connection_params() {
325        let params = ConnectionParams::parse("user@host").unwrap();
326        assert_eq!(params.user, "user");
327        assert_eq!(params.host, "host");
328        assert_eq!(params.port, None);
329
330        let params = ConnectionParams::parse("user@host:22").unwrap();
331        assert_eq!(params.user, "user");
332        assert_eq!(params.host, "host");
333        assert_eq!(params.port, Some(22));
334
335        assert!(ConnectionParams::parse("hostonly").is_none());
336        assert!(ConnectionParams::parse("@host").is_none());
337        assert!(ConnectionParams::parse("user@").is_none());
338    }
339
340    #[test]
341    fn test_connection_string() {
342        let params = ConnectionParams {
343            user: "alice".to_string(),
344            host: "example.com".to_string(),
345            port: None,
346            identity_file: None,
347        };
348        assert_eq!(params.to_string(), "alice@example.com");
349
350        let params = ConnectionParams {
351            user: "bob".to_string(),
352            host: "server.local".to_string(),
353            port: Some(2222),
354            identity_file: None,
355        };
356        assert_eq!(params.to_string(), "bob@server.local:2222");
357    }
358}