use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use crate::state::CryoState;
#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum RotateOn {
QuickExit,
AnyFailure,
#[default]
Never,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderConfig {
pub name: String,
#[serde(default)]
pub env: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CryoConfig {
#[serde(default = "default_agent")]
pub agent: String,
#[serde(default = "default_max_retries")]
pub max_retries: u32,
#[serde(default)]
pub max_session_duration: u64,
#[serde(default = "default_watch_inbox")]
pub watch_inbox: bool,
#[serde(default = "default_web_host")]
pub web_host: String,
#[serde(default = "default_web_port")]
pub web_port: u16,
#[serde(default = "default_fallback_alert")]
pub fallback_alert: String,
#[serde(default = "default_report_time")]
pub report_time: String,
#[serde(default)]
pub report_interval: u64,
#[serde(default)]
pub rotate_on: RotateOn,
#[serde(default)]
pub providers: Vec<ProviderConfig>,
#[serde(default = "default_poll_interval")]
pub zulip_poll_interval: u64,
#[serde(default = "default_poll_interval")]
pub gh_poll_interval: u64,
}
fn default_agent() -> String {
"opencode".to_string()
}
fn default_max_retries() -> u32 {
5
}
fn default_watch_inbox() -> bool {
true
}
fn default_web_host() -> String {
"127.0.0.1".to_string()
}
fn default_web_port() -> u16 {
3945
}
fn default_fallback_alert() -> String {
"notify".to_string()
}
fn default_report_time() -> String {
"09:00".to_string()
}
fn default_poll_interval() -> u64 {
5
}
impl Default for CryoConfig {
fn default() -> Self {
Self {
agent: default_agent(),
max_retries: default_max_retries(),
max_session_duration: 0,
watch_inbox: default_watch_inbox(),
web_host: default_web_host(),
web_port: default_web_port(),
fallback_alert: default_fallback_alert(),
report_time: default_report_time(),
report_interval: 0,
rotate_on: RotateOn::default(),
providers: Vec::new(),
zulip_poll_interval: default_poll_interval(),
gh_poll_interval: default_poll_interval(),
}
}
}
impl CryoConfig {
pub fn apply_overrides(&mut self, state: &CryoState) {
if let Some(ref agent) = state.agent_override {
self.agent = agent.clone();
}
if let Some(max_retries) = state.max_retries_override {
self.max_retries = max_retries;
}
if let Some(max_session_duration) = state.max_session_duration_override {
self.max_session_duration = max_session_duration;
}
}
}
pub fn config_path(dir: &Path) -> PathBuf {
dir.join("cryo.toml")
}
pub fn load_config(path: &Path) -> Result<Option<CryoConfig>> {
if !path.exists() {
return Ok(None);
}
let contents = std::fs::read_to_string(path)?;
let config: CryoConfig = toml::from_str(&contents)?;
Ok(Some(config))
}
pub fn save_config(path: &Path, config: &CryoConfig) -> Result<()> {
let toml = toml::to_string_pretty(config)?;
std::fs::write(path, toml)?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_load_malformed_toml() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("cryo.toml");
std::fs::write(&path, "this is {{{{ not valid toml").unwrap();
let result = load_config(&path);
assert!(result.is_err(), "Should return error for malformed TOML");
}
#[test]
fn test_load_partial_toml() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("cryo.toml");
std::fs::write(&path, "agent = \"claude\"\n").unwrap();
let config = load_config(&path).unwrap().unwrap();
assert_eq!(config.agent, "claude");
assert_eq!(config.max_retries, 5, "Should use default max_retries");
assert_eq!(config.max_session_duration, 0, "Should use default timeout");
assert!(config.watch_inbox, "Should use default watch_inbox");
}
#[test]
fn test_apply_overrides_all_fields() {
let mut config = CryoConfig::default();
let state = crate::state::CryoState {
session_number: 1,
pid: None,
retry_count: 0,
agent_override: Some("claude".to_string()),
max_retries_override: Some(10),
max_session_duration_override: Some(300),
next_wake: None,
last_report_time: None,
provider_index: None,
};
config.apply_overrides(&state);
assert_eq!(config.agent, "claude");
assert_eq!(config.max_retries, 10);
assert_eq!(config.max_session_duration, 300);
}
#[test]
fn test_apply_overrides_none_fields() {
let original = CryoConfig::default();
let mut config = CryoConfig::default();
let state = crate::state::CryoState {
session_number: 1,
pid: None,
retry_count: 0,
agent_override: None,
max_retries_override: None,
max_session_duration_override: None,
next_wake: None,
last_report_time: None,
provider_index: None,
};
config.apply_overrides(&state);
assert_eq!(config.agent, original.agent);
assert_eq!(config.max_retries, original.max_retries);
assert_eq!(config.max_session_duration, original.max_session_duration);
}
}