1use std::borrow::Cow;
32use std::collections::{HashMap, HashSet};
33use std::future::Future;
34use std::pin::Pin;
35use std::time::Duration;
36
37use serde::{Deserialize, Serialize};
38use serde_json::Value;
39
40use crate::chat::{ChatMessage, ChatResponse};
41use crate::error::LlmError;
42use crate::stream::ChatStream;
43
44pub trait Provider: Send + Sync {
60 fn generate(
62 &self,
63 params: &ChatParams,
64 ) -> impl Future<Output = Result<ChatResponse, LlmError>> + Send;
65
66 fn stream(
71 &self,
72 params: &ChatParams,
73 ) -> impl Future<Output = Result<ChatStream, LlmError>> + Send;
74
75 fn metadata(&self) -> ProviderMetadata;
77}
78
79pub trait DynProvider: Send + Sync {
98 fn generate_boxed<'a>(
100 &'a self,
101 params: &'a ChatParams,
102 ) -> Pin<Box<dyn Future<Output = Result<ChatResponse, LlmError>> + Send + 'a>>;
103
104 fn stream_boxed<'a>(
106 &'a self,
107 params: &'a ChatParams,
108 ) -> Pin<Box<dyn Future<Output = Result<ChatStream, LlmError>> + Send + 'a>>;
109
110 fn metadata(&self) -> ProviderMetadata;
112}
113
114impl<T: Provider> DynProvider for T {
115 fn generate_boxed<'a>(
116 &'a self,
117 params: &'a ChatParams,
118 ) -> Pin<Box<dyn Future<Output = Result<ChatResponse, LlmError>> + Send + 'a>> {
119 Box::pin(self.generate(params))
120 }
121
122 fn stream_boxed<'a>(
123 &'a self,
124 params: &'a ChatParams,
125 ) -> Pin<Box<dyn Future<Output = Result<ChatStream, LlmError>> + Send + 'a>> {
126 Box::pin(self.stream(params))
127 }
128
129 fn metadata(&self) -> ProviderMetadata {
130 Provider::metadata(self)
131 }
132}
133
134#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
140pub struct ProviderMetadata {
141 pub name: Cow<'static, str>,
143 pub model: String,
145 pub context_window: u64,
147 pub capabilities: HashSet<Capability>,
149}
150
151#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
156#[non_exhaustive]
157pub enum Capability {
158 Tools,
160 StructuredOutput,
162 Reasoning,
164 Vision,
166 Caching,
168}
169
170#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
201pub struct ChatParams {
202 pub messages: Vec<ChatMessage>,
204 pub tools: Option<Vec<ToolDefinition>>,
206 pub tool_choice: Option<ToolChoice>,
208 pub temperature: Option<f32>,
210 pub max_tokens: Option<u32>,
212 pub system: Option<String>,
215 pub reasoning_budget: Option<u32>,
218 pub structured_output: Option<JsonSchema>,
220 #[serde(skip)]
222 pub timeout: Option<Duration>,
223 #[serde(skip)]
226 pub extra_headers: Option<http::HeaderMap>,
227 pub metadata: HashMap<String, Value>,
230}
231
232#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
234#[non_exhaustive]
235pub enum ToolChoice {
236 Auto,
238 None,
240 Required,
242 Specific(String),
244}
245
246pub type RetryPredicate = std::sync::Arc<dyn Fn(&str) -> bool + Send + Sync>;
250
251#[derive(Clone)]
272pub struct ToolRetryConfig {
273 pub max_retries: u32,
275 pub initial_backoff: Duration,
277 pub max_backoff: Duration,
279 pub backoff_multiplier: f64,
281 pub jitter: f64,
283 pub retry_if: Option<RetryPredicate>,
286}
287
288impl Default for ToolRetryConfig {
289 fn default() -> Self {
290 Self {
291 max_retries: 3,
292 initial_backoff: Duration::from_millis(100),
293 max_backoff: Duration::from_secs(5),
294 backoff_multiplier: 2.0,
295 jitter: 0.5,
296 retry_if: None,
297 }
298 }
299}
300
301impl std::fmt::Debug for ToolRetryConfig {
302 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
303 f.debug_struct("ToolRetryConfig")
304 .field("max_retries", &self.max_retries)
305 .field("initial_backoff", &self.initial_backoff)
306 .field("max_backoff", &self.max_backoff)
307 .field("backoff_multiplier", &self.backoff_multiplier)
308 .field("jitter", &self.jitter)
309 .field("has_retry_if", &self.retry_if.is_some())
310 .finish()
311 }
312}
313
314impl PartialEq for ToolRetryConfig {
315 fn eq(&self, other: &Self) -> bool {
316 self.max_retries == other.max_retries
317 && self.initial_backoff == other.initial_backoff
318 && self.max_backoff == other.max_backoff
319 && self.backoff_multiplier == other.backoff_multiplier
320 && self.jitter == other.jitter
321 && self.retry_if.is_some() == other.retry_if.is_some()
322 }
323}
324
325#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
329pub struct ToolDefinition {
330 pub name: String,
332 pub description: String,
335 pub parameters: JsonSchema,
337 #[serde(skip)]
340 pub retry: Option<ToolRetryConfig>,
341}
342
343#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
365pub struct JsonSchema(Value);
366
367impl JsonSchema {
368 pub fn new(schema: Value) -> Self {
370 Self(schema)
371 }
372
373 pub fn as_value(&self) -> &Value {
375 &self.0
376 }
377
378 #[cfg(feature = "schema")]
386 pub fn from_type<T: schemars::JsonSchema>() -> Result<Self, serde_json::Error> {
387 let schema = schemars::schema_for!(T);
388 let value = serde_json::to_value(schema)?;
389 Ok(Self(value))
390 }
391
392 #[cfg(feature = "schema")]
400 pub fn validate(&self, value: &Value) -> Result<(), LlmError> {
401 let validator = jsonschema::validator_for(&self.0)
402 .map_err(|e| LlmError::InvalidRequest(format!("invalid JSON schema: {e}")))?;
403 let errors: Vec<String> = validator
404 .iter_errors(value)
405 .map(|e| e.to_string())
406 .collect();
407 if errors.is_empty() {
408 Ok(())
409 } else {
410 Err(LlmError::SchemaValidation {
411 message: errors.join("; "),
412 schema: self.0.clone(),
413 actual: value.clone(),
414 })
415 }
416 }
417}
418
419#[cfg(test)]
420mod tests {
421 use super::*;
422
423 #[test]
426 fn test_capability_hash_set() {
427 let caps: HashSet<Capability> = HashSet::from([
428 Capability::Tools,
429 Capability::StructuredOutput,
430 Capability::Reasoning,
431 Capability::Vision,
432 Capability::Caching,
433 ]);
434 assert_eq!(caps.len(), 5);
435 }
436
437 #[test]
438 fn test_capability_copy() {
439 let c = Capability::Tools;
440 let c2 = c; assert_eq!(c, c2);
442 }
443
444 #[test]
445 fn test_capability_serde_roundtrip() {
446 let cap = Capability::Tools;
447 let json = serde_json::to_string(&cap).unwrap();
448 let back: Capability = serde_json::from_str(&json).unwrap();
449 assert_eq!(cap, back);
450 }
451
452 #[test]
455 fn test_provider_metadata_clone_eq() {
456 let m = ProviderMetadata {
457 name: "mock".into(),
458 model: "test-model".into(),
459 context_window: 128_000,
460 capabilities: HashSet::from([Capability::Tools]),
461 };
462 assert_eq!(m, m.clone());
463 }
464
465 #[test]
466 fn test_provider_metadata_owned_name() {
467 let name = String::from("custom-provider");
468 let m = ProviderMetadata {
469 name: Cow::Owned(name),
470 model: "test".into(),
471 context_window: 4096,
472 capabilities: HashSet::new(),
473 };
474 assert_eq!(m.name, "custom-provider");
475 }
476
477 #[test]
480 fn test_chat_params_defaults() {
481 let p = ChatParams::default();
482 assert!(p.messages.is_empty());
483 assert!(p.tools.is_none());
484 assert!(p.tool_choice.is_none());
485 assert!(p.temperature.is_none());
486 assert!(p.max_tokens.is_none());
487 assert!(p.system.is_none());
488 assert!(p.reasoning_budget.is_none());
489 assert!(p.structured_output.is_none());
490 assert!(p.timeout.is_none());
491 assert!(p.extra_headers.is_none());
492 assert!(p.metadata.is_empty());
493 }
494
495 #[test]
496 fn test_chat_params_full() {
497 let p = ChatParams {
498 messages: vec![ChatMessage::user("hi")],
499 tools: Some(vec![]),
500 tool_choice: Some(ToolChoice::Auto),
501 temperature: Some(0.7),
502 max_tokens: Some(1024),
503 system: Some("you are helpful".into()),
504 reasoning_budget: Some(2048),
505 structured_output: Some(JsonSchema::new(serde_json::json!({"type": "object"}))),
506 timeout: Some(Duration::from_secs(30)),
507 extra_headers: Some(http::HeaderMap::new()),
508 metadata: HashMap::from([("key".into(), serde_json::json!("val"))]),
509 };
510 assert_eq!(p.messages.len(), 1);
511 assert!(p.tools.is_some());
512 assert_eq!(p.temperature, Some(0.7));
513 }
514
515 #[test]
518 fn test_tool_choice_all_variants() {
519 let variants = [
520 ToolChoice::Auto,
521 ToolChoice::None,
522 ToolChoice::Required,
523 ToolChoice::Specific("my_tool".into()),
524 ];
525 for v in &variants {
526 assert_eq!(*v, v.clone());
527 }
528 }
529
530 #[test]
531 fn test_tool_choice_serde_roundtrip() {
532 let tc = ToolChoice::Specific("search".into());
533 let json = serde_json::to_string(&tc).unwrap();
534 let back: ToolChoice = serde_json::from_str(&json).unwrap();
535 assert_eq!(tc, back);
536 }
537
538 #[test]
541 fn test_json_schema_from_raw() {
542 let schema = JsonSchema::new(serde_json::json!({"type": "object"}));
543 assert_eq!(*schema.as_value(), serde_json::json!({"type": "object"}));
544 }
545
546 #[cfg(feature = "schema")]
547 #[test]
548 fn test_json_schema_from_type_simple() {
549 #[derive(schemars::JsonSchema)]
550 struct Foo {
551 #[allow(dead_code)]
552 x: i32,
553 }
554 let schema = JsonSchema::from_type::<Foo>().unwrap();
555 let props = schema
556 .as_value()
557 .get("properties")
558 .expect("should have properties");
559 assert!(props.get("x").is_some());
560 }
561
562 #[cfg(feature = "schema")]
563 #[test]
564 fn test_json_schema_validate_valid() {
565 let schema = JsonSchema::new(serde_json::json!({
566 "type": "object",
567 "properties": {
568 "x": {"type": "integer"}
569 },
570 "required": ["x"]
571 }));
572 assert!(schema.validate(&serde_json::json!({"x": 42})).is_ok());
573 }
574
575 #[cfg(feature = "schema")]
576 #[test]
577 fn test_json_schema_validate_missing_field() {
578 let schema = JsonSchema::new(serde_json::json!({
579 "type": "object",
580 "properties": {
581 "x": {"type": "integer"}
582 },
583 "required": ["x"]
584 }));
585 let result = schema.validate(&serde_json::json!({}));
586 assert!(result.is_err());
587 assert!(matches!(
588 result.unwrap_err(),
589 LlmError::SchemaValidation { .. }
590 ));
591 }
592
593 #[cfg(feature = "schema")]
594 #[test]
595 fn test_json_schema_validate_wrong_type() {
596 let schema = JsonSchema::new(serde_json::json!({
597 "type": "object",
598 "properties": {
599 "x": {"type": "integer"}
600 },
601 "required": ["x"]
602 }));
603 let result = schema.validate(&serde_json::json!({"x": "not a number"}));
604 assert!(result.is_err());
605 }
606
607 #[cfg(feature = "schema")]
608 #[test]
609 fn test_json_schema_validate_invalid_schema() {
610 let schema = JsonSchema::new(serde_json::json!({"type": "bogus_not_a_type"}));
611 let result = schema.validate(&serde_json::json!(42));
612 assert!(result.is_err());
613 assert!(matches!(result.unwrap_err(), LlmError::InvalidRequest(_)));
614 }
615
616 #[test]
617 fn test_json_schema_clone_eq() {
618 let s = JsonSchema::new(serde_json::json!({"type": "string"}));
619 assert_eq!(s, s.clone());
620 }
621
622 #[test]
623 fn test_json_schema_serde_roundtrip() {
624 let s = JsonSchema::new(
625 serde_json::json!({"type": "object", "properties": {"x": {"type": "integer"}}}),
626 );
627 let json = serde_json::to_string(&s).unwrap();
628 let back: JsonSchema = serde_json::from_str(&json).unwrap();
629 assert_eq!(s, back);
630 }
631
632 #[test]
633 fn test_tool_definition_serde_roundtrip() {
634 let td = ToolDefinition {
635 name: "search".into(),
636 description: "Search the web".into(),
637 parameters: JsonSchema::new(serde_json::json!({"type": "object"})),
638 retry: None,
639 };
640 let json = serde_json::to_string(&td).unwrap();
641 let back: ToolDefinition = serde_json::from_str(&json).unwrap();
642 assert_eq!(td, back);
643 }
644
645 #[test]
646 fn test_provider_metadata_serde_roundtrip() {
647 let m = ProviderMetadata {
648 name: "anthropic".into(),
649 model: "claude-sonnet-4".into(),
650 context_window: 200_000,
651 capabilities: HashSet::from([Capability::Tools, Capability::Vision]),
652 };
653 let json = serde_json::to_string(&m).unwrap();
654 let back: ProviderMetadata = serde_json::from_str(&json).unwrap();
655 assert_eq!(m, back);
656 }
657
658 #[test]
659 fn test_chat_params_serde_roundtrip_with_metadata() {
660 let p = ChatParams {
661 messages: vec![ChatMessage::user("hi")],
662 metadata: HashMap::from([
663 ("provider_key".into(), serde_json::json!("abc123")),
664 ("flags".into(), serde_json::json!({"stream": true})),
665 ]),
666 ..Default::default()
667 };
668 let json = serde_json::to_string(&p).unwrap();
669 let back: ChatParams = serde_json::from_str(&json).unwrap();
670 assert_eq!(back.metadata.len(), 2);
671 assert_eq!(back.metadata["provider_key"], serde_json::json!("abc123"));
672 assert_eq!(back.metadata["flags"], serde_json::json!({"stream": true}));
673 }
674
675 #[test]
676 fn test_chat_params_serde_roundtrip_skips_timeout_and_headers() {
677 let p = ChatParams {
678 messages: vec![ChatMessage::user("hi")],
679 temperature: Some(0.7),
680 timeout: Some(Duration::from_secs(30)),
681 extra_headers: Some(http::HeaderMap::new()),
682 ..Default::default()
683 };
684 let json = serde_json::to_string(&p).unwrap();
685 let back: ChatParams = serde_json::from_str(&json).unwrap();
686 assert_eq!(back.timeout, None);
688 assert_eq!(back.extra_headers, None);
689 assert_eq!(back.messages.len(), 1);
691 assert_eq!(back.temperature, Some(0.7));
692 }
693}