use crate::router::TargetModel;
use serde::{Deserialize, Serialize};
use std::path::Path;
use std::str::FromStr;
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Config {
pub server: ServerConfig,
pub models: ModelsConfig,
pub routing: RoutingConfig,
#[serde(default)]
pub observability: ObservabilityConfig,
#[serde(default)]
pub timeouts: TimeoutsConfig,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ServerConfig {
pub host: String,
pub port: u16,
#[serde(default = "default_request_timeout")]
pub request_timeout_seconds: u64,
}
fn default_request_timeout() -> u64 {
30
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ModelsConfig {
pub fast: Vec<ModelEndpoint>,
pub balanced: Vec<ModelEndpoint>,
pub deep: Vec<ModelEndpoint>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ModelEndpoint {
name: String,
base_url: String,
max_tokens: usize,
#[serde(default = "default_temperature")]
temperature: f64,
#[serde(default = "default_weight")]
weight: f64,
#[serde(default = "default_priority")]
priority: u8,
}
impl ModelEndpoint {
pub fn name(&self) -> &str {
&self.name
}
pub fn base_url(&self) -> &str {
&self.base_url
}
pub fn max_tokens(&self) -> usize {
self.max_tokens
}
pub fn temperature(&self) -> f64 {
self.temperature
}
pub fn weight(&self) -> f64 {
self.weight
}
pub fn priority(&self) -> u8 {
self.priority
}
}
fn default_temperature() -> f64 {
0.7
}
fn default_weight() -> f64 {
1.0
}
fn default_priority() -> u8 {
1
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct RouterTimeouts {
fast: u64,
balanced: u64,
deep: u64,
}
impl Default for RouterTimeouts {
fn default() -> Self {
Self {
fast: 10,
balanced: 10,
deep: 10,
}
}
}
impl RouterTimeouts {
pub fn new(fast: u64, balanced: u64, deep: u64) -> Result<Self, String> {
let timeouts = Self {
fast,
balanced,
deep,
};
timeouts.validate()?;
Ok(timeouts)
}
pub fn fast(&self) -> u64 {
self.fast
}
pub fn balanced(&self) -> u64 {
self.balanced
}
pub fn deep(&self) -> u64 {
self.deep
}
fn validate(&self) -> Result<(), String> {
if self.fast == 0 {
return Err("router_timeouts.fast must be greater than 0".to_string());
}
if self.balanced == 0 {
return Err("router_timeouts.balanced must be greater than 0".to_string());
}
if self.deep == 0 {
return Err("router_timeouts.deep must be greater than 0".to_string());
}
Ok(())
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct RoutingConfig {
pub strategy: RoutingStrategy,
#[serde(default)]
pub default_importance: crate::router::Importance,
#[serde(default)]
router_tier: TargetModel,
#[serde(default)]
pub router_timeouts: RouterTimeouts,
}
impl RoutingConfig {
pub fn router_tier(&self) -> TargetModel {
self.router_tier
}
pub fn router_timeout_for_tier(&self, tier: TargetModel) -> u64 {
match tier {
TargetModel::Fast => self.router_timeouts.fast(),
TargetModel::Balanced => self.router_timeouts.balanced(),
TargetModel::Deep => self.router_timeouts.deep(),
}
}
}
#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum RoutingStrategy {
Rule,
Llm,
Hybrid,
Tool,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ObservabilityConfig {
#[serde(default = "default_log_level")]
pub log_level: String,
}
impl Default for ObservabilityConfig {
fn default() -> Self {
Self {
log_level: default_log_level(),
}
}
}
fn default_log_level() -> String {
"info".to_string()
}
#[derive(Debug, Clone, Default, Serialize)]
pub struct TimeoutsConfig {
fast: Option<u64>,
balanced: Option<u64>,
deep: Option<u64>,
}
impl TimeoutsConfig {
pub fn new(
fast: Option<u64>,
balanced: Option<u64>,
deep: Option<u64>,
) -> crate::error::AppResult<Self> {
for (tier_name, timeout_opt) in [("fast", fast), ("balanced", balanced), ("deep", deep)] {
if let Some(timeout) = timeout_opt {
if timeout == 0 {
return Err(crate::error::AppError::Config(format!(
"timeouts.{} must be greater than 0, got {}",
tier_name, timeout
)));
}
if timeout > 300 {
return Err(crate::error::AppError::Config(format!(
"timeouts.{} cannot exceed 300 seconds (5 minutes), got {}. \
This configuration policy prevents connection pool exhaustion and ensures timely failure detection.",
tier_name, timeout
)));
}
}
}
Ok(Self {
fast,
balanced,
deep,
})
}
pub fn fast(&self) -> Option<u64> {
self.fast
}
pub fn balanced(&self) -> Option<u64> {
self.balanced
}
pub fn deep(&self) -> Option<u64> {
self.deep
}
}
impl<'de> Deserialize<'de> for TimeoutsConfig {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::{self, MapAccess, Visitor};
use std::fmt;
#[derive(Deserialize)]
#[serde(field_identifier, rename_all = "lowercase")]
enum Field {
Fast,
Balanced,
Deep,
}
struct TimeoutsConfigVisitor;
impl<'de> Visitor<'de> for TimeoutsConfigVisitor {
type Value = TimeoutsConfig;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a struct with optional timeout fields (fast, balanced, deep)")
}
fn visit_map<V>(self, mut map: V) -> Result<TimeoutsConfig, V::Error>
where
V: MapAccess<'de>,
{
let mut fast = None;
let mut balanced = None;
let mut deep = None;
while let Some(key) = map.next_key()? {
match key {
Field::Fast => {
if fast.is_some() {
return Err(de::Error::duplicate_field("fast"));
}
fast = Some(map.next_value()?);
}
Field::Balanced => {
if balanced.is_some() {
return Err(de::Error::duplicate_field("balanced"));
}
balanced = Some(map.next_value()?);
}
Field::Deep => {
if deep.is_some() {
return Err(de::Error::duplicate_field("deep"));
}
deep = Some(map.next_value()?);
}
}
}
TimeoutsConfig::new(fast, balanced, deep)
.map_err(|e| de::Error::custom(format!("Invalid timeout configuration: {}", e)))
}
}
deserializer.deserialize_struct(
"TimeoutsConfig",
&["fast", "balanced", "deep"],
TimeoutsConfigVisitor,
)
}
}
impl Config {
pub fn from_file<P: AsRef<Path>>(path: P) -> crate::error::AppResult<Self> {
let path_display = path.as_ref().display().to_string();
let content = std::fs::read_to_string(path.as_ref()).map_err(|source| {
let remediation = match source.kind() {
std::io::ErrorKind::NotFound => {
let current_dir = std::env::current_dir()
.map(|p| p.display().to_string())
.unwrap_or_else(|_| "<unknown>".to_string());
format!(
"\nFile not found. Check that:\n\
1. Path '{}' is correct\n\
2. File exists and is readable\n\
3. Current working directory is: {}",
path_display, current_dir
)
}
std::io::ErrorKind::PermissionDenied => {
format!(
"\nPermission denied. Check that:\n\
1. File '{}' has read permissions (chmod +r)\n\
2. Parent directories have execute permissions (chmod +x)\n\
3. Process runs as user with file access",
path_display
)
}
_ => String::new(),
};
crate::error::AppError::ConfigFileRead {
path: path_display.clone(),
source,
remediation,
}
})?;
let config: Self = toml::from_str(&content).map_err(|source| {
crate::error::AppError::ConfigParseFailed {
path: path_display.clone(),
source,
}
})?;
config
.validate()
.map_err(|e| crate::error::AppError::ConfigValidationFailed {
path: path_display,
reason: e.to_string(),
})?;
Ok(config)
}
pub fn timeout_for_tier(&self, tier: crate::router::TargetModel) -> u64 {
let tier_timeout = match tier {
crate::router::TargetModel::Fast => self.timeouts.fast(),
crate::router::TargetModel::Balanced => self.timeouts.balanced(),
crate::router::TargetModel::Deep => self.timeouts.deep(),
};
match tier_timeout {
Some(timeout) => {
tracing::debug!(
tier = ?tier,
timeout_seconds = timeout,
"Using tier-specific timeout override"
);
timeout
}
None => {
tracing::debug!(
tier = ?tier,
timeout_seconds = self.server.request_timeout_seconds,
"No tier-specific timeout configured, using global default"
);
self.server.request_timeout_seconds
}
}
}
pub fn validate(&self) -> crate::error::AppResult<()> {
for (tier_name, endpoints) in [
("fast", &self.models.fast),
("balanced", &self.models.balanced),
("deep", &self.models.deep),
] {
for endpoint in endpoints {
if endpoint.weight <= 0.0
|| endpoint.weight.is_nan()
|| endpoint.weight.is_infinite()
{
return Err(crate::error::AppError::Config(format!(
"Configuration error: Endpoint '{}' in tier '{}' has invalid weight {}. \
Weight must be a positive finite number.",
endpoint.name, tier_name, endpoint.weight
)));
}
if endpoint.max_tokens == 0 {
return Err(crate::error::AppError::Config(format!(
"Configuration error: Endpoint '{}' in tier '{}' has max_tokens=0. \
max_tokens must be greater than 0.",
endpoint.name, tier_name
)));
}
if endpoint.max_tokens > u32::MAX as usize {
return Err(crate::error::AppError::Config(format!(
"Configuration error: Endpoint '{}' in tier '{}' has max_tokens={} which exceeds u32::MAX ({}). \
max_tokens must fit in u32 for compatibility with open-agent-sdk.",
endpoint.name,
tier_name,
endpoint.max_tokens,
u32::MAX
)));
}
if !endpoint.base_url.starts_with("http://")
&& !endpoint.base_url.starts_with("https://")
{
return Err(crate::error::AppError::Config(format!(
"Configuration error: Endpoint '{}' in tier '{}' has invalid base_url '{}'. \
base_url must start with 'http://' or 'https://'.",
endpoint.name, tier_name, endpoint.base_url
)));
}
if !endpoint.base_url.ends_with("/v1") {
return Err(crate::error::AppError::Config(format!(
"Configuration error: Endpoint '{}' in tier '{}' has invalid base_url '{}'. \
base_url must end with '/v1' (e.g., 'http://host:port/v1') for OpenAI API compatibility.",
endpoint.name, tier_name, endpoint.base_url
)));
}
if endpoint.temperature < 0.0
|| endpoint.temperature > 2.0
|| endpoint.temperature.is_nan()
|| endpoint.temperature.is_infinite()
{
return Err(crate::error::AppError::Config(format!(
"Configuration error: Endpoint '{}' in tier '{}' has invalid temperature {}. \
temperature must be a finite number between 0.0 and 2.0.",
endpoint.name, tier_name, endpoint.temperature
)));
}
}
}
if self.models.fast.is_empty() {
return Err(crate::error::AppError::Config(
"Configuration error: models.fast has no endpoints. \
All three tiers (fast, balanced, deep) must have at least one endpoint \
because routers can select any tier based on request characteristics. \
See config.toml or tests for configuration examples."
.to_string(),
));
}
if self.models.balanced.is_empty() {
return Err(crate::error::AppError::Config(
"Configuration error: models.balanced has no endpoints. \
All three tiers (fast, balanced, deep) must have at least one endpoint \
because routers can select any tier based on request characteristics. \
See config.toml or tests for configuration examples."
.to_string(),
));
}
if self.models.deep.is_empty() {
return Err(crate::error::AppError::Config(
"Configuration error: models.deep has no endpoints. \
All three tiers (fast, balanced, deep) must have at least one endpoint \
because routers can select any tier based on request characteristics. \
See config.toml or tests for configuration examples."
.to_string(),
));
}
{
use std::collections::HashMap;
let mut name_to_tier: HashMap<&str, &str> = HashMap::new();
for (tier_name, endpoints) in [
("fast", &self.models.fast),
("balanced", &self.models.balanced),
("deep", &self.models.deep),
] {
for endpoint in endpoints {
if let Some(existing_tier) = name_to_tier.get(endpoint.name.as_str()) {
if *existing_tier != tier_name {
return Err(crate::error::AppError::Config(format!(
"Configuration error: Endpoint name '{}' exists in both '{}' and '{}' tiers. \
Endpoint names must be unique across different tiers. \
Duplicates within the same tier (for load balancing) are allowed.",
endpoint.name, existing_tier, tier_name
)));
}
} else {
name_to_tier.insert(&endpoint.name, tier_name);
}
}
}
}
if self.server.request_timeout_seconds == 0 {
return Err(crate::error::AppError::Config(
"Configuration error: request_timeout_seconds must be greater than 0".to_string(),
));
}
if self.server.request_timeout_seconds > 300 {
return Err(crate::error::AppError::Config(format!(
"Configuration error: request_timeout_seconds cannot exceed 300 seconds (5 minutes), got {}",
self.server.request_timeout_seconds
)));
}
self.routing
.router_timeouts
.validate()
.map_err(|e| crate::error::AppError::Config(format!("Configuration error: {}", e)))?;
reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(5))
.build()
.map_err(|e| {
crate::error::AppError::Config(format!(
"Failed to create HTTP client (TLS configuration error): {}.\n\
This usually indicates:\n\
- Invalid or expired TLS certificates\n\
- Missing CA certificate bundle\n\
- Incompatible system TLS libraries\n\
\n\
Please check your system's TLS configuration and certificates.",
e
))
})?;
Ok(())
}
}
impl FromStr for Config {
type Err = crate::error::AppError;
fn from_str(toml_str: &str) -> Result<Self, Self::Err> {
let config: Config = toml::from_str(toml_str).map_err(|source| {
let path_with_context = format!(
"<string> ({} bytes, {} lines)",
toml_str.len(),
toml_str.lines().count()
);
crate::error::AppError::ConfigParseFailed {
path: path_with_context,
source,
}
})?;
config.validate()?;
Ok(config)
}
}
#[cfg(test)]
mod tests {
use super::*;
const TEST_CONFIG: &str = r#"
[server]
host = "0.0.0.0"
port = 3000
request_timeout_seconds = 30
[[models.fast]]
name = "qwen/qwen3-vl-8b"
base_url = "http://192.168.1.67:1234/v1"
max_tokens = 4096
temperature = 0.7
weight = 1.0
priority = 1
[[models.fast]]
name = "qwen/qwen3-vl-8b"
base_url = "http://192.168.1.72:1234/v1"
max_tokens = 4096
temperature = 0.7
weight = 1.0
priority = 1
[[models.balanced]]
name = "qwen/qwen3-30b-a3b-2507"
base_url = "http://192.168.1.61:1234/v1"
max_tokens = 8192
temperature = 0.7
weight = 1.0
priority = 1
[[models.deep]]
name = "/home/steve/dev/llama.cpp/models/gpt-oss-120b-mxfp4.gguf"
base_url = "https://strix-ai.localbrandonfamily.com/v1"
max_tokens = 16384
temperature = 0.7
weight = 1.0
priority = 1
[routing]
strategy = "hybrid"
default_importance = "normal"
router_tier = "balanced"
[observability]
log_level = "info"
"#;
#[test]
fn test_config_from_str_parses_successfully() {
let config = Config::from_str(TEST_CONFIG).expect("should parse config");
assert_eq!(config.server.host, "0.0.0.0");
assert_eq!(config.server.port, 3000);
assert_eq!(config.server.request_timeout_seconds, 30);
}
#[test]
fn test_config_parses_model_endpoints() {
let config = Config::from_str(TEST_CONFIG).expect("should parse config");
assert_eq!(config.models.fast.len(), 2);
assert_eq!(config.models.fast[0].name, "qwen/qwen3-vl-8b");
assert_eq!(
config.models.fast[0].base_url,
"http://192.168.1.67:1234/v1"
);
assert_eq!(config.models.fast[0].max_tokens, 4096);
assert_eq!(config.models.fast[0].weight, 1.0);
assert_eq!(config.models.fast[0].priority, 1);
assert_eq!(
config.models.fast[1].base_url,
"http://192.168.1.72:1234/v1"
);
assert_eq!(config.models.balanced.len(), 1);
assert_eq!(config.models.balanced[0].name, "qwen/qwen3-30b-a3b-2507");
assert_eq!(config.models.deep.len(), 1);
assert_eq!(config.models.deep[0].max_tokens, 16384);
}
#[test]
fn test_config_parses_routing_strategy() {
let config = Config::from_str(TEST_CONFIG).expect("should parse config");
assert_eq!(config.routing.strategy, RoutingStrategy::Hybrid);
assert_eq!(
config.routing.default_importance,
crate::router::Importance::Normal
);
assert_eq!(config.routing.router_tier(), TargetModel::Balanced);
}
#[test]
fn test_config_parses_observability() {
let config = Config::from_str(TEST_CONFIG).expect("should parse config");
assert_eq!(config.observability.log_level, "info");
}
#[test]
fn test_routing_strategy_enum_values() {
assert_eq!(
serde_json::from_str::<RoutingStrategy>(r#""rule""#)
.expect("Test operation should succeed"),
RoutingStrategy::Rule
);
assert_eq!(
serde_json::from_str::<RoutingStrategy>(r#""llm""#)
.expect("Test operation should succeed"),
RoutingStrategy::Llm
);
assert_eq!(
serde_json::from_str::<RoutingStrategy>(r#""hybrid""#)
.expect("Test operation should succeed"),
RoutingStrategy::Hybrid
);
assert_eq!(
serde_json::from_str::<RoutingStrategy>(r#""tool""#)
.expect("Test operation should succeed"),
RoutingStrategy::Tool
);
}
#[test]
fn test_config_with_missing_observability_uses_defaults() {
let minimal_config = r#"
[server]
host = "127.0.0.1"
port = 8080
[[models.fast]]
name = "test-fast"
base_url = "http://localhost:1234/v1"
max_tokens = 2048
[[models.balanced]]
name = "test-balanced"
base_url = "http://localhost:1235/v1"
max_tokens = 4096
[[models.deep]]
name = "test-deep"
base_url = "http://localhost:1236/v1"
max_tokens = 8192
[routing]
strategy = "rule"
default_importance = "normal"
router_tier = "balanced"
"#;
let config = Config::from_str(minimal_config).expect("should parse minimal config");
assert_eq!(config.observability.log_level, "info");
assert_eq!(config.models.fast[0].weight, 1.0);
assert_eq!(config.models.fast[0].priority, 1);
}
#[test]
fn test_config_validation_invalid_router_tier_fails() {
let config_str = r#"
[server]
host = "127.0.0.1"
port = 8080
[[models.fast]]
name = "test"
base_url = "http://localhost:1234/v1"
max_tokens = 4096
[[models.balanced]]
name = "test"
base_url = "http://localhost:1235/v1"
max_tokens = 8192
[[models.deep]]
name = "test"
base_url = "http://localhost:1236/v1"
max_tokens = 16384
[routing]
strategy = "rule"
router_tier = "invalid"
"#;
let result = Config::from_str(config_str);
assert!(
result.is_err(),
"Should fail to deserialize invalid router_tier"
);
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("router_tier") || err_msg.contains("invalid"));
assert!(
err_msg.contains("fast") || err_msg.contains("balanced") || err_msg.contains("deep"),
"Error should list valid values"
);
}
#[test]
fn test_config_validation_router_tier_with_no_endpoints_fails() {
let config_str = r#"
[server]
host = "127.0.0.1"
port = 3000
request_timeout_seconds = 30
[[models.fast]]
name = "fast-1"
base_url = "http://localhost:11434/v1"
max_tokens = 2048
weight = 1.0
priority = 1
[[models.balanced]]
name = "balanced-1"
base_url = "http://localhost:1234/v1"
max_tokens = 4096
weight = 1.0
priority = 1
[[models.deep]]
name = "deep-1"
base_url = "http://localhost:8080/v1"
max_tokens = 8192
weight = 1.0
priority = 1
[routing]
strategy = "llm"
default_importance = "normal"
router_tier = "deep"
"#;
let mut config = Config::from_str(config_str).expect("Test operation should succeed");
config.models.deep.clear();
let result = config.validate();
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("deep") || err_msg.contains("Deep"));
assert!(err_msg.contains("endpoint"));
}
#[test]
fn test_config_router_tier_defaults_to_balanced() {
let config_str = r#"
[server]
host = "127.0.0.1"
port = 3000
[[models.fast]]
name = "test-fast"
base_url = "http://localhost:1234/v1"
max_tokens = 2048
[[models.balanced]]
name = "test-balanced"
base_url = "http://localhost:1235/v1"
max_tokens = 4096
[[models.deep]]
name = "test-deep"
base_url = "http://localhost:1236/v1"
max_tokens = 8192
[routing]
strategy = "rule"
# router_tier omitted - should default to balanced
"#;
let config: Config =
toml::from_str(config_str).expect("should parse config without router_tier");
assert_eq!(
config.routing.router_tier(),
TargetModel::Balanced,
"router_tier should default to Balanced when omitted"
);
}
#[test]
fn test_config_validation_negative_weight_fails() {
let mut config = Config::from_str(TEST_CONFIG).expect("Test operation should succeed");
config.models.fast[0].weight = -1.0;
let result = config.validate();
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("weight"));
assert!(err_msg.contains("positive"));
}
#[test]
fn test_config_validation_zero_weight_fails() {
let mut config = Config::from_str(TEST_CONFIG).expect("Test operation should succeed");
config.models.balanced[0].weight = 0.0;
let result = config.validate();
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("weight"));
assert!(err_msg.contains("positive"));
}
#[test]
fn test_config_validation_nan_weight_fails() {
let mut config = Config::from_str(TEST_CONFIG).expect("Test operation should succeed");
config.models.deep[0].weight = f64::NAN;
let result = config.validate();
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("weight"));
}
#[test]
fn test_config_validation_zero_max_tokens_fails() {
let mut config = Config::from_str(TEST_CONFIG).expect("Test operation should succeed");
config.models.fast[0].max_tokens = 0;
let result = config.validate();
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("max_tokens"));
assert!(err_msg.contains("greater than 0"));
}
#[test]
fn test_config_validation_invalid_base_url_fails() {
let mut config = Config::from_str(TEST_CONFIG).expect("Test operation should succeed");
config.models.balanced[0].base_url = "ftp://invalid.com".to_string();
let result = config.validate();
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("base_url"));
assert!(err_msg.contains("http"));
}
#[test]
fn test_config_validation_missing_protocol_base_url_fails() {
let mut config = Config::from_str(TEST_CONFIG).expect("Test operation should succeed");
config.models.deep[0].base_url = "localhost:1234/v1".to_string();
let result = config.validate();
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("base_url"));
assert!(err_msg.contains("http"));
}
#[test]
fn test_config_validation_base_url_must_end_with_v1() {
let mut config = Config::from_str(TEST_CONFIG).expect("Test operation should succeed");
config.models.fast[0].base_url = "http://localhost:1234".to_string();
let result = config.validate();
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("base_url"));
assert!(err_msg.contains("/v1"));
assert!(err_msg.contains("OpenAI API"));
}
#[test]
fn test_config_validation_zero_timeout_fails() {
let mut config = Config::from_str(TEST_CONFIG).expect("Test operation should succeed");
config.server.request_timeout_seconds = 0;
let result = config.validate();
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("request_timeout_seconds") && err_msg.contains("greater than 0"),
"Expected error about request_timeout_seconds > 0, got: {}",
err_msg
);
}
#[test]
fn test_config_validation_excessive_timeout_fails() {
let mut config = Config::from_str(TEST_CONFIG).expect("Test operation should succeed");
config.server.request_timeout_seconds = 301;
let result = config.validate();
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("request_timeout_seconds") && err_msg.contains("300"),
"Expected error about request_timeout_seconds max 300, got: {}",
err_msg
);
}
#[test]
fn test_config_validation_valid_timeout_succeeds() {
let mut config = Config::from_str(TEST_CONFIG).expect("Test operation should succeed");
config.server.request_timeout_seconds = 1;
assert!(config.validate().is_ok());
config.server.request_timeout_seconds = 300;
assert!(config.validate().is_ok());
config.server.request_timeout_seconds = 30;
assert!(config.validate().is_ok());
}
#[test]
fn test_config_parses_per_tier_timeouts() {
let config_with_timeouts = r#"
[server]
host = "127.0.0.1"
port = 3000
request_timeout_seconds = 30
[[models.fast]]
name = "test-fast"
base_url = "http://localhost:1234/v1"
max_tokens = 4096
[[models.balanced]]
name = "test-balanced"
base_url = "http://localhost:1235/v1"
max_tokens = 8192
[[models.deep]]
name = "test-deep"
base_url = "http://localhost:1236/v1"
max_tokens = 16384
[routing]
strategy = "rule"
router_tier = "balanced"
[timeouts]
fast = 15
balanced = 30
deep = 60
"#;
let config =
Config::from_str(config_with_timeouts).expect("should parse config with timeouts");
assert_eq!(config.timeouts.fast, Some(15));
assert_eq!(config.timeouts.balanced, Some(30));
assert_eq!(config.timeouts.deep, Some(60));
}
#[test]
fn test_config_timeouts_optional_fields_default_to_none() {
let config_partial_timeouts = r#"
[server]
host = "127.0.0.1"
port = 3000
[[models.fast]]
name = "test-fast"
base_url = "http://localhost:1234/v1"
max_tokens = 4096
[[models.balanced]]
name = "test-balanced"
base_url = "http://localhost:1235/v1"
max_tokens = 8192
[[models.deep]]
name = "test-deep"
base_url = "http://localhost:1236/v1"
max_tokens = 16384
[routing]
strategy = "rule"
router_tier = "balanced"
[timeouts]
fast = 15
# balanced and deep use global default
"#;
let config =
Config::from_str(config_partial_timeouts).expect("should parse partial timeouts");
assert_eq!(config.timeouts.fast, Some(15));
assert_eq!(config.timeouts.balanced, None); assert_eq!(config.timeouts.deep, None); }
#[test]
fn test_config_timeouts_section_optional() {
let config = Config::from_str(TEST_CONFIG).expect("should parse without timeouts section");
assert_eq!(config.timeouts.fast, None);
assert_eq!(config.timeouts.balanced, None);
assert_eq!(config.timeouts.deep, None);
}
#[test]
fn test_config_timeout_for_tier_uses_override() {
let mut config = Config::from_str(TEST_CONFIG).expect("Test operation should succeed");
config.server.request_timeout_seconds = 30; config.timeouts.fast = Some(15);
config.timeouts.balanced = Some(45);
config.timeouts.deep = Some(60);
assert_eq!(
config.timeout_for_tier(crate::router::TargetModel::Fast),
15
);
assert_eq!(
config.timeout_for_tier(crate::router::TargetModel::Balanced),
45
);
assert_eq!(
config.timeout_for_tier(crate::router::TargetModel::Deep),
60
);
}
#[test]
fn test_config_timeout_for_tier_uses_global_default() {
let config = Config::from_str(TEST_CONFIG).expect("Test operation should succeed");
assert_eq!(
config.timeout_for_tier(crate::router::TargetModel::Fast),
30
);
assert_eq!(
config.timeout_for_tier(crate::router::TargetModel::Balanced),
30
);
assert_eq!(
config.timeout_for_tier(crate::router::TargetModel::Deep),
30
);
}
#[test]
fn test_config_timeout_for_tier_mixed_overrides() {
let mut config = Config::from_str(TEST_CONFIG).expect("Test operation should succeed");
config.server.request_timeout_seconds = 40; config.timeouts.fast = Some(20);
assert_eq!(
config.timeout_for_tier(crate::router::TargetModel::Fast),
20
);
assert_eq!(
config.timeout_for_tier(crate::router::TargetModel::Balanced),
40
);
assert_eq!(
config.timeout_for_tier(crate::router::TargetModel::Deep),
40
);
}
#[test]
fn test_timeouts_config_deserialization_rejects_zero_timeout() {
let config_with_zero_timeout = r#"
[server]
host = "127.0.0.1"
port = 3000
[[models.fast]]
name = "test-fast"
base_url = "http://localhost:1234/v1"
max_tokens = 4096
[[models.balanced]]
name = "test-balanced"
base_url = "http://localhost:1235/v1"
max_tokens = 8192
[[models.deep]]
name = "test-deep"
base_url = "http://localhost:1236/v1"
max_tokens = 16384
[routing]
strategy = "rule"
router_tier = "balanced"
[timeouts]
fast = 0
"#;
let result = Config::from_str(config_with_zero_timeout);
assert!(
result.is_err(),
"Config parsing should fail with zero timeout (custom Deserialize should reject it)"
);
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("fast") && err_msg.contains("timeout"),
"Error should mention which timeout field is invalid"
);
}
#[test]
fn test_timeouts_config_deserialization_rejects_timeout_too_high() {
let config_with_high_timeout = r#"
[server]
host = "127.0.0.1"
port = 3000
[[models.fast]]
name = "test-fast"
base_url = "http://localhost:1234/v1"
max_tokens = 4096
[[models.balanced]]
name = "test-balanced"
base_url = "http://localhost:1235/v1"
max_tokens = 8192
[[models.deep]]
name = "test-deep"
base_url = "http://localhost:1236/v1"
max_tokens = 16384
[routing]
strategy = "rule"
router_tier = "balanced"
[timeouts]
deep = 301
"#;
let result = Config::from_str(config_with_high_timeout);
assert!(
result.is_err(),
"Config parsing should fail with timeout > 300 (custom Deserialize should reject it)"
);
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("deep") && err_msg.contains("300"),
"Error should mention which timeout field exceeds limit"
);
}
#[test]
fn test_timeouts_config_deserialization_accepts_valid_timeouts() {
let config_with_valid_timeouts = r#"
[server]
host = "127.0.0.1"
port = 3000
[[models.fast]]
name = "test-fast"
base_url = "http://localhost:1234/v1"
max_tokens = 4096
[[models.balanced]]
name = "test-balanced"
base_url = "http://localhost:1235/v1"
max_tokens = 8192
[[models.deep]]
name = "test-deep"
base_url = "http://localhost:1236/v1"
max_tokens = 16384
[routing]
strategy = "rule"
router_tier = "balanced"
[timeouts]
fast = 15
balanced = 30
deep = 60
"#;
let result = Config::from_str(config_with_valid_timeouts);
assert!(
result.is_ok(),
"Config parsing should succeed with valid timeouts (1-300)"
);
let config = result.expect("Test operation should succeed");
assert_eq!(config.timeouts.fast(), Some(15));
assert_eq!(config.timeouts.balanced(), Some(30));
assert_eq!(config.timeouts.deep(), Some(60));
}
#[test]
fn test_timeouts_config_deserialization_accepts_boundary_values() {
let config_with_boundaries = r#"
[server]
host = "127.0.0.1"
port = 3000
[[models.fast]]
name = "test-fast"
base_url = "http://localhost:1234/v1"
max_tokens = 4096
[[models.balanced]]
name = "test-balanced"
base_url = "http://localhost:1235/v1"
max_tokens = 8192
[[models.deep]]
name = "test-deep"
base_url = "http://localhost:1236/v1"
max_tokens = 16384
[routing]
strategy = "rule"
router_tier = "balanced"
[timeouts]
fast = 1
deep = 300
"#;
let result = Config::from_str(config_with_boundaries);
assert!(
result.is_ok(),
"Config parsing should succeed with boundary values 1 and 300"
);
let config = result.expect("Test operation should succeed");
assert_eq!(config.timeouts.fast(), Some(1));
assert_eq!(config.timeouts.deep(), Some(300));
}
}