Skip to main content

bamboo_tools/tools/
workspace.rs

1use async_trait::async_trait;
2use bamboo_agent_core::{Tool, ToolError, ToolExecutionContext, ToolResult};
3use serde_json::json;
4use std::path::{Path, PathBuf};
5
6use super::workspace_state;
7
8/// Unified workspace tool: get or set the session working directory.
9///
10/// - When called **without** `path`  → returns the current workspace directory.
11/// - When called **with** `path`     → sets the workspace and returns the new path.
12///
13/// This replaces the previous `GetCurrentDir` + `SetWorkspace` pair.
14pub struct WorkspaceTool;
15
16impl WorkspaceTool {
17    pub fn new() -> Self {
18        Self
19    }
20}
21
22impl Default for WorkspaceTool {
23    fn default() -> Self {
24        Self::new()
25    }
26}
27
28#[async_trait]
29impl Tool for WorkspaceTool {
30    fn name(&self) -> &str {
31        "Workspace"
32    }
33
34    fn description(&self) -> &str {
35        "Get or set the current session workspace directory. Call without 'path' to get the current workspace; call with 'path' to change it."
36    }
37
38    fn mutability(&self) -> crate::ToolMutability {
39        crate::ToolMutability::Mutating
40    }
41
42    fn call_mutability(&self, args: &serde_json::Value) -> crate::ToolMutability {
43        let has_path = args
44            .get("path")
45            .and_then(|v| v.as_str())
46            .map(str::trim)
47            .is_some_and(|v| !v.is_empty());
48        if has_path {
49            crate::ToolMutability::Mutating
50        } else {
51            crate::ToolMutability::ReadOnly
52        }
53    }
54
55    fn call_concurrency_safe(&self, args: &serde_json::Value) -> bool {
56        self.call_mutability(args) == crate::ToolMutability::ReadOnly
57    }
58
59    fn parameters_schema(&self) -> serde_json::Value {
60        json!({
61            "type": "object",
62            "properties": {
63                "path": {
64                    "type": "string",
65                    "description": "Path of the workspace directory to set. Omit to just read the current workspace."
66                }
67            },
68            "additionalProperties": false
69        })
70    }
71
72    async fn execute(&self, args: serde_json::Value) -> Result<ToolResult, ToolError> {
73        self.execute_with_context(args, ToolExecutionContext::none("Workspace"))
74            .await
75    }
76
77    async fn execute_with_context(
78        &self,
79        args: serde_json::Value,
80        ctx: ToolExecutionContext<'_>,
81    ) -> Result<ToolResult, ToolError> {
82        let path_arg = args
83            .get("path")
84            .and_then(|v| v.as_str())
85            .map(|s| s.trim())
86            .filter(|s| !s.is_empty());
87
88        match path_arg {
89            // ── SET mode ──────────────────────────────────────────────
90            Some(path) => {
91                let session_id = ctx.session_id.ok_or_else(|| {
92                    ToolError::Execution(
93                        "Workspace(set) requires a session_id in tool context".to_string(),
94                    )
95                })?;
96
97                let base = workspace_state::workspace_or_process_cwd(Some(session_id));
98                let raw_path = Path::new(path);
99                let path_obj: PathBuf = if raw_path.is_absolute() {
100                    raw_path.to_path_buf()
101                } else {
102                    base.join(raw_path)
103                };
104
105                if !path_obj.exists() {
106                    return Ok(ToolResult {
107                        success: false,
108                        result: format!("Path does not exist: {}", path_obj.display()),
109                        display_preference: Some("error".to_string()),
110                    });
111                }
112                if !path_obj.is_dir() {
113                    return Ok(ToolResult {
114                        success: false,
115                        result: format!("Path is not a directory: {}", path_obj.display()),
116                        display_preference: Some("error".to_string()),
117                    });
118                }
119
120                let absolute_path = path_obj.canonicalize().map_err(|e| {
121                    ToolError::Execution(format!("Failed to canonicalize path: {e}"))
122                })?;
123
124                workspace_state::set_workspace(session_id, absolute_path.clone());
125
126                Ok(ToolResult {
127                    success: true,
128                    result: json!({
129                        "session_id": session_id,
130                        "workspace": bamboo_infrastructure::paths::path_to_display_string(&absolute_path)
131                    })
132                    .to_string(),
133                    display_preference: Some("json".to_string()),
134                })
135            }
136
137            // ── GET mode ──────────────────────────────────────────────
138            None => {
139                if let Some(session_id) = ctx.session_id {
140                    if let Some(workspace) = workspace_state::get_workspace(session_id) {
141                        return Ok(ToolResult {
142                            success: true,
143                            result: bamboo_infrastructure::paths::path_to_display_string(
144                                &workspace,
145                            ),
146                            display_preference: None,
147                        });
148                    }
149                }
150
151                match std::env::current_dir() {
152                    Ok(dir) => Ok(ToolResult {
153                        success: true,
154                        result: bamboo_infrastructure::paths::path_to_display_string(&dir),
155                        display_preference: None,
156                    }),
157                    Err(error) => Ok(ToolResult {
158                        success: false,
159                        result: format!("Failed to get current directory: {error}"),
160                        display_preference: Some("error".to_string()),
161                    }),
162                }
163            }
164        }
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171
172    #[tokio::test]
173    async fn workspace_get_returns_non_empty_path() {
174        let tool = WorkspaceTool::new();
175        let result = tool.execute(json!({})).await.unwrap();
176        assert!(result.success);
177        assert!(!result.result.trim().is_empty());
178    }
179
180    #[tokio::test]
181    async fn workspace_get_prefers_session_workspace() {
182        let dir = tempfile::tempdir().unwrap();
183        let workspace = dir.path().join("workspace");
184        tokio::fs::create_dir_all(&workspace).await.unwrap();
185        let session = format!("session_{}", uuid::Uuid::new_v4());
186        workspace_state::set_workspace(&session, workspace.clone());
187
188        let tool = WorkspaceTool::new();
189        let result = tool
190            .execute_with_context(
191                json!({}),
192                ToolExecutionContext {
193                    session_id: Some(&session),
194                    tool_call_id: "call_1",
195                    event_tx: None,
196                    available_tool_schemas: None,
197                },
198            )
199            .await
200            .unwrap();
201        assert!(result.success);
202        assert_eq!(
203            result.result,
204            bamboo_infrastructure::paths::path_to_display_string(&workspace)
205        );
206    }
207
208    #[tokio::test]
209    async fn workspace_set_changes_session_workspace() {
210        let dir = tempfile::tempdir().unwrap();
211        let workspace = dir.path().join("ws");
212        tokio::fs::create_dir_all(&workspace).await.unwrap();
213        let session = format!("session_{}", uuid::Uuid::new_v4());
214
215        let tool = WorkspaceTool::new();
216        let result = tool
217            .execute_with_context(
218                json!({"path": workspace.to_string_lossy()}),
219                ToolExecutionContext {
220                    session_id: Some(&session),
221                    tool_call_id: "call_1",
222                    event_tx: None,
223                    available_tool_schemas: None,
224                },
225            )
226            .await
227            .unwrap();
228        assert!(result.success);
229
230        // Verify get mode now returns the new workspace
231        let get_result = tool
232            .execute_with_context(
233                json!({}),
234                ToolExecutionContext {
235                    session_id: Some(&session),
236                    tool_call_id: "call_2",
237                    event_tx: None,
238                    available_tool_schemas: None,
239                },
240            )
241            .await
242            .unwrap();
243        assert!(get_result.success);
244        let expected = workspace.canonicalize().unwrap();
245        assert_eq!(
246            get_result.result,
247            bamboo_infrastructure::paths::path_to_display_string(&expected)
248        );
249    }
250
251    #[tokio::test]
252    async fn workspace_set_rejects_missing_path() {
253        let tool = WorkspaceTool::new();
254        let result = tool
255            .execute_with_context(
256                json!({"path": "/tmp/bamboo-no-such-workspace-xyz-99999"}),
257                ToolExecutionContext {
258                    session_id: Some("session_1"),
259                    tool_call_id: "call_1",
260                    event_tx: None,
261                    available_tool_schemas: None,
262                },
263            )
264            .await
265            .unwrap();
266        assert!(!result.success);
267        assert!(result.result.contains("does not exist"));
268    }
269
270    #[tokio::test]
271    async fn workspace_set_requires_session_context() {
272        let tool = WorkspaceTool::new();
273        let err = tool
274            .execute(json!({"path": "/"}))
275            .await
276            .expect_err("missing session should fail");
277        assert!(matches!(err, ToolError::Execution(msg) if msg.contains("session_id")));
278    }
279}