Skip to main content

bamboo_tools/tools/
notebook_edit.rs

1use async_trait::async_trait;
2use bamboo_agent_core::{Tool, ToolError, ToolResult};
3use serde::Deserialize;
4use serde_json::{json, Value};
5use std::path::Path;
6
7use super::file_change;
8
9#[derive(Debug, Deserialize)]
10#[serde(rename_all = "lowercase")]
11enum CellType {
12    Code,
13    Markdown,
14}
15
16impl CellType {
17    fn as_str(&self) -> &'static str {
18        match self {
19            CellType::Code => "code",
20            CellType::Markdown => "markdown",
21        }
22    }
23}
24
25#[derive(Debug, Deserialize, Default)]
26#[serde(rename_all = "lowercase")]
27enum EditMode {
28    #[default]
29    Replace,
30    Insert,
31    Delete,
32}
33
34#[derive(Debug, Deserialize)]
35struct NotebookEditArgs {
36    notebook_path: String,
37    #[serde(default)]
38    cell_id: Option<String>,
39    new_source: String,
40    #[serde(default)]
41    cell_type: Option<CellType>,
42    #[serde(default)]
43    edit_mode: Option<EditMode>,
44}
45
46pub struct NotebookEditTool;
47
48impl NotebookEditTool {
49    pub fn new() -> Self {
50        Self
51    }
52
53    fn source_to_lines(source: &str) -> Vec<Value> {
54        if source.is_empty() {
55            return vec![Value::String(String::new())];
56        }
57
58        source
59            .lines()
60            .map(|line| Value::String(format!("{}\n", line)))
61            .collect()
62    }
63
64    fn find_cell_index(cells: &[Value], cell_id: Option<&str>) -> Option<usize> {
65        let cell_id = cell_id?;
66
67        cells.iter().position(|cell| {
68            cell.get("id")
69                .and_then(|value| value.as_str())
70                .map(|value| value == cell_id)
71                .unwrap_or(false)
72        })
73    }
74}
75
76impl Default for NotebookEditTool {
77    fn default() -> Self {
78        Self::new()
79    }
80}
81
82#[async_trait]
83impl Tool for NotebookEditTool {
84    fn name(&self) -> &str {
85        "NotebookEdit"
86    }
87
88    fn description(&self) -> &str {
89        "Replace, insert, or delete a Jupyter notebook cell"
90    }
91
92    fn parameters_schema(&self) -> serde_json::Value {
93        json!({
94            "type": "object",
95            "properties": {
96                "notebook_path": {
97                    "type": "string",
98                    "description": "Absolute path to the notebook file"
99                },
100                "cell_id": {
101                    "type": "string",
102                    "description": "Cell ID to edit"
103                },
104                "new_source": {
105                    "type": "string",
106                    "description": "New source content for the cell"
107                },
108                "cell_type": {
109                    "type": "string",
110                    "enum": ["code", "markdown"],
111                    "description": "Cell type when inserting"
112                },
113                "edit_mode": {
114                    "type": "string",
115                    "enum": ["replace", "insert", "delete"],
116                    "description": "Edit mode"
117                }
118            },
119            "required": ["notebook_path", "new_source"],
120            "additionalProperties": false
121        })
122    }
123
124    async fn execute(&self, args: serde_json::Value) -> Result<ToolResult, ToolError> {
125        let parsed: NotebookEditArgs = serde_json::from_value(args).map_err(|e| {
126            ToolError::InvalidArguments(format!("Invalid NotebookEdit args: {}", e))
127        })?;
128
129        let path = Path::new(parsed.notebook_path.trim());
130        if !path.is_absolute() {
131            return Err(ToolError::InvalidArguments(
132                "notebook_path must be absolute".to_string(),
133            ));
134        }
135
136        let content = tokio::fs::read_to_string(path)
137            .await
138            .map_err(|e| ToolError::Execution(format!("Failed to read notebook: {}", e)))?;
139        let checkpoint = file_change::create_checkpoint(path, Some(content.as_bytes())).await?;
140
141        let mut notebook: Value = serde_json::from_str(&content)
142            .map_err(|e| ToolError::Execution(format!("Invalid notebook JSON: {}", e)))?;
143
144        let cells = notebook
145            .get_mut("cells")
146            .and_then(|value| value.as_array_mut())
147            .ok_or_else(|| ToolError::Execution("Notebook missing 'cells' array".to_string()))?;
148
149        let edit_mode = parsed.edit_mode.unwrap_or_default();
150        let cell_id = parsed
151            .cell_id
152            .as_deref()
153            .map(str::trim)
154            .filter(|value| !value.is_empty());
155        let target_index = Self::find_cell_index(cells, cell_id);
156
157        match edit_mode {
158            EditMode::Replace => {
159                if cell_id.is_none() {
160                    return Err(ToolError::InvalidArguments(
161                        "cell_id is required when edit_mode=replace".to_string(),
162                    ));
163                }
164                let idx = target_index.ok_or_else(|| {
165                    ToolError::Execution("Target cell not found for replace".to_string())
166                })?;
167                if let Some(cell) = cells.get_mut(idx) {
168                    cell["source"] = Value::Array(Self::source_to_lines(&parsed.new_source));
169                }
170            }
171            EditMode::Insert => {
172                let cell_type = parsed.cell_type.ok_or_else(|| {
173                    ToolError::InvalidArguments(
174                        "cell_type is required when edit_mode=insert".to_string(),
175                    )
176                })?;
177                let new_cell = json!({
178                    "id": uuid::Uuid::new_v4().to_string(),
179                    "cell_type": cell_type.as_str(),
180                    "metadata": {},
181                    "source": Self::source_to_lines(&parsed.new_source),
182                    "outputs": [],
183                    "execution_count": serde_json::Value::Null,
184                });
185
186                if let Some(cell_id) = cell_id {
187                    let idx = target_index.ok_or_else(|| {
188                        ToolError::Execution(format!(
189                            "Target cell '{}' not found for insert",
190                            cell_id
191                        ))
192                    })?;
193                    cells.insert(idx + 1, new_cell);
194                } else {
195                    cells.push(new_cell);
196                }
197            }
198            EditMode::Delete => {
199                if cell_id.is_none() {
200                    return Err(ToolError::InvalidArguments(
201                        "cell_id is required when edit_mode=delete".to_string(),
202                    ));
203                }
204                let idx = target_index.ok_or_else(|| {
205                    ToolError::Execution("Target cell not found for delete".to_string())
206                })?;
207                cells.remove(idx);
208            }
209        }
210
211        let updated = serde_json::to_string_pretty(&notebook)
212            .map_err(|e| ToolError::Execution(format!("Failed to serialize notebook: {}", e)))?;
213
214        file_change::atomic_write_text(path, &updated).await?;
215
216        let payload = file_change::build_file_change_payload(
217            "NotebookEdit",
218            path,
219            format!("Notebook updated: {}", parsed.notebook_path),
220            checkpoint,
221            &content,
222            &updated,
223        );
224
225        Ok(ToolResult {
226            success: true,
227            result: payload,
228            display_preference: Some("Default".to_string()),
229            images: Vec::new(),
230        })
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237
238    fn sample_notebook() -> serde_json::Value {
239        json!({
240            "cells": [
241                {
242                    "id": "cell-a",
243                    "cell_type": "code",
244                    "metadata": {},
245                    "source": ["print('a')\n"],
246                    "outputs": [],
247                    "execution_count": 1
248                },
249                {
250                    "id": "cell-b",
251                    "cell_type": "markdown",
252                    "metadata": {},
253                    "source": ["# title\n"]
254                }
255            ],
256            "metadata": {},
257            "nbformat": 4,
258            "nbformat_minor": 5
259        })
260    }
261
262    async fn write_notebook(path: &Path) {
263        tokio::fs::write(
264            path,
265            serde_json::to_string_pretty(&sample_notebook()).unwrap(),
266        )
267        .await
268        .unwrap();
269    }
270
271    #[tokio::test]
272    async fn replace_requires_cell_id() {
273        let file = tempfile::NamedTempFile::new().unwrap();
274        write_notebook(file.path()).await;
275
276        let tool = NotebookEditTool::new();
277        let result = tool
278            .execute(json!({
279                "notebook_path": file.path(),
280                "edit_mode": "replace",
281                "new_source": "updated"
282            }))
283            .await;
284
285        assert!(matches!(result, Err(ToolError::InvalidArguments(msg)) if msg.contains("cell_id")));
286    }
287
288    #[tokio::test]
289    async fn delete_requires_cell_id() {
290        let file = tempfile::NamedTempFile::new().unwrap();
291        write_notebook(file.path()).await;
292
293        let tool = NotebookEditTool::new();
294        let result = tool
295            .execute(json!({
296                "notebook_path": file.path(),
297                "edit_mode": "delete",
298                "new_source": ""
299            }))
300            .await;
301
302        assert!(matches!(result, Err(ToolError::InvalidArguments(msg)) if msg.contains("cell_id")));
303    }
304
305    #[tokio::test]
306    async fn insert_without_cell_id_appends_cell() {
307        let file = tempfile::NamedTempFile::new().unwrap();
308        write_notebook(file.path()).await;
309
310        let tool = NotebookEditTool::new();
311        let result = tool
312            .execute(json!({
313                "notebook_path": file.path(),
314                "edit_mode": "insert",
315                "cell_type": "markdown",
316                "new_source": "appended cell"
317            }))
318            .await
319            .unwrap();
320        assert!(result.success);
321
322        let updated: Value =
323            serde_json::from_str(&tokio::fs::read_to_string(file.path()).await.unwrap()).unwrap();
324        let cells = updated["cells"].as_array().unwrap();
325        assert_eq!(cells.len(), 3);
326        let last = cells.last().unwrap();
327        assert_eq!(last["cell_type"], "markdown");
328        let source = last["source"].as_array().unwrap();
329        assert_eq!(source[0], "appended cell\n");
330    }
331
332    #[tokio::test]
333    async fn insert_with_unknown_cell_id_returns_error() {
334        let file = tempfile::NamedTempFile::new().unwrap();
335        write_notebook(file.path()).await;
336
337        let tool = NotebookEditTool::new();
338        let result = tool
339            .execute(json!({
340                "notebook_path": file.path(),
341                "edit_mode": "insert",
342                "cell_id": "does-not-exist",
343                "cell_type": "code",
344                "new_source": "print('x')"
345            }))
346            .await;
347
348        assert!(matches!(result, Err(ToolError::Execution(msg)) if msg.contains("not found")));
349    }
350}