use std::borrow::Cow;
use std::collections::{HashMap, HashSet};
use std::future::Future;
use std::pin::Pin;
use std::time::Duration;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::chat::{ChatMessage, ChatResponse};
use crate::error::LlmError;
use crate::stream::ChatStream;
pub trait Provider: Send + Sync {
fn generate(
&self,
params: &ChatParams,
) -> impl Future<Output = Result<ChatResponse, LlmError>> + Send;
fn stream(
&self,
params: &ChatParams,
) -> impl Future<Output = Result<ChatStream, LlmError>> + Send;
fn metadata(&self) -> ProviderMetadata;
}
pub trait DynProvider: Send + Sync {
fn generate_boxed<'a>(
&'a self,
params: &'a ChatParams,
) -> Pin<Box<dyn Future<Output = Result<ChatResponse, LlmError>> + Send + 'a>>;
fn stream_boxed<'a>(
&'a self,
params: &'a ChatParams,
) -> Pin<Box<dyn Future<Output = Result<ChatStream, LlmError>> + Send + 'a>>;
fn metadata(&self) -> ProviderMetadata;
}
impl<T: Provider> DynProvider for T {
fn generate_boxed<'a>(
&'a self,
params: &'a ChatParams,
) -> Pin<Box<dyn Future<Output = Result<ChatResponse, LlmError>> + Send + 'a>> {
Box::pin(self.generate(params))
}
fn stream_boxed<'a>(
&'a self,
params: &'a ChatParams,
) -> Pin<Box<dyn Future<Output = Result<ChatStream, LlmError>> + Send + 'a>> {
Box::pin(self.stream(params))
}
fn metadata(&self) -> ProviderMetadata {
Provider::metadata(self)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ProviderMetadata {
pub name: Cow<'static, str>,
pub model: String,
pub context_window: u64,
pub capabilities: HashSet<Capability>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[non_exhaustive]
pub enum Capability {
Tools,
StructuredOutput,
Reasoning,
Vision,
Caching,
}
#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
pub struct ChatParams {
pub messages: Vec<ChatMessage>,
pub tools: Option<Vec<ToolDefinition>>,
pub tool_choice: Option<ToolChoice>,
pub temperature: Option<f32>,
pub max_tokens: Option<u32>,
pub system: Option<String>,
pub reasoning_budget: Option<u32>,
pub structured_output: Option<JsonSchema>,
#[serde(skip)]
pub timeout: Option<Duration>,
#[serde(skip)]
pub extra_headers: Option<http::HeaderMap>,
pub metadata: HashMap<String, Value>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[non_exhaustive]
pub enum ToolChoice {
Auto,
None,
Required,
Specific(String),
}
pub type RetryPredicate = std::sync::Arc<dyn Fn(&str) -> bool + Send + Sync>;
#[derive(Clone)]
pub struct ToolRetryConfig {
pub max_retries: u32,
pub initial_backoff: Duration,
pub max_backoff: Duration,
pub backoff_multiplier: f64,
pub jitter: f64,
pub retry_if: Option<RetryPredicate>,
}
impl Default for ToolRetryConfig {
fn default() -> Self {
Self {
max_retries: 3,
initial_backoff: Duration::from_millis(100),
max_backoff: Duration::from_secs(5),
backoff_multiplier: 2.0,
jitter: 0.5,
retry_if: None,
}
}
}
impl std::fmt::Debug for ToolRetryConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ToolRetryConfig")
.field("max_retries", &self.max_retries)
.field("initial_backoff", &self.initial_backoff)
.field("max_backoff", &self.max_backoff)
.field("backoff_multiplier", &self.backoff_multiplier)
.field("jitter", &self.jitter)
.field("has_retry_if", &self.retry_if.is_some())
.finish()
}
}
impl PartialEq for ToolRetryConfig {
fn eq(&self, other: &Self) -> bool {
self.max_retries == other.max_retries
&& self.initial_backoff == other.initial_backoff
&& self.max_backoff == other.max_backoff
&& self.backoff_multiplier == other.backoff_multiplier
&& self.jitter == other.jitter
&& self.retry_if.is_some() == other.retry_if.is_some()
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub parameters: JsonSchema,
#[serde(skip)]
pub retry: Option<ToolRetryConfig>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct JsonSchema(Value);
impl JsonSchema {
pub fn new(schema: Value) -> Self {
Self(schema)
}
pub fn as_value(&self) -> &Value {
&self.0
}
#[cfg(feature = "schema")]
pub fn from_type<T: schemars::JsonSchema>() -> Result<Self, serde_json::Error> {
let schema = schemars::schema_for!(T);
let value = serde_json::to_value(schema)?;
Ok(Self(value))
}
#[cfg(feature = "schema")]
pub fn validate(&self, value: &Value) -> Result<(), LlmError> {
let validator = jsonschema::validator_for(&self.0)
.map_err(|e| LlmError::InvalidRequest(format!("invalid JSON schema: {e}")))?;
let errors: Vec<String> = validator
.iter_errors(value)
.map(|e| e.to_string())
.collect();
if errors.is_empty() {
Ok(())
} else {
Err(LlmError::SchemaValidation {
message: errors.join("; "),
schema: self.0.clone(),
actual: value.clone(),
})
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_capability_hash_set() {
let caps: HashSet<Capability> = HashSet::from([
Capability::Tools,
Capability::StructuredOutput,
Capability::Reasoning,
Capability::Vision,
Capability::Caching,
]);
assert_eq!(caps.len(), 5);
}
#[test]
fn test_capability_copy() {
let c = Capability::Tools;
let c2 = c; assert_eq!(c, c2);
}
#[test]
fn test_capability_serde_roundtrip() {
let cap = Capability::Tools;
let json = serde_json::to_string(&cap).unwrap();
let back: Capability = serde_json::from_str(&json).unwrap();
assert_eq!(cap, back);
}
#[test]
fn test_provider_metadata_clone_eq() {
let m = ProviderMetadata {
name: "mock".into(),
model: "test-model".into(),
context_window: 128_000,
capabilities: HashSet::from([Capability::Tools]),
};
assert_eq!(m, m.clone());
}
#[test]
fn test_provider_metadata_owned_name() {
let name = String::from("custom-provider");
let m = ProviderMetadata {
name: Cow::Owned(name),
model: "test".into(),
context_window: 4096,
capabilities: HashSet::new(),
};
assert_eq!(m.name, "custom-provider");
}
#[test]
fn test_chat_params_defaults() {
let p = ChatParams::default();
assert!(p.messages.is_empty());
assert!(p.tools.is_none());
assert!(p.tool_choice.is_none());
assert!(p.temperature.is_none());
assert!(p.max_tokens.is_none());
assert!(p.system.is_none());
assert!(p.reasoning_budget.is_none());
assert!(p.structured_output.is_none());
assert!(p.timeout.is_none());
assert!(p.extra_headers.is_none());
assert!(p.metadata.is_empty());
}
#[test]
fn test_chat_params_full() {
let p = ChatParams {
messages: vec![ChatMessage::user("hi")],
tools: Some(vec![]),
tool_choice: Some(ToolChoice::Auto),
temperature: Some(0.7),
max_tokens: Some(1024),
system: Some("you are helpful".into()),
reasoning_budget: Some(2048),
structured_output: Some(JsonSchema::new(serde_json::json!({"type": "object"}))),
timeout: Some(Duration::from_secs(30)),
extra_headers: Some(http::HeaderMap::new()),
metadata: HashMap::from([("key".into(), serde_json::json!("val"))]),
};
assert_eq!(p.messages.len(), 1);
assert!(p.tools.is_some());
assert_eq!(p.temperature, Some(0.7));
}
#[test]
fn test_tool_choice_all_variants() {
let variants = [
ToolChoice::Auto,
ToolChoice::None,
ToolChoice::Required,
ToolChoice::Specific("my_tool".into()),
];
for v in &variants {
assert_eq!(*v, v.clone());
}
}
#[test]
fn test_tool_choice_serde_roundtrip() {
let tc = ToolChoice::Specific("search".into());
let json = serde_json::to_string(&tc).unwrap();
let back: ToolChoice = serde_json::from_str(&json).unwrap();
assert_eq!(tc, back);
}
#[test]
fn test_json_schema_from_raw() {
let schema = JsonSchema::new(serde_json::json!({"type": "object"}));
assert_eq!(*schema.as_value(), serde_json::json!({"type": "object"}));
}
#[cfg(feature = "schema")]
#[test]
fn test_json_schema_from_type_simple() {
#[derive(schemars::JsonSchema)]
struct Foo {
#[allow(dead_code)]
x: i32,
}
let schema = JsonSchema::from_type::<Foo>().unwrap();
let props = schema
.as_value()
.get("properties")
.expect("should have properties");
assert!(props.get("x").is_some());
}
#[cfg(feature = "schema")]
#[test]
fn test_json_schema_validate_valid() {
let schema = JsonSchema::new(serde_json::json!({
"type": "object",
"properties": {
"x": {"type": "integer"}
},
"required": ["x"]
}));
assert!(schema.validate(&serde_json::json!({"x": 42})).is_ok());
}
#[cfg(feature = "schema")]
#[test]
fn test_json_schema_validate_missing_field() {
let schema = JsonSchema::new(serde_json::json!({
"type": "object",
"properties": {
"x": {"type": "integer"}
},
"required": ["x"]
}));
let result = schema.validate(&serde_json::json!({}));
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
LlmError::SchemaValidation { .. }
));
}
#[cfg(feature = "schema")]
#[test]
fn test_json_schema_validate_wrong_type() {
let schema = JsonSchema::new(serde_json::json!({
"type": "object",
"properties": {
"x": {"type": "integer"}
},
"required": ["x"]
}));
let result = schema.validate(&serde_json::json!({"x": "not a number"}));
assert!(result.is_err());
}
#[cfg(feature = "schema")]
#[test]
fn test_json_schema_validate_invalid_schema() {
let schema = JsonSchema::new(serde_json::json!({"type": "bogus_not_a_type"}));
let result = schema.validate(&serde_json::json!(42));
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), LlmError::InvalidRequest(_)));
}
#[test]
fn test_json_schema_clone_eq() {
let s = JsonSchema::new(serde_json::json!({"type": "string"}));
assert_eq!(s, s.clone());
}
#[test]
fn test_json_schema_serde_roundtrip() {
let s = JsonSchema::new(
serde_json::json!({"type": "object", "properties": {"x": {"type": "integer"}}}),
);
let json = serde_json::to_string(&s).unwrap();
let back: JsonSchema = serde_json::from_str(&json).unwrap();
assert_eq!(s, back);
}
#[test]
fn test_tool_definition_serde_roundtrip() {
let td = ToolDefinition {
name: "search".into(),
description: "Search the web".into(),
parameters: JsonSchema::new(serde_json::json!({"type": "object"})),
retry: None,
};
let json = serde_json::to_string(&td).unwrap();
let back: ToolDefinition = serde_json::from_str(&json).unwrap();
assert_eq!(td, back);
}
#[test]
fn test_provider_metadata_serde_roundtrip() {
let m = ProviderMetadata {
name: "anthropic".into(),
model: "claude-sonnet-4".into(),
context_window: 200_000,
capabilities: HashSet::from([Capability::Tools, Capability::Vision]),
};
let json = serde_json::to_string(&m).unwrap();
let back: ProviderMetadata = serde_json::from_str(&json).unwrap();
assert_eq!(m, back);
}
#[test]
fn test_chat_params_serde_roundtrip_with_metadata() {
let p = ChatParams {
messages: vec![ChatMessage::user("hi")],
metadata: HashMap::from([
("provider_key".into(), serde_json::json!("abc123")),
("flags".into(), serde_json::json!({"stream": true})),
]),
..Default::default()
};
let json = serde_json::to_string(&p).unwrap();
let back: ChatParams = serde_json::from_str(&json).unwrap();
assert_eq!(back.metadata.len(), 2);
assert_eq!(back.metadata["provider_key"], serde_json::json!("abc123"));
assert_eq!(back.metadata["flags"], serde_json::json!({"stream": true}));
}
#[test]
fn test_chat_params_serde_roundtrip_skips_timeout_and_headers() {
let p = ChatParams {
messages: vec![ChatMessage::user("hi")],
temperature: Some(0.7),
timeout: Some(Duration::from_secs(30)),
extra_headers: Some(http::HeaderMap::new()),
..Default::default()
};
let json = serde_json::to_string(&p).unwrap();
let back: ChatParams = serde_json::from_str(&json).unwrap();
assert_eq!(back.timeout, None);
assert_eq!(back.extra_headers, None);
assert_eq!(back.messages.len(), 1);
assert_eq!(back.temperature, Some(0.7));
}
}