nexo_driver_loop/workspace/
manager.rs1use std::path::{Path, PathBuf};
6
7use nexo_driver_types::Goal;
8use regex::Regex;
9
10use crate::acceptance::ShellRunner;
11use crate::error::DriverError;
12use crate::workspace::git;
13
14#[derive(Clone, Debug)]
15pub enum GitWorktreeMode {
16 Disabled,
17 SourceRepo { path: PathBuf, base_ref: String },
18}
19
20pub struct WorkspaceManager {
21 root: PathBuf,
22 git: GitWorktreeMode,
23 shell: ShellRunner,
24}
25
26impl WorkspaceManager {
27 pub const NO_GIT_SENTINEL: &'static str = "<no-git>";
29
30 pub fn new(root: impl Into<PathBuf>) -> Self {
31 Self {
32 root: root.into(),
33 git: GitWorktreeMode::Disabled,
34 shell: ShellRunner::default(),
35 }
36 }
37
38 pub fn with_git(mut self, mode: GitWorktreeMode) -> Self {
39 self.git = mode;
40 self
41 }
42
43 pub fn with_shell(mut self, shell: ShellRunner) -> Self {
44 self.shell = shell;
45 self
46 }
47
48 pub fn root(&self) -> &Path {
49 &self.root
50 }
51
52 pub fn git_mode(&self) -> &GitWorktreeMode {
53 &self.git
54 }
55
56 pub async fn ensure(&self, goal: &Goal) -> Result<PathBuf, DriverError> {
64 tokio::fs::create_dir_all(&self.root).await?;
65 let canonical_root = tokio::fs::canonicalize(&self.root).await?;
66
67 let override_source: Option<PathBuf> = goal
75 .metadata
76 .get("worktree.source_repo")
77 .and_then(|v| v.as_str())
78 .filter(|s| !s.is_empty())
79 .map(PathBuf::from)
80 .filter(|p| p.join(".git").exists());
81
82 let resolved = if let Some(repo) = override_source {
83 let base_ref = match &self.git {
87 GitWorktreeMode::SourceRepo { base_ref, .. } => base_ref.clone(),
88 GitWorktreeMode::Disabled => "HEAD".to_string(),
89 };
90 GitWorktreeMode::SourceRepo {
91 path: repo,
92 base_ref,
93 }
94 } else {
95 self.git.clone()
96 };
97
98 match &resolved {
99 GitWorktreeMode::Disabled => {
100 let candidate = match &goal.workspace {
101 Some(p) => PathBuf::from(p),
102 None => canonical_root.join(goal.id.0.to_string()),
103 };
104 tokio::fs::create_dir_all(&candidate).await?;
105 let canonical = tokio::fs::canonicalize(&candidate).await?;
106 if !canonical.starts_with(&canonical_root) {
107 return Err(DriverError::WorkspaceTraversal {
108 path: canonical.display().to_string(),
109 });
110 }
111 Ok(canonical)
112 }
113 GitWorktreeMode::SourceRepo { path, base_ref } => {
114 let target = canonical_root.join(goal.id.0.to_string());
115 let branch = format!("nexo-driver/{}", goal.id.0);
116 let parent = target.parent().unwrap_or(&canonical_root);
117 tokio::fs::create_dir_all(parent).await?;
118 git::worktree_add(&self.shell, path, &branch, &target, base_ref).await?;
122 let canonical = tokio::fs::canonicalize(&target).await?;
123 if !canonical.starts_with(&canonical_root) {
124 return Err(DriverError::WorkspaceTraversal {
125 path: canonical.display().to_string(),
126 });
127 }
128 Ok(canonical)
129 }
130 }
131 }
132
133 pub async fn cleanup(&self, path: &Path) -> Result<(), DriverError> {
136 if let GitWorktreeMode::SourceRepo {
137 path: source_repo, ..
138 } = &self.git
139 {
140 let _ = git::worktree_remove(&self.shell, source_repo, path).await;
141 }
142 match tokio::fs::remove_dir_all(path).await {
143 Ok(()) => Ok(()),
144 Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(()),
145 Err(e) => Err(DriverError::Io(e)),
146 }
147 }
148
149 pub async fn checkpoint(&self, workspace: &Path, label: &str) -> Result<String, DriverError> {
150 match &self.git {
151 GitWorktreeMode::Disabled => Ok(Self::NO_GIT_SENTINEL.to_string()),
152 GitWorktreeMode::SourceRepo { .. } => {
153 git::commit_all_with_label(&self.shell, workspace, label).await
154 }
155 }
156 }
157
158 pub async fn rollback(&self, workspace: &Path, sha: &str) -> Result<(), DriverError> {
159 match &self.git {
160 GitWorktreeMode::Disabled => Ok(()),
161 GitWorktreeMode::SourceRepo { .. } => {
162 if !is_valid_sha(sha) {
163 return Err(DriverError::Workspace(format!(
164 "rollback: sha {sha:?} is not 7..=40 hex chars"
165 )));
166 }
167 git::reset_hard(&self.shell, workspace, sha).await
168 }
169 }
170 }
171
172 pub async fn diff_stat(
173 &self,
174 workspace: &Path,
175 since_sha: &str,
176 ) -> Result<String, DriverError> {
177 match &self.git {
178 GitWorktreeMode::Disabled => Ok(String::new()),
179 GitWorktreeMode::SourceRepo { .. } => {
180 if since_sha == Self::NO_GIT_SENTINEL || !is_valid_sha(since_sha) {
181 return Ok(String::new());
182 }
183 let raw = git::diff_stat(&self.shell, workspace, since_sha).await?;
184 Ok(truncate_to(&raw, 1024))
185 }
186 }
187 }
188}
189
190fn is_valid_sha(s: &str) -> bool {
191 static_re().is_match(s)
192}
193
194fn static_re() -> &'static Regex {
195 use std::sync::OnceLock;
196 static RE: OnceLock<Regex> = OnceLock::new();
197 RE.get_or_init(|| Regex::new(r"^[0-9a-fA-F]{7,40}$").unwrap())
198}
199
200fn truncate_to(s: &str, limit: usize) -> String {
201 if s.len() <= limit {
202 return s.to_string();
203 }
204 let mut end = limit;
205 while end < s.len() && !s.is_char_boundary(end) {
206 end -= 1;
207 }
208 let mut out = s[..end].to_string();
209 out.push_str("\n... (truncated)");
210 out
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216 use nexo_driver_types::{AcceptanceCriterion, BudgetGuards, GoalId};
217 use std::time::Duration;
218 use uuid::Uuid;
219
220 fn goal(workspace: Option<String>) -> Goal {
221 Goal {
222 id: GoalId(Uuid::new_v4()),
223 description: "test".into(),
224 acceptance: vec![AcceptanceCriterion::shell("true")],
225 budget: BudgetGuards {
226 max_turns: 1,
227 max_wall_time: Duration::from_secs(60),
228 max_tokens: 1_000,
229 max_consecutive_denies: 1,
230 max_consecutive_errors: 5,
231 max_consecutive_413: 2,
232 },
233 workspace,
234 metadata: serde_json::Map::new(),
235 }
236 }
237
238 #[tokio::test]
239 async fn ensure_disabled_creates_default_subdir() {
240 let dir = tempfile::tempdir().unwrap();
241 let mgr = WorkspaceManager::new(dir.path());
242 let g = goal(None);
243 let path = mgr.ensure(&g).await.unwrap();
244 assert!(path.is_dir());
245 assert!(path.starts_with(dir.path().canonicalize().unwrap()));
246 }
247
248 #[tokio::test]
249 async fn ensure_disabled_rejects_path_traversal() {
250 let root = tempfile::tempdir().unwrap();
251 let outside = tempfile::tempdir().unwrap();
252 let mgr = WorkspaceManager::new(root.path());
253 let g = goal(Some(outside.path().display().to_string()));
254 let err = mgr.ensure(&g).await.unwrap_err();
255 assert!(matches!(err, DriverError::WorkspaceTraversal { .. }));
256 }
257
258 #[tokio::test]
259 async fn cleanup_disabled_is_idempotent() {
260 let dir = tempfile::tempdir().unwrap();
261 let mgr = WorkspaceManager::new(dir.path());
262 let nonexistent = dir.path().join("does/not/exist");
263 mgr.cleanup(&nonexistent).await.unwrap();
264 let sub = dir.path().join("sub");
265 tokio::fs::create_dir_all(&sub).await.unwrap();
266 mgr.cleanup(&sub).await.unwrap();
267 mgr.cleanup(&sub).await.unwrap();
268 }
269
270 #[tokio::test]
271 async fn disabled_checkpoint_returns_sentinel() {
272 let dir = tempfile::tempdir().unwrap();
273 let mgr = WorkspaceManager::new(dir.path());
274 let sha = mgr.checkpoint(dir.path(), "x").await.unwrap();
275 assert_eq!(sha, WorkspaceManager::NO_GIT_SENTINEL);
276 }
277
278 #[tokio::test]
279 async fn disabled_rollback_is_noop() {
280 let dir = tempfile::tempdir().unwrap();
281 let mgr = WorkspaceManager::new(dir.path());
282 mgr.rollback(dir.path(), "deadbeef").await.unwrap();
284 }
285
286 #[tokio::test]
287 async fn invalid_sha_in_sourcerepo_mode_rejected() {
288 let dir = tempfile::tempdir().unwrap();
289 let mgr = WorkspaceManager::new(dir.path()).with_git(GitWorktreeMode::SourceRepo {
290 path: dir.path().to_path_buf(),
291 base_ref: "HEAD".into(),
292 });
293 let err = mgr.rollback(dir.path(), "zzz").await.unwrap_err();
294 assert!(matches!(err, DriverError::Workspace(_)));
295 }
296
297 #[tokio::test]
298 async fn truncate_to_appends_marker() {
299 let s: String = "x".repeat(2000);
300 let out = truncate_to(&s, 100);
301 assert!(out.starts_with(&"x".repeat(100)));
302 assert!(out.contains("(truncated)"));
303 }
304}