use std::time::Duration;
pub use atd_protocol::ToolTier;
pub fn tier_as_str(tier: ToolTier) -> &'static str {
match tier {
ToolTier::Hot => "hot",
ToolTier::Warm => "warm",
ToolTier::Cold => "cold",
}
}
#[derive(Debug, Clone)]
pub struct TierPolicy {
pub hot_timeout: Duration,
pub warm_timeout: Duration,
pub cold_timeout: Duration,
pub hot_max_output: usize,
pub warm_max_output: usize,
pub cold_max_output: usize,
}
impl TierPolicy {
pub fn defaults() -> Self {
Self {
hot_timeout: Duration::from_millis(500),
warm_timeout: Duration::from_secs(5),
cold_timeout: Duration::from_secs(60),
hot_max_output: 64 * 1024,
warm_max_output: 1024 * 1024,
cold_max_output: 16 * 1024 * 1024,
}
}
pub fn timeout(&self, tier: ToolTier) -> Duration {
match tier {
ToolTier::Hot => self.hot_timeout,
ToolTier::Warm => self.warm_timeout,
ToolTier::Cold => self.cold_timeout,
}
}
pub fn max_output(&self, tier: ToolTier) -> usize {
match tier {
ToolTier::Hot => self.hot_max_output,
ToolTier::Warm => self.warm_max_output,
ToolTier::Cold => self.cold_max_output,
}
}
pub fn apply_override(&mut self, spec: &str) -> Result<(), String> {
let parts: Vec<&str> = spec.splitn(3, '=').collect();
if parts.len() != 3 {
return Err(format!("expected '<tier>=<key>=<value>', got '{spec}'"));
}
let tier = match parts[0] {
"hot" => ToolTier::Hot,
"warm" => ToolTier::Warm,
"cold" => ToolTier::Cold,
other => return Err(format!("unknown tier '{other}' (want hot|warm|cold)")),
};
match parts[1] {
"timeout_ms" => {
let v: u64 = parts[2]
.parse()
.map_err(|e| format!("invalid timeout_ms '{}': {e}", parts[2]))?;
let d = Duration::from_millis(v);
match tier {
ToolTier::Hot => self.hot_timeout = d,
ToolTier::Warm => self.warm_timeout = d,
ToolTier::Cold => self.cold_timeout = d,
}
}
"max_output_bytes" => {
let v: usize = parts[2]
.parse()
.map_err(|e| format!("invalid max_output_bytes '{}': {e}", parts[2]))?;
match tier {
ToolTier::Hot => self.hot_max_output = v,
ToolTier::Warm => self.warm_max_output = v,
ToolTier::Cold => self.cold_max_output = v,
}
}
other => {
return Err(format!(
"unknown key '{other}' (want timeout_ms|max_output_bytes)"
));
}
}
Ok(())
}
}
impl Default for TierPolicy {
fn default() -> Self {
Self::defaults()
}
}
pub fn tier_from_opt_str(s: Option<&str>) -> ToolTier {
match s.map(|v| v.to_ascii_lowercase()).as_deref() {
Some("hot") => ToolTier::Hot,
Some("cold") => ToolTier::Cold,
_ => ToolTier::Warm,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tier_from_opt_str_defaults_to_warm_on_none() {
assert_eq!(tier_from_opt_str(None), ToolTier::Warm);
}
#[test]
fn tier_from_opt_str_defaults_to_warm_on_unknown() {
assert_eq!(tier_from_opt_str(Some("nebulous")), ToolTier::Warm);
}
#[test]
fn tier_from_opt_str_parses_known_values_case_insensitive() {
assert_eq!(tier_from_opt_str(Some("hot")), ToolTier::Hot);
assert_eq!(tier_from_opt_str(Some("HOT")), ToolTier::Hot);
assert_eq!(tier_from_opt_str(Some("cold")), ToolTier::Cold);
assert_eq!(tier_from_opt_str(Some("Warm")), ToolTier::Warm);
}
#[test]
fn defaults_match_current_server_warm_budget() {
let p = TierPolicy::defaults();
assert!(p.warm_timeout >= Duration::from_secs(5));
assert!(p.warm_max_output >= 1024 * 1024);
}
#[test]
fn timeout_lookup_by_tier() {
let p = TierPolicy::defaults();
assert_eq!(p.timeout(ToolTier::Hot), Duration::from_millis(500));
assert_eq!(p.timeout(ToolTier::Warm), Duration::from_secs(5));
assert_eq!(p.timeout(ToolTier::Cold), Duration::from_secs(60));
}
#[test]
fn apply_override_timeout_ms() {
let mut p = TierPolicy::defaults();
p.apply_override("hot=timeout_ms=300").unwrap();
assert_eq!(p.hot_timeout, Duration::from_millis(300));
assert_eq!(p.warm_timeout, Duration::from_secs(5));
}
#[test]
fn apply_override_max_output_bytes() {
let mut p = TierPolicy::defaults();
p.apply_override("cold=max_output_bytes=33554432").unwrap();
assert_eq!(p.cold_max_output, 33_554_432);
}
#[test]
fn apply_override_rejects_malformed_spec() {
let mut p = TierPolicy::defaults();
assert!(p.apply_override("no_equals").is_err());
assert!(p.apply_override("hot=timeout_ms").is_err());
assert!(p.apply_override("bogus=timeout_ms=100").is_err());
assert!(p.apply_override("hot=bogus=100").is_err());
assert!(p.apply_override("hot=timeout_ms=not_a_number").is_err());
}
}