Skip to main content

nexo_driver_loop/workspace/
manager.rs

1//! `WorkspaceManager` — resolve, mkdir, optionally checkout into a
2//! git worktree, and provide checkpoint / rollback / diff_stat
3//! helpers used by the driver loop.
4
5use std::path::{Path, PathBuf};
6
7use nexo_driver_types::Goal;
8use regex::Regex;
9
10use crate::acceptance::ShellRunner;
11use crate::error::DriverError;
12use crate::workspace::git;
13
14#[derive(Clone, Debug)]
15pub enum GitWorktreeMode {
16    Disabled,
17    SourceRepo { path: PathBuf, base_ref: String },
18}
19
20pub struct WorkspaceManager {
21    root: PathBuf,
22    git: GitWorktreeMode,
23    shell: ShellRunner,
24}
25
26impl WorkspaceManager {
27    /// Sentinel returned by `checkpoint` when git mode is `Disabled`.
28    pub const NO_GIT_SENTINEL: &'static str = "<no-git>";
29
30    pub fn new(root: impl Into<PathBuf>) -> Self {
31        Self {
32            root: root.into(),
33            git: GitWorktreeMode::Disabled,
34            shell: ShellRunner::default(),
35        }
36    }
37
38    pub fn with_git(mut self, mode: GitWorktreeMode) -> Self {
39        self.git = mode;
40        self
41    }
42
43    pub fn with_shell(mut self, shell: ShellRunner) -> Self {
44        self.shell = shell;
45        self
46    }
47
48    pub fn root(&self) -> &Path {
49        &self.root
50    }
51
52    pub fn git_mode(&self) -> &GitWorktreeMode {
53        &self.git
54    }
55
56    /// Resolve the goal's workspace path, mkdir/worktree-add it, and
57    /// verify it stays inside `root`. Returns the absolute path the
58    /// harness will `cwd` into.
59    ///
60    /// In `SourceRepo` mode, `goal.workspace` (operator-supplied) is
61    /// IGNORED — we always use `<root>/<goal_id>` so the worktree
62    /// branch name (`nexo-driver/<goal_id>`) lines up with the path.
63    pub async fn ensure(&self, goal: &Goal) -> Result<PathBuf, DriverError> {
64        tokio::fs::create_dir_all(&self.root).await?;
65        let canonical_root = tokio::fs::canonicalize(&self.root).await?;
66
67        // Per-goal source repo override. When
68        // `program_phase_dispatch` detects the active tracker is a
69        // standalone git repo (typical after `init_project`), it
70        // stamps `goal.metadata["worktree.source_repo"]` with that
71        // path. The override takes precedence over `self.git` so
72        // the per-goal worktree clones the right repo even when
73        // the daemon was booted against a different one.
74        let override_source: Option<PathBuf> = goal
75            .metadata
76            .get("worktree.source_repo")
77            .and_then(|v| v.as_str())
78            .filter(|s| !s.is_empty())
79            .map(PathBuf::from)
80            .filter(|p| p.join(".git").exists());
81
82        let resolved = if let Some(repo) = override_source {
83            // Use the operator-provided source repo with the same
84            // base_ref the boot config picked, falling back to a
85            // sane default when the boot mode is `Disabled`.
86            let base_ref = match &self.git {
87                GitWorktreeMode::SourceRepo { base_ref, .. } => base_ref.clone(),
88                GitWorktreeMode::Disabled => "HEAD".to_string(),
89            };
90            GitWorktreeMode::SourceRepo {
91                path: repo,
92                base_ref,
93            }
94        } else {
95            self.git.clone()
96        };
97
98        match &resolved {
99            GitWorktreeMode::Disabled => {
100                let candidate = match &goal.workspace {
101                    Some(p) => PathBuf::from(p),
102                    None => canonical_root.join(goal.id.0.to_string()),
103                };
104                tokio::fs::create_dir_all(&candidate).await?;
105                let canonical = tokio::fs::canonicalize(&candidate).await?;
106                if !canonical.starts_with(&canonical_root) {
107                    return Err(DriverError::WorkspaceTraversal {
108                        path: canonical.display().to_string(),
109                    });
110                }
111                Ok(canonical)
112            }
113            GitWorktreeMode::SourceRepo { path, base_ref } => {
114                let target = canonical_root.join(goal.id.0.to_string());
115                let branch = format!("nexo-driver/{}", goal.id.0);
116                let parent = target.parent().unwrap_or(&canonical_root);
117                tokio::fs::create_dir_all(parent).await?;
118                // If target already exists and is a worktree, the
119                // `worktree add -B` form moves the branch pointer
120                // back to base_ref while keeping the worktree.
121                git::worktree_add(&self.shell, path, &branch, &target, base_ref).await?;
122                let canonical = tokio::fs::canonicalize(&target).await?;
123                if !canonical.starts_with(&canonical_root) {
124                    return Err(DriverError::WorkspaceTraversal {
125                        path: canonical.display().to_string(),
126                    });
127                }
128                Ok(canonical)
129            }
130        }
131    }
132
133    /// Best-effort recursive remove. In `SourceRepo` mode, also
134    /// unregisters the worktree.
135    pub async fn cleanup(&self, path: &Path) -> Result<(), DriverError> {
136        if let GitWorktreeMode::SourceRepo {
137            path: source_repo, ..
138        } = &self.git
139        {
140            let _ = git::worktree_remove(&self.shell, source_repo, path).await;
141        }
142        match tokio::fs::remove_dir_all(path).await {
143            Ok(()) => Ok(()),
144            Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(()),
145            Err(e) => Err(DriverError::Io(e)),
146        }
147    }
148
149    pub async fn checkpoint(&self, workspace: &Path, label: &str) -> Result<String, DriverError> {
150        match &self.git {
151            GitWorktreeMode::Disabled => Ok(Self::NO_GIT_SENTINEL.to_string()),
152            GitWorktreeMode::SourceRepo { .. } => {
153                git::commit_all_with_label(&self.shell, workspace, label).await
154            }
155        }
156    }
157
158    pub async fn rollback(&self, workspace: &Path, sha: &str) -> Result<(), DriverError> {
159        match &self.git {
160            GitWorktreeMode::Disabled => Ok(()),
161            GitWorktreeMode::SourceRepo { .. } => {
162                if !is_valid_sha(sha) {
163                    return Err(DriverError::Workspace(format!(
164                        "rollback: sha {sha:?} is not 7..=40 hex chars"
165                    )));
166                }
167                git::reset_hard(&self.shell, workspace, sha).await
168            }
169        }
170    }
171
172    pub async fn diff_stat(
173        &self,
174        workspace: &Path,
175        since_sha: &str,
176    ) -> Result<String, DriverError> {
177        match &self.git {
178            GitWorktreeMode::Disabled => Ok(String::new()),
179            GitWorktreeMode::SourceRepo { .. } => {
180                if since_sha == Self::NO_GIT_SENTINEL || !is_valid_sha(since_sha) {
181                    return Ok(String::new());
182                }
183                let raw = git::diff_stat(&self.shell, workspace, since_sha).await?;
184                Ok(truncate_to(&raw, 1024))
185            }
186        }
187    }
188}
189
190fn is_valid_sha(s: &str) -> bool {
191    static_re().is_match(s)
192}
193
194fn static_re() -> &'static Regex {
195    use std::sync::OnceLock;
196    static RE: OnceLock<Regex> = OnceLock::new();
197    RE.get_or_init(|| Regex::new(r"^[0-9a-fA-F]{7,40}$").unwrap())
198}
199
200fn truncate_to(s: &str, limit: usize) -> String {
201    if s.len() <= limit {
202        return s.to_string();
203    }
204    let mut end = limit;
205    while end < s.len() && !s.is_char_boundary(end) {
206        end -= 1;
207    }
208    let mut out = s[..end].to_string();
209    out.push_str("\n... (truncated)");
210    out
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216    use nexo_driver_types::{AcceptanceCriterion, BudgetGuards, GoalId};
217    use std::time::Duration;
218    use uuid::Uuid;
219
220    fn goal(workspace: Option<String>) -> Goal {
221        Goal {
222            id: GoalId(Uuid::new_v4()),
223            description: "test".into(),
224            acceptance: vec![AcceptanceCriterion::shell("true")],
225            budget: BudgetGuards {
226                max_turns: 1,
227                max_wall_time: Duration::from_secs(60),
228                max_tokens: 1_000,
229                max_consecutive_denies: 1,
230                max_consecutive_errors: 5,
231                max_consecutive_413: 2,
232            },
233            workspace,
234            metadata: serde_json::Map::new(),
235        }
236    }
237
238    #[tokio::test]
239    async fn ensure_disabled_creates_default_subdir() {
240        let dir = tempfile::tempdir().unwrap();
241        let mgr = WorkspaceManager::new(dir.path());
242        let g = goal(None);
243        let path = mgr.ensure(&g).await.unwrap();
244        assert!(path.is_dir());
245        assert!(path.starts_with(dir.path().canonicalize().unwrap()));
246    }
247
248    #[tokio::test]
249    async fn ensure_disabled_rejects_path_traversal() {
250        let root = tempfile::tempdir().unwrap();
251        let outside = tempfile::tempdir().unwrap();
252        let mgr = WorkspaceManager::new(root.path());
253        let g = goal(Some(outside.path().display().to_string()));
254        let err = mgr.ensure(&g).await.unwrap_err();
255        assert!(matches!(err, DriverError::WorkspaceTraversal { .. }));
256    }
257
258    #[tokio::test]
259    async fn cleanup_disabled_is_idempotent() {
260        let dir = tempfile::tempdir().unwrap();
261        let mgr = WorkspaceManager::new(dir.path());
262        let nonexistent = dir.path().join("does/not/exist");
263        mgr.cleanup(&nonexistent).await.unwrap();
264        let sub = dir.path().join("sub");
265        tokio::fs::create_dir_all(&sub).await.unwrap();
266        mgr.cleanup(&sub).await.unwrap();
267        mgr.cleanup(&sub).await.unwrap();
268    }
269
270    #[tokio::test]
271    async fn disabled_checkpoint_returns_sentinel() {
272        let dir = tempfile::tempdir().unwrap();
273        let mgr = WorkspaceManager::new(dir.path());
274        let sha = mgr.checkpoint(dir.path(), "x").await.unwrap();
275        assert_eq!(sha, WorkspaceManager::NO_GIT_SENTINEL);
276    }
277
278    #[tokio::test]
279    async fn disabled_rollback_is_noop() {
280        let dir = tempfile::tempdir().unwrap();
281        let mgr = WorkspaceManager::new(dir.path());
282        // Even with a bogus sha; Disabled mode short-circuits.
283        mgr.rollback(dir.path(), "deadbeef").await.unwrap();
284    }
285
286    #[tokio::test]
287    async fn invalid_sha_in_sourcerepo_mode_rejected() {
288        let dir = tempfile::tempdir().unwrap();
289        let mgr = WorkspaceManager::new(dir.path()).with_git(GitWorktreeMode::SourceRepo {
290            path: dir.path().to_path_buf(),
291            base_ref: "HEAD".into(),
292        });
293        let err = mgr.rollback(dir.path(), "zzz").await.unwrap_err();
294        assert!(matches!(err, DriverError::Workspace(_)));
295    }
296
297    #[tokio::test]
298    async fn truncate_to_appends_marker() {
299        let s: String = "x".repeat(2000);
300        let out = truncate_to(&s, 100);
301        assert!(out.starts_with(&"x".repeat(100)));
302        assert!(out.contains("(truncated)"));
303    }
304}