Skip to main content

ralph_workflow/reducer/state/agent_chain/
mod.rs

1// Agent fallback chain state.
2//
3// Contains AgentChainState and backoff computation helpers.
4//
5// # Performance Optimization
6//
7// AgentChainState uses Arc<[T]> for immutable collections (agents, models_per_agent)
8// to enable cheap state copying during state transitions. This eliminates O(n) deep
9// copy overhead and makes state transitions O(1) for collection fields.
10//
11// The reducer creates new state instances on every event, so this optimization
12// significantly reduces memory allocations and improves performance.
13
14use std::sync::Arc;
15
16use serde::{Deserialize, Serialize};
17use sha2::{Digest, Sha256};
18
19pub use crate::agents::{AgentDrain, AgentRole, DrainMode};
20
21mod backoff;
22mod transitions;
23
24/// Agent fallback chain state (explicit, not loop indices).
25///
26/// Tracks position in the multi-level fallback chain:
27/// - Agent level (primary → fallback1 → fallback2)
28/// - Model level (within each agent, try different models)
29/// - Retry cycle (exhaust all agents, start over with exponential backoff)
30///
31/// # Memory Optimization
32///
33/// Uses Arc<[T]> for `agents` and `models_per_agent` collections to enable
34/// cheap cloning during state transitions. Since these collections are immutable
35/// after construction, `Arc::clone` only increments a reference count instead of
36/// deep copying the entire collection.
37#[derive(Clone, Serialize, Debug)]
38pub struct AgentChainState {
39    /// Agent names in fallback order. Arc<[String]> enables cheap cloning
40    /// via reference counting instead of deep copying the collection.
41    pub agents: Arc<[String]>,
42    pub current_agent_index: usize,
43    /// Models per agent. Arc for immutable outer collection with cheap cloning.
44    /// Inner Vec<String> is kept for runtime indexing during model selection.
45    pub models_per_agent: Arc<[Vec<String>]>,
46    pub current_model_index: usize,
47    pub retry_cycle: u32,
48    pub max_cycles: u32,
49    /// Base delay between retry cycles in milliseconds.
50    #[serde(default = "default_retry_delay_ms")]
51    pub retry_delay_ms: u64,
52    /// Multiplier for exponential backoff.
53    #[serde(default = "default_backoff_multiplier")]
54    pub backoff_multiplier: f64,
55    /// Maximum backoff delay in milliseconds.
56    #[serde(default = "default_max_backoff_ms")]
57    pub max_backoff_ms: u64,
58    /// Pending backoff delay (milliseconds) that must be waited before continuing.
59    #[serde(default)]
60    pub backoff_pending_ms: Option<u64>,
61    /// Compatibility copy of the broad capability role.
62    ///
63    /// Runtime code should treat `current_drain` as authoritative and derive the
64    /// active role from it. This field is retained for checkpoint compatibility
65    /// and diagnostics only.
66    pub current_role: AgentRole,
67    #[serde(default = "default_current_drain")]
68    pub current_drain: AgentDrain,
69    #[serde(default)]
70    pub current_mode: DrainMode,
71    /// Prompt context preserved from a rate-limited agent for continuation.
72    ///
73    /// When an agent hits 429, we save the prompt here so the next agent can
74    /// continue the SAME role/task instead of starting from scratch.
75    ///
76    /// IMPORTANT: This must be role-scoped to prevent cross-task contamination
77    /// (e.g., a developer continuation prompt overriding an analysis prompt).
78    #[serde(default)]
79    pub rate_limit_continuation_prompt: Option<RateLimitContinuationPrompt>,
80    /// Session ID from the last agent response.
81    ///
82    /// Used for XSD retry to continue with the same session when possible.
83    /// Agents that support sessions (e.g., Claude Code) emit session IDs
84    /// that can be passed back for continuation.
85    #[serde(default)]
86    pub last_session_id: Option<String>,
87}
88
89/// Role-scoped continuation prompt captured from a rate limit (429).
90#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Eq)]
91pub struct RateLimitContinuationPrompt {
92    pub drain: AgentDrain,
93    pub role: AgentRole,
94    pub prompt: String,
95}
96
97#[derive(Deserialize)]
98#[serde(untagged)]
99enum RateLimitContinuationPromptRepr {
100    LegacyString(String),
101    Structured {
102        #[serde(rename = "role")]
103        _role: AgentRole,
104        #[serde(default)]
105        drain: Option<AgentDrain>,
106        prompt: String,
107    },
108}
109
110fn infer_legacy_current_drain(
111    current_drain: Option<AgentDrain>,
112    current_role: Option<AgentRole>,
113    current_mode: DrainMode,
114    continuation_prompt: Option<&RateLimitContinuationPromptRepr>,
115) -> AgentDrain {
116    if let Some(current_drain) = current_drain {
117        return current_drain;
118    }
119
120    if let Some(prompt_drain) = continuation_prompt.and_then(|prompt| match prompt {
121        RateLimitContinuationPromptRepr::LegacyString(_) => None,
122        RateLimitContinuationPromptRepr::Structured { drain, .. } => *drain,
123    }) {
124        return prompt_drain;
125    }
126
127    match (current_role, current_mode) {
128        (Some(AgentRole::Reviewer), DrainMode::Continuation) => AgentDrain::Fix,
129        (Some(AgentRole::Developer), DrainMode::Continuation) => AgentDrain::Development,
130        (Some(current_role), _) => AgentDrain::from(current_role),
131        (None, _) => default_current_drain(),
132    }
133}
134
135impl<'de> Deserialize<'de> for AgentChainState {
136    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
137    where
138        D: serde::Deserializer<'de>,
139    {
140        #[derive(Deserialize)]
141        struct AgentChainStateSerde {
142            agents: Arc<[String]>,
143            current_agent_index: usize,
144            models_per_agent: Arc<[Vec<String>]>,
145            current_model_index: usize,
146            retry_cycle: u32,
147            max_cycles: u32,
148            #[serde(default = "default_retry_delay_ms")]
149            retry_delay_ms: u64,
150            #[serde(default = "default_backoff_multiplier")]
151            backoff_multiplier: f64,
152            #[serde(default = "default_max_backoff_ms")]
153            max_backoff_ms: u64,
154            #[serde(default)]
155            backoff_pending_ms: Option<u64>,
156            #[serde(default)]
157            current_drain: Option<AgentDrain>,
158            #[serde(default)]
159            current_role: Option<AgentRole>,
160            #[serde(default)]
161            current_mode: DrainMode,
162            #[serde(default)]
163            rate_limit_continuation_prompt: Option<RateLimitContinuationPromptRepr>,
164            #[serde(default)]
165            last_session_id: Option<String>,
166        }
167
168        let raw = AgentChainStateSerde::deserialize(deserializer)?;
169        let current_drain = infer_legacy_current_drain(
170            raw.current_drain,
171            raw.current_role,
172            raw.current_mode,
173            raw.rate_limit_continuation_prompt.as_ref(),
174        );
175        let current_role = current_drain.role();
176
177        let rate_limit_continuation_prompt = raw.rate_limit_continuation_prompt.map(|repr| {
178            match repr {
179                RateLimitContinuationPromptRepr::LegacyString(prompt) => {
180                    // Legacy checkpoints stored only the prompt string. Scope it to the
181                    // resolved drain role so resume can't cross-contaminate drains.
182                    RateLimitContinuationPrompt {
183                        drain: current_drain,
184                        role: current_role,
185                        prompt,
186                    }
187                }
188                RateLimitContinuationPromptRepr::Structured {
189                    _role: _,
190                    drain,
191                    prompt,
192                } => {
193                    let prompt_drain = drain.unwrap_or(current_drain);
194                    RateLimitContinuationPrompt {
195                        drain: prompt_drain,
196                        role: prompt_drain.role(),
197                        prompt,
198                    }
199                }
200            }
201        });
202
203        Ok(Self {
204            agents: raw.agents,
205            current_agent_index: raw.current_agent_index,
206            models_per_agent: raw.models_per_agent,
207            current_model_index: raw.current_model_index,
208            retry_cycle: raw.retry_cycle,
209            max_cycles: raw.max_cycles,
210            retry_delay_ms: raw.retry_delay_ms,
211            backoff_multiplier: raw.backoff_multiplier,
212            max_backoff_ms: raw.max_backoff_ms,
213            backoff_pending_ms: raw.backoff_pending_ms,
214            current_role,
215            current_drain,
216            current_mode: raw.current_mode,
217            rate_limit_continuation_prompt,
218            last_session_id: raw.last_session_id,
219        })
220    }
221}
222
223const fn default_retry_delay_ms() -> u64 {
224    1000
225}
226
227const fn default_backoff_multiplier() -> f64 {
228    2.0
229}
230
231const fn default_max_backoff_ms() -> u64 {
232    60000
233}
234
235const fn default_current_drain() -> AgentDrain {
236    AgentDrain::Planning
237}
238
239const fn agent_drain_signature_tag(drain: AgentDrain) -> &'static [u8] {
240    match drain {
241        AgentDrain::Planning => b"planning\n",
242        AgentDrain::Development => b"development\n",
243        AgentDrain::Review => b"review\n",
244        AgentDrain::Fix => b"fix\n",
245        AgentDrain::Commit => b"commit\n",
246        AgentDrain::Analysis => b"analysis\n",
247    }
248}
249
250impl AgentChainState {
251    #[must_use]
252    pub fn initial() -> Self {
253        Self {
254            agents: Arc::from(vec![]),
255            current_agent_index: 0,
256            models_per_agent: Arc::from(vec![]),
257            current_model_index: 0,
258            retry_cycle: 0,
259            max_cycles: 3,
260            retry_delay_ms: default_retry_delay_ms(),
261            backoff_multiplier: default_backoff_multiplier(),
262            max_backoff_ms: default_max_backoff_ms(),
263            backoff_pending_ms: None,
264            current_role: AgentRole::Developer,
265            current_drain: default_current_drain(),
266            current_mode: DrainMode::Normal,
267            rate_limit_continuation_prompt: None,
268            last_session_id: None,
269        }
270    }
271
272    #[must_use]
273    pub fn matches_runtime_drain(&self, runtime_drain: AgentDrain) -> bool {
274        self.current_drain == runtime_drain
275    }
276
277    #[must_use]
278    pub fn with_agents(
279        mut self,
280        agents: Vec<String>,
281        models_per_agent: Vec<Vec<String>>,
282        role: AgentRole,
283    ) -> Self {
284        self.agents = Arc::from(agents);
285        self.models_per_agent = Arc::from(models_per_agent);
286        self.current_role = role;
287        self.current_drain = match role {
288            AgentRole::Developer => AgentDrain::Development,
289            AgentRole::Reviewer => AgentDrain::Review,
290            AgentRole::Commit => AgentDrain::Commit,
291            AgentRole::Analysis => AgentDrain::Analysis,
292        };
293        self.current_mode = DrainMode::Normal;
294        self
295    }
296
297    #[must_use]
298    pub const fn with_drain(mut self, drain: AgentDrain) -> Self {
299        self.current_drain = drain;
300        self.current_role = drain.role();
301        self
302    }
303
304    #[must_use]
305    pub const fn with_mode(mut self, mode: DrainMode) -> Self {
306        self.current_mode = mode;
307        self
308    }
309
310    #[must_use]
311    pub const fn active_role(&self) -> AgentRole {
312        self.current_drain.role()
313    }
314
315    /// Builder method to set the maximum number of retry cycles.
316    ///
317    /// A retry cycle is when all agents have been exhausted and we start
318    /// over with exponential backoff.
319    #[must_use]
320    pub const fn with_max_cycles(mut self, max_cycles: u32) -> Self {
321        self.max_cycles = max_cycles;
322        self
323    }
324
325    #[must_use]
326    pub const fn with_backoff_policy(
327        mut self,
328        retry_delay_ms: u64,
329        backoff_multiplier: f64,
330        max_backoff_ms: u64,
331    ) -> Self {
332        self.retry_delay_ms = retry_delay_ms;
333        self.backoff_multiplier = backoff_multiplier;
334        self.max_backoff_ms = max_backoff_ms;
335        self
336    }
337
338    #[must_use]
339    pub fn current_agent(&self) -> Option<&String> {
340        self.agents.get(self.current_agent_index)
341    }
342
343    /// Stable signature of the current consumer set (agents + configured models + drain).
344    ///
345    /// This is used to dedupe oversize materialization decisions across reducer retries.
346    /// The signature is stable under:
347    /// - switching the current agent/model index
348    /// - retry cycles
349    ///
350    /// It changes only when the configured consumer set changes.
351    #[must_use]
352    pub fn consumer_signature_sha256(&self) -> String {
353        let mut pairs: Vec<(&str, &[String])> = self
354            .agents
355            .iter()
356            .enumerate()
357            .map(|(idx, agent)| {
358                let models: &[String] = self
359                    .models_per_agent
360                    .get(idx)
361                    .map_or([].as_slice(), std::vec::Vec::as_slice);
362                (agent.as_str(), models)
363            })
364            .collect();
365
366        // Sort so the signature is stable even if callers reorder the configured
367        // consumer set.
368        pairs.sort_by(|(agent_a, models_a), (agent_b, models_b)| {
369            use std::cmp::Ordering;
370
371            let agent_ord = agent_a.cmp(agent_b);
372            if agent_ord != Ordering::Equal {
373                return agent_ord;
374            }
375
376            for (a, b) in models_a.iter().zip(models_b.iter()) {
377                let ord = a.cmp(b);
378                if ord != Ordering::Equal {
379                    return ord;
380                }
381            }
382
383            models_a.len().cmp(&models_b.len())
384        });
385
386        let mut hasher = Sha256::new();
387        hasher.update(agent_drain_signature_tag(self.current_drain));
388        for (agent, models) in pairs {
389            hasher.update(agent.as_bytes());
390            hasher.update(b"|");
391            for (idx, model) in models.iter().enumerate() {
392                if idx > 0 {
393                    hasher.update(b",");
394                }
395                hasher.update(model.as_bytes());
396            }
397            hasher.update(b"\n");
398        }
399        let digest = hasher.finalize();
400        digest.iter().fold(String::new(), |mut s, b| {
401            use std::fmt::Write;
402            write!(&mut s, "{b:02x}").unwrap();
403            s
404        })
405    }
406
407    #[cfg(test)]
408    fn legacy_consumer_signature_sha256_for_test(&self) -> String {
409        let mut rendered: Vec<String> = self
410            .agents
411            .iter()
412            .enumerate()
413            .map(|(idx, agent)| {
414                let models = self
415                    .models_per_agent
416                    .get(idx)
417                    .map_or([].as_slice(), std::vec::Vec::as_slice);
418                format!("{}|{}", agent, models.join(","))
419            })
420            .collect();
421
422        rendered.sort();
423
424        let mut hasher = Sha256::new();
425        hasher.update(agent_drain_signature_tag(self.current_drain));
426        for line in rendered {
427            hasher.update(line.as_bytes());
428            hasher.update(b"\n");
429        }
430        let digest = hasher.finalize();
431        digest.iter().fold(String::new(), |mut s, b| {
432            use std::fmt::Write;
433            write!(&mut s, "{b:02x}").unwrap();
434            s
435        })
436    }
437
438    /// Get the currently selected model for the current agent.
439    ///
440    /// Returns `None` if:
441    /// - No models are configured
442    /// - The current agent index is out of bounds
443    /// - The current model index is out of bounds
444    #[must_use]
445    pub fn current_model(&self) -> Option<&String> {
446        self.models_per_agent
447            .get(self.current_agent_index)
448            .and_then(|models| models.get(self.current_model_index))
449    }
450
451    #[must_use]
452    pub const fn is_exhausted(&self) -> bool {
453        self.retry_cycle >= self.max_cycles
454            && self.current_agent_index == 0
455            && self.current_model_index == 0
456    }
457}
458
459#[cfg(test)]
460mod consumer_signature_tests {
461    use super::*;
462
463    #[test]
464    fn test_consumer_signature_sorting_matches_legacy_rendered_pair_ordering() {
465        // This regression test locks in the pre-optimization signature ordering:
466        // sort by the lexicographic ordering of the rendered `agent|models_csv` strings.
467        //
468        // A length-first models compare changes ordering when the first model differs.
469        // Example: "a,z" must sort before "b" even though it is longer.
470        let state = AgentChainState::initial().with_agents(
471            vec!["agent".to_string(), "agent".to_string()],
472            vec![
473                vec!["b".to_string()],
474                vec!["a".to_string(), "z".to_string()],
475            ],
476            AgentRole::Developer,
477        );
478
479        assert_eq!(
480            state.consumer_signature_sha256(),
481            state.legacy_consumer_signature_sha256_for_test(),
482            "consumer signature ordering must remain stable for the same configured consumers"
483        );
484    }
485
486    #[test]
487    fn test_consumer_signature_uses_stable_drain_encoding() {
488        // The consumer signature is persisted in reducer state and used for dedupe.
489        // It must not depend on Debug formatting (variant renames would change the hash).
490        // Instead, it should use a stable, explicit role tag.
491        let state = AgentChainState::initial()
492            .with_agents(
493                vec!["agent-a".to_string()],
494                vec![vec!["m1".to_string(), "m2".to_string()]],
495                AgentRole::Reviewer,
496            )
497            .with_drain(AgentDrain::Fix);
498
499        let mut hasher = Sha256::new();
500        hasher.update(b"fix\n");
501        hasher.update(b"agent-a");
502        hasher.update(b"|");
503        hasher.update(b"m1");
504        hasher.update(b",");
505        hasher.update(b"m2");
506        hasher.update(b"\n");
507        let expected = hasher.finalize().iter().fold(String::new(), |mut acc, b| {
508            use std::fmt::Write;
509            write!(acc, "{b:02x}").unwrap();
510            acc
511        });
512
513        assert_eq!(
514            state.consumer_signature_sha256(),
515            expected,
516            "role encoding must be stable and explicit"
517        );
518    }
519}
520
521#[cfg(test)]
522mod legacy_rate_limit_prompt_tests {
523    use super::*;
524
525    #[test]
526    fn test_legacy_rate_limit_continuation_prompt_uses_current_role_on_deserialize() {
527        // Legacy checkpoints stored `rate_limit_continuation_prompt` as a bare string.
528        // When resuming, we must scope that prompt to the chain's `current_role`
529        // (the role the checkpoint was executing) instead of defaulting to Developer.
530        let state = AgentChainState::initial().with_agents(
531            vec!["a".to_string()],
532            vec![vec![]],
533            AgentRole::Reviewer,
534        );
535
536        let mut v = serde_json::to_value(&state).expect("serialize AgentChainState");
537        v["rate_limit_continuation_prompt"] = serde_json::Value::String("legacy prompt".into());
538
539        let json = serde_json::to_string(&v).expect("serialize JSON value");
540        let decoded: AgentChainState =
541            serde_json::from_str(&json).expect("deserialize AgentChainState");
542
543        let prompt = decoded
544            .rate_limit_continuation_prompt
545            .expect("expected legacy prompt to deserialize");
546        assert_eq!(prompt.drain, AgentDrain::Review);
547        assert_eq!(prompt.role, AgentRole::Reviewer);
548        assert_eq!(prompt.prompt, "legacy prompt");
549    }
550}