Skip to main content

ai_agent/tools/
notebook_edit.rs

1// Source: ~/claudecode/openclaudecode/src/tools/NotebookEditTool/NotebookEditTool.ts
2//! NotebookEdit tool - edit Jupyter notebook cells.
3//!
4//! Provides tools for editing Jupyter notebook (.ipynb) files.
5
6use crate::error::AgentError;
7use crate::types::*;
8use std::fs;
9use std::path::Path;
10
11pub const NOTEBOOK_EDIT_TOOL_NAME: &str = "NotebookEdit";
12
13/// Parse cell ID like "cell-5" into numeric index
14fn parse_cell_id(cell_id: &str) -> Option<usize> {
15    if let Some(rest) = cell_id.strip_prefix("cell-") {
16        rest.parse::<usize>().ok()
17    } else {
18        None
19    }
20}
21
22/// NotebookEdit tool - edit Jupyter notebook (.ipynb) cells
23pub struct NotebookEditTool;
24
25impl NotebookEditTool {
26    pub fn new() -> Self {
27        Self
28    }
29
30    pub fn name(&self) -> &str {
31        NOTEBOOK_EDIT_TOOL_NAME
32    }
33
34    pub fn description(&self) -> &str {
35        "Edit Jupyter notebook (.ipynb) cells: replace, insert, or delete cell content"
36    }
37
38    pub fn user_facing_name(&self, _input: Option<&serde_json::Value>) -> String {
39        "NotebookEdit".to_string()
40    }
41
42    pub fn get_tool_use_summary(&self, input: Option<&serde_json::Value>) -> Option<String> {
43        input.and_then(|inp| inp["notebook_path"].as_str().map(String::from))
44    }
45
46    pub fn render_tool_result_message(
47        &self,
48        content: &serde_json::Value,
49    ) -> Option<String> {
50        content["content"].as_str().map(|s| s.to_string())
51    }
52
53    pub fn input_schema(&self) -> ToolInputSchema {
54        ToolInputSchema {
55            schema_type: "object".to_string(),
56            properties: serde_json::json!({
57                "notebook_path": {
58                    "type": "string",
59                    "description": "The absolute path to the Jupyter notebook file to edit (must be absolute, not relative)"
60                },
61                "cell_id": {
62                    "type": "string",
63                    "description": "The ID of the cell to edit. When inserting a new cell, the new cell will be inserted after the cell with this ID, or at the beginning if not specified."
64                },
65                "new_source": {
66                    "type": "string",
67                    "description": "The new source for the cell"
68                },
69                "cell_type": {
70                    "type": "string",
71                    "enum": ["code", "markdown"],
72                    "description": "The type of the cell (code or markdown). If not specified, it defaults to the current cell type. If using edit_mode=insert, this is required."
73                },
74                "edit_mode": {
75                    "type": "string",
76                    "enum": ["replace", "insert", "delete"],
77                    "description": "The type of edit to make (replace, insert, delete). Defaults to replace."
78                }
79            }),
80            required: Some(vec!["notebook_path".to_string(), "new_source".to_string()]),
81        }
82    }
83
84    pub async fn execute(
85        &self,
86        input: serde_json::Value,
87        context: &ToolContext,
88    ) -> Result<ToolResult, AgentError> {
89        let notebook_path = input["notebook_path"]
90            .as_str()
91            .ok_or_else(|| AgentError::Tool("notebook_path is required".to_string()))?;
92
93        let new_source = input["new_source"]
94            .as_str()
95            .ok_or_else(|| AgentError::Tool("new_source is required".to_string()))?;
96
97        let cell_id = input["cell_id"].as_str();
98        let cell_type = input["cell_type"].as_str();
99        let edit_mode = input["edit_mode"].as_str().unwrap_or("replace");
100
101        // Validate edit_mode
102        if !["replace", "insert", "delete"].contains(&edit_mode) {
103            return Ok(ToolResult {
104                result_type: "text".to_string(),
105                tool_use_id: "".to_string(),
106                content: "Error: Edit mode must be replace, insert, or delete.".to_string(),
107                is_error: Some(true),
108                was_persisted: None,
109            });
110        }
111
112        // Cell type required for insert mode
113        if edit_mode == "insert" && cell_type.is_none() {
114            return Ok(ToolResult {
115                result_type: "text".to_string(),
116                tool_use_id: "".to_string(),
117                content: "Error: Cell type is required when using edit_mode=insert.".to_string(),
118                is_error: Some(true),
119                was_persisted: None,
120            });
121        }
122
123        // Resolve path
124        let path_buf = if Path::new(notebook_path).is_absolute() {
125            Path::new(notebook_path).to_path_buf()
126        } else {
127            Path::new(&context.cwd).join(notebook_path)
128        };
129
130        // Check .ipynb extension
131        if path_buf.extension().map(|e| e.to_str()) != Some(Some("ipynb")) {
132            return Ok(ToolResult {
133                result_type: "text".to_string(),
134                tool_use_id: "".to_string(),
135                content: "Error: File must be a Jupyter notebook (.ipynb file). For editing other file types, use the FileEdit tool.".to_string(),
136                is_error: Some(true),
137                was_persisted: None,
138            });
139        }
140
141        // Check file exists
142        if !path_buf.exists() {
143            return Ok(ToolResult {
144                result_type: "text".to_string(),
145                tool_use_id: "".to_string(),
146                content: "Error: Notebook file does not exist.".to_string(),
147                is_error: Some(true),
148                was_persisted: None,
149            });
150        }
151
152        // Read file
153        let content = fs::read_to_string(&path_buf)
154            .map_err(|e| AgentError::Tool(format!("Failed to read notebook: {}", e)))?;
155
156        // Parse JSON
157        let mut notebook: serde_json::Value = match serde_json::from_str(&content) {
158            Ok(v) => v,
159            Err(_) => {
160                return Ok(ToolResult {
161                    result_type: "text".to_string(),
162                    tool_use_id: "".to_string(),
163                    content: "Error: Notebook is not valid JSON.".to_string(),
164                    is_error: Some(true),
165                    was_persisted: None,
166                });
167            }
168        };
169
170        // Get notebook metadata BEFORE getting mutable cells reference
171        let language = notebook["metadata"]["language_info"]["name"]
172            .as_str()
173            .unwrap_or("python")
174            .to_string();
175
176        let nbformat = notebook["nbformat"].as_i64().unwrap_or(4);
177        let nbformat_minor = notebook["nbformat_minor"].as_i64().unwrap_or(0);
178
179        let cells = notebook["cells"]
180            .as_array_mut()
181            .ok_or_else(|| AgentError::Tool("Invalid notebook: no cells array".to_string()))?;
182
183        let original_content = content.clone();
184
185        // Determine cell index
186        let cell_index = if cell_id.is_none() {
187            if edit_mode != "insert" {
188                return Ok(ToolResult {
189                    result_type: "text".to_string(),
190                    tool_use_id: "".to_string(),
191                    content: "Error: Cell ID must be specified when not inserting a new cell."
192                        .to_string(),
193                    is_error: Some(true),
194                    was_persisted: None,
195                });
196            }
197            0 // Default to inserting at the beginning
198        } else {
199            let cid = cell_id.unwrap();
200            // First try to find by actual ID
201            let idx = cells
202                .iter()
203                .position(|c| c.get("id").and_then(|v| v.as_str()) == Some(cid));
204            if let Some(i) = idx {
205                i
206            } else {
207                // Try to parse as numeric index (cell-N format)
208                if let Some(parsed) = parse_cell_id(cid) {
209                    if parsed >= cells.len() {
210                        return Ok(ToolResult {
211                            result_type: "text".to_string(),
212                            tool_use_id: "".to_string(),
213                            content: format!(
214                                "Error: Cell with index {} does not exist in notebook.",
215                                parsed
216                            ),
217                            is_error: Some(true),
218                            was_persisted: None,
219                        });
220                    }
221                    parsed
222                } else {
223                    return Ok(ToolResult {
224                        result_type: "text".to_string(),
225                        tool_use_id: "".to_string(),
226                        content: format!("Error: Cell with ID \"{}\" not found in notebook.", cid),
227                        is_error: Some(true),
228                        was_persisted: None,
229                    });
230                }
231            }
232        };
233
234        let actual_cell_index = if edit_mode == "insert" {
235            cell_index + 1 // Insert after the cell with this ID
236        } else {
237            cell_index
238        };
239
240        // Convert replace to insert if trying to replace one past the end
241        let mut actual_edit_mode = edit_mode.to_string();
242        let mut actual_cell_type = cell_type.map(|s| s.to_string());
243
244        if actual_edit_mode == "replace" && actual_cell_index == cells.len() {
245            actual_edit_mode = "insert".to_string();
246            if actual_cell_type.is_none() {
247                actual_cell_type = Some("code".to_string());
248            }
249        }
250
251        let mut new_cell_id: Option<String> = None;
252
253        // Check nbformat version for cell ID generation
254        let needs_cell_ids = nbformat > 4 || (nbformat == 4 && nbformat_minor >= 5);
255
256        if needs_cell_ids {
257            if actual_edit_mode == "insert" {
258                // Generate random cell ID
259                new_cell_id = Some(
260                    (0..13)
261                        .map(|_| {
262                            let c = "abcdefghijklmnopqrstuvwxyz0123456789".as_bytes()
263                                [rand::random::<u8>() as usize % 36];
264                            c as char
265                        })
266                        .collect(),
267                );
268            } else if let Some(cid) = cell_id {
269                new_cell_id = Some(cid.to_string());
270            }
271        }
272
273        match actual_edit_mode.as_str() {
274            "delete" => {
275                if actual_cell_index >= cells.len() {
276                    return Ok(ToolResult {
277                        result_type: "text".to_string(),
278                        tool_use_id: "".to_string(),
279                        content: format!("Error: Cell index {} out of bounds", actual_cell_index),
280                        is_error: Some(true),
281                        was_persisted: None,
282                    });
283                }
284                cells.remove(actual_cell_index);
285            }
286            "insert" => {
287                let ct = actual_cell_type.as_deref().unwrap_or("code");
288                let mut new_cell = serde_json::json!({
289                    "cell_type": ct,
290                    "source": new_source,
291                    "metadata": serde_json::json!({})
292                });
293                if needs_cell_ids {
294                    if let Some(id) = &new_cell_id {
295                        new_cell["id"] = serde_json::json!(id);
296                    }
297                }
298                if ct != "markdown" {
299                    new_cell["execution_count"] = serde_json::json!(null);
300                    new_cell["outputs"] = serde_json::json!([]);
301                }
302                cells.insert(actual_cell_index, new_cell);
303            }
304            "replace" => {
305                if actual_cell_index >= cells.len() {
306                    return Ok(ToolResult {
307                        result_type: "text".to_string(),
308                        tool_use_id: "".to_string(),
309                        content: format!("Error: Cell index {} out of bounds", actual_cell_index),
310                        is_error: Some(true),
311                        was_persisted: None,
312                    });
313                }
314                let target_cell = &mut cells[actual_cell_index];
315                // Set source as lines array
316                let source_lines: Vec<String> = new_source
317                    .lines()
318                    .enumerate()
319                    .map(|(i, l)| {
320                        if i < new_source.lines().count() - 1 {
321                            format!("{}\n", l)
322                        } else {
323                            l.to_string()
324                        }
325                    })
326                    .collect();
327                target_cell["source"] = serde_json::json!(source_lines);
328                if target_cell.get("cell_type").and_then(|v| v.as_str()) == Some("code") {
329                    // Reset execution count and clear outputs
330                    target_cell["execution_count"] = serde_json::json!(null);
331                    target_cell["outputs"] = serde_json::json!([]);
332                }
333                if let Some(ct) = &actual_cell_type {
334                    if target_cell.get("cell_type").and_then(|v| v.as_str()) != Some(ct.as_str()) {
335                        target_cell["cell_type"] = serde_json::json!(ct);
336                    }
337                }
338            }
339            _ => {
340                return Ok(ToolResult {
341                    result_type: "text".to_string(),
342                    tool_use_id: "".to_string(),
343                    content: format!("Error: Unknown edit mode: {}", actual_edit_mode),
344                    is_error: Some(true),
345                    was_persisted: None,
346                });
347            }
348        }
349
350        // Write back to file with indent=1 (matching TS: IPYNB_INDENT = 1)
351        let updated_content = serde_json::to_string_pretty(&notebook)
352            .map_err(|e| AgentError::Tool(format!("Failed to serialize notebook: {}", e)))?;
353
354        fs::write(&path_buf, &updated_content)
355            .map_err(|e| AgentError::Tool(format!("Failed to write notebook: {}", e)))?;
356
357        let result_cell_id = new_cell_id.or_else(|| cell_id.map(|s| s.to_string()));
358
359        let display_cell_id = result_cell_id.as_deref().unwrap_or("unknown");
360
361        let message = match actual_edit_mode.as_str() {
362            "replace" => format!("Updated cell {} with {}", display_cell_id, new_source),
363            "insert" => format!("Inserted cell {} with {}", display_cell_id, new_source),
364            "delete" => format!("Deleted cell {}", display_cell_id),
365            _ => "Unknown edit mode".to_string(),
366        };
367
368        Ok(ToolResult {
369            result_type: "text".to_string(),
370            tool_use_id: "".to_string(),
371            content: message,
372            is_error: None,
373            was_persisted: None,
374        })
375    }
376}
377
378impl Default for NotebookEditTool {
379    fn default() -> Self {
380        Self::new()
381    }
382}
383
384#[cfg(test)]
385mod tests {
386    use super::*;
387
388    fn create_test_notebook() -> serde_json::Value {
389        serde_json::json!({
390            "nbformat": 4,
391            "nbformat_minor": 5,
392            "metadata": {
393                "language_info": { "name": "python" }
394            },
395            "cells": [
396                {
397                    "cell_type": "code",
398                    "execution_count": 1,
399                    "metadata": {},
400                    "outputs": [{"name": "stdout", "output_type": "stream", "text": ["hello\n"]}],
401                    "source": ["print('hello')\n"],
402                    "id": "abc123"
403                },
404                {
405                    "cell_type": "markdown",
406                    "metadata": {},
407                    "source": ["# Title\n"],
408                    "id": "def456"
409                }
410            ]
411        })
412    }
413
414    #[test]
415    fn test_notebook_edit_tool_name() {
416        let tool = NotebookEditTool::new();
417        assert_eq!(tool.name(), NOTEBOOK_EDIT_TOOL_NAME);
418    }
419
420    #[test]
421    fn test_parse_cell_id() {
422        assert_eq!(parse_cell_id("cell-5"), Some(5));
423        assert_eq!(parse_cell_id("cell-0"), Some(0));
424        assert_eq!(parse_cell_id("abc123"), None);
425        assert_eq!(parse_cell_id("cell-"), None);
426    }
427
428    #[tokio::test]
429    async fn test_notebook_edit_tool_replace_cell() {
430        let temp_dir = std::env::temp_dir();
431        let temp_file = temp_dir.join("test_nb_replace.ipynb");
432        let notebook = create_test_notebook();
433        std::fs::write(&temp_file, serde_json::to_string_pretty(&notebook).unwrap()).unwrap();
434
435        let tool = NotebookEditTool::new();
436        let input = serde_json::json!({
437            "notebook_path": temp_file.to_str().unwrap(),
438            "cell_id": "abc123",
439            "new_source": "print('replaced')",
440            "edit_mode": "replace"
441        });
442        let context = ToolContext::default();
443
444        let result = tool.execute(input, &context).await;
445        assert!(result.is_ok());
446
447        let content = std::fs::read_to_string(&temp_file).unwrap();
448        let nb: serde_json::Value = serde_json::from_str(&content).unwrap();
449        // Cell source should be updated
450        assert_eq!(
451            nb["cells"][0]["source"].as_array().unwrap()[0],
452            "print('replaced')"
453        );
454        // Execution count should be reset
455        assert!(nb["cells"][0]["execution_count"].is_null());
456        // Outputs should be cleared
457        assert!(nb["cells"][0]["outputs"].as_array().unwrap().is_empty());
458
459        std::fs::remove_file(temp_file).ok();
460    }
461
462    #[tokio::test]
463    async fn test_notebook_edit_tool_insert_cell() {
464        let temp_dir = std::env::temp_dir();
465        let temp_file = temp_dir.join("test_nb_insert.ipynb");
466        let notebook = create_test_notebook();
467        std::fs::write(&temp_file, serde_json::to_string_pretty(&notebook).unwrap()).unwrap();
468
469        let tool = NotebookEditTool::new();
470        let input = serde_json::json!({
471            "notebook_path": temp_file.to_str().unwrap(),
472            "cell_id": "abc123",
473            "new_source": "x = 1",
474            "cell_type": "code",
475            "edit_mode": "insert"
476        });
477        let context = ToolContext::default();
478
479        let result = tool.execute(input, &context).await;
480        assert!(result.is_ok());
481
482        let content = std::fs::read_to_string(&temp_file).unwrap();
483        let nb: serde_json::Value = serde_json::from_str(&content).unwrap();
484        // Should now have 3 cells
485        assert_eq!(nb["cells"].as_array().unwrap().len(), 3);
486        // New cell inserted after index 0
487        assert_eq!(nb["cells"][1]["source"].as_str().unwrap(), "x = 1");
488
489        std::fs::remove_file(temp_file).ok();
490    }
491
492    #[tokio::test]
493    async fn test_notebook_edit_tool_delete_cell() {
494        let temp_dir = std::env::temp_dir();
495        let temp_file = temp_dir.join("test_nb_delete.ipynb");
496        let notebook = create_test_notebook();
497        std::fs::write(&temp_file, serde_json::to_string_pretty(&notebook).unwrap()).unwrap();
498
499        let tool = NotebookEditTool::new();
500        let input = serde_json::json!({
501            "notebook_path": temp_file.to_str().unwrap(),
502            "cell_id": "def456",
503            "new_source": "",
504            "edit_mode": "delete"
505        });
506        let context = ToolContext::default();
507
508        let result = tool.execute(input, &context).await;
509        assert!(result.is_ok());
510
511        let content = std::fs::read_to_string(&temp_file).unwrap();
512        let nb: serde_json::Value = serde_json::from_str(&content).unwrap();
513        assert_eq!(nb["cells"].as_array().unwrap().len(), 1);
514        assert_eq!(nb["cells"][0]["cell_type"], "code");
515
516        std::fs::remove_file(temp_file).ok();
517    }
518
519    #[tokio::test]
520    async fn test_notebook_edit_tool_cell_id_numeric_fallback() {
521        let temp_dir = std::env::temp_dir();
522        let temp_file = temp_dir.join("test_nb_numeric.ipynb");
523        let notebook = create_test_notebook();
524        std::fs::write(&temp_file, serde_json::to_string_pretty(&notebook).unwrap()).unwrap();
525
526        let tool = NotebookEditTool::new();
527        let input = serde_json::json!({
528            "notebook_path": temp_file.to_str().unwrap(),
529            "cell_id": "cell-1",
530            "new_source": "# Updated markdown",
531            "edit_mode": "replace"
532        });
533        let context = ToolContext::default();
534
535        let result = tool.execute(input, &context).await;
536        assert!(result.is_ok());
537
538        let content = std::fs::read_to_string(&temp_file).unwrap();
539        let nb: serde_json::Value = serde_json::from_str(&content).unwrap();
540        assert!(
541            nb["cells"][1]["source"].as_array().unwrap()[0]
542                .to_string()
543                .contains("Updated markdown")
544        );
545
546        std::fs::remove_file(temp_file).ok();
547    }
548
549    #[tokio::test]
550    async fn test_notebook_edit_tool_rejects_non_ipynb() {
551        let tool = NotebookEditTool::new();
552        let input = serde_json::json!({
553            "notebook_path": "/tmp/test.txt",
554            "new_source": "test",
555            "edit_mode": "replace"
556        });
557        let context = ToolContext::default();
558
559        let result = tool.execute(input, &context).await;
560        assert!(result.is_ok());
561        let tool_result = result.unwrap();
562        assert!(tool_result.is_error.is_some() && tool_result.is_error.unwrap());
563        assert!(tool_result.content.contains(".ipynb"));
564    }
565}