Skip to main content

lumen_runtime/
tools.rs

1//! Tool dispatch interface for external tool invocations.
2//!
3//! This module provides two layers:
4//! - **`ToolDispatcher`** — the low-level dispatch trait used by the VM.
5//! - **`ToolProvider`** — the high-level pluggable provider trait for external integrations.
6//!
7//! A [`ProviderRegistry`] collects named providers and implements `ToolDispatcher`,
8//! so it can be plugged directly into the VM's `tool_dispatcher` slot.
9
10use serde::{Deserialize, Serialize};
11use std::{collections::HashMap, future::Future, pin::Pin, time::Instant};
12use thiserror::Error;
13
14// ---------------------------------------------------------------------------
15// Errors
16// ---------------------------------------------------------------------------
17
18#[derive(Debug, Error)]
19pub enum ToolError {
20    #[error("tool not found: {0}")]
21    NotFound(String),
22    #[error("invalid arguments: {0}")]
23    InvalidArgs(String),
24    #[error("tool execution failed: {0}")]
25    ExecutionFailed(String),
26    #[error("policy violation: {0}")]
27    PolicyViolation(String),
28    #[error("rate limit exceeded: {message}")]
29    RateLimit {
30        retry_after_ms: Option<u64>,
31        message: String,
32    },
33    #[error("authentication failed: {message}")]
34    AuthError { message: String },
35    #[error("model not found: {model} (provider: {provider})")]
36    ModelNotFound { model: String, provider: String },
37    #[error("timeout: elapsed {elapsed_ms}ms, limit {limit_ms}ms")]
38    Timeout { elapsed_ms: u64, limit_ms: u64 },
39    #[error("provider unavailable: {provider} ({reason})")]
40    ProviderUnavailable { provider: String, reason: String },
41    #[error("output validation failed: expected {expected_schema}, got {actual}")]
42    OutputValidationFailed {
43        expected_schema: String,
44        actual: String,
45    },
46    #[error("provider not registered: {0}")]
47    NotRegistered(String),
48    // Legacy variant for backward compatibility
49    #[error("tool invocation failed: {0}")]
50    InvocationFailed(String),
51}
52
53// ---------------------------------------------------------------------------
54// Low-level dispatch (consumed by the VM)
55// ---------------------------------------------------------------------------
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct ToolRequest {
59    pub tool_id: String,
60    pub version: String,
61    pub args: serde_json::Value,
62    pub policy: serde_json::Value,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct ToolResponse {
67    pub outputs: serde_json::Value,
68    pub latency_ms: u64,
69}
70
71/// Boxed async result used by tool dispatcher and provider async paths.
72pub type ToolFuture<'a, T> = Pin<Box<dyn Future<Output = Result<T, ToolError>> + Send + 'a>>;
73
74/// Tool dispatch trait — implementations handle HTTP, MCP, or built-in tool calls.
75pub trait ToolDispatcher: Send + Sync {
76    fn dispatch(&self, request: &ToolRequest) -> Result<ToolResponse, ToolError>;
77
78    /// Async dispatch hook.
79    ///
80    /// Default implementation preserves backwards compatibility by delegating
81    /// to sync `dispatch`.
82    fn dispatch_async<'a>(&'a self, request: &'a ToolRequest) -> ToolFuture<'a, ToolResponse> {
83        Box::pin(async move { self.dispatch(request) })
84    }
85}
86
87/// Stub tool dispatcher for testing (returns configured responses).
88#[derive(Default)]
89pub struct StubDispatcher {
90    responses: HashMap<String, serde_json::Value>,
91}
92
93impl StubDispatcher {
94    pub fn new() -> Self {
95        Self::default()
96    }
97
98    pub fn set_response(&mut self, tool_id: &str, response: serde_json::Value) {
99        self.responses.insert(tool_id.to_string(), response);
100    }
101}
102
103impl ToolDispatcher for StubDispatcher {
104    fn dispatch(&self, request: &ToolRequest) -> Result<ToolResponse, ToolError> {
105        if let Some(response) = self.responses.get(&request.tool_id) {
106            Ok(ToolResponse {
107                outputs: response.clone(),
108                latency_ms: 0,
109            })
110        } else {
111            Err(ToolError::NotFound(request.tool_id.clone()))
112        }
113    }
114}
115
116// ---------------------------------------------------------------------------
117// High-level pluggable provider trait
118// ---------------------------------------------------------------------------
119
120/// Retry policy for tool calls.
121#[derive(Debug, Clone)]
122pub struct RetryPolicy {
123    pub max_retries: u32,
124    pub base_delay_ms: u64,
125    pub max_delay_ms: u64,
126}
127
128impl Default for RetryPolicy {
129    fn default() -> Self {
130        Self {
131            max_retries: 3,
132            base_delay_ms: 100,
133            max_delay_ms: 10_000,
134        }
135    }
136}
137
138/// Schema describing a tool's input/output types and declared effects.
139#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct ToolSchema {
141    pub name: String,
142    pub description: String,
143    /// JSON Schema for the tool's input.
144    pub input_schema: serde_json::Value,
145    /// JSON Schema for the tool's output.
146    pub output_schema: serde_json::Value,
147    /// Declared effect kinds (e.g. `["http", "trace"]`).
148    pub effects: Vec<String>,
149}
150
151/// Capability supported by a provider.
152#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
153pub enum Capability {
154    TextGeneration,
155    Chat,
156    Embedding,
157    Vision,
158    ToolUse,
159    StructuredOutput,
160    Streaming,
161}
162
163/// A pluggable tool provider. Implementations live in separate crates
164/// (e.g. an HTTP provider, an MCP provider, a mock provider).
165pub trait ToolProvider: Send + Sync {
166    /// Human-readable provider name (e.g. `"openai"`, `"anthropic"`).
167    fn name(&self) -> &str;
168
169    /// Semver version of the provider implementation.
170    fn version(&self) -> &str;
171
172    /// Schema describing the tool this provider exposes.
173    fn schema(&self) -> &ToolSchema;
174
175    /// Execute the tool with the given JSON input, returning JSON output.
176    fn call(&self, input: serde_json::Value) -> Result<serde_json::Value, ToolError>;
177
178    /// Async execution hook.
179    ///
180    /// Default implementation preserves backwards compatibility by delegating
181    /// to sync `call`.
182    fn call_async<'a>(&'a self, input: serde_json::Value) -> ToolFuture<'a, serde_json::Value> {
183        Box::pin(async move { self.call(input) })
184    }
185
186    /// Declared effect kinds this provider may trigger.
187    fn effects(&self) -> Vec<String> {
188        self.schema().effects.clone()
189    }
190
191    /// Capabilities supported by this provider (default: empty).
192    fn capabilities(&self) -> Vec<Capability> {
193        vec![]
194    }
195}
196
197// ---------------------------------------------------------------------------
198// NullProvider — returns an error for unregistered tools
199// ---------------------------------------------------------------------------
200
201/// A sentinel provider that always returns `ToolError::NotRegistered`.
202/// Used as a placeholder when no real provider has been registered for a tool.
203pub struct NullProvider {
204    tool_name: String,
205    schema: ToolSchema,
206}
207
208impl NullProvider {
209    pub fn new(tool_name: &str) -> Self {
210        Self {
211            tool_name: tool_name.to_string(),
212            schema: ToolSchema {
213                name: tool_name.to_string(),
214                description: format!("Unregistered tool: {}", tool_name),
215                input_schema: serde_json::Value::Null,
216                output_schema: serde_json::Value::Null,
217                effects: vec![],
218            },
219        }
220    }
221}
222
223impl ToolProvider for NullProvider {
224    fn name(&self) -> &str {
225        &self.tool_name
226    }
227
228    fn version(&self) -> &str {
229        "0.0.0"
230    }
231
232    fn schema(&self) -> &ToolSchema {
233        &self.schema
234    }
235
236    fn call(&self, _input: serde_json::Value) -> Result<serde_json::Value, ToolError> {
237        Err(ToolError::NotRegistered(self.tool_name.clone()))
238    }
239}
240
241// ---------------------------------------------------------------------------
242// ProviderRegistry
243// ---------------------------------------------------------------------------
244
245/// A registry of named tool providers. Implements `ToolDispatcher` so it can
246/// be plugged directly into the VM.
247pub struct ProviderRegistry {
248    providers: HashMap<String, Box<dyn ToolProvider>>,
249}
250
251impl ProviderRegistry {
252    pub fn new() -> Self {
253        Self {
254            providers: HashMap::new(),
255        }
256    }
257
258    /// Register a provider under the given name, replacing any previous one.
259    pub fn register(&mut self, name: &str, provider: Box<dyn ToolProvider>) {
260        self.providers.insert(name.to_string(), provider);
261    }
262
263    /// Look up a provider by name.
264    pub fn get(&self, name: &str) -> Option<&dyn ToolProvider> {
265        self.providers.get(name).map(|p| p.as_ref())
266    }
267
268    /// Return the names of all registered providers.
269    pub fn list(&self) -> Vec<&str> {
270        let mut names: Vec<&str> = self.providers.keys().map(|s| s.as_str()).collect();
271        names.sort();
272        names
273    }
274
275    /// Check whether a provider is registered under the given name.
276    pub fn has(&self, name: &str) -> bool {
277        self.providers.contains_key(name)
278    }
279
280    /// Remove a provider by name, returning `true` if it existed.
281    pub fn unregister(&mut self, name: &str) -> bool {
282        self.providers.remove(name).is_some()
283    }
284
285    /// Number of registered providers.
286    pub fn len(&self) -> usize {
287        self.providers.len()
288    }
289
290    /// Whether the registry is empty.
291    pub fn is_empty(&self) -> bool {
292        self.providers.is_empty()
293    }
294}
295
296impl Default for ProviderRegistry {
297    fn default() -> Self {
298        Self::new()
299    }
300}
301
302fn validate_provider_output(
303    schema: &serde_json::Value,
304    output: &serde_json::Value,
305) -> Result<(), ToolError> {
306    if let Err(reason) = validate_schema_value(schema, output, "$") {
307        let expected_schema = serde_json::to_string(schema).unwrap_or_else(|_| "<schema>".into());
308        let actual_output = serde_json::to_string(output).unwrap_or_else(|_| "<output>".into());
309        return Err(ToolError::OutputValidationFailed {
310            expected_schema,
311            actual: format!("{actual_output} ({reason})"),
312        });
313    }
314    Ok(())
315}
316
317fn validate_schema_value(
318    schema: &serde_json::Value,
319    value: &serde_json::Value,
320    path: &str,
321) -> Result<(), String> {
322    let schema_obj = match schema {
323        serde_json::Value::Null => return Ok(()),
324        serde_json::Value::Bool(true) => return Ok(()),
325        serde_json::Value::Bool(false) => {
326            return Err(format!("{path}: schema is false"));
327        }
328        serde_json::Value::Object(map) if map.is_empty() => return Ok(()),
329        serde_json::Value::Object(map) => map,
330        _ => return Ok(()),
331    };
332
333    if let Some(const_value) = schema_obj.get("const") {
334        if const_value != value {
335            return Err(format!("{path}: value does not match const"));
336        }
337    }
338
339    if let Some(enum_values) = schema_obj.get("enum").and_then(|v| v.as_array()) {
340        if !enum_values.iter().any(|candidate| candidate == value) {
341            return Err(format!("{path}: value is not in enum"));
342        }
343    }
344
345    if let Some(type_decl) = schema_obj.get("type") {
346        let type_matches = match type_decl {
347            serde_json::Value::String(expected) => value_matches_type(value, expected),
348            serde_json::Value::Array(candidates) => candidates
349                .iter()
350                .filter_map(|candidate| candidate.as_str())
351                .any(|expected| value_matches_type(value, expected)),
352            _ => true,
353        };
354        if !type_matches {
355            return Err(format!(
356                "{path}: expected type {}, got {}",
357                type_decl,
358                value_type_name(value)
359            ));
360        }
361    }
362
363    if let Some(obj) = value.as_object() {
364        if let Some(required_fields) = schema_obj.get("required").and_then(|v| v.as_array()) {
365            for required in required_fields.iter().filter_map(|field| field.as_str()) {
366                if !obj.contains_key(required) {
367                    return Err(format!("{path}: missing required property '{required}'"));
368                }
369            }
370        }
371
372        let props = schema_obj.get("properties").and_then(|v| v.as_object());
373        if let Some(props) = props {
374            for (name, prop_schema) in props {
375                if let Some(prop_value) = obj.get(name) {
376                    let prop_path = format!("{path}.{name}");
377                    validate_schema_value(prop_schema, prop_value, &prop_path)?;
378                }
379            }
380        }
381
382        if let Some(additional) = schema_obj.get("additionalProperties") {
383            for (key, extra_value) in obj {
384                if props.is_some_and(|p| p.contains_key(key)) {
385                    continue;
386                }
387                match additional {
388                    serde_json::Value::Bool(true) => {}
389                    serde_json::Value::Bool(false) => {
390                        return Err(format!(
391                            "{path}: additional property '{key}' is not allowed"
392                        ));
393                    }
394                    schema => {
395                        let prop_path = format!("{path}.{key}");
396                        validate_schema_value(schema, extra_value, &prop_path)?;
397                    }
398                }
399            }
400        }
401    }
402
403    if let (Some(items_schema), Some(items)) = (schema_obj.get("items"), value.as_array()) {
404        for (index, item) in items.iter().enumerate() {
405            let item_path = format!("{path}[{index}]");
406            validate_schema_value(items_schema, item, &item_path)?;
407        }
408    }
409
410    Ok(())
411}
412
413fn value_type_name(value: &serde_json::Value) -> &'static str {
414    match value {
415        serde_json::Value::Null => "null",
416        serde_json::Value::Bool(_) => "boolean",
417        serde_json::Value::Number(number) => {
418            if number.is_i64() || number.is_u64() {
419                "integer"
420            } else {
421                "number"
422            }
423        }
424        serde_json::Value::String(_) => "string",
425        serde_json::Value::Array(_) => "array",
426        serde_json::Value::Object(_) => "object",
427    }
428}
429
430fn value_matches_type(value: &serde_json::Value, expected_type: &str) -> bool {
431    match expected_type {
432        "null" => value.is_null(),
433        "boolean" => value.is_boolean(),
434        "number" => value.is_number(),
435        "integer" => value.as_i64().is_some() || value.as_u64().is_some(),
436        "string" => value.is_string(),
437        "array" => value.is_array(),
438        "object" => value.is_object(),
439        _ => true,
440    }
441}
442
443/// The registry doubles as a `ToolDispatcher`.  It resolves `request.tool_id`
444/// to a registered provider, forwards the call, and wraps the result in a
445/// `ToolResponse`.
446impl ToolDispatcher for ProviderRegistry {
447    fn dispatch(&self, request: &ToolRequest) -> Result<ToolResponse, ToolError> {
448        let provider = self
449            .providers
450            .get(&request.tool_id)
451            .ok_or_else(|| ToolError::NotRegistered(request.tool_id.clone()))?;
452
453        // Check capabilities (future: validate against request requirements)
454        let _capabilities = provider.capabilities();
455
456        let start = Instant::now();
457        let output = provider.call(request.args.clone())?;
458        let latency_ms = start.elapsed().as_millis() as u64;
459        validate_provider_output(&provider.schema().output_schema, &output)?;
460
461        Ok(ToolResponse {
462            outputs: output,
463            latency_ms,
464        })
465    }
466
467    fn dispatch_async<'a>(&'a self, request: &'a ToolRequest) -> ToolFuture<'a, ToolResponse> {
468        Box::pin(async move {
469            let provider = self
470                .providers
471                .get(&request.tool_id)
472                .ok_or_else(|| ToolError::NotRegistered(request.tool_id.clone()))?;
473
474            // Check capabilities (future: validate against request requirements)
475            let _capabilities = provider.capabilities();
476
477            let start = Instant::now();
478            let output = provider.call_async(request.args.clone()).await?;
479            let latency_ms = start.elapsed().as_millis() as u64;
480            validate_provider_output(&provider.schema().output_schema, &output)?;
481
482            Ok(ToolResponse {
483                outputs: output,
484                latency_ms,
485            })
486        })
487    }
488}
489
490// ---------------------------------------------------------------------------
491// Tests
492// ---------------------------------------------------------------------------
493
494#[cfg(test)]
495mod tests {
496    use super::*;
497    use serde_json::json;
498    use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
499
500    // -- helpers ----------------------------------------------------------
501
502    /// A simple in-memory provider for testing.
503    struct EchoProvider {
504        provider_name: String,
505        provider_version: String,
506        schema: ToolSchema,
507    }
508
509    impl EchoProvider {
510        fn new(name: &str) -> Self {
511            Self {
512                provider_name: name.to_string(),
513                provider_version: "1.0.0".to_string(),
514                schema: ToolSchema {
515                    name: name.to_string(),
516                    description: format!("Echo provider: {}", name),
517                    input_schema: json!({"type": "object"}),
518                    output_schema: json!({"type": "object"}),
519                    effects: vec!["echo".to_string()],
520                },
521            }
522        }
523    }
524
525    impl ToolProvider for EchoProvider {
526        fn name(&self) -> &str {
527            &self.provider_name
528        }
529        fn version(&self) -> &str {
530            &self.provider_version
531        }
532        fn schema(&self) -> &ToolSchema {
533            &self.schema
534        }
535        fn call(&self, input: serde_json::Value) -> Result<serde_json::Value, ToolError> {
536            Ok(json!({ "echo": input }))
537        }
538    }
539
540    /// A provider that always fails.
541    struct FailingProvider;
542
543    impl ToolProvider for FailingProvider {
544        fn name(&self) -> &str {
545            "failing"
546        }
547        fn version(&self) -> &str {
548            "0.1.0"
549        }
550        fn schema(&self) -> &ToolSchema {
551            // Leak a static schema for testing convenience.
552            // (Tests don't care about the small allocation.)
553            Box::leak(Box::new(ToolSchema {
554                name: "failing".to_string(),
555                description: "Always fails".to_string(),
556                input_schema: json!({}),
557                output_schema: json!({}),
558                effects: vec![],
559            }))
560        }
561        fn call(&self, _input: serde_json::Value) -> Result<serde_json::Value, ToolError> {
562            Err(ToolError::InvocationFailed("intentional failure".into()))
563        }
564    }
565
566    /// Provider with distinct sync/async behavior so tests can verify dispatch path.
567    struct DualPathProvider {
568        schema: ToolSchema,
569    }
570
571    impl DualPathProvider {
572        fn new(name: &str) -> Self {
573            Self {
574                schema: ToolSchema {
575                    name: name.to_string(),
576                    description: "Provider with distinct sync/async responses".to_string(),
577                    input_schema: json!({"type": "object"}),
578                    output_schema: json!({"type": "object"}),
579                    effects: vec!["test".to_string()],
580                },
581            }
582        }
583    }
584
585    impl ToolProvider for DualPathProvider {
586        fn name(&self) -> &str {
587            &self.schema.name
588        }
589        fn version(&self) -> &str {
590            "1.0.0"
591        }
592        fn schema(&self) -> &ToolSchema {
593            &self.schema
594        }
595        fn call(&self, input: serde_json::Value) -> Result<serde_json::Value, ToolError> {
596            Ok(json!({ "path": "sync", "echo": input }))
597        }
598        fn call_async<'a>(&'a self, input: serde_json::Value) -> ToolFuture<'a, serde_json::Value> {
599            Box::pin(async move { Ok(json!({ "path": "async", "echo": input })) })
600        }
601    }
602
603    struct UnionOutputProvider {
604        schema: ToolSchema,
605        output: serde_json::Value,
606    }
607
608    impl UnionOutputProvider {
609        fn new(name: &str, output_schema: serde_json::Value, output: serde_json::Value) -> Self {
610            Self {
611                schema: ToolSchema {
612                    name: name.to_string(),
613                    description: "Provider used for schema validation tests".to_string(),
614                    input_schema: json!({"type": "object"}),
615                    output_schema,
616                    effects: vec!["test".to_string()],
617                },
618                output,
619            }
620        }
621    }
622
623    impl ToolProvider for UnionOutputProvider {
624        fn name(&self) -> &str {
625            &self.schema.name
626        }
627        fn version(&self) -> &str {
628            "1.0.0"
629        }
630        fn schema(&self) -> &ToolSchema {
631            &self.schema
632        }
633        fn call(&self, _input: serde_json::Value) -> Result<serde_json::Value, ToolError> {
634            Ok(self.output.clone())
635        }
636        fn call_async<'a>(
637            &'a self,
638            _input: serde_json::Value,
639        ) -> ToolFuture<'a, serde_json::Value> {
640            let output = self.output.clone();
641            Box::pin(async move { Ok(output) })
642        }
643    }
644
645    fn noop_waker() -> Waker {
646        fn clone(_: *const ()) -> RawWaker {
647            RawWaker::new(std::ptr::null(), &VTABLE)
648        }
649        fn wake(_: *const ()) {}
650        fn wake_by_ref(_: *const ()) {}
651        fn drop(_: *const ()) {}
652
653        static VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop);
654        // SAFETY: The no-op vtable never dereferences the data pointer.
655        unsafe { Waker::from_raw(RawWaker::new(std::ptr::null(), &VTABLE)) }
656    }
657
658    fn block_on<F: Future>(future: F) -> F::Output {
659        let mut future = Box::pin(future);
660        let waker = noop_waker();
661        let mut cx = Context::from_waker(&waker);
662
663        loop {
664            match future.as_mut().poll(&mut cx) {
665                Poll::Ready(output) => return output,
666                Poll::Pending => std::thread::yield_now(),
667            }
668        }
669    }
670
671    // -- ToolProvider trait ------------------------------------------------
672
673    #[test]
674    fn echo_provider_returns_wrapped_input() {
675        let provider = EchoProvider::new("test_echo");
676        let result = provider.call(json!({"x": 1})).unwrap();
677        assert_eq!(result, json!({"echo": {"x": 1}}));
678    }
679
680    #[test]
681    fn provider_metadata_accessors() {
682        let provider = EchoProvider::new("my_tool");
683        assert_eq!(provider.name(), "my_tool");
684        assert_eq!(provider.version(), "1.0.0");
685        assert_eq!(provider.schema().name, "my_tool");
686        assert_eq!(provider.effects(), vec!["echo".to_string()]);
687    }
688
689    // -- NullProvider -----------------------------------------------------
690
691    #[test]
692    fn null_provider_returns_not_registered() {
693        let null = NullProvider::new("missing_tool");
694        assert_eq!(null.name(), "missing_tool");
695        assert_eq!(null.version(), "0.0.0");
696        let err = null.call(json!({})).unwrap_err();
697        match err {
698            ToolError::NotRegistered(name) => assert_eq!(name, "missing_tool"),
699            other => panic!("expected NotRegistered, got: {}", other),
700        }
701    }
702
703    #[test]
704    fn null_provider_schema_describes_unregistered() {
705        let null = NullProvider::new("xyz");
706        let schema = null.schema();
707        assert_eq!(schema.name, "xyz");
708        assert!(schema.description.contains("Unregistered"));
709        assert_eq!(schema.input_schema, serde_json::Value::Null);
710        assert!(schema.effects.is_empty());
711    }
712
713    // -- ProviderRegistry -------------------------------------------------
714
715    #[test]
716    fn registry_starts_empty() {
717        let reg = ProviderRegistry::new();
718        assert!(reg.is_empty());
719        assert_eq!(reg.len(), 0);
720        assert!(reg.list().is_empty());
721    }
722
723    #[test]
724    fn registry_register_and_get() {
725        let mut reg = ProviderRegistry::new();
726        reg.register("echo", Box::new(EchoProvider::new("echo")));
727        assert!(reg.has("echo"));
728        assert!(!reg.has("other"));
729        assert_eq!(reg.len(), 1);
730
731        let provider = reg.get("echo").unwrap();
732        assert_eq!(provider.name(), "echo");
733    }
734
735    #[test]
736    fn registry_get_missing_returns_none() {
737        let reg = ProviderRegistry::new();
738        assert!(reg.get("nonexistent").is_none());
739    }
740
741    #[test]
742    fn registry_list_returns_sorted_names() {
743        let mut reg = ProviderRegistry::new();
744        reg.register("zebra", Box::new(EchoProvider::new("zebra")));
745        reg.register("alpha", Box::new(EchoProvider::new("alpha")));
746        reg.register("mid", Box::new(EchoProvider::new("mid")));
747        assert_eq!(reg.list(), vec!["alpha", "mid", "zebra"]);
748    }
749
750    #[test]
751    fn registry_replace_existing_provider() {
752        let mut reg = ProviderRegistry::new();
753        reg.register("tool", Box::new(EchoProvider::new("v1")));
754        assert_eq!(reg.get("tool").unwrap().name(), "v1");
755
756        reg.register("tool", Box::new(EchoProvider::new("v2")));
757        assert_eq!(reg.get("tool").unwrap().name(), "v2");
758        assert_eq!(reg.len(), 1);
759    }
760
761    #[test]
762    fn registry_unregister() {
763        let mut reg = ProviderRegistry::new();
764        reg.register("tool", Box::new(EchoProvider::new("tool")));
765        assert!(reg.unregister("tool"));
766        assert!(!reg.has("tool"));
767        assert!(!reg.unregister("tool")); // second time returns false
768    }
769
770    #[test]
771    fn registry_multiple_providers() {
772        let mut reg = ProviderRegistry::new();
773        reg.register("a", Box::new(EchoProvider::new("a")));
774        reg.register("b", Box::new(EchoProvider::new("b")));
775        reg.register("c", Box::new(EchoProvider::new("c")));
776        assert_eq!(reg.len(), 3);
777        assert!(reg.has("a"));
778        assert!(reg.has("b"));
779        assert!(reg.has("c"));
780    }
781
782    // -- ProviderRegistry as ToolDispatcher --------------------------------
783
784    #[test]
785    fn registry_dispatches_to_registered_provider() {
786        let mut reg = ProviderRegistry::new();
787        reg.register("echo", Box::new(EchoProvider::new("echo")));
788
789        let request = ToolRequest {
790            tool_id: "echo".to_string(),
791            version: "1.0.0".to_string(),
792            args: json!({"hello": "world"}),
793            policy: json!({}),
794        };
795        let response = reg.dispatch(&request).unwrap();
796        assert_eq!(response.outputs, json!({"echo": {"hello": "world"}}));
797    }
798
799    #[test]
800    fn registry_dispatch_missing_tool_returns_not_registered() {
801        let reg = ProviderRegistry::new();
802        let request = ToolRequest {
803            tool_id: "missing".to_string(),
804            version: "".to_string(),
805            args: json!({}),
806            policy: json!({}),
807        };
808        let err = reg.dispatch(&request).unwrap_err();
809        match err {
810            ToolError::NotRegistered(name) => assert_eq!(name, "missing"),
811            other => panic!("expected NotRegistered, got: {}", other),
812        }
813    }
814
815    #[test]
816    fn registry_dispatch_propagates_provider_error() {
817        let mut reg = ProviderRegistry::new();
818        reg.register("fail", Box::new(FailingProvider));
819
820        let request = ToolRequest {
821            tool_id: "fail".to_string(),
822            version: "".to_string(),
823            args: json!({}),
824            policy: json!({}),
825        };
826        let err = reg.dispatch(&request).unwrap_err();
827        match err {
828            ToolError::InvocationFailed(msg) => assert!(msg.contains("intentional")),
829            other => panic!("expected InvocationFailed, got: {}", other),
830        }
831    }
832
833    #[test]
834    fn registry_dispatch_measures_latency() {
835        let mut reg = ProviderRegistry::new();
836        reg.register("echo", Box::new(EchoProvider::new("echo")));
837
838        let request = ToolRequest {
839            tool_id: "echo".to_string(),
840            version: "1.0.0".to_string(),
841            args: json!({}),
842            policy: json!({}),
843        };
844        let response = reg.dispatch(&request).unwrap();
845        // Latency should be very small but non-negative
846        assert!(response.latency_ms < 1000);
847    }
848
849    #[test]
850    fn registry_dispatch_rejects_schema_mismatch_output() {
851        let mut reg = ProviderRegistry::new();
852        reg.register(
853            "bad",
854            Box::new(UnionOutputProvider::new(
855                "bad",
856                json!({"type": "string"}),
857                json!({"status": "wrong-shape"}),
858            )),
859        );
860
861        let request = ToolRequest {
862            tool_id: "bad".to_string(),
863            version: "1.0.0".to_string(),
864            args: json!({}),
865            policy: json!({}),
866        };
867
868        let err = reg
869            .dispatch(&request)
870            .expect_err("schema mismatch should fail");
871        match err {
872            ToolError::OutputValidationFailed {
873                expected_schema,
874                actual,
875            } => {
876                assert!(expected_schema.contains("\"string\""));
877                assert!(actual.contains("wrong-shape"));
878            }
879            other => panic!("expected OutputValidationFailed, got: {other}"),
880        }
881    }
882
883    #[test]
884    fn registry_dispatch_async_rejects_schema_mismatch_output() {
885        let mut reg = ProviderRegistry::new();
886        reg.register(
887            "bad_async",
888            Box::new(UnionOutputProvider::new(
889                "bad_async",
890                json!({"type": "object", "required": ["ok"]}),
891                json!({"missing_ok": true}),
892            )),
893        );
894
895        let request = ToolRequest {
896            tool_id: "bad_async".to_string(),
897            version: "1.0.0".to_string(),
898            args: json!({}),
899            policy: json!({}),
900        };
901
902        let err = block_on(reg.dispatch_async(&request)).expect_err("schema mismatch should fail");
903        match err {
904            ToolError::OutputValidationFailed { actual, .. } => {
905                assert!(actual.contains("missing required property"));
906            }
907            other => panic!("expected OutputValidationFailed, got: {other}"),
908        }
909    }
910
911    #[test]
912    fn registry_dispatch_accepts_union_output_type() {
913        let mut reg = ProviderRegistry::new();
914        reg.register(
915            "union",
916            Box::new(UnionOutputProvider::new(
917                "union",
918                json!({"type": ["object", "string", "null"]}),
919                json!("ok"),
920            )),
921        );
922
923        let request = ToolRequest {
924            tool_id: "union".to_string(),
925            version: "1.0.0".to_string(),
926            args: json!({}),
927            policy: json!({}),
928        };
929
930        let response = reg
931            .dispatch(&request)
932            .expect("union schema should validate");
933        assert_eq!(response.outputs, json!("ok"));
934    }
935
936    #[test]
937    fn registry_dispatch_async_uses_provider_async_path() {
938        let mut reg = ProviderRegistry::new();
939        reg.register("dual", Box::new(DualPathProvider::new("dual")));
940
941        let request = ToolRequest {
942            tool_id: "dual".to_string(),
943            version: "1.0.0".to_string(),
944            args: json!({"hello": "world"}),
945            policy: json!({}),
946        };
947
948        let sync_response = reg.dispatch(&request).unwrap();
949        assert_eq!(
950            sync_response.outputs,
951            json!({"path": "sync", "echo": {"hello": "world"}})
952        );
953
954        let async_response = block_on(reg.dispatch_async(&request)).unwrap();
955        assert_eq!(
956            async_response.outputs,
957            json!({"path": "async", "echo": {"hello": "world"}})
958        );
959    }
960
961    #[test]
962    fn registry_dispatch_async_missing_tool_returns_not_registered() {
963        let reg = ProviderRegistry::new();
964        let request = ToolRequest {
965            tool_id: "missing".to_string(),
966            version: "".to_string(),
967            args: json!({}),
968            policy: json!({}),
969        };
970
971        let err = block_on(reg.dispatch_async(&request)).unwrap_err();
972        match err {
973            ToolError::NotRegistered(name) => assert_eq!(name, "missing"),
974            other => panic!("expected NotRegistered, got: {}", other),
975        }
976    }
977
978    #[test]
979    fn registry_dispatch_async_propagates_provider_error() {
980        let mut reg = ProviderRegistry::new();
981        reg.register("fail", Box::new(FailingProvider));
982
983        let request = ToolRequest {
984            tool_id: "fail".to_string(),
985            version: "".to_string(),
986            args: json!({}),
987            policy: json!({}),
988        };
989
990        let err = block_on(reg.dispatch_async(&request)).unwrap_err();
991        match err {
992            ToolError::InvocationFailed(msg) => assert!(msg.contains("intentional")),
993            other => panic!("expected InvocationFailed, got: {}", other),
994        }
995    }
996
997    // -- Provider schema access -------------------------------------------
998
999    #[test]
1000    fn provider_schema_round_trip() {
1001        let provider = EchoProvider::new("my_tool");
1002        let schema = provider.schema();
1003        assert_eq!(schema.name, "my_tool");
1004        assert_eq!(schema.description, "Echo provider: my_tool");
1005        assert_eq!(schema.input_schema, json!({"type": "object"}));
1006        assert_eq!(schema.output_schema, json!({"type": "object"}));
1007        assert_eq!(schema.effects, vec!["echo"]);
1008    }
1009
1010    #[test]
1011    fn schema_serialization() {
1012        let schema = ToolSchema {
1013            name: "test".to_string(),
1014            description: "A test tool".to_string(),
1015            input_schema: json!({"type": "string"}),
1016            output_schema: json!({"type": "number"}),
1017            effects: vec!["io".to_string()],
1018        };
1019        let json_str = serde_json::to_string(&schema).unwrap();
1020        let roundtrip: ToolSchema = serde_json::from_str(&json_str).unwrap();
1021        assert_eq!(roundtrip.name, "test");
1022        assert_eq!(roundtrip.effects, vec!["io"]);
1023    }
1024
1025    // -- Default trait impl -----------------------------------------------
1026
1027    #[test]
1028    fn registry_default_is_empty() {
1029        let reg = ProviderRegistry::default();
1030        assert!(reg.is_empty());
1031    }
1032
1033    // -- StubDispatcher (pre-existing, verify still works) -----------------
1034
1035    #[test]
1036    fn stub_dispatcher_returns_configured_response() {
1037        let mut stub = StubDispatcher::new();
1038        stub.set_response("tool_a", json!({"ok": true}));
1039        let req = ToolRequest {
1040            tool_id: "tool_a".to_string(),
1041            version: "1".to_string(),
1042            args: json!({}),
1043            policy: json!({}),
1044        };
1045        let resp = stub.dispatch(&req).unwrap();
1046        assert_eq!(resp.outputs, json!({"ok": true}));
1047        assert_eq!(resp.latency_ms, 0);
1048    }
1049
1050    #[test]
1051    fn stub_dispatcher_returns_not_found_for_unknown() {
1052        let stub = StubDispatcher::new();
1053        let req = ToolRequest {
1054            tool_id: "unknown".to_string(),
1055            version: "".to_string(),
1056            args: json!({}),
1057            policy: json!({}),
1058        };
1059        let err = stub.dispatch(&req).unwrap_err();
1060        match err {
1061            ToolError::NotFound(name) => assert_eq!(name, "unknown"),
1062            other => panic!("expected NotFound, got: {}", other),
1063        }
1064    }
1065
1066    // -- Error type tests --------------------------------------------------
1067
1068    #[test]
1069    fn error_rate_limit_with_retry_after() {
1070        let err = ToolError::RateLimit {
1071            retry_after_ms: Some(5000),
1072            message: "Too many requests".to_string(),
1073        };
1074        let err_str = err.to_string();
1075        assert!(err_str.contains("rate limit exceeded"));
1076        assert!(err_str.contains("Too many requests"));
1077    }
1078
1079    #[test]
1080    fn error_rate_limit_without_retry_after() {
1081        let err = ToolError::RateLimit {
1082            retry_after_ms: None,
1083            message: "Rate limited".to_string(),
1084        };
1085        assert!(err.to_string().contains("Rate limited"));
1086    }
1087
1088    #[test]
1089    fn error_auth_error() {
1090        let err = ToolError::AuthError {
1091            message: "Invalid API key".to_string(),
1092        };
1093        assert!(err.to_string().contains("authentication failed"));
1094        assert!(err.to_string().contains("Invalid API key"));
1095    }
1096
1097    #[test]
1098    fn error_model_not_found() {
1099        let err = ToolError::ModelNotFound {
1100            model: "gpt-5".to_string(),
1101            provider: "openai".to_string(),
1102        };
1103        let err_str = err.to_string();
1104        assert!(err_str.contains("model not found"));
1105        assert!(err_str.contains("gpt-5"));
1106        assert!(err_str.contains("openai"));
1107    }
1108
1109    #[test]
1110    fn error_timeout() {
1111        let err = ToolError::Timeout {
1112            elapsed_ms: 35000,
1113            limit_ms: 30000,
1114        };
1115        let err_str = err.to_string();
1116        assert!(err_str.contains("timeout"));
1117        assert!(err_str.contains("35000"));
1118        assert!(err_str.contains("30000"));
1119    }
1120
1121    #[test]
1122    fn error_provider_unavailable() {
1123        let err = ToolError::ProviderUnavailable {
1124            provider: "gemini".to_string(),
1125            reason: "Service under maintenance".to_string(),
1126        };
1127        let err_str = err.to_string();
1128        assert!(err_str.contains("provider unavailable"));
1129        assert!(err_str.contains("gemini"));
1130        assert!(err_str.contains("Service under maintenance"));
1131    }
1132
1133    #[test]
1134    fn error_output_validation_failed() {
1135        let err = ToolError::OutputValidationFailed {
1136            expected_schema: r#"{"type": "string"}"#.to_string(),
1137            actual: r#"{"value": 123}"#.to_string(),
1138        };
1139        let err_str = err.to_string();
1140        assert!(err_str.contains("output validation failed"));
1141        assert!(err_str.contains("expected"));
1142        assert!(err_str.contains("got"));
1143    }
1144
1145    #[test]
1146    fn error_invalid_args() {
1147        let err = ToolError::InvalidArgs("Missing required field 'prompt'".to_string());
1148        assert!(err.to_string().contains("invalid arguments"));
1149        assert!(err.to_string().contains("Missing required field"));
1150    }
1151
1152    #[test]
1153    fn error_execution_failed() {
1154        let err = ToolError::ExecutionFailed("Network timeout".to_string());
1155        assert!(err.to_string().contains("execution failed"));
1156        assert!(err.to_string().contains("Network timeout"));
1157    }
1158
1159    // -- Capability tests --------------------------------------------------
1160
1161    #[test]
1162    fn capability_equality() {
1163        assert_eq!(Capability::TextGeneration, Capability::TextGeneration);
1164        assert_ne!(Capability::Chat, Capability::Embedding);
1165    }
1166
1167    #[test]
1168    fn capability_in_vector() {
1169        let caps = [
1170            Capability::Chat,
1171            Capability::TextGeneration,
1172            Capability::Vision,
1173        ];
1174        assert!(caps.contains(&Capability::Chat));
1175        assert!(caps.contains(&Capability::Vision));
1176        assert!(!caps.contains(&Capability::Streaming));
1177    }
1178
1179    // -- RetryPolicy tests -------------------------------------------------
1180
1181    #[test]
1182    fn retry_policy_default() {
1183        let policy = RetryPolicy::default();
1184        assert_eq!(policy.max_retries, 3);
1185        assert_eq!(policy.base_delay_ms, 100);
1186        assert_eq!(policy.max_delay_ms, 10_000);
1187    }
1188
1189    #[test]
1190    fn retry_policy_custom() {
1191        let policy = RetryPolicy {
1192            max_retries: 5,
1193            base_delay_ms: 200,
1194            max_delay_ms: 30_000,
1195        };
1196        assert_eq!(policy.max_retries, 5);
1197        assert_eq!(policy.base_delay_ms, 200);
1198        assert_eq!(policy.max_delay_ms, 30_000);
1199    }
1200}