roboticus-agent 0.11.3

Agent core with ReAct loop, policy engine, injection defense, memory system, and skill loader
Documentation
//! Bridge: wraps MCP tools as `Capability` impls for the CapabilityRegistry.
//!
//! Each discovered MCP tool becomes an `McpCapability` that routes
//! `execute()` calls to `tools/call` on the originating server.

use std::sync::Arc;

use async_trait::async_trait;
use serde_json::Value;
use tokio::sync::RwLock;

use roboticus_core::RiskLevel;
use roboticus_core::config::McpTransport;

use crate::capability::{Capability, CapabilitySource};
use crate::tools::{ToolContext, ToolError, ToolResult};

use super::client::{DiscoveredTool, LiveMcpConnection};

/// A `Capability` backed by a remote MCP tool.
///
/// Holds a shared reference to the `LiveMcpConnection` so that tool calls
/// are routed to the correct server. Multiple `McpCapability` instances
/// may share the same connection (one per discovered tool on that server).
pub struct McpCapability {
    /// Fully qualified name: `{server}::{tool}`.
    prefixed_name: String,
    server_name: String,
    tool_name: String,
    description: String,
    input_schema: Value,
    transport: McpTransport,
    risk_level: RiskLevel,
    /// Shared connection to the originating server.
    connection: Arc<RwLock<LiveMcpConnection>>,
}

impl McpCapability {
    /// Create a new MCP capability from a discovered tool and its connection.
    pub fn new(
        server_name: &str,
        tool: &DiscoveredTool,
        transport: McpTransport,
        connection: Arc<RwLock<LiveMcpConnection>>,
    ) -> Self {
        Self {
            prefixed_name: format!("{server_name}::{}", tool.name),
            server_name: server_name.to_string(),
            tool_name: tool.name.clone(),
            description: tool.description.clone(),
            input_schema: tool.input_schema.clone(),
            transport,
            risk_level: RiskLevel::Caution,
            connection,
        }
    }

    /// Override the default risk level.
    pub fn with_risk_level(mut self, level: RiskLevel) -> Self {
        self.risk_level = level;
        self
    }
}

#[async_trait]
impl Capability for McpCapability {
    fn name(&self) -> &str {
        &self.prefixed_name
    }

    fn description(&self) -> &str {
        &self.description
    }

    fn risk_level(&self) -> RiskLevel {
        self.risk_level
    }

    fn parameters_schema(&self) -> Value {
        self.input_schema.clone()
    }

    fn source(&self) -> CapabilitySource {
        CapabilitySource::Mcp {
            server: self.server_name.clone(),
            transport: self.transport.clone(),
        }
    }

    async fn execute(&self, params: Value, _ctx: &ToolContext) -> Result<ToolResult, ToolError> {
        let conn = self.connection.read().await;
        if !conn.is_alive() {
            return Err(ToolError {
                message: format!("MCP server '{}' is not connected", self.server_name),
            });
        }

        let result = conn
            .call_tool(&self.tool_name, params)
            .await
            .map_err(|e| ToolError {
                message: format!("MCP tool '{}' call failed: {e}", self.prefixed_name),
            })?;

        // result is a JSON value like {"content": "...", "is_error": false}
        let is_error = result
            .get("is_error")
            .and_then(|v| v.as_bool())
            .unwrap_or(false);
        let content = result
            .get("content")
            .and_then(|v| v.as_str())
            .unwrap_or("")
            .to_string();

        if is_error {
            Err(ToolError {
                message: format!("MCP tool error: {content}"),
            })
        } else {
            Ok(ToolResult {
                output: content,
                metadata: Some(serde_json::json!({
                    "mcp_server": self.server_name,
                    "mcp_tool": self.tool_name,
                })),
            })
        }
    }
}

/// Build `McpCapability` instances for all tools on a connection.
///
/// Returns a `Vec` of capabilities ready to register with the `CapabilityRegistry`.
pub fn bridge_tools(
    server_name: &str,
    tools: &[DiscoveredTool],
    transport: McpTransport,
    connection: Arc<RwLock<LiveMcpConnection>>,
) -> Vec<McpCapability> {
    tools
        .iter()
        .map(|tool| {
            McpCapability::new(
                server_name,
                tool,
                transport.clone(),
                Arc::clone(&connection),
            )
        })
        .collect()
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::sync::Arc;
    use tokio::sync::RwLock;

    use crate::mcp::client::test_support;
    use crate::tools::{ToolContext, ToolSandboxSnapshot};
    use roboticus_core::InputAuthority;

    fn make_tool(name: &str, desc: &str) -> DiscoveredTool {
        DiscoveredTool {
            name: name.into(),
            description: desc.into(),
            input_schema: serde_json::json!({"type": "object"}),
        }
    }

    fn test_ctx() -> ToolContext {
        ToolContext {
            session_id: "test-session".into(),
            agent_id: "test-agent".into(),
            agent_name: "test-agent".into(),
            authority: InputAuthority::Creator,
            workspace_root: std::env::current_dir().unwrap(),
            tool_allowed_paths: vec![],
            channel: None,
            db: None,
            sandbox: ToolSandboxSnapshot::default(),
        }
    }

    // We can't easily create a LiveMcpConnection in tests without a real server,
    // so test the type construction and accessor methods.

    #[test]
    fn bridge_tools_produces_correct_names() {
        // Use a mock-like approach: create the DiscoveredTool list and verify
        // the bridge produces correctly named capabilities
        let tools = [
            make_tool("create_issue", "Create a GitHub issue"),
            make_tool("list_repos", "List repositories"),
        ];

        // We can't call bridge_tools without a real connection, but we can
        // test the naming logic directly
        let name = format!("github::{}", tools[0].name);
        assert_eq!(name, "github::create_issue");

        let name2 = format!("github::{}", tools[1].name);
        assert_eq!(name2, "github::list_repos");
    }

    #[test]
    fn prefixed_name_uses_double_colon() {
        let name = format!("{}::{}", "linear", "create_ticket");
        assert!(name.contains("::"));
        assert_eq!(name, "linear::create_ticket");
    }

    #[tokio::test]
    async fn bridge_tools_builds_capabilities_with_expected_metadata() {
        let (conn, server_handle) = test_support::echo_connection("remote-test").await.unwrap();
        let conn = Arc::new(RwLock::new(conn));
        let caps = {
            let read = conn.read().await;
            bridge_tools(
                "remote-test",
                read.tools(),
                McpTransport::Sse,
                Arc::clone(&conn),
            )
        };

        assert_eq!(caps.len(), 1);
        let cap = &caps[0];
        assert_eq!(cap.name(), "remote-test::echo");
        assert_eq!(cap.description(), "Echo back the provided text");
        assert_eq!(cap.parameters_schema()["type"], "object");
        match cap.source() {
            CapabilitySource::Mcp { server, transport } => {
                assert_eq!(server, "remote-test");
                assert!(matches!(transport, McpTransport::Sse));
            }
            other => panic!("expected MCP source, got {other:?}"),
        }

        server_handle.abort();
        let _ = server_handle.await;
    }

    #[tokio::test]
    async fn mcp_capability_executes_remote_tool_and_returns_metadata() {
        let (conn, server_handle) = test_support::echo_connection("remote-test").await.unwrap();
        let conn = Arc::new(RwLock::new(conn));
        let tool = {
            let read = conn.read().await;
            read.tools()[0].clone()
        };
        let cap = McpCapability::new("remote-test", &tool, McpTransport::Sse, Arc::clone(&conn))
            .with_risk_level(RiskLevel::Dangerous);

        let result = cap
            .execute(serde_json::json!({ "text": "hello bridge" }), &test_ctx())
            .await
            .unwrap();
        assert_eq!(cap.risk_level(), RiskLevel::Dangerous);
        assert_eq!(result.output, "hello bridge");
        assert_eq!(
            result.metadata.as_ref().unwrap()["mcp_server"],
            "remote-test"
        );
        assert_eq!(result.metadata.as_ref().unwrap()["mcp_tool"], "echo");

        server_handle.abort();
        let _ = server_handle.await;
    }
}