sofos 0.2.1

An interactive AI coding agent for your terminal
use crate::error::{Result, SofosError};
use crate::mcp::config::McpServerConfig;
use crate::mcp::protocol::*;
use serde_json::Value;
use std::collections::HashMap;
use std::io::{BufRead, BufReader, Write};
use std::process::{Child, ChildStdin, ChildStdout, Command, Stdio};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};

fn create_init_request() -> InitializeRequest {
    InitializeRequest {
        protocol_version: "2024-11-05".to_string(),
        capabilities: ClientCapabilities {
            roots: Some(RootsCapability { list_changed: true }),
            sampling: None,
        },
        client_info: ClientInfo {
            name: "sofos".to_string(),
            version: env!("CARGO_PKG_VERSION").to_string(),
        },
    }
}

fn parse_list_tools_response(result: Value) -> Result<Vec<McpTool>> {
    let list_result: ListToolsResult = serde_json::from_value(result)?;
    Ok(list_result.tools)
}

fn parse_call_tool_response(result: Value) -> Result<CallToolResult> {
    let call_result: CallToolResult = serde_json::from_value(result)?;
    Ok(call_result)
}

fn create_call_tool_request(name: &str, arguments: Option<Value>) -> CallToolRequest {
    CallToolRequest {
        name: name.to_string(),
        arguments,
    }
}

macro_rules! impl_list_tools {
    ($self:ident) => {{
        let result = $self.send_request("tools/list", None).await?;
        parse_list_tools_response(result)
    }};
}

macro_rules! impl_call_tool {
    ($self:ident, $name:ident, $arguments:ident) => {{
        let result = $self
            .send_request(
                "tools/call",
                Some(serde_json::to_value(create_call_tool_request(
                    $name, $arguments,
                ))?),
            )
            .await?;
        parse_call_tool_response(result)
    }};
}

pub enum McpClient {
    Stdio(StdioClient),
    Http(HttpClient),
}

impl McpClient {
    pub async fn connect(name: String, config: McpServerConfig) -> Result<Self> {
        if config.is_stdio() {
            let client = StdioClient::new(name, config).await?;
            Ok(McpClient::Stdio(client))
        } else if config.is_http() {
            let client = HttpClient::new(name, config).await?;
            Ok(McpClient::Http(client))
        } else {
            Err(SofosError::McpError(
                "Invalid MCP server configuration".to_string(),
            ))
        }
    }

    pub async fn list_tools(&self) -> Result<Vec<McpTool>> {
        match self {
            McpClient::Stdio(client) => client.list_tools().await,
            McpClient::Http(client) => client.list_tools().await,
        }
    }

    pub async fn call_tool(&self, name: &str, arguments: Option<Value>) -> Result<CallToolResult> {
        match self {
            McpClient::Stdio(client) => client.call_tool(name, arguments).await,
            McpClient::Http(client) => client.call_tool(name, arguments).await,
        }
    }
}

pub struct StdioClient {
    server_name: String,
    _process: Arc<Mutex<Child>>,
    stdin: Arc<Mutex<ChildStdin>>,
    stdout: Arc<Mutex<BufReader<ChildStdout>>>,
    next_id: Arc<AtomicU64>,
}

impl StdioClient {
    pub async fn new(server_name: String, config: McpServerConfig) -> Result<Self> {
        let command = config
            .command
            .ok_or_else(|| SofosError::McpError("Missing command for stdio server".to_string()))?;

        let args = config.args.unwrap_or_default();
        let env_vars = config.env.unwrap_or_default();

        let mut cmd = Command::new(&command);
        cmd.args(&args)
            .stdin(Stdio::piped())
            .stdout(Stdio::piped())
            .stderr(Stdio::null());

        for (key, value) in env_vars {
            cmd.env(key, value);
        }

        let mut process = cmd.spawn().map_err(|e| {
            SofosError::McpError(format!(
                "Failed to start MCP server '{}': {}",
                server_name, e
            ))
        })?;

        let stdin = process.stdin.take().ok_or_else(|| {
            SofosError::McpError(format!(
                "Failed to get stdin for MCP server '{}'",
                server_name
            ))
        })?;

        let stdout = process.stdout.take().ok_or_else(|| {
            SofosError::McpError(format!(
                "Failed to get stdout for MCP server '{}'",
                server_name
            ))
        })?;

        let client = Self {
            server_name: server_name.clone(),
            _process: Arc::new(Mutex::new(process)),
            stdin: Arc::new(Mutex::new(stdin)),
            stdout: Arc::new(Mutex::new(BufReader::new(stdout))),
            next_id: Arc::new(AtomicU64::new(1)),
        };

        client.initialize().await?;

        Ok(client)
    }

    async fn initialize(&self) -> Result<()> {
        let response = self
            .send_request(
                "initialize",
                Some(serde_json::to_value(create_init_request())?),
            )
            .await?;

        let _init_result: InitializeResult = serde_json::from_value(response)?;

        self.send_notification("notifications/initialized", None)
            .await?;

        Ok(())
    }

    async fn send_request(&self, method: &str, params: Option<Value>) -> Result<Value> {
        let id = self.next_id.fetch_add(1, Ordering::SeqCst);
        let request = JsonRpcRequest::new(id, method.to_string(), params);

        let request_json = serde_json::to_string(&request)?;

        {
            let mut stdin = self
                .stdin
                .lock()
                .map_err(|e| SofosError::McpError(format!("Failed to lock stdin: {}", e)))?;
            writeln!(stdin, "{}", request_json).map_err(|e| {
                SofosError::McpError(format!(
                    "Failed to write to MCP server '{}': {}",
                    self.server_name, e
                ))
            })?;
            stdin.flush().map_err(|e| {
                SofosError::McpError(format!(
                    "Failed to flush stdin for MCP server '{}': {}",
                    self.server_name, e
                ))
            })?;
        }

        let mut stdout = self
            .stdout
            .lock()
            .map_err(|e| SofosError::McpError(format!("Failed to lock stdout: {}", e)))?;

        let mut response_line = String::new();
        stdout.read_line(&mut response_line).map_err(|e| {
            SofosError::McpError(format!(
                "Failed to read from MCP server '{}': {}",
                self.server_name, e
            ))
        })?;

        let response: JsonRpcResponse = serde_json::from_str(&response_line).map_err(|e| {
            SofosError::McpError(format!(
                "Failed to parse response from MCP server '{}': {}",
                self.server_name, e
            ))
        })?;

        if let Some(error) = response.error {
            return Err(SofosError::McpError(format!(
                "MCP server '{}' returned error: {}",
                self.server_name, error.message
            )));
        }

        response.result.ok_or_else(|| {
            SofosError::McpError(format!(
                "MCP server '{}' returned no result",
                self.server_name
            ))
        })
    }

    async fn send_notification(&self, method: &str, _params: Option<Value>) -> Result<()> {
        let notification = serde_json::json!({
            "jsonrpc": "2.0",
            "method": method,
        });

        let notification_json = serde_json::to_string(&notification)?;

        let mut stdin = self
            .stdin
            .lock()
            .map_err(|e| SofosError::McpError(format!("Failed to lock stdin: {}", e)))?;
        writeln!(stdin, "{}", notification_json).map_err(|e| {
            SofosError::McpError(format!(
                "Failed to write notification to MCP server '{}': {}",
                self.server_name, e
            ))
        })?;
        stdin.flush().map_err(|e| {
            SofosError::McpError(format!(
                "Failed to flush stdin for MCP server '{}': {}",
                self.server_name, e
            ))
        })?;

        Ok(())
    }

    pub async fn list_tools(&self) -> Result<Vec<McpTool>> {
        impl_list_tools!(self)
    }

    pub async fn call_tool(&self, name: &str, arguments: Option<Value>) -> Result<CallToolResult> {
        impl_call_tool!(self, name, arguments)
    }
}

pub struct HttpClient {
    server_name: String,
    url: String,
    headers: HashMap<String, String>,
    client: reqwest::Client,
    next_id: Arc<AtomicU64>,
}

impl HttpClient {
    pub async fn new(server_name: String, config: McpServerConfig) -> Result<Self> {
        let url = config
            .url
            .ok_or_else(|| SofosError::McpError("Missing URL for HTTP server".to_string()))?;

        let headers = config.headers.unwrap_or_default();

        let client = reqwest::Client::new();

        let http_client = Self {
            server_name: server_name.clone(),
            url,
            headers,
            client,
            next_id: Arc::new(AtomicU64::new(1)),
        };

        http_client.initialize().await?;

        Ok(http_client)
    }

    async fn initialize(&self) -> Result<()> {
        let response = self
            .send_request(
                "initialize",
                Some(serde_json::to_value(create_init_request())?),
            )
            .await?;

        let _init_result: InitializeResult = serde_json::from_value(response)?;

        Ok(())
    }

    async fn send_request(&self, method: &str, params: Option<Value>) -> Result<Value> {
        let id = self.next_id.fetch_add(1, Ordering::SeqCst);
        let request = JsonRpcRequest::new(id, method.to_string(), params);

        let mut req = self.client.post(&self.url).json(&request);

        for (key, value) in &self.headers {
            req = req.header(key, value);
        }

        let response = req.send().await.map_err(|e| {
            SofosError::McpError(format!(
                "Failed to send request to MCP server '{}': {}",
                self.server_name, e
            ))
        })?;

        let response_json: JsonRpcResponse = response.json().await.map_err(|e| {
            SofosError::McpError(format!(
                "Failed to parse response from MCP server '{}': {}",
                self.server_name, e
            ))
        })?;

        if let Some(error) = response_json.error {
            return Err(SofosError::McpError(format!(
                "MCP server '{}' returned error: {}",
                self.server_name, error.message
            )));
        }

        response_json.result.ok_or_else(|| {
            SofosError::McpError(format!(
                "MCP server '{}' returned no result",
                self.server_name
            ))
        })
    }

    pub async fn list_tools(&self) -> Result<Vec<McpTool>> {
        impl_list_tools!(self)
    }

    pub async fn call_tool(&self, name: &str, arguments: Option<Value>) -> Result<CallToolResult> {
        impl_call_tool!(self, name, arguments)
    }
}