use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum LlmProvider {
OpenAI,
LiteRt,
Anthropic,
Ollama,
Gemini,
Mistral,
Bedrock,
Local,
Mock,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmConfig {
pub provider: LlmProvider,
pub model: String,
pub api_key: Option<String>,
pub endpoint: Option<String>,
pub temperature: f32,
pub max_tokens: u32,
pub streaming: bool,
pub timeout_seconds: u64,
pub max_retries: u32,
pub rate_limit_enabled: bool,
pub rate_limit_requests: u32,
pub rate_limit_interval_seconds: u64,
pub fallback_model: Option<String>,
pub fallback_api_key: Option<String>,
}
impl Default for LlmConfig {
fn default() -> Self {
Self {
provider: LlmProvider::OpenAI,
model: "gpt-4".to_string(),
api_key: None,
endpoint: None,
temperature: 0.0,
max_tokens: 16384,
streaming: false,
timeout_seconds: 120,
max_retries: 3,
rate_limit_enabled: false,
rate_limit_requests: 60,
rate_limit_interval_seconds: 60,
fallback_model: None,
fallback_api_key: None,
}
}
}
impl LlmConfig {
pub fn new(provider: LlmProvider, model: impl Into<String>) -> Self {
Self {
provider,
model: model.into(),
..Default::default()
}
}
pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
self.api_key = Some(api_key.into());
self
}
pub fn with_endpoint(mut self, endpoint: impl Into<String>) -> Self {
self.endpoint = Some(endpoint.into());
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = temperature;
self
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = max_tokens;
self
}
pub fn with_timeout_seconds(mut self, timeout_seconds: u64) -> Self {
self.timeout_seconds = timeout_seconds;
self
}
pub fn with_max_retries(mut self, max_retries: u32) -> Self {
self.max_retries = max_retries;
self
}
}