Skip to main content

vdsl_sync/infra/
shell.rs

1//! Remote shell abstraction for executing commands on different hosts.
2//!
3//! - [`LocalShell`]: runs via `tokio::process::Command` on the local machine
4//! - `PodShell`: runs via RunPod exec API on a GPU pod (downstream crate)
5//! - `SshShell`: runs via SSH (future)
6//!
7//! [`StorageBackend`](super::backend::StorageBackend) implementations compose
8//! a `RemoteShell` to run transfer commands (rclone, rsync, etc.) on the
9//! appropriate host.
10
11use async_trait::async_trait;
12
13use crate::infra::error::InfraError;
14
15/// Output from a shell command execution.
16#[derive(Debug, Clone)]
17pub struct ShellOutput {
18    pub stdout: String,
19    pub stderr: String,
20    pub success: bool,
21    pub exit_code: Option<i32>,
22}
23
24/// Per-file inspection result from batch_inspect.
25#[derive(Debug, Clone)]
26pub struct FileInspection {
27    /// Relative path (same key as input).
28    pub relative_path: String,
29    /// SHA-256 hex hash of the file content.
30    pub sha256: String,
31    /// File size in bytes.
32    pub size: u64,
33}
34
35/// Abstract shell for executing commands on a location's host.
36#[async_trait]
37pub trait RemoteShell: Send + Sync {
38    /// Execute a command on this host.
39    ///
40    /// `args[0]` is the program name, `args[1..]` are arguments.
41    async fn exec(
42        &self,
43        args: &[&str],
44        timeout_secs: Option<u64>,
45    ) -> Result<ShellOutput, InfraError>;
46
47    /// Execute a shell script on this host.
48    ///
49    /// Default: `exec(&["sh", "-c", script])`.
50    /// Remote shells may override to use file-based transfer (SCP)
51    /// to avoid shell escaping issues with SSH.
52    async fn exec_script(
53        &self,
54        script: &str,
55        timeout_secs: Option<u64>,
56    ) -> Result<ShellOutput, InfraError> {
57        self.exec(&["sh", "-c", script], timeout_secs).await
58    }
59
60    /// Batch inspect files: get sha256 + size for ALL paths in one exec call.
61    ///
62    /// Constructs a single shell script that processes every file in the list
63    /// and outputs `<sha256> <size> <relative_path>` per line. Parsed on return.
64    ///
65    /// Timeout scales with file count: base 30s + 2s per file.
66    async fn batch_inspect(
67        &self,
68        root: &str,
69        relative_paths: &[String],
70    ) -> Result<Vec<FileInspection>, InfraError> {
71        if relative_paths.is_empty() {
72            return Ok(Vec::new());
73        }
74
75        // Build heredoc file list embedded in a single sh -c script.
76        // Each file is read line-by-line, sha256sum + stat in one pass.
77        let mut script = format!(
78            "cd '{}' && while IFS= read -r f; do \
79             h=$(sha256sum \"$f\" 2>/dev/null | cut -d' ' -f1); \
80             s=$(stat --format=%s \"$f\" 2>/dev/null || echo 0); \
81             [ -n \"$h\" ] && printf '%s %s %s\\n' \"$h\" \"$s\" \"$f\"; \
82             done <<'__VDSL_FILELIST__'\n",
83            root.replace('\'', "'\\''")
84        );
85        for rel in relative_paths {
86            script.push_str(rel);
87            script.push('\n');
88        }
89        script.push_str("__VDSL_FILELIST__");
90
91        let timeout = 30 + (relative_paths.len() as u64 * 2);
92        let output = self.exec(&["sh", "-c", &script], Some(timeout)).await?;
93
94        if !output.success {
95            return Err(InfraError::Transfer {
96                reason: format!("batch_inspect failed: {}", output.stderr.trim()),
97            });
98        }
99
100        let mut results = Vec::with_capacity(relative_paths.len());
101        for line in output.stdout.lines() {
102            // Format: <sha256_hex> <size> <relative_path>
103            let mut parts = line.splitn(3, ' ');
104            let sha256 = match parts.next() {
105                Some(h) if h.len() == 64 => h.to_string(),
106                _ => continue,
107            };
108            let size = parts
109                .next()
110                .and_then(|s| s.parse::<u64>().ok())
111                .unwrap_or(0);
112            let relative_path = match parts.next() {
113                Some(p) if !p.is_empty() => p.to_string(),
114                _ => continue,
115            };
116            results.push(FileInspection {
117                relative_path,
118                sha256,
119                size,
120            });
121        }
122
123        Ok(results)
124    }
125}
126
127/// Execute commands on the local machine via `tokio::process::Command`.
128pub struct LocalShell;
129
130const LOCAL_DEFAULT_TIMEOUT_SECS: u64 = 600;
131
132#[async_trait]
133impl RemoteShell for LocalShell {
134    async fn exec(
135        &self,
136        args: &[&str],
137        timeout_secs: Option<u64>,
138    ) -> Result<ShellOutput, InfraError> {
139        if args.is_empty() {
140            return Err(InfraError::Transfer {
141                reason: "empty command".into(),
142            });
143        }
144
145        let mut cmd = tokio::process::Command::new(args[0]);
146        if args.len() > 1 {
147            cmd.args(&args[1..]);
148        }
149
150        let timeout =
151            std::time::Duration::from_secs(timeout_secs.unwrap_or(LOCAL_DEFAULT_TIMEOUT_SECS));
152
153        let output = tokio::time::timeout(timeout, cmd.output())
154            .await
155            .map_err(|_| -> InfraError {
156                InfraError::Transfer {
157                    reason: format!(
158                        "command timed out after {}s: {}",
159                        timeout.as_secs(),
160                        args.join(" ")
161                    ),
162                }
163            })?
164            .map_err(|e| -> InfraError {
165                InfraError::Transfer {
166                    reason: format!("exec failed ({}): {e}", args[0]),
167                }
168            })?;
169
170        Ok(ShellOutput {
171            stdout: String::from_utf8_lossy(&output.stdout).to_string(),
172            stderr: String::from_utf8_lossy(&output.stderr).to_string(),
173            success: output.status.success(),
174            exit_code: output.status.code(),
175        })
176    }
177}
178
179/// Mock shell for testing — returns configurable responses.
180#[cfg(any(test, feature = "test-utils"))]
181#[allow(dead_code)]
182pub mod mock {
183    use super::*;
184    use std::collections::HashMap;
185    use tokio::sync::Mutex;
186
187    /// Mock file entry with optional hash and size.
188    #[derive(Clone)]
189    pub struct MockFile {
190        pub sha256: String,
191        pub size: u64,
192    }
193
194    impl MockFile {
195        pub fn new(sha256: impl Into<String>, size: u64) -> Self {
196            Self {
197                sha256: sha256.into(),
198                size,
199            }
200        }
201    }
202
203    /// A mock RemoteShell that simulates file operations on a remote host.
204    ///
205    /// Supports:
206    /// - `test -f <path>` — file existence check
207    /// - `sha256sum <path>` — returns configured hash
208    /// - `stat --format=%s <path>` — returns configured size
209    ///
210    /// - `exec_log`: records all commands executed (for assertions)
211    pub struct MockShell {
212        files: Mutex<HashMap<String, MockFile>>,
213        pub exec_log: Mutex<Vec<Vec<String>>>,
214    }
215
216    impl MockShell {
217        /// Create with a set of files (path → MockFile).
218        pub fn with_files(files: impl IntoIterator<Item = (impl Into<String>, MockFile)>) -> Self {
219            Self {
220                files: Mutex::new(files.into_iter().map(|(k, v)| (k.into(), v)).collect()),
221                exec_log: Mutex::new(Vec::new()),
222            }
223        }
224
225        /// Create with paths only (existence checks only, no hash/size).
226        pub fn new(existing: impl IntoIterator<Item = impl Into<String>>) -> Self {
227            Self::with_files(
228                existing
229                    .into_iter()
230                    .map(|p| (p, MockFile::new("0000000000000000", 0))),
231            )
232        }
233    }
234
235    #[async_trait]
236    impl RemoteShell for MockShell {
237        async fn exec(
238            &self,
239            args: &[&str],
240            _timeout_secs: Option<u64>,
241        ) -> Result<ShellOutput, InfraError> {
242            let owned: Vec<String> = args.iter().map(|s| s.to_string()).collect();
243            self.exec_log.lock().await.push(owned);
244
245            // Simulate `test -f <path>`
246            if args.len() >= 3 && args[0] == "test" && args[1] == "-f" {
247                let path = args[2];
248                let exists = self.files.lock().await.contains_key(path);
249                return Ok(ShellOutput {
250                    stdout: String::new(),
251                    stderr: String::new(),
252                    success: exists,
253                    exit_code: Some(if exists { 0 } else { 1 }),
254                });
255            }
256
257            // Simulate `sha256sum <path>`
258            if args.len() >= 2 && args[0] == "sha256sum" {
259                let path = args[1];
260                let files = self.files.lock().await;
261                if let Some(f) = files.get(path) {
262                    return Ok(ShellOutput {
263                        stdout: format!("{}  {}\n", f.sha256, path),
264                        stderr: String::new(),
265                        success: true,
266                        exit_code: Some(0),
267                    });
268                }
269                return Ok(ShellOutput {
270                    stdout: String::new(),
271                    stderr: format!("sha256sum: {path}: No such file or directory\n"),
272                    success: false,
273                    exit_code: Some(1),
274                });
275            }
276
277            // Simulate `stat --format=%s <path>` (GNU) or `stat -f%z <path>` (BSD)
278            if args.len() >= 3 && args[0] == "stat" {
279                let path = args.last().expect("args is non-empty");
280                let files = self.files.lock().await;
281                if let Some(f) = files.get(*path) {
282                    return Ok(ShellOutput {
283                        stdout: format!("{}\n", f.size),
284                        stderr: String::new(),
285                        success: true,
286                        exit_code: Some(0),
287                    });
288                }
289                return Ok(ShellOutput {
290                    stdout: String::new(),
291                    stderr: format!("stat: cannot stat '{path}': No such file or directory\n"),
292                    success: false,
293                    exit_code: Some(1),
294                });
295            }
296
297            // Default: success with empty output
298            Ok(ShellOutput {
299                stdout: String::new(),
300                stderr: String::new(),
301                success: true,
302                exit_code: Some(0),
303            })
304        }
305    }
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311
312    #[tokio::test]
313    async fn local_shell_echo() {
314        let shell = LocalShell;
315        let output = shell.exec(&["echo", "hello"], None).await.unwrap();
316        assert!(output.success);
317        assert_eq!(output.stdout.trim(), "hello");
318        assert_eq!(output.exit_code, Some(0));
319    }
320
321    #[tokio::test]
322    async fn local_shell_empty_args() {
323        let shell = LocalShell;
324        let result = shell.exec(&[], None).await;
325        assert!(result.is_err());
326    }
327
328    #[tokio::test]
329    async fn local_shell_nonexistent_command() {
330        let shell = LocalShell;
331        let result = shell.exec(&["__nonexistent_command_12345__"], None).await;
332        assert!(result.is_err());
333    }
334
335    #[tokio::test]
336    async fn local_shell_exit_code() {
337        let shell = LocalShell;
338        let output = shell.exec(&["false"], None).await.unwrap();
339        assert!(!output.success);
340        assert_ne!(output.exit_code, Some(0));
341    }
342}