use crate::error::{LlmError, LlmResult};
use crate::internals::retry::RetryPolicy;
use crate::logging::log_debug;
use serde::{Deserialize, Serialize};
use std::any::Any;
use std::fmt::Debug;
pub trait ProviderConfig: Send + Sync + Debug + Any {
fn provider_name(&self) -> &'static str;
fn max_context_tokens(&self) -> usize;
fn validate(&self) -> LlmResult<()>;
fn base_url(&self) -> &str;
fn api_key(&self) -> Option<&str>;
fn default_model(&self) -> &str;
fn as_any(&self) -> &dyn Any;
fn retry_policy(&self) -> &RetryPolicy;
}
#[derive(Debug)]
pub struct LLMConfig {
pub provider: Box<dyn ProviderConfig>,
pub default_params: DefaultLLMParams,
}
impl LLMConfig {
fn clone_provider(&self) -> Box<dyn ProviderConfig> {
let any_ref = self.provider.as_any();
if let Some(config) = any_ref.downcast_ref::<AnthropicConfig>() {
return Box::new(config.clone());
}
if let Some(config) = any_ref.downcast_ref::<OpenAIConfig>() {
return Box::new(config.clone());
}
if let Some(config) = any_ref.downcast_ref::<LMStudioConfig>() {
return Box::new(config.clone());
}
if let Some(config) = any_ref.downcast_ref::<OllamaConfig>() {
return Box::new(config.clone());
}
unreachable!("Unknown provider type - all provider types should be handled")
}
}
impl Clone for LLMConfig {
fn clone(&self) -> Self {
Self {
provider: self.clone_provider(),
default_params: self.default_params.clone(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DefaultLLMParams {
pub temperature: f64,
pub max_tokens: u32,
pub top_p: f64,
pub top_k: u32,
pub min_p: f64,
pub presence_penalty: f64,
}
impl Default for DefaultLLMParams {
fn default() -> Self {
Self {
temperature: 0.7,
max_tokens: 1000,
top_p: 0.9,
top_k: 40,
min_p: 0.05,
presence_penalty: 0.0,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AnthropicConfig {
pub api_key: Option<String>,
pub base_url: String,
pub default_model: String,
pub max_context_tokens: usize,
pub retry_policy: RetryPolicy,
pub enable_prompt_caching: bool,
pub cache_ttl: String,
}
impl Default for AnthropicConfig {
fn default() -> Self {
Self {
api_key: None,
base_url: "https://api.anthropic.com".to_string(),
default_model: "claude-3-5-sonnet-20241022".to_string(),
max_context_tokens: 200_000,
retry_policy: RetryPolicy::default(),
enable_prompt_caching: true, cache_ttl: "1h".to_string(), }
}
}
impl ProviderConfig for AnthropicConfig {
fn provider_name(&self) -> &'static str {
"anthropic"
}
fn max_context_tokens(&self) -> usize {
self.max_context_tokens
}
fn validate(&self) -> LlmResult<()> {
if self.api_key.is_none() {
return Err(LlmError::configuration_error(
"Anthropic API key is required",
));
}
Ok(())
}
fn base_url(&self) -> &str {
&self.base_url
}
fn api_key(&self) -> Option<&str> {
self.api_key.as_deref()
}
fn default_model(&self) -> &str {
&self.default_model
}
fn as_any(&self) -> &dyn Any {
self
}
fn retry_policy(&self) -> &RetryPolicy {
&self.retry_policy
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenAIConfig {
pub api_key: Option<String>,
pub base_url: String,
pub default_model: String,
pub max_context_tokens: usize,
pub retry_policy: RetryPolicy,
}
impl Default for OpenAIConfig {
fn default() -> Self {
Self {
api_key: None,
base_url: "https://api.openai.com".to_string(),
default_model: "gpt-4".to_string(),
max_context_tokens: 128_000,
retry_policy: RetryPolicy::default(),
}
}
}
impl ProviderConfig for OpenAIConfig {
fn provider_name(&self) -> &'static str {
"openai"
}
fn max_context_tokens(&self) -> usize {
self.max_context_tokens
}
fn validate(&self) -> LlmResult<()> {
if self.api_key.is_none() {
return Err(LlmError::configuration_error("OpenAI API key is required"));
}
Ok(())
}
fn base_url(&self) -> &str {
&self.base_url
}
fn api_key(&self) -> Option<&str> {
self.api_key.as_deref()
}
fn default_model(&self) -> &str {
&self.default_model
}
fn as_any(&self) -> &dyn Any {
self
}
fn retry_policy(&self) -> &RetryPolicy {
&self.retry_policy
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LMStudioConfig {
pub base_url: String,
pub default_model: String,
pub max_context_tokens: usize,
pub retry_policy: RetryPolicy,
}
impl Default for LMStudioConfig {
fn default() -> Self {
Self {
base_url: "http://localhost:1234".to_string(),
default_model: "local-model".to_string(),
max_context_tokens: 4_096,
retry_policy: RetryPolicy::default(),
}
}
}
impl ProviderConfig for LMStudioConfig {
fn provider_name(&self) -> &'static str {
"lmstudio"
}
fn max_context_tokens(&self) -> usize {
self.max_context_tokens
}
fn validate(&self) -> LlmResult<()> {
if self.base_url.is_empty() {
return Err(LlmError::configuration_error(
"LM Studio base URL is required",
));
}
Ok(())
}
fn base_url(&self) -> &str {
&self.base_url
}
fn api_key(&self) -> Option<&str> {
None }
fn default_model(&self) -> &str {
&self.default_model
}
fn as_any(&self) -> &dyn Any {
self
}
fn retry_policy(&self) -> &RetryPolicy {
&self.retry_policy
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OllamaConfig {
pub base_url: String,
pub default_model: String,
pub max_context_tokens: usize,
pub retry_policy: RetryPolicy,
}
impl Default for OllamaConfig {
fn default() -> Self {
Self {
base_url: "http://localhost:11434".to_string(),
default_model: "llama2".to_string(),
max_context_tokens: 4_096,
retry_policy: RetryPolicy::default(),
}
}
}
impl ProviderConfig for OllamaConfig {
fn provider_name(&self) -> &'static str {
"ollama"
}
fn max_context_tokens(&self) -> usize {
self.max_context_tokens
}
fn validate(&self) -> LlmResult<()> {
if self.base_url.is_empty() {
return Err(LlmError::configuration_error("Ollama base URL is required"));
}
Ok(())
}
fn base_url(&self) -> &str {
&self.base_url
}
fn api_key(&self) -> Option<&str> {
None }
fn default_model(&self) -> &str {
&self.default_model
}
fn as_any(&self) -> &dyn Any {
self
}
fn retry_policy(&self) -> &RetryPolicy {
&self.retry_policy
}
}
impl LLMConfig {
pub fn create_provider(
provider_name: &str,
api_key: Option<String>,
base_url: Option<String>,
model: Option<String>,
) -> LlmResult<Self> {
log_debug!(
provider = %provider_name,
has_api_key = api_key.is_some(),
has_base_url = base_url.is_some(),
has_model = model.is_some(),
"Creating provider configuration"
);
let provider: Box<dyn ProviderConfig> = match provider_name.to_lowercase().as_str() {
"anthropic" => Self::create_anthropic_provider(api_key, base_url, model),
"openai" => Self::create_openai_provider(api_key, base_url, model),
"lmstudio" => Self::create_lmstudio_provider(base_url, model),
"ollama" => Self::create_ollama_provider(base_url, model),
_ => {
return Err(LlmError::configuration_error(format!(
"Unsupported provider: {}. Supported providers: anthropic, openai, lmstudio, ollama",
provider_name
)));
}
};
provider.validate()?;
Ok(Self {
provider,
default_params: DefaultLLMParams::default(),
})
}
fn create_anthropic_provider(
api_key: Option<String>,
base_url: Option<String>,
model: Option<String>,
) -> Box<dyn ProviderConfig> {
let mut config = AnthropicConfig::default();
if let Some(key) = api_key {
config.api_key = Some(key);
} else if let Ok(env_key) = std::env::var("ANTHROPIC_API_KEY") {
config.api_key = Some(env_key);
}
if let Some(url) = base_url {
config.base_url = url;
}
if let Some(m) = model {
config.default_model = m;
}
Box::new(config)
}
fn create_openai_provider(
api_key: Option<String>,
base_url: Option<String>,
model: Option<String>,
) -> Box<dyn ProviderConfig> {
let mut config = OpenAIConfig::default();
if let Some(key) = api_key {
config.api_key = Some(key);
}
if let Some(url) = base_url {
config.base_url = url;
}
if let Some(m) = model {
config.default_model = m;
}
Box::new(config)
}
fn create_lmstudio_provider(
base_url: Option<String>,
model: Option<String>,
) -> Box<dyn ProviderConfig> {
let mut config = LMStudioConfig::default();
if let Some(url) = base_url {
config.base_url = url;
}
if let Some(m) = model {
config.default_model = m;
}
Box::new(config)
}
fn create_ollama_provider(
base_url: Option<String>,
model: Option<String>,
) -> Box<dyn ProviderConfig> {
let mut config = OllamaConfig::default();
if let Some(url) = base_url {
config.base_url = url;
}
if let Some(m) = model {
config.default_model = m;
}
Box::new(config)
}
pub fn from_env() -> LlmResult<Self> {
let provider_name =
std::env::var("AI_PROVIDER").unwrap_or_else(|_| "anthropic".to_string());
log_debug!(
target_provider = %provider_name,
"Loading LLM configuration from environment"
);
let provider: Box<dyn ProviderConfig> = match provider_name.as_str() {
"anthropic" => Self::anthropic_from_env(),
"openai" => Self::openai_from_env(),
"lmstudio" => Self::lmstudio_from_env(),
_ => {
return Err(LlmError::unsupported_provider(provider_name));
}
};
provider.validate()?;
log_debug!(
provider = provider.provider_name(),
max_context_tokens = provider.max_context_tokens(),
base_url = provider.base_url(),
has_api_key = provider.api_key().is_some(),
"LLM configuration loaded and validated"
);
Ok(Self {
provider,
default_params: DefaultLLMParams::default(),
})
}
fn anthropic_from_env() -> Box<dyn ProviderConfig> {
let mut config = AnthropicConfig::default();
if let Ok(api_key) = std::env::var("ANTHROPIC_API_KEY") {
config.api_key = Some(api_key);
}
Box::new(config)
}
fn openai_from_env() -> Box<dyn ProviderConfig> {
let mut config = OpenAIConfig::default();
if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
config.api_key = Some(api_key);
}
if let Ok(base_url) = std::env::var("OPENAI_BASE_URL") {
config.base_url = base_url;
}
Box::new(config)
}
fn lmstudio_from_env() -> Box<dyn ProviderConfig> {
let mut config = LMStudioConfig::default();
if let Ok(base_url) = std::env::var("LM_STUDIO_BASE_URL") {
config.base_url = base_url;
} else if let Ok(base_url) = std::env::var("OPENAI_BASE_URL") {
config.base_url = base_url;
}
Box::new(config)
}
}