Skip to main content

ralph_workflow/reducer/state/agent_chain/
mod.rs

1// Agent fallback chain state.
2//
3// Contains AgentChainState and backoff computation helpers.
4//
5// # Performance Optimization
6//
7// AgentChainState uses Arc<[T]> for immutable collections (agents, models_per_agent)
8// to enable cheap state copying during state transitions. This eliminates O(n) deep
9// copy overhead and makes state transitions O(1) for collection fields.
10//
11// The reducer creates new state instances on every event, so this optimization
12// significantly reduces memory allocations and improves performance.
13
14use std::sync::Arc;
15
16use serde::{Deserialize, Serialize};
17use sha2::{Digest, Sha256};
18
19pub use crate::agents::AgentRole;
20
21mod backoff;
22mod transitions;
23
24/// Agent fallback chain state (explicit, not loop indices).
25///
26/// Tracks position in the multi-level fallback chain:
27/// - Agent level (primary → fallback1 → fallback2)
28/// - Model level (within each agent, try different models)
29/// - Retry cycle (exhaust all agents, start over with exponential backoff)
30///
31/// # Memory Optimization
32///
33/// Uses Arc<[T]> for `agents` and `models_per_agent` collections to enable
34/// cheap cloning during state transitions. Since these collections are immutable
35/// after construction, `Arc::clone` only increments a reference count instead of
36/// deep copying the entire collection.
37#[derive(Clone, Serialize, Debug)]
38pub struct AgentChainState {
39    /// Agent names in fallback order. Arc<[String]> enables cheap cloning
40    /// via reference counting instead of deep copying the collection.
41    pub agents: Arc<[String]>,
42    pub current_agent_index: usize,
43    /// Models per agent. Arc for immutable outer collection with cheap cloning.
44    /// Inner Vec<String> is kept for runtime indexing during model selection.
45    pub models_per_agent: Arc<[Vec<String>]>,
46    pub current_model_index: usize,
47    pub retry_cycle: u32,
48    pub max_cycles: u32,
49    /// Base delay between retry cycles in milliseconds.
50    #[serde(default = "default_retry_delay_ms")]
51    pub retry_delay_ms: u64,
52    /// Multiplier for exponential backoff.
53    #[serde(default = "default_backoff_multiplier")]
54    pub backoff_multiplier: f64,
55    /// Maximum backoff delay in milliseconds.
56    #[serde(default = "default_max_backoff_ms")]
57    pub max_backoff_ms: u64,
58    /// Pending backoff delay (milliseconds) that must be waited before continuing.
59    #[serde(default)]
60    pub backoff_pending_ms: Option<u64>,
61    pub current_role: AgentRole,
62    /// Prompt context preserved from a rate-limited agent for continuation.
63    ///
64    /// When an agent hits 429, we save the prompt here so the next agent can
65    /// continue the SAME role/task instead of starting from scratch.
66    ///
67    /// IMPORTANT: This must be role-scoped to prevent cross-task contamination
68    /// (e.g., a developer continuation prompt overriding an analysis prompt).
69    #[serde(default)]
70    pub rate_limit_continuation_prompt: Option<RateLimitContinuationPrompt>,
71    /// Session ID from the last agent response.
72    ///
73    /// Used for XSD retry to continue with the same session when possible.
74    /// Agents that support sessions (e.g., Claude Code) emit session IDs
75    /// that can be passed back for continuation.
76    #[serde(default)]
77    pub last_session_id: Option<String>,
78}
79
80/// Role-scoped continuation prompt captured from a rate limit (429).
81#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Eq)]
82pub struct RateLimitContinuationPrompt {
83    pub role: AgentRole,
84    pub prompt: String,
85}
86
87#[derive(Deserialize)]
88#[serde(untagged)]
89enum RateLimitContinuationPromptRepr {
90    LegacyString(String),
91    Structured { role: AgentRole, prompt: String },
92}
93
94impl<'de> Deserialize<'de> for AgentChainState {
95    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
96    where
97        D: serde::Deserializer<'de>,
98    {
99        #[derive(Deserialize)]
100        struct AgentChainStateSerde {
101            agents: Arc<[String]>,
102            current_agent_index: usize,
103            models_per_agent: Arc<[Vec<String>]>,
104            current_model_index: usize,
105            retry_cycle: u32,
106            max_cycles: u32,
107            #[serde(default = "default_retry_delay_ms")]
108            retry_delay_ms: u64,
109            #[serde(default = "default_backoff_multiplier")]
110            backoff_multiplier: f64,
111            #[serde(default = "default_max_backoff_ms")]
112            max_backoff_ms: u64,
113            #[serde(default)]
114            backoff_pending_ms: Option<u64>,
115            current_role: AgentRole,
116            #[serde(default)]
117            rate_limit_continuation_prompt: Option<RateLimitContinuationPromptRepr>,
118            #[serde(default)]
119            last_session_id: Option<String>,
120        }
121
122        let raw = AgentChainStateSerde::deserialize(deserializer)?;
123
124        let rate_limit_continuation_prompt = raw.rate_limit_continuation_prompt.map(|repr| {
125            match repr {
126                RateLimitContinuationPromptRepr::LegacyString(prompt) => {
127                    // Legacy checkpoints stored only the prompt string. Scope it to the
128                    // chain's current role so resume can't cross-contaminate roles.
129                    RateLimitContinuationPrompt {
130                        role: raw.current_role,
131                        prompt,
132                    }
133                }
134                RateLimitContinuationPromptRepr::Structured { role, prompt } => {
135                    RateLimitContinuationPrompt { role, prompt }
136                }
137            }
138        });
139
140        Ok(Self {
141            agents: raw.agents,
142            current_agent_index: raw.current_agent_index,
143            models_per_agent: raw.models_per_agent,
144            current_model_index: raw.current_model_index,
145            retry_cycle: raw.retry_cycle,
146            max_cycles: raw.max_cycles,
147            retry_delay_ms: raw.retry_delay_ms,
148            backoff_multiplier: raw.backoff_multiplier,
149            max_backoff_ms: raw.max_backoff_ms,
150            backoff_pending_ms: raw.backoff_pending_ms,
151            current_role: raw.current_role,
152            rate_limit_continuation_prompt,
153            last_session_id: raw.last_session_id,
154        })
155    }
156}
157
158const fn default_retry_delay_ms() -> u64 {
159    1000
160}
161
162const fn default_backoff_multiplier() -> f64 {
163    2.0
164}
165
166const fn default_max_backoff_ms() -> u64 {
167    60000
168}
169
170const fn agent_role_signature_tag(role: AgentRole) -> &'static [u8] {
171    match role {
172        AgentRole::Developer => b"developer\n",
173        AgentRole::Reviewer => b"reviewer\n",
174        AgentRole::Commit => b"commit\n",
175        AgentRole::Analysis => b"analysis\n",
176    }
177}
178
179impl AgentChainState {
180    #[must_use]
181    pub fn initial() -> Self {
182        Self {
183            agents: Arc::from(vec![]),
184            current_agent_index: 0,
185            models_per_agent: Arc::from(vec![]),
186            current_model_index: 0,
187            retry_cycle: 0,
188            max_cycles: 3,
189            retry_delay_ms: default_retry_delay_ms(),
190            backoff_multiplier: default_backoff_multiplier(),
191            max_backoff_ms: default_max_backoff_ms(),
192            backoff_pending_ms: None,
193            current_role: AgentRole::Developer,
194            rate_limit_continuation_prompt: None,
195            last_session_id: None,
196        }
197    }
198
199    #[must_use]
200    pub fn with_agents(
201        mut self,
202        agents: Vec<String>,
203        models_per_agent: Vec<Vec<String>>,
204        role: AgentRole,
205    ) -> Self {
206        self.agents = Arc::from(agents);
207        self.models_per_agent = Arc::from(models_per_agent);
208        self.current_role = role;
209        self
210    }
211
212    /// Builder method to set the maximum number of retry cycles.
213    ///
214    /// A retry cycle is when all agents have been exhausted and we start
215    /// over with exponential backoff.
216    #[must_use]
217    pub const fn with_max_cycles(mut self, max_cycles: u32) -> Self {
218        self.max_cycles = max_cycles;
219        self
220    }
221
222    #[must_use]
223    pub const fn with_backoff_policy(
224        mut self,
225        retry_delay_ms: u64,
226        backoff_multiplier: f64,
227        max_backoff_ms: u64,
228    ) -> Self {
229        self.retry_delay_ms = retry_delay_ms;
230        self.backoff_multiplier = backoff_multiplier;
231        self.max_backoff_ms = max_backoff_ms;
232        self
233    }
234
235    #[must_use]
236    pub fn current_agent(&self) -> Option<&String> {
237        self.agents.get(self.current_agent_index)
238    }
239
240    /// Stable signature of the current consumer set (agents + configured models + role).
241    ///
242    /// This is used to dedupe oversize materialization decisions across reducer retries.
243    /// The signature is stable under:
244    /// - switching the current agent/model index
245    /// - retry cycles
246    ///
247    /// It changes only when the configured consumer set changes.
248    #[must_use]
249    pub fn consumer_signature_sha256(&self) -> String {
250        let mut pairs: Vec<(&str, &[String])> = self
251            .agents
252            .iter()
253            .enumerate()
254            .map(|(idx, agent)| {
255                let models: &[String] = self
256                    .models_per_agent
257                    .get(idx)
258                    .map_or([].as_slice(), std::vec::Vec::as_slice);
259                (agent.as_str(), models)
260            })
261            .collect();
262
263        // Sort so the signature is stable even if callers reorder the configured
264        // consumer set.
265        pairs.sort_by(|(agent_a, models_a), (agent_b, models_b)| {
266            use std::cmp::Ordering;
267
268            let agent_ord = agent_a.cmp(agent_b);
269            if agent_ord != Ordering::Equal {
270                return agent_ord;
271            }
272
273            for (a, b) in models_a.iter().zip(models_b.iter()) {
274                let ord = a.cmp(b);
275                if ord != Ordering::Equal {
276                    return ord;
277                }
278            }
279
280            models_a.len().cmp(&models_b.len())
281        });
282
283        let mut hasher = Sha256::new();
284        hasher.update(agent_role_signature_tag(self.current_role));
285        for (agent, models) in pairs {
286            hasher.update(agent.as_bytes());
287            hasher.update(b"|");
288            for (idx, model) in models.iter().enumerate() {
289                if idx > 0 {
290                    hasher.update(b",");
291                }
292                hasher.update(model.as_bytes());
293            }
294            hasher.update(b"\n");
295        }
296        let digest = hasher.finalize();
297        digest.iter().fold(String::new(), |mut s, b| {
298            use std::fmt::Write;
299            write!(&mut s, "{b:02x}").unwrap();
300            s
301        })
302    }
303
304    #[cfg(test)]
305    fn legacy_consumer_signature_sha256_for_test(&self) -> String {
306        let mut rendered: Vec<String> = self
307            .agents
308            .iter()
309            .enumerate()
310            .map(|(idx, agent)| {
311                let models = self
312                    .models_per_agent
313                    .get(idx)
314                    .map_or([].as_slice(), std::vec::Vec::as_slice);
315                format!("{}|{}", agent, models.join(","))
316            })
317            .collect();
318
319        rendered.sort();
320
321        let mut hasher = Sha256::new();
322        hasher.update(agent_role_signature_tag(self.current_role));
323        for line in rendered {
324            hasher.update(line.as_bytes());
325            hasher.update(b"\n");
326        }
327        let digest = hasher.finalize();
328        digest.iter().fold(String::new(), |mut s, b| {
329            use std::fmt::Write;
330            write!(&mut s, "{b:02x}").unwrap();
331            s
332        })
333    }
334
335    /// Get the currently selected model for the current agent.
336    ///
337    /// Returns `None` if:
338    /// - No models are configured
339    /// - The current agent index is out of bounds
340    /// - The current model index is out of bounds
341    #[must_use]
342    pub fn current_model(&self) -> Option<&String> {
343        self.models_per_agent
344            .get(self.current_agent_index)
345            .and_then(|models| models.get(self.current_model_index))
346    }
347
348    #[must_use]
349    pub const fn is_exhausted(&self) -> bool {
350        self.retry_cycle >= self.max_cycles
351            && self.current_agent_index == 0
352            && self.current_model_index == 0
353    }
354}
355
356#[cfg(test)]
357mod consumer_signature_tests {
358    use super::*;
359
360    #[test]
361    fn test_consumer_signature_sorting_matches_legacy_rendered_pair_ordering() {
362        // This regression test locks in the pre-optimization signature ordering:
363        // sort by the lexicographic ordering of the rendered `agent|models_csv` strings.
364        //
365        // A length-first models compare changes ordering when the first model differs.
366        // Example: "a,z" must sort before "b" even though it is longer.
367        let state = AgentChainState::initial().with_agents(
368            vec!["agent".to_string(), "agent".to_string()],
369            vec![
370                vec!["b".to_string()],
371                vec!["a".to_string(), "z".to_string()],
372            ],
373            AgentRole::Developer,
374        );
375
376        assert_eq!(
377            state.consumer_signature_sha256(),
378            state.legacy_consumer_signature_sha256_for_test(),
379            "consumer signature ordering must remain stable for the same configured consumers"
380        );
381    }
382
383    #[test]
384    fn test_consumer_signature_uses_stable_role_encoding() {
385        // The consumer signature is persisted in reducer state and used for dedupe.
386        // It must not depend on Debug formatting (variant renames would change the hash).
387        // Instead, it should use a stable, explicit role tag.
388        let state = AgentChainState::initial().with_agents(
389            vec!["agent-a".to_string()],
390            vec![vec!["m1".to_string(), "m2".to_string()]],
391            AgentRole::Reviewer,
392        );
393
394        let mut hasher = Sha256::new();
395        hasher.update(b"reviewer\n");
396        hasher.update(b"agent-a");
397        hasher.update(b"|");
398        hasher.update(b"m1");
399        hasher.update(b",");
400        hasher.update(b"m2");
401        hasher.update(b"\n");
402        let expected = hasher.finalize().iter().fold(String::new(), |mut acc, b| {
403            use std::fmt::Write;
404            write!(acc, "{b:02x}").unwrap();
405            acc
406        });
407
408        assert_eq!(
409            state.consumer_signature_sha256(),
410            expected,
411            "role encoding must be stable and explicit"
412        );
413    }
414}
415
416#[cfg(test)]
417mod legacy_rate_limit_prompt_tests {
418    use super::*;
419
420    #[test]
421    fn test_legacy_rate_limit_continuation_prompt_uses_current_role_on_deserialize() {
422        // Legacy checkpoints stored `rate_limit_continuation_prompt` as a bare string.
423        // When resuming, we must scope that prompt to the chain's `current_role`
424        // (the role the checkpoint was executing) instead of defaulting to Developer.
425        let state = AgentChainState::initial().with_agents(
426            vec!["a".to_string()],
427            vec![vec![]],
428            AgentRole::Reviewer,
429        );
430
431        let mut v = serde_json::to_value(&state).expect("serialize AgentChainState");
432        v["rate_limit_continuation_prompt"] = serde_json::Value::String("legacy prompt".into());
433
434        let json = serde_json::to_string(&v).expect("serialize JSON value");
435        let decoded: AgentChainState =
436            serde_json::from_str(&json).expect("deserialize AgentChainState");
437
438        let prompt = decoded
439            .rate_limit_continuation_prompt
440            .expect("expected legacy prompt to deserialize");
441        assert_eq!(prompt.role, AgentRole::Reviewer);
442        assert_eq!(prompt.prompt, "legacy prompt");
443    }
444}