use std::collections::HashMap;
use std::fmt;
use camel_component_api::CamelError;
use camel_component_api::NetworkRetryPolicy;
use crate::cost::PricingTable;
fn default_max_prompt_bytes() -> usize {
32768
}
fn default_mock_response() -> String {
"echo".into()
}
fn default_mock_model() -> String {
"mock-model".into()
}
#[derive(Clone, Debug, serde::Deserialize)]
pub struct LlmGlobalConfig {
#[serde(default)]
pub default_provider: Option<String>,
#[serde(default)]
pub timeout_secs: Option<u64>,
#[serde(default = "default_max_prompt_bytes")]
pub max_prompt_bytes: usize,
#[serde(default)]
pub providers: HashMap<String, LlmProviderConfig>,
}
impl Default for LlmGlobalConfig {
fn default() -> Self {
Self {
default_provider: None,
timeout_secs: None,
max_prompt_bytes: default_max_prompt_bytes(),
providers: HashMap::new(),
}
}
}
#[derive(Clone, serde::Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum LlmProviderConfig {
Openai(OpenaiProviderConfig),
Ollama(OllamaProviderConfig),
Mock(MockProviderConfig),
}
impl LlmProviderConfig {
pub fn max_concurrency(&self) -> Option<usize> {
match self {
LlmProviderConfig::Openai(c) => c.max_concurrency,
LlmProviderConfig::Ollama(c) => c.max_concurrency,
LlmProviderConfig::Mock(_) => None,
}
}
pub fn timeout_secs(&self) -> Option<u64> {
match self {
LlmProviderConfig::Openai(c) => c.timeout_secs,
LlmProviderConfig::Ollama(c) => c.timeout_secs,
LlmProviderConfig::Mock(_) => None,
}
}
pub fn network_retry(&self) -> Option<NetworkRetryPolicy> {
match self {
LlmProviderConfig::Openai(c) => c.network_retry.clone(),
LlmProviderConfig::Ollama(c) => c.network_retry.clone(),
LlmProviderConfig::Mock(_) => None,
}
}
pub fn pricing(&self) -> Option<PricingTable> {
match self {
LlmProviderConfig::Openai(c) => c.pricing.clone(),
LlmProviderConfig::Ollama(c) => c.pricing.clone(),
LlmProviderConfig::Mock(_) => None,
}
}
pub fn cache_ttl_secs(&self) -> Option<u64> {
match self {
LlmProviderConfig::Openai(c) => c.cache_ttl_secs,
LlmProviderConfig::Ollama(c) => c.cache_ttl_secs,
LlmProviderConfig::Mock(_) => None,
}
}
pub fn cache_max_entries(&self) -> Option<usize> {
match self {
LlmProviderConfig::Openai(c) => c.cache_max_entries,
LlmProviderConfig::Ollama(c) => c.cache_max_entries,
LlmProviderConfig::Mock(_) => None,
}
}
}
impl fmt::Debug for LlmProviderConfig {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
LlmProviderConfig::Openai(c) => f
.debug_struct("Openai")
.field("api_key", &"[REDACTED]")
.field("base_url", &c.base_url)
.field("default_model", &c.default_model)
.field("timeout_secs", &c.timeout_secs)
.field("max_concurrency", &c.max_concurrency)
.field("network_retry", &c.network_retry)
.field("pricing", &c.pricing)
.field("cache_ttl_secs", &c.cache_ttl_secs)
.field("cache_max_entries", &c.cache_max_entries)
.finish(),
LlmProviderConfig::Ollama(c) => f
.debug_struct("Ollama")
.field("base_url", &c.base_url)
.field("default_model", &c.default_model)
.field("timeout_secs", &c.timeout_secs)
.field("max_concurrency", &c.max_concurrency)
.field("network_retry", &c.network_retry)
.field("pricing", &c.pricing)
.field("cache_ttl_secs", &c.cache_ttl_secs)
.field("cache_max_entries", &c.cache_max_entries)
.finish(),
LlmProviderConfig::Mock(c) => f
.debug_struct("Mock")
.field("response", &c.response)
.field("default_model", &c.default_model)
.field("error_message", &c.error_message)
.finish(),
}
}
}
#[derive(Clone, serde::Deserialize)]
pub struct OpenaiProviderConfig {
pub api_key: String,
#[serde(default)]
pub base_url: Option<String>,
pub default_model: String,
#[serde(default)]
pub timeout_secs: Option<u64>,
#[serde(default)]
pub max_concurrency: Option<usize>,
#[serde(default)]
pub network_retry: Option<NetworkRetryPolicy>,
#[serde(default)]
pub pricing: Option<PricingTable>,
#[serde(default)]
pub cache_ttl_secs: Option<u64>,
#[serde(default)]
pub cache_max_entries: Option<usize>,
}
#[derive(Clone, Debug, serde::Deserialize)]
pub struct OllamaProviderConfig {
pub base_url: String,
pub default_model: String,
#[serde(default)]
pub timeout_secs: Option<u64>,
#[serde(default)]
pub max_concurrency: Option<usize>,
#[serde(default)]
pub network_retry: Option<NetworkRetryPolicy>,
#[serde(default)]
pub pricing: Option<PricingTable>,
#[serde(default)]
pub cache_ttl_secs: Option<u64>,
#[serde(default)]
pub cache_max_entries: Option<usize>,
}
#[derive(Clone, Debug, serde::Deserialize)]
pub struct MockProviderConfig {
#[serde(default = "default_mock_response")]
pub response: String,
#[serde(default = "default_mock_model")]
pub default_model: String,
#[serde(default)]
pub error_message: Option<String>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum LlmOperation {
Chat,
Embed,
}
impl LlmGlobalConfig {
pub fn validate(&self) -> Result<(), CamelError> {
if self.timeout_secs == Some(0) {
return Err(CamelError::Config(
"global timeout_secs must be > 0 when set (got 0)".into(),
));
}
for (name, provider) in &self.providers {
if provider.timeout_secs() == Some(0) {
return Err(CamelError::Config(format!(
"provider '{name}' timeout_secs must be > 0 when set (got 0)"
)));
}
if provider.max_concurrency() == Some(0) {
return Err(CamelError::Config(format!(
"provider '{name}' max_concurrency must be > 0 when set (got 0)"
)));
}
if let Some(p) = provider.pricing()
&& (p.input_per_1k_tokens < 0.0 || p.output_per_1k_tokens < 0.0)
{
return Err(CamelError::Config(format!(
"provider '{name}' pricing has negative values: input={}, output={}",
p.input_per_1k_tokens, p.output_per_1k_tokens
)));
}
if provider.cache_ttl_secs() == Some(0) {
return Err(CamelError::Config(format!(
"provider '{name}' cache_ttl_secs must be > 0 when set (got 0)"
)));
}
if provider.cache_max_entries() == Some(0) {
return Err(CamelError::Config(format!(
"provider '{name}' cache_max_entries must be > 0 when set (got 0)"
)));
}
}
Ok(())
}
}
#[derive(Clone, Debug)]
pub struct LlmEndpointConfig {
pub operation: LlmOperation,
pub provider: Option<String>,
pub model: Option<String>,
pub temperature: Option<f64>,
pub max_tokens: Option<u32>,
pub stream: bool,
pub system_prompt: Option<String>,
}
impl Default for LlmEndpointConfig {
fn default() -> Self {
Self {
operation: LlmOperation::Chat,
provider: None,
model: None,
temperature: None,
max_tokens: None,
stream: true,
system_prompt: None,
}
}
}
impl LlmEndpointConfig {
pub fn from_uri(uri: &str) -> Result<Self, CamelError> {
let (operation_str, query) = match uri.split_once('?') {
Some((path, q)) => (path, q),
None => (uri, ""),
};
let operation = match operation_str.trim_start_matches("llm:") {
"chat" => LlmOperation::Chat,
"embed" => LlmOperation::Embed,
other => {
return Err(CamelError::InvalidUri(format!(
"unknown llm operation: '{other}' (expected 'chat' or 'embed')"
)));
}
};
let params: HashMap<String, String> = url::form_urlencoded::parse(query.as_bytes())
.into_owned()
.collect();
let stream = params
.get("stream")
.map(|v| v == "true" || v == "1")
.unwrap_or(true);
Ok(Self {
operation,
provider: params.get("provider").cloned(),
model: params.get("model").cloned(),
temperature: params.get("temperature").and_then(|v| v.parse().ok()),
max_tokens: params.get("max_tokens").and_then(|v| v.parse().ok()),
stream,
system_prompt: params.get("system_prompt").cloned(),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rejects_zero_global_timeout() {
let cfg = LlmGlobalConfig {
default_provider: None,
timeout_secs: Some(0), max_prompt_bytes: 32768,
providers: HashMap::new(),
};
assert!(
cfg.validate().is_err(),
"Some(0) global timeout_secs must be rejected"
);
}
#[test]
fn accepts_none_global_timeout() {
let cfg = LlmGlobalConfig {
default_provider: None,
timeout_secs: None, max_prompt_bytes: 32768,
providers: HashMap::new(),
};
assert!(
cfg.validate().is_ok(),
"None global timeout_secs must be valid"
);
}
#[test]
fn rejects_zero_provider_timeout() {
let mut providers = HashMap::new();
providers.insert(
"bad".into(),
LlmProviderConfig::Openai(OpenaiProviderConfig {
api_key: "sk-test".into(),
base_url: None,
default_model: "gpt-4o".into(),
timeout_secs: Some(0), max_concurrency: None,
network_retry: None,
pricing: None,
cache_ttl_secs: None,
cache_max_entries: None,
}),
);
let cfg = LlmGlobalConfig {
default_provider: None,
timeout_secs: None,
max_prompt_bytes: 32768,
providers,
};
assert!(
cfg.validate().is_err(),
"Some(0) provider timeout_secs must be rejected"
);
}
#[test]
fn rejects_zero_max_concurrency() {
let mut providers = HashMap::new();
providers.insert(
"bad".into(),
LlmProviderConfig::Openai(OpenaiProviderConfig {
api_key: "sk-test".into(),
base_url: None,
default_model: "gpt-4o".into(),
timeout_secs: None,
max_concurrency: Some(0), network_retry: None,
pricing: None,
cache_ttl_secs: None,
cache_max_entries: None,
}),
);
let cfg = LlmGlobalConfig {
default_provider: None,
timeout_secs: None,
max_prompt_bytes: 32768,
providers,
};
assert!(
cfg.validate().is_err(),
"Some(0) max_concurrency must be rejected"
);
}
#[test]
fn rejects_zero_ollama_timeout() {
let mut providers = HashMap::new();
providers.insert(
"bad".into(),
LlmProviderConfig::Ollama(OllamaProviderConfig {
base_url: "http://localhost:11434".into(),
default_model: "llama3".into(),
timeout_secs: Some(0), max_concurrency: None,
network_retry: None,
pricing: None,
cache_ttl_secs: None,
cache_max_entries: None,
}),
);
let cfg = LlmGlobalConfig {
default_provider: None,
timeout_secs: None,
max_prompt_bytes: 32768,
providers,
};
assert!(
cfg.validate().is_err(),
"Some(0) Ollama timeout_secs must be rejected"
);
}
#[test]
fn rejects_negative_pricing() {
let mut providers = HashMap::new();
providers.insert(
"bad".into(),
LlmProviderConfig::Openai(OpenaiProviderConfig {
api_key: "sk-test".into(),
base_url: None,
default_model: "gpt-4o".into(),
timeout_secs: None,
max_concurrency: None,
network_retry: None,
pricing: Some(PricingTable {
input_per_1k_tokens: -0.01,
output_per_1k_tokens: 0.03,
}),
cache_ttl_secs: None,
cache_max_entries: None,
}),
);
let cfg = LlmGlobalConfig {
default_provider: None,
timeout_secs: None,
max_prompt_bytes: 32768,
providers,
};
assert!(
cfg.validate().is_err(),
"Negative input_per_1k_tokens must be rejected"
);
}
#[test]
fn rejects_zero_cache_ttl() {
let mut providers = HashMap::new();
providers.insert(
"bad-cache".into(),
LlmProviderConfig::Openai(OpenaiProviderConfig {
api_key: "sk-test".into(),
base_url: None,
default_model: "gpt-4o".into(),
timeout_secs: None,
max_concurrency: None,
network_retry: None,
pricing: None,
cache_ttl_secs: Some(0), cache_max_entries: None,
}),
);
let cfg = LlmGlobalConfig {
default_provider: None,
timeout_secs: None,
max_prompt_bytes: 32768,
providers,
};
assert!(
cfg.validate().is_err(),
"Some(0) cache_ttl_secs must be rejected"
);
}
#[test]
fn rejects_zero_cache_max_entries() {
let mut providers = HashMap::new();
providers.insert(
"bad-entries".into(),
LlmProviderConfig::Openai(OpenaiProviderConfig {
api_key: "sk-test".into(),
base_url: None,
default_model: "gpt-4o".into(),
timeout_secs: None,
max_concurrency: None,
network_retry: None,
pricing: None,
cache_ttl_secs: None,
cache_max_entries: Some(0), }),
);
let cfg = LlmGlobalConfig {
default_provider: None,
timeout_secs: None,
max_prompt_bytes: 32768,
providers,
};
assert!(
cfg.validate().is_err(),
"Some(0) cache_max_entries must be rejected"
);
}
#[test]
fn accepts_valid_provider_config() {
let mut providers = HashMap::new();
providers.insert(
"valid".into(),
LlmProviderConfig::Openai(OpenaiProviderConfig {
api_key: "sk-test".into(),
base_url: None,
default_model: "gpt-4o".into(),
timeout_secs: Some(30),
max_concurrency: Some(5),
network_retry: None,
pricing: None,
cache_ttl_secs: None,
cache_max_entries: None,
}),
);
let cfg = LlmGlobalConfig {
default_provider: None,
timeout_secs: Some(60),
max_prompt_bytes: 32768,
providers,
};
assert!(
cfg.validate().is_ok(),
"Valid config with non-zero timeouts and concurrency must pass"
);
}
#[test]
fn validate_error_contains_provider_name() {
let mut providers = HashMap::new();
providers.insert(
"my-openai".into(),
LlmProviderConfig::Openai(OpenaiProviderConfig {
api_key: "sk-test".into(),
base_url: None,
default_model: "gpt-4o".into(),
timeout_secs: Some(0),
max_concurrency: None,
network_retry: None,
pricing: None,
cache_ttl_secs: None,
cache_max_entries: None,
}),
);
let cfg = LlmGlobalConfig {
default_provider: None,
timeout_secs: None,
max_prompt_bytes: 32768,
providers,
};
let err = cfg.validate().unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("my-openai"),
"validation error should include provider name: {msg}"
);
}
#[test]
fn deserialize_global_config_with_providers() {
let toml_str = r#"
default_provider = "my-openai"
[providers.my-openai]
type = "openai"
api_key = "secret-key"
default_model = "gpt-4o"
[providers.local]
type = "ollama"
base_url = "http://localhost:11434"
default_model = "llama3"
[providers.test]
type = "mock"
response = "echo"
default_model = "mock-model"
"#;
let cfg: LlmGlobalConfig = toml::from_str(toml_str).expect("parse");
assert_eq!(cfg.default_provider.as_deref(), Some("my-openai"));
assert_eq!(cfg.providers.len(), 3);
assert!(matches!(
cfg.providers["my-openai"],
LlmProviderConfig::Openai(_)
));
assert!(matches!(
cfg.providers["local"],
LlmProviderConfig::Ollama(_)
));
assert!(matches!(cfg.providers["test"], LlmProviderConfig::Mock(_)));
}
#[test]
fn default_config_has_no_providers() {
let cfg = LlmGlobalConfig::default();
assert!(cfg.providers.is_empty());
assert_eq!(cfg.timeout_secs, None);
assert_eq!(cfg.max_prompt_bytes, 32768);
}
#[test]
fn debug_redacts_api_key() {
let cfg = LlmProviderConfig::Openai(OpenaiProviderConfig {
api_key: "sk-secret123".into(),
base_url: None,
default_model: "gpt-4o".into(),
timeout_secs: None,
max_concurrency: None,
network_retry: None,
pricing: None,
cache_ttl_secs: None,
cache_max_entries: None,
});
let debug_str = format!("{:?}", cfg);
assert!(debug_str.contains("[REDACTED]"));
assert!(!debug_str.contains("sk-secret123"));
}
#[test]
fn endpoint_config_from_uri_chat() {
let ec = LlmEndpointConfig::from_uri(
"llm:chat?provider=my-openai&model=gpt-4o&temperature=0.7&stream=false",
);
let ec = ec.expect("parse");
assert_eq!(ec.operation, LlmOperation::Chat);
assert_eq!(ec.provider.as_deref(), Some("my-openai"));
assert_eq!(ec.model.as_deref(), Some("gpt-4o"));
assert!(!ec.stream);
}
#[test]
fn endpoint_config_from_uri_embed() {
let ec = LlmEndpointConfig::from_uri("llm:embed?provider=local");
let ec = ec.expect("parse");
assert_eq!(ec.operation, LlmOperation::Embed);
assert!(ec.stream); }
#[test]
fn from_uri_unknown_operation_returns_invalid_uri() {
let result = LlmEndpointConfig::from_uri("llm:summarize?provider=x");
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, CamelError::InvalidUri(_)));
assert!(err.to_string().contains("summarize"));
}
#[test]
fn mock_config_with_error_message_deserializes() {
let toml_str = r#"
[providers.err]
type = "mock"
error_message = "boom"
"#;
let cfg: LlmGlobalConfig = toml::from_str(toml_str).expect("parse");
let mock_cfg = match &cfg.providers["err"] {
LlmProviderConfig::Mock(c) => c,
_ => panic!("expected Mock"),
};
assert_eq!(mock_cfg.error_message.as_deref(), Some("boom"));
}
#[test]
fn from_uri_stream_parsing() {
let ec = LlmEndpointConfig::from_uri("llm:chat?stream=false").unwrap();
assert!(!ec.stream);
let ec = LlmEndpointConfig::from_uri("llm:chat?stream=true").unwrap();
assert!(ec.stream);
let ec = LlmEndpointConfig::from_uri("llm:chat?stream=1").unwrap();
assert!(ec.stream);
let ec = LlmEndpointConfig::from_uri("llm:chat").unwrap();
assert!(ec.stream);
}
}