use crate::config::{Config, RoutingStrategy};
use crate::error::{AppError, AppResult};
use crate::models::ModelSelector;
use crate::router::{HybridRouter, LlmBasedRouter, Router, RuleBasedRouter};
use std::sync::Arc;
type MetricsHandle = Arc<crate::metrics::Metrics>;
pub mod chat;
pub mod health;
pub mod metrics;
pub mod models;
pub mod openai;
#[derive(Clone)]
pub struct AppState {
config: Arc<Config>,
selector: Arc<ModelSelector>,
router: Arc<Router>,
metrics: Arc<crate::metrics::Metrics>,
}
impl AppState {
pub fn new(config: Arc<Config>) -> AppResult<Self> {
let metrics = {
let m = crate::metrics::Metrics::new()
.map_err(|e| AppError::Internal(format!("Failed to initialize metrics: {}", e)))?;
tracing::info!("Metrics collection enabled");
Arc::new(m)
};
let selector = Arc::new(ModelSelector::new(config.clone(), metrics.clone()));
let router = match config.routing.strategy {
RoutingStrategy::Rule => {
tracing::info!("Initializing rule-based router (no LLM routing)");
Arc::new(Router::Rule(RuleBasedRouter::new()))
}
RoutingStrategy::Llm => {
let router_tier = config.routing.router_tier();
let router_timeout_secs = config.routing.router_timeout_for_tier(router_tier);
tracing::info!(
"Initializing LLM-based router with {:?} tier for routing decisions (timeout: {}s)",
router_tier,
router_timeout_secs
);
let llm_router = LlmBasedRouter::new(
selector.clone(),
router_tier,
router_timeout_secs,
metrics.clone(),
)?;
Arc::new(Router::Llm(llm_router))
}
RoutingStrategy::Hybrid => {
tracing::info!(
"Initializing hybrid router (rule-based with LLM fallback using {:?} tier)",
config.routing.router_tier()
);
let hybrid_router =
HybridRouter::new(config.clone(), selector.clone(), metrics.clone())?;
Arc::new(Router::Hybrid(hybrid_router))
}
RoutingStrategy::Tool => {
return Err(AppError::Config(
"Tool-based routing is not yet implemented. Use 'rule', 'llm', or 'hybrid'."
.to_string(),
));
}
};
Ok(Self {
config,
selector,
router,
metrics,
})
}
pub fn config(&self) -> &Config {
&self.config
}
pub fn selector(&self) -> &ModelSelector {
&self.selector
}
pub fn selector_arc(&self) -> Arc<ModelSelector> {
Arc::clone(&self.selector)
}
pub fn router(&self) -> &Router {
&self.router
}
pub fn metrics(&self) -> MetricsHandle {
self.metrics.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_config() -> Config {
let toml = r#"
[server]
host = "127.0.0.1"
port = 3000
request_timeout_seconds = 30
[[models.fast]]
name = "fast-1"
base_url = "http://localhost:1234/v1"
max_tokens = 2048
temperature = 0.7
weight = 1.0
priority = 1
[[models.balanced]]
name = "balanced-1"
base_url = "http://localhost:1235/v1"
max_tokens = 4096
temperature = 0.7
weight = 1.0
priority = 1
[[models.deep]]
name = "deep-1"
base_url = "http://localhost:1236/v1"
max_tokens = 8192
temperature = 0.7
weight = 1.0
priority = 1
[routing]
strategy = "rule"
default_importance = "normal"
router_tier = "balanced"
"#;
toml::from_str(toml).expect("should parse TOML config")
}
#[tokio::test]
async fn test_appstate_new_creates_state() {
let config = Arc::new(create_test_config());
let state = AppState::new(config).expect("AppState::new should succeed with balanced tier");
assert_eq!(state.config().server.port, 3000);
assert_eq!(
state
.selector()
.endpoint_count(crate::router::TargetModel::Fast),
1
);
}
#[tokio::test]
async fn test_appstate_is_clonable() {
let config = Arc::new(create_test_config());
let state = AppState::new(config).expect("AppState::new should succeed with balanced tier");
let state2 = state.clone();
assert_eq!(state2.config().server.port, 3000);
}
#[tokio::test]
async fn test_appstate_provides_access_to_components() {
let config = Arc::new(create_test_config());
let state = AppState::new(config).expect("AppState::new should succeed with balanced tier");
let _ = state.config();
let _ = state.selector();
let _ = state.router();
}
}