1use std::sync::Arc;
15
16use serde::{Deserialize, Serialize};
17use sha2::{Digest, Sha256};
18
19pub use crate::agents::{AgentDrain, AgentRole, DrainMode};
20
21mod backoff;
22mod transitions;
23
24#[derive(Clone, Serialize, Debug)]
38pub struct AgentChainState {
39 pub agents: Arc<[String]>,
42 pub current_agent_index: usize,
43 pub models_per_agent: Arc<[Vec<String>]>,
46 pub current_model_index: usize,
47 pub retry_cycle: u32,
48 pub max_cycles: u32,
49 #[serde(default = "default_retry_delay_ms")]
51 pub retry_delay_ms: u64,
52 #[serde(default = "default_backoff_multiplier")]
54 pub backoff_multiplier: f64,
55 #[serde(default = "default_max_backoff_ms")]
57 pub max_backoff_ms: u64,
58 #[serde(default)]
60 pub backoff_pending_ms: Option<u64>,
61 pub current_role: AgentRole,
67 #[serde(default = "default_current_drain")]
68 pub current_drain: AgentDrain,
69 #[serde(default)]
70 pub current_mode: DrainMode,
71 #[serde(default)]
79 pub rate_limit_continuation_prompt: Option<RateLimitContinuationPrompt>,
80 #[serde(default)]
86 pub last_session_id: Option<String>,
87 #[serde(default)]
92 pub last_failure_reason: Option<String>,
93}
94
95#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Eq)]
97pub struct RateLimitContinuationPrompt {
98 pub drain: AgentDrain,
99 pub role: AgentRole,
100 pub prompt: String,
101}
102
103#[derive(Deserialize)]
104#[serde(untagged)]
105enum RateLimitContinuationPromptRepr {
106 LegacyString(String),
107 Structured {
108 #[serde(rename = "role")]
109 _role: AgentRole,
110 #[serde(default)]
111 drain: Option<AgentDrain>,
112 prompt: String,
113 },
114}
115
116fn infer_legacy_current_drain(
117 current_drain: Option<AgentDrain>,
118 current_role: Option<AgentRole>,
119 current_mode: DrainMode,
120 continuation_prompt: Option<&RateLimitContinuationPromptRepr>,
121) -> AgentDrain {
122 if let Some(current_drain) = current_drain {
123 return current_drain;
124 }
125
126 if let Some(prompt_drain) = continuation_prompt.and_then(|prompt| match prompt {
127 RateLimitContinuationPromptRepr::LegacyString(_) => None,
128 RateLimitContinuationPromptRepr::Structured { drain, .. } => *drain,
129 }) {
130 return prompt_drain;
131 }
132
133 match (current_role, current_mode) {
134 (Some(AgentRole::Reviewer), DrainMode::Continuation) => AgentDrain::Fix,
135 (Some(AgentRole::Developer), DrainMode::Continuation) => AgentDrain::Development,
136 (Some(current_role), _) => AgentDrain::from(current_role),
137 (None, _) => default_current_drain(),
138 }
139}
140
141impl<'de> Deserialize<'de> for AgentChainState {
142 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
143 where
144 D: serde::Deserializer<'de>,
145 {
146 #[derive(Deserialize)]
147 struct AgentChainStateSerde {
148 agents: Arc<[String]>,
149 current_agent_index: usize,
150 models_per_agent: Arc<[Vec<String>]>,
151 current_model_index: usize,
152 retry_cycle: u32,
153 max_cycles: u32,
154 #[serde(default = "default_retry_delay_ms")]
155 retry_delay_ms: u64,
156 #[serde(default = "default_backoff_multiplier")]
157 backoff_multiplier: f64,
158 #[serde(default = "default_max_backoff_ms")]
159 max_backoff_ms: u64,
160 #[serde(default)]
161 backoff_pending_ms: Option<u64>,
162 #[serde(default)]
163 current_drain: Option<AgentDrain>,
164 #[serde(default)]
165 current_role: Option<AgentRole>,
166 #[serde(default)]
167 current_mode: DrainMode,
168 #[serde(default)]
169 rate_limit_continuation_prompt: Option<RateLimitContinuationPromptRepr>,
170 #[serde(default)]
171 last_session_id: Option<String>,
172 #[serde(default)]
173 last_failure_reason: Option<String>,
174 }
175
176 let raw = AgentChainStateSerde::deserialize(deserializer)?;
177 let current_drain = infer_legacy_current_drain(
178 raw.current_drain,
179 raw.current_role,
180 raw.current_mode,
181 raw.rate_limit_continuation_prompt.as_ref(),
182 );
183 let current_role = current_drain.role();
184
185 let rate_limit_continuation_prompt = raw.rate_limit_continuation_prompt.map(|repr| {
186 match repr {
187 RateLimitContinuationPromptRepr::LegacyString(prompt) => {
188 RateLimitContinuationPrompt {
191 drain: current_drain,
192 role: current_role,
193 prompt,
194 }
195 }
196 RateLimitContinuationPromptRepr::Structured {
197 _role: _,
198 drain,
199 prompt,
200 } => {
201 let prompt_drain = drain.unwrap_or(current_drain);
202 RateLimitContinuationPrompt {
203 drain: prompt_drain,
204 role: prompt_drain.role(),
205 prompt,
206 }
207 }
208 }
209 });
210
211 Ok(Self {
212 agents: raw.agents,
213 current_agent_index: raw.current_agent_index,
214 models_per_agent: raw.models_per_agent,
215 current_model_index: raw.current_model_index,
216 retry_cycle: raw.retry_cycle,
217 max_cycles: raw.max_cycles,
218 retry_delay_ms: raw.retry_delay_ms,
219 backoff_multiplier: raw.backoff_multiplier,
220 max_backoff_ms: raw.max_backoff_ms,
221 backoff_pending_ms: raw.backoff_pending_ms,
222 current_role,
223 current_drain,
224 current_mode: raw.current_mode,
225 rate_limit_continuation_prompt,
226 last_session_id: raw.last_session_id,
227 last_failure_reason: raw.last_failure_reason,
228 })
229 }
230}
231
232const fn default_retry_delay_ms() -> u64 {
233 1000
234}
235
236const fn default_backoff_multiplier() -> f64 {
237 2.0
238}
239
240const fn default_max_backoff_ms() -> u64 {
241 60000
242}
243
244const fn default_current_drain() -> AgentDrain {
245 AgentDrain::Planning
246}
247
248const fn agent_drain_signature_tag(drain: AgentDrain) -> &'static [u8] {
249 match drain {
250 AgentDrain::Planning => b"planning\n",
251 AgentDrain::Development => b"development\n",
252 AgentDrain::Review => b"review\n",
253 AgentDrain::Fix => b"fix\n",
254 AgentDrain::Commit => b"commit\n",
255 AgentDrain::Analysis => b"analysis\n",
256 }
257}
258
259impl AgentChainState {
260 #[must_use]
261 pub fn initial() -> Self {
262 Self {
263 agents: Arc::from(vec![]),
264 current_agent_index: 0,
265 models_per_agent: Arc::from(vec![]),
266 current_model_index: 0,
267 retry_cycle: 0,
268 max_cycles: 3,
269 retry_delay_ms: default_retry_delay_ms(),
270 backoff_multiplier: default_backoff_multiplier(),
271 max_backoff_ms: default_max_backoff_ms(),
272 backoff_pending_ms: None,
273 current_role: AgentRole::Developer,
274 current_drain: default_current_drain(),
275 current_mode: DrainMode::Normal,
276 rate_limit_continuation_prompt: None,
277 last_session_id: None,
278 last_failure_reason: None,
279 }
280 }
281
282 #[must_use]
283 pub fn matches_runtime_drain(&self, runtime_drain: AgentDrain) -> bool {
284 self.current_drain == runtime_drain
285 }
286
287 #[must_use]
288 pub fn with_agents(
289 mut self,
290 agents: Vec<String>,
291 models_per_agent: Vec<Vec<String>>,
292 role: AgentRole,
293 ) -> Self {
294 self.agents = Arc::from(agents);
295 self.models_per_agent = Arc::from(models_per_agent);
296 self.current_role = role;
297 self.current_drain = match role {
298 AgentRole::Developer => AgentDrain::Development,
299 AgentRole::Reviewer => AgentDrain::Review,
300 AgentRole::Commit => AgentDrain::Commit,
301 AgentRole::Analysis => AgentDrain::Analysis,
302 };
303 self.current_mode = DrainMode::Normal;
304 self
305 }
306
307 #[must_use]
308 pub const fn with_drain(mut self, drain: AgentDrain) -> Self {
309 self.current_drain = drain;
310 self.current_role = drain.role();
311 self
312 }
313
314 #[must_use]
315 pub const fn with_mode(mut self, mode: DrainMode) -> Self {
316 self.current_mode = mode;
317 self
318 }
319
320 #[must_use]
321 pub const fn active_role(&self) -> AgentRole {
322 self.current_drain.role()
323 }
324
325 #[must_use]
330 pub const fn with_max_cycles(mut self, max_cycles: u32) -> Self {
331 self.max_cycles = max_cycles;
332 self
333 }
334
335 #[must_use]
336 pub const fn with_backoff_policy(
337 mut self,
338 retry_delay_ms: u64,
339 backoff_multiplier: f64,
340 max_backoff_ms: u64,
341 ) -> Self {
342 self.retry_delay_ms = retry_delay_ms;
343 self.backoff_multiplier = backoff_multiplier;
344 self.max_backoff_ms = max_backoff_ms;
345 self
346 }
347
348 #[must_use]
349 pub fn current_agent(&self) -> Option<&String> {
350 self.agents.get(self.current_agent_index)
351 }
352
353 #[must_use]
362 pub fn consumer_signature_sha256(&self) -> String {
363 let mut pairs: Vec<(&str, &[String])> = self
364 .agents
365 .iter()
366 .enumerate()
367 .map(|(idx, agent)| {
368 let models: &[String] = self
369 .models_per_agent
370 .get(idx)
371 .map_or([].as_slice(), std::vec::Vec::as_slice);
372 (agent.as_str(), models)
373 })
374 .collect();
375
376 pairs.sort_by(|(agent_a, models_a), (agent_b, models_b)| {
379 use std::cmp::Ordering;
380
381 let agent_ord = agent_a.cmp(agent_b);
382 if agent_ord != Ordering::Equal {
383 return agent_ord;
384 }
385
386 for (a, b) in models_a.iter().zip(models_b.iter()) {
387 let ord = a.cmp(b);
388 if ord != Ordering::Equal {
389 return ord;
390 }
391 }
392
393 models_a.len().cmp(&models_b.len())
394 });
395
396 let mut hasher = Sha256::new();
397 hasher.update(agent_drain_signature_tag(self.current_drain));
398 for (agent, models) in pairs {
399 hasher.update(agent.as_bytes());
400 hasher.update(b"|");
401 for (idx, model) in models.iter().enumerate() {
402 if idx > 0 {
403 hasher.update(b",");
404 }
405 hasher.update(model.as_bytes());
406 }
407 hasher.update(b"\n");
408 }
409 let digest = hasher.finalize();
410 digest.iter().fold(String::new(), |mut s, b| {
411 use std::fmt::Write;
412 write!(&mut s, "{b:02x}").unwrap();
413 s
414 })
415 }
416
417 #[cfg(test)]
418 fn legacy_consumer_signature_sha256_for_test(&self) -> String {
419 let mut rendered: Vec<String> = self
420 .agents
421 .iter()
422 .enumerate()
423 .map(|(idx, agent)| {
424 let models = self
425 .models_per_agent
426 .get(idx)
427 .map_or([].as_slice(), std::vec::Vec::as_slice);
428 format!("{}|{}", agent, models.join(","))
429 })
430 .collect();
431
432 rendered.sort();
433
434 let mut hasher = Sha256::new();
435 hasher.update(agent_drain_signature_tag(self.current_drain));
436 for line in rendered {
437 hasher.update(line.as_bytes());
438 hasher.update(b"\n");
439 }
440 let digest = hasher.finalize();
441 digest.iter().fold(String::new(), |mut s, b| {
442 use std::fmt::Write;
443 write!(&mut s, "{b:02x}").unwrap();
444 s
445 })
446 }
447
448 #[must_use]
455 pub fn current_model(&self) -> Option<&String> {
456 self.models_per_agent
457 .get(self.current_agent_index)
458 .and_then(|models| models.get(self.current_model_index))
459 }
460
461 #[must_use]
462 pub const fn is_exhausted(&self) -> bool {
463 self.retry_cycle >= self.max_cycles
464 && self.current_agent_index == 0
465 && self.current_model_index == 0
466 }
467}
468
469#[cfg(test)]
470mod consumer_signature_tests {
471 use super::*;
472
473 #[test]
474 fn test_consumer_signature_sorting_matches_legacy_rendered_pair_ordering() {
475 let state = AgentChainState::initial().with_agents(
481 vec!["agent".to_string(), "agent".to_string()],
482 vec![
483 vec!["b".to_string()],
484 vec!["a".to_string(), "z".to_string()],
485 ],
486 AgentRole::Developer,
487 );
488
489 assert_eq!(
490 state.consumer_signature_sha256(),
491 state.legacy_consumer_signature_sha256_for_test(),
492 "consumer signature ordering must remain stable for the same configured consumers"
493 );
494 }
495
496 #[test]
497 fn test_consumer_signature_uses_stable_drain_encoding() {
498 let state = AgentChainState::initial()
502 .with_agents(
503 vec!["agent-a".to_string()],
504 vec![vec!["m1".to_string(), "m2".to_string()]],
505 AgentRole::Reviewer,
506 )
507 .with_drain(AgentDrain::Fix);
508
509 let mut hasher = Sha256::new();
510 hasher.update(b"fix\n");
511 hasher.update(b"agent-a");
512 hasher.update(b"|");
513 hasher.update(b"m1");
514 hasher.update(b",");
515 hasher.update(b"m2");
516 hasher.update(b"\n");
517 let expected = hasher.finalize().iter().fold(String::new(), |mut acc, b| {
518 use std::fmt::Write;
519 write!(acc, "{b:02x}").unwrap();
520 acc
521 });
522
523 assert_eq!(
524 state.consumer_signature_sha256(),
525 expected,
526 "role encoding must be stable and explicit"
527 );
528 }
529}
530
531#[cfg(test)]
532mod legacy_rate_limit_prompt_tests {
533 use super::*;
534
535 #[test]
536 fn test_legacy_rate_limit_continuation_prompt_uses_current_role_on_deserialize() {
537 let state = AgentChainState::initial().with_agents(
541 vec!["a".to_string()],
542 vec![vec![]],
543 AgentRole::Reviewer,
544 );
545
546 let mut v = serde_json::to_value(&state).expect("serialize AgentChainState");
547 v["rate_limit_continuation_prompt"] = serde_json::Value::String("legacy prompt".into());
548
549 let json = serde_json::to_string(&v).expect("serialize JSON value");
550 let decoded: AgentChainState =
551 serde_json::from_str(&json).expect("deserialize AgentChainState");
552
553 let prompt = decoded
554 .rate_limit_continuation_prompt
555 .expect("expected legacy prompt to deserialize");
556 assert_eq!(prompt.drain, AgentDrain::Review);
557 assert_eq!(prompt.role, AgentRole::Reviewer);
558 assert_eq!(prompt.prompt, "legacy prompt");
559 }
560}