use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetryConfig {
pub max_retries: u32,
pub initial_backoff_ms: u64,
pub backoff_multiplier: f32,
pub max_backoff_ms: u64,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: 3,
initial_backoff_ms: 1000,
backoff_multiplier: 2.0,
max_backoff_ms: 30000,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderConfig {
pub api_key: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub base_url: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub timeout_ms: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub proxy: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub retry: Option<RetryConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
pub headers: Option<HashMap<String, String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_concurrent_requests: Option<usize>,
}
impl ProviderConfig {
pub fn new(api_key: impl Into<String>) -> Self {
Self {
api_key: api_key.into(),
base_url: None,
timeout_ms: None,
proxy: None,
retry: None,
headers: None,
max_concurrent_requests: None,
}
}
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
self.base_url = Some(base_url.into());
self
}
pub fn with_timeout_ms(mut self, timeout_ms: u64) -> Self {
self.timeout_ms = Some(timeout_ms);
self
}
pub fn with_proxy(mut self, proxy: impl Into<String>) -> Self {
self.proxy = Some(proxy.into());
self
}
pub fn with_retry(mut self, retry: RetryConfig) -> Self {
self.retry = Some(retry);
self
}
pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.headers
.get_or_insert_with(HashMap::new)
.insert(key.into(), value.into());
self
}
pub fn with_headers(mut self, headers: HashMap<String, String>) -> Self {
self.headers = Some(headers);
self
}
pub fn with_max_concurrent_requests(mut self, max: usize) -> Self {
self.max_concurrent_requests = Some(max);
self
}
pub fn timeout(&self) -> std::time::Duration {
std::time::Duration::from_millis(self.timeout_ms.unwrap_or(120000))
}
pub fn retry_config(&self) -> RetryConfig {
self.retry.clone().unwrap_or_default()
}
}
#[derive(Debug, Clone)]
pub struct SharedProviderConfig {
inner: Arc<ProviderConfig>,
}
impl SharedProviderConfig {
pub fn new(config: ProviderConfig) -> Self {
Self {
inner: Arc::new(config),
}
}
pub fn get(&self) -> &ProviderConfig {
&self.inner
}
}
impl From<ProviderConfig> for SharedProviderConfig {
fn from(config: ProviderConfig) -> Self {
Self::new(config)
}
}
impl std::ops::Deref for SharedProviderConfig {
type Target = ProviderConfig;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_provider_config_builder() {
let config = ProviderConfig::new("test-key")
.with_base_url("https://api.example.com")
.with_timeout_ms(5000)
.with_header("X-Custom", "value")
.with_retry(RetryConfig::default());
assert_eq!(config.api_key, "test-key");
assert_eq!(config.base_url, Some("https://api.example.com".to_string()));
assert_eq!(config.timeout_ms, Some(5000));
assert!(config.headers.is_some());
assert!(config.retry.is_some());
}
#[test]
fn test_retry_config_default() {
let retry = RetryConfig::default();
assert_eq!(retry.max_retries, 3);
assert_eq!(retry.initial_backoff_ms, 1000);
assert_eq!(retry.backoff_multiplier, 2.0);
assert_eq!(retry.max_backoff_ms, 30000);
}
#[test]
fn test_shared_config() {
let config = ProviderConfig::new("test-key");
let shared1 = SharedProviderConfig::new(config.clone());
let shared2 = shared1.clone();
assert_eq!(shared1.api_key, shared2.api_key);
assert_eq!(Arc::strong_count(&shared1.inner), 2);
}
}