1use std::borrow::Cow;
32use std::collections::{HashMap, HashSet};
33use std::future::Future;
34use std::pin::Pin;
35use std::time::Duration;
36
37use serde::{Deserialize, Serialize};
38use serde_json::Value;
39
40use crate::chat::{ChatMessage, ChatResponse};
41use crate::error::LlmError;
42use crate::stream::ChatStream;
43
44pub trait Provider: Send + Sync {
60 fn generate(
62 &self,
63 params: &ChatParams,
64 ) -> impl Future<Output = Result<ChatResponse, LlmError>> + Send;
65
66 fn stream(
71 &self,
72 params: &ChatParams,
73 ) -> impl Future<Output = Result<ChatStream, LlmError>> + Send;
74
75 fn metadata(&self) -> ProviderMetadata;
77}
78
79pub trait DynProvider: Send + Sync {
98 fn generate_boxed<'a>(
100 &'a self,
101 params: &'a ChatParams,
102 ) -> Pin<Box<dyn Future<Output = Result<ChatResponse, LlmError>> + Send + 'a>>;
103
104 fn stream_boxed<'a>(
106 &'a self,
107 params: &'a ChatParams,
108 ) -> Pin<Box<dyn Future<Output = Result<ChatStream, LlmError>> + Send + 'a>>;
109
110 fn metadata(&self) -> ProviderMetadata;
112}
113
114impl<T: Provider> DynProvider for T {
115 fn generate_boxed<'a>(
116 &'a self,
117 params: &'a ChatParams,
118 ) -> Pin<Box<dyn Future<Output = Result<ChatResponse, LlmError>> + Send + 'a>> {
119 Box::pin(self.generate(params))
120 }
121
122 fn stream_boxed<'a>(
123 &'a self,
124 params: &'a ChatParams,
125 ) -> Pin<Box<dyn Future<Output = Result<ChatStream, LlmError>> + Send + 'a>> {
126 Box::pin(self.stream(params))
127 }
128
129 fn metadata(&self) -> ProviderMetadata {
130 Provider::metadata(self)
131 }
132}
133
134#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
140pub struct ProviderMetadata {
141 pub name: Cow<'static, str>,
143 pub model: String,
145 pub context_window: u64,
147 pub capabilities: HashSet<Capability>,
149}
150
151#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
156#[non_exhaustive]
157pub enum Capability {
158 Tools,
160 StructuredOutput,
162 Reasoning,
164 Vision,
166 Caching,
168}
169
170#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
201pub struct ChatParams {
202 pub messages: Vec<ChatMessage>,
204 pub tools: Option<Vec<ToolDefinition>>,
206 pub tool_choice: Option<ToolChoice>,
208 pub temperature: Option<f32>,
210 pub max_tokens: Option<u32>,
212 pub system: Option<String>,
215 pub reasoning_budget: Option<u32>,
218 pub structured_output: Option<JsonSchema>,
220 #[serde(skip)]
222 pub timeout: Option<Duration>,
223 #[serde(skip)]
226 pub extra_headers: Option<http::HeaderMap>,
227 pub metadata: HashMap<String, Value>,
230}
231
232#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
234#[non_exhaustive]
235pub enum ToolChoice {
236 Auto,
238 None,
240 Required,
242 Specific(String),
244}
245
246pub type RetryPredicate = std::sync::Arc<dyn Fn(&str) -> bool + Send + Sync>;
250
251#[derive(Clone)]
272pub struct ToolRetryConfig {
273 pub max_retries: u32,
275 pub initial_backoff: Duration,
277 pub max_backoff: Duration,
279 pub backoff_multiplier: f64,
281 pub jitter: f64,
283 pub retry_if: Option<RetryPredicate>,
286}
287
288impl Default for ToolRetryConfig {
289 fn default() -> Self {
290 Self {
291 max_retries: 3,
292 initial_backoff: Duration::from_millis(100),
293 max_backoff: Duration::from_secs(5),
294 backoff_multiplier: 2.0,
295 jitter: 0.5,
296 retry_if: None,
297 }
298 }
299}
300
301impl std::fmt::Debug for ToolRetryConfig {
302 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
303 f.debug_struct("ToolRetryConfig")
304 .field("max_retries", &self.max_retries)
305 .field("initial_backoff", &self.initial_backoff)
306 .field("max_backoff", &self.max_backoff)
307 .field("backoff_multiplier", &self.backoff_multiplier)
308 .field("jitter", &self.jitter)
309 .field("has_retry_if", &self.retry_if.is_some())
310 .finish()
311 }
312}
313
314impl PartialEq for ToolRetryConfig {
315 fn eq(&self, other: &Self) -> bool {
316 self.max_retries == other.max_retries
317 && self.initial_backoff == other.initial_backoff
318 && self.max_backoff == other.max_backoff
319 && self.backoff_multiplier == other.backoff_multiplier
320 && self.jitter == other.jitter
321 && self.retry_if.is_some() == other.retry_if.is_some()
322 }
323}
324
325#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
329pub struct ToolDefinition {
330 pub name: String,
332 pub description: String,
335 pub parameters: JsonSchema,
337 #[serde(skip)]
344 pub retry: Option<ToolRetryConfig>,
345}
346
347#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
369pub struct JsonSchema(Value);
370
371impl JsonSchema {
372 pub fn new(schema: Value) -> Self {
374 Self(schema)
375 }
376
377 pub fn as_value(&self) -> &Value {
379 &self.0
380 }
381
382 #[cfg(feature = "schema")]
390 pub fn from_type<T: schemars::JsonSchema>() -> Result<Self, serde_json::Error> {
391 let schema = schemars::schema_for!(T);
392 let value = serde_json::to_value(schema)?;
393 Ok(Self(value))
394 }
395
396 #[cfg(feature = "schema")]
404 pub fn validate(&self, value: &Value) -> Result<(), LlmError> {
405 let validator = jsonschema::validator_for(&self.0)
406 .map_err(|e| LlmError::InvalidRequest(format!("invalid JSON schema: {e}")))?;
407 let errors: Vec<String> = validator
408 .iter_errors(value)
409 .map(|e| e.to_string())
410 .collect();
411 if errors.is_empty() {
412 Ok(())
413 } else {
414 Err(LlmError::SchemaValidation {
415 message: errors.join("; "),
416 schema: self.0.clone(),
417 actual: value.clone(),
418 })
419 }
420 }
421}
422
423#[cfg(test)]
424mod tests {
425 use super::*;
426
427 #[test]
430 fn test_capability_hash_set() {
431 let caps: HashSet<Capability> = HashSet::from([
432 Capability::Tools,
433 Capability::StructuredOutput,
434 Capability::Reasoning,
435 Capability::Vision,
436 Capability::Caching,
437 ]);
438 assert_eq!(caps.len(), 5);
439 }
440
441 #[test]
442 fn test_capability_copy() {
443 let c = Capability::Tools;
444 let c2 = c; assert_eq!(c, c2);
446 }
447
448 #[test]
449 fn test_capability_serde_roundtrip() {
450 let cap = Capability::Tools;
451 let json = serde_json::to_string(&cap).unwrap();
452 let back: Capability = serde_json::from_str(&json).unwrap();
453 assert_eq!(cap, back);
454 }
455
456 #[test]
459 fn test_provider_metadata_clone_eq() {
460 let m = ProviderMetadata {
461 name: "mock".into(),
462 model: "test-model".into(),
463 context_window: 128_000,
464 capabilities: HashSet::from([Capability::Tools]),
465 };
466 assert_eq!(m, m.clone());
467 }
468
469 #[test]
470 fn test_provider_metadata_owned_name() {
471 let name = String::from("custom-provider");
472 let m = ProviderMetadata {
473 name: Cow::Owned(name),
474 model: "test".into(),
475 context_window: 4096,
476 capabilities: HashSet::new(),
477 };
478 assert_eq!(m.name, "custom-provider");
479 }
480
481 #[test]
484 fn test_chat_params_defaults() {
485 let p = ChatParams::default();
486 assert!(p.messages.is_empty());
487 assert!(p.tools.is_none());
488 assert!(p.tool_choice.is_none());
489 assert!(p.temperature.is_none());
490 assert!(p.max_tokens.is_none());
491 assert!(p.system.is_none());
492 assert!(p.reasoning_budget.is_none());
493 assert!(p.structured_output.is_none());
494 assert!(p.timeout.is_none());
495 assert!(p.extra_headers.is_none());
496 assert!(p.metadata.is_empty());
497 }
498
499 #[test]
500 fn test_chat_params_full() {
501 let p = ChatParams {
502 messages: vec![ChatMessage::user("hi")],
503 tools: Some(vec![]),
504 tool_choice: Some(ToolChoice::Auto),
505 temperature: Some(0.7),
506 max_tokens: Some(1024),
507 system: Some("you are helpful".into()),
508 reasoning_budget: Some(2048),
509 structured_output: Some(JsonSchema::new(serde_json::json!({"type": "object"}))),
510 timeout: Some(Duration::from_secs(30)),
511 extra_headers: Some(http::HeaderMap::new()),
512 metadata: HashMap::from([("key".into(), serde_json::json!("val"))]),
513 };
514 assert_eq!(p.messages.len(), 1);
515 assert!(p.tools.is_some());
516 assert_eq!(p.temperature, Some(0.7));
517 }
518
519 #[test]
522 fn test_tool_choice_all_variants() {
523 let variants = [
524 ToolChoice::Auto,
525 ToolChoice::None,
526 ToolChoice::Required,
527 ToolChoice::Specific("my_tool".into()),
528 ];
529 for v in &variants {
530 assert_eq!(*v, v.clone());
531 }
532 }
533
534 #[test]
535 fn test_tool_choice_serde_roundtrip() {
536 let tc = ToolChoice::Specific("search".into());
537 let json = serde_json::to_string(&tc).unwrap();
538 let back: ToolChoice = serde_json::from_str(&json).unwrap();
539 assert_eq!(tc, back);
540 }
541
542 #[test]
545 fn test_json_schema_from_raw() {
546 let schema = JsonSchema::new(serde_json::json!({"type": "object"}));
547 assert_eq!(*schema.as_value(), serde_json::json!({"type": "object"}));
548 }
549
550 #[cfg(feature = "schema")]
551 #[test]
552 fn test_json_schema_from_type_simple() {
553 #[derive(schemars::JsonSchema)]
554 struct Foo {
555 #[allow(dead_code)]
556 x: i32,
557 }
558 let schema = JsonSchema::from_type::<Foo>().unwrap();
559 let props = schema
560 .as_value()
561 .get("properties")
562 .expect("should have properties");
563 assert!(props.get("x").is_some());
564 }
565
566 #[cfg(feature = "schema")]
567 #[test]
568 fn test_json_schema_validate_valid() {
569 let schema = JsonSchema::new(serde_json::json!({
570 "type": "object",
571 "properties": {
572 "x": {"type": "integer"}
573 },
574 "required": ["x"]
575 }));
576 assert!(schema.validate(&serde_json::json!({"x": 42})).is_ok());
577 }
578
579 #[cfg(feature = "schema")]
580 #[test]
581 fn test_json_schema_validate_missing_field() {
582 let schema = JsonSchema::new(serde_json::json!({
583 "type": "object",
584 "properties": {
585 "x": {"type": "integer"}
586 },
587 "required": ["x"]
588 }));
589 let result = schema.validate(&serde_json::json!({}));
590 assert!(result.is_err());
591 assert!(matches!(
592 result.unwrap_err(),
593 LlmError::SchemaValidation { .. }
594 ));
595 }
596
597 #[cfg(feature = "schema")]
598 #[test]
599 fn test_json_schema_validate_wrong_type() {
600 let schema = JsonSchema::new(serde_json::json!({
601 "type": "object",
602 "properties": {
603 "x": {"type": "integer"}
604 },
605 "required": ["x"]
606 }));
607 let result = schema.validate(&serde_json::json!({"x": "not a number"}));
608 assert!(result.is_err());
609 }
610
611 #[cfg(feature = "schema")]
612 #[test]
613 fn test_json_schema_validate_invalid_schema() {
614 let schema = JsonSchema::new(serde_json::json!({"type": "bogus_not_a_type"}));
615 let result = schema.validate(&serde_json::json!(42));
616 assert!(result.is_err());
617 assert!(matches!(result.unwrap_err(), LlmError::InvalidRequest(_)));
618 }
619
620 #[test]
621 fn test_json_schema_clone_eq() {
622 let s = JsonSchema::new(serde_json::json!({"type": "string"}));
623 assert_eq!(s, s.clone());
624 }
625
626 #[test]
627 fn test_json_schema_serde_roundtrip() {
628 let s = JsonSchema::new(
629 serde_json::json!({"type": "object", "properties": {"x": {"type": "integer"}}}),
630 );
631 let json = serde_json::to_string(&s).unwrap();
632 let back: JsonSchema = serde_json::from_str(&json).unwrap();
633 assert_eq!(s, back);
634 }
635
636 #[test]
637 fn test_tool_definition_serde_roundtrip() {
638 let td = ToolDefinition {
639 name: "search".into(),
640 description: "Search the web".into(),
641 parameters: JsonSchema::new(serde_json::json!({"type": "object"})),
642 retry: None,
643 };
644 let json = serde_json::to_string(&td).unwrap();
645 let back: ToolDefinition = serde_json::from_str(&json).unwrap();
646 assert_eq!(td, back);
647 }
648
649 #[test]
650 fn test_provider_metadata_serde_roundtrip() {
651 let m = ProviderMetadata {
652 name: "anthropic".into(),
653 model: "claude-sonnet-4".into(),
654 context_window: 200_000,
655 capabilities: HashSet::from([Capability::Tools, Capability::Vision]),
656 };
657 let json = serde_json::to_string(&m).unwrap();
658 let back: ProviderMetadata = serde_json::from_str(&json).unwrap();
659 assert_eq!(m, back);
660 }
661
662 #[test]
663 fn test_chat_params_serde_roundtrip_with_metadata() {
664 let p = ChatParams {
665 messages: vec![ChatMessage::user("hi")],
666 metadata: HashMap::from([
667 ("provider_key".into(), serde_json::json!("abc123")),
668 ("flags".into(), serde_json::json!({"stream": true})),
669 ]),
670 ..Default::default()
671 };
672 let json = serde_json::to_string(&p).unwrap();
673 let back: ChatParams = serde_json::from_str(&json).unwrap();
674 assert_eq!(back.metadata.len(), 2);
675 assert_eq!(back.metadata["provider_key"], serde_json::json!("abc123"));
676 assert_eq!(back.metadata["flags"], serde_json::json!({"stream": true}));
677 }
678
679 #[test]
680 fn test_chat_params_serde_roundtrip_skips_timeout_and_headers() {
681 let p = ChatParams {
682 messages: vec![ChatMessage::user("hi")],
683 temperature: Some(0.7),
684 timeout: Some(Duration::from_secs(30)),
685 extra_headers: Some(http::HeaderMap::new()),
686 ..Default::default()
687 };
688 let json = serde_json::to_string(&p).unwrap();
689 let back: ChatParams = serde_json::from_str(&json).unwrap();
690 assert_eq!(back.timeout, None);
692 assert_eq!(back.extra_headers, None);
693 assert_eq!(back.messages.len(), 1);
695 assert_eq!(back.temperature, Some(0.7));
696 }
697}