ai-agent-sdk 0.5.0

Idiomatic agent sdk inspired by the claude code source leak
Documentation
use crate::types::*;
use std::fs;

pub struct NotebookEditTool;

impl NotebookEditTool {
    pub fn new() -> Self {
        Self
    }

    pub fn name(&self) -> &str {
        "NotebookEdit"
    }

    pub fn description(&self) -> &str {
        "Edit Jupyter notebook (.ipynb) cells"
    }

    pub fn input_schema(&self) -> ToolInputSchema {
        ToolInputSchema {
            schema_type: "object".to_string(),
            properties: serde_json::json!({
                "file_path": {
                    "type": "string",
                    "description": "Path to the .ipynb file"
                },
                "command": {
                    "type": "string",
                    "enum": ["insert", "replace", "delete"],
                    "description": "The edit operation to perform"
                },
                "cell_number": {
                    "type": "number",
                    "description": "Cell index (0-based) to operate on"
                },
                "cell_type": {
                    "type": "string",
                    "enum": ["code", "markdown"],
                    "description": "Type of cell (for insert/replace)"
                },
                "source": {
                    "type": "string",
                    "description": "Cell content (for insert/replace)"
                }
            }),
            required: Some(vec!["file_path".to_string(), "command".to_string(), "cell_number".to_string()]),
        }
    }

    pub async fn execute(
        &self,
        input: serde_json::Value,
        _context: &ToolContext,
    ) -> Result<ToolResult, crate::error::AgentError> {
        let file_path = input["file_path"]
            .as_str()
            .ok_or_else(|| crate::error::AgentError::Tool("file_path is required".to_string()))?;

        let command = input["command"]
            .as_str()
            .ok_or_else(|| crate::error::AgentError::Tool("command is required".to_string()))?;

        let cell_number = input["cell_number"]
            .as_u64()
            .ok_or_else(|| crate::error::AgentError::Tool("cell_number is required".to_string()))? as usize;

        let cell_type = input["cell_type"].as_str();
        let source = input["source"].as_str();

        let content = fs::read_to_string(file_path)
            .map_err(|e| crate::error::AgentError::Io(e))?;

        let mut notebook: serde_json::Value = serde_json::from_str(&content)
            .map_err(|e| crate::error::AgentError::Tool(format!("Invalid notebook format: {}", e)))?;

        let cells = notebook["cells"].as_array_mut()
            .ok_or_else(|| crate::error::AgentError::Tool("Invalid notebook format: no cells array".to_string()))?;

        match command {
            "insert" => {
                let source_lines: Vec<String> = match source {
                    Some(s) => s.lines().enumerate().map(|(i, l)| {
                        if i < s.lines().count() - 1 {
                            format!("{}\n", l)
                        } else {
                            l.to_string()
                        }
                    }).collect(),
                    None => vec![],
                };

                let cell_type_str = cell_type.unwrap_or("code");

                let mut new_cell = serde_json::json!({
                    "cell_type": cell_type_str,
                    "source": source_lines,
                    "metadata": serde_json::json!({}),
                });

                if cell_type_str != "markdown" {
                    new_cell["outputs"] = serde_json::json!([]);
                    new_cell["execution_count"] = serde_json::json!(null);
                }

                cells.insert(cell_number, new_cell);
            }
            "replace" => {
                if cell_number >= cells.len() {
                    return Ok(ToolResult {
                        result_type: "text".to_string(),
                        tool_use_id: "".to_string(),
                        content: format!("Error: Cell {} does not exist", cell_number),
                        is_error: Some(true),
                    });
                }

                let source_lines: Vec<String> = match source {
                    Some(s) => s.lines().enumerate().map(|(i, l)| {
                        if i < s.lines().count() - 1 {
                            format!("{}\n", l)
                        } else {
                            l.to_string()
                        }
                    }).collect(),
                    None => vec![],
                };

                cells[cell_number]["source"] = serde_json::json!(source_lines);

                if let Some(ct) = cell_type {
                    cells[cell_number]["cell_type"] = serde_json::json!(ct);
                }
            }
            "delete" => {
                if cell_number >= cells.len() {
                    return Ok(ToolResult {
                        result_type: "text".to_string(),
                        tool_use_id: "".to_string(),
                        content: format!("Error: Cell {} does not exist", cell_number),
                        is_error: Some(true),
                    });
                }
                cells.remove(cell_number);
            }
            _ => {
                return Ok(ToolResult {
                    result_type: "text".to_string(),
                    tool_use_id: "".to_string(),
                    content: format!("Error: Unknown command: {}", command),
                    is_error: Some(true),
                });
            }
        }

        let new_content = serde_json::to_string_pretty(&notebook)
            .map_err(|e| crate::error::AgentError::Tool(format!("Failed to serialize notebook: {}", e)))?;

        fs::write(file_path, new_content)
            .map_err(|e| crate::error::AgentError::Io(e))?;

        Ok(ToolResult {
            result_type: "text".to_string(),
            tool_use_id: "".to_string(),
            content: format!("Notebook {}: cell {} in {}", command, cell_number, file_path),
            is_error: None,
        })
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    fn create_test_notebook() -> serde_json::Value {
        serde_json::json!({
            "nbformat": 4,
            "nbformat_minor": 5,
            "metadata": {},
            "cells": [
                {
                    "cell_type": "code",
                    "execution_count": null,
                    "metadata": {},
                    "outputs": [],
                    "source": ["print('hello')\n"]
                },
                {
                    "cell_type": "markdown",
                    "metadata": {},
                    "source": ["# Title\n"]
                }
            ]
        })
    }

    #[test]
    fn test_notebook_edit_tool_name() {
        let tool = NotebookEditTool::new();
        assert_eq!(tool.name(), "NotebookEdit");
    }

    #[test]
    fn test_notebook_edit_tool_description_contains_notebook() {
        let tool = NotebookEditTool::new();
        assert!(tool.description().to_lowercase().contains("notebook"));
    }

    #[test]
    fn test_notebook_edit_tool_has_file_path_in_schema() {
        let tool = NotebookEditTool::new();
        let schema = tool.input_schema();
        assert!(schema.properties.get("file_path").is_some());
    }

    #[test]
    fn test_notebook_edit_tool_has_command_in_schema() {
        let tool = NotebookEditTool::new();
        let schema = tool.input_schema();
        assert!(schema.properties.get("command").is_some());
    }

    #[test]
    fn test_notebook_edit_tool_has_cell_number_in_schema() {
        let tool = NotebookEditTool::new();
        let schema = tool.input_schema();
        assert!(schema.properties.get("cell_number").is_some());
    }

    #[tokio::test]
    async fn test_notebook_edit_tool_insert_code_cell() {
        let temp_dir = std::env::temp_dir();
        let temp_file = temp_dir.join("test_notebook.ipynb");
        let notebook = create_test_notebook();
        std::fs::write(&temp_file, serde_json::to_string_pretty(&notebook).unwrap()).unwrap();

        let tool = NotebookEditTool::new();
        let input = serde_json::json!({
            "file_path": temp_file.to_str().unwrap(),
            "command": "insert",
            "cell_number": 0,
            "cell_type": "code",
            "source": "x = 1"
        });
        let context = ToolContext::default();

        let result = tool.execute(input, &context).await;
        assert!(result.is_ok());

        let content = std::fs::read_to_string(&temp_file).unwrap();
        let notebook: serde_json::Value = serde_json::from_str(&content).unwrap();
        assert_eq!(notebook["cells"].as_array().unwrap().len(), 3);

        std::fs::remove_file(temp_file).ok();
    }

    #[tokio::test]
    async fn test_notebook_edit_tool_replace_cell() {
        let temp_dir = std::env::temp_dir();
        let temp_file = temp_dir.join("test_notebook_replace.ipynb");
        let notebook = create_test_notebook();
        std::fs::write(&temp_file, serde_json::to_string_pretty(&notebook).unwrap()).unwrap();

        let tool = NotebookEditTool::new();
        let input = serde_json::json!({
            "file_path": temp_file.to_str().unwrap(),
            "command": "replace",
            "cell_number": 0,
            "source": "print('replaced')"
        });
        let context = ToolContext::default();

        let result = tool.execute(input, &context).await;
        assert!(result.is_ok());

        let content = std::fs::read_to_string(&temp_file).unwrap();
        let notebook: serde_json::Value = serde_json::from_str(&content).unwrap();
        assert_eq!(notebook["cells"][0]["source"].as_array().unwrap()[0], "print('replaced')");

        std::fs::remove_file(temp_file).ok();
    }

    #[tokio::test]
    async fn test_notebook_edit_tool_delete_cell() {
        let temp_dir = std::env::temp_dir();
        let temp_file = temp_dir.join("test_notebook_delete.ipynb");
        let notebook = create_test_notebook();
        std::fs::write(&temp_file, serde_json::to_string_pretty(&notebook).unwrap()).unwrap();

        let tool = NotebookEditTool::new();
        let input = serde_json::json!({
            "file_path": temp_file.to_str().unwrap(),
            "command": "delete",
            "cell_number": 1
        });
        let context = ToolContext::default();

        let result = tool.execute(input, &context).await;
        assert!(result.is_ok());

        let content = std::fs::read_to_string(&temp_file).unwrap();
        let notebook: serde_json::Value = serde_json::from_str(&content).unwrap();
        assert_eq!(notebook["cells"].as_array().unwrap().len(), 1);

        std::fs::remove_file(temp_file).ok();
    }

    #[tokio::test]
    async fn test_notebook_edit_tool_returns_error_for_invalid_cell() {
        let temp_dir = std::env::temp_dir();
        let temp_file = temp_dir.join("test_notebook_invalid.ipynb");
        let notebook = create_test_notebook();
        std::fs::write(&temp_file, serde_json::to_string_pretty(&notebook).unwrap()).unwrap();

        let tool = NotebookEditTool::new();
        let input = serde_json::json!({
            "file_path": temp_file.to_str().unwrap(),
            "command": "replace",
            "cell_number": 100,
            "source": "test"
        });
        let context = ToolContext::default();

        let result = tool.execute(input, &context).await;
        assert!(result.is_ok());
        let tool_result = result.unwrap();
        assert!(tool_result.is_error.is_some() && tool_result.is_error.unwrap());

        std::fs::remove_file(temp_file).ok();
    }

    #[tokio::test]
    async fn test_notebook_edit_tool_returns_error_for_nonexistent_file() {
        let tool = NotebookEditTool::new();
        let input = serde_json::json!({
            "file_path": "/nonexistent/notebook.ipynb",
            "command": "insert",
            "cell_number": 0
        });
        let context = ToolContext::default();

        let result = tool.execute(input, &context).await;
        assert!(result.is_err());
    }
}