1use std::time::Duration;
10
11pub use atd_protocol::ToolTier;
12
13pub fn tier_as_str(tier: ToolTier) -> &'static str {
16 match tier {
17 ToolTier::Hot => "hot",
18 ToolTier::Warm => "warm",
19 ToolTier::Cold => "cold",
20 }
21}
22
23#[derive(Debug, Clone)]
27pub struct TierPolicy {
28 pub hot_timeout: Duration,
29 pub warm_timeout: Duration,
30 pub cold_timeout: Duration,
31 pub hot_max_output: usize,
32 pub warm_max_output: usize,
33 pub cold_max_output: usize,
34}
35
36impl TierPolicy {
37 pub fn defaults() -> Self {
42 Self {
43 hot_timeout: Duration::from_millis(500),
44 warm_timeout: Duration::from_secs(5),
45 cold_timeout: Duration::from_secs(60),
46 hot_max_output: 64 * 1024,
47 warm_max_output: 1024 * 1024,
48 cold_max_output: 16 * 1024 * 1024,
49 }
50 }
51
52 pub fn timeout(&self, tier: ToolTier) -> Duration {
53 match tier {
54 ToolTier::Hot => self.hot_timeout,
55 ToolTier::Warm => self.warm_timeout,
56 ToolTier::Cold => self.cold_timeout,
57 }
58 }
59
60 pub fn max_output(&self, tier: ToolTier) -> usize {
61 match tier {
62 ToolTier::Hot => self.hot_max_output,
63 ToolTier::Warm => self.warm_max_output,
64 ToolTier::Cold => self.cold_max_output,
65 }
66 }
67
68 pub fn apply_override(&mut self, spec: &str) -> Result<(), String> {
72 let parts: Vec<&str> = spec.splitn(3, '=').collect();
73 if parts.len() != 3 {
74 return Err(format!("expected '<tier>=<key>=<value>', got '{spec}'"));
75 }
76 let tier = match parts[0] {
77 "hot" => ToolTier::Hot,
78 "warm" => ToolTier::Warm,
79 "cold" => ToolTier::Cold,
80 other => return Err(format!("unknown tier '{other}' (want hot|warm|cold)")),
81 };
82 match parts[1] {
83 "timeout_ms" => {
84 let v: u64 = parts[2]
85 .parse()
86 .map_err(|e| format!("invalid timeout_ms '{}': {e}", parts[2]))?;
87 let d = Duration::from_millis(v);
88 match tier {
89 ToolTier::Hot => self.hot_timeout = d,
90 ToolTier::Warm => self.warm_timeout = d,
91 ToolTier::Cold => self.cold_timeout = d,
92 }
93 }
94 "max_output_bytes" => {
95 let v: usize = parts[2]
96 .parse()
97 .map_err(|e| format!("invalid max_output_bytes '{}': {e}", parts[2]))?;
98 match tier {
99 ToolTier::Hot => self.hot_max_output = v,
100 ToolTier::Warm => self.warm_max_output = v,
101 ToolTier::Cold => self.cold_max_output = v,
102 }
103 }
104 other => {
105 return Err(format!(
106 "unknown key '{other}' (want timeout_ms|max_output_bytes)"
107 ));
108 }
109 }
110 Ok(())
111 }
112}
113
114impl Default for TierPolicy {
115 fn default() -> Self {
116 Self::defaults()
117 }
118}
119
120pub fn tier_from_opt_str(s: Option<&str>) -> ToolTier {
126 match s.map(|v| v.to_ascii_lowercase()).as_deref() {
127 Some("hot") => ToolTier::Hot,
128 Some("cold") => ToolTier::Cold,
129 _ => ToolTier::Warm,
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136
137 #[test]
138 fn tier_from_opt_str_defaults_to_warm_on_none() {
139 assert_eq!(tier_from_opt_str(None), ToolTier::Warm);
140 }
141
142 #[test]
143 fn tier_from_opt_str_defaults_to_warm_on_unknown() {
144 assert_eq!(tier_from_opt_str(Some("nebulous")), ToolTier::Warm);
145 }
146
147 #[test]
148 fn tier_from_opt_str_parses_known_values_case_insensitive() {
149 assert_eq!(tier_from_opt_str(Some("hot")), ToolTier::Hot);
150 assert_eq!(tier_from_opt_str(Some("HOT")), ToolTier::Hot);
151 assert_eq!(tier_from_opt_str(Some("cold")), ToolTier::Cold);
152 assert_eq!(tier_from_opt_str(Some("Warm")), ToolTier::Warm);
153 }
154
155 #[test]
156 fn defaults_match_current_server_warm_budget() {
157 let p = TierPolicy::defaults();
161 assert!(p.warm_timeout >= Duration::from_secs(5));
162 assert!(p.warm_max_output >= 1024 * 1024);
163 }
164
165 #[test]
166 fn timeout_lookup_by_tier() {
167 let p = TierPolicy::defaults();
168 assert_eq!(p.timeout(ToolTier::Hot), Duration::from_millis(500));
169 assert_eq!(p.timeout(ToolTier::Warm), Duration::from_secs(5));
170 assert_eq!(p.timeout(ToolTier::Cold), Duration::from_secs(60));
171 }
172
173 #[test]
174 fn apply_override_timeout_ms() {
175 let mut p = TierPolicy::defaults();
176 p.apply_override("hot=timeout_ms=300").unwrap();
177 assert_eq!(p.hot_timeout, Duration::from_millis(300));
178 assert_eq!(p.warm_timeout, Duration::from_secs(5));
180 }
181
182 #[test]
183 fn apply_override_max_output_bytes() {
184 let mut p = TierPolicy::defaults();
185 p.apply_override("cold=max_output_bytes=33554432").unwrap();
186 assert_eq!(p.cold_max_output, 33_554_432);
187 }
188
189 #[test]
190 fn apply_override_rejects_malformed_spec() {
191 let mut p = TierPolicy::defaults();
192 assert!(p.apply_override("no_equals").is_err());
193 assert!(p.apply_override("hot=timeout_ms").is_err());
194 assert!(p.apply_override("bogus=timeout_ms=100").is_err());
195 assert!(p.apply_override("hot=bogus=100").is_err());
196 assert!(p.apply_override("hot=timeout_ms=not_a_number").is_err());
197 }
198}