Skip to main content

rs_agent/
agent_tool.rs

1//! UTCP tool integration and adapters
2//!
3//! This module provides the bridge between rs-agent and UTCP tools, matching the
4//! structure from go-agent's agent_tool.go.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use anyhow::{anyhow, Result};
10use async_trait::async_trait;
11use futures::future::BoxFuture;
12use parking_lot::RwLock;
13use rs_utcp::providers::base::Provider;
14use rs_utcp::providers::cli::CliProvider;
15use rs_utcp::tools::Tool as UtcpTool;
16use rs_utcp::transports::stream::StreamResult;
17use rs_utcp::transports::CommunicationProtocol;
18use serde_json::Value;
19
20/// Handler type for in-process UTCP tools.
21pub type InProcessHandler =
22    Arc<dyn Fn(HashMap<String, Value>) -> BoxFuture<'static, Result<Value>> + Send + Sync>;
23
24/// UTCP tool paired with an in-process handler.
25#[derive(Clone)]
26pub struct InProcessTool {
27    pub spec: UtcpTool,
28    pub handler: InProcessHandler,
29}
30
31/// Transport shim that routes CLI providers to in-process handlers while
32/// delegating everything else to the original transport.
33///
34/// This matches go-agent's agentCLITransport structure.
35pub struct AgentCliTransport {
36    inner: Arc<dyn CommunicationProtocol>,
37    tools: RwLock<HashMap<String, Vec<InProcessTool>>>,
38}
39
40impl AgentCliTransport {
41    pub fn new(inner: Arc<dyn CommunicationProtocol>) -> Self {
42        Self {
43            inner,
44            tools: RwLock::new(HashMap::new()),
45        }
46    }
47
48    pub fn register(&self, provider: &str, tool: InProcessTool) {
49        let mut guard = self.tools.write();
50        guard.entry(provider.to_string()).or_default().push(tool);
51    }
52
53    fn lookup_handler(&self, provider: &str, tool_name: &str) -> Option<InProcessHandler> {
54        let guard = self.tools.read();
55        let list = guard.get(provider)?;
56        let handler = list.iter().find(|t| {
57            t.spec.name == tool_name
58                || t.spec
59                    .name
60                    .rsplit('.')
61                    .next()
62                    .map(|suffix| suffix == tool_name)
63                    .unwrap_or(false)
64        })?;
65        Some(handler.handler.clone())
66    }
67
68    fn specs_for(&self, provider: &str) -> Option<Vec<UtcpTool>> {
69        let guard = self.tools.read();
70        guard
71            .get(provider)
72            .map(|tools| tools.iter().map(|t| t.spec.clone()).collect())
73    }
74}
75
76#[async_trait]
77impl CommunicationProtocol for AgentCliTransport {
78    async fn register_tool_provider(&self, prov: &dyn Provider) -> Result<Vec<UtcpTool>> {
79        if let Some(cli) = prov.as_any().downcast_ref::<CliProvider>() {
80            if let Some(specs) = self.specs_for(&cli.base.name) {
81                return Ok(specs);
82            }
83        }
84        self.inner.register_tool_provider(prov).await
85    }
86
87    async fn deregister_tool_provider(&self, prov: &dyn Provider) -> Result<()> {
88        if let Some(cli) = prov.as_any().downcast_ref::<CliProvider>() {
89            if self.tools.write().remove(&cli.base.name).is_some() {
90                return Ok(());
91            }
92        }
93        self.inner.deregister_tool_provider(prov).await
94    }
95
96    async fn call_tool(
97        &self,
98        tool_name: &str,
99        args: HashMap<String, Value>,
100        prov: &dyn Provider,
101    ) -> Result<Value> {
102        if let Some(cli) = prov.as_any().downcast_ref::<CliProvider>() {
103            if let Some(handler) = self.lookup_handler(&cli.base.name, tool_name) {
104                return handler(args).await;
105            }
106        }
107        self.inner.call_tool(tool_name, args, prov).await
108    }
109
110    async fn call_tool_stream(
111        &self,
112        tool_name: &str,
113        args: HashMap<String, Value>,
114        prov: &dyn Provider,
115    ) -> Result<Box<dyn StreamResult>> {
116        if let Some(cli) = prov.as_any().downcast_ref::<CliProvider>() {
117            if self.tools.read().contains_key(&cli.base.name) {
118                return Err(anyhow!(
119                    "Streaming not supported for in-process tool {}",
120                    tool_name
121                ));
122            }
123        }
124        self.inner.call_tool_stream(tool_name, args, prov).await
125    }
126}
127
128/// Register (or retrieve) the global agent CLI transport, ensuring it replaces the default CLI transport.
129pub fn ensure_agent_cli_transport() -> Arc<AgentCliTransport> {
130    use std::sync::OnceLock;
131
132    static TRANSPORT: OnceLock<Arc<AgentCliTransport>> = OnceLock::new();
133
134    TRANSPORT
135        .get_or_init(|| {
136            let snapshot = rs_utcp::transports::communication_protocols_snapshot();
137            let fallback = snapshot
138                .get("cli")
139                .unwrap_or_else(|| Arc::new(rs_utcp::transports::cli::CliTransport::new()));
140
141            let shim = Arc::new(AgentCliTransport::new(fallback));
142            // Replace the global CLI transport so existing clients pick up the shim.
143            rs_utcp::transports::register_communication_protocol("cli", shim.clone());
144            shim
145        })
146        .clone()
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152
153    #[test]
154    fn agent_cli_transport_initializes() {
155        let transport = ensure_agent_cli_transport();
156        assert!(transport.specs_for("nonexistent").is_none());
157    }
158}