1use async_trait::async_trait;
4use crate::{Connection, TransportError};
5use std::collections::HashMap;
6use std::path::PathBuf;
7use std::process::Stdio;
8use tokio::process::{Child, Command};
9use tracing::{debug, info};
10
11#[async_trait]
13pub trait Transport: Send + Sync {
14 async fn connect(&mut self) -> Result<Connection, TransportError>;
16
17 async fn bootstrap_agent(&mut self, agent_binary: &[u8]) -> Result<(), TransportError>;
19
20 fn connection_info(&self) -> ConnectionInfo;
22
23 async fn test_connection(&mut self) -> Result<(), TransportError>;
25}
26
27#[derive(Debug, Clone)]
29pub struct ConnectionInfo {
30 pub host: String,
32 pub port: u16,
34 pub username: String,
36 pub transport_type: TransportType,
38}
39
40#[derive(Debug, Clone, PartialEq, Eq)]
42pub enum TransportType {
43 SshSubprocess,
45 SshLibssh2,
47 Local,
49}
50
51#[derive(Debug, Clone)]
53pub struct SshConfig {
54 pub host: String,
56 pub port: u16,
58 pub username: String,
60 pub key_path: Option<PathBuf>,
62 pub options: HashMap<String, String>,
64 pub connect_timeout: u64,
66 pub command_timeout: u64,
68}
69
70impl Default for SshConfig {
71 fn default() -> Self {
72 Self {
73 host: "localhost".to_string(),
74 port: 22,
75 username: "root".to_string(),
76 key_path: None,
77 options: HashMap::new(),
78 connect_timeout: 30,
79 command_timeout: 300,
80 }
81 }
82}
83
84pub struct StdioTransport {
86 config: SshConfig,
88 ssh_process: Option<Child>,
90 connected: bool,
92}
93
94impl StdioTransport {
95 pub fn new(config: SshConfig) -> Self {
97 Self {
98 config,
99 ssh_process: None,
100 connected: false,
101 }
102 }
103
104 fn build_ssh_args(&self) -> Vec<String> {
106 let mut args = vec![
107 "-o".to_string(), "BatchMode=yes".to_string(),
108 "-o".to_string(), "StrictHostKeyChecking=no".to_string(),
109 "-o".to_string(), format!("ConnectTimeout={}", self.config.connect_timeout),
110 "-p".to_string(), self.config.port.to_string(),
111 ];
112
113 if let Some(key_path) = &self.config.key_path {
115 args.push("-i".to_string());
116 args.push(key_path.to_string_lossy().to_string());
117 }
118
119 for (key, value) in &self.config.options {
121 args.push("-o".to_string());
122 args.push(format!("{}={}", key, value));
123 }
124
125 args.push(format!("{}@{}", self.config.username, self.config.host));
127
128 args
129 }
130
131 async fn execute_command(&mut self, command: &str) -> Result<String, TransportError> {
133 let mut ssh_args = self.build_ssh_args();
134 ssh_args.push(command.to_string());
135
136 debug!("Executing SSH command: ssh {}", ssh_args.join(" "));
137
138 let output = Command::new("ssh")
139 .args(&ssh_args)
140 .output()
141 .await
142 .map_err(|e| TransportError::Connection(format!("Failed to execute SSH: {}", e)))?;
143
144 if !output.status.success() {
145 let stderr = String::from_utf8_lossy(&output.stderr);
146 return Err(TransportError::Connection(format!("SSH command failed: {}", stderr)));
147 }
148
149 Ok(String::from_utf8_lossy(&output.stdout).to_string())
150 }
151
152 async fn start_interactive_session(&mut self) -> Result<Child, TransportError> {
154 let ssh_args = self.build_ssh_args();
155
156 debug!("Starting interactive SSH session: ssh {}", ssh_args.join(" "));
157
158 let child = Command::new("ssh")
159 .args(&ssh_args)
160 .stdin(Stdio::piped())
161 .stdout(Stdio::piped())
162 .stderr(Stdio::piped())
163 .spawn()
164 .map_err(|e| TransportError::Connection(format!("Failed to start SSH: {}", e)))?;
165
166 Ok(child)
167 }
168}
169
170#[async_trait]
171impl Transport for StdioTransport {
172 async fn connect(&mut self) -> Result<Connection, TransportError> {
173 if self.connected {
174 return Ok(Connection::new(self.ssh_process.take()));
175 }
176
177 info!("Connecting to {}@{}:{}", self.config.username, self.config.host, self.config.port);
178
179 self.test_connection().await?;
181
182 let child = self.start_interactive_session().await?;
184 self.ssh_process = Some(child);
185 self.connected = true;
186
187 info!("Successfully connected to {}@{}", self.config.username, self.config.host);
188 Ok(Connection::new(self.ssh_process.take()))
189 }
190
191 async fn bootstrap_agent(&mut self, agent_binary: &[u8]) -> Result<(), TransportError> {
192 info!("Bootstrapping agent on {}@{}", self.config.username, self.config.host);
193
194 let platform_info = self.execute_command("uname -m && uname -s").await?;
196 debug!("Remote platform: {}", platform_info.trim());
197
198 let bootstrap_script = format!(
200 r#"
201 set -e
202
203 # Try memfd_create approach first (Linux)
204 if command -v python3 >/dev/null 2>&1; then
205 python3 -c "
206import os, sys
207try:
208 import ctypes
209 libc = ctypes.CDLL('libc.so.6')
210 fd = libc.syscall(319, b'mitoxide-agent', 1) # memfd_create
211 if fd >= 0:
212 os.write(fd, sys.stdin.buffer.read())
213 os.fexecve(fd, ['/proc/self/fd/%d' % fd], os.environ)
214except:
215 pass
216# Fallback to temp file
217import tempfile
218with tempfile.NamedTemporaryFile(delete=False) as f:
219 f.write(sys.stdin.buffer.read())
220 f.flush()
221 os.chmod(f.name, 0o755)
222 os.execv(f.name, [f.name])
223"
224 elif [ -d /tmp ] && [ -w /tmp ]; then
225 # Fallback to /tmp
226 AGENT_PATH="/tmp/mitoxide-agent-$$"
227 cat > "$AGENT_PATH"
228 chmod +x "$AGENT_PATH"
229 exec "$AGENT_PATH"
230 rm -f "$AGENT_PATH" 2>/dev/null || true
231 else
232 echo "No suitable location for agent bootstrap" >&2
233 exit 1
234 fi
235 "#
236 );
237
238 let mut ssh_args = self.build_ssh_args();
240 ssh_args.push("bash".to_string());
241
242 let mut child = Command::new("ssh")
243 .args(&ssh_args)
244 .stdin(Stdio::piped())
245 .stdout(Stdio::piped())
246 .stderr(Stdio::piped())
247 .spawn()
248 .map_err(|e| TransportError::Bootstrap(format!("Failed to start SSH for bootstrap: {}", e)))?;
249
250 if let Some(stdin) = child.stdin.as_mut() {
252 use tokio::io::AsyncWriteExt;
253
254 stdin.write_all(bootstrap_script.as_bytes()).await
255 .map_err(|e| TransportError::Bootstrap(format!("Failed to write bootstrap script: {}", e)))?;
256
257 stdin.write_all(agent_binary).await
258 .map_err(|e| TransportError::Bootstrap(format!("Failed to write agent binary: {}", e)))?;
259
260 stdin.shutdown().await
261 .map_err(|e| TransportError::Bootstrap(format!("Failed to close stdin: {}", e)))?;
262 }
263
264 let output = child.wait_with_output().await
266 .map_err(|e| TransportError::Bootstrap(format!("Bootstrap process failed: {}", e)))?;
267
268 if !output.status.success() {
269 let stderr = String::from_utf8_lossy(&output.stderr);
270 return Err(TransportError::Bootstrap(format!("Agent bootstrap failed: {}", stderr)));
271 }
272
273 info!("Agent successfully bootstrapped on {}@{}", self.config.username, self.config.host);
274 Ok(())
275 }
276
277 fn connection_info(&self) -> ConnectionInfo {
278 ConnectionInfo {
279 host: self.config.host.clone(),
280 port: self.config.port,
281 username: self.config.username.clone(),
282 transport_type: TransportType::SshSubprocess,
283 }
284 }
285
286 async fn test_connection(&mut self) -> Result<(), TransportError> {
287 debug!("Testing connection to {}@{}", self.config.username, self.config.host);
288
289 let result = self.execute_command("echo 'connection_test'").await?;
291
292 if !result.trim().contains("connection_test") {
293 return Err(TransportError::Connection("Connection test failed".to_string()));
294 }
295
296 debug!("Connection test successful");
297 Ok(())
298 }
299}
300
301impl Drop for StdioTransport {
302 fn drop(&mut self) {
303 if let Some(mut child) = self.ssh_process.take() {
304 let _ = child.start_kill();
306 }
307 }
308}
309
310#[cfg(test)]
311mod tests {
312 use super::*;
313
314
315 #[test]
316 fn test_ssh_config_default() {
317 let config = SshConfig::default();
318 assert_eq!(config.host, "localhost");
319 assert_eq!(config.port, 22);
320 assert_eq!(config.username, "root");
321 assert_eq!(config.connect_timeout, 30);
322 assert_eq!(config.command_timeout, 300);
323 }
324
325 #[test]
326 fn test_stdio_transport_creation() {
327 let config = SshConfig::default();
328 let transport = StdioTransport::new(config.clone());
329
330 let info = transport.connection_info();
331 assert_eq!(info.host, config.host);
332 assert_eq!(info.port, config.port);
333 assert_eq!(info.username, config.username);
334 assert_eq!(info.transport_type, TransportType::SshSubprocess);
335 }
336
337 #[test]
338 fn test_ssh_args_building() {
339 let mut config = SshConfig::default();
340 config.host = "example.com".to_string();
341 config.port = 2222;
342 config.username = "testuser".to_string();
343 config.key_path = Some(PathBuf::from("/path/to/key"));
344 config.options.insert("ServerAliveInterval".to_string(), "60".to_string());
345
346 let transport = StdioTransport::new(config);
347 let args = transport.build_ssh_args();
348
349 assert!(args.contains(&"-p".to_string()));
350 assert!(args.contains(&"2222".to_string()));
351 assert!(args.contains(&"-i".to_string()));
352 assert!(args.contains(&"/path/to/key".to_string()));
353 assert!(args.contains(&"-o".to_string()));
354 assert!(args.contains(&"ServerAliveInterval=60".to_string()));
355 assert!(args.contains(&"testuser@example.com".to_string()));
356 }
357
358 #[test]
359 fn test_connection_info() {
360 let config = SshConfig {
361 host: "test.example.com".to_string(),
362 port: 2222,
363 username: "testuser".to_string(),
364 ..Default::default()
365 };
366
367 let transport = StdioTransport::new(config);
368 let info = transport.connection_info();
369
370 assert_eq!(info.host, "test.example.com");
371 assert_eq!(info.port, 2222);
372 assert_eq!(info.username, "testuser");
373 assert_eq!(info.transport_type, TransportType::SshSubprocess);
374 }
375
376 #[cfg(test)]
378 pub struct MockTransport {
379 should_fail: bool,
380 connection_info: ConnectionInfo,
381 }
382
383 #[cfg(test)]
384 impl MockTransport {
385 pub fn new(should_fail: bool) -> Self {
386 Self {
387 should_fail,
388 connection_info: ConnectionInfo {
389 host: "mock.example.com".to_string(),
390 port: 22,
391 username: "mockuser".to_string(),
392 transport_type: TransportType::Local,
393 },
394 }
395 }
396 }
397
398 #[cfg(test)]
399 #[async_trait]
400 impl Transport for MockTransport {
401 async fn connect(&mut self) -> Result<Connection, TransportError> {
402 if self.should_fail {
403 Err(TransportError::Connection("Mock connection failed".to_string()))
404 } else {
405 Ok(Connection::new(None))
406 }
407 }
408
409 async fn bootstrap_agent(&mut self, _agent_binary: &[u8]) -> Result<(), TransportError> {
410 if self.should_fail {
411 Err(TransportError::Bootstrap("Mock bootstrap failed".to_string()))
412 } else {
413 Ok(())
414 }
415 }
416
417 fn connection_info(&self) -> ConnectionInfo {
418 self.connection_info.clone()
419 }
420
421 async fn test_connection(&mut self) -> Result<(), TransportError> {
422 if self.should_fail {
423 Err(TransportError::Connection("Mock test failed".to_string()))
424 } else {
425 Ok(())
426 }
427 }
428 }
429
430 #[tokio::test]
431 async fn test_mock_transport_success() {
432 let mut transport = MockTransport::new(false);
433
434 assert!(transport.test_connection().await.is_ok());
435 assert!(transport.connect().await.is_ok());
436 assert!(transport.bootstrap_agent(b"test").await.is_ok());
437 }
438
439 #[tokio::test]
440 async fn test_mock_transport_failure() {
441 let mut transport = MockTransport::new(true);
442
443 assert!(transport.test_connection().await.is_err());
444 assert!(transport.connect().await.is_err());
445 assert!(transport.bootstrap_agent(b"test").await.is_err());
446 }
447}