Skip to main content

cersei_tools/
notebook_edit.rs

1//! NotebookEdit tool: edit Jupyter/IPython notebook cells.
2
3use super::*;
4use serde::Deserialize;
5
6pub struct NotebookEditTool;
7
8#[async_trait]
9impl Tool for NotebookEditTool {
10    fn name(&self) -> &str {
11        "NotebookEdit"
12    }
13    fn description(&self) -> &str {
14        "Edit a Jupyter notebook (.ipynb) cell by index. Can replace cell source or change cell type."
15    }
16    fn permission_level(&self) -> PermissionLevel {
17        PermissionLevel::Write
18    }
19    fn category(&self) -> ToolCategory {
20        ToolCategory::FileSystem
21    }
22
23    fn input_schema(&self) -> Value {
24        serde_json::json!({
25            "type": "object",
26            "properties": {
27                "file_path": { "type": "string", "description": "Path to .ipynb file" },
28                "cell_index": { "type": "integer", "description": "0-based cell index to edit" },
29                "new_source": { "type": "string", "description": "New cell source content" },
30                "cell_type": { "type": "string", "description": "Optional: 'code' or 'markdown'" }
31            },
32            "required": ["file_path", "cell_index", "new_source"]
33        })
34    }
35
36    async fn execute(&self, input: Value, _ctx: &ToolContext) -> ToolResult {
37        #[derive(Deserialize)]
38        struct Input {
39            file_path: String,
40            cell_index: usize,
41            new_source: String,
42            cell_type: Option<String>,
43        }
44
45        let input: Input = match serde_json::from_value(input) {
46            Ok(i) => i,
47            Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
48        };
49
50        let content = match tokio::fs::read_to_string(&input.file_path).await {
51            Ok(c) => c,
52            Err(e) => return ToolResult::error(format!("Failed to read notebook: {}", e)),
53        };
54
55        let mut notebook: Value = match serde_json::from_str(&content) {
56            Ok(n) => n,
57            Err(e) => return ToolResult::error(format!("Invalid notebook JSON: {}", e)),
58        };
59
60        let cells = match notebook.get_mut("cells").and_then(|c| c.as_array_mut()) {
61            Some(c) => c,
62            None => return ToolResult::error("Notebook has no 'cells' array"),
63        };
64
65        if input.cell_index >= cells.len() {
66            return ToolResult::error(format!(
67                "Cell index {} out of range (notebook has {} cells)",
68                input.cell_index,
69                cells.len()
70            ));
71        }
72
73        // Update cell source (notebooks store source as array of lines)
74        let source_lines: Vec<Value> = input
75            .new_source
76            .lines()
77            .enumerate()
78            .map(|(i, line)| {
79                if i < input.new_source.lines().count() - 1 {
80                    Value::String(format!("{}\n", line))
81                } else {
82                    Value::String(line.to_string())
83                }
84            })
85            .collect();
86
87        cells[input.cell_index]["source"] = Value::Array(source_lines);
88
89        if let Some(ct) = &input.cell_type {
90            cells[input.cell_index]["cell_type"] = Value::String(ct.clone());
91        }
92
93        // Clear outputs for code cells
94        if cells[input.cell_index]["cell_type"].as_str() == Some("code") {
95            cells[input.cell_index]["outputs"] = Value::Array(vec![]);
96            cells[input.cell_index]["execution_count"] = Value::Null;
97        }
98
99        let output = serde_json::to_string_pretty(&notebook).unwrap_or_default();
100        match tokio::fs::write(&input.file_path, output).await {
101            Ok(()) => ToolResult::success(format!(
102                "Updated cell {} in {}",
103                input.cell_index, input.file_path
104            )),
105            Err(e) => ToolResult::error(format!("Failed to write notebook: {}", e)),
106        }
107    }
108}
109
110#[cfg(test)]
111mod tests {
112    use super::*;
113    use crate::permissions::AllowAll;
114    use std::sync::Arc;
115
116    fn test_ctx() -> ToolContext {
117        ToolContext {
118            working_dir: std::env::temp_dir(),
119            session_id: "nb-test".into(),
120            permissions: Arc::new(AllowAll),
121            cost_tracker: Arc::new(CostTracker::new()),
122            mcp_manager: None,
123            extensions: Extensions::default(),
124        }
125    }
126
127    #[tokio::test]
128    async fn test_notebook_edit() {
129        let tmp = tempfile::tempdir().unwrap();
130        let nb_path = tmp.path().join("test.ipynb");
131        let notebook = serde_json::json!({
132            "nbformat": 4,
133            "nbformat_minor": 5,
134            "metadata": {},
135            "cells": [
136                {"cell_type": "code", "source": ["print('hello')\n"], "outputs": [], "metadata": {}},
137                {"cell_type": "markdown", "source": ["# Title\n"], "metadata": {}}
138            ]
139        });
140        std::fs::write(&nb_path, serde_json::to_string(&notebook).unwrap()).unwrap();
141
142        let tool = NotebookEditTool;
143        let result = tool
144            .execute(
145                serde_json::json!({
146                    "file_path": nb_path.display().to_string(),
147                    "cell_index": 0,
148                    "new_source": "print('updated')"
149                }),
150                &test_ctx(),
151            )
152            .await;
153
154        assert!(!result.is_error);
155        assert!(result.content.contains("Updated cell 0"));
156
157        // Verify
158        let content: Value =
159            serde_json::from_str(&std::fs::read_to_string(&nb_path).unwrap()).unwrap();
160        let source = content["cells"][0]["source"][0].as_str().unwrap();
161        assert!(source.contains("updated"));
162    }
163}