Skip to main content

ralph_workflow/reducer/state/agent_chain/
mod.rs

1// Agent fallback chain state.
2//
3// Contains AgentChainState and backoff computation helpers.
4//
5// # Performance Optimization
6//
7// AgentChainState uses Arc<[T]> for immutable collections (agents, models_per_agent)
8// to enable cheap state copying during state transitions. This eliminates O(n) deep
9// copy overhead and makes state transitions O(1) for collection fields.
10//
11// The reducer creates new state instances on every event, so this optimization
12// significantly reduces memory allocations and improves performance.
13
14use std::sync::Arc;
15
16use serde::{Deserialize, Serialize};
17use sha2::{Digest, Sha256};
18
19pub use crate::agents::{AgentDrain, AgentRole, DrainMode};
20
21mod backoff;
22mod transitions;
23
24/// Agent fallback chain state (explicit, not loop indices).
25///
26/// Tracks position in the multi-level fallback chain:
27/// - Agent level (primary → fallback1 → fallback2)
28/// - Model level (within each agent, try different models)
29/// - Retry cycle (exhaust all agents, start over with exponential backoff)
30///
31/// # Memory Optimization
32///
33/// Uses Arc<[T]> for `agents` and `models_per_agent` collections to enable
34/// cheap cloning during state transitions. Since these collections are immutable
35/// after construction, `Arc::clone` only increments a reference count instead of
36/// deep copying the entire collection.
37#[derive(Clone, Serialize, Debug)]
38pub struct AgentChainState {
39    /// Agent names in fallback order. Arc<[String]> enables cheap cloning
40    /// via reference counting instead of deep copying the collection.
41    pub agents: Arc<[String]>,
42    pub current_agent_index: usize,
43    /// Models per agent. Arc for immutable outer collection with cheap cloning.
44    /// Inner Vec<String> is kept for runtime indexing during model selection.
45    pub models_per_agent: Arc<[Vec<String>]>,
46    pub current_model_index: usize,
47    pub retry_cycle: u32,
48    pub max_cycles: u32,
49    /// Base delay between retry cycles in milliseconds.
50    #[serde(default = "default_retry_delay_ms")]
51    pub retry_delay_ms: u64,
52    /// Multiplier for exponential backoff.
53    #[serde(default = "default_backoff_multiplier")]
54    pub backoff_multiplier: f64,
55    /// Maximum backoff delay in milliseconds.
56    #[serde(default = "default_max_backoff_ms")]
57    pub max_backoff_ms: u64,
58    /// Pending backoff delay (milliseconds) that must be waited before continuing.
59    #[serde(default)]
60    pub backoff_pending_ms: Option<u64>,
61    /// Compatibility copy of the broad capability role.
62    ///
63    /// Runtime code should treat `current_drain` as authoritative and derive the
64    /// active role from it. This field is retained for checkpoint compatibility
65    /// and diagnostics only.
66    pub current_role: AgentRole,
67    #[serde(default = "default_current_drain")]
68    pub current_drain: AgentDrain,
69    #[serde(default)]
70    pub current_mode: DrainMode,
71    /// Prompt context preserved from a rate-limited agent for continuation.
72    ///
73    /// When an agent hits 429, we save the prompt here so the next agent can
74    /// continue the SAME role/task instead of starting from scratch.
75    ///
76    /// IMPORTANT: This must be role-scoped to prevent cross-task contamination
77    /// (e.g., a developer continuation prompt overriding an analysis prompt).
78    #[serde(default)]
79    pub rate_limit_continuation_prompt: Option<RateLimitContinuationPrompt>,
80    /// Session ID from the last agent response.
81    ///
82    /// Used for XSD retry to continue with the same session when possible.
83    /// Agents that support sessions (e.g., Claude Code) emit session IDs
84    /// that can be passed back for continuation.
85    #[serde(default)]
86    pub last_session_id: Option<String>,
87    /// Last failure reason from the most recent agent failure.
88    ///
89    /// Used to provide context in CLI output when a fallback agent is invoked.
90    /// Cleared on InvocationSucceeded or ChainInitialized.
91    #[serde(default)]
92    pub last_failure_reason: Option<String>,
93}
94
95/// Role-scoped continuation prompt captured from a rate limit (429).
96#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Eq)]
97pub struct RateLimitContinuationPrompt {
98    pub drain: AgentDrain,
99    pub role: AgentRole,
100    pub prompt: String,
101}
102
103#[derive(Deserialize)]
104#[serde(untagged)]
105enum RateLimitContinuationPromptRepr {
106    LegacyString(String),
107    Structured {
108        #[serde(rename = "role")]
109        _role: AgentRole,
110        #[serde(default)]
111        drain: Option<AgentDrain>,
112        prompt: String,
113    },
114}
115
116fn infer_legacy_current_drain(
117    current_drain: Option<AgentDrain>,
118    current_role: Option<AgentRole>,
119    current_mode: DrainMode,
120    continuation_prompt: Option<&RateLimitContinuationPromptRepr>,
121) -> AgentDrain {
122    if let Some(current_drain) = current_drain {
123        return current_drain;
124    }
125
126    if let Some(prompt_drain) = continuation_prompt.and_then(|prompt| match prompt {
127        RateLimitContinuationPromptRepr::LegacyString(_) => None,
128        RateLimitContinuationPromptRepr::Structured { drain, .. } => *drain,
129    }) {
130        return prompt_drain;
131    }
132
133    match (current_role, current_mode) {
134        (Some(AgentRole::Reviewer), DrainMode::Continuation) => AgentDrain::Fix,
135        (Some(AgentRole::Developer), DrainMode::Continuation) => AgentDrain::Development,
136        (Some(current_role), _) => AgentDrain::from(current_role),
137        (None, _) => default_current_drain(),
138    }
139}
140
141impl<'de> Deserialize<'de> for AgentChainState {
142    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
143    where
144        D: serde::Deserializer<'de>,
145    {
146        #[derive(Deserialize)]
147        struct AgentChainStateSerde {
148            agents: Arc<[String]>,
149            current_agent_index: usize,
150            models_per_agent: Arc<[Vec<String>]>,
151            current_model_index: usize,
152            retry_cycle: u32,
153            max_cycles: u32,
154            #[serde(default = "default_retry_delay_ms")]
155            retry_delay_ms: u64,
156            #[serde(default = "default_backoff_multiplier")]
157            backoff_multiplier: f64,
158            #[serde(default = "default_max_backoff_ms")]
159            max_backoff_ms: u64,
160            #[serde(default)]
161            backoff_pending_ms: Option<u64>,
162            #[serde(default)]
163            current_drain: Option<AgentDrain>,
164            #[serde(default)]
165            current_role: Option<AgentRole>,
166            #[serde(default)]
167            current_mode: DrainMode,
168            #[serde(default)]
169            rate_limit_continuation_prompt: Option<RateLimitContinuationPromptRepr>,
170            #[serde(default)]
171            last_session_id: Option<String>,
172            #[serde(default)]
173            last_failure_reason: Option<String>,
174        }
175
176        let raw = AgentChainStateSerde::deserialize(deserializer)?;
177        let current_drain = infer_legacy_current_drain(
178            raw.current_drain,
179            raw.current_role,
180            raw.current_mode,
181            raw.rate_limit_continuation_prompt.as_ref(),
182        );
183        let current_role = current_drain.role();
184
185        let rate_limit_continuation_prompt = raw.rate_limit_continuation_prompt.map(|repr| {
186            match repr {
187                RateLimitContinuationPromptRepr::LegacyString(prompt) => {
188                    // Legacy checkpoints stored only the prompt string. Scope it to the
189                    // resolved drain role so resume can't cross-contaminate drains.
190                    RateLimitContinuationPrompt {
191                        drain: current_drain,
192                        role: current_role,
193                        prompt,
194                    }
195                }
196                RateLimitContinuationPromptRepr::Structured {
197                    _role: _,
198                    drain,
199                    prompt,
200                } => {
201                    let prompt_drain = drain.unwrap_or(current_drain);
202                    RateLimitContinuationPrompt {
203                        drain: prompt_drain,
204                        role: prompt_drain.role(),
205                        prompt,
206                    }
207                }
208            }
209        });
210
211        Ok(Self {
212            agents: raw.agents,
213            current_agent_index: raw.current_agent_index,
214            models_per_agent: raw.models_per_agent,
215            current_model_index: raw.current_model_index,
216            retry_cycle: raw.retry_cycle,
217            max_cycles: raw.max_cycles,
218            retry_delay_ms: raw.retry_delay_ms,
219            backoff_multiplier: raw.backoff_multiplier,
220            max_backoff_ms: raw.max_backoff_ms,
221            backoff_pending_ms: raw.backoff_pending_ms,
222            current_role,
223            current_drain,
224            current_mode: raw.current_mode,
225            rate_limit_continuation_prompt,
226            last_session_id: raw.last_session_id,
227            last_failure_reason: raw.last_failure_reason,
228        })
229    }
230}
231
232impl Default for AgentChainState {
233    fn default() -> Self {
234        Self {
235            agents: Arc::from(vec![]),
236            current_agent_index: 0,
237            models_per_agent: Arc::from(vec![]),
238            current_model_index: 0,
239            retry_cycle: 0,
240            max_cycles: 3,
241            retry_delay_ms: default_retry_delay_ms(),
242            backoff_multiplier: default_backoff_multiplier(),
243            max_backoff_ms: default_max_backoff_ms(),
244            backoff_pending_ms: None,
245            current_role: AgentRole::Developer,
246            current_drain: default_current_drain(),
247            current_mode: DrainMode::Normal,
248            rate_limit_continuation_prompt: None,
249            last_session_id: None,
250            last_failure_reason: None,
251        }
252    }
253}
254
255const fn default_retry_delay_ms() -> u64 {
256    1000
257}
258
259const fn default_backoff_multiplier() -> f64 {
260    2.0
261}
262
263const fn default_max_backoff_ms() -> u64 {
264    60000
265}
266
267const fn default_current_drain() -> AgentDrain {
268    AgentDrain::Planning
269}
270
271const fn agent_drain_signature_tag(drain: AgentDrain) -> &'static [u8] {
272    match drain {
273        AgentDrain::Planning => b"planning\n",
274        AgentDrain::Development => b"development\n",
275        AgentDrain::Review => b"review\n",
276        AgentDrain::Fix => b"fix\n",
277        AgentDrain::Commit => b"commit\n",
278        AgentDrain::Analysis => b"analysis\n",
279    }
280}
281
282impl AgentChainState {
283    #[must_use]
284    pub fn initial() -> Self {
285        Self {
286            agents: Arc::from(vec![]),
287            current_agent_index: 0,
288            models_per_agent: Arc::from(vec![]),
289            current_model_index: 0,
290            retry_cycle: 0,
291            max_cycles: 3,
292            retry_delay_ms: default_retry_delay_ms(),
293            backoff_multiplier: default_backoff_multiplier(),
294            max_backoff_ms: default_max_backoff_ms(),
295            backoff_pending_ms: None,
296            current_role: AgentRole::Developer,
297            current_drain: default_current_drain(),
298            current_mode: DrainMode::Normal,
299            rate_limit_continuation_prompt: None,
300            last_session_id: None,
301            last_failure_reason: None,
302        }
303    }
304
305    #[must_use]
306    pub fn matches_runtime_drain(&self, runtime_drain: AgentDrain) -> bool {
307        self.current_drain == runtime_drain
308    }
309
310    #[must_use]
311    pub fn with_agents(
312        self,
313        agents: Vec<String>,
314        models_per_agent: Vec<Vec<String>>,
315        role: AgentRole,
316    ) -> Self {
317        let current_drain = match role {
318            AgentRole::Developer => AgentDrain::Development,
319            AgentRole::Reviewer => AgentDrain::Review,
320            AgentRole::Commit => AgentDrain::Commit,
321            AgentRole::Analysis => AgentDrain::Analysis,
322        };
323        Self {
324            agents: Arc::from(agents),
325            models_per_agent: Arc::from(models_per_agent),
326            current_role: role,
327            current_drain,
328            current_mode: DrainMode::Normal,
329            ..self
330        }
331    }
332
333    #[must_use]
334    pub fn with_drain(self, drain: AgentDrain) -> Self {
335        Self {
336            current_drain: drain,
337            current_role: drain.role(),
338            ..self
339        }
340    }
341
342    #[must_use]
343    pub fn with_mode(self, mode: DrainMode) -> Self {
344        Self {
345            current_mode: mode,
346            ..self
347        }
348    }
349
350    #[must_use]
351    pub const fn active_role(&self) -> AgentRole {
352        self.current_drain.role()
353    }
354
355    /// Builder method to set the maximum number of retry cycles.
356    ///
357    /// A retry cycle is when all agents have been exhausted and we start
358    /// over with exponential backoff.
359    #[must_use]
360    pub fn with_max_cycles(self, max_cycles: u32) -> Self {
361        Self { max_cycles, ..self }
362    }
363
364    #[must_use]
365    pub fn with_backoff_policy(
366        self,
367        retry_delay_ms: u64,
368        backoff_multiplier: f64,
369        max_backoff_ms: u64,
370    ) -> Self {
371        Self {
372            retry_delay_ms,
373            backoff_multiplier,
374            max_backoff_ms,
375            ..self
376        }
377    }
378
379    #[must_use]
380    pub fn with_retry_cycle(self, retry_cycle: u32) -> Self {
381        Self {
382            retry_cycle,
383            ..self
384        }
385    }
386
387    #[must_use]
388    pub fn with_current_agent_index(self, current_agent_index: usize) -> Self {
389        Self {
390            current_agent_index,
391            ..self
392        }
393    }
394
395    #[must_use]
396    pub fn current_agent(&self) -> Option<&String> {
397        self.agents.get(self.current_agent_index)
398    }
399
400    /// Stable signature of the current consumer set (agents + configured models + drain).
401    ///
402    /// This is used to dedupe oversize materialization decisions across reducer retries.
403    /// The signature is stable under:
404    /// - switching the current agent/model index
405    /// - retry cycles
406    ///
407    /// It changes only when the configured consumer set changes.
408    #[must_use]
409    pub fn consumer_signature_sha256(&self) -> String {
410        use itertools::Itertools;
411
412        let sorted_pairs: Vec<(String, Vec<String>)> = self
413            .agents
414            .iter()
415            .enumerate()
416            .map(|(idx, agent)| {
417                let models: Vec<String> = self
418                    .models_per_agent
419                    .get(idx)
420                    .map_or_else(Vec::new, |m| m.clone());
421                (agent.clone(), models)
422            })
423            .sorted_by_key(|(agent, models)| (agent.clone(), models.clone()))
424            .collect();
425
426        let update_chain: Vec<Vec<u8>> = sorted_pairs
427            .iter()
428            .map(|(agent, models)| {
429                let models_bytes: Vec<u8> = models
430                    .iter()
431                    .map(|m| m.as_bytes())
432                    .collect::<Vec<_>>()
433                    .join(&b',');
434                let line: Vec<u8> = std::iter::empty()
435                    .chain(agent.as_bytes().iter().copied())
436                    .chain([b'|'])
437                    .chain(models_bytes.iter().copied())
438                    .chain([b'\n'])
439                    .collect();
440                line
441            })
442            .collect();
443
444        let hasher = update_chain.iter().fold(
445            Digest::chain_update(Sha256::new(), agent_drain_signature_tag(self.current_drain)),
446            |h, chunk| Digest::chain_update(h, chunk.as_slice()),
447        );
448        let digest = hasher.finalize();
449        digest
450            .iter()
451            .map(|b| format!("{b:02x}"))
452            .collect::<String>()
453    }
454
455    #[cfg(test)]
456    fn legacy_consumer_signature_sha256_for_test(&self) -> String {
457        use itertools::Itertools;
458
459        let rendered: Vec<String> = self
460            .agents
461            .iter()
462            .enumerate()
463            .map(|(idx, agent)| {
464                let models = self
465                    .models_per_agent
466                    .get(idx)
467                    .map_or([].as_slice(), std::vec::Vec::as_slice);
468                format!(
469                    "{}|{}",
470                    agent,
471                    models
472                        .iter()
473                        .map(|s| s.as_str())
474                        .collect::<Vec<_>>()
475                        .join(",")
476                )
477            })
478            .sorted()
479            .collect();
480
481        let update_chain: Vec<&[u8]> = rendered
482            .iter()
483            .flat_map(|line| [line.as_bytes(), b"\n"])
484            .collect();
485
486        let hasher = update_chain.iter().fold(
487            Digest::chain_update(Sha256::new(), agent_drain_signature_tag(self.current_drain)),
488            |h, chunk| Digest::chain_update(h, *chunk),
489        );
490        let digest = hasher.finalize();
491        digest
492            .iter()
493            .map(|b| format!("{b:02x}"))
494            .collect::<String>()
495    }
496
497    /// Get the currently selected model for the current agent.
498    ///
499    /// Returns `None` if:
500    /// - No models are configured
501    /// - The current agent index is out of bounds
502    /// - The current model index is out of bounds
503    #[must_use]
504    pub fn current_model(&self) -> Option<&String> {
505        self.models_per_agent
506            .get(self.current_agent_index)
507            .and_then(|models| models.get(self.current_model_index))
508    }
509
510    #[must_use]
511    pub const fn is_exhausted(&self) -> bool {
512        self.retry_cycle >= self.max_cycles
513            && self.current_agent_index == 0
514            && self.current_model_index == 0
515    }
516}
517
518#[cfg(test)]
519mod consumer_signature_tests {
520    use super::*;
521
522    #[test]
523    fn test_consumer_signature_sorting_matches_legacy_rendered_pair_ordering() {
524        // This regression test locks in the pre-optimization signature ordering:
525        // sort by the lexicographic ordering of the rendered `agent|models_csv` strings.
526        //
527        // A length-first models compare changes ordering when the first model differs.
528        // Example: "a,z" must sort before "b" even though it is longer.
529        let state = AgentChainState::initial().with_agents(
530            vec!["agent".to_string(), "agent".to_string()],
531            vec![
532                vec!["b".to_string()],
533                vec!["a".to_string(), "z".to_string()],
534            ],
535            AgentRole::Developer,
536        );
537
538        assert_eq!(
539            state.consumer_signature_sha256(),
540            state.legacy_consumer_signature_sha256_for_test(),
541            "consumer signature ordering must remain stable for the same configured consumers"
542        );
543    }
544
545    #[test]
546    fn test_consumer_signature_uses_stable_drain_encoding() {
547        let state = AgentChainState::initial()
548            .with_agents(
549                vec!["agent-a".to_string()],
550                vec![vec!["m1".to_string(), "m2".to_string()]],
551                AgentRole::Reviewer,
552            )
553            .with_drain(AgentDrain::Fix);
554
555        let data = b"fix\nagent-a|m1,m2\n".to_vec();
556        let expected = Sha256::digest(&data)
557            .iter()
558            .fold(String::new(), |mut acc, b| {
559                use std::fmt::Write;
560                write!(acc, "{b:02x}").unwrap();
561                acc
562            });
563
564        assert_eq!(
565            state.consumer_signature_sha256(),
566            expected,
567            "role encoding must be stable and explicit"
568        );
569    }
570}
571
572#[cfg(test)]
573mod legacy_rate_limit_prompt_tests {
574    use super::*;
575
576    #[test]
577    fn test_legacy_rate_limit_continuation_prompt_uses_current_role_on_deserialize() {
578        // Legacy checkpoints stored `rate_limit_continuation_prompt` as a bare string.
579        // When resuming, we must scope that prompt to the chain's `current_role`
580        // (the role the checkpoint was executing) instead of defaulting to Developer.
581        let state = AgentChainState::initial().with_agents(
582            vec!["a".to_string()],
583            vec![vec![]],
584            AgentRole::Reviewer,
585        );
586
587        let mut v = serde_json::to_value(&state).expect("serialize AgentChainState");
588        v["rate_limit_continuation_prompt"] = serde_json::Value::String("legacy prompt".into());
589
590        let json = serde_json::to_string(&v).expect("serialize JSON value");
591        let decoded: AgentChainState =
592            serde_json::from_str(&json).expect("deserialize AgentChainState");
593
594        let prompt = decoded
595            .rate_limit_continuation_prompt
596            .expect("expected legacy prompt to deserialize");
597        assert_eq!(prompt.drain, AgentDrain::Review);
598        assert_eq!(prompt.role, AgentRole::Reviewer);
599        assert_eq!(prompt.prompt, "legacy prompt");
600    }
601}