1use serde::{Deserialize, Serialize};
11use std::{collections::HashMap, future::Future, pin::Pin, time::Instant};
12use thiserror::Error;
13
14#[derive(Debug, Error)]
19pub enum ToolError {
20 #[error("tool not found: {0}")]
21 NotFound(String),
22 #[error("invalid arguments: {0}")]
23 InvalidArgs(String),
24 #[error("tool execution failed: {0}")]
25 ExecutionFailed(String),
26 #[error("policy violation: {0}")]
27 PolicyViolation(String),
28 #[error("rate limit exceeded: {message}")]
29 RateLimit {
30 retry_after_ms: Option<u64>,
31 message: String,
32 },
33 #[error("authentication failed: {message}")]
34 AuthError { message: String },
35 #[error("model not found: {model} (provider: {provider})")]
36 ModelNotFound { model: String, provider: String },
37 #[error("timeout: elapsed {elapsed_ms}ms, limit {limit_ms}ms")]
38 Timeout { elapsed_ms: u64, limit_ms: u64 },
39 #[error("provider unavailable: {provider} ({reason})")]
40 ProviderUnavailable { provider: String, reason: String },
41 #[error("output validation failed: expected {expected_schema}, got {actual}")]
42 OutputValidationFailed {
43 expected_schema: String,
44 actual: String,
45 },
46 #[error("provider not registered: {0}")]
47 NotRegistered(String),
48 #[error("tool invocation failed: {0}")]
50 InvocationFailed(String),
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct ToolRequest {
59 pub tool_id: String,
60 pub version: String,
61 pub args: serde_json::Value,
62 pub policy: serde_json::Value,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct ToolResponse {
67 pub outputs: serde_json::Value,
68 pub latency_ms: u64,
69}
70
71pub type ToolFuture<'a, T> = Pin<Box<dyn Future<Output = Result<T, ToolError>> + Send + 'a>>;
73
74pub trait ToolDispatcher: Send + Sync {
76 fn dispatch(&self, request: &ToolRequest) -> Result<ToolResponse, ToolError>;
77
78 fn dispatch_async<'a>(&'a self, request: &'a ToolRequest) -> ToolFuture<'a, ToolResponse> {
83 Box::pin(async move { self.dispatch(request) })
84 }
85}
86
87#[derive(Default)]
89pub struct StubDispatcher {
90 responses: HashMap<String, serde_json::Value>,
91}
92
93impl StubDispatcher {
94 pub fn new() -> Self {
95 Self::default()
96 }
97
98 pub fn set_response(&mut self, tool_id: &str, response: serde_json::Value) {
99 self.responses.insert(tool_id.to_string(), response);
100 }
101}
102
103impl ToolDispatcher for StubDispatcher {
104 fn dispatch(&self, request: &ToolRequest) -> Result<ToolResponse, ToolError> {
105 if let Some(response) = self.responses.get(&request.tool_id) {
106 Ok(ToolResponse {
107 outputs: response.clone(),
108 latency_ms: 0,
109 })
110 } else {
111 Err(ToolError::NotFound(request.tool_id.clone()))
112 }
113 }
114}
115
116#[derive(Debug, Clone)]
122pub struct RetryPolicy {
123 pub max_retries: u32,
124 pub base_delay_ms: u64,
125 pub max_delay_ms: u64,
126}
127
128impl Default for RetryPolicy {
129 fn default() -> Self {
130 Self {
131 max_retries: 3,
132 base_delay_ms: 100,
133 max_delay_ms: 10_000,
134 }
135 }
136}
137
138#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct ToolSchema {
141 pub name: String,
142 pub description: String,
143 pub input_schema: serde_json::Value,
145 pub output_schema: serde_json::Value,
147 pub effects: Vec<String>,
149}
150
151#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
153pub enum Capability {
154 TextGeneration,
155 Chat,
156 Embedding,
157 Vision,
158 ToolUse,
159 StructuredOutput,
160 Streaming,
161}
162
163pub trait ToolProvider: Send + Sync {
166 fn name(&self) -> &str;
168
169 fn version(&self) -> &str;
171
172 fn schema(&self) -> &ToolSchema;
174
175 fn call(&self, input: serde_json::Value) -> Result<serde_json::Value, ToolError>;
177
178 fn call_async<'a>(&'a self, input: serde_json::Value) -> ToolFuture<'a, serde_json::Value> {
183 Box::pin(async move { self.call(input) })
184 }
185
186 fn effects(&self) -> Vec<String> {
188 self.schema().effects.clone()
189 }
190
191 fn capabilities(&self) -> Vec<Capability> {
193 vec![]
194 }
195}
196
197pub struct NullProvider {
204 tool_name: String,
205 schema: ToolSchema,
206}
207
208impl NullProvider {
209 pub fn new(tool_name: &str) -> Self {
210 Self {
211 tool_name: tool_name.to_string(),
212 schema: ToolSchema {
213 name: tool_name.to_string(),
214 description: format!("Unregistered tool: {}", tool_name),
215 input_schema: serde_json::Value::Null,
216 output_schema: serde_json::Value::Null,
217 effects: vec![],
218 },
219 }
220 }
221}
222
223impl ToolProvider for NullProvider {
224 fn name(&self) -> &str {
225 &self.tool_name
226 }
227
228 fn version(&self) -> &str {
229 "0.0.0"
230 }
231
232 fn schema(&self) -> &ToolSchema {
233 &self.schema
234 }
235
236 fn call(&self, _input: serde_json::Value) -> Result<serde_json::Value, ToolError> {
237 Err(ToolError::NotRegistered(self.tool_name.clone()))
238 }
239}
240
241pub struct ProviderRegistry {
248 providers: HashMap<String, Box<dyn ToolProvider>>,
249}
250
251impl ProviderRegistry {
252 pub fn new() -> Self {
253 Self {
254 providers: HashMap::new(),
255 }
256 }
257
258 pub fn register(&mut self, name: &str, provider: Box<dyn ToolProvider>) {
260 self.providers.insert(name.to_string(), provider);
261 }
262
263 pub fn get(&self, name: &str) -> Option<&dyn ToolProvider> {
265 self.providers.get(name).map(|p| p.as_ref())
266 }
267
268 pub fn list(&self) -> Vec<&str> {
270 let mut names: Vec<&str> = self.providers.keys().map(|s| s.as_str()).collect();
271 names.sort();
272 names
273 }
274
275 pub fn has(&self, name: &str) -> bool {
277 self.providers.contains_key(name)
278 }
279
280 pub fn unregister(&mut self, name: &str) -> bool {
282 self.providers.remove(name).is_some()
283 }
284
285 pub fn len(&self) -> usize {
287 self.providers.len()
288 }
289
290 pub fn is_empty(&self) -> bool {
292 self.providers.is_empty()
293 }
294}
295
296impl Default for ProviderRegistry {
297 fn default() -> Self {
298 Self::new()
299 }
300}
301
302fn validate_provider_output(
303 schema: &serde_json::Value,
304 output: &serde_json::Value,
305) -> Result<(), ToolError> {
306 if let Err(reason) = validate_schema_value(schema, output, "$") {
307 let expected_schema = serde_json::to_string(schema).unwrap_or_else(|_| "<schema>".into());
308 let actual_output = serde_json::to_string(output).unwrap_or_else(|_| "<output>".into());
309 return Err(ToolError::OutputValidationFailed {
310 expected_schema,
311 actual: format!("{actual_output} ({reason})"),
312 });
313 }
314 Ok(())
315}
316
317fn validate_schema_value(
318 schema: &serde_json::Value,
319 value: &serde_json::Value,
320 path: &str,
321) -> Result<(), String> {
322 let schema_obj = match schema {
323 serde_json::Value::Null => return Ok(()),
324 serde_json::Value::Bool(true) => return Ok(()),
325 serde_json::Value::Bool(false) => {
326 return Err(format!("{path}: schema is false"));
327 }
328 serde_json::Value::Object(map) if map.is_empty() => return Ok(()),
329 serde_json::Value::Object(map) => map,
330 _ => return Ok(()),
331 };
332
333 if let Some(const_value) = schema_obj.get("const") {
334 if const_value != value {
335 return Err(format!("{path}: value does not match const"));
336 }
337 }
338
339 if let Some(enum_values) = schema_obj.get("enum").and_then(|v| v.as_array()) {
340 if !enum_values.iter().any(|candidate| candidate == value) {
341 return Err(format!("{path}: value is not in enum"));
342 }
343 }
344
345 if let Some(type_decl) = schema_obj.get("type") {
346 let type_matches = match type_decl {
347 serde_json::Value::String(expected) => value_matches_type(value, expected),
348 serde_json::Value::Array(candidates) => candidates
349 .iter()
350 .filter_map(|candidate| candidate.as_str())
351 .any(|expected| value_matches_type(value, expected)),
352 _ => true,
353 };
354 if !type_matches {
355 return Err(format!(
356 "{path}: expected type {}, got {}",
357 type_decl,
358 value_type_name(value)
359 ));
360 }
361 }
362
363 if let Some(obj) = value.as_object() {
364 if let Some(required_fields) = schema_obj.get("required").and_then(|v| v.as_array()) {
365 for required in required_fields.iter().filter_map(|field| field.as_str()) {
366 if !obj.contains_key(required) {
367 return Err(format!("{path}: missing required property '{required}'"));
368 }
369 }
370 }
371
372 let props = schema_obj.get("properties").and_then(|v| v.as_object());
373 if let Some(props) = props {
374 for (name, prop_schema) in props {
375 if let Some(prop_value) = obj.get(name) {
376 let prop_path = format!("{path}.{name}");
377 validate_schema_value(prop_schema, prop_value, &prop_path)?;
378 }
379 }
380 }
381
382 if let Some(additional) = schema_obj.get("additionalProperties") {
383 for (key, extra_value) in obj {
384 if props.is_some_and(|p| p.contains_key(key)) {
385 continue;
386 }
387 match additional {
388 serde_json::Value::Bool(true) => {}
389 serde_json::Value::Bool(false) => {
390 return Err(format!(
391 "{path}: additional property '{key}' is not allowed"
392 ));
393 }
394 schema => {
395 let prop_path = format!("{path}.{key}");
396 validate_schema_value(schema, extra_value, &prop_path)?;
397 }
398 }
399 }
400 }
401 }
402
403 if let (Some(items_schema), Some(items)) = (schema_obj.get("items"), value.as_array()) {
404 for (index, item) in items.iter().enumerate() {
405 let item_path = format!("{path}[{index}]");
406 validate_schema_value(items_schema, item, &item_path)?;
407 }
408 }
409
410 Ok(())
411}
412
413fn value_type_name(value: &serde_json::Value) -> &'static str {
414 match value {
415 serde_json::Value::Null => "null",
416 serde_json::Value::Bool(_) => "boolean",
417 serde_json::Value::Number(number) => {
418 if number.is_i64() || number.is_u64() {
419 "integer"
420 } else {
421 "number"
422 }
423 }
424 serde_json::Value::String(_) => "string",
425 serde_json::Value::Array(_) => "array",
426 serde_json::Value::Object(_) => "object",
427 }
428}
429
430fn value_matches_type(value: &serde_json::Value, expected_type: &str) -> bool {
431 match expected_type {
432 "null" => value.is_null(),
433 "boolean" => value.is_boolean(),
434 "number" => value.is_number(),
435 "integer" => value.as_i64().is_some() || value.as_u64().is_some(),
436 "string" => value.is_string(),
437 "array" => value.is_array(),
438 "object" => value.is_object(),
439 _ => true,
440 }
441}
442
443impl ToolDispatcher for ProviderRegistry {
447 fn dispatch(&self, request: &ToolRequest) -> Result<ToolResponse, ToolError> {
448 let provider = self
449 .providers
450 .get(&request.tool_id)
451 .ok_or_else(|| ToolError::NotRegistered(request.tool_id.clone()))?;
452
453 let _capabilities = provider.capabilities();
455
456 let start = Instant::now();
457 let output = provider.call(request.args.clone())?;
458 let latency_ms = start.elapsed().as_millis() as u64;
459 validate_provider_output(&provider.schema().output_schema, &output)?;
460
461 Ok(ToolResponse {
462 outputs: output,
463 latency_ms,
464 })
465 }
466
467 fn dispatch_async<'a>(&'a self, request: &'a ToolRequest) -> ToolFuture<'a, ToolResponse> {
468 Box::pin(async move {
469 let provider = self
470 .providers
471 .get(&request.tool_id)
472 .ok_or_else(|| ToolError::NotRegistered(request.tool_id.clone()))?;
473
474 let _capabilities = provider.capabilities();
476
477 let start = Instant::now();
478 let output = provider.call_async(request.args.clone()).await?;
479 let latency_ms = start.elapsed().as_millis() as u64;
480 validate_provider_output(&provider.schema().output_schema, &output)?;
481
482 Ok(ToolResponse {
483 outputs: output,
484 latency_ms,
485 })
486 })
487 }
488}
489
490#[cfg(test)]
495mod tests {
496 use super::*;
497 use serde_json::json;
498 use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
499
500 struct EchoProvider {
504 provider_name: String,
505 provider_version: String,
506 schema: ToolSchema,
507 }
508
509 impl EchoProvider {
510 fn new(name: &str) -> Self {
511 Self {
512 provider_name: name.to_string(),
513 provider_version: "1.0.0".to_string(),
514 schema: ToolSchema {
515 name: name.to_string(),
516 description: format!("Echo provider: {}", name),
517 input_schema: json!({"type": "object"}),
518 output_schema: json!({"type": "object"}),
519 effects: vec!["echo".to_string()],
520 },
521 }
522 }
523 }
524
525 impl ToolProvider for EchoProvider {
526 fn name(&self) -> &str {
527 &self.provider_name
528 }
529 fn version(&self) -> &str {
530 &self.provider_version
531 }
532 fn schema(&self) -> &ToolSchema {
533 &self.schema
534 }
535 fn call(&self, input: serde_json::Value) -> Result<serde_json::Value, ToolError> {
536 Ok(json!({ "echo": input }))
537 }
538 }
539
540 struct FailingProvider;
542
543 impl ToolProvider for FailingProvider {
544 fn name(&self) -> &str {
545 "failing"
546 }
547 fn version(&self) -> &str {
548 "0.1.0"
549 }
550 fn schema(&self) -> &ToolSchema {
551 Box::leak(Box::new(ToolSchema {
554 name: "failing".to_string(),
555 description: "Always fails".to_string(),
556 input_schema: json!({}),
557 output_schema: json!({}),
558 effects: vec![],
559 }))
560 }
561 fn call(&self, _input: serde_json::Value) -> Result<serde_json::Value, ToolError> {
562 Err(ToolError::InvocationFailed("intentional failure".into()))
563 }
564 }
565
566 struct DualPathProvider {
568 schema: ToolSchema,
569 }
570
571 impl DualPathProvider {
572 fn new(name: &str) -> Self {
573 Self {
574 schema: ToolSchema {
575 name: name.to_string(),
576 description: "Provider with distinct sync/async responses".to_string(),
577 input_schema: json!({"type": "object"}),
578 output_schema: json!({"type": "object"}),
579 effects: vec!["test".to_string()],
580 },
581 }
582 }
583 }
584
585 impl ToolProvider for DualPathProvider {
586 fn name(&self) -> &str {
587 &self.schema.name
588 }
589 fn version(&self) -> &str {
590 "1.0.0"
591 }
592 fn schema(&self) -> &ToolSchema {
593 &self.schema
594 }
595 fn call(&self, input: serde_json::Value) -> Result<serde_json::Value, ToolError> {
596 Ok(json!({ "path": "sync", "echo": input }))
597 }
598 fn call_async<'a>(&'a self, input: serde_json::Value) -> ToolFuture<'a, serde_json::Value> {
599 Box::pin(async move { Ok(json!({ "path": "async", "echo": input })) })
600 }
601 }
602
603 struct UnionOutputProvider {
604 schema: ToolSchema,
605 output: serde_json::Value,
606 }
607
608 impl UnionOutputProvider {
609 fn new(name: &str, output_schema: serde_json::Value, output: serde_json::Value) -> Self {
610 Self {
611 schema: ToolSchema {
612 name: name.to_string(),
613 description: "Provider used for schema validation tests".to_string(),
614 input_schema: json!({"type": "object"}),
615 output_schema,
616 effects: vec!["test".to_string()],
617 },
618 output,
619 }
620 }
621 }
622
623 impl ToolProvider for UnionOutputProvider {
624 fn name(&self) -> &str {
625 &self.schema.name
626 }
627 fn version(&self) -> &str {
628 "1.0.0"
629 }
630 fn schema(&self) -> &ToolSchema {
631 &self.schema
632 }
633 fn call(&self, _input: serde_json::Value) -> Result<serde_json::Value, ToolError> {
634 Ok(self.output.clone())
635 }
636 fn call_async<'a>(
637 &'a self,
638 _input: serde_json::Value,
639 ) -> ToolFuture<'a, serde_json::Value> {
640 let output = self.output.clone();
641 Box::pin(async move { Ok(output) })
642 }
643 }
644
645 fn noop_waker() -> Waker {
646 fn clone(_: *const ()) -> RawWaker {
647 RawWaker::new(std::ptr::null(), &VTABLE)
648 }
649 fn wake(_: *const ()) {}
650 fn wake_by_ref(_: *const ()) {}
651 fn drop(_: *const ()) {}
652
653 static VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop);
654 unsafe { Waker::from_raw(RawWaker::new(std::ptr::null(), &VTABLE)) }
656 }
657
658 fn block_on<F: Future>(future: F) -> F::Output {
659 let mut future = Box::pin(future);
660 let waker = noop_waker();
661 let mut cx = Context::from_waker(&waker);
662
663 loop {
664 match future.as_mut().poll(&mut cx) {
665 Poll::Ready(output) => return output,
666 Poll::Pending => std::thread::yield_now(),
667 }
668 }
669 }
670
671 #[test]
674 fn echo_provider_returns_wrapped_input() {
675 let provider = EchoProvider::new("test_echo");
676 let result = provider.call(json!({"x": 1})).unwrap();
677 assert_eq!(result, json!({"echo": {"x": 1}}));
678 }
679
680 #[test]
681 fn provider_metadata_accessors() {
682 let provider = EchoProvider::new("my_tool");
683 assert_eq!(provider.name(), "my_tool");
684 assert_eq!(provider.version(), "1.0.0");
685 assert_eq!(provider.schema().name, "my_tool");
686 assert_eq!(provider.effects(), vec!["echo".to_string()]);
687 }
688
689 #[test]
692 fn null_provider_returns_not_registered() {
693 let null = NullProvider::new("missing_tool");
694 assert_eq!(null.name(), "missing_tool");
695 assert_eq!(null.version(), "0.0.0");
696 let err = null.call(json!({})).unwrap_err();
697 match err {
698 ToolError::NotRegistered(name) => assert_eq!(name, "missing_tool"),
699 other => panic!("expected NotRegistered, got: {}", other),
700 }
701 }
702
703 #[test]
704 fn null_provider_schema_describes_unregistered() {
705 let null = NullProvider::new("xyz");
706 let schema = null.schema();
707 assert_eq!(schema.name, "xyz");
708 assert!(schema.description.contains("Unregistered"));
709 assert_eq!(schema.input_schema, serde_json::Value::Null);
710 assert!(schema.effects.is_empty());
711 }
712
713 #[test]
716 fn registry_starts_empty() {
717 let reg = ProviderRegistry::new();
718 assert!(reg.is_empty());
719 assert_eq!(reg.len(), 0);
720 assert!(reg.list().is_empty());
721 }
722
723 #[test]
724 fn registry_register_and_get() {
725 let mut reg = ProviderRegistry::new();
726 reg.register("echo", Box::new(EchoProvider::new("echo")));
727 assert!(reg.has("echo"));
728 assert!(!reg.has("other"));
729 assert_eq!(reg.len(), 1);
730
731 let provider = reg.get("echo").unwrap();
732 assert_eq!(provider.name(), "echo");
733 }
734
735 #[test]
736 fn registry_get_missing_returns_none() {
737 let reg = ProviderRegistry::new();
738 assert!(reg.get("nonexistent").is_none());
739 }
740
741 #[test]
742 fn registry_list_returns_sorted_names() {
743 let mut reg = ProviderRegistry::new();
744 reg.register("zebra", Box::new(EchoProvider::new("zebra")));
745 reg.register("alpha", Box::new(EchoProvider::new("alpha")));
746 reg.register("mid", Box::new(EchoProvider::new("mid")));
747 assert_eq!(reg.list(), vec!["alpha", "mid", "zebra"]);
748 }
749
750 #[test]
751 fn registry_replace_existing_provider() {
752 let mut reg = ProviderRegistry::new();
753 reg.register("tool", Box::new(EchoProvider::new("v1")));
754 assert_eq!(reg.get("tool").unwrap().name(), "v1");
755
756 reg.register("tool", Box::new(EchoProvider::new("v2")));
757 assert_eq!(reg.get("tool").unwrap().name(), "v2");
758 assert_eq!(reg.len(), 1);
759 }
760
761 #[test]
762 fn registry_unregister() {
763 let mut reg = ProviderRegistry::new();
764 reg.register("tool", Box::new(EchoProvider::new("tool")));
765 assert!(reg.unregister("tool"));
766 assert!(!reg.has("tool"));
767 assert!(!reg.unregister("tool")); }
769
770 #[test]
771 fn registry_multiple_providers() {
772 let mut reg = ProviderRegistry::new();
773 reg.register("a", Box::new(EchoProvider::new("a")));
774 reg.register("b", Box::new(EchoProvider::new("b")));
775 reg.register("c", Box::new(EchoProvider::new("c")));
776 assert_eq!(reg.len(), 3);
777 assert!(reg.has("a"));
778 assert!(reg.has("b"));
779 assert!(reg.has("c"));
780 }
781
782 #[test]
785 fn registry_dispatches_to_registered_provider() {
786 let mut reg = ProviderRegistry::new();
787 reg.register("echo", Box::new(EchoProvider::new("echo")));
788
789 let request = ToolRequest {
790 tool_id: "echo".to_string(),
791 version: "1.0.0".to_string(),
792 args: json!({"hello": "world"}),
793 policy: json!({}),
794 };
795 let response = reg.dispatch(&request).unwrap();
796 assert_eq!(response.outputs, json!({"echo": {"hello": "world"}}));
797 }
798
799 #[test]
800 fn registry_dispatch_missing_tool_returns_not_registered() {
801 let reg = ProviderRegistry::new();
802 let request = ToolRequest {
803 tool_id: "missing".to_string(),
804 version: "".to_string(),
805 args: json!({}),
806 policy: json!({}),
807 };
808 let err = reg.dispatch(&request).unwrap_err();
809 match err {
810 ToolError::NotRegistered(name) => assert_eq!(name, "missing"),
811 other => panic!("expected NotRegistered, got: {}", other),
812 }
813 }
814
815 #[test]
816 fn registry_dispatch_propagates_provider_error() {
817 let mut reg = ProviderRegistry::new();
818 reg.register("fail", Box::new(FailingProvider));
819
820 let request = ToolRequest {
821 tool_id: "fail".to_string(),
822 version: "".to_string(),
823 args: json!({}),
824 policy: json!({}),
825 };
826 let err = reg.dispatch(&request).unwrap_err();
827 match err {
828 ToolError::InvocationFailed(msg) => assert!(msg.contains("intentional")),
829 other => panic!("expected InvocationFailed, got: {}", other),
830 }
831 }
832
833 #[test]
834 fn registry_dispatch_measures_latency() {
835 let mut reg = ProviderRegistry::new();
836 reg.register("echo", Box::new(EchoProvider::new("echo")));
837
838 let request = ToolRequest {
839 tool_id: "echo".to_string(),
840 version: "1.0.0".to_string(),
841 args: json!({}),
842 policy: json!({}),
843 };
844 let response = reg.dispatch(&request).unwrap();
845 assert!(response.latency_ms < 1000);
847 }
848
849 #[test]
850 fn registry_dispatch_rejects_schema_mismatch_output() {
851 let mut reg = ProviderRegistry::new();
852 reg.register(
853 "bad",
854 Box::new(UnionOutputProvider::new(
855 "bad",
856 json!({"type": "string"}),
857 json!({"status": "wrong-shape"}),
858 )),
859 );
860
861 let request = ToolRequest {
862 tool_id: "bad".to_string(),
863 version: "1.0.0".to_string(),
864 args: json!({}),
865 policy: json!({}),
866 };
867
868 let err = reg
869 .dispatch(&request)
870 .expect_err("schema mismatch should fail");
871 match err {
872 ToolError::OutputValidationFailed {
873 expected_schema,
874 actual,
875 } => {
876 assert!(expected_schema.contains("\"string\""));
877 assert!(actual.contains("wrong-shape"));
878 }
879 other => panic!("expected OutputValidationFailed, got: {other}"),
880 }
881 }
882
883 #[test]
884 fn registry_dispatch_async_rejects_schema_mismatch_output() {
885 let mut reg = ProviderRegistry::new();
886 reg.register(
887 "bad_async",
888 Box::new(UnionOutputProvider::new(
889 "bad_async",
890 json!({"type": "object", "required": ["ok"]}),
891 json!({"missing_ok": true}),
892 )),
893 );
894
895 let request = ToolRequest {
896 tool_id: "bad_async".to_string(),
897 version: "1.0.0".to_string(),
898 args: json!({}),
899 policy: json!({}),
900 };
901
902 let err = block_on(reg.dispatch_async(&request)).expect_err("schema mismatch should fail");
903 match err {
904 ToolError::OutputValidationFailed { actual, .. } => {
905 assert!(actual.contains("missing required property"));
906 }
907 other => panic!("expected OutputValidationFailed, got: {other}"),
908 }
909 }
910
911 #[test]
912 fn registry_dispatch_accepts_union_output_type() {
913 let mut reg = ProviderRegistry::new();
914 reg.register(
915 "union",
916 Box::new(UnionOutputProvider::new(
917 "union",
918 json!({"type": ["object", "string", "null"]}),
919 json!("ok"),
920 )),
921 );
922
923 let request = ToolRequest {
924 tool_id: "union".to_string(),
925 version: "1.0.0".to_string(),
926 args: json!({}),
927 policy: json!({}),
928 };
929
930 let response = reg
931 .dispatch(&request)
932 .expect("union schema should validate");
933 assert_eq!(response.outputs, json!("ok"));
934 }
935
936 #[test]
937 fn registry_dispatch_async_uses_provider_async_path() {
938 let mut reg = ProviderRegistry::new();
939 reg.register("dual", Box::new(DualPathProvider::new("dual")));
940
941 let request = ToolRequest {
942 tool_id: "dual".to_string(),
943 version: "1.0.0".to_string(),
944 args: json!({"hello": "world"}),
945 policy: json!({}),
946 };
947
948 let sync_response = reg.dispatch(&request).unwrap();
949 assert_eq!(
950 sync_response.outputs,
951 json!({"path": "sync", "echo": {"hello": "world"}})
952 );
953
954 let async_response = block_on(reg.dispatch_async(&request)).unwrap();
955 assert_eq!(
956 async_response.outputs,
957 json!({"path": "async", "echo": {"hello": "world"}})
958 );
959 }
960
961 #[test]
962 fn registry_dispatch_async_missing_tool_returns_not_registered() {
963 let reg = ProviderRegistry::new();
964 let request = ToolRequest {
965 tool_id: "missing".to_string(),
966 version: "".to_string(),
967 args: json!({}),
968 policy: json!({}),
969 };
970
971 let err = block_on(reg.dispatch_async(&request)).unwrap_err();
972 match err {
973 ToolError::NotRegistered(name) => assert_eq!(name, "missing"),
974 other => panic!("expected NotRegistered, got: {}", other),
975 }
976 }
977
978 #[test]
979 fn registry_dispatch_async_propagates_provider_error() {
980 let mut reg = ProviderRegistry::new();
981 reg.register("fail", Box::new(FailingProvider));
982
983 let request = ToolRequest {
984 tool_id: "fail".to_string(),
985 version: "".to_string(),
986 args: json!({}),
987 policy: json!({}),
988 };
989
990 let err = block_on(reg.dispatch_async(&request)).unwrap_err();
991 match err {
992 ToolError::InvocationFailed(msg) => assert!(msg.contains("intentional")),
993 other => panic!("expected InvocationFailed, got: {}", other),
994 }
995 }
996
997 #[test]
1000 fn provider_schema_round_trip() {
1001 let provider = EchoProvider::new("my_tool");
1002 let schema = provider.schema();
1003 assert_eq!(schema.name, "my_tool");
1004 assert_eq!(schema.description, "Echo provider: my_tool");
1005 assert_eq!(schema.input_schema, json!({"type": "object"}));
1006 assert_eq!(schema.output_schema, json!({"type": "object"}));
1007 assert_eq!(schema.effects, vec!["echo"]);
1008 }
1009
1010 #[test]
1011 fn schema_serialization() {
1012 let schema = ToolSchema {
1013 name: "test".to_string(),
1014 description: "A test tool".to_string(),
1015 input_schema: json!({"type": "string"}),
1016 output_schema: json!({"type": "number"}),
1017 effects: vec!["io".to_string()],
1018 };
1019 let json_str = serde_json::to_string(&schema).unwrap();
1020 let roundtrip: ToolSchema = serde_json::from_str(&json_str).unwrap();
1021 assert_eq!(roundtrip.name, "test");
1022 assert_eq!(roundtrip.effects, vec!["io"]);
1023 }
1024
1025 #[test]
1028 fn registry_default_is_empty() {
1029 let reg = ProviderRegistry::default();
1030 assert!(reg.is_empty());
1031 }
1032
1033 #[test]
1036 fn stub_dispatcher_returns_configured_response() {
1037 let mut stub = StubDispatcher::new();
1038 stub.set_response("tool_a", json!({"ok": true}));
1039 let req = ToolRequest {
1040 tool_id: "tool_a".to_string(),
1041 version: "1".to_string(),
1042 args: json!({}),
1043 policy: json!({}),
1044 };
1045 let resp = stub.dispatch(&req).unwrap();
1046 assert_eq!(resp.outputs, json!({"ok": true}));
1047 assert_eq!(resp.latency_ms, 0);
1048 }
1049
1050 #[test]
1051 fn stub_dispatcher_returns_not_found_for_unknown() {
1052 let stub = StubDispatcher::new();
1053 let req = ToolRequest {
1054 tool_id: "unknown".to_string(),
1055 version: "".to_string(),
1056 args: json!({}),
1057 policy: json!({}),
1058 };
1059 let err = stub.dispatch(&req).unwrap_err();
1060 match err {
1061 ToolError::NotFound(name) => assert_eq!(name, "unknown"),
1062 other => panic!("expected NotFound, got: {}", other),
1063 }
1064 }
1065
1066 #[test]
1069 fn error_rate_limit_with_retry_after() {
1070 let err = ToolError::RateLimit {
1071 retry_after_ms: Some(5000),
1072 message: "Too many requests".to_string(),
1073 };
1074 let err_str = err.to_string();
1075 assert!(err_str.contains("rate limit exceeded"));
1076 assert!(err_str.contains("Too many requests"));
1077 }
1078
1079 #[test]
1080 fn error_rate_limit_without_retry_after() {
1081 let err = ToolError::RateLimit {
1082 retry_after_ms: None,
1083 message: "Rate limited".to_string(),
1084 };
1085 assert!(err.to_string().contains("Rate limited"));
1086 }
1087
1088 #[test]
1089 fn error_auth_error() {
1090 let err = ToolError::AuthError {
1091 message: "Invalid API key".to_string(),
1092 };
1093 assert!(err.to_string().contains("authentication failed"));
1094 assert!(err.to_string().contains("Invalid API key"));
1095 }
1096
1097 #[test]
1098 fn error_model_not_found() {
1099 let err = ToolError::ModelNotFound {
1100 model: "gpt-5".to_string(),
1101 provider: "openai".to_string(),
1102 };
1103 let err_str = err.to_string();
1104 assert!(err_str.contains("model not found"));
1105 assert!(err_str.contains("gpt-5"));
1106 assert!(err_str.contains("openai"));
1107 }
1108
1109 #[test]
1110 fn error_timeout() {
1111 let err = ToolError::Timeout {
1112 elapsed_ms: 35000,
1113 limit_ms: 30000,
1114 };
1115 let err_str = err.to_string();
1116 assert!(err_str.contains("timeout"));
1117 assert!(err_str.contains("35000"));
1118 assert!(err_str.contains("30000"));
1119 }
1120
1121 #[test]
1122 fn error_provider_unavailable() {
1123 let err = ToolError::ProviderUnavailable {
1124 provider: "gemini".to_string(),
1125 reason: "Service under maintenance".to_string(),
1126 };
1127 let err_str = err.to_string();
1128 assert!(err_str.contains("provider unavailable"));
1129 assert!(err_str.contains("gemini"));
1130 assert!(err_str.contains("Service under maintenance"));
1131 }
1132
1133 #[test]
1134 fn error_output_validation_failed() {
1135 let err = ToolError::OutputValidationFailed {
1136 expected_schema: r#"{"type": "string"}"#.to_string(),
1137 actual: r#"{"value": 123}"#.to_string(),
1138 };
1139 let err_str = err.to_string();
1140 assert!(err_str.contains("output validation failed"));
1141 assert!(err_str.contains("expected"));
1142 assert!(err_str.contains("got"));
1143 }
1144
1145 #[test]
1146 fn error_invalid_args() {
1147 let err = ToolError::InvalidArgs("Missing required field 'prompt'".to_string());
1148 assert!(err.to_string().contains("invalid arguments"));
1149 assert!(err.to_string().contains("Missing required field"));
1150 }
1151
1152 #[test]
1153 fn error_execution_failed() {
1154 let err = ToolError::ExecutionFailed("Network timeout".to_string());
1155 assert!(err.to_string().contains("execution failed"));
1156 assert!(err.to_string().contains("Network timeout"));
1157 }
1158
1159 #[test]
1162 fn capability_equality() {
1163 assert_eq!(Capability::TextGeneration, Capability::TextGeneration);
1164 assert_ne!(Capability::Chat, Capability::Embedding);
1165 }
1166
1167 #[test]
1168 fn capability_in_vector() {
1169 let caps = [
1170 Capability::Chat,
1171 Capability::TextGeneration,
1172 Capability::Vision,
1173 ];
1174 assert!(caps.contains(&Capability::Chat));
1175 assert!(caps.contains(&Capability::Vision));
1176 assert!(!caps.contains(&Capability::Streaming));
1177 }
1178
1179 #[test]
1182 fn retry_policy_default() {
1183 let policy = RetryPolicy::default();
1184 assert_eq!(policy.max_retries, 3);
1185 assert_eq!(policy.base_delay_ms, 100);
1186 assert_eq!(policy.max_delay_ms, 10_000);
1187 }
1188
1189 #[test]
1190 fn retry_policy_custom() {
1191 let policy = RetryPolicy {
1192 max_retries: 5,
1193 base_delay_ms: 200,
1194 max_delay_ms: 30_000,
1195 };
1196 assert_eq!(policy.max_retries, 5);
1197 assert_eq!(policy.base_delay_ms, 200);
1198 assert_eq!(policy.max_delay_ms, 30_000);
1199 }
1200}