ralph_workflow/reducer/state/agent_chain/
mod.rs1use std::sync::Arc;
15
16use serde::{Deserialize, Serialize};
17use sha2::{Digest, Sha256};
18
19pub use crate::agents::AgentRole;
20
21mod backoff;
22mod transitions;
23
24#[derive(Clone, Serialize, Debug)]
38pub struct AgentChainState {
39 pub agents: Arc<[String]>,
42 pub current_agent_index: usize,
43 pub models_per_agent: Arc<[Vec<String>]>,
46 pub current_model_index: usize,
47 pub retry_cycle: u32,
48 pub max_cycles: u32,
49 #[serde(default = "default_retry_delay_ms")]
51 pub retry_delay_ms: u64,
52 #[serde(default = "default_backoff_multiplier")]
54 pub backoff_multiplier: f64,
55 #[serde(default = "default_max_backoff_ms")]
57 pub max_backoff_ms: u64,
58 #[serde(default)]
60 pub backoff_pending_ms: Option<u64>,
61 pub current_role: AgentRole,
62 #[serde(default)]
70 pub rate_limit_continuation_prompt: Option<RateLimitContinuationPrompt>,
71 #[serde(default)]
77 pub last_session_id: Option<String>,
78}
79
80#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Eq)]
82pub struct RateLimitContinuationPrompt {
83 pub role: AgentRole,
84 pub prompt: String,
85}
86
87#[derive(Deserialize)]
88#[serde(untagged)]
89enum RateLimitContinuationPromptRepr {
90 LegacyString(String),
91 Structured { role: AgentRole, prompt: String },
92}
93
94impl<'de> Deserialize<'de> for AgentChainState {
95 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
96 where
97 D: serde::Deserializer<'de>,
98 {
99 #[derive(Deserialize)]
100 struct AgentChainStateSerde {
101 agents: Arc<[String]>,
102 current_agent_index: usize,
103 models_per_agent: Arc<[Vec<String>]>,
104 current_model_index: usize,
105 retry_cycle: u32,
106 max_cycles: u32,
107 #[serde(default = "default_retry_delay_ms")]
108 retry_delay_ms: u64,
109 #[serde(default = "default_backoff_multiplier")]
110 backoff_multiplier: f64,
111 #[serde(default = "default_max_backoff_ms")]
112 max_backoff_ms: u64,
113 #[serde(default)]
114 backoff_pending_ms: Option<u64>,
115 current_role: AgentRole,
116 #[serde(default)]
117 rate_limit_continuation_prompt: Option<RateLimitContinuationPromptRepr>,
118 #[serde(default)]
119 last_session_id: Option<String>,
120 }
121
122 let raw = AgentChainStateSerde::deserialize(deserializer)?;
123
124 let rate_limit_continuation_prompt = raw.rate_limit_continuation_prompt.map(|repr| {
125 match repr {
126 RateLimitContinuationPromptRepr::LegacyString(prompt) => {
127 RateLimitContinuationPrompt {
130 role: raw.current_role,
131 prompt,
132 }
133 }
134 RateLimitContinuationPromptRepr::Structured { role, prompt } => {
135 RateLimitContinuationPrompt { role, prompt }
136 }
137 }
138 });
139
140 Ok(Self {
141 agents: raw.agents,
142 current_agent_index: raw.current_agent_index,
143 models_per_agent: raw.models_per_agent,
144 current_model_index: raw.current_model_index,
145 retry_cycle: raw.retry_cycle,
146 max_cycles: raw.max_cycles,
147 retry_delay_ms: raw.retry_delay_ms,
148 backoff_multiplier: raw.backoff_multiplier,
149 max_backoff_ms: raw.max_backoff_ms,
150 backoff_pending_ms: raw.backoff_pending_ms,
151 current_role: raw.current_role,
152 rate_limit_continuation_prompt,
153 last_session_id: raw.last_session_id,
154 })
155 }
156}
157
158const fn default_retry_delay_ms() -> u64 {
159 1000
160}
161
162const fn default_backoff_multiplier() -> f64 {
163 2.0
164}
165
166const fn default_max_backoff_ms() -> u64 {
167 60000
168}
169
170const fn agent_role_signature_tag(role: AgentRole) -> &'static [u8] {
171 match role {
172 AgentRole::Developer => b"developer\n",
173 AgentRole::Reviewer => b"reviewer\n",
174 AgentRole::Commit => b"commit\n",
175 AgentRole::Analysis => b"analysis\n",
176 }
177}
178
179impl AgentChainState {
180 #[must_use]
181 pub fn initial() -> Self {
182 Self {
183 agents: Arc::from(vec![]),
184 current_agent_index: 0,
185 models_per_agent: Arc::from(vec![]),
186 current_model_index: 0,
187 retry_cycle: 0,
188 max_cycles: 3,
189 retry_delay_ms: default_retry_delay_ms(),
190 backoff_multiplier: default_backoff_multiplier(),
191 max_backoff_ms: default_max_backoff_ms(),
192 backoff_pending_ms: None,
193 current_role: AgentRole::Developer,
194 rate_limit_continuation_prompt: None,
195 last_session_id: None,
196 }
197 }
198
199 #[must_use]
200 pub fn with_agents(
201 mut self,
202 agents: Vec<String>,
203 models_per_agent: Vec<Vec<String>>,
204 role: AgentRole,
205 ) -> Self {
206 self.agents = Arc::from(agents);
207 self.models_per_agent = Arc::from(models_per_agent);
208 self.current_role = role;
209 self
210 }
211
212 #[must_use]
217 pub const fn with_max_cycles(mut self, max_cycles: u32) -> Self {
218 self.max_cycles = max_cycles;
219 self
220 }
221
222 #[must_use]
223 pub const fn with_backoff_policy(
224 mut self,
225 retry_delay_ms: u64,
226 backoff_multiplier: f64,
227 max_backoff_ms: u64,
228 ) -> Self {
229 self.retry_delay_ms = retry_delay_ms;
230 self.backoff_multiplier = backoff_multiplier;
231 self.max_backoff_ms = max_backoff_ms;
232 self
233 }
234
235 #[must_use]
236 pub fn current_agent(&self) -> Option<&String> {
237 self.agents.get(self.current_agent_index)
238 }
239
240 #[must_use]
249 pub fn consumer_signature_sha256(&self) -> String {
250 let mut pairs: Vec<(&str, &[String])> = self
251 .agents
252 .iter()
253 .enumerate()
254 .map(|(idx, agent)| {
255 let models: &[String] = self
256 .models_per_agent
257 .get(idx)
258 .map_or([].as_slice(), std::vec::Vec::as_slice);
259 (agent.as_str(), models)
260 })
261 .collect();
262
263 pairs.sort_by(|(agent_a, models_a), (agent_b, models_b)| {
266 use std::cmp::Ordering;
267
268 let agent_ord = agent_a.cmp(agent_b);
269 if agent_ord != Ordering::Equal {
270 return agent_ord;
271 }
272
273 for (a, b) in models_a.iter().zip(models_b.iter()) {
274 let ord = a.cmp(b);
275 if ord != Ordering::Equal {
276 return ord;
277 }
278 }
279
280 models_a.len().cmp(&models_b.len())
281 });
282
283 let mut hasher = Sha256::new();
284 hasher.update(agent_role_signature_tag(self.current_role));
285 for (agent, models) in pairs {
286 hasher.update(agent.as_bytes());
287 hasher.update(b"|");
288 for (idx, model) in models.iter().enumerate() {
289 if idx > 0 {
290 hasher.update(b",");
291 }
292 hasher.update(model.as_bytes());
293 }
294 hasher.update(b"\n");
295 }
296 let digest = hasher.finalize();
297 digest.iter().fold(String::new(), |mut s, b| {
298 use std::fmt::Write;
299 write!(&mut s, "{b:02x}").unwrap();
300 s
301 })
302 }
303
304 #[cfg(test)]
305 fn legacy_consumer_signature_sha256_for_test(&self) -> String {
306 let mut rendered: Vec<String> = self
307 .agents
308 .iter()
309 .enumerate()
310 .map(|(idx, agent)| {
311 let models = self
312 .models_per_agent
313 .get(idx)
314 .map_or([].as_slice(), std::vec::Vec::as_slice);
315 format!("{}|{}", agent, models.join(","))
316 })
317 .collect();
318
319 rendered.sort();
320
321 let mut hasher = Sha256::new();
322 hasher.update(agent_role_signature_tag(self.current_role));
323 for line in rendered {
324 hasher.update(line.as_bytes());
325 hasher.update(b"\n");
326 }
327 let digest = hasher.finalize();
328 digest.iter().fold(String::new(), |mut s, b| {
329 use std::fmt::Write;
330 write!(&mut s, "{b:02x}").unwrap();
331 s
332 })
333 }
334
335 #[must_use]
342 pub fn current_model(&self) -> Option<&String> {
343 self.models_per_agent
344 .get(self.current_agent_index)
345 .and_then(|models| models.get(self.current_model_index))
346 }
347
348 #[must_use]
349 pub const fn is_exhausted(&self) -> bool {
350 self.retry_cycle >= self.max_cycles
351 && self.current_agent_index == 0
352 && self.current_model_index == 0
353 }
354}
355
356#[cfg(test)]
357mod consumer_signature_tests {
358 use super::*;
359
360 #[test]
361 fn test_consumer_signature_sorting_matches_legacy_rendered_pair_ordering() {
362 let state = AgentChainState::initial().with_agents(
368 vec!["agent".to_string(), "agent".to_string()],
369 vec![
370 vec!["b".to_string()],
371 vec!["a".to_string(), "z".to_string()],
372 ],
373 AgentRole::Developer,
374 );
375
376 assert_eq!(
377 state.consumer_signature_sha256(),
378 state.legacy_consumer_signature_sha256_for_test(),
379 "consumer signature ordering must remain stable for the same configured consumers"
380 );
381 }
382
383 #[test]
384 fn test_consumer_signature_uses_stable_role_encoding() {
385 let state = AgentChainState::initial().with_agents(
389 vec!["agent-a".to_string()],
390 vec![vec!["m1".to_string(), "m2".to_string()]],
391 AgentRole::Reviewer,
392 );
393
394 let mut hasher = Sha256::new();
395 hasher.update(b"reviewer\n");
396 hasher.update(b"agent-a");
397 hasher.update(b"|");
398 hasher.update(b"m1");
399 hasher.update(b",");
400 hasher.update(b"m2");
401 hasher.update(b"\n");
402 let expected = hasher.finalize().iter().fold(String::new(), |mut acc, b| {
403 use std::fmt::Write;
404 write!(acc, "{b:02x}").unwrap();
405 acc
406 });
407
408 assert_eq!(
409 state.consumer_signature_sha256(),
410 expected,
411 "role encoding must be stable and explicit"
412 );
413 }
414}
415
416#[cfg(test)]
417mod legacy_rate_limit_prompt_tests {
418 use super::*;
419
420 #[test]
421 fn test_legacy_rate_limit_continuation_prompt_uses_current_role_on_deserialize() {
422 let state = AgentChainState::initial().with_agents(
426 vec!["a".to_string()],
427 vec![vec![]],
428 AgentRole::Reviewer,
429 );
430
431 let mut v = serde_json::to_value(&state).expect("serialize AgentChainState");
432 v["rate_limit_continuation_prompt"] = serde_json::Value::String("legacy prompt".into());
433
434 let json = serde_json::to_string(&v).expect("serialize JSON value");
435 let decoded: AgentChainState =
436 serde_json::from_str(&json).expect("deserialize AgentChainState");
437
438 let prompt = decoded
439 .rate_limit_continuation_prompt
440 .expect("expected legacy prompt to deserialize");
441 assert_eq!(prompt.role, AgentRole::Reviewer);
442 assert_eq!(prompt.prompt, "legacy prompt");
443 }
444}