rs_utcp/
lib.rs

1pub mod auth;
2pub mod call_templates;
3pub mod config;
4pub mod errors;
5pub mod grpcpb;
6pub mod loader;
7pub mod migration;
8pub mod openapi;
9pub mod plugins;
10pub mod providers;
11pub mod repository;
12pub mod security;
13pub mod spec;
14pub mod tag;
15pub mod tools;
16pub mod transports;
17
18#[cfg(test)]
19mod allowed_protocols_tests;
20
21use anyhow::{anyhow, Result};
22use async_trait::async_trait;
23use std::collections::HashMap;
24use std::sync::Arc;
25use tokio::sync::RwLock;
26
27use crate::config::UtcpClientConfig;
28use crate::errors::UtcpError;
29use crate::openapi::OpenApiConverter;
30use crate::providers::base::{Provider, ProviderType};
31use crate::providers::http::HttpProvider;
32use crate::repository::ToolRepository;
33use crate::tools::{Tool, ToolSearchStrategy};
34use crate::transports::registry::{
35    communication_protocols_snapshot, CommunicationProtocolRegistry,
36};
37use crate::transports::stream::StreamResult;
38use crate::transports::CommunicationProtocol;
39
40/// UtcpClientInterface defines the core operations for a UTCP client.
41/// It allows registering/deregistering tool providers, calling tools, and searching for tools.
42#[async_trait]
43pub trait UtcpClientInterface: Send + Sync {
44    /// Registers a new tool provider and returns the list of tools it offers.
45    async fn register_tool_provider(&self, prov: Arc<dyn Provider>) -> Result<Vec<Tool>>;
46
47    /// Registers a tool provider with a specific set of tools, overriding automatic discovery.
48    async fn register_tool_provider_with_tools(
49        &self,
50        prov: Arc<dyn Provider>,
51        tools: Vec<Tool>,
52    ) -> Result<Vec<Tool>>;
53
54    /// Deregisters an existing tool provider by its name.
55    async fn deregister_tool_provider(&self, provider_name: &str) -> Result<()>;
56
57    /// Calls a specific tool by name with the provided arguments.
58    async fn call_tool(
59        &self,
60        tool_name: &str,
61        args: HashMap<String, serde_json::Value>,
62    ) -> Result<serde_json::Value>;
63
64    /// Searches for tools matching the query string, limited by the count.
65    async fn search_tools(&self, query: &str, limit: usize) -> Result<Vec<Tool>>;
66
67    /// Returns a map of available transports (communication protocols).
68    fn get_transports(&self) -> HashMap<String, Arc<dyn CommunicationProtocol>>;
69
70    /// Alias for get_transports.
71    fn get_communication_protocols(&self) -> HashMap<String, Arc<dyn CommunicationProtocol>> {
72        self.get_transports()
73    }
74
75    /// Calls a tool and returns a stream of results (e.g., for SSE).
76    async fn call_tool_stream(
77        &self,
78        tool_name: &str,
79        args: HashMap<String, serde_json::Value>,
80    ) -> Result<Box<dyn StreamResult>>;
81}
82
83/// UtcpClient is the main entry point for the UTCP library.
84/// It manages tool providers, communication protocols, and tool execution.
85pub struct UtcpClient {
86    config: UtcpClientConfig,
87    communication_protocols: CommunicationProtocolRegistry,
88    tool_repository: Arc<dyn ToolRepository>,
89    search_strategy: Arc<dyn ToolSearchStrategy>,
90
91    provider_tools_cache: RwLock<HashMap<String, Vec<Tool>>>,
92    resolved_tools_cache: RwLock<HashMap<String, ResolvedTool>>,
93}
94
95/// ResolvedTool represents a tool that has been resolved to a specific provider and protocol.
96#[derive(Clone)]
97struct ResolvedTool {
98    provider: Arc<dyn Provider>,
99    protocol: Arc<dyn CommunicationProtocol>,
100    call_name: String,
101}
102
103impl UtcpClient {
104    /// v1.0-style async factory for symmetry with other language SDKs
105    pub async fn create(
106        config: UtcpClientConfig,
107        repo: Arc<dyn ToolRepository>,
108        strat: Arc<dyn ToolSearchStrategy>,
109    ) -> Result<Self> {
110        Self::new(config, repo, strat).await
111    }
112
113    /// Create a new UtcpClient and automatically load providers from the JSON file specified in config
114    pub async fn new(
115        config: UtcpClientConfig,
116        repo: Arc<dyn ToolRepository>,
117        strat: Arc<dyn ToolSearchStrategy>,
118    ) -> Result<Self> {
119        let communication_protocols = communication_protocols_snapshot();
120
121        let client = Self {
122            config,
123            communication_protocols,
124            tool_repository: repo,
125            search_strategy: strat,
126            provider_tools_cache: RwLock::new(HashMap::new()),
127            resolved_tools_cache: RwLock::new(HashMap::new()),
128        };
129
130        // Load providers if file path is specified
131        if let Some(providers_path) = &client.config.providers_file_path {
132            let providers =
133                crate::loader::load_providers_with_tools_from_file(providers_path, &client.config)
134                    .await?;
135
136            for loaded in providers {
137                let result = if let Some(tools) = loaded.tools {
138                    client
139                        .register_tool_provider_with_tools(loaded.provider.clone(), tools)
140                        .await
141                } else {
142                    client.register_tool_provider(loaded.provider.clone()).await
143                };
144
145                match result {
146                    Ok(tools) => {
147                        println!("✓ Loaded provider with {} tools", tools.len());
148                    }
149                    Err(e) => {
150                        eprintln!("✗ Failed to load provider: {}", e);
151                    }
152                }
153            }
154        }
155
156        Ok(client)
157    }
158
159    /// Determines the correct call name for a tool based on its provider type.
160    fn call_name_for_provider(tool_name: &str, provider_type: &ProviderType) -> String {
161        match provider_type {
162            ProviderType::Mcp | ProviderType::Text => tool_name
163                .splitn(2, '.')
164                .nth(1)
165                .unwrap_or(tool_name)
166                .to_string(),
167            _ => tool_name.to_string(),
168        }
169    }
170
171    /// Validates that the protocol is allowed by the provider.
172    fn validate_allowed_protocol(resolved: &ResolvedTool, tool_name: &str) -> Result<()> {
173        let provider_allowed_protocols = resolved.provider.allowed_protocols();
174        let tool_protocol = resolved.provider.type_().as_key();
175
176        if !provider_allowed_protocols.contains(&tool_protocol.to_string()) {
177            return Err(anyhow!(
178                "Tool '{}' uses communication protocol '{}' which is not allowed by its provider. Allowed protocols: {:?}",
179                tool_name,
180                tool_protocol,
181                provider_allowed_protocols
182            ));
183        }
184
185        Ok(())
186    }
187
188    /// Resolves a tool name to a `ResolvedTool` containing the provider and protocol.
189    /// Handles both fully qualified names (provider.tool) and bare names.
190    async fn resolve_tool(&self, tool_name: &str) -> Result<ResolvedTool> {
191        {
192            let cache = self.resolved_tools_cache.read().await;
193            if let Some(resolved) = cache.get(tool_name) {
194                return Ok(resolved.clone());
195            }
196        }
197
198        // Legacy qualified name flow
199        if let Some((provider_name, suffix)) = tool_name.split_once('.') {
200            if provider_name.is_empty() {
201                return Err(UtcpError::Config(format!("Invalid tool name: {}", tool_name)).into());
202            }
203
204            let prov = self
205                .tool_repository
206                .get_provider(provider_name)
207                .await?
208                .ok_or_else(|| UtcpError::ToolNotFound(provider_name.to_string()))?;
209            let provider_type = prov.type_();
210
211            let protocol_key = provider_type.as_key().to_string();
212            let protocol = self
213                .communication_protocols
214                .get(&protocol_key)
215                .ok_or_else(|| {
216                    UtcpError::Config(format!(
217                        "No communication protocol found for provider type: {:?}",
218                        provider_type
219                    ))
220                })?
221                .clone();
222
223            let call_name = Self::call_name_for_provider(tool_name, &provider_type);
224            let resolved = ResolvedTool {
225                provider: prov.clone(),
226                protocol: protocol.clone(),
227                call_name,
228            };
229
230            let mut cache = self.resolved_tools_cache.write().await;
231            cache.insert(tool_name.to_string(), resolved.clone());
232            cache.insert(suffix.to_string(), resolved.clone());
233            return Ok(resolved);
234        }
235
236        // v1.0 bare tool names: search cached provider tools
237        {
238            let cache = self.provider_tools_cache.read().await;
239            for (prov_name, tools) in cache.iter() {
240                if tools.iter().any(|t| {
241                    t.name
242                        .split_once('.')
243                        .map(|(_, suffix)| suffix == tool_name)
244                        .unwrap_or(false)
245                }) {
246                    let prov = self
247                        .tool_repository
248                        .get_provider(prov_name)
249                        .await?
250                        .ok_or_else(|| UtcpError::ToolNotFound(prov_name.clone()))?;
251                    let provider_type = prov.type_();
252                    let protocol_key = provider_type.as_key().to_string();
253                    let protocol = self
254                        .communication_protocols
255                        .get(&protocol_key)
256                        .ok_or_else(|| {
257                            UtcpError::Config(format!(
258                                "No communication protocol found for provider type: {:?}",
259                                provider_type
260                            ))
261                        })?
262                        .clone();
263
264                    let full_name = format!("{}.{}", prov_name, tool_name);
265                    let call_name = Self::call_name_for_provider(&full_name, &provider_type);
266                    let resolved = ResolvedTool {
267                        provider: prov.clone(),
268                        protocol: protocol.clone(),
269                        call_name,
270                    };
271
272                    let mut rcache = self.resolved_tools_cache.write().await;
273                    rcache.insert(full_name, resolved.clone());
274                    rcache.insert(tool_name.to_string(), resolved.clone());
275                    return Ok(resolved);
276                }
277            }
278        }
279
280        Err(UtcpError::ToolNotFound(tool_name.to_string()).into())
281    }
282}
283
284#[async_trait]
285impl UtcpClientInterface for UtcpClient {
286    async fn register_tool_provider(&self, prov: Arc<dyn Provider>) -> Result<Vec<Tool>> {
287        self.register_tool_provider_with_tools(prov, Vec::new())
288            .await
289    }
290
291    async fn register_tool_provider_with_tools(
292        &self,
293        prov: Arc<dyn Provider>,
294        tools_override: Vec<Tool>,
295    ) -> Result<Vec<Tool>> {
296        let provider_name = prov.name();
297        let provider_type = prov.type_();
298
299        // Check cache first
300        {
301            let cache = self.provider_tools_cache.read().await;
302            if let Some(tools) = cache.get(&provider_name) {
303                return Ok(tools.clone());
304            }
305        }
306
307        // Get communication protocol for this provider type
308        let protocol_key = provider_type.as_key().to_string();
309        let protocol = self
310            .communication_protocols
311            .get(&protocol_key)
312            .ok_or_else(|| {
313                anyhow!(
314                    "No communication protocol found for provider type: {:?}",
315                    provider_type
316                )
317            })?
318            .clone();
319
320        // Register with protocol
321        let tools = if !tools_override.is_empty() {
322            tools_override
323        } else if provider_type == ProviderType::Http {
324            if let Some(http_prov) = prov.as_any().downcast_ref::<HttpProvider>() {
325                match OpenApiConverter::new_from_url(&http_prov.url, Some(provider_name.clone()))
326                    .await
327                {
328                    Ok(converter) => {
329                        let manual = converter.convert();
330                        if manual.tools.is_empty() {
331                            protocol.register_tool_provider(prov.as_ref()).await?
332                        } else {
333                            manual.tools
334                        }
335                    }
336                    Err(_) => protocol.register_tool_provider(prov.as_ref()).await?,
337                }
338            } else {
339                protocol.register_tool_provider(prov.as_ref()).await?
340            }
341        } else {
342            protocol.register_tool_provider(prov.as_ref()).await?
343        };
344
345        // Normalize tool names (prefix with provider name)
346        let mut normalized_tools = Vec::new();
347        for mut tool in tools {
348            if !tool.name.starts_with(&format!("{}.", provider_name)) {
349                tool.name = format!("{}.{}", provider_name, tool.name.trim_start_matches('.'));
350            }
351            normalized_tools.push(tool);
352        }
353
354        // Save to repository
355        self.tool_repository
356            .save_provider_with_tools(prov.clone(), normalized_tools.clone())
357            .await?;
358
359        // Update cache
360        {
361            let mut cache = self.provider_tools_cache.write().await;
362            cache.insert(provider_name, normalized_tools.clone());
363        }
364
365        {
366            let mut resolved = self.resolved_tools_cache.write().await;
367            for tool in &normalized_tools {
368                let call_name = Self::call_name_for_provider(&tool.name, &provider_type);
369                let resolved_entry = ResolvedTool {
370                    provider: prov.clone(),
371                    protocol: protocol.clone(),
372                    call_name,
373                };
374
375                // Full name
376                resolved.insert(tool.name.clone(), resolved_entry.clone());
377
378                // Bare name (v1.0 style)
379                if let Some((_, bare)) = tool.name.split_once('.') {
380                    resolved.insert(bare.to_string(), resolved_entry);
381                }
382            }
383        }
384
385        Ok(normalized_tools)
386    }
387
388    async fn deregister_tool_provider(&self, provider_name: &str) -> Result<()> {
389        // Get provider from repository
390        let prov = self
391            .tool_repository
392            .get_provider(provider_name)
393            .await?
394            .ok_or_else(|| anyhow!("Provider not found: {}", provider_name))?;
395
396        // Get communication protocol
397        let provider_type = prov.type_();
398        let protocol_key = provider_type.as_key().to_string();
399        let protocol = self
400            .communication_protocols
401            .get(&protocol_key)
402            .ok_or_else(|| {
403                anyhow!(
404                    "No communication protocol found for provider type: {:?}",
405                    provider_type
406                )
407            })?;
408
409        // Deregister from protocol
410        protocol.deregister_tool_provider(prov.as_ref()).await?;
411
412        // Remove from repository
413        self.tool_repository.remove_provider(provider_name).await?;
414
415        // Clear cache
416        {
417            let mut cache = self.provider_tools_cache.write().await;
418            cache.remove(provider_name);
419        }
420        {
421            let mut resolved = self.resolved_tools_cache.write().await;
422            resolved.retain(|tool_name, _| !tool_name.starts_with(&format!("{}.", provider_name)));
423        }
424
425        Ok(())
426    }
427
428    async fn call_tool(
429        &self,
430        tool_name: &str,
431        args: HashMap<String, serde_json::Value>,
432    ) -> Result<serde_json::Value> {
433        let resolved = self.resolve_tool(tool_name).await?;
434
435        // Validate protocol is allowed by the provider
436        Self::validate_allowed_protocol(&resolved, tool_name)?;
437
438        resolved
439            .protocol
440            .call_tool(&resolved.call_name, args, resolved.provider.as_ref())
441            .await
442    }
443
444    async fn search_tools(&self, query: &str, limit: usize) -> Result<Vec<Tool>> {
445        self.search_strategy.search_tools(query, limit).await
446    }
447
448    fn get_transports(&self) -> HashMap<String, Arc<dyn CommunicationProtocol>> {
449        self.communication_protocols.as_map()
450    }
451
452    async fn call_tool_stream(
453        &self,
454        tool_name: &str,
455        args: HashMap<String, serde_json::Value>,
456    ) -> Result<Box<dyn StreamResult>> {
457        let resolved = self.resolve_tool(tool_name).await?;
458
459        // Validate protocol is allowed by the provider
460        Self::validate_allowed_protocol(&resolved, tool_name)?;
461
462        resolved
463            .protocol
464            .call_tool_stream(&resolved.call_name, args, resolved.provider.as_ref())
465            .await
466    }
467}