use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
use std::time::Duration;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
pub models: HashMap<String, ModelConfig>,
#[serde(default)]
pub policy: PolicyConfig,
#[serde(default = "default_port")]
pub port: u16,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelConfig {
pub port: u16,
pub wake: String,
pub sleep: String,
pub alive: String,
}
fn default_port() -> u16 {
3000
}
impl Config {
pub async fn from_file(path: &Path) -> Result<Self> {
let contents = tokio::fs::read_to_string(path)
.await
.with_context(|| format!("Failed to read config file: {}", path.display()))?;
let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
match ext {
"yaml" | "yml" => serde_yaml::from_str(&contents)
.with_context(|| format!("Failed to parse YAML config: {}", path.display())),
_ => serde_json::from_str(&contents)
.with_context(|| format!("Failed to parse JSON config: {}", path.display())),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PolicyConfig {
#[serde(default)]
pub request_timeout_secs: Option<u64>,
#[serde(default = "default_drain_before_switch")]
pub drain_before_switch: bool,
#[serde(default)]
pub min_active_secs: u64,
}
impl Default for PolicyConfig {
fn default() -> Self {
Self {
request_timeout_secs: None,
drain_before_switch: default_drain_before_switch(),
min_active_secs: 0,
}
}
}
fn default_drain_before_switch() -> bool {
true
}
impl PolicyConfig {
pub fn build_policy(&self) -> Box<dyn crate::policy::SwitchPolicy> {
Box::new(crate::policy::FifoPolicy::new(
self.request_timeout_secs.map(Duration::from_secs),
self.drain_before_switch,
Duration::from_secs(self.min_active_secs),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_json() {
let json = r#"{
"models": {
"llama": {
"port": 8001,
"wake": "./scripts/wake-llama.sh",
"sleep": "./scripts/sleep-llama.sh",
"alive": "./scripts/alive-llama.sh"
},
"mistral": {
"port": 8002,
"wake": "./scripts/wake-mistral.sh",
"sleep": "./scripts/sleep-mistral.sh",
"alive": "./scripts/alive-mistral.sh"
}
},
"policy": {
"request_timeout_secs": 30
},
"port": 3000
}"#;
let config: Config = serde_json::from_str(json).unwrap();
assert_eq!(config.models.len(), 2);
assert_eq!(config.models["llama"].port, 8001);
assert_eq!(config.policy.request_timeout_secs, Some(30));
}
#[test]
fn test_parse_yaml() {
let yaml = r#"
models:
llama:
port: 8001
wake: ./scripts/wake-llama.sh
sleep: |
kill $(cat /tmp/llama.pid)
rm /tmp/llama.pid
alive: curl -sf http://localhost:8001/health
policy:
request_timeout_secs: 60
port: 4000
"#;
let config: Config = serde_yaml::from_str(yaml).unwrap();
assert_eq!(config.models.len(), 1);
assert_eq!(config.models["llama"].port, 8001);
assert_eq!(config.models["llama"].wake, "./scripts/wake-llama.sh");
assert!(config.models["llama"].sleep.contains("kill"));
assert_eq!(config.policy.request_timeout_secs, Some(60));
assert_eq!(config.port, 4000);
}
#[test]
fn test_defaults() {
let json = r#"{
"models": {
"llama": {
"port": 8001,
"wake": "./wake.sh",
"sleep": "./sleep.sh",
"alive": "./alive.sh"
}
}
}"#;
let config: Config = serde_json::from_str(json).unwrap();
assert_eq!(config.port, 3000);
assert_eq!(config.policy.request_timeout_secs, None);
assert!(config.policy.drain_before_switch);
assert_eq!(config.policy.min_active_secs, 0);
}
}