1use crate::error::AgentError;
7use crate::types::*;
8use std::fs;
9use std::path::Path;
10
11pub const NOTEBOOK_EDIT_TOOL_NAME: &str = "NotebookEdit";
12
13fn parse_cell_id(cell_id: &str) -> Option<usize> {
15 if let Some(rest) = cell_id.strip_prefix("cell-") {
16 rest.parse::<usize>().ok()
17 } else {
18 None
19 }
20}
21
22pub struct NotebookEditTool;
24
25impl NotebookEditTool {
26 pub fn new() -> Self {
27 Self
28 }
29
30 pub fn name(&self) -> &str {
31 NOTEBOOK_EDIT_TOOL_NAME
32 }
33
34 pub fn description(&self) -> &str {
35 "Edit Jupyter notebook (.ipynb) cells: replace, insert, or delete cell content"
36 }
37
38 pub fn user_facing_name(&self, _input: Option<&serde_json::Value>) -> String {
39 "NotebookEdit".to_string()
40 }
41
42 pub fn get_tool_use_summary(&self, input: Option<&serde_json::Value>) -> Option<String> {
43 input.and_then(|inp| inp["notebook_path"].as_str().map(String::from))
44 }
45
46 pub fn render_tool_result_message(
47 &self,
48 content: &serde_json::Value,
49 ) -> Option<String> {
50 content["content"].as_str().map(|s| s.to_string())
51 }
52
53 pub fn input_schema(&self) -> ToolInputSchema {
54 ToolInputSchema {
55 schema_type: "object".to_string(),
56 properties: serde_json::json!({
57 "notebook_path": {
58 "type": "string",
59 "description": "The absolute path to the Jupyter notebook file to edit (must be absolute, not relative)"
60 },
61 "cell_id": {
62 "type": "string",
63 "description": "The ID of the cell to edit. When inserting a new cell, the new cell will be inserted after the cell with this ID, or at the beginning if not specified."
64 },
65 "new_source": {
66 "type": "string",
67 "description": "The new source for the cell"
68 },
69 "cell_type": {
70 "type": "string",
71 "enum": ["code", "markdown"],
72 "description": "The type of the cell (code or markdown). If not specified, it defaults to the current cell type. If using edit_mode=insert, this is required."
73 },
74 "edit_mode": {
75 "type": "string",
76 "enum": ["replace", "insert", "delete"],
77 "description": "The type of edit to make (replace, insert, delete). Defaults to replace."
78 }
79 }),
80 required: Some(vec!["notebook_path".to_string(), "new_source".to_string()]),
81 }
82 }
83
84 pub async fn execute(
85 &self,
86 input: serde_json::Value,
87 context: &ToolContext,
88 ) -> Result<ToolResult, AgentError> {
89 let notebook_path = input["notebook_path"]
90 .as_str()
91 .ok_or_else(|| AgentError::Tool("notebook_path is required".to_string()))?;
92
93 let new_source = input["new_source"]
94 .as_str()
95 .ok_or_else(|| AgentError::Tool("new_source is required".to_string()))?;
96
97 let cell_id = input["cell_id"].as_str();
98 let cell_type = input["cell_type"].as_str();
99 let edit_mode = input["edit_mode"].as_str().unwrap_or("replace");
100
101 if !["replace", "insert", "delete"].contains(&edit_mode) {
103 return Ok(ToolResult {
104 result_type: "text".to_string(),
105 tool_use_id: "".to_string(),
106 content: "Error: Edit mode must be replace, insert, or delete.".to_string(),
107 is_error: Some(true),
108 was_persisted: None,
109 });
110 }
111
112 if edit_mode == "insert" && cell_type.is_none() {
114 return Ok(ToolResult {
115 result_type: "text".to_string(),
116 tool_use_id: "".to_string(),
117 content: "Error: Cell type is required when using edit_mode=insert.".to_string(),
118 is_error: Some(true),
119 was_persisted: None,
120 });
121 }
122
123 let path_buf = if Path::new(notebook_path).is_absolute() {
125 Path::new(notebook_path).to_path_buf()
126 } else {
127 Path::new(&context.cwd).join(notebook_path)
128 };
129
130 if path_buf.extension().map(|e| e.to_str()) != Some(Some("ipynb")) {
132 return Ok(ToolResult {
133 result_type: "text".to_string(),
134 tool_use_id: "".to_string(),
135 content: "Error: File must be a Jupyter notebook (.ipynb file). For editing other file types, use the FileEdit tool.".to_string(),
136 is_error: Some(true),
137 was_persisted: None,
138 });
139 }
140
141 if !path_buf.exists() {
143 return Ok(ToolResult {
144 result_type: "text".to_string(),
145 tool_use_id: "".to_string(),
146 content: "Error: Notebook file does not exist.".to_string(),
147 is_error: Some(true),
148 was_persisted: None,
149 });
150 }
151
152 let content = fs::read_to_string(&path_buf)
154 .map_err(|e| AgentError::Tool(format!("Failed to read notebook: {}", e)))?;
155
156 let mut notebook: serde_json::Value = match serde_json::from_str(&content) {
158 Ok(v) => v,
159 Err(_) => {
160 return Ok(ToolResult {
161 result_type: "text".to_string(),
162 tool_use_id: "".to_string(),
163 content: "Error: Notebook is not valid JSON.".to_string(),
164 is_error: Some(true),
165 was_persisted: None,
166 });
167 }
168 };
169
170 let language = notebook["metadata"]["language_info"]["name"]
172 .as_str()
173 .unwrap_or("python")
174 .to_string();
175
176 let nbformat = notebook["nbformat"].as_i64().unwrap_or(4);
177 let nbformat_minor = notebook["nbformat_minor"].as_i64().unwrap_or(0);
178
179 let cells = notebook["cells"]
180 .as_array_mut()
181 .ok_or_else(|| AgentError::Tool("Invalid notebook: no cells array".to_string()))?;
182
183 let original_content = content.clone();
184
185 let cell_index = if cell_id.is_none() {
187 if edit_mode != "insert" {
188 return Ok(ToolResult {
189 result_type: "text".to_string(),
190 tool_use_id: "".to_string(),
191 content: "Error: Cell ID must be specified when not inserting a new cell."
192 .to_string(),
193 is_error: Some(true),
194 was_persisted: None,
195 });
196 }
197 0 } else {
199 let cid = cell_id.unwrap();
200 let idx = cells
202 .iter()
203 .position(|c| c.get("id").and_then(|v| v.as_str()) == Some(cid));
204 if let Some(i) = idx {
205 i
206 } else {
207 if let Some(parsed) = parse_cell_id(cid) {
209 if parsed >= cells.len() {
210 return Ok(ToolResult {
211 result_type: "text".to_string(),
212 tool_use_id: "".to_string(),
213 content: format!(
214 "Error: Cell with index {} does not exist in notebook.",
215 parsed
216 ),
217 is_error: Some(true),
218 was_persisted: None,
219 });
220 }
221 parsed
222 } else {
223 return Ok(ToolResult {
224 result_type: "text".to_string(),
225 tool_use_id: "".to_string(),
226 content: format!("Error: Cell with ID \"{}\" not found in notebook.", cid),
227 is_error: Some(true),
228 was_persisted: None,
229 });
230 }
231 }
232 };
233
234 let actual_cell_index = if edit_mode == "insert" {
235 cell_index + 1 } else {
237 cell_index
238 };
239
240 let mut actual_edit_mode = edit_mode.to_string();
242 let mut actual_cell_type = cell_type.map(|s| s.to_string());
243
244 if actual_edit_mode == "replace" && actual_cell_index == cells.len() {
245 actual_edit_mode = "insert".to_string();
246 if actual_cell_type.is_none() {
247 actual_cell_type = Some("code".to_string());
248 }
249 }
250
251 let mut new_cell_id: Option<String> = None;
252
253 let needs_cell_ids = nbformat > 4 || (nbformat == 4 && nbformat_minor >= 5);
255
256 if needs_cell_ids {
257 if actual_edit_mode == "insert" {
258 new_cell_id = Some(
260 (0..13)
261 .map(|_| {
262 let c = "abcdefghijklmnopqrstuvwxyz0123456789".as_bytes()
263 [rand::random::<u8>() as usize % 36];
264 c as char
265 })
266 .collect(),
267 );
268 } else if let Some(cid) = cell_id {
269 new_cell_id = Some(cid.to_string());
270 }
271 }
272
273 match actual_edit_mode.as_str() {
274 "delete" => {
275 if actual_cell_index >= cells.len() {
276 return Ok(ToolResult {
277 result_type: "text".to_string(),
278 tool_use_id: "".to_string(),
279 content: format!("Error: Cell index {} out of bounds", actual_cell_index),
280 is_error: Some(true),
281 was_persisted: None,
282 });
283 }
284 cells.remove(actual_cell_index);
285 }
286 "insert" => {
287 let ct = actual_cell_type.as_deref().unwrap_or("code");
288 let mut new_cell = serde_json::json!({
289 "cell_type": ct,
290 "source": new_source,
291 "metadata": serde_json::json!({})
292 });
293 if needs_cell_ids {
294 if let Some(id) = &new_cell_id {
295 new_cell["id"] = serde_json::json!(id);
296 }
297 }
298 if ct != "markdown" {
299 new_cell["execution_count"] = serde_json::json!(null);
300 new_cell["outputs"] = serde_json::json!([]);
301 }
302 cells.insert(actual_cell_index, new_cell);
303 }
304 "replace" => {
305 if actual_cell_index >= cells.len() {
306 return Ok(ToolResult {
307 result_type: "text".to_string(),
308 tool_use_id: "".to_string(),
309 content: format!("Error: Cell index {} out of bounds", actual_cell_index),
310 is_error: Some(true),
311 was_persisted: None,
312 });
313 }
314 let target_cell = &mut cells[actual_cell_index];
315 let source_lines: Vec<String> = new_source
317 .lines()
318 .enumerate()
319 .map(|(i, l)| {
320 if i < new_source.lines().count() - 1 {
321 format!("{}\n", l)
322 } else {
323 l.to_string()
324 }
325 })
326 .collect();
327 target_cell["source"] = serde_json::json!(source_lines);
328 if target_cell.get("cell_type").and_then(|v| v.as_str()) == Some("code") {
329 target_cell["execution_count"] = serde_json::json!(null);
331 target_cell["outputs"] = serde_json::json!([]);
332 }
333 if let Some(ct) = &actual_cell_type {
334 if target_cell.get("cell_type").and_then(|v| v.as_str()) != Some(ct.as_str()) {
335 target_cell["cell_type"] = serde_json::json!(ct);
336 }
337 }
338 }
339 _ => {
340 return Ok(ToolResult {
341 result_type: "text".to_string(),
342 tool_use_id: "".to_string(),
343 content: format!("Error: Unknown edit mode: {}", actual_edit_mode),
344 is_error: Some(true),
345 was_persisted: None,
346 });
347 }
348 }
349
350 let updated_content = serde_json::to_string_pretty(¬ebook)
352 .map_err(|e| AgentError::Tool(format!("Failed to serialize notebook: {}", e)))?;
353
354 fs::write(&path_buf, &updated_content)
355 .map_err(|e| AgentError::Tool(format!("Failed to write notebook: {}", e)))?;
356
357 let result_cell_id = new_cell_id.or_else(|| cell_id.map(|s| s.to_string()));
358
359 let display_cell_id = result_cell_id.as_deref().unwrap_or("unknown");
360
361 let message = match actual_edit_mode.as_str() {
362 "replace" => format!("Updated cell {} with {}", display_cell_id, new_source),
363 "insert" => format!("Inserted cell {} with {}", display_cell_id, new_source),
364 "delete" => format!("Deleted cell {}", display_cell_id),
365 _ => "Unknown edit mode".to_string(),
366 };
367
368 Ok(ToolResult {
369 result_type: "text".to_string(),
370 tool_use_id: "".to_string(),
371 content: message,
372 is_error: None,
373 was_persisted: None,
374 })
375 }
376}
377
378impl Default for NotebookEditTool {
379 fn default() -> Self {
380 Self::new()
381 }
382}
383
384#[cfg(test)]
385mod tests {
386 use super::*;
387
388 fn create_test_notebook() -> serde_json::Value {
389 serde_json::json!({
390 "nbformat": 4,
391 "nbformat_minor": 5,
392 "metadata": {
393 "language_info": { "name": "python" }
394 },
395 "cells": [
396 {
397 "cell_type": "code",
398 "execution_count": 1,
399 "metadata": {},
400 "outputs": [{"name": "stdout", "output_type": "stream", "text": ["hello\n"]}],
401 "source": ["print('hello')\n"],
402 "id": "abc123"
403 },
404 {
405 "cell_type": "markdown",
406 "metadata": {},
407 "source": ["# Title\n"],
408 "id": "def456"
409 }
410 ]
411 })
412 }
413
414 #[test]
415 fn test_notebook_edit_tool_name() {
416 let tool = NotebookEditTool::new();
417 assert_eq!(tool.name(), NOTEBOOK_EDIT_TOOL_NAME);
418 }
419
420 #[test]
421 fn test_parse_cell_id() {
422 assert_eq!(parse_cell_id("cell-5"), Some(5));
423 assert_eq!(parse_cell_id("cell-0"), Some(0));
424 assert_eq!(parse_cell_id("abc123"), None);
425 assert_eq!(parse_cell_id("cell-"), None);
426 }
427
428 #[tokio::test]
429 async fn test_notebook_edit_tool_replace_cell() {
430 let temp_dir = std::env::temp_dir();
431 let temp_file = temp_dir.join("test_nb_replace.ipynb");
432 let notebook = create_test_notebook();
433 std::fs::write(&temp_file, serde_json::to_string_pretty(¬ebook).unwrap()).unwrap();
434
435 let tool = NotebookEditTool::new();
436 let input = serde_json::json!({
437 "notebook_path": temp_file.to_str().unwrap(),
438 "cell_id": "abc123",
439 "new_source": "print('replaced')",
440 "edit_mode": "replace"
441 });
442 let context = ToolContext::default();
443
444 let result = tool.execute(input, &context).await;
445 assert!(result.is_ok());
446
447 let content = std::fs::read_to_string(&temp_file).unwrap();
448 let nb: serde_json::Value = serde_json::from_str(&content).unwrap();
449 assert_eq!(
451 nb["cells"][0]["source"].as_array().unwrap()[0],
452 "print('replaced')"
453 );
454 assert!(nb["cells"][0]["execution_count"].is_null());
456 assert!(nb["cells"][0]["outputs"].as_array().unwrap().is_empty());
458
459 std::fs::remove_file(temp_file).ok();
460 }
461
462 #[tokio::test]
463 async fn test_notebook_edit_tool_insert_cell() {
464 let temp_dir = std::env::temp_dir();
465 let temp_file = temp_dir.join("test_nb_insert.ipynb");
466 let notebook = create_test_notebook();
467 std::fs::write(&temp_file, serde_json::to_string_pretty(¬ebook).unwrap()).unwrap();
468
469 let tool = NotebookEditTool::new();
470 let input = serde_json::json!({
471 "notebook_path": temp_file.to_str().unwrap(),
472 "cell_id": "abc123",
473 "new_source": "x = 1",
474 "cell_type": "code",
475 "edit_mode": "insert"
476 });
477 let context = ToolContext::default();
478
479 let result = tool.execute(input, &context).await;
480 assert!(result.is_ok());
481
482 let content = std::fs::read_to_string(&temp_file).unwrap();
483 let nb: serde_json::Value = serde_json::from_str(&content).unwrap();
484 assert_eq!(nb["cells"].as_array().unwrap().len(), 3);
486 assert_eq!(nb["cells"][1]["source"].as_str().unwrap(), "x = 1");
488
489 std::fs::remove_file(temp_file).ok();
490 }
491
492 #[tokio::test]
493 async fn test_notebook_edit_tool_delete_cell() {
494 let temp_dir = std::env::temp_dir();
495 let temp_file = temp_dir.join("test_nb_delete.ipynb");
496 let notebook = create_test_notebook();
497 std::fs::write(&temp_file, serde_json::to_string_pretty(¬ebook).unwrap()).unwrap();
498
499 let tool = NotebookEditTool::new();
500 let input = serde_json::json!({
501 "notebook_path": temp_file.to_str().unwrap(),
502 "cell_id": "def456",
503 "new_source": "",
504 "edit_mode": "delete"
505 });
506 let context = ToolContext::default();
507
508 let result = tool.execute(input, &context).await;
509 assert!(result.is_ok());
510
511 let content = std::fs::read_to_string(&temp_file).unwrap();
512 let nb: serde_json::Value = serde_json::from_str(&content).unwrap();
513 assert_eq!(nb["cells"].as_array().unwrap().len(), 1);
514 assert_eq!(nb["cells"][0]["cell_type"], "code");
515
516 std::fs::remove_file(temp_file).ok();
517 }
518
519 #[tokio::test]
520 async fn test_notebook_edit_tool_cell_id_numeric_fallback() {
521 let temp_dir = std::env::temp_dir();
522 let temp_file = temp_dir.join("test_nb_numeric.ipynb");
523 let notebook = create_test_notebook();
524 std::fs::write(&temp_file, serde_json::to_string_pretty(¬ebook).unwrap()).unwrap();
525
526 let tool = NotebookEditTool::new();
527 let input = serde_json::json!({
528 "notebook_path": temp_file.to_str().unwrap(),
529 "cell_id": "cell-1",
530 "new_source": "# Updated markdown",
531 "edit_mode": "replace"
532 });
533 let context = ToolContext::default();
534
535 let result = tool.execute(input, &context).await;
536 assert!(result.is_ok());
537
538 let content = std::fs::read_to_string(&temp_file).unwrap();
539 let nb: serde_json::Value = serde_json::from_str(&content).unwrap();
540 assert!(
541 nb["cells"][1]["source"].as_array().unwrap()[0]
542 .to_string()
543 .contains("Updated markdown")
544 );
545
546 std::fs::remove_file(temp_file).ok();
547 }
548
549 #[tokio::test]
550 async fn test_notebook_edit_tool_rejects_non_ipynb() {
551 let tool = NotebookEditTool::new();
552 let input = serde_json::json!({
553 "notebook_path": "/tmp/test.txt",
554 "new_source": "test",
555 "edit_mode": "replace"
556 });
557 let context = ToolContext::default();
558
559 let result = tool.execute(input, &context).await;
560 assert!(result.is_ok());
561 let tool_result = result.unwrap();
562 assert!(tool_result.is_error.is_some() && tool_result.is_error.unwrap());
563 assert!(tool_result.content.contains(".ipynb"));
564 }
565}