use crate::{RsllmError, RsllmResult, Provider};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::Duration;
use url::Url;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClientConfig {
pub provider: ProviderConfig,
pub model: ModelConfig,
pub http: HttpConfig,
pub retry: RetryConfig,
pub headers: HashMap<String, String>,
}
impl Default for ClientConfig {
fn default() -> Self {
Self {
provider: ProviderConfig::default(),
model: ModelConfig::default(),
http: HttpConfig::default(),
retry: RetryConfig::default(),
headers: HashMap::new(),
}
}
}
impl ClientConfig {
pub fn builder() -> ClientConfigBuilder {
ClientConfigBuilder::new()
}
pub fn from_env() -> RsllmResult<Self> {
dotenv::dotenv().ok();
let mut config = Self::default();
if let Ok(provider_str) = std::env::var("RSLLM_PROVIDER") {
config.provider.provider = provider_str.parse()?;
}
if let Ok(api_key) = std::env::var("RSLLM_API_KEY") {
config.provider.api_key = Some(api_key);
}
if let Ok(base_url) = std::env::var("RSLLM_BASE_URL") {
config.provider.base_url = Some(base_url.parse()?);
}
if let Ok(model) = std::env::var("RSLLM_MODEL") {
config.model.model = model;
}
if let Ok(temp_str) = std::env::var("RSLLM_TEMPERATURE") {
config.model.temperature = Some(temp_str.parse().map_err(|_|
RsllmError::configuration("Invalid temperature value"))?);
}
if let Ok(max_tokens_str) = std::env::var("RSLLM_MAX_TOKENS") {
config.model.max_tokens = Some(max_tokens_str.parse().map_err(|_|
RsllmError::configuration("Invalid max_tokens value"))?);
}
if let Ok(timeout_str) = std::env::var("RSLLM_TIMEOUT") {
let timeout_secs: u64 = timeout_str.parse().map_err(|_|
RsllmError::configuration("Invalid timeout value"))?;
config.http.timeout = Duration::from_secs(timeout_secs);
}
Ok(config)
}
pub fn validate(&self) -> RsllmResult<()> {
self.provider.validate()?;
self.model.validate()?;
self.http.validate()?;
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderConfig {
pub provider: Provider,
pub api_key: Option<String>,
pub base_url: Option<Url>,
pub organization_id: Option<String>,
pub custom_settings: HashMap<String, serde_json::Value>,
}
impl Default for ProviderConfig {
fn default() -> Self {
Self {
provider: Provider::OpenAI,
api_key: None,
base_url: None,
organization_id: None,
custom_settings: HashMap::new(),
}
}
}
impl ProviderConfig {
pub fn validate(&self) -> RsllmResult<()> {
match self.provider {
Provider::OpenAI | Provider::Claude => {
if self.api_key.is_none() {
return Err(RsllmError::configuration(
format!("API key required for provider: {:?}", self.provider)
));
}
}
Provider::Ollama => {
}
}
if let Some(url) = &self.base_url {
if url.scheme() != "http" && url.scheme() != "https" {
return Err(RsllmError::configuration(
"Base URL must use HTTP or HTTPS scheme"
));
}
}
Ok(())
}
pub fn effective_base_url(&self) -> RsllmResult<Url> {
if let Some(url) = &self.base_url {
Ok(url.clone())
} else {
Ok(self.provider.default_base_url())
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelConfig {
pub model: String,
pub temperature: Option<f32>,
pub max_tokens: Option<u32>,
pub top_p: Option<f32>,
pub frequency_penalty: Option<f32>,
pub presence_penalty: Option<f32>,
pub stop: Option<Vec<String>>,
pub stream: bool,
}
impl Default for ModelConfig {
fn default() -> Self {
Self {
model: "gpt-3.5-turbo".to_string(),
temperature: Some(0.7),
max_tokens: None,
top_p: None,
frequency_penalty: None,
presence_penalty: None,
stop: None,
stream: false,
}
}
}
impl ModelConfig {
pub fn validate(&self) -> RsllmResult<()> {
if self.model.is_empty() {
return Err(RsllmError::validation("model", "Model name cannot be empty"));
}
if let Some(temp) = self.temperature {
if !(0.0..=2.0).contains(&temp) {
return Err(RsllmError::validation(
"temperature",
"Temperature must be between 0.0 and 2.0"
));
}
}
if let Some(top_p) = self.top_p {
if !(0.0..=1.0).contains(&top_p) {
return Err(RsllmError::validation(
"top_p",
"Top-p must be between 0.0 and 1.0"
));
}
}
if let Some(freq_penalty) = self.frequency_penalty {
if !(-2.0..=2.0).contains(&freq_penalty) {
return Err(RsllmError::validation(
"frequency_penalty",
"Frequency penalty must be between -2.0 and 2.0"
));
}
}
if let Some(pres_penalty) = self.presence_penalty {
if !(-2.0..=2.0).contains(&pres_penalty) {
return Err(RsllmError::validation(
"presence_penalty",
"Presence penalty must be between -2.0 and 2.0"
));
}
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HttpConfig {
pub timeout: Duration,
pub connect_timeout: Duration,
pub max_redirects: u32,
pub user_agent: String,
pub verify_tls: bool,
}
impl Default for HttpConfig {
fn default() -> Self {
Self {
timeout: Duration::from_secs(30),
connect_timeout: Duration::from_secs(10),
max_redirects: 5,
user_agent: format!("rsllm/{}", crate::VERSION),
verify_tls: true,
}
}
}
impl HttpConfig {
pub fn validate(&self) -> RsllmResult<()> {
if self.timeout.as_secs() == 0 {
return Err(RsllmError::validation("timeout", "Timeout must be greater than 0"));
}
if self.connect_timeout.as_secs() == 0 {
return Err(RsllmError::validation(
"connect_timeout",
"Connect timeout must be greater than 0"
));
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetryConfig {
pub max_retries: u32,
pub base_delay: Duration,
pub max_delay: Duration,
pub backoff_multiplier: f32,
pub jitter: bool,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: 3,
base_delay: Duration::from_millis(500),
max_delay: Duration::from_secs(30),
backoff_multiplier: 2.0,
jitter: true,
}
}
}
pub struct ClientConfigBuilder {
config: ClientConfig,
}
impl ClientConfigBuilder {
pub fn new() -> Self {
Self {
config: ClientConfig::default(),
}
}
pub fn provider(mut self, provider: Provider) -> Self {
self.config.provider.provider = provider;
self
}
pub fn api_key(mut self, api_key: impl Into<String>) -> Self {
self.config.provider.api_key = Some(api_key.into());
self
}
pub fn base_url(mut self, base_url: impl AsRef<str>) -> RsllmResult<Self> {
self.config.provider.base_url = Some(base_url.as_ref().parse()?);
Ok(self)
}
pub fn model(mut self, model: impl Into<String>) -> Self {
self.config.model.model = model.into();
self
}
pub fn temperature(mut self, temperature: f32) -> Self {
self.config.model.temperature = Some(temperature);
self
}
pub fn max_tokens(mut self, max_tokens: u32) -> Self {
self.config.model.max_tokens = Some(max_tokens);
self
}
pub fn stream(mut self, stream: bool) -> Self {
self.config.model.stream = stream;
self
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.config.http.timeout = timeout;
self
}
pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.config.headers.insert(key.into(), value.into());
self
}
pub fn build(self) -> RsllmResult<ClientConfig> {
self.config.validate()?;
Ok(self.config)
}
}
impl Default for ClientConfigBuilder {
fn default() -> Self {
Self::new()
}
}