1use std::sync::Arc;
15
16use serde::{Deserialize, Serialize};
17use sha2::{Digest, Sha256};
18
19pub use crate::agents::{AgentDrain, AgentRole, DrainMode};
20
21mod backoff;
22mod transitions;
23
24#[derive(Clone, Serialize, Debug)]
38pub struct AgentChainState {
39 pub agents: Arc<[String]>,
42 pub current_agent_index: usize,
43 pub models_per_agent: Arc<[Vec<String>]>,
46 pub current_model_index: usize,
47 pub retry_cycle: u32,
48 pub max_cycles: u32,
49 #[serde(default = "default_retry_delay_ms")]
51 pub retry_delay_ms: u64,
52 #[serde(default = "default_backoff_multiplier")]
54 pub backoff_multiplier: f64,
55 #[serde(default = "default_max_backoff_ms")]
57 pub max_backoff_ms: u64,
58 #[serde(default)]
60 pub backoff_pending_ms: Option<u64>,
61 pub current_role: AgentRole,
67 #[serde(default = "default_current_drain")]
68 pub current_drain: AgentDrain,
69 #[serde(default)]
70 pub current_mode: DrainMode,
71 #[serde(default)]
79 pub rate_limit_continuation_prompt: Option<RateLimitContinuationPrompt>,
80 #[serde(default)]
86 pub last_session_id: Option<String>,
87}
88
89#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Eq)]
91pub struct RateLimitContinuationPrompt {
92 pub drain: AgentDrain,
93 pub role: AgentRole,
94 pub prompt: String,
95}
96
97#[derive(Deserialize)]
98#[serde(untagged)]
99enum RateLimitContinuationPromptRepr {
100 LegacyString(String),
101 Structured {
102 #[serde(rename = "role")]
103 _role: AgentRole,
104 #[serde(default)]
105 drain: Option<AgentDrain>,
106 prompt: String,
107 },
108}
109
110fn infer_legacy_current_drain(
111 current_drain: Option<AgentDrain>,
112 current_role: Option<AgentRole>,
113 current_mode: DrainMode,
114 continuation_prompt: Option<&RateLimitContinuationPromptRepr>,
115) -> AgentDrain {
116 if let Some(current_drain) = current_drain {
117 return current_drain;
118 }
119
120 if let Some(prompt_drain) = continuation_prompt.and_then(|prompt| match prompt {
121 RateLimitContinuationPromptRepr::LegacyString(_) => None,
122 RateLimitContinuationPromptRepr::Structured { drain, .. } => *drain,
123 }) {
124 return prompt_drain;
125 }
126
127 match (current_role, current_mode) {
128 (Some(AgentRole::Reviewer), DrainMode::Continuation) => AgentDrain::Fix,
129 (Some(AgentRole::Developer), DrainMode::Continuation) => AgentDrain::Development,
130 (Some(current_role), _) => AgentDrain::from(current_role),
131 (None, _) => default_current_drain(),
132 }
133}
134
135impl<'de> Deserialize<'de> for AgentChainState {
136 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
137 where
138 D: serde::Deserializer<'de>,
139 {
140 #[derive(Deserialize)]
141 struct AgentChainStateSerde {
142 agents: Arc<[String]>,
143 current_agent_index: usize,
144 models_per_agent: Arc<[Vec<String>]>,
145 current_model_index: usize,
146 retry_cycle: u32,
147 max_cycles: u32,
148 #[serde(default = "default_retry_delay_ms")]
149 retry_delay_ms: u64,
150 #[serde(default = "default_backoff_multiplier")]
151 backoff_multiplier: f64,
152 #[serde(default = "default_max_backoff_ms")]
153 max_backoff_ms: u64,
154 #[serde(default)]
155 backoff_pending_ms: Option<u64>,
156 #[serde(default)]
157 current_drain: Option<AgentDrain>,
158 #[serde(default)]
159 current_role: Option<AgentRole>,
160 #[serde(default)]
161 current_mode: DrainMode,
162 #[serde(default)]
163 rate_limit_continuation_prompt: Option<RateLimitContinuationPromptRepr>,
164 #[serde(default)]
165 last_session_id: Option<String>,
166 }
167
168 let raw = AgentChainStateSerde::deserialize(deserializer)?;
169 let current_drain = infer_legacy_current_drain(
170 raw.current_drain,
171 raw.current_role,
172 raw.current_mode,
173 raw.rate_limit_continuation_prompt.as_ref(),
174 );
175 let current_role = current_drain.role();
176
177 let rate_limit_continuation_prompt = raw.rate_limit_continuation_prompt.map(|repr| {
178 match repr {
179 RateLimitContinuationPromptRepr::LegacyString(prompt) => {
180 RateLimitContinuationPrompt {
183 drain: current_drain,
184 role: current_role,
185 prompt,
186 }
187 }
188 RateLimitContinuationPromptRepr::Structured {
189 _role: _,
190 drain,
191 prompt,
192 } => {
193 let prompt_drain = drain.unwrap_or(current_drain);
194 RateLimitContinuationPrompt {
195 drain: prompt_drain,
196 role: prompt_drain.role(),
197 prompt,
198 }
199 }
200 }
201 });
202
203 Ok(Self {
204 agents: raw.agents,
205 current_agent_index: raw.current_agent_index,
206 models_per_agent: raw.models_per_agent,
207 current_model_index: raw.current_model_index,
208 retry_cycle: raw.retry_cycle,
209 max_cycles: raw.max_cycles,
210 retry_delay_ms: raw.retry_delay_ms,
211 backoff_multiplier: raw.backoff_multiplier,
212 max_backoff_ms: raw.max_backoff_ms,
213 backoff_pending_ms: raw.backoff_pending_ms,
214 current_role,
215 current_drain,
216 current_mode: raw.current_mode,
217 rate_limit_continuation_prompt,
218 last_session_id: raw.last_session_id,
219 })
220 }
221}
222
223const fn default_retry_delay_ms() -> u64 {
224 1000
225}
226
227const fn default_backoff_multiplier() -> f64 {
228 2.0
229}
230
231const fn default_max_backoff_ms() -> u64 {
232 60000
233}
234
235const fn default_current_drain() -> AgentDrain {
236 AgentDrain::Planning
237}
238
239const fn agent_drain_signature_tag(drain: AgentDrain) -> &'static [u8] {
240 match drain {
241 AgentDrain::Planning => b"planning\n",
242 AgentDrain::Development => b"development\n",
243 AgentDrain::Review => b"review\n",
244 AgentDrain::Fix => b"fix\n",
245 AgentDrain::Commit => b"commit\n",
246 AgentDrain::Analysis => b"analysis\n",
247 }
248}
249
250impl AgentChainState {
251 #[must_use]
252 pub fn initial() -> Self {
253 Self {
254 agents: Arc::from(vec![]),
255 current_agent_index: 0,
256 models_per_agent: Arc::from(vec![]),
257 current_model_index: 0,
258 retry_cycle: 0,
259 max_cycles: 3,
260 retry_delay_ms: default_retry_delay_ms(),
261 backoff_multiplier: default_backoff_multiplier(),
262 max_backoff_ms: default_max_backoff_ms(),
263 backoff_pending_ms: None,
264 current_role: AgentRole::Developer,
265 current_drain: default_current_drain(),
266 current_mode: DrainMode::Normal,
267 rate_limit_continuation_prompt: None,
268 last_session_id: None,
269 }
270 }
271
272 #[must_use]
273 pub fn matches_runtime_drain(&self, runtime_drain: AgentDrain) -> bool {
274 self.current_drain == runtime_drain
275 }
276
277 #[must_use]
278 pub fn with_agents(
279 mut self,
280 agents: Vec<String>,
281 models_per_agent: Vec<Vec<String>>,
282 role: AgentRole,
283 ) -> Self {
284 self.agents = Arc::from(agents);
285 self.models_per_agent = Arc::from(models_per_agent);
286 self.current_role = role;
287 self.current_drain = match role {
288 AgentRole::Developer => AgentDrain::Development,
289 AgentRole::Reviewer => AgentDrain::Review,
290 AgentRole::Commit => AgentDrain::Commit,
291 AgentRole::Analysis => AgentDrain::Analysis,
292 };
293 self.current_mode = DrainMode::Normal;
294 self
295 }
296
297 #[must_use]
298 pub const fn with_drain(mut self, drain: AgentDrain) -> Self {
299 self.current_drain = drain;
300 self.current_role = drain.role();
301 self
302 }
303
304 #[must_use]
305 pub const fn with_mode(mut self, mode: DrainMode) -> Self {
306 self.current_mode = mode;
307 self
308 }
309
310 #[must_use]
311 pub const fn active_role(&self) -> AgentRole {
312 self.current_drain.role()
313 }
314
315 #[must_use]
320 pub const fn with_max_cycles(mut self, max_cycles: u32) -> Self {
321 self.max_cycles = max_cycles;
322 self
323 }
324
325 #[must_use]
326 pub const fn with_backoff_policy(
327 mut self,
328 retry_delay_ms: u64,
329 backoff_multiplier: f64,
330 max_backoff_ms: u64,
331 ) -> Self {
332 self.retry_delay_ms = retry_delay_ms;
333 self.backoff_multiplier = backoff_multiplier;
334 self.max_backoff_ms = max_backoff_ms;
335 self
336 }
337
338 #[must_use]
339 pub fn current_agent(&self) -> Option<&String> {
340 self.agents.get(self.current_agent_index)
341 }
342
343 #[must_use]
352 pub fn consumer_signature_sha256(&self) -> String {
353 let mut pairs: Vec<(&str, &[String])> = self
354 .agents
355 .iter()
356 .enumerate()
357 .map(|(idx, agent)| {
358 let models: &[String] = self
359 .models_per_agent
360 .get(idx)
361 .map_or([].as_slice(), std::vec::Vec::as_slice);
362 (agent.as_str(), models)
363 })
364 .collect();
365
366 pairs.sort_by(|(agent_a, models_a), (agent_b, models_b)| {
369 use std::cmp::Ordering;
370
371 let agent_ord = agent_a.cmp(agent_b);
372 if agent_ord != Ordering::Equal {
373 return agent_ord;
374 }
375
376 for (a, b) in models_a.iter().zip(models_b.iter()) {
377 let ord = a.cmp(b);
378 if ord != Ordering::Equal {
379 return ord;
380 }
381 }
382
383 models_a.len().cmp(&models_b.len())
384 });
385
386 let mut hasher = Sha256::new();
387 hasher.update(agent_drain_signature_tag(self.current_drain));
388 for (agent, models) in pairs {
389 hasher.update(agent.as_bytes());
390 hasher.update(b"|");
391 for (idx, model) in models.iter().enumerate() {
392 if idx > 0 {
393 hasher.update(b",");
394 }
395 hasher.update(model.as_bytes());
396 }
397 hasher.update(b"\n");
398 }
399 let digest = hasher.finalize();
400 digest.iter().fold(String::new(), |mut s, b| {
401 use std::fmt::Write;
402 write!(&mut s, "{b:02x}").unwrap();
403 s
404 })
405 }
406
407 #[cfg(test)]
408 fn legacy_consumer_signature_sha256_for_test(&self) -> String {
409 let mut rendered: Vec<String> = self
410 .agents
411 .iter()
412 .enumerate()
413 .map(|(idx, agent)| {
414 let models = self
415 .models_per_agent
416 .get(idx)
417 .map_or([].as_slice(), std::vec::Vec::as_slice);
418 format!("{}|{}", agent, models.join(","))
419 })
420 .collect();
421
422 rendered.sort();
423
424 let mut hasher = Sha256::new();
425 hasher.update(agent_drain_signature_tag(self.current_drain));
426 for line in rendered {
427 hasher.update(line.as_bytes());
428 hasher.update(b"\n");
429 }
430 let digest = hasher.finalize();
431 digest.iter().fold(String::new(), |mut s, b| {
432 use std::fmt::Write;
433 write!(&mut s, "{b:02x}").unwrap();
434 s
435 })
436 }
437
438 #[must_use]
445 pub fn current_model(&self) -> Option<&String> {
446 self.models_per_agent
447 .get(self.current_agent_index)
448 .and_then(|models| models.get(self.current_model_index))
449 }
450
451 #[must_use]
452 pub const fn is_exhausted(&self) -> bool {
453 self.retry_cycle >= self.max_cycles
454 && self.current_agent_index == 0
455 && self.current_model_index == 0
456 }
457}
458
459#[cfg(test)]
460mod consumer_signature_tests {
461 use super::*;
462
463 #[test]
464 fn test_consumer_signature_sorting_matches_legacy_rendered_pair_ordering() {
465 let state = AgentChainState::initial().with_agents(
471 vec!["agent".to_string(), "agent".to_string()],
472 vec![
473 vec!["b".to_string()],
474 vec!["a".to_string(), "z".to_string()],
475 ],
476 AgentRole::Developer,
477 );
478
479 assert_eq!(
480 state.consumer_signature_sha256(),
481 state.legacy_consumer_signature_sha256_for_test(),
482 "consumer signature ordering must remain stable for the same configured consumers"
483 );
484 }
485
486 #[test]
487 fn test_consumer_signature_uses_stable_drain_encoding() {
488 let state = AgentChainState::initial()
492 .with_agents(
493 vec!["agent-a".to_string()],
494 vec![vec!["m1".to_string(), "m2".to_string()]],
495 AgentRole::Reviewer,
496 )
497 .with_drain(AgentDrain::Fix);
498
499 let mut hasher = Sha256::new();
500 hasher.update(b"fix\n");
501 hasher.update(b"agent-a");
502 hasher.update(b"|");
503 hasher.update(b"m1");
504 hasher.update(b",");
505 hasher.update(b"m2");
506 hasher.update(b"\n");
507 let expected = hasher.finalize().iter().fold(String::new(), |mut acc, b| {
508 use std::fmt::Write;
509 write!(acc, "{b:02x}").unwrap();
510 acc
511 });
512
513 assert_eq!(
514 state.consumer_signature_sha256(),
515 expected,
516 "role encoding must be stable and explicit"
517 );
518 }
519}
520
521#[cfg(test)]
522mod legacy_rate_limit_prompt_tests {
523 use super::*;
524
525 #[test]
526 fn test_legacy_rate_limit_continuation_prompt_uses_current_role_on_deserialize() {
527 let state = AgentChainState::initial().with_agents(
531 vec!["a".to_string()],
532 vec![vec![]],
533 AgentRole::Reviewer,
534 );
535
536 let mut v = serde_json::to_value(&state).expect("serialize AgentChainState");
537 v["rate_limit_continuation_prompt"] = serde_json::Value::String("legacy prompt".into());
538
539 let json = serde_json::to_string(&v).expect("serialize JSON value");
540 let decoded: AgentChainState =
541 serde_json::from_str(&json).expect("deserialize AgentChainState");
542
543 let prompt = decoded
544 .rate_limit_continuation_prompt
545 .expect("expected legacy prompt to deserialize");
546 assert_eq!(prompt.drain, AgentDrain::Review);
547 assert_eq!(prompt.role, AgentRole::Reviewer);
548 assert_eq!(prompt.prompt, "legacy prompt");
549 }
550}