Skip to main content

ai_agent/tools/
notebook_edit.rs

1use crate::types::*;
2use std::fs;
3
4pub struct NotebookEditTool;
5
6impl NotebookEditTool {
7    pub fn new() -> Self {
8        Self
9    }
10
11    pub fn name(&self) -> &str {
12        "NotebookEdit"
13    }
14
15    pub fn description(&self) -> &str {
16        "Edit Jupyter notebook (.ipynb) cells"
17    }
18
19    pub fn input_schema(&self) -> ToolInputSchema {
20        ToolInputSchema {
21            schema_type: "object".to_string(),
22            properties: serde_json::json!({
23                "file_path": {
24                    "type": "string",
25                    "description": "Path to the .ipynb file"
26                },
27                "command": {
28                    "type": "string",
29                    "enum": ["insert", "replace", "delete"],
30                    "description": "The edit operation to perform"
31                },
32                "cell_number": {
33                    "type": "number",
34                    "description": "Cell index (0-based) to operate on"
35                },
36                "cell_type": {
37                    "type": "string",
38                    "enum": ["code", "markdown"],
39                    "description": "Type of cell (for insert/replace)"
40                },
41                "source": {
42                    "type": "string",
43                    "description": "Cell content (for insert/replace)"
44                }
45            }),
46            required: Some(vec![
47                "file_path".to_string(),
48                "command".to_string(),
49                "cell_number".to_string(),
50            ]),
51        }
52    }
53
54    pub async fn execute(
55        &self,
56        input: serde_json::Value,
57        _context: &ToolContext,
58    ) -> Result<ToolResult, crate::error::AgentError> {
59        let file_path = input["file_path"]
60            .as_str()
61            .ok_or_else(|| crate::error::AgentError::Tool("file_path is required".to_string()))?;
62
63        let command = input["command"]
64            .as_str()
65            .ok_or_else(|| crate::error::AgentError::Tool("command is required".to_string()))?;
66
67        let cell_number = input["cell_number"]
68            .as_u64()
69            .ok_or_else(|| crate::error::AgentError::Tool("cell_number is required".to_string()))?
70            as usize;
71
72        let cell_type = input["cell_type"].as_str();
73        let source = input["source"].as_str();
74
75        let content = fs::read_to_string(file_path).map_err(|e| crate::error::AgentError::Io(e))?;
76
77        let mut notebook: serde_json::Value = serde_json::from_str(&content).map_err(|e| {
78            crate::error::AgentError::Tool(format!("Invalid notebook format: {}", e))
79        })?;
80
81        let cells = notebook["cells"].as_array_mut().ok_or_else(|| {
82            crate::error::AgentError::Tool("Invalid notebook format: no cells array".to_string())
83        })?;
84
85        match command {
86            "insert" => {
87                let source_lines: Vec<String> = match source {
88                    Some(s) => s
89                        .lines()
90                        .enumerate()
91                        .map(|(i, l)| {
92                            if i < s.lines().count() - 1 {
93                                format!("{}\n", l)
94                            } else {
95                                l.to_string()
96                            }
97                        })
98                        .collect(),
99                    None => vec![],
100                };
101
102                let cell_type_str = cell_type.unwrap_or("code");
103
104                let mut new_cell = serde_json::json!({
105                    "cell_type": cell_type_str,
106                    "source": source_lines,
107                    "metadata": serde_json::json!({}),
108                });
109
110                if cell_type_str != "markdown" {
111                    new_cell["outputs"] = serde_json::json!([]);
112                    new_cell["execution_count"] = serde_json::json!(null);
113                }
114
115                cells.insert(cell_number, new_cell);
116            }
117            "replace" => {
118                if cell_number >= cells.len() {
119                    return Ok(ToolResult {
120                        result_type: "text".to_string(),
121                        tool_use_id: "".to_string(),
122                        content: format!("Error: Cell {} does not exist", cell_number),
123                        is_error: Some(true),
124                    });
125                }
126
127                let source_lines: Vec<String> = match source {
128                    Some(s) => s
129                        .lines()
130                        .enumerate()
131                        .map(|(i, l)| {
132                            if i < s.lines().count() - 1 {
133                                format!("{}\n", l)
134                            } else {
135                                l.to_string()
136                            }
137                        })
138                        .collect(),
139                    None => vec![],
140                };
141
142                cells[cell_number]["source"] = serde_json::json!(source_lines);
143
144                if let Some(ct) = cell_type {
145                    cells[cell_number]["cell_type"] = serde_json::json!(ct);
146                }
147            }
148            "delete" => {
149                if cell_number >= cells.len() {
150                    return Ok(ToolResult {
151                        result_type: "text".to_string(),
152                        tool_use_id: "".to_string(),
153                        content: format!("Error: Cell {} does not exist", cell_number),
154                        is_error: Some(true),
155                    });
156                }
157                cells.remove(cell_number);
158            }
159            _ => {
160                return Ok(ToolResult {
161                    result_type: "text".to_string(),
162                    tool_use_id: "".to_string(),
163                    content: format!("Error: Unknown command: {}", command),
164                    is_error: Some(true),
165                });
166            }
167        }
168
169        let new_content = serde_json::to_string_pretty(&notebook).map_err(|e| {
170            crate::error::AgentError::Tool(format!("Failed to serialize notebook: {}", e))
171        })?;
172
173        fs::write(file_path, new_content).map_err(|e| crate::error::AgentError::Io(e))?;
174
175        Ok(ToolResult {
176            result_type: "text".to_string(),
177            tool_use_id: "".to_string(),
178            content: format!(
179                "Notebook {}: cell {} in {}",
180                command, cell_number, file_path
181            ),
182            is_error: None,
183        })
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190
191    fn create_test_notebook() -> serde_json::Value {
192        serde_json::json!({
193            "nbformat": 4,
194            "nbformat_minor": 5,
195            "metadata": {},
196            "cells": [
197                {
198                    "cell_type": "code",
199                    "execution_count": null,
200                    "metadata": {},
201                    "outputs": [],
202                    "source": ["print('hello')\n"]
203                },
204                {
205                    "cell_type": "markdown",
206                    "metadata": {},
207                    "source": ["# Title\n"]
208                }
209            ]
210        })
211    }
212
213    #[test]
214    fn test_notebook_edit_tool_name() {
215        let tool = NotebookEditTool::new();
216        assert_eq!(tool.name(), "NotebookEdit");
217    }
218
219    #[test]
220    fn test_notebook_edit_tool_description_contains_notebook() {
221        let tool = NotebookEditTool::new();
222        assert!(tool.description().to_lowercase().contains("notebook"));
223    }
224
225    #[test]
226    fn test_notebook_edit_tool_has_file_path_in_schema() {
227        let tool = NotebookEditTool::new();
228        let schema = tool.input_schema();
229        assert!(schema.properties.get("file_path").is_some());
230    }
231
232    #[test]
233    fn test_notebook_edit_tool_has_command_in_schema() {
234        let tool = NotebookEditTool::new();
235        let schema = tool.input_schema();
236        assert!(schema.properties.get("command").is_some());
237    }
238
239    #[test]
240    fn test_notebook_edit_tool_has_cell_number_in_schema() {
241        let tool = NotebookEditTool::new();
242        let schema = tool.input_schema();
243        assert!(schema.properties.get("cell_number").is_some());
244    }
245
246    #[tokio::test]
247    async fn test_notebook_edit_tool_insert_code_cell() {
248        let temp_dir = std::env::temp_dir();
249        let temp_file = temp_dir.join("test_notebook.ipynb");
250        let notebook = create_test_notebook();
251        std::fs::write(&temp_file, serde_json::to_string_pretty(&notebook).unwrap()).unwrap();
252
253        let tool = NotebookEditTool::new();
254        let input = serde_json::json!({
255            "file_path": temp_file.to_str().unwrap(),
256            "command": "insert",
257            "cell_number": 0,
258            "cell_type": "code",
259            "source": "x = 1"
260        });
261        let context = ToolContext::default();
262
263        let result = tool.execute(input, &context).await;
264        assert!(result.is_ok());
265
266        let content = std::fs::read_to_string(&temp_file).unwrap();
267        let notebook: serde_json::Value = serde_json::from_str(&content).unwrap();
268        assert_eq!(notebook["cells"].as_array().unwrap().len(), 3);
269
270        std::fs::remove_file(temp_file).ok();
271    }
272
273    #[tokio::test]
274    async fn test_notebook_edit_tool_replace_cell() {
275        let temp_dir = std::env::temp_dir();
276        let temp_file = temp_dir.join("test_notebook_replace.ipynb");
277        let notebook = create_test_notebook();
278        std::fs::write(&temp_file, serde_json::to_string_pretty(&notebook).unwrap()).unwrap();
279
280        let tool = NotebookEditTool::new();
281        let input = serde_json::json!({
282            "file_path": temp_file.to_str().unwrap(),
283            "command": "replace",
284            "cell_number": 0,
285            "source": "print('replaced')"
286        });
287        let context = ToolContext::default();
288
289        let result = tool.execute(input, &context).await;
290        assert!(result.is_ok());
291
292        let content = std::fs::read_to_string(&temp_file).unwrap();
293        let notebook: serde_json::Value = serde_json::from_str(&content).unwrap();
294        assert_eq!(
295            notebook["cells"][0]["source"].as_array().unwrap()[0],
296            "print('replaced')"
297        );
298
299        std::fs::remove_file(temp_file).ok();
300    }
301
302    #[tokio::test]
303    async fn test_notebook_edit_tool_delete_cell() {
304        let temp_dir = std::env::temp_dir();
305        let temp_file = temp_dir.join("test_notebook_delete.ipynb");
306        let notebook = create_test_notebook();
307        std::fs::write(&temp_file, serde_json::to_string_pretty(&notebook).unwrap()).unwrap();
308
309        let tool = NotebookEditTool::new();
310        let input = serde_json::json!({
311            "file_path": temp_file.to_str().unwrap(),
312            "command": "delete",
313            "cell_number": 1
314        });
315        let context = ToolContext::default();
316
317        let result = tool.execute(input, &context).await;
318        assert!(result.is_ok());
319
320        let content = std::fs::read_to_string(&temp_file).unwrap();
321        let notebook: serde_json::Value = serde_json::from_str(&content).unwrap();
322        assert_eq!(notebook["cells"].as_array().unwrap().len(), 1);
323
324        std::fs::remove_file(temp_file).ok();
325    }
326
327    #[tokio::test]
328    async fn test_notebook_edit_tool_returns_error_for_invalid_cell() {
329        let temp_dir = std::env::temp_dir();
330        let temp_file = temp_dir.join("test_notebook_invalid.ipynb");
331        let notebook = create_test_notebook();
332        std::fs::write(&temp_file, serde_json::to_string_pretty(&notebook).unwrap()).unwrap();
333
334        let tool = NotebookEditTool::new();
335        let input = serde_json::json!({
336            "file_path": temp_file.to_str().unwrap(),
337            "command": "replace",
338            "cell_number": 100,
339            "source": "test"
340        });
341        let context = ToolContext::default();
342
343        let result = tool.execute(input, &context).await;
344        assert!(result.is_ok());
345        let tool_result = result.unwrap();
346        assert!(tool_result.is_error.is_some() && tool_result.is_error.unwrap());
347
348        std::fs::remove_file(temp_file).ok();
349    }
350
351    #[tokio::test]
352    async fn test_notebook_edit_tool_returns_error_for_nonexistent_file() {
353        let tool = NotebookEditTool::new();
354        let input = serde_json::json!({
355            "file_path": "/nonexistent/notebook.ipynb",
356            "command": "insert",
357            "cell_number": 0
358        });
359        let context = ToolContext::default();
360
361        let result = tool.execute(input, &context).await;
362        assert!(result.is_err());
363    }
364}