1use std::collections::HashMap;
7use std::sync::Arc;
8
9use anyhow::{anyhow, Result};
10use async_trait::async_trait;
11use futures::future::BoxFuture;
12use parking_lot::RwLock;
13use rs_utcp::providers::base::Provider;
14use rs_utcp::providers::cli::CliProvider;
15use rs_utcp::tools::Tool as UtcpTool;
16use rs_utcp::transports::stream::StreamResult;
17use rs_utcp::transports::CommunicationProtocol;
18use serde_json::Value;
19
20pub type InProcessHandler =
22 Arc<dyn Fn(HashMap<String, Value>) -> BoxFuture<'static, Result<Value>> + Send + Sync>;
23
24#[derive(Clone)]
26pub struct InProcessTool {
27 pub spec: UtcpTool,
28 pub handler: InProcessHandler,
29}
30
31pub struct AgentCliTransport {
36 inner: Arc<dyn CommunicationProtocol>,
37 tools: RwLock<HashMap<String, Vec<InProcessTool>>>,
38}
39
40impl AgentCliTransport {
41 pub fn new(inner: Arc<dyn CommunicationProtocol>) -> Self {
42 Self {
43 inner,
44 tools: RwLock::new(HashMap::new()),
45 }
46 }
47
48 pub fn register(&self, provider: &str, tool: InProcessTool) {
49 let mut guard = self.tools.write();
50 guard.entry(provider.to_string()).or_default().push(tool);
51 }
52
53 fn lookup_handler(&self, provider: &str, tool_name: &str) -> Option<InProcessHandler> {
54 let guard = self.tools.read();
55 let list = guard.get(provider)?;
56 let handler = list.iter().find(|t| {
57 t.spec.name == tool_name
58 || t.spec
59 .name
60 .rsplit('.')
61 .next()
62 .map(|suffix| suffix == tool_name)
63 .unwrap_or(false)
64 })?;
65 Some(handler.handler.clone())
66 }
67
68 fn specs_for(&self, provider: &str) -> Option<Vec<UtcpTool>> {
69 let guard = self.tools.read();
70 guard
71 .get(provider)
72 .map(|tools| tools.iter().map(|t| t.spec.clone()).collect())
73 }
74}
75
76#[async_trait]
77impl CommunicationProtocol for AgentCliTransport {
78 async fn register_tool_provider(&self, prov: &dyn Provider) -> Result<Vec<UtcpTool>> {
79 if let Some(cli) = prov.as_any().downcast_ref::<CliProvider>() {
80 if let Some(specs) = self.specs_for(&cli.base.name) {
81 return Ok(specs);
82 }
83 }
84 self.inner.register_tool_provider(prov).await
85 }
86
87 async fn deregister_tool_provider(&self, prov: &dyn Provider) -> Result<()> {
88 if let Some(cli) = prov.as_any().downcast_ref::<CliProvider>() {
89 if self.tools.write().remove(&cli.base.name).is_some() {
90 return Ok(());
91 }
92 }
93 self.inner.deregister_tool_provider(prov).await
94 }
95
96 async fn call_tool(
97 &self,
98 tool_name: &str,
99 args: HashMap<String, Value>,
100 prov: &dyn Provider,
101 ) -> Result<Value> {
102 if let Some(cli) = prov.as_any().downcast_ref::<CliProvider>() {
103 if let Some(handler) = self.lookup_handler(&cli.base.name, tool_name) {
104 return handler(args).await;
105 }
106 }
107 self.inner.call_tool(tool_name, args, prov).await
108 }
109
110 async fn call_tool_stream(
111 &self,
112 tool_name: &str,
113 args: HashMap<String, Value>,
114 prov: &dyn Provider,
115 ) -> Result<Box<dyn StreamResult>> {
116 if let Some(cli) = prov.as_any().downcast_ref::<CliProvider>() {
117 if self.tools.read().contains_key(&cli.base.name) {
118 return Err(anyhow!(
119 "Streaming not supported for in-process tool {}",
120 tool_name
121 ));
122 }
123 }
124 self.inner.call_tool_stream(tool_name, args, prov).await
125 }
126}
127
128pub fn ensure_agent_cli_transport() -> Arc<AgentCliTransport> {
130 use std::sync::OnceLock;
131
132 static TRANSPORT: OnceLock<Arc<AgentCliTransport>> = OnceLock::new();
133
134 TRANSPORT
135 .get_or_init(|| {
136 let snapshot = rs_utcp::transports::communication_protocols_snapshot();
137 let fallback = snapshot
138 .get("cli")
139 .unwrap_or_else(|| Arc::new(rs_utcp::transports::cli::CliTransport::new()));
140
141 let shim = Arc::new(AgentCliTransport::new(fallback));
142 rs_utcp::transports::register_communication_protocol("cli", shim.clone());
144 shim
145 })
146 .clone()
147}
148
149#[cfg(test)]
150mod tests {
151 use super::*;
152
153 #[test]
154 fn agent_cli_transport_initializes() {
155 let transport = ensure_agent_cli_transport();
156 assert!(transport.specs_for("nonexistent").is_none());
157 }
158}