Skip to main content

ralph_workflow/reducer/state/agent_chain/
mod.rs

1// Agent fallback chain state.
2//
3// Contains AgentChainState and backoff computation helpers.
4//
5// # Performance Optimization
6//
7// AgentChainState uses Arc<[T]> for immutable collections (agents, models_per_agent)
8// to enable cheap state copying during state transitions. This eliminates O(n) deep
9// copy overhead and makes state transitions O(1) for collection fields.
10//
11// The reducer creates new state instances on every event, so this optimization
12// significantly reduces memory allocations and improves performance.
13
14use std::sync::Arc;
15
16use serde::{Deserialize, Serialize};
17use sha2::{Digest, Sha256};
18
19pub use crate::agents::{AgentDrain, AgentRole, DrainMode};
20
21mod backoff;
22mod transitions;
23
24/// Agent fallback chain state (explicit, not loop indices).
25///
26/// Tracks position in the multi-level fallback chain:
27/// - Agent level (primary → fallback1 → fallback2)
28/// - Model level (within each agent, try different models)
29/// - Retry cycle (exhaust all agents, start over with exponential backoff)
30///
31/// # Memory Optimization
32///
33/// Uses Arc<[T]> for `agents` and `models_per_agent` collections to enable
34/// cheap cloning during state transitions. Since these collections are immutable
35/// after construction, `Arc::clone` only increments a reference count instead of
36/// deep copying the entire collection.
37#[derive(Clone, Serialize, Debug)]
38pub struct AgentChainState {
39    /// Agent names in fallback order. Arc<[String]> enables cheap cloning
40    /// via reference counting instead of deep copying the collection.
41    pub agents: Arc<[String]>,
42    pub current_agent_index: usize,
43    /// Models per agent. Arc for immutable outer collection with cheap cloning.
44    /// Inner Vec<String> is kept for runtime indexing during model selection.
45    pub models_per_agent: Arc<[Vec<String>]>,
46    pub current_model_index: usize,
47    pub retry_cycle: u32,
48    pub max_cycles: u32,
49    /// Base delay between retry cycles in milliseconds.
50    #[serde(default = "default_retry_delay_ms")]
51    pub retry_delay_ms: u64,
52    /// Multiplier for exponential backoff.
53    #[serde(default = "default_backoff_multiplier")]
54    pub backoff_multiplier: f64,
55    /// Maximum backoff delay in milliseconds.
56    #[serde(default = "default_max_backoff_ms")]
57    pub max_backoff_ms: u64,
58    /// Pending backoff delay (milliseconds) that must be waited before continuing.
59    #[serde(default)]
60    pub backoff_pending_ms: Option<u64>,
61    /// Compatibility copy of the broad capability role.
62    ///
63    /// Runtime code should treat `current_drain` as authoritative and derive the
64    /// active role from it. This field is retained for checkpoint compatibility
65    /// and diagnostics only.
66    pub current_role: AgentRole,
67    #[serde(default = "default_current_drain")]
68    pub current_drain: AgentDrain,
69    #[serde(default)]
70    pub current_mode: DrainMode,
71    /// Prompt context preserved from a rate-limited agent for continuation.
72    ///
73    /// When an agent hits 429, we save the prompt here so the next agent can
74    /// continue the SAME role/task instead of starting from scratch.
75    ///
76    /// IMPORTANT: This must be role-scoped to prevent cross-task contamination
77    /// (e.g., a developer continuation prompt overriding an analysis prompt).
78    #[serde(default)]
79    pub rate_limit_continuation_prompt: Option<RateLimitContinuationPrompt>,
80    /// Session ID from the last agent response.
81    ///
82    /// Used for XSD retry to continue with the same session when possible.
83    /// Agents that support sessions (e.g., Claude Code) emit session IDs
84    /// that can be passed back for continuation.
85    #[serde(default)]
86    pub last_session_id: Option<String>,
87    /// Last failure reason from the most recent agent failure.
88    ///
89    /// Used to provide context in CLI output when a fallback agent is invoked.
90    /// Cleared on InvocationSucceeded or ChainInitialized.
91    #[serde(default)]
92    pub last_failure_reason: Option<String>,
93}
94
95/// Role-scoped continuation prompt captured from a rate limit (429).
96#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Eq)]
97pub struct RateLimitContinuationPrompt {
98    pub drain: AgentDrain,
99    pub role: AgentRole,
100    pub prompt: String,
101}
102
103#[derive(Deserialize)]
104#[serde(untagged)]
105enum RateLimitContinuationPromptRepr {
106    LegacyString(String),
107    Structured {
108        #[serde(rename = "role")]
109        _role: AgentRole,
110        #[serde(default)]
111        drain: Option<AgentDrain>,
112        prompt: String,
113    },
114}
115
116fn infer_legacy_current_drain(
117    current_drain: Option<AgentDrain>,
118    current_role: Option<AgentRole>,
119    current_mode: DrainMode,
120    continuation_prompt: Option<&RateLimitContinuationPromptRepr>,
121) -> AgentDrain {
122    if let Some(current_drain) = current_drain {
123        return current_drain;
124    }
125
126    if let Some(prompt_drain) = continuation_prompt.and_then(|prompt| match prompt {
127        RateLimitContinuationPromptRepr::LegacyString(_) => None,
128        RateLimitContinuationPromptRepr::Structured { drain, .. } => *drain,
129    }) {
130        return prompt_drain;
131    }
132
133    match (current_role, current_mode) {
134        (Some(AgentRole::Reviewer), DrainMode::Continuation) => AgentDrain::Fix,
135        (Some(AgentRole::Developer), DrainMode::Continuation) => AgentDrain::Development,
136        (Some(current_role), _) => AgentDrain::from(current_role),
137        (None, _) => default_current_drain(),
138    }
139}
140
141impl<'de> Deserialize<'de> for AgentChainState {
142    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
143    where
144        D: serde::Deserializer<'de>,
145    {
146        #[derive(Deserialize)]
147        struct AgentChainStateSerde {
148            agents: Arc<[String]>,
149            current_agent_index: usize,
150            models_per_agent: Arc<[Vec<String>]>,
151            current_model_index: usize,
152            retry_cycle: u32,
153            max_cycles: u32,
154            #[serde(default = "default_retry_delay_ms")]
155            retry_delay_ms: u64,
156            #[serde(default = "default_backoff_multiplier")]
157            backoff_multiplier: f64,
158            #[serde(default = "default_max_backoff_ms")]
159            max_backoff_ms: u64,
160            #[serde(default)]
161            backoff_pending_ms: Option<u64>,
162            #[serde(default)]
163            current_drain: Option<AgentDrain>,
164            #[serde(default)]
165            current_role: Option<AgentRole>,
166            #[serde(default)]
167            current_mode: DrainMode,
168            #[serde(default)]
169            rate_limit_continuation_prompt: Option<RateLimitContinuationPromptRepr>,
170            #[serde(default)]
171            last_session_id: Option<String>,
172            #[serde(default)]
173            last_failure_reason: Option<String>,
174        }
175
176        let raw = AgentChainStateSerde::deserialize(deserializer)?;
177        let current_drain = infer_legacy_current_drain(
178            raw.current_drain,
179            raw.current_role,
180            raw.current_mode,
181            raw.rate_limit_continuation_prompt.as_ref(),
182        );
183        let current_role = current_drain.role();
184
185        let rate_limit_continuation_prompt = raw.rate_limit_continuation_prompt.map(|repr| {
186            match repr {
187                RateLimitContinuationPromptRepr::LegacyString(prompt) => {
188                    // Legacy checkpoints stored only the prompt string. Scope it to the
189                    // resolved drain role so resume can't cross-contaminate drains.
190                    RateLimitContinuationPrompt {
191                        drain: current_drain,
192                        role: current_role,
193                        prompt,
194                    }
195                }
196                RateLimitContinuationPromptRepr::Structured {
197                    _role: _,
198                    drain,
199                    prompt,
200                } => {
201                    let prompt_drain = drain.unwrap_or(current_drain);
202                    RateLimitContinuationPrompt {
203                        drain: prompt_drain,
204                        role: prompt_drain.role(),
205                        prompt,
206                    }
207                }
208            }
209        });
210
211        Ok(Self {
212            agents: raw.agents,
213            current_agent_index: raw.current_agent_index,
214            models_per_agent: raw.models_per_agent,
215            current_model_index: raw.current_model_index,
216            retry_cycle: raw.retry_cycle,
217            max_cycles: raw.max_cycles,
218            retry_delay_ms: raw.retry_delay_ms,
219            backoff_multiplier: raw.backoff_multiplier,
220            max_backoff_ms: raw.max_backoff_ms,
221            backoff_pending_ms: raw.backoff_pending_ms,
222            current_role,
223            current_drain,
224            current_mode: raw.current_mode,
225            rate_limit_continuation_prompt,
226            last_session_id: raw.last_session_id,
227            last_failure_reason: raw.last_failure_reason,
228        })
229    }
230}
231
232const fn default_retry_delay_ms() -> u64 {
233    1000
234}
235
236const fn default_backoff_multiplier() -> f64 {
237    2.0
238}
239
240const fn default_max_backoff_ms() -> u64 {
241    60000
242}
243
244const fn default_current_drain() -> AgentDrain {
245    AgentDrain::Planning
246}
247
248const fn agent_drain_signature_tag(drain: AgentDrain) -> &'static [u8] {
249    match drain {
250        AgentDrain::Planning => b"planning\n",
251        AgentDrain::Development => b"development\n",
252        AgentDrain::Review => b"review\n",
253        AgentDrain::Fix => b"fix\n",
254        AgentDrain::Commit => b"commit\n",
255        AgentDrain::Analysis => b"analysis\n",
256    }
257}
258
259impl AgentChainState {
260    #[must_use]
261    pub fn initial() -> Self {
262        Self {
263            agents: Arc::from(vec![]),
264            current_agent_index: 0,
265            models_per_agent: Arc::from(vec![]),
266            current_model_index: 0,
267            retry_cycle: 0,
268            max_cycles: 3,
269            retry_delay_ms: default_retry_delay_ms(),
270            backoff_multiplier: default_backoff_multiplier(),
271            max_backoff_ms: default_max_backoff_ms(),
272            backoff_pending_ms: None,
273            current_role: AgentRole::Developer,
274            current_drain: default_current_drain(),
275            current_mode: DrainMode::Normal,
276            rate_limit_continuation_prompt: None,
277            last_session_id: None,
278            last_failure_reason: None,
279        }
280    }
281
282    #[must_use]
283    pub fn matches_runtime_drain(&self, runtime_drain: AgentDrain) -> bool {
284        self.current_drain == runtime_drain
285    }
286
287    #[must_use]
288    pub fn with_agents(
289        mut self,
290        agents: Vec<String>,
291        models_per_agent: Vec<Vec<String>>,
292        role: AgentRole,
293    ) -> Self {
294        self.agents = Arc::from(agents);
295        self.models_per_agent = Arc::from(models_per_agent);
296        self.current_role = role;
297        self.current_drain = match role {
298            AgentRole::Developer => AgentDrain::Development,
299            AgentRole::Reviewer => AgentDrain::Review,
300            AgentRole::Commit => AgentDrain::Commit,
301            AgentRole::Analysis => AgentDrain::Analysis,
302        };
303        self.current_mode = DrainMode::Normal;
304        self
305    }
306
307    #[must_use]
308    pub const fn with_drain(mut self, drain: AgentDrain) -> Self {
309        self.current_drain = drain;
310        self.current_role = drain.role();
311        self
312    }
313
314    #[must_use]
315    pub const fn with_mode(mut self, mode: DrainMode) -> Self {
316        self.current_mode = mode;
317        self
318    }
319
320    #[must_use]
321    pub const fn active_role(&self) -> AgentRole {
322        self.current_drain.role()
323    }
324
325    /// Builder method to set the maximum number of retry cycles.
326    ///
327    /// A retry cycle is when all agents have been exhausted and we start
328    /// over with exponential backoff.
329    #[must_use]
330    pub const fn with_max_cycles(mut self, max_cycles: u32) -> Self {
331        self.max_cycles = max_cycles;
332        self
333    }
334
335    #[must_use]
336    pub const fn with_backoff_policy(
337        mut self,
338        retry_delay_ms: u64,
339        backoff_multiplier: f64,
340        max_backoff_ms: u64,
341    ) -> Self {
342        self.retry_delay_ms = retry_delay_ms;
343        self.backoff_multiplier = backoff_multiplier;
344        self.max_backoff_ms = max_backoff_ms;
345        self
346    }
347
348    #[must_use]
349    pub fn current_agent(&self) -> Option<&String> {
350        self.agents.get(self.current_agent_index)
351    }
352
353    /// Stable signature of the current consumer set (agents + configured models + drain).
354    ///
355    /// This is used to dedupe oversize materialization decisions across reducer retries.
356    /// The signature is stable under:
357    /// - switching the current agent/model index
358    /// - retry cycles
359    ///
360    /// It changes only when the configured consumer set changes.
361    #[must_use]
362    pub fn consumer_signature_sha256(&self) -> String {
363        let mut pairs: Vec<(&str, &[String])> = self
364            .agents
365            .iter()
366            .enumerate()
367            .map(|(idx, agent)| {
368                let models: &[String] = self
369                    .models_per_agent
370                    .get(idx)
371                    .map_or([].as_slice(), std::vec::Vec::as_slice);
372                (agent.as_str(), models)
373            })
374            .collect();
375
376        // Sort so the signature is stable even if callers reorder the configured
377        // consumer set.
378        pairs.sort_by(|(agent_a, models_a), (agent_b, models_b)| {
379            use std::cmp::Ordering;
380
381            let agent_ord = agent_a.cmp(agent_b);
382            if agent_ord != Ordering::Equal {
383                return agent_ord;
384            }
385
386            for (a, b) in models_a.iter().zip(models_b.iter()) {
387                let ord = a.cmp(b);
388                if ord != Ordering::Equal {
389                    return ord;
390                }
391            }
392
393            models_a.len().cmp(&models_b.len())
394        });
395
396        let mut hasher = Sha256::new();
397        hasher.update(agent_drain_signature_tag(self.current_drain));
398        for (agent, models) in pairs {
399            hasher.update(agent.as_bytes());
400            hasher.update(b"|");
401            for (idx, model) in models.iter().enumerate() {
402                if idx > 0 {
403                    hasher.update(b",");
404                }
405                hasher.update(model.as_bytes());
406            }
407            hasher.update(b"\n");
408        }
409        let digest = hasher.finalize();
410        digest.iter().fold(String::new(), |mut s, b| {
411            use std::fmt::Write;
412            write!(&mut s, "{b:02x}").unwrap();
413            s
414        })
415    }
416
417    #[cfg(test)]
418    fn legacy_consumer_signature_sha256_for_test(&self) -> String {
419        let mut rendered: Vec<String> = self
420            .agents
421            .iter()
422            .enumerate()
423            .map(|(idx, agent)| {
424                let models = self
425                    .models_per_agent
426                    .get(idx)
427                    .map_or([].as_slice(), std::vec::Vec::as_slice);
428                format!("{}|{}", agent, models.join(","))
429            })
430            .collect();
431
432        rendered.sort();
433
434        let mut hasher = Sha256::new();
435        hasher.update(agent_drain_signature_tag(self.current_drain));
436        for line in rendered {
437            hasher.update(line.as_bytes());
438            hasher.update(b"\n");
439        }
440        let digest = hasher.finalize();
441        digest.iter().fold(String::new(), |mut s, b| {
442            use std::fmt::Write;
443            write!(&mut s, "{b:02x}").unwrap();
444            s
445        })
446    }
447
448    /// Get the currently selected model for the current agent.
449    ///
450    /// Returns `None` if:
451    /// - No models are configured
452    /// - The current agent index is out of bounds
453    /// - The current model index is out of bounds
454    #[must_use]
455    pub fn current_model(&self) -> Option<&String> {
456        self.models_per_agent
457            .get(self.current_agent_index)
458            .and_then(|models| models.get(self.current_model_index))
459    }
460
461    #[must_use]
462    pub const fn is_exhausted(&self) -> bool {
463        self.retry_cycle >= self.max_cycles
464            && self.current_agent_index == 0
465            && self.current_model_index == 0
466    }
467}
468
469#[cfg(test)]
470mod consumer_signature_tests {
471    use super::*;
472
473    #[test]
474    fn test_consumer_signature_sorting_matches_legacy_rendered_pair_ordering() {
475        // This regression test locks in the pre-optimization signature ordering:
476        // sort by the lexicographic ordering of the rendered `agent|models_csv` strings.
477        //
478        // A length-first models compare changes ordering when the first model differs.
479        // Example: "a,z" must sort before "b" even though it is longer.
480        let state = AgentChainState::initial().with_agents(
481            vec!["agent".to_string(), "agent".to_string()],
482            vec![
483                vec!["b".to_string()],
484                vec!["a".to_string(), "z".to_string()],
485            ],
486            AgentRole::Developer,
487        );
488
489        assert_eq!(
490            state.consumer_signature_sha256(),
491            state.legacy_consumer_signature_sha256_for_test(),
492            "consumer signature ordering must remain stable for the same configured consumers"
493        );
494    }
495
496    #[test]
497    fn test_consumer_signature_uses_stable_drain_encoding() {
498        // The consumer signature is persisted in reducer state and used for dedupe.
499        // It must not depend on Debug formatting (variant renames would change the hash).
500        // Instead, it should use a stable, explicit role tag.
501        let state = AgentChainState::initial()
502            .with_agents(
503                vec!["agent-a".to_string()],
504                vec![vec!["m1".to_string(), "m2".to_string()]],
505                AgentRole::Reviewer,
506            )
507            .with_drain(AgentDrain::Fix);
508
509        let mut hasher = Sha256::new();
510        hasher.update(b"fix\n");
511        hasher.update(b"agent-a");
512        hasher.update(b"|");
513        hasher.update(b"m1");
514        hasher.update(b",");
515        hasher.update(b"m2");
516        hasher.update(b"\n");
517        let expected = hasher.finalize().iter().fold(String::new(), |mut acc, b| {
518            use std::fmt::Write;
519            write!(acc, "{b:02x}").unwrap();
520            acc
521        });
522
523        assert_eq!(
524            state.consumer_signature_sha256(),
525            expected,
526            "role encoding must be stable and explicit"
527        );
528    }
529}
530
531#[cfg(test)]
532mod legacy_rate_limit_prompt_tests {
533    use super::*;
534
535    #[test]
536    fn test_legacy_rate_limit_continuation_prompt_uses_current_role_on_deserialize() {
537        // Legacy checkpoints stored `rate_limit_continuation_prompt` as a bare string.
538        // When resuming, we must scope that prompt to the chain's `current_role`
539        // (the role the checkpoint was executing) instead of defaulting to Developer.
540        let state = AgentChainState::initial().with_agents(
541            vec!["a".to_string()],
542            vec![vec![]],
543            AgentRole::Reviewer,
544        );
545
546        let mut v = serde_json::to_value(&state).expect("serialize AgentChainState");
547        v["rate_limit_continuation_prompt"] = serde_json::Value::String("legacy prompt".into());
548
549        let json = serde_json::to_string(&v).expect("serialize JSON value");
550        let decoded: AgentChainState =
551            serde_json::from_str(&json).expect("deserialize AgentChainState");
552
553        let prompt = decoded
554            .rate_limit_continuation_prompt
555            .expect("expected legacy prompt to deserialize");
556        assert_eq!(prompt.drain, AgentDrain::Review);
557        assert_eq!(prompt.role, AgentRole::Reviewer);
558        assert_eq!(prompt.prompt, "legacy prompt");
559    }
560}