Skip to main content

atd_runtime/
tier.rs

1//! Tier-aware dispatch policy.
2//!
3//! `atd_protocol::ToolTier` (Hot/Warm/Cold) is the authoritative tier enum —
4//! this module adds the **policy** that maps each tier to a timeout budget
5//! and a max-output budget used at dispatch time. SP-12 makes the `tier`
6//! signal load-bearing; a future SP can extend this with placement /
7//! priority semantics without changing the field on the definition.
8
9use std::time::Duration;
10
11pub use atd_protocol::ToolTier;
12
13/// Stable lower-case string label for a `ToolTier`. Used on the audit-event
14/// wire (`CallEvent::tier`) and anywhere a human-readable tier tag is needed.
15pub fn tier_as_str(tier: ToolTier) -> &'static str {
16    match tier {
17        ToolTier::Hot => "hot",
18        ToolTier::Warm => "warm",
19        ToolTier::Cold => "cold",
20    }
21}
22
23/// Per-tier budgets used when constructing `CallContext` for a tool call.
24/// `Warm` defaults match the pre-SP-12 server config (1 MiB / 60 s) to keep
25/// the 9 existing tools' behavior unchanged.
26#[derive(Debug, Clone)]
27pub struct TierPolicy {
28    pub hot_timeout: Duration,
29    pub warm_timeout: Duration,
30    pub cold_timeout: Duration,
31    pub hot_max_output: usize,
32    pub warm_max_output: usize,
33    pub cold_max_output: usize,
34}
35
36impl TierPolicy {
37    /// Canonical defaults:
38    /// - Hot: 500 ms / 64 KiB — latency-critical tools (sensors, cached state).
39    /// - Warm: 5 s / 1 MiB — typical tool invocations (current server default).
40    /// - Cold: 60 s / 16 MiB — long-running / large-output tools.
41    pub fn defaults() -> Self {
42        Self {
43            hot_timeout: Duration::from_millis(500),
44            warm_timeout: Duration::from_secs(5),
45            cold_timeout: Duration::from_secs(60),
46            hot_max_output: 64 * 1024,
47            warm_max_output: 1024 * 1024,
48            cold_max_output: 16 * 1024 * 1024,
49        }
50    }
51
52    pub fn timeout(&self, tier: ToolTier) -> Duration {
53        match tier {
54            ToolTier::Hot => self.hot_timeout,
55            ToolTier::Warm => self.warm_timeout,
56            ToolTier::Cold => self.cold_timeout,
57        }
58    }
59
60    pub fn max_output(&self, tier: ToolTier) -> usize {
61        match tier {
62            ToolTier::Hot => self.hot_max_output,
63            ToolTier::Warm => self.warm_max_output,
64            ToolTier::Cold => self.cold_max_output,
65        }
66    }
67
68    /// Apply a single `"<tier>=<key>=<value>"` override. Keys: `timeout_ms`,
69    /// `max_output_bytes`. Returns a human-readable error on malformed input
70    /// so the CLI can surface it via exit 2.
71    pub fn apply_override(&mut self, spec: &str) -> Result<(), String> {
72        let parts: Vec<&str> = spec.splitn(3, '=').collect();
73        if parts.len() != 3 {
74            return Err(format!("expected '<tier>=<key>=<value>', got '{spec}'"));
75        }
76        let tier = match parts[0] {
77            "hot" => ToolTier::Hot,
78            "warm" => ToolTier::Warm,
79            "cold" => ToolTier::Cold,
80            other => return Err(format!("unknown tier '{other}' (want hot|warm|cold)")),
81        };
82        match parts[1] {
83            "timeout_ms" => {
84                let v: u64 = parts[2]
85                    .parse()
86                    .map_err(|e| format!("invalid timeout_ms '{}': {e}", parts[2]))?;
87                let d = Duration::from_millis(v);
88                match tier {
89                    ToolTier::Hot => self.hot_timeout = d,
90                    ToolTier::Warm => self.warm_timeout = d,
91                    ToolTier::Cold => self.cold_timeout = d,
92                }
93            }
94            "max_output_bytes" => {
95                let v: usize = parts[2]
96                    .parse()
97                    .map_err(|e| format!("invalid max_output_bytes '{}': {e}", parts[2]))?;
98                match tier {
99                    ToolTier::Hot => self.hot_max_output = v,
100                    ToolTier::Warm => self.warm_max_output = v,
101                    ToolTier::Cold => self.cold_max_output = v,
102                }
103            }
104            other => {
105                return Err(format!(
106                    "unknown key '{other}' (want timeout_ms|max_output_bytes)"
107                ));
108            }
109        }
110        Ok(())
111    }
112}
113
114impl Default for TierPolicy {
115    fn default() -> Self {
116        Self::defaults()
117    }
118}
119
120/// Parse a tier hint from an optional string (e.g. a `tier` field on a tool
121/// definition, which may be absent). Defaults to `Warm` on `None` or an
122/// unrecognized value — this is the SP-12 back-compat behavior locked in by
123/// spec §8 Q5. A future SP can flip unknown-values to a hard error once all
124/// builtin tools opt in.
125pub fn tier_from_opt_str(s: Option<&str>) -> ToolTier {
126    match s.map(|v| v.to_ascii_lowercase()).as_deref() {
127        Some("hot") => ToolTier::Hot,
128        Some("cold") => ToolTier::Cold,
129        _ => ToolTier::Warm,
130    }
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136
137    #[test]
138    fn tier_from_opt_str_defaults_to_warm_on_none() {
139        assert_eq!(tier_from_opt_str(None), ToolTier::Warm);
140    }
141
142    #[test]
143    fn tier_from_opt_str_defaults_to_warm_on_unknown() {
144        assert_eq!(tier_from_opt_str(Some("nebulous")), ToolTier::Warm);
145    }
146
147    #[test]
148    fn tier_from_opt_str_parses_known_values_case_insensitive() {
149        assert_eq!(tier_from_opt_str(Some("hot")), ToolTier::Hot);
150        assert_eq!(tier_from_opt_str(Some("HOT")), ToolTier::Hot);
151        assert_eq!(tier_from_opt_str(Some("cold")), ToolTier::Cold);
152        assert_eq!(tier_from_opt_str(Some("Warm")), ToolTier::Warm);
153    }
154
155    #[test]
156    fn defaults_match_current_server_warm_budget() {
157        // Migration-safety pin: Warm timeout >= 5s and max_output >= 1 MiB so
158        // existing tools (which carry no tier) keep working. This test is a
159        // tripwire — intentionally relax it only when flipping Warm defaults.
160        let p = TierPolicy::defaults();
161        assert!(p.warm_timeout >= Duration::from_secs(5));
162        assert!(p.warm_max_output >= 1024 * 1024);
163    }
164
165    #[test]
166    fn timeout_lookup_by_tier() {
167        let p = TierPolicy::defaults();
168        assert_eq!(p.timeout(ToolTier::Hot), Duration::from_millis(500));
169        assert_eq!(p.timeout(ToolTier::Warm), Duration::from_secs(5));
170        assert_eq!(p.timeout(ToolTier::Cold), Duration::from_secs(60));
171    }
172
173    #[test]
174    fn apply_override_timeout_ms() {
175        let mut p = TierPolicy::defaults();
176        p.apply_override("hot=timeout_ms=300").unwrap();
177        assert_eq!(p.hot_timeout, Duration::from_millis(300));
178        // Other tiers untouched.
179        assert_eq!(p.warm_timeout, Duration::from_secs(5));
180    }
181
182    #[test]
183    fn apply_override_max_output_bytes() {
184        let mut p = TierPolicy::defaults();
185        p.apply_override("cold=max_output_bytes=33554432").unwrap();
186        assert_eq!(p.cold_max_output, 33_554_432);
187    }
188
189    #[test]
190    fn apply_override_rejects_malformed_spec() {
191        let mut p = TierPolicy::defaults();
192        assert!(p.apply_override("no_equals").is_err());
193        assert!(p.apply_override("hot=timeout_ms").is_err());
194        assert!(p.apply_override("bogus=timeout_ms=100").is_err());
195        assert!(p.apply_override("hot=bogus=100").is_err());
196        assert!(p.apply_override("hot=timeout_ms=not_a_number").is_err());
197    }
198}