use super::adapter::ProviderAdapter;
use crate::error::LlmError;
use crate::types::{CommonParams, HttpConfig};
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct OpenAiCompatibleConfig {
pub provider_id: String,
pub api_key: String,
pub base_url: String,
pub model: String,
pub common_params: CommonParams,
pub http_config: HttpConfig,
pub custom_headers: reqwest::header::HeaderMap,
pub adapter: Arc<dyn ProviderAdapter>,
}
impl OpenAiCompatibleConfig {
pub fn new(
provider_id: &str,
api_key: &str,
base_url: &str,
adapter: Arc<dyn ProviderAdapter>,
) -> Self {
Self {
provider_id: provider_id.to_string(),
api_key: api_key.to_string(),
base_url: base_url.to_string(),
model: String::new(),
common_params: CommonParams::default(),
http_config: HttpConfig::default(),
custom_headers: reqwest::header::HeaderMap::new(),
adapter,
}
}
pub fn with_model(mut self, model: &str) -> Self {
self.model = model.to_string();
self
}
pub fn with_common_params(mut self, params: CommonParams) -> Self {
self.common_params = params;
self
}
pub fn with_http_config(mut self, config: HttpConfig) -> Self {
self.http_config = config;
self
}
pub fn with_header(mut self, key: &str, value: &str) -> Result<Self, LlmError> {
let header_name = reqwest::header::HeaderName::from_bytes(key.as_bytes())
.map_err(|e| LlmError::ConfigurationError(format!("Invalid header name: {}", e)))?;
let header_value = reqwest::header::HeaderValue::from_str(value)
.map_err(|e| LlmError::ConfigurationError(format!("Invalid header value: {}", e)))?;
self.custom_headers.insert(header_name, header_value);
Ok(self)
}
pub fn validate(&self) -> Result<(), LlmError> {
if self.provider_id.is_empty() {
return Err(LlmError::ConfigurationError(
"Provider ID cannot be empty".to_string(),
));
}
if self.api_key.is_empty() {
return Err(LlmError::ConfigurationError(
"API key cannot be empty".to_string(),
));
}
if self.base_url.is_empty() {
return Err(LlmError::ConfigurationError(
"Base URL cannot be empty".to_string(),
));
}
if !self.base_url.starts_with("http://") && !self.base_url.starts_with("https://") {
return Err(LlmError::ConfigurationError(
"Base URL must start with http:// or https://".to_string(),
));
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::providers::openai_compatible::registry::ConfigurableAdapter;
#[test]
fn test_config_creation() {
let provider_config = crate::providers::openai_compatible::registry::ProviderConfig {
id: "test".to_string(),
name: "Test Provider".to_string(),
base_url: "https://api.test.com/v1".to_string(),
field_mappings:
crate::providers::openai_compatible::registry::ProviderFieldMappings::default(),
capabilities: vec!["chat".to_string()],
default_model: Some("test-model".to_string()),
supports_reasoning: false,
};
let config = OpenAiCompatibleConfig::new(
"test",
"test-key",
"https://api.test.com/v1",
Arc::new(ConfigurableAdapter::new(provider_config)),
);
assert_eq!(config.provider_id, "test");
assert_eq!(config.api_key, "test-key");
assert_eq!(config.base_url, "https://api.test.com/v1");
}
#[test]
fn test_config_with_model() {
let provider_config = crate::providers::openai_compatible::registry::ProviderConfig {
id: "test".to_string(),
name: "Test Provider".to_string(),
base_url: "https://api.test.com/v1".to_string(),
field_mappings:
crate::providers::openai_compatible::registry::ProviderFieldMappings::default(),
capabilities: vec!["chat".to_string()],
default_model: Some("test-model".to_string()),
supports_reasoning: false,
};
let config = OpenAiCompatibleConfig::new(
"test",
"test-key",
"https://api.test.com/v1",
Arc::new(ConfigurableAdapter::new(provider_config)),
)
.with_model("test-model");
assert_eq!(config.model, "test-model");
}
#[test]
fn test_config_validation() {
let create_provider_config =
|| crate::providers::openai_compatible::registry::ProviderConfig {
id: "test".to_string(),
name: "Test Provider".to_string(),
base_url: "https://api.test.com/v1".to_string(),
field_mappings:
crate::providers::openai_compatible::registry::ProviderFieldMappings::default(),
capabilities: vec!["chat".to_string()],
default_model: Some("test-model".to_string()),
supports_reasoning: false,
};
let config = OpenAiCompatibleConfig::new(
"test",
"test-key",
"https://api.test.com/v1",
Arc::new(ConfigurableAdapter::new(create_provider_config())),
);
assert!(config.validate().is_ok());
let config = OpenAiCompatibleConfig::new(
"",
"test-key",
"https://api.test.com/v1",
Arc::new(ConfigurableAdapter::new(create_provider_config())),
);
assert!(config.validate().is_err());
let config = OpenAiCompatibleConfig::new(
"test",
"",
"https://api.test.com/v1",
Arc::new(ConfigurableAdapter::new(create_provider_config())),
);
assert!(config.validate().is_err());
let config = OpenAiCompatibleConfig::new(
"test",
"test-key",
"invalid-url",
Arc::new(ConfigurableAdapter::new(create_provider_config())),
);
assert!(config.validate().is_err());
}
#[test]
fn test_config_with_header() {
let provider_config = crate::providers::openai_compatible::registry::ProviderConfig {
id: "test".to_string(),
name: "Test Provider".to_string(),
base_url: "https://api.test.com/v1".to_string(),
field_mappings:
crate::providers::openai_compatible::registry::ProviderFieldMappings::default(),
capabilities: vec!["chat".to_string()],
default_model: Some("test-model".to_string()),
supports_reasoning: false,
};
let config = OpenAiCompatibleConfig::new(
"test",
"test-key",
"https://api.test.com/v1",
Arc::new(ConfigurableAdapter::new(provider_config)),
)
.with_header("X-Custom", "test-value")
.unwrap();
assert!(config.custom_headers.contains_key("X-Custom"));
}
}