1use std::sync::{Arc, RwLock};
6
7use async_trait::async_trait;
8use breaker_machines::{CircuitBreaker, Config as BreakerConfig};
9use dashmap::DashMap;
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use thiserror::Error;
13use tokio::sync::mpsc;
14
15use crate::content::types::{Content, TextContent, ImageContent};
16use crate::server::multiplexer::ClientRequester;
17use crate::server::session::Session;
18use crate::server::visibility::{ExecutionContext, VisibilityContext};
19use crate::transport::traits::JsonRpcNotification;
20
21#[derive(Debug, Error)]
23pub enum ToolError {
24 #[error("Tool not found: {0}")]
26 NotFound(String),
27
28 #[error("Invalid arguments: {0}")]
30 InvalidArguments(String),
31
32 #[error("Execution error: {0}")]
34 Execution(String),
35
36 #[error("Internal error: {0}")]
38 Internal(String),
39
40 #[error("Circuit breaker open for tool '{tool}': {message}")]
42 CircuitOpen { tool: String, message: String },
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct ToolInfo {
48 pub name: String,
50
51 #[serde(skip_serializing_if = "Option::is_none")]
53 pub description: Option<String>,
54
55 #[serde(rename = "inputSchema")]
57 pub input_schema: Value,
58
59 #[serde(skip_serializing_if = "Option::is_none")]
61 pub execution: Option<crate::protocol::types::ToolExecution>,
62}
63
64#[async_trait]
66pub trait Tool: Send + Sync {
67 fn name(&self) -> &str;
69
70 fn description(&self) -> Option<&str> {
72 None
73 }
74
75 fn input_schema(&self) -> Value;
77
78 fn execution(&self) -> Option<crate::protocol::types::ToolExecution> {
83 None
84 }
85
86 fn is_visible(&self, _ctx: &VisibilityContext) -> bool {
102 true
103 }
104
105 async fn execute(&self, ctx: ExecutionContext<'_>) -> Result<Vec<Box<dyn Content>>, ToolError>;
127}
128
129pub trait ToolHelpers {
134 fn text(&self, content: &str) -> Box<dyn Content> {
136 Box::new(TextContent::new(content))
137 }
138
139 fn image(&self, data: &str, mime_type: &str) -> Box<dyn Content> {
141 Box::new(ImageContent::new(data, mime_type))
142 }
143}
144
145impl<T: Tool + ?Sized> ToolHelpers for T {}
147
148#[derive(Debug, Clone)]
150pub struct ToolBreakerConfig {
151 pub failure_threshold: usize,
153 pub failure_window_secs: f64,
155 pub half_open_timeout_secs: f64,
157 pub success_threshold: usize,
159}
160
161impl Default for ToolBreakerConfig {
162 fn default() -> Self {
163 Self {
164 failure_threshold: 5,
165 failure_window_secs: 60.0,
166 half_open_timeout_secs: 30.0,
167 success_threshold: 2,
168 }
169 }
170}
171
172impl From<ToolBreakerConfig> for BreakerConfig {
173 fn from(cfg: ToolBreakerConfig) -> Self {
174 BreakerConfig {
175 failure_threshold: Some(cfg.failure_threshold),
176 failure_rate_threshold: None,
177 minimum_calls: 1,
178 failure_window_secs: cfg.failure_window_secs,
179 half_open_timeout_secs: cfg.half_open_timeout_secs,
180 success_threshold: cfg.success_threshold,
181 jitter_factor: 0.1,
182 }
183 }
184}
185
186#[derive(Clone)]
188pub struct ToolRegistry {
189 tools: Arc<DashMap<String, Arc<dyn Tool>>>,
190 breakers: Arc<DashMap<String, Arc<RwLock<CircuitBreaker>>>>,
191 breaker_config: Arc<RwLock<ToolBreakerConfig>>,
192 notification_tx: Option<mpsc::UnboundedSender<JsonRpcNotification>>,
193}
194
195impl ToolRegistry {
196 pub fn new() -> Self {
198 Self {
199 tools: Arc::new(DashMap::new()),
200 breakers: Arc::new(DashMap::new()),
201 breaker_config: Arc::new(RwLock::new(ToolBreakerConfig::default())),
202 notification_tx: None,
203 }
204 }
205
206 pub fn with_notifications(notification_tx: mpsc::UnboundedSender<JsonRpcNotification>) -> Self {
208 Self {
209 tools: Arc::new(DashMap::new()),
210 breakers: Arc::new(DashMap::new()),
211 breaker_config: Arc::new(RwLock::new(ToolBreakerConfig::default())),
212 notification_tx: Some(notification_tx),
213 }
214 }
215
216 pub fn set_notification_tx(&mut self, tx: mpsc::UnboundedSender<JsonRpcNotification>) {
218 self.notification_tx = Some(tx);
219 }
220
221 pub fn set_breaker_config(&self, config: ToolBreakerConfig) {
223 if let Ok(mut cfg) = self.breaker_config.write() {
224 *cfg = config;
225 }
226 }
227
228 pub fn register<T: Tool + 'static>(&self, tool: T) {
230 let name = tool.name().to_string();
231
232 let breaker_config = self.breaker_config.read()
234 .map(|c| c.clone())
235 .unwrap_or_default();
236
237 let breaker = CircuitBreaker::builder(&name)
238 .failure_threshold(breaker_config.failure_threshold)
239 .failure_window_secs(breaker_config.failure_window_secs)
240 .half_open_timeout_secs(breaker_config.half_open_timeout_secs)
241 .success_threshold(breaker_config.success_threshold)
242 .build();
243
244 self.breakers.insert(name.clone(), Arc::new(RwLock::new(breaker)));
245 self.tools.insert(name, Arc::new(tool));
246 }
247
248 pub fn register_boxed(&self, tool: Arc<dyn Tool>) {
250 let name = tool.name().to_string();
251
252 let breaker_config = self.breaker_config.read()
253 .map(|c| c.clone())
254 .unwrap_or_default();
255
256 let breaker = CircuitBreaker::builder(&name)
257 .failure_threshold(breaker_config.failure_threshold)
258 .failure_window_secs(breaker_config.failure_window_secs)
259 .half_open_timeout_secs(breaker_config.half_open_timeout_secs)
260 .success_threshold(breaker_config.success_threshold)
261 .build();
262
263 self.breakers.insert(name.clone(), Arc::new(RwLock::new(breaker)));
264 self.tools.insert(name, tool);
265 }
266
267 pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
269 self.tools.get(name).map(|t| Arc::clone(&t))
270 }
271
272 pub fn is_circuit_open(&self, name: &str) -> bool {
274 self.breakers
275 .get(name)
276 .and_then(|b| b.read().ok().map(|breaker| breaker.is_open()))
277 .unwrap_or(false)
278 }
279
280 pub fn list(&self) -> Vec<ToolInfo> {
282 self.tools
283 .iter()
284 .map(|entry| {
285 let tool = entry.value();
286 ToolInfo {
287 name: tool.name().to_string(),
288 description: tool.description().map(|s| s.to_string()),
289 input_schema: tool.input_schema(),
290 execution: tool.execution(),
291 }
292 })
293 .collect()
294 }
295
296 pub fn list_available(&self) -> Vec<ToolInfo> {
298 self.tools
299 .iter()
300 .filter(|entry| !self.is_circuit_open(entry.key()))
301 .map(|entry| {
302 let tool = entry.value();
303 ToolInfo {
304 name: tool.name().to_string(),
305 description: tool.description().map(|s| s.to_string()),
306 input_schema: tool.input_schema(),
307 execution: tool.execution(),
308 }
309 })
310 .collect()
311 }
312
313 fn send_notification(&self, method: &str, params: Option<Value>) {
315 if let Some(tx) = &self.notification_tx {
316 let notification = JsonRpcNotification::new(method, params);
317 let _ = tx.send(notification);
318 }
319 }
320
321 fn notify_tools_changed(&self) {
323 self.send_notification("notifications/tools/list_changed", None);
324 }
325
326 fn notify_message(&self, level: &str, logger: &str, message: &str) {
328 self.send_notification(
329 "notifications/message",
330 Some(serde_json::json!({
331 "level": level,
332 "logger": logger,
333 "data": message
334 })),
335 );
336 }
337
338 pub async fn call(
347 &self,
348 name: &str,
349 params: Value,
350 session: &Session,
351 client_requester: Option<ClientRequester>,
352 ) -> Result<Vec<Box<dyn Content>>, ToolError> {
353 let tool = self
354 .get(name)
355 .ok_or_else(|| ToolError::NotFound(name.to_string()))?;
356
357 let breaker = self
358 .breakers
359 .get(name)
360 .ok_or_else(|| ToolError::Internal(format!("No circuit breaker for tool '{}'", name)))?;
361
362 let was_open = {
364 let breaker_guard = breaker.read()
365 .map_err(|e| ToolError::Internal(format!("Breaker lock error: {}", e)))?;
366 breaker_guard.is_open()
367 };
368
369 if was_open {
370 return Err(ToolError::CircuitOpen {
371 tool: name.to_string(),
372 message: "Too many recent failures. Service temporarily unavailable.".to_string(),
373 });
374 }
375
376 let ctx = match client_requester {
378 Some(cr) => ExecutionContext::new(params, session).with_client_requester(cr),
379 None => ExecutionContext::new(params, session),
380 };
381
382 let start = std::time::Instant::now();
384 let result = tool.execute(ctx).await;
385 let duration_secs = start.elapsed().as_secs_f64();
386
387 let breaker_guard = breaker.write()
389 .map_err(|e| ToolError::Internal(format!("Breaker lock error: {}", e)))?;
390
391 let was_closed_before = !breaker_guard.is_open();
392
393 match &result {
394 Ok(_) => {
395 breaker_guard.record_success(duration_secs);
396 if was_open && !breaker_guard.is_open() {
398 self.notify_tools_changed();
399 self.notify_message(
400 "info",
401 "breaker-machines",
402 &format!("Tool '{}' recovered and available", name),
403 );
404 }
405 }
406 Err(_) => {
407 breaker_guard.record_failure(duration_secs);
408 if was_closed_before && breaker_guard.is_open() {
410 self.notify_tools_changed();
411 self.notify_message(
412 "warning",
413 "breaker-machines",
414 &format!("Tool '{}' disabled: circuit breaker open after failures", name),
415 );
416 }
417 }
418 }
419
420 result
421 }
422
423 pub fn list_for_session(&self, session: &Session, ctx: &VisibilityContext<'_>) -> Vec<ToolInfo> {
432 let mut tools = std::collections::HashMap::new();
433
434 for entry in self.tools.iter() {
436 let name = entry.key().clone();
437 if !session.is_tool_hidden(&name) && !self.is_circuit_open(&name) {
438 let tool = entry.value();
439 if tool.is_visible(ctx) {
440 tools.insert(
441 name,
442 ToolInfo {
443 name: tool.name().to_string(),
444 description: tool.description().map(|s| s.to_string()),
445 input_schema: tool.input_schema(),
446 execution: tool.execution(),
447 },
448 );
449 }
450 }
451 }
452
453 for entry in session.tool_extras().iter() {
455 let name = entry.key().clone();
456 let tool = entry.value();
457 if tool.is_visible(ctx) {
458 tools.insert(
459 name,
460 ToolInfo {
461 name: tool.name().to_string(),
462 description: tool.description().map(|s| s.to_string(),),
463 input_schema: tool.input_schema(),
464 execution: tool.execution(),
465 },
466 );
467 }
468 }
469
470 for entry in session.tool_overrides().iter() {
472 let name = entry.key().clone();
473 let tool = entry.value();
474 if tool.is_visible(ctx) {
475 tools.insert(
476 name,
477 ToolInfo {
478 name: tool.name().to_string(),
479 description: tool.description().map(|s| s.to_string(),),
480 input_schema: tool.input_schema(),
481 execution: tool.execution(),
482 },
483 );
484 }
485 }
486
487 tools.into_values().collect()
488 }
489
490 pub async fn call_for_session(
500 &self,
501 name: &str,
502 params: Value,
503 session: &Session,
504 visibility_ctx: &VisibilityContext<'_>,
505 client_requester: Option<ClientRequester>,
506 ) -> Result<Vec<Box<dyn Content>>, ToolError> {
507 let resolved_name = session.resolve_tool_alias(name);
509 let resolved = resolved_name.as_ref();
510
511 let exec_ctx = match (visibility_ctx.environment, client_requester.as_ref()) {
514 (Some(env), Some(cr)) => ExecutionContext::with_environment(params.clone(), session, env)
515 .with_client_requester(cr.clone()),
516 (Some(env), None) => ExecutionContext::with_environment(params.clone(), session, env),
517 (None, Some(cr)) => ExecutionContext::new(params.clone(), session)
518 .with_client_requester(cr.clone()),
519 (None, None) => ExecutionContext::new(params.clone(), session),
520 };
521
522 if let Some(tool) = session.get_tool_override(resolved) {
524 if !tool.is_visible(visibility_ctx) {
525 return Err(ToolError::NotFound(name.to_string()));
526 }
527 return tool.execute(exec_ctx).await;
528 }
529
530 if let Some(tool) = session.get_tool_extra(resolved) {
532 if !tool.is_visible(visibility_ctx) {
533 return Err(ToolError::NotFound(name.to_string()));
534 }
535 return tool.execute(exec_ctx).await;
536 }
537
538 if session.is_tool_hidden(resolved) {
540 return Err(ToolError::NotFound(name.to_string()));
541 }
542
543 let tool = self
545 .get(resolved)
546 .ok_or_else(|| ToolError::NotFound(name.to_string()))?;
547
548 if !tool.is_visible(visibility_ctx) {
549 return Err(ToolError::NotFound(name.to_string()));
550 }
551
552 self.call(resolved, params, session, client_requester).await
554 }
555
556 pub fn len(&self) -> usize {
558 self.tools.len()
559 }
560
561 pub fn is_empty(&self) -> bool {
563 self.tools.is_empty()
564 }
565}
566
567impl Default for ToolRegistry {
568 fn default() -> Self {
569 Self::new()
570 }
571}
572
573#[cfg(test)]
574mod tests {
575 use super::*;
576 use crate::content::types::TextContent;
577
578 struct EchoTool;
580
581 #[async_trait]
582 impl Tool for EchoTool {
583 fn name(&self) -> &str {
584 "echo"
585 }
586
587 fn description(&self) -> Option<&str> {
588 Some("Echoes back the input message")
589 }
590
591 fn input_schema(&self) -> Value {
592 serde_json::json!({
593 "type": "object",
594 "properties": {
595 "message": {
596 "type": "string",
597 "description": "Message to echo"
598 }
599 },
600 "required": ["message"]
601 })
602 }
603
604 async fn execute(&self, ctx: ExecutionContext<'_>) -> Result<Vec<Box<dyn Content>>, ToolError> {
605 let message = ctx.params
606 .get("message")
607 .and_then(|v| v.as_str())
608 .ok_or_else(|| ToolError::InvalidArguments("Missing 'message' field".to_string()))?;
609
610 let content = TextContent::new(format!("Echo: {}", message));
611 Ok(vec![Box::new(content)])
612 }
613 }
614
615 #[test]
616 fn test_registry_creation() {
617 let registry = ToolRegistry::new();
618 assert!(registry.is_empty());
619 }
620
621 #[test]
622 fn test_tool_registration() {
623 let registry = ToolRegistry::new();
624 registry.register(EchoTool);
625
626 assert_eq!(registry.len(), 1);
627 assert!(!registry.is_empty());
628 }
629
630 #[test]
631 fn test_get_tool() {
632 let registry = ToolRegistry::new();
633 registry.register(EchoTool);
634
635 let tool = registry.get("echo");
636 assert!(tool.is_some());
637 assert_eq!(tool.unwrap().name(), "echo");
638
639 let missing = registry.get("nonexistent");
640 assert!(missing.is_none());
641 }
642
643 #[test]
644 fn test_list_tools() {
645 let registry = ToolRegistry::new();
646 registry.register(EchoTool);
647
648 let tools = registry.list();
649 assert_eq!(tools.len(), 1);
650 assert_eq!(tools[0].name, "echo");
651 assert_eq!(tools[0].description, Some("Echoes back the input message".to_string()));
652 }
653
654 #[tokio::test]
655 async fn test_call_tool() {
656 let registry = ToolRegistry::new();
657 registry.register(EchoTool);
658 let session = Session::new();
659
660 let params = serde_json::json!({
661 "message": "Hello, world!"
662 });
663
664 let result = registry.call("echo", params, &session, None).await.unwrap();
665 assert_eq!(result.len(), 1);
666 }
667
668 #[tokio::test]
669 async fn test_call_missing_tool() {
670 let registry = ToolRegistry::new();
671 let session = Session::new();
672
673 let params = serde_json::json!({});
674 let result = registry.call("nonexistent", params, &session, None).await;
675
676 assert!(matches!(result, Err(ToolError::NotFound(_))));
677 }
678
679 #[tokio::test]
680 async fn test_tool_invalid_arguments() {
681 let registry = ToolRegistry::new();
682 registry.register(EchoTool);
683 let session = Session::new();
684
685 let params = serde_json::json!({}); let result = registry.call("echo", params, &session, None).await;
688 assert!(matches!(result, Err(ToolError::InvalidArguments(_))));
689 }
690}