ai-agent-sdk 0.4.0

Idiomatic agent sdk inspired by the claude code source leak
Documentation
use crate::types::{McpServerConfig, McpConnectionStatus, McpTool, ToolDefinition, ToolInputSchema};
use std::process::Stdio;

/// MCP connection representation
#[derive(Debug, Clone)]
pub struct McpConnection {
    pub name: String,
    pub status: McpConnectionStatus,
    pub tools: Vec<ToolDefinition>,
}

impl McpConnection {
    /// Close the MCP connection
    pub async fn close(&mut self) {
        self.status = McpConnectionStatus::Disconnected;
        self.tools.clear();
    }
}

/// Connect to an MCP server and fetch its tools.
pub async fn connect_mcp_server(
    name: &str,
    config: &McpServerConfig,
) -> Result<McpConnection, crate::error::AgentError> {
    match config {
        McpServerConfig::Stdio(stdio_config) => connect_mcp_stdio(name, stdio_config).await,
        McpServerConfig::Sse(_sse_config) => {
            // SSE not implemented - return error connection
            Ok(McpConnection {
                name: name.to_string(),
                status: McpConnectionStatus::Error,
                tools: vec![],
            })
        }
        McpServerConfig::Http(_http_config) => {
            // HTTP not implemented - return error connection
            Ok(McpConnection {
                name: name.to_string(),
                status: McpConnectionStatus::Error,
                tools: vec![],
            })
        }
    }
}

/// Connect to an MCP server via stdio
async fn connect_mcp_stdio(
    name: &str,
    config: &crate::types::McpStdioConfig,
) -> Result<McpConnection, crate::error::AgentError> {
    // Use tokio's process module for async IO
    use std::collections::HashMap;
    use tokio::process::Command;
    use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};

    let mut env_vars: HashMap<String, String> = std::env::vars().collect();
    if let Some(custom_env) = &config.env {
        for (key, value) in custom_env {
            env_vars.insert(key.clone(), value.clone());
        }
    }

    let mut child = Command::new(&config.command)
        .args(config.args.as_deref().unwrap_or(&[]))
        .envs(&env_vars)
        .kill_on_drop(true)
        .stdout(Stdio::piped())
        .stderr(Stdio::piped())
        .stdin(Stdio::piped())
        .spawn()
        .map_err(|e| crate::error::AgentError::Mcp(format!("Failed to spawn MCP server: {}", e)))?;

    let stdout = child.stdout.take().ok_or_else(|| {
        crate::error::AgentError::Mcp("Failed to take stdout".to_string())
    })?;
    let mut stdin = child.stdin.take().ok_or_else(|| {
        crate::error::AgentError::Mcp("Failed to take stdin".to_string())
    })?;

    let mut stdout_reader = BufReader::new(stdout).lines();

    // Send initialize request
    let initialize_request = serde_json::json!({
        "jsonrpc": "2.0",
        "id": 1,
        "method": "initialize",
        "params": {
            "protocolVersion": "2024-11-05",
            "capabilities": {},
            "clientInfo": {
                "name": format!("agent-sdk-{}", name),
                "version": "1.0.0"
            }
        }
    });

    stdin.write_all(format!("{initialize_request}\n").as_bytes()).await
        .map_err(|e| crate::error::AgentError::Io(std::io::Error::new(std::io::ErrorKind::Other, e.to_string())))?;
    stdin.flush().await
        .map_err(|e| crate::error::AgentError::Io(std::io::Error::new(std::io::ErrorKind::Other, e.to_string())))?;

    // Read initialize response (drain it)
    let _ = stdout_reader.next_line().await;

    // Send tools/list request
    let list_tools_request = serde_json::json!({
        "jsonrpc": "2.0",
        "id": 2,
        "method": "tools/list"
    });

    stdin.write_all(format!("{list_tools_request}\n").as_bytes()).await
        .map_err(|e| crate::error::AgentError::Io(std::io::Error::new(std::io::ErrorKind::Other, e.to_string())))?;
    stdin.flush().await
        .map_err(|e| crate::error::AgentError::Io(std::io::Error::new(std::io::ErrorKind::Other, e.to_string())))?;

    // Read tools/list response
    let mut tools = vec![];
    if let Ok(Some(response)) = stdout_reader.next_line().await {
        if let Ok(resp) = serde_json::from_str::<serde_json::Value>(&response) {
            if let Some(result) = resp.get("result") {
                if let Some(tools_array) = result.get("tools").and_then(|t| t.as_array()) {
                    for tool_val in tools_array {
                        if let Ok(mcp_tool) = serde_json::from_value::<McpTool>(tool_val.clone()) {
                            let tool_def = create_mcp_tool_definition(name, &mcp_tool);
                            tools.push(tool_def);
                        }
                    }
                }
            }
        }
    }

    // Drop stdin to signal EOF to the server, but keep the process alive
    drop(stdin);

    // Don't wait for the child - let it run in the background
    // In a real implementation, we'd want to manage this more carefully

    Ok(McpConnection {
        name: name.to_string(),
        status: McpConnectionStatus::Connected,
        tools,
    })
}

/// Create a ToolDefinition from an MCP tool
fn create_mcp_tool_definition(server_name: &str, mcp_tool: &McpTool) -> ToolDefinition {
    let tool_name = format!("mcp__{}__{}", server_name, mcp_tool.name);

    let input_schema = mcp_tool.input_schema.clone().unwrap_or_else(|| {
        serde_json::json!({
            "type": "object",
            "properties": {}
        })
    });

    ToolDefinition {
        name: tool_name,
        description: mcp_tool.description.clone().unwrap_or_else(|| format!("MCP tool: {}", mcp_tool.name)),
        input_schema: ToolInputSchema {
            schema_type: input_schema.get("type").and_then(|t| t.as_str()).unwrap_or("object").to_string(),
            properties: input_schema.get("properties").cloned().unwrap_or(serde_json::json!({})),
            required: input_schema.get("required").and_then(|r| r.as_array()).map(|arr| {
                arr.iter().filter_map(|s| s.as_str().map(String::from)).collect()
            }),
        },
    }
}

/// Close all MCP connections
pub async fn close_all_connections(connections: &mut Vec<McpConnection>) {
    for conn in connections.iter_mut() {
        conn.close().await;
    }
    connections.clear();
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::collections::HashMap;

    #[test]
    fn test_mcp_stdio_config() {
        let mut env = HashMap::new();
        env.insert("KEY".to_string(), "value".to_string());

        let config = McpServerConfig::Stdio(crate::types::McpStdioConfig {
            transport_type: Some("stdio".to_string()),
            command: "npx".to_string(),
            args: Some(vec!["-y".to_string(), "some-server".to_string()]),
            env: Some(env),
        });

        match config {
            McpServerConfig::Stdio(s) => {
                assert_eq!(s.command, "npx");
                assert_eq!(s.args.unwrap().len(), 2);
            }
            _ => panic!("Expected Stdio variant"),
        }
    }

    #[test]
    fn test_mcp_tool_definition_creation() {
        let mcp_tool = McpTool {
            name: "test_tool".to_string(),
            description: Some("A test tool".to_string()),
            input_schema: Some(serde_json::json!({
                "type": "object",
                "properties": {
                    "arg1": { "type": "string" }
                },
                "required": ["arg1"]
            })),
        };

        let tool_def = create_mcp_tool_definition("myserver", &mcp_tool);
        assert_eq!(tool_def.name, "mcp__myserver__test_tool");
        assert_eq!(tool_def.description, "A test tool");
    }

    #[tokio::test]
    async fn test_mcp_connection_status() {
        let conn = McpConnection {
            name: "test".to_string(),
            status: McpConnectionStatus::Connected,
            tools: vec![],
        };

        assert_eq!(conn.status, McpConnectionStatus::Connected);
    }

    #[tokio::test]
    async fn test_sse_config() {
        let config = McpServerConfig::Sse(crate::types::McpSseConfig {
            transport_type: "sse".to_string(),
            url: "http://localhost:3000/sse".to_string(),
            headers: None,
        });

        match config {
            McpServerConfig::Sse(s) => {
                assert_eq!(s.url, "http://localhost:3000/sse");
            }
            _ => panic!("Expected Sse variant"),
        }
    }

    #[tokio::test]
    async fn test_http_config() {
        let config = McpServerConfig::Http(crate::types::McpHttpConfig {
            transport_type: "http".to_string(),
            url: "http://localhost:3000/mcp".to_string(),
            headers: Some({
                let mut h = HashMap::new();
                h.insert("Authorization".to_string(), "Bearer token".to_string());
                h
            }),
        });

        match config {
            McpServerConfig::Http(h) => {
                assert_eq!(h.url, "http://localhost:3000/mcp");
                assert!(h.headers.is_some());
            }
            _ => panic!("Expected Http variant"),
        }
    }
}