use serde::{Deserialize, Serialize};
use std::time::Duration;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct RetryConfig {
#[serde(deserialize_with = "non_zero_u32")]
pub max_attempts: u32,
#[serde(with = "duration_millis")]
pub initial_backoff: Duration,
#[serde(with = "duration_millis")]
pub max_backoff: Duration,
pub backoff_multiplier: f32,
pub jitter: bool,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_attempts: 3,
initial_backoff: Duration::from_millis(100),
max_backoff: Duration::from_secs(10),
backoff_multiplier: 2.0,
jitter: true,
}
}
}
impl RetryConfig {
pub fn validate(&self) -> Result<(), String> {
if self.max_attempts == 0 {
return Err("max_attempts must be >= 1".to_string());
}
Ok(())
}
pub fn calculate_backoff(&self, attempt: u32) -> Duration {
let base =
self.initial_backoff.as_millis() as f32 * self.backoff_multiplier.powi(attempt as i32);
let capped = base.min(self.max_backoff.as_millis() as f32);
let duration = if self.jitter {
let jitter_factor = 0.5 + (rand() * 0.5);
Duration::from_millis((capped * jitter_factor) as u64)
} else {
Duration::from_millis(capped as u64)
};
duration.min(self.max_backoff)
}
}
fn non_zero_u32<'de, D>(deserializer: D) -> Result<u32, D::Error>
where
D: serde::Deserializer<'de>,
{
let value = u32::deserialize(deserializer)?;
if value == 0 {
return Err(serde::de::Error::custom("max_attempts must be >= 1"));
}
Ok(value)
}
fn rand() -> f32 {
use rand::Rng;
rand::thread_rng().gen()
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct HealingConfig {
pub enabled: bool,
pub strict_mode: bool,
pub allow_type_coercion: bool,
pub min_confidence: f32,
pub allow_fuzzy_matching: bool,
pub max_attempts: u32,
}
impl Default for HealingConfig {
fn default() -> Self {
Self {
enabled: true,
strict_mode: false,
allow_type_coercion: true,
min_confidence: 0.7,
allow_fuzzy_matching: true,
max_attempts: 3,
}
}
}
impl HealingConfig {
pub fn strict() -> Self {
Self {
enabled: true,
strict_mode: true,
allow_type_coercion: false,
min_confidence: 0.95,
allow_fuzzy_matching: false,
max_attempts: 1,
}
}
pub fn lenient() -> Self {
Self {
enabled: true,
strict_mode: false,
allow_type_coercion: true,
min_confidence: 0.5,
allow_fuzzy_matching: true,
max_attempts: 5,
}
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct Capabilities {
pub streaming: bool,
pub function_calling: bool,
pub vision: bool,
pub max_tokens: u32,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ProviderConfig {
pub name: String,
pub base_url: String,
#[serde(
skip_serializing_if = "Option::is_none",
serialize_with = "serialize_optional_secret"
)]
pub api_key: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub default_model: Option<String>,
#[serde(default)]
pub retry_config: RetryConfig,
#[serde(with = "duration_millis")]
pub timeout: Duration,
#[serde(default)]
pub capabilities: Capabilities,
}
fn serialize_optional_secret<S>(value: &Option<String>, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match value {
Some(_) => serializer.serialize_some("<redacted>"),
None => serializer.serialize_none(),
}
}
impl ProviderConfig {
pub fn new(name: impl Into<String>, base_url: impl Into<String>) -> Self {
Self {
name: name.into(),
base_url: base_url.into(),
api_key: None,
default_model: None,
retry_config: RetryConfig::default(),
timeout: Duration::from_secs(30),
capabilities: Capabilities::default(),
}
}
pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
self.api_key = Some(api_key.into());
self
}
pub fn with_default_model(mut self, model: impl Into<String>) -> Self {
self.default_model = Some(model.into());
self
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum RateLimitScope {
#[default]
PerInstance,
Shared,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct RateLimitConfig {
pub enabled: bool,
pub requests_per_second: u32,
pub burst_size: u32,
#[serde(default)]
pub scope: RateLimitScope,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
enabled: false,
requests_per_second: 10,
burst_size: 20,
scope: RateLimitScope::PerInstance,
}
}
}
impl RateLimitConfig {
pub fn validate(&self) -> Result<(), String> {
if !self.enabled {
return Ok(());
}
if self.requests_per_second == 0 {
return Err(
"requests_per_second must be >= 1 when rate limiting is enabled".to_string(),
);
}
if self.burst_size == 0 {
return Err("burst_size must be >= 1 when rate limiting is enabled".to_string());
}
Ok(())
}
pub fn new(requests_per_second: u32, burst_size: u32) -> Self {
Self {
enabled: true,
requests_per_second,
burst_size,
scope: RateLimitScope::PerInstance,
}
}
pub fn shared(requests_per_second: u32, burst_size: u32) -> Self {
Self {
enabled: true,
requests_per_second,
burst_size,
scope: RateLimitScope::Shared,
}
}
pub fn disabled() -> Self {
Self {
enabled: false,
requests_per_second: 0,
burst_size: 0,
scope: RateLimitScope::PerInstance,
}
}
}
mod duration_millis {
use serde::{Deserialize, Deserializer, Serializer};
use std::time::Duration;
pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_u64(duration.as_millis() as u64)
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
where
D: Deserializer<'de>,
{
let millis = u64::deserialize(deserializer)?;
Ok(Duration::from_millis(millis))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_retry_config_default() {
let config = RetryConfig::default();
assert!(config.validate().is_ok());
assert_eq!(config.max_attempts, 3);
assert_eq!(config.initial_backoff, Duration::from_millis(100));
assert_eq!(config.max_backoff, Duration::from_secs(10));
assert_eq!(config.backoff_multiplier, 2.0);
assert!(config.jitter);
}
#[test]
fn test_retry_config_backoff() {
let config = RetryConfig {
max_attempts: 5,
initial_backoff: Duration::from_millis(100),
max_backoff: Duration::from_secs(10),
backoff_multiplier: 2.0,
jitter: false,
};
let backoff1 = config.calculate_backoff(0);
let backoff2 = config.calculate_backoff(1);
let backoff3 = config.calculate_backoff(2);
assert_eq!(backoff1, Duration::from_millis(100));
assert_eq!(backoff2, Duration::from_millis(200));
assert_eq!(backoff3, Duration::from_millis(400));
}
#[test]
fn test_retry_config_validate_rejects_zero_attempts() {
let config = RetryConfig {
max_attempts: 0,
initial_backoff: Duration::from_millis(100),
max_backoff: Duration::from_secs(1),
backoff_multiplier: 2.0,
jitter: false,
};
assert_eq!(
config.validate().unwrap_err(),
"max_attempts must be >= 1".to_string()
);
}
#[test]
fn test_retry_config_deserialize_rejects_zero_attempts() {
let json = r#"{\"max_attempts\":0,\"initial_backoff\":100,\"max_backoff\":1000,\"backoff_multiplier\":2.0,\"jitter\":false}"#;
let parsed: Result<RetryConfig, _> = serde_json::from_str(json);
assert!(parsed.is_err());
}
#[test]
fn test_retry_config_max_backoff() {
let config = RetryConfig {
max_attempts: 10,
initial_backoff: Duration::from_millis(100),
max_backoff: Duration::from_secs(1),
backoff_multiplier: 2.0,
jitter: false,
};
let backoff = config.calculate_backoff(10);
assert!(backoff <= Duration::from_secs(1));
}
#[test]
fn test_healing_config_default() {
let config = HealingConfig::default();
assert!(config.enabled);
assert!(!config.strict_mode);
assert!(config.allow_type_coercion);
assert_eq!(config.min_confidence, 0.7);
assert!(config.allow_fuzzy_matching);
}
#[test]
fn test_healing_config_strict() {
let config = HealingConfig::strict();
assert!(config.enabled);
assert!(config.strict_mode);
assert!(!config.allow_type_coercion);
assert_eq!(config.min_confidence, 0.95);
assert!(!config.allow_fuzzy_matching);
}
#[test]
fn test_healing_config_lenient() {
let config = HealingConfig::lenient();
assert!(config.enabled);
assert!(!config.strict_mode);
assert!(config.allow_type_coercion);
assert_eq!(config.min_confidence, 0.5);
assert!(config.allow_fuzzy_matching);
}
#[test]
fn test_capabilities_default() {
let caps = Capabilities::default();
assert!(!caps.streaming);
assert!(!caps.function_calling);
assert!(!caps.vision);
assert_eq!(caps.max_tokens, 0);
}
#[test]
fn test_provider_config_builder() {
let config = ProviderConfig::new("openai", "https://api.openai.com/v1")
.with_api_key("sk-test")
.with_default_model("gpt-4")
.with_timeout(Duration::from_secs(60));
assert_eq!(config.name, "openai");
assert_eq!(config.base_url, "https://api.openai.com/v1");
assert_eq!(config.api_key, Some("sk-test".to_string()));
assert_eq!(config.default_model, Some("gpt-4".to_string()));
assert_eq!(config.timeout, Duration::from_secs(60));
}
#[test]
fn test_config_serialization() {
let config = RetryConfig::default();
let json = serde_json::to_string(&config).unwrap();
let parsed: RetryConfig = serde_json::from_str(&json).unwrap();
assert_eq!(config, parsed);
}
#[test]
fn test_provider_config_serialization() {
let config = ProviderConfig::new("test", "https://example.com");
let json = serde_json::to_string(&config).unwrap();
let parsed: ProviderConfig = serde_json::from_str(&json).unwrap();
assert_eq!(config.name, parsed.name);
assert_eq!(config.base_url, parsed.base_url);
}
#[test]
fn test_provider_config_serialization_redacts_api_key() {
let config = ProviderConfig::new("test", "https://example.com").with_api_key("secret-key");
let json = serde_json::to_string(&config).unwrap();
let value: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(
value.get("api_key"),
Some(&serde_json::Value::String("<redacted>".to_string()))
);
}
#[test]
fn test_rate_limit_config_validate_enabled_requires_non_zero_values() {
let invalid_rps = RateLimitConfig::new(0, 10);
assert_eq!(
invalid_rps.validate().unwrap_err(),
"requests_per_second must be >= 1 when rate limiting is enabled"
);
let invalid_burst = RateLimitConfig::new(10, 0);
assert_eq!(
invalid_burst.validate().unwrap_err(),
"burst_size must be >= 1 when rate limiting is enabled"
);
}
#[test]
fn test_rate_limit_config_validate_disabled_allows_zero_values() {
let disabled = RateLimitConfig::disabled();
assert!(disabled.validate().is_ok());
}
#[test]
fn test_jitter_randomness() {
let config = RetryConfig {
max_attempts: 5,
initial_backoff: Duration::from_millis(100),
max_backoff: Duration::from_secs(10),
backoff_multiplier: 2.0,
jitter: true,
};
let backoffs: Vec<Duration> = (0..10).map(|_| config.calculate_backoff(1)).collect();
for backoff in &backoffs {
let ms = backoff.as_millis();
assert!(ms >= 100, "Backoff too small: {}ms", ms); assert!(ms <= 300, "Backoff too large: {}ms", ms); }
let unique_count = backoffs
.iter()
.collect::<std::collections::HashSet<_>>()
.len();
assert!(
unique_count > 1,
"All jitter values are the same - RNG may not be working"
);
}
}