cersei_tools/
notebook_edit.rs1use super::*;
4use serde::Deserialize;
5
6pub struct NotebookEditTool;
7
8#[async_trait]
9impl Tool for NotebookEditTool {
10 fn name(&self) -> &str {
11 "NotebookEdit"
12 }
13 fn description(&self) -> &str {
14 "Edit a Jupyter notebook (.ipynb) cell by index. Can replace cell source or change cell type."
15 }
16 fn permission_level(&self) -> PermissionLevel {
17 PermissionLevel::Write
18 }
19 fn category(&self) -> ToolCategory {
20 ToolCategory::FileSystem
21 }
22
23 fn input_schema(&self) -> Value {
24 serde_json::json!({
25 "type": "object",
26 "properties": {
27 "file_path": { "type": "string", "description": "Path to .ipynb file" },
28 "cell_index": { "type": "integer", "description": "0-based cell index to edit" },
29 "new_source": { "type": "string", "description": "New cell source content" },
30 "cell_type": { "type": "string", "description": "Optional: 'code' or 'markdown'" }
31 },
32 "required": ["file_path", "cell_index", "new_source"]
33 })
34 }
35
36 async fn execute(&self, input: Value, _ctx: &ToolContext) -> ToolResult {
37 #[derive(Deserialize)]
38 struct Input {
39 file_path: String,
40 cell_index: usize,
41 new_source: String,
42 cell_type: Option<String>,
43 }
44
45 let input: Input = match serde_json::from_value(input) {
46 Ok(i) => i,
47 Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
48 };
49
50 let content = match tokio::fs::read_to_string(&input.file_path).await {
51 Ok(c) => c,
52 Err(e) => return ToolResult::error(format!("Failed to read notebook: {}", e)),
53 };
54
55 let mut notebook: Value = match serde_json::from_str(&content) {
56 Ok(n) => n,
57 Err(e) => return ToolResult::error(format!("Invalid notebook JSON: {}", e)),
58 };
59
60 let cells = match notebook.get_mut("cells").and_then(|c| c.as_array_mut()) {
61 Some(c) => c,
62 None => return ToolResult::error("Notebook has no 'cells' array"),
63 };
64
65 if input.cell_index >= cells.len() {
66 return ToolResult::error(format!(
67 "Cell index {} out of range (notebook has {} cells)",
68 input.cell_index,
69 cells.len()
70 ));
71 }
72
73 let source_lines: Vec<Value> = input
75 .new_source
76 .lines()
77 .enumerate()
78 .map(|(i, line)| {
79 if i < input.new_source.lines().count() - 1 {
80 Value::String(format!("{}\n", line))
81 } else {
82 Value::String(line.to_string())
83 }
84 })
85 .collect();
86
87 cells[input.cell_index]["source"] = Value::Array(source_lines);
88
89 if let Some(ct) = &input.cell_type {
90 cells[input.cell_index]["cell_type"] = Value::String(ct.clone());
91 }
92
93 if cells[input.cell_index]["cell_type"].as_str() == Some("code") {
95 cells[input.cell_index]["outputs"] = Value::Array(vec![]);
96 cells[input.cell_index]["execution_count"] = Value::Null;
97 }
98
99 let output = serde_json::to_string_pretty(¬ebook).unwrap_or_default();
100 match tokio::fs::write(&input.file_path, output).await {
101 Ok(()) => ToolResult::success(format!(
102 "Updated cell {} in {}",
103 input.cell_index, input.file_path
104 )),
105 Err(e) => ToolResult::error(format!("Failed to write notebook: {}", e)),
106 }
107 }
108}
109
110#[cfg(test)]
111mod tests {
112 use super::*;
113 use crate::permissions::AllowAll;
114 use std::sync::Arc;
115
116 fn test_ctx() -> ToolContext {
117 ToolContext {
118 working_dir: std::env::temp_dir(),
119 session_id: "nb-test".into(),
120 permissions: Arc::new(AllowAll),
121 cost_tracker: Arc::new(CostTracker::new()),
122 mcp_manager: None,
123 extensions: Extensions::default(),
124 }
125 }
126
127 #[tokio::test]
128 async fn test_notebook_edit() {
129 let tmp = tempfile::tempdir().unwrap();
130 let nb_path = tmp.path().join("test.ipynb");
131 let notebook = serde_json::json!({
132 "nbformat": 4,
133 "nbformat_minor": 5,
134 "metadata": {},
135 "cells": [
136 {"cell_type": "code", "source": ["print('hello')\n"], "outputs": [], "metadata": {}},
137 {"cell_type": "markdown", "source": ["# Title\n"], "metadata": {}}
138 ]
139 });
140 std::fs::write(&nb_path, serde_json::to_string(¬ebook).unwrap()).unwrap();
141
142 let tool = NotebookEditTool;
143 let result = tool
144 .execute(
145 serde_json::json!({
146 "file_path": nb_path.display().to_string(),
147 "cell_index": 0,
148 "new_source": "print('updated')"
149 }),
150 &test_ctx(),
151 )
152 .await;
153
154 assert!(!result.is_error);
155 assert!(result.content.contains("Updated cell 0"));
156
157 let content: Value =
159 serde_json::from_str(&std::fs::read_to_string(&nb_path).unwrap()).unwrap();
160 let source = content["cells"][0]["source"][0].as_str().unwrap();
161 assert!(source.contains("updated"));
162 }
163}