uira-orchestration 0.1.1

Agent definitions, SDK, tool registry, and hook implementations for Uira
Documentation
use crate::tools::provider::ToolProvider;
use crate::tools::{ToolContext, ToolError};
use async_trait::async_trait;
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use uira_core::schema::NamedMcpServerConfig;
use uira_core::{ToolOutput, ToolSpec};
use uira_mcp_client::{McpRuntimeManager, McpServerConfig};

#[derive(Debug, Clone)]
struct ToolRoute {
    server_name: String,
    tool_name: String,
}

pub struct McpToolProvider {
    runtime: McpRuntimeManager,
    specs: Vec<ToolSpec>,
    routes: Arc<RwLock<HashMap<String, ToolRoute>>>,
}

impl McpToolProvider {
    pub fn new(
        servers: Vec<NamedMcpServerConfig>,
        specs: Vec<ToolSpec>,
        default_cwd: std::path::PathBuf,
    ) -> Result<Self, ToolError> {
        let runtime_configs = servers
            .into_iter()
            .map(|server| {
                McpServerConfig::from_command(
                    server.name,
                    server.config.command,
                    server.config.args,
                    server.config.env,
                )
                .map_err(|message| ToolError::InvalidInput { message })
            })
            .collect::<Result<Vec<_>, _>>()?;

        let routes = specs
            .iter()
            .filter_map(|spec| {
                parse_namespaced_tool_name(&spec.name).map(|route| (spec.name.clone(), route))
            })
            .collect::<HashMap<_, _>>();

        Ok(Self {
            runtime: McpRuntimeManager::new(runtime_configs, default_cwd)
                .with_rpc_timeout(Duration::from_secs(20)),
            specs,
            routes: Arc::new(RwLock::new(routes)),
        })
    }
}

#[async_trait]
impl ToolProvider for McpToolProvider {
    fn specs(&self) -> Vec<ToolSpec> {
        self.specs.clone()
    }

    fn handles(&self, name: &str) -> bool {
        name.starts_with("mcp__")
    }

    async fn execute(
        &self,
        name: &str,
        input: Value,
        ctx: &ToolContext,
    ) -> Result<ToolOutput, ToolError> {
        let route =
            self.routes
                .read()
                .await
                .get(name)
                .cloned()
                .ok_or_else(|| ToolError::NotFound {
                    name: name.to_string(),
                })?;

        let result = self
            .runtime
            .call_tool(&route.server_name, &route.tool_name, input, &ctx.cwd)
            .await
            .map_err(|e| ToolError::ExecutionFailed {
                message: e.to_string(),
            })?;

        let content = result
            .get("content")
            .and_then(Value::as_array)
            .cloned()
            .unwrap_or_default();

        let is_error = result
            .get("isError")
            .and_then(Value::as_bool)
            .unwrap_or(false);

        let text = content
            .iter()
            .filter_map(|entry| {
                if entry.get("type").and_then(Value::as_str) == Some("text") {
                    entry
                        .get("text")
                        .and_then(Value::as_str)
                        .map(|v| v.to_string())
                } else {
                    None
                }
            })
            .collect::<Vec<_>>()
            .join("\n");

        if is_error {
            return Err(ToolError::ExecutionFailed { message: text });
        }

        if text.is_empty() {
            Ok(ToolOutput::json(result))
        } else {
            Ok(ToolOutput::text(text))
        }
    }
}

fn parse_namespaced_tool_name(name: &str) -> Option<ToolRoute> {
    let without_prefix = name.strip_prefix("mcp__")?;
    let (server_name, tool_name) = without_prefix.split_once("__")?;
    Some(ToolRoute {
        server_name: server_name.to_string(),
        tool_name: tool_name.to_string(),
    })
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn parse_namespaced_name() {
        let route = parse_namespaced_tool_name("mcp__filesystem__read_file").unwrap();
        assert_eq!(route.server_name, "filesystem");
        assert_eq!(route.tool_name, "read_file");
    }
}