1use async_trait::async_trait;
2use bamboo_agent_core::{Tool, ToolError, ToolResult};
3use serde::Deserialize;
4use serde_json::{json, Value};
5use std::path::Path;
6
7use super::file_change;
8
9#[derive(Debug, Deserialize)]
10#[serde(rename_all = "lowercase")]
11enum CellType {
12 Code,
13 Markdown,
14}
15
16impl CellType {
17 fn as_str(&self) -> &'static str {
18 match self {
19 CellType::Code => "code",
20 CellType::Markdown => "markdown",
21 }
22 }
23}
24
25#[derive(Debug, Deserialize, Default)]
26#[serde(rename_all = "lowercase")]
27enum EditMode {
28 #[default]
29 Replace,
30 Insert,
31 Delete,
32}
33
34#[derive(Debug, Deserialize)]
35struct NotebookEditArgs {
36 notebook_path: String,
37 #[serde(default)]
38 cell_id: Option<String>,
39 new_source: String,
40 #[serde(default)]
41 cell_type: Option<CellType>,
42 #[serde(default)]
43 edit_mode: Option<EditMode>,
44}
45
46pub struct NotebookEditTool;
47
48impl NotebookEditTool {
49 pub fn new() -> Self {
50 Self
51 }
52
53 fn source_to_lines(source: &str) -> Vec<Value> {
54 if source.is_empty() {
55 return vec![Value::String(String::new())];
56 }
57
58 source
59 .lines()
60 .map(|line| Value::String(format!("{}\n", line)))
61 .collect()
62 }
63
64 fn find_cell_index(cells: &[Value], cell_id: Option<&str>) -> Option<usize> {
65 let cell_id = cell_id?;
66
67 cells.iter().position(|cell| {
68 cell.get("id")
69 .and_then(|value| value.as_str())
70 .map(|value| value == cell_id)
71 .unwrap_or(false)
72 })
73 }
74}
75
76impl Default for NotebookEditTool {
77 fn default() -> Self {
78 Self::new()
79 }
80}
81
82#[async_trait]
83impl Tool for NotebookEditTool {
84 fn name(&self) -> &str {
85 "NotebookEdit"
86 }
87
88 fn description(&self) -> &str {
89 "Replace, insert, or delete a Jupyter notebook cell"
90 }
91
92 fn parameters_schema(&self) -> serde_json::Value {
93 json!({
94 "type": "object",
95 "properties": {
96 "notebook_path": {
97 "type": "string",
98 "description": "Absolute path to the notebook file"
99 },
100 "cell_id": {
101 "type": "string",
102 "description": "Cell ID to edit"
103 },
104 "new_source": {
105 "type": "string",
106 "description": "New source content for the cell"
107 },
108 "cell_type": {
109 "type": "string",
110 "enum": ["code", "markdown"],
111 "description": "Cell type when inserting"
112 },
113 "edit_mode": {
114 "type": "string",
115 "enum": ["replace", "insert", "delete"],
116 "description": "Edit mode"
117 }
118 },
119 "required": ["notebook_path", "new_source"],
120 "additionalProperties": false
121 })
122 }
123
124 async fn execute(&self, args: serde_json::Value) -> Result<ToolResult, ToolError> {
125 let parsed: NotebookEditArgs = serde_json::from_value(args).map_err(|e| {
126 ToolError::InvalidArguments(format!("Invalid NotebookEdit args: {}", e))
127 })?;
128
129 let path = Path::new(parsed.notebook_path.trim());
130 if !path.is_absolute() {
131 return Err(ToolError::InvalidArguments(
132 "notebook_path must be absolute".to_string(),
133 ));
134 }
135
136 let content = tokio::fs::read_to_string(path)
137 .await
138 .map_err(|e| ToolError::Execution(format!("Failed to read notebook: {}", e)))?;
139 let checkpoint = file_change::create_checkpoint(path, Some(content.as_bytes())).await?;
140
141 let mut notebook: Value = serde_json::from_str(&content)
142 .map_err(|e| ToolError::Execution(format!("Invalid notebook JSON: {}", e)))?;
143
144 let cells = notebook
145 .get_mut("cells")
146 .and_then(|value| value.as_array_mut())
147 .ok_or_else(|| ToolError::Execution("Notebook missing 'cells' array".to_string()))?;
148
149 let edit_mode = parsed.edit_mode.unwrap_or_default();
150 let cell_id = parsed
151 .cell_id
152 .as_deref()
153 .map(str::trim)
154 .filter(|value| !value.is_empty());
155 let target_index = Self::find_cell_index(cells, cell_id);
156
157 match edit_mode {
158 EditMode::Replace => {
159 if cell_id.is_none() {
160 return Err(ToolError::InvalidArguments(
161 "cell_id is required when edit_mode=replace".to_string(),
162 ));
163 }
164 let idx = target_index.ok_or_else(|| {
165 ToolError::Execution("Target cell not found for replace".to_string())
166 })?;
167 if let Some(cell) = cells.get_mut(idx) {
168 cell["source"] = Value::Array(Self::source_to_lines(&parsed.new_source));
169 }
170 }
171 EditMode::Insert => {
172 let cell_type = parsed.cell_type.ok_or_else(|| {
173 ToolError::InvalidArguments(
174 "cell_type is required when edit_mode=insert".to_string(),
175 )
176 })?;
177 let new_cell = json!({
178 "id": uuid::Uuid::new_v4().to_string(),
179 "cell_type": cell_type.as_str(),
180 "metadata": {},
181 "source": Self::source_to_lines(&parsed.new_source),
182 "outputs": [],
183 "execution_count": serde_json::Value::Null,
184 });
185
186 if let Some(cell_id) = cell_id {
187 let idx = target_index.ok_or_else(|| {
188 ToolError::Execution(format!(
189 "Target cell '{}' not found for insert",
190 cell_id
191 ))
192 })?;
193 cells.insert(idx + 1, new_cell);
194 } else {
195 cells.push(new_cell);
196 }
197 }
198 EditMode::Delete => {
199 if cell_id.is_none() {
200 return Err(ToolError::InvalidArguments(
201 "cell_id is required when edit_mode=delete".to_string(),
202 ));
203 }
204 let idx = target_index.ok_or_else(|| {
205 ToolError::Execution("Target cell not found for delete".to_string())
206 })?;
207 cells.remove(idx);
208 }
209 }
210
211 let updated = serde_json::to_string_pretty(¬ebook)
212 .map_err(|e| ToolError::Execution(format!("Failed to serialize notebook: {}", e)))?;
213
214 file_change::atomic_write_text(path, &updated).await?;
215
216 let payload = file_change::build_file_change_payload(
217 "NotebookEdit",
218 path,
219 format!("Notebook updated: {}", parsed.notebook_path),
220 checkpoint,
221 &content,
222 &updated,
223 );
224
225 Ok(ToolResult {
226 success: true,
227 result: payload,
228 display_preference: Some("Default".to_string()),
229 images: Vec::new(),
230 })
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237
238 fn sample_notebook() -> serde_json::Value {
239 json!({
240 "cells": [
241 {
242 "id": "cell-a",
243 "cell_type": "code",
244 "metadata": {},
245 "source": ["print('a')\n"],
246 "outputs": [],
247 "execution_count": 1
248 },
249 {
250 "id": "cell-b",
251 "cell_type": "markdown",
252 "metadata": {},
253 "source": ["# title\n"]
254 }
255 ],
256 "metadata": {},
257 "nbformat": 4,
258 "nbformat_minor": 5
259 })
260 }
261
262 async fn write_notebook(path: &Path) {
263 tokio::fs::write(
264 path,
265 serde_json::to_string_pretty(&sample_notebook()).unwrap(),
266 )
267 .await
268 .unwrap();
269 }
270
271 #[tokio::test]
272 async fn replace_requires_cell_id() {
273 let file = tempfile::NamedTempFile::new().unwrap();
274 write_notebook(file.path()).await;
275
276 let tool = NotebookEditTool::new();
277 let result = tool
278 .execute(json!({
279 "notebook_path": file.path(),
280 "edit_mode": "replace",
281 "new_source": "updated"
282 }))
283 .await;
284
285 assert!(matches!(result, Err(ToolError::InvalidArguments(msg)) if msg.contains("cell_id")));
286 }
287
288 #[tokio::test]
289 async fn delete_requires_cell_id() {
290 let file = tempfile::NamedTempFile::new().unwrap();
291 write_notebook(file.path()).await;
292
293 let tool = NotebookEditTool::new();
294 let result = tool
295 .execute(json!({
296 "notebook_path": file.path(),
297 "edit_mode": "delete",
298 "new_source": ""
299 }))
300 .await;
301
302 assert!(matches!(result, Err(ToolError::InvalidArguments(msg)) if msg.contains("cell_id")));
303 }
304
305 #[tokio::test]
306 async fn insert_without_cell_id_appends_cell() {
307 let file = tempfile::NamedTempFile::new().unwrap();
308 write_notebook(file.path()).await;
309
310 let tool = NotebookEditTool::new();
311 let result = tool
312 .execute(json!({
313 "notebook_path": file.path(),
314 "edit_mode": "insert",
315 "cell_type": "markdown",
316 "new_source": "appended cell"
317 }))
318 .await
319 .unwrap();
320 assert!(result.success);
321
322 let updated: Value =
323 serde_json::from_str(&tokio::fs::read_to_string(file.path()).await.unwrap()).unwrap();
324 let cells = updated["cells"].as_array().unwrap();
325 assert_eq!(cells.len(), 3);
326 let last = cells.last().unwrap();
327 assert_eq!(last["cell_type"], "markdown");
328 let source = last["source"].as_array().unwrap();
329 assert_eq!(source[0], "appended cell\n");
330 }
331
332 #[tokio::test]
333 async fn insert_with_unknown_cell_id_returns_error() {
334 let file = tempfile::NamedTempFile::new().unwrap();
335 write_notebook(file.path()).await;
336
337 let tool = NotebookEditTool::new();
338 let result = tool
339 .execute(json!({
340 "notebook_path": file.path(),
341 "edit_mode": "insert",
342 "cell_id": "does-not-exist",
343 "cell_type": "code",
344 "new_source": "print('x')"
345 }))
346 .await;
347
348 assert!(matches!(result, Err(ToolError::Execution(msg)) if msg.contains("not found")));
349 }
350}