1use std::sync::Arc;
15
16use serde::{Deserialize, Serialize};
17use sha2::{Digest, Sha256};
18
19pub use crate::agents::{AgentDrain, AgentRole, DrainMode};
20
21mod backoff;
22mod transitions;
23
24#[derive(Clone, Serialize, Debug)]
38pub struct AgentChainState {
39 pub agents: Arc<[String]>,
42 pub current_agent_index: usize,
43 pub models_per_agent: Arc<[Vec<String>]>,
46 pub current_model_index: usize,
47 pub retry_cycle: u32,
48 pub max_cycles: u32,
49 #[serde(default = "default_retry_delay_ms")]
51 pub retry_delay_ms: u64,
52 #[serde(default = "default_backoff_multiplier")]
54 pub backoff_multiplier: f64,
55 #[serde(default = "default_max_backoff_ms")]
57 pub max_backoff_ms: u64,
58 #[serde(default)]
60 pub backoff_pending_ms: Option<u64>,
61 pub current_role: AgentRole,
67 #[serde(default = "default_current_drain")]
68 pub current_drain: AgentDrain,
69 #[serde(default)]
70 pub current_mode: DrainMode,
71 #[serde(default)]
79 pub rate_limit_continuation_prompt: Option<RateLimitContinuationPrompt>,
80 #[serde(default)]
86 pub last_session_id: Option<String>,
87 #[serde(default)]
92 pub last_failure_reason: Option<String>,
93}
94
95#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Eq)]
97pub struct RateLimitContinuationPrompt {
98 pub drain: AgentDrain,
99 pub role: AgentRole,
100 pub prompt: String,
101}
102
103#[derive(Deserialize)]
104#[serde(untagged)]
105enum RateLimitContinuationPromptRepr {
106 LegacyString(String),
107 Structured {
108 #[serde(rename = "role")]
109 _role: AgentRole,
110 #[serde(default)]
111 drain: Option<AgentDrain>,
112 prompt: String,
113 },
114}
115
116fn infer_legacy_current_drain(
117 current_drain: Option<AgentDrain>,
118 current_role: Option<AgentRole>,
119 current_mode: DrainMode,
120 continuation_prompt: Option<&RateLimitContinuationPromptRepr>,
121) -> AgentDrain {
122 if let Some(current_drain) = current_drain {
123 return current_drain;
124 }
125
126 if let Some(prompt_drain) = continuation_prompt.and_then(|prompt| match prompt {
127 RateLimitContinuationPromptRepr::LegacyString(_) => None,
128 RateLimitContinuationPromptRepr::Structured { drain, .. } => *drain,
129 }) {
130 return prompt_drain;
131 }
132
133 match (current_role, current_mode) {
134 (Some(AgentRole::Reviewer), DrainMode::Continuation) => AgentDrain::Fix,
135 (Some(AgentRole::Developer), DrainMode::Continuation) => AgentDrain::Development,
136 (Some(current_role), _) => AgentDrain::from(current_role),
137 (None, _) => default_current_drain(),
138 }
139}
140
141impl<'de> Deserialize<'de> for AgentChainState {
142 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
143 where
144 D: serde::Deserializer<'de>,
145 {
146 #[derive(Deserialize)]
147 struct AgentChainStateSerde {
148 agents: Arc<[String]>,
149 current_agent_index: usize,
150 models_per_agent: Arc<[Vec<String>]>,
151 current_model_index: usize,
152 retry_cycle: u32,
153 max_cycles: u32,
154 #[serde(default = "default_retry_delay_ms")]
155 retry_delay_ms: u64,
156 #[serde(default = "default_backoff_multiplier")]
157 backoff_multiplier: f64,
158 #[serde(default = "default_max_backoff_ms")]
159 max_backoff_ms: u64,
160 #[serde(default)]
161 backoff_pending_ms: Option<u64>,
162 #[serde(default)]
163 current_drain: Option<AgentDrain>,
164 #[serde(default)]
165 current_role: Option<AgentRole>,
166 #[serde(default)]
167 current_mode: DrainMode,
168 #[serde(default)]
169 rate_limit_continuation_prompt: Option<RateLimitContinuationPromptRepr>,
170 #[serde(default)]
171 last_session_id: Option<String>,
172 #[serde(default)]
173 last_failure_reason: Option<String>,
174 }
175
176 let raw = AgentChainStateSerde::deserialize(deserializer)?;
177 let current_drain = infer_legacy_current_drain(
178 raw.current_drain,
179 raw.current_role,
180 raw.current_mode,
181 raw.rate_limit_continuation_prompt.as_ref(),
182 );
183 let current_role = current_drain.role();
184
185 let rate_limit_continuation_prompt = raw.rate_limit_continuation_prompt.map(|repr| {
186 match repr {
187 RateLimitContinuationPromptRepr::LegacyString(prompt) => {
188 RateLimitContinuationPrompt {
191 drain: current_drain,
192 role: current_role,
193 prompt,
194 }
195 }
196 RateLimitContinuationPromptRepr::Structured {
197 _role: _,
198 drain,
199 prompt,
200 } => {
201 let prompt_drain = drain.unwrap_or(current_drain);
202 RateLimitContinuationPrompt {
203 drain: prompt_drain,
204 role: prompt_drain.role(),
205 prompt,
206 }
207 }
208 }
209 });
210
211 Ok(Self {
212 agents: raw.agents,
213 current_agent_index: raw.current_agent_index,
214 models_per_agent: raw.models_per_agent,
215 current_model_index: raw.current_model_index,
216 retry_cycle: raw.retry_cycle,
217 max_cycles: raw.max_cycles,
218 retry_delay_ms: raw.retry_delay_ms,
219 backoff_multiplier: raw.backoff_multiplier,
220 max_backoff_ms: raw.max_backoff_ms,
221 backoff_pending_ms: raw.backoff_pending_ms,
222 current_role,
223 current_drain,
224 current_mode: raw.current_mode,
225 rate_limit_continuation_prompt,
226 last_session_id: raw.last_session_id,
227 last_failure_reason: raw.last_failure_reason,
228 })
229 }
230}
231
232impl Default for AgentChainState {
233 fn default() -> Self {
234 Self {
235 agents: Arc::from(vec![]),
236 current_agent_index: 0,
237 models_per_agent: Arc::from(vec![]),
238 current_model_index: 0,
239 retry_cycle: 0,
240 max_cycles: 3,
241 retry_delay_ms: default_retry_delay_ms(),
242 backoff_multiplier: default_backoff_multiplier(),
243 max_backoff_ms: default_max_backoff_ms(),
244 backoff_pending_ms: None,
245 current_role: AgentRole::Developer,
246 current_drain: default_current_drain(),
247 current_mode: DrainMode::Normal,
248 rate_limit_continuation_prompt: None,
249 last_session_id: None,
250 last_failure_reason: None,
251 }
252 }
253}
254
255const fn default_retry_delay_ms() -> u64 {
256 1000
257}
258
259const fn default_backoff_multiplier() -> f64 {
260 2.0
261}
262
263const fn default_max_backoff_ms() -> u64 {
264 60000
265}
266
267const fn default_current_drain() -> AgentDrain {
268 AgentDrain::Planning
269}
270
271const fn agent_drain_signature_tag(drain: AgentDrain) -> &'static [u8] {
272 match drain {
273 AgentDrain::Planning => b"planning\n",
274 AgentDrain::Development => b"development\n",
275 AgentDrain::Review => b"review\n",
276 AgentDrain::Fix => b"fix\n",
277 AgentDrain::Commit => b"commit\n",
278 AgentDrain::Analysis => b"analysis\n",
279 }
280}
281
282impl AgentChainState {
283 #[must_use]
284 pub fn initial() -> Self {
285 Self {
286 agents: Arc::from(vec![]),
287 current_agent_index: 0,
288 models_per_agent: Arc::from(vec![]),
289 current_model_index: 0,
290 retry_cycle: 0,
291 max_cycles: 3,
292 retry_delay_ms: default_retry_delay_ms(),
293 backoff_multiplier: default_backoff_multiplier(),
294 max_backoff_ms: default_max_backoff_ms(),
295 backoff_pending_ms: None,
296 current_role: AgentRole::Developer,
297 current_drain: default_current_drain(),
298 current_mode: DrainMode::Normal,
299 rate_limit_continuation_prompt: None,
300 last_session_id: None,
301 last_failure_reason: None,
302 }
303 }
304
305 #[must_use]
306 pub fn matches_runtime_drain(&self, runtime_drain: AgentDrain) -> bool {
307 self.current_drain == runtime_drain
308 }
309
310 #[must_use]
311 pub fn with_agents(
312 self,
313 agents: Vec<String>,
314 models_per_agent: Vec<Vec<String>>,
315 role: AgentRole,
316 ) -> Self {
317 let current_drain = match role {
318 AgentRole::Developer => AgentDrain::Development,
319 AgentRole::Reviewer => AgentDrain::Review,
320 AgentRole::Commit => AgentDrain::Commit,
321 AgentRole::Analysis => AgentDrain::Analysis,
322 };
323 Self {
324 agents: Arc::from(agents),
325 models_per_agent: Arc::from(models_per_agent),
326 current_role: role,
327 current_drain,
328 current_mode: DrainMode::Normal,
329 ..self
330 }
331 }
332
333 #[must_use]
334 pub fn with_drain(self, drain: AgentDrain) -> Self {
335 Self {
336 current_drain: drain,
337 current_role: drain.role(),
338 ..self
339 }
340 }
341
342 #[must_use]
343 pub fn with_mode(self, mode: DrainMode) -> Self {
344 Self {
345 current_mode: mode,
346 ..self
347 }
348 }
349
350 #[must_use]
351 pub const fn active_role(&self) -> AgentRole {
352 self.current_drain.role()
353 }
354
355 #[must_use]
360 pub fn with_max_cycles(self, max_cycles: u32) -> Self {
361 Self { max_cycles, ..self }
362 }
363
364 #[must_use]
365 pub fn with_backoff_policy(
366 self,
367 retry_delay_ms: u64,
368 backoff_multiplier: f64,
369 max_backoff_ms: u64,
370 ) -> Self {
371 Self {
372 retry_delay_ms,
373 backoff_multiplier,
374 max_backoff_ms,
375 ..self
376 }
377 }
378
379 #[must_use]
380 pub fn with_retry_cycle(self, retry_cycle: u32) -> Self {
381 Self {
382 retry_cycle,
383 ..self
384 }
385 }
386
387 #[must_use]
388 pub fn with_current_agent_index(self, current_agent_index: usize) -> Self {
389 Self {
390 current_agent_index,
391 ..self
392 }
393 }
394
395 #[must_use]
396 pub fn current_agent(&self) -> Option<&String> {
397 self.agents.get(self.current_agent_index)
398 }
399
400 #[must_use]
409 pub fn consumer_signature_sha256(&self) -> String {
410 use itertools::Itertools;
411
412 let sorted_pairs: Vec<(String, Vec<String>)> = self
413 .agents
414 .iter()
415 .enumerate()
416 .map(|(idx, agent)| {
417 let models: Vec<String> = self
418 .models_per_agent
419 .get(idx)
420 .map_or_else(Vec::new, |m| m.clone());
421 (agent.clone(), models)
422 })
423 .sorted_by_key(|(agent, models)| (agent.clone(), models.clone()))
424 .collect();
425
426 let update_chain: Vec<Vec<u8>> = sorted_pairs
427 .iter()
428 .map(|(agent, models)| {
429 let models_bytes: Vec<u8> = models
430 .iter()
431 .map(|m| m.as_bytes())
432 .collect::<Vec<_>>()
433 .join(&b',');
434 let line: Vec<u8> = std::iter::empty()
435 .chain(agent.as_bytes().iter().copied())
436 .chain([b'|'])
437 .chain(models_bytes.iter().copied())
438 .chain([b'\n'])
439 .collect();
440 line
441 })
442 .collect();
443
444 let hasher = update_chain.iter().fold(
445 Digest::chain_update(Sha256::new(), agent_drain_signature_tag(self.current_drain)),
446 |h, chunk| Digest::chain_update(h, chunk.as_slice()),
447 );
448 let digest = hasher.finalize();
449 digest
450 .iter()
451 .map(|b| format!("{b:02x}"))
452 .collect::<String>()
453 }
454
455 #[cfg(test)]
456 fn legacy_consumer_signature_sha256_for_test(&self) -> String {
457 use itertools::Itertools;
458
459 let rendered: Vec<String> = self
460 .agents
461 .iter()
462 .enumerate()
463 .map(|(idx, agent)| {
464 let models = self
465 .models_per_agent
466 .get(idx)
467 .map_or([].as_slice(), std::vec::Vec::as_slice);
468 format!(
469 "{}|{}",
470 agent,
471 models
472 .iter()
473 .map(|s| s.as_str())
474 .collect::<Vec<_>>()
475 .join(",")
476 )
477 })
478 .sorted()
479 .collect();
480
481 let update_chain: Vec<&[u8]> = rendered
482 .iter()
483 .flat_map(|line| [line.as_bytes(), b"\n"])
484 .collect();
485
486 let hasher = update_chain.iter().fold(
487 Digest::chain_update(Sha256::new(), agent_drain_signature_tag(self.current_drain)),
488 |h, chunk| Digest::chain_update(h, *chunk),
489 );
490 let digest = hasher.finalize();
491 digest
492 .iter()
493 .map(|b| format!("{b:02x}"))
494 .collect::<String>()
495 }
496
497 #[must_use]
504 pub fn current_model(&self) -> Option<&String> {
505 self.models_per_agent
506 .get(self.current_agent_index)
507 .and_then(|models| models.get(self.current_model_index))
508 }
509
510 #[must_use]
511 pub const fn is_exhausted(&self) -> bool {
512 self.retry_cycle >= self.max_cycles
513 && self.current_agent_index == 0
514 && self.current_model_index == 0
515 }
516}
517
518#[cfg(test)]
519mod consumer_signature_tests {
520 use super::*;
521
522 #[test]
523 fn test_consumer_signature_sorting_matches_legacy_rendered_pair_ordering() {
524 let state = AgentChainState::initial().with_agents(
530 vec!["agent".to_string(), "agent".to_string()],
531 vec![
532 vec!["b".to_string()],
533 vec!["a".to_string(), "z".to_string()],
534 ],
535 AgentRole::Developer,
536 );
537
538 assert_eq!(
539 state.consumer_signature_sha256(),
540 state.legacy_consumer_signature_sha256_for_test(),
541 "consumer signature ordering must remain stable for the same configured consumers"
542 );
543 }
544
545 #[test]
546 fn test_consumer_signature_uses_stable_drain_encoding() {
547 let state = AgentChainState::initial()
548 .with_agents(
549 vec!["agent-a".to_string()],
550 vec![vec!["m1".to_string(), "m2".to_string()]],
551 AgentRole::Reviewer,
552 )
553 .with_drain(AgentDrain::Fix);
554
555 let data = b"fix\nagent-a|m1,m2\n".to_vec();
556 let expected = Sha256::digest(&data)
557 .iter()
558 .fold(String::new(), |mut acc, b| {
559 use std::fmt::Write;
560 write!(acc, "{b:02x}").unwrap();
561 acc
562 });
563
564 assert_eq!(
565 state.consumer_signature_sha256(),
566 expected,
567 "role encoding must be stable and explicit"
568 );
569 }
570}
571
572#[cfg(test)]
573mod legacy_rate_limit_prompt_tests {
574 use super::*;
575
576 #[test]
577 fn test_legacy_rate_limit_continuation_prompt_uses_current_role_on_deserialize() {
578 let state = AgentChainState::initial().with_agents(
582 vec!["a".to_string()],
583 vec![vec![]],
584 AgentRole::Reviewer,
585 );
586
587 let mut v = serde_json::to_value(&state).expect("serialize AgentChainState");
588 v["rate_limit_continuation_prompt"] = serde_json::Value::String("legacy prompt".into());
589
590 let json = serde_json::to_string(&v).expect("serialize JSON value");
591 let decoded: AgentChainState =
592 serde_json::from_str(&json).expect("deserialize AgentChainState");
593
594 let prompt = decoded
595 .rate_limit_continuation_prompt
596 .expect("expected legacy prompt to deserialize");
597 assert_eq!(prompt.drain, AgentDrain::Review);
598 assert_eq!(prompt.role, AgentRole::Reviewer);
599 assert_eq!(prompt.prompt, "legacy prompt");
600 }
601}