Skip to main content

aster/tools/
notebook_edit_tool.rs

1//! Notebook Edit Tool Implementation
2//!
3//! 此模块实现了 `NotebookEditTool`,用于编辑 Jupyter Notebook 单元格:
4//! - 支持 replace, insert, delete 三种编辑模式
5//! - 自动清理单元格输出(code 类型)
6//! - Jupyter notebook 格式验证
7//! - 增强的错误处理和路径验证
8//! - 保留单元格元数据
9//!
10//! Requirements: 基于 Claude Agent SDK notebook.ts 中的 NotebookEditTool 实现
11
12use async_trait::async_trait;
13use serde::{Deserialize, Serialize};
14use std::fs;
15use std::path::PathBuf;
16
17use super::base::{PermissionCheckResult, Tool};
18use super::context::{ToolContext, ToolOptions, ToolResult};
19use super::error::ToolError;
20
21/// Jupyter Notebook 单元格
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct NotebookCell {
24    /// 单元格 ID(可选,nbformat 4.5+ 支持)
25    #[serde(skip_serializing_if = "Option::is_none")]
26    pub id: Option<String>,
27    /// 单元格类型
28    pub cell_type: String,
29    /// 源代码内容
30    pub source: serde_json::Value,
31    /// 单元格元数据
32    #[serde(default)]
33    pub metadata: serde_json::Value,
34    /// 输出(仅 code 单元格)
35    #[serde(skip_serializing_if = "Option::is_none")]
36    pub outputs: Option<Vec<serde_json::Value>>,
37    /// 执行计数(仅 code 单元格)
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub execution_count: Option<serde_json::Value>,
40}
41
42/// Jupyter Notebook 内容
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct NotebookContent {
45    /// 单元格列表
46    pub cells: Vec<NotebookCell>,
47    /// Notebook 元数据
48    pub metadata: serde_json::Value,
49    /// Notebook 格式版本
50    pub nbformat: u32,
51    /// Notebook 格式次版本
52    pub nbformat_minor: u32,
53}
54
55/// NotebookEdit 工具输入参数
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct NotebookEditInput {
58    /// Notebook 文件的绝对路径
59    pub notebook_path: String,
60    /// 要编辑的单元格 ID
61    #[serde(skip_serializing_if = "Option::is_none")]
62    pub cell_id: Option<String>,
63    /// 新的源代码内容
64    pub new_source: String,
65    /// 单元格类型(可选)
66    #[serde(skip_serializing_if = "Option::is_none")]
67    pub cell_type: Option<String>,
68    /// 编辑模式
69    #[serde(default = "default_edit_mode")]
70    pub edit_mode: String,
71}
72
73fn default_edit_mode() -> String {
74    "replace".to_string()
75}
76
77/// Notebook Edit Tool for editing Jupyter notebook cells
78///
79/// 提供 Jupyter Notebook 单元格编辑功能:
80/// - 替换单元格内容(replace 模式)
81/// - 插入新单元格(insert 模式)
82/// - 删除单元格(delete 模式)
83/// - 自动清理 code 单元格输出
84/// - 格式验证和错误处理
85#[derive(Debug)]
86pub struct NotebookEditTool {
87    /// 工具名称
88    name: String,
89}
90
91impl Default for NotebookEditTool {
92    fn default() -> Self {
93        Self::new()
94    }
95}
96
97impl NotebookEditTool {
98    /// Create a new NotebookEditTool
99    pub fn new() -> Self {
100        Self {
101            name: "NotebookEdit".to_string(),
102        }
103    }
104
105    /// 验证 Jupyter notebook 格式
106    fn validate_notebook_format(&self, notebook: &serde_json::Value) -> Result<(), String> {
107        // 检查必需字段
108        let cells = notebook
109            .get("cells")
110            .and_then(|c| c.as_array())
111            .ok_or("Invalid notebook structure: missing or invalid cells array")?;
112
113        let nbformat = notebook
114            .get("nbformat")
115            .and_then(|n| n.as_u64())
116            .ok_or("Invalid notebook structure: missing or invalid nbformat")?;
117
118        let nbformat_minor = notebook
119            .get("nbformat_minor")
120            .and_then(|n| n.as_u64())
121            .ok_or("Invalid notebook structure: missing or invalid nbformat_minor")?;
122
123        // 验证 nbformat 版本(支持 v4.x)
124        if nbformat < 4 {
125            return Err(format!(
126                "Unsupported notebook format version: {}.{} (only v4.x is supported)",
127                nbformat, nbformat_minor
128            ));
129        }
130
131        // 验证 metadata
132        if !notebook.get("metadata").is_some_and(|m| m.is_object()) {
133            return Err("Invalid notebook structure: missing or invalid metadata".to_string());
134        }
135
136        // 验证每个单元格的基本结构
137        for (i, cell) in cells.iter().enumerate() {
138            let cell_type = cell
139                .get("cell_type")
140                .and_then(|t| t.as_str())
141                .ok_or_else(|| {
142                    format!("Invalid cell at index {}: missing or invalid cell_type", i)
143                })?;
144
145            if !["code", "markdown", "raw"].contains(&cell_type) {
146                return Err(format!(
147                    "Invalid cell at index {}: unknown cell_type '{}'",
148                    i, cell_type
149                ));
150            }
151
152            if cell.get("source").is_none() {
153                return Err(format!("Invalid cell at index {}: missing source", i));
154            }
155        }
156
157        Ok(())
158    }
159
160    /// 查找单元格索引
161    fn find_cell_index(&self, cells: &[serde_json::Value], cell_id: &str) -> i32 {
162        // 首先尝试按 ID 精确匹配
163        for (i, cell) in cells.iter().enumerate() {
164            if let Some(id) = cell.get("id").and_then(|id| id.as_str()) {
165                if id == cell_id {
166                    return i as i32;
167                }
168            }
169        }
170
171        // 尝试解析为数字索引
172        if let Ok(num_index) = cell_id.parse::<i32>() {
173            // 支持负数索引(从末尾开始)
174            if num_index < 0 {
175                let positive_index = cells.len() as i32 + num_index;
176                if positive_index >= 0 && positive_index < cells.len() as i32 {
177                    return positive_index;
178                }
179            } else if num_index >= 0 && (num_index as usize) < cells.len() {
180                return num_index;
181            }
182        }
183
184        -1
185    }
186
187    /// 清除单元格输出
188    fn clear_cell_outputs(&self, cell: &mut serde_json::Value) {
189        if let Some(cell_type) = cell.get("cell_type").and_then(|t| t.as_str()) {
190            if cell_type == "code" {
191                cell.as_object_mut()
192                    .unwrap()
193                    .insert("outputs".to_string(), serde_json::json!([]));
194                cell.as_object_mut()
195                    .unwrap()
196                    .insert("execution_count".to_string(), serde_json::Value::Null);
197            } else {
198                // 对于非 code 单元格,移除 outputs 和 execution_count 字段
199                if let Some(obj) = cell.as_object_mut() {
200                    obj.remove("outputs");
201                    obj.remove("execution_count");
202                }
203            }
204        }
205    }
206
207    /// 生成唯一的单元格 ID
208    fn generate_cell_id(&self) -> String {
209        use rand::Rng;
210        const CHARS: &[u8] = b"abcdefghijklmnopqrstuvwxyz0123456789";
211        let mut rng = rand::thread_rng();
212        (0..8)
213            .map(|_| {
214                let idx = rng.gen_range(0..CHARS.len());
215                CHARS[idx] as char
216            })
217            .collect()
218    }
219}
220
221#[async_trait]
222impl Tool for NotebookEditTool {
223    /// Returns the tool name
224    fn name(&self) -> &str {
225        &self.name
226    }
227
228    /// Returns the tool description
229    fn description(&self) -> &str {
230        "Replace the contents of a specific cell in a Jupyter notebook. \
231         Completely replaces the contents of a specific cell in a Jupyter notebook (.ipynb file) with new source. \
232         Jupyter notebooks are interactive documents that combine code, text, and visualizations, \
233         commonly used for data analysis and scientific computing. \
234         The notebook_path parameter must be an absolute path, not a relative path. \
235         The cell_id can be a cell ID or numeric index (0-indexed). \
236         Use edit_mode=insert to add a new cell at the index specified by cell_id. \
237         Use edit_mode=delete to delete the cell at the index specified by cell_id."
238    }
239
240    /// Returns the JSON Schema for input parameters
241    fn input_schema(&self) -> serde_json::Value {
242        serde_json::json!({
243            "type": "object",
244            "properties": {
245                "notebook_path": {
246                    "type": "string",
247                    "description": "The absolute path to the Jupyter notebook file to edit (must be absolute, not relative)"
248                },
249                "cell_id": {
250                    "type": "string",
251                    "description": "The ID of the cell to edit. When inserting a new cell, the new cell will be inserted after the cell with this ID, or at the beginning if not specified."
252                },
253                "new_source": {
254                    "type": "string",
255                    "description": "The new source for the cell"
256                },
257                "cell_type": {
258                    "type": "string",
259                    "enum": ["code", "markdown"],
260                    "description": "The type of the cell (code or markdown). If not specified, it defaults to the current cell type. If using edit_mode=insert, this is required."
261                },
262                "edit_mode": {
263                    "type": "string",
264                    "enum": ["replace", "insert", "delete"],
265                    "description": "The type of edit to make (replace, insert, delete). Defaults to replace."
266                }
267            },
268            "required": ["notebook_path", "new_source"]
269        })
270    }
271
272    /// Execute the notebook edit command
273    async fn execute(
274        &self,
275        params: serde_json::Value,
276        _context: &ToolContext,
277    ) -> Result<ToolResult, ToolError> {
278        // Extract input parameters
279        let input: NotebookEditInput = serde_json::from_value(params)
280            .map_err(|e| ToolError::invalid_params(format!("Invalid input format: {}", e)))?;
281
282        let notebook_path = PathBuf::from(&input.notebook_path);
283        let edit_mode = input.edit_mode.as_str();
284
285        // 验证路径是否为绝对路径
286        if !notebook_path.is_absolute() {
287            return Ok(ToolResult::error(format!(
288                "notebook_path must be an absolute path, got: {}",
289                input.notebook_path
290            )));
291        }
292
293        // 检查文件是否存在
294        if !notebook_path.exists() {
295            return Ok(ToolResult::error(format!(
296                "Notebook file not found: {}",
297                notebook_path.display()
298            )));
299        }
300
301        // 检查是否是文件(不是目录)
302        let metadata = fs::metadata(&notebook_path).map_err(|e| {
303            ToolError::execution_failed(format!("Failed to read file metadata: {}", e))
304        })?;
305
306        if !metadata.is_file() {
307            return Ok(ToolResult::error(format!(
308                "Path is not a file: {}",
309                notebook_path.display()
310            )));
311        }
312
313        // 检查文件扩展名
314        if notebook_path.extension().is_none_or(|ext| ext != "ipynb") {
315            return Ok(ToolResult::error(format!(
316                "File must be a Jupyter notebook (.ipynb), got: {}",
317                notebook_path
318                    .extension()
319                    .unwrap_or_default()
320                    .to_string_lossy()
321            )));
322        }
323
324        // 读取并解析 notebook
325        let content = fs::read_to_string(&notebook_path).map_err(|e| {
326            ToolError::execution_failed(format!("Failed to read notebook file: {}", e))
327        })?;
328
329        let mut notebook: serde_json::Value = serde_json::from_str(&content).map_err(|e| {
330            ToolError::execution_failed(format!("Failed to parse notebook JSON: {}", e))
331        })?;
332
333        // 验证 notebook 格式
334        if let Err(error) = self.validate_notebook_format(&notebook) {
335            return Ok(ToolResult::error(error));
336        }
337
338        // 提前获取 nbformat 信息,避免借用冲突
339        let nbformat = notebook
340            .get("nbformat")
341            .and_then(|n| n.as_u64())
342            .unwrap_or(4);
343        let nbformat_minor = notebook
344            .get("nbformat_minor")
345            .and_then(|n| n.as_u64())
346            .unwrap_or(0);
347
348        // 获取单元格数组
349        let cells = notebook
350            .get_mut("cells")
351            .and_then(|c| c.as_array_mut())
352            .ok_or_else(|| {
353                ToolError::execution_failed("Invalid notebook format: missing cells".to_string())
354            })?;
355
356        // 找到目标单元格索引
357        let mut cell_index: i32 = 0;
358        if let Some(cell_id) = &input.cell_id {
359            cell_index = self.find_cell_index(cells, cell_id);
360
361            // 如果按 ID 找不到,对于非 insert 模式报错
362            if cell_index == -1 {
363                if edit_mode == "insert" {
364                    // insert 模式下找不到 cell_id,也应该报错
365                    return Ok(ToolResult::error(format!(
366                        "Cell with ID \"{}\" not found in notebook.",
367                        cell_id
368                    )));
369                } else if edit_mode == "replace" {
370                    // replace 模式下,如果是数字索引且超出范围,转为 insert 到末尾
371                    if let Ok(num_index) = cell_id.parse::<usize>() {
372                        cell_index = num_index as i32;
373                    } else {
374                        return Ok(ToolResult::error(format!(
375                            "Cell not found with ID: {}. Available cells: {}",
376                            cell_id,
377                            cells.len()
378                        )));
379                    }
380                } else {
381                    return Ok(ToolResult::error(format!(
382                        "Cell not found with ID: {}. Available cells: {}",
383                        cell_id,
384                        cells.len()
385                    )));
386                }
387            }
388
389            // insert 模式:在找到的单元格之后插入
390            if edit_mode == "insert" && cell_index != -1 {
391                cell_index += 1;
392            }
393        }
394
395        // delete 模式必须指定 cell_id
396        if edit_mode == "delete" && input.cell_id.is_none() {
397            return Ok(ToolResult::error(
398                "cell_id is required for delete mode".to_string(),
399            ));
400        }
401
402        // 执行编辑操作
403        let result_message = match edit_mode {
404            "replace" => {
405                let cell_index = cell_index as usize;
406
407                // 特殊处理:如果索引超出范围,自动转为 insert
408                if cell_index >= cells.len() {
409                    let final_cell_type = input.cell_type.as_deref().unwrap_or("code");
410                    let mut new_cell = serde_json::json!({
411                        "cell_type": final_cell_type,
412                        "source": input.new_source,
413                        "metadata": {}
414                    });
415
416                    // 初始化 code 单元格的输出
417                    if final_cell_type == "code" {
418                        new_cell
419                            .as_object_mut()
420                            .unwrap()
421                            .insert("outputs".to_string(), serde_json::json!([]));
422                        new_cell
423                            .as_object_mut()
424                            .unwrap()
425                            .insert("execution_count".to_string(), serde_json::Value::Null);
426                    }
427
428                    // 只在 nbformat 4.5+ 生成 ID
429                    if nbformat > 4 || (nbformat == 4 && nbformat_minor >= 5) {
430                        new_cell
431                            .as_object_mut()
432                            .unwrap()
433                            .insert("id".to_string(), serde_json::json!(self.generate_cell_id()));
434                    }
435
436                    cells.push(new_cell);
437                    format!(
438                        "Inserted new {} cell at position {} (converted from replace)",
439                        final_cell_type, cell_index
440                    )
441                } else {
442                    let cell = &mut cells[cell_index];
443                    let old_type = cell
444                        .get("cell_type")
445                        .and_then(|t| t.as_str())
446                        .unwrap_or("unknown")
447                        .to_string();
448
449                    // 更新源代码
450                    cell.as_object_mut()
451                        .unwrap()
452                        .insert("source".to_string(), serde_json::json!(input.new_source));
453
454                    // 如果指定了 cell_type,更新类型
455                    if let Some(new_type) = &input.cell_type {
456                        cell.as_object_mut()
457                            .unwrap()
458                            .insert("cell_type".to_string(), serde_json::json!(new_type));
459                    }
460
461                    // 清除输出(对于 code 单元格)
462                    self.clear_cell_outputs(cell);
463
464                    let current_type = cell
465                        .get("cell_type")
466                        .and_then(|t| t.as_str())
467                        .unwrap_or("unknown");
468                    if old_type != current_type {
469                        format!(
470                            "Replaced cell {} (changed type from {} to {})",
471                            cell_index, old_type, current_type
472                        )
473                    } else {
474                        format!("Replaced cell {}", cell_index)
475                    }
476                }
477            }
478            "insert" => {
479                let cell_index = cell_index as usize;
480                let final_cell_type = input.cell_type.as_deref().unwrap_or("code");
481
482                let mut new_cell = serde_json::json!({
483                    "cell_type": final_cell_type,
484                    "source": input.new_source,
485                    "metadata": {}
486                });
487
488                // 初始化 code 单元格的输出
489                if final_cell_type == "code" {
490                    new_cell
491                        .as_object_mut()
492                        .unwrap()
493                        .insert("outputs".to_string(), serde_json::json!([]));
494                    new_cell
495                        .as_object_mut()
496                        .unwrap()
497                        .insert("execution_count".to_string(), serde_json::Value::Null);
498                }
499
500                // 只在 nbformat 4.5+ 生成 ID
501                if nbformat > 4 || (nbformat == 4 && nbformat_minor >= 5) {
502                    new_cell
503                        .as_object_mut()
504                        .unwrap()
505                        .insert("id".to_string(), serde_json::json!(self.generate_cell_id()));
506                }
507
508                // 插入新单元格
509                if cell_index <= cells.len() {
510                    cells.insert(cell_index, new_cell);
511                } else {
512                    cells.push(new_cell);
513                }
514
515                format!(
516                    "Inserted new {} cell at position {}",
517                    final_cell_type, cell_index
518                )
519            }
520            "delete" => {
521                let cell_index = cell_index as usize;
522                if cell_index >= cells.len() {
523                    return Ok(ToolResult::error(format!(
524                        "Cell index out of range: {} (total cells: {})",
525                        cell_index,
526                        cells.len()
527                    )));
528                }
529
530                let deleted_type = cells[cell_index]
531                    .get("cell_type")
532                    .and_then(|t| t.as_str())
533                    .unwrap_or("unknown")
534                    .to_string();
535                cells.remove(cell_index);
536
537                format!(
538                    "Deleted {} cell at position {} ({} cells remaining)",
539                    deleted_type,
540                    cell_index,
541                    cells.len()
542                )
543            }
544            _ => {
545                return Ok(ToolResult::error(format!(
546                    "Invalid edit_mode: {}. Must be 'replace', 'insert', or 'delete'",
547                    edit_mode
548                )));
549            }
550        };
551
552        // 写回文件(使用美化的 JSON 格式,缩进 1 空格)
553        let formatted_json = serde_json::to_string_pretty(&notebook).map_err(|e| {
554            ToolError::execution_failed(format!("Failed to serialize notebook: {}", e))
555        })?;
556
557        // 手动调整缩进为 1 空格(serde_json::to_string_pretty 使用 2 空格)
558        let formatted_json = formatted_json
559            .lines()
560            .map(|line| {
561                let leading_spaces = line.len() - line.trim_start().len();
562                let adjusted_spaces = leading_spaces / 2; // 从 2 空格缩进改为 1 空格
563                format!("{}{}", " ".repeat(adjusted_spaces), line.trim_start())
564            })
565            .collect::<Vec<_>>()
566            .join("\n");
567
568        fs::write(&notebook_path, format!("{}\n", formatted_json)).map_err(|e| {
569            ToolError::execution_failed(format!("Failed to write notebook file: {}", e))
570        })?;
571
572        let filename = notebook_path
573            .file_name()
574            .unwrap_or_default()
575            .to_string_lossy();
576        Ok(ToolResult::success(format!(
577            "{} in {}",
578            result_message, filename
579        )))
580    }
581
582    /// Check permissions before execution
583    async fn check_permissions(
584        &self,
585        params: &serde_json::Value,
586        _context: &ToolContext,
587    ) -> PermissionCheckResult {
588        // Validate input format
589        match serde_json::from_value::<NotebookEditInput>(params.clone()) {
590            Ok(input) => {
591                let notebook_path = PathBuf::from(&input.notebook_path);
592
593                // Check if path is absolute
594                if !notebook_path.is_absolute() {
595                    return PermissionCheckResult::deny(format!(
596                        "notebook_path must be an absolute path, got: {}",
597                        input.notebook_path
598                    ));
599                }
600
601                // Check if it's a notebook file
602                if notebook_path.extension().is_none_or(|ext| ext != "ipynb") {
603                    return PermissionCheckResult::deny(format!(
604                        "File must be a Jupyter notebook (.ipynb), got: {}",
605                        notebook_path
606                            .extension()
607                            .unwrap_or_default()
608                            .to_string_lossy()
609                    ));
610                }
611
612                // Validate edit_mode
613                if !["replace", "insert", "delete"].contains(&input.edit_mode.as_str()) {
614                    return PermissionCheckResult::deny(format!(
615                        "Invalid edit_mode: {}. Must be 'replace', 'insert', or 'delete'",
616                        input.edit_mode
617                    ));
618                }
619
620                // delete 模式必须指定 cell_id
621                if input.edit_mode == "delete" && input.cell_id.is_none() {
622                    return PermissionCheckResult::deny(
623                        "cell_id is required for delete mode".to_string(),
624                    );
625                }
626
627                PermissionCheckResult::allow()
628            }
629            Err(e) => PermissionCheckResult::deny(format!("Invalid input format: {}", e)),
630        }
631    }
632
633    /// Get tool options
634    fn options(&self) -> ToolOptions {
635        ToolOptions::new()
636            .with_max_retries(0) // Don't retry notebook operations
637            .with_base_timeout(std::time::Duration::from_secs(30)) // Longer timeout for file operations
638            .with_dynamic_timeout(false)
639    }
640}
641
642// =============================================================================
643// Unit Tests
644// =============================================================================
645
646#[cfg(test)]
647mod tests {
648    use super::*;
649    use std::path::PathBuf;
650    use tempfile::TempDir;
651
652    fn create_test_context() -> ToolContext {
653        ToolContext::new(PathBuf::from("/tmp"))
654            .with_session_id("test-session")
655            .with_user("test-user")
656    }
657
658    fn create_test_notebook() -> serde_json::Value {
659        serde_json::json!({
660            "cells": [
661                {
662                    "id": "cell-1",
663                    "cell_type": "code",
664                    "source": "print('Hello, World!')",
665                    "metadata": {},
666                    "outputs": [],
667                    "execution_count": null
668                },
669                {
670                    "id": "cell-2",
671                    "cell_type": "markdown",
672                    "source": "# Test Markdown",
673                    "metadata": {}
674                }
675            ],
676            "metadata": {
677                "kernelspec": {
678                    "display_name": "Python 3",
679                    "language": "python",
680                    "name": "python3"
681                }
682            },
683            "nbformat": 4,
684            "nbformat_minor": 5
685        })
686    }
687
688    #[test]
689    fn test_tool_name() {
690        let tool = NotebookEditTool::new();
691        assert_eq!(tool.name(), "NotebookEdit");
692    }
693
694    #[test]
695    fn test_tool_description() {
696        let tool = NotebookEditTool::new();
697        assert!(!tool.description().is_empty());
698        assert!(tool.description().contains("Jupyter notebook"));
699        assert!(tool.description().contains("cell"));
700    }
701
702    #[test]
703    fn test_tool_input_schema() {
704        let tool = NotebookEditTool::new();
705        let schema = tool.input_schema();
706        assert_eq!(schema["type"], "object");
707        assert!(schema["properties"]["notebook_path"].is_object());
708        assert!(schema["properties"]["new_source"].is_object());
709        assert!(schema["required"]
710            .as_array()
711            .unwrap()
712            .contains(&serde_json::json!("notebook_path")));
713        assert!(schema["required"]
714            .as_array()
715            .unwrap()
716            .contains(&serde_json::json!("new_source")));
717    }
718
719    #[test]
720    fn test_tool_options() {
721        let tool = NotebookEditTool::new();
722        let options = tool.options();
723        assert_eq!(options.max_retries, 0);
724        assert_eq!(options.base_timeout, std::time::Duration::from_secs(30));
725        assert!(!options.enable_dynamic_timeout);
726    }
727
728    #[test]
729    fn test_validate_notebook_format_valid() {
730        let tool = NotebookEditTool::new();
731        let notebook = create_test_notebook();
732        assert!(tool.validate_notebook_format(&notebook).is_ok());
733    }
734
735    #[test]
736    fn test_validate_notebook_format_missing_cells() {
737        let tool = NotebookEditTool::new();
738        let notebook = serde_json::json!({
739            "metadata": {},
740            "nbformat": 4,
741            "nbformat_minor": 5
742        });
743        let result = tool.validate_notebook_format(&notebook);
744        assert!(result.is_err());
745        assert!(result.unwrap_err().contains("missing or invalid cells"));
746    }
747
748    #[test]
749    fn test_validate_notebook_format_invalid_nbformat() {
750        let tool = NotebookEditTool::new();
751        let notebook = serde_json::json!({
752            "cells": [],
753            "metadata": {},
754            "nbformat": 3,
755            "nbformat_minor": 0
756        });
757        let result = tool.validate_notebook_format(&notebook);
758        assert!(result.is_err());
759        assert!(result
760            .unwrap_err()
761            .contains("Unsupported notebook format version"));
762    }
763
764    #[test]
765    fn test_validate_notebook_format_invalid_cell_type() {
766        let tool = NotebookEditTool::new();
767        let notebook = serde_json::json!({
768            "cells": [
769                {
770                    "cell_type": "invalid_type",
771                    "source": "test"
772                }
773            ],
774            "metadata": {},
775            "nbformat": 4,
776            "nbformat_minor": 5
777        });
778        let result = tool.validate_notebook_format(&notebook);
779        assert!(result.is_err());
780        assert!(result.unwrap_err().contains("unknown cell_type"));
781    }
782
783    #[test]
784    fn test_find_cell_index_by_id() {
785        let tool = NotebookEditTool::new();
786        let notebook = create_test_notebook();
787        let cells = notebook.get("cells").unwrap().as_array().unwrap();
788
789        assert_eq!(tool.find_cell_index(cells, "cell-1"), 0);
790        assert_eq!(tool.find_cell_index(cells, "cell-2"), 1);
791        assert_eq!(tool.find_cell_index(cells, "nonexistent"), -1);
792    }
793
794    #[test]
795    fn test_find_cell_index_by_number() {
796        let tool = NotebookEditTool::new();
797        let notebook = create_test_notebook();
798        let cells = notebook.get("cells").unwrap().as_array().unwrap();
799
800        assert_eq!(tool.find_cell_index(cells, "0"), 0);
801        assert_eq!(tool.find_cell_index(cells, "1"), 1);
802        assert_eq!(tool.find_cell_index(cells, "2"), -1); // Out of range
803        assert_eq!(tool.find_cell_index(cells, "-1"), 1); // Last cell
804        assert_eq!(tool.find_cell_index(cells, "-2"), 0); // Second to last
805    }
806
807    #[test]
808    fn test_clear_cell_outputs() {
809        let tool = NotebookEditTool::new();
810        let mut cell = serde_json::json!({
811            "cell_type": "code",
812            "source": "print('test')",
813            "outputs": [{"output_type": "stream", "text": "test"}],
814            "execution_count": 5
815        });
816
817        tool.clear_cell_outputs(&mut cell);
818
819        assert_eq!(cell["outputs"], serde_json::json!([]));
820        assert_eq!(cell["execution_count"], serde_json::Value::Null);
821    }
822
823    #[test]
824    fn test_clear_cell_outputs_markdown() {
825        let tool = NotebookEditTool::new();
826        let mut cell = serde_json::json!({
827            "cell_type": "markdown",
828            "source": "# Test"
829        });
830
831        tool.clear_cell_outputs(&mut cell);
832
833        // Markdown cells should not have outputs added
834        assert!(!cell.as_object().unwrap().contains_key("outputs"));
835        assert!(!cell.as_object().unwrap().contains_key("execution_count"));
836    }
837
838    #[test]
839    fn test_generate_cell_id() {
840        let tool = NotebookEditTool::new();
841        let id1 = tool.generate_cell_id();
842        let id2 = tool.generate_cell_id();
843
844        assert_eq!(id1.len(), 8);
845        assert_eq!(id2.len(), 8);
846        assert_ne!(id1, id2); // Should be different
847
848        // Should only contain lowercase letters and numbers
849        for c in id1.chars() {
850            assert!(c.is_ascii_lowercase() || c.is_ascii_digit());
851        }
852    }
853
854    // Permission Check Tests
855
856    #[tokio::test]
857    async fn test_check_permissions_valid_input() {
858        let tool = NotebookEditTool::new();
859        let context = create_test_context();
860        let params = serde_json::json!({
861            "notebook_path": "/tmp/test.ipynb",
862            "new_source": "print('hello')",
863            "edit_mode": "replace"
864        });
865
866        let result = tool.check_permissions(&params, &context).await;
867        assert!(result.is_allowed());
868    }
869
870    #[tokio::test]
871    async fn test_check_permissions_relative_path() {
872        let tool = NotebookEditTool::new();
873        let context = create_test_context();
874        let params = serde_json::json!({
875            "notebook_path": "test.ipynb",
876            "new_source": "print('hello')"
877        });
878
879        let result = tool.check_permissions(&params, &context).await;
880        assert!(result.is_denied());
881        assert!(result.message.unwrap().contains("must be an absolute path"));
882    }
883
884    #[tokio::test]
885    async fn test_check_permissions_not_notebook() {
886        let tool = NotebookEditTool::new();
887        let context = create_test_context();
888        let params = serde_json::json!({
889            "notebook_path": "/tmp/test.py",
890            "new_source": "print('hello')"
891        });
892
893        let result = tool.check_permissions(&params, &context).await;
894        assert!(result.is_denied());
895        assert!(result
896            .message
897            .unwrap()
898            .contains("must be a Jupyter notebook"));
899    }
900
901    #[tokio::test]
902    async fn test_check_permissions_invalid_edit_mode() {
903        let tool = NotebookEditTool::new();
904        let context = create_test_context();
905        let params = serde_json::json!({
906            "notebook_path": "/tmp/test.ipynb",
907            "new_source": "print('hello')",
908            "edit_mode": "invalid"
909        });
910
911        let result = tool.check_permissions(&params, &context).await;
912        assert!(result.is_denied());
913        assert!(result.message.unwrap().contains("Invalid edit_mode"));
914    }
915
916    #[tokio::test]
917    async fn test_check_permissions_delete_without_cell_id() {
918        let tool = NotebookEditTool::new();
919        let context = create_test_context();
920        let params = serde_json::json!({
921            "notebook_path": "/tmp/test.ipynb",
922            "new_source": "print('hello')",
923            "edit_mode": "delete"
924        });
925
926        let result = tool.check_permissions(&params, &context).await;
927        assert!(result.is_denied());
928        assert!(result
929            .message
930            .unwrap()
931            .contains("cell_id is required for delete mode"));
932    }
933
934    #[tokio::test]
935    async fn test_check_permissions_invalid_format() {
936        let tool = NotebookEditTool::new();
937        let context = create_test_context();
938        let params = serde_json::json!({"invalid": "format"});
939
940        let result = tool.check_permissions(&params, &context).await;
941        assert!(result.is_denied());
942        assert!(result.message.unwrap().contains("Invalid input format"));
943    }
944
945    // Execution Tests
946
947    #[tokio::test]
948    async fn test_execute_file_not_found() {
949        let tool = NotebookEditTool::new();
950        let context = create_test_context();
951        let params = serde_json::json!({
952            "notebook_path": "/tmp/nonexistent.ipynb",
953            "new_source": "print('hello')"
954        });
955
956        let result = tool.execute(params, &context).await;
957        assert!(result.is_ok());
958        let tool_result = result.unwrap();
959        assert!(tool_result.is_error());
960        assert!(tool_result
961            .error
962            .unwrap()
963            .contains("Notebook file not found"));
964    }
965
966    #[tokio::test]
967    async fn test_execute_not_a_file() {
968        let tool = NotebookEditTool::new();
969        let context = create_test_context();
970        let temp_dir = TempDir::new().unwrap();
971        let dir_path = temp_dir.path().join("test.ipynb");
972        fs::create_dir(&dir_path).unwrap();
973
974        let params = serde_json::json!({
975            "notebook_path": dir_path.to_string_lossy(),
976            "new_source": "print('hello')"
977        });
978
979        let result = tool.execute(params, &context).await;
980        assert!(result.is_ok());
981        let tool_result = result.unwrap();
982        assert!(tool_result.is_error());
983        assert!(tool_result.error.unwrap().contains("Path is not a file"));
984    }
985
986    #[tokio::test]
987    async fn test_execute_invalid_json() {
988        let tool = NotebookEditTool::new();
989        let context = create_test_context();
990        let temp_dir = TempDir::new().unwrap();
991        let file_path = temp_dir.path().join("test.ipynb");
992        fs::write(&file_path, "invalid json").unwrap();
993
994        let params = serde_json::json!({
995            "notebook_path": file_path.to_string_lossy(),
996            "new_source": "print('hello')"
997        });
998
999        let result = tool.execute(params, &context).await;
1000        assert!(result.is_err());
1001        assert!(result
1002            .unwrap_err()
1003            .to_string()
1004            .contains("Failed to parse notebook JSON"));
1005    }
1006
1007    #[tokio::test]
1008    async fn test_execute_replace_cell() {
1009        let tool = NotebookEditTool::new();
1010        let context = create_test_context();
1011        let temp_dir = TempDir::new().unwrap();
1012        let file_path = temp_dir.path().join("test.ipynb");
1013
1014        let notebook = create_test_notebook();
1015        fs::write(&file_path, serde_json::to_string(&notebook).unwrap()).unwrap();
1016
1017        let params = serde_json::json!({
1018            "notebook_path": file_path.to_string_lossy(),
1019            "cell_id": "cell-1",
1020            "new_source": "print('Hello, Rust!')",
1021            "edit_mode": "replace"
1022        });
1023
1024        let result = tool.execute(params, &context).await;
1025        assert!(result.is_ok());
1026        let tool_result = result.unwrap();
1027        assert!(tool_result.is_success());
1028        assert!(tool_result.output.unwrap().contains("Replaced cell 0"));
1029
1030        // Verify the file was updated
1031        let updated_content = fs::read_to_string(&file_path).unwrap();
1032        let updated_notebook: serde_json::Value = serde_json::from_str(&updated_content).unwrap();
1033        let cells = updated_notebook["cells"].as_array().unwrap();
1034        assert_eq!(cells[0]["source"], "print('Hello, Rust!')");
1035        assert_eq!(cells[0]["outputs"], serde_json::json!([]));
1036        assert_eq!(cells[0]["execution_count"], serde_json::Value::Null);
1037    }
1038
1039    #[tokio::test]
1040    async fn test_execute_insert_cell() {
1041        let tool = NotebookEditTool::new();
1042        let context = create_test_context();
1043        let temp_dir = TempDir::new().unwrap();
1044        let file_path = temp_dir.path().join("test.ipynb");
1045
1046        let notebook = create_test_notebook();
1047        fs::write(&file_path, serde_json::to_string(&notebook).unwrap()).unwrap();
1048
1049        let params = serde_json::json!({
1050            "notebook_path": file_path.to_string_lossy(),
1051            "cell_id": "cell-1",
1052            "new_source": "# New markdown cell",
1053            "cell_type": "markdown",
1054            "edit_mode": "insert"
1055        });
1056
1057        let result = tool.execute(params, &context).await;
1058        assert!(result.is_ok());
1059        let tool_result = result.unwrap();
1060        assert!(tool_result.is_success());
1061        assert!(tool_result
1062            .output
1063            .unwrap()
1064            .contains("Inserted new markdown cell at position 1"));
1065
1066        // Verify the file was updated
1067        let updated_content = fs::read_to_string(&file_path).unwrap();
1068        let updated_notebook: serde_json::Value = serde_json::from_str(&updated_content).unwrap();
1069        let cells = updated_notebook["cells"].as_array().unwrap();
1070        assert_eq!(cells.len(), 3); // Original 2 + 1 inserted
1071        assert_eq!(cells[1]["source"], "# New markdown cell");
1072        assert_eq!(cells[1]["cell_type"], "markdown");
1073    }
1074
1075    #[tokio::test]
1076    async fn test_execute_delete_cell() {
1077        let tool = NotebookEditTool::new();
1078        let context = create_test_context();
1079        let temp_dir = TempDir::new().unwrap();
1080        let file_path = temp_dir.path().join("test.ipynb");
1081
1082        let notebook = create_test_notebook();
1083        fs::write(&file_path, serde_json::to_string(&notebook).unwrap()).unwrap();
1084
1085        let params = serde_json::json!({
1086            "notebook_path": file_path.to_string_lossy(),
1087            "cell_id": "cell-1",
1088            "new_source": "", // Not used for delete
1089            "edit_mode": "delete"
1090        });
1091
1092        let result = tool.execute(params, &context).await;
1093        assert!(result.is_ok());
1094        let tool_result = result.unwrap();
1095        assert!(tool_result.is_success());
1096        assert!(tool_result
1097            .output
1098            .unwrap()
1099            .contains("Deleted code cell at position 0"));
1100
1101        // Verify the file was updated
1102        let updated_content = fs::read_to_string(&file_path).unwrap();
1103        let updated_notebook: serde_json::Value = serde_json::from_str(&updated_content).unwrap();
1104        let cells = updated_notebook["cells"].as_array().unwrap();
1105        assert_eq!(cells.len(), 1); // Original 2 - 1 deleted
1106        assert_eq!(cells[0]["id"], "cell-2"); // cell-2 should remain
1107    }
1108
1109    #[tokio::test]
1110    async fn test_execute_cell_not_found() {
1111        let tool = NotebookEditTool::new();
1112        let context = create_test_context();
1113        let temp_dir = TempDir::new().unwrap();
1114        let file_path = temp_dir.path().join("test.ipynb");
1115
1116        let notebook = create_test_notebook();
1117        fs::write(&file_path, serde_json::to_string(&notebook).unwrap()).unwrap();
1118
1119        let params = serde_json::json!({
1120            "notebook_path": file_path.to_string_lossy(),
1121            "cell_id": "nonexistent",
1122            "new_source": "print('hello')",
1123            "edit_mode": "replace"
1124        });
1125
1126        let result = tool.execute(params, &context).await;
1127        assert!(result.is_ok());
1128        let tool_result = result.unwrap();
1129        assert!(tool_result.is_error());
1130        assert!(tool_result
1131            .error
1132            .unwrap()
1133            .contains("Cell not found with ID: nonexistent"));
1134    }
1135
1136    #[tokio::test]
1137    async fn test_execute_replace_out_of_range_converts_to_insert() {
1138        let tool = NotebookEditTool::new();
1139        let context = create_test_context();
1140        let temp_dir = TempDir::new().unwrap();
1141        let file_path = temp_dir.path().join("test.ipynb");
1142
1143        let notebook = create_test_notebook();
1144        fs::write(&file_path, serde_json::to_string(&notebook).unwrap()).unwrap();
1145
1146        let params = serde_json::json!({
1147            "notebook_path": file_path.to_string_lossy(),
1148            "cell_id": "5", // Out of range index
1149            "new_source": "print('new cell')",
1150            "edit_mode": "replace"
1151        });
1152
1153        let result = tool.execute(params, &context).await;
1154        assert!(result.is_ok());
1155        let tool_result = result.unwrap();
1156        assert!(tool_result.is_success());
1157        assert!(tool_result
1158            .output
1159            .unwrap()
1160            .contains("Inserted new code cell"));
1161
1162        // Verify the file was updated
1163        let updated_content = fs::read_to_string(&file_path).unwrap();
1164        let updated_notebook: serde_json::Value = serde_json::from_str(&updated_content).unwrap();
1165        let cells = updated_notebook["cells"].as_array().unwrap();
1166        assert_eq!(cells.len(), 3); // Original 2 + 1 inserted
1167        assert_eq!(cells[2]["source"], "print('new cell')");
1168        assert_eq!(cells[2]["cell_type"], "code");
1169    }
1170
1171    #[tokio::test]
1172    async fn test_execute_invalid_input_format() {
1173        let tool = NotebookEditTool::new();
1174        let context = create_test_context();
1175        let params = serde_json::json!({"invalid": "format"});
1176
1177        let result = tool.execute(params, &context).await;
1178        assert!(result.is_err());
1179        assert!(matches!(result.unwrap_err(), ToolError::InvalidParams(_)));
1180    }
1181
1182    #[tokio::test]
1183    async fn test_execute_change_cell_type() {
1184        let tool = NotebookEditTool::new();
1185        let context = create_test_context();
1186        let temp_dir = TempDir::new().unwrap();
1187        let file_path = temp_dir.path().join("test.ipynb");
1188
1189        let notebook = create_test_notebook();
1190        fs::write(&file_path, serde_json::to_string(&notebook).unwrap()).unwrap();
1191
1192        let params = serde_json::json!({
1193            "notebook_path": file_path.to_string_lossy(),
1194            "cell_id": "cell-1",
1195            "new_source": "# Now a markdown cell",
1196            "cell_type": "markdown",
1197            "edit_mode": "replace"
1198        });
1199
1200        let result = tool.execute(params, &context).await;
1201        assert!(result.is_ok());
1202        let tool_result = result.unwrap();
1203        assert!(tool_result.is_success());
1204        assert!(tool_result
1205            .output
1206            .unwrap()
1207            .contains("changed type from code to markdown"));
1208
1209        // Verify the file was updated
1210        let updated_content = fs::read_to_string(&file_path).unwrap();
1211        let updated_notebook: serde_json::Value = serde_json::from_str(&updated_content).unwrap();
1212        let cells = updated_notebook["cells"].as_array().unwrap();
1213        assert_eq!(cells[0]["source"], "# Now a markdown cell");
1214        assert_eq!(cells[0]["cell_type"], "markdown");
1215        // Should not have outputs or execution_count for markdown
1216        assert!(!cells[0].as_object().unwrap().contains_key("outputs"));
1217        assert!(!cells[0]
1218            .as_object()
1219            .unwrap()
1220            .contains_key("execution_count"));
1221    }
1222}