Skip to main content

ralph_workflow/reducer/state/agent_chain/
transitions.rs

1// State transition methods for AgentChainState.
2//
3// These methods implement the fallback chain progression: advancing models,
4// switching agents, and starting retry cycles with backoff.
5
6use std::sync::Arc;
7
8use super::backoff::calculate_backoff_delay_ms;
9use super::{AgentChainState, AgentRole, RateLimitContinuationPrompt};
10
11impl AgentChainState {
12    #[must_use]
13    pub fn advance_to_next_model(&self) -> Self {
14        let start_agent_index = self.current_agent_index;
15
16        // When models are configured, we try each model for the current agent once.
17        // If the models list is exhausted, advance to the next agent/retry cycle
18        // instead of looping models indefinitely.
19        let mut next = match self.models_per_agent.get(self.current_agent_index) {
20            Some(models) if !models.is_empty() => {
21                if self.current_model_index + 1 < models.len() {
22                    // Simple model advance - only increment model index
23                    Self {
24                        agents: Arc::clone(&self.agents),
25                        current_agent_index: self.current_agent_index,
26                        models_per_agent: Arc::clone(&self.models_per_agent),
27                        current_model_index: self.current_model_index + 1,
28                        retry_cycle: self.retry_cycle,
29                        max_cycles: self.max_cycles,
30                        retry_delay_ms: self.retry_delay_ms,
31                        backoff_multiplier: self.backoff_multiplier,
32                        max_backoff_ms: self.max_backoff_ms,
33                        backoff_pending_ms: self.backoff_pending_ms,
34                        current_role: self.current_role,
35                        rate_limit_continuation_prompt: self.rate_limit_continuation_prompt.clone(),
36                        last_session_id: self.last_session_id.clone(),
37                    }
38                } else {
39                    self.switch_to_next_agent()
40                }
41            }
42            _ => self.switch_to_next_agent(),
43        };
44
45        if next.current_agent_index != start_agent_index {
46            next.last_session_id = None;
47        }
48
49        next
50    }
51
52    #[must_use]
53    pub fn switch_to_next_agent(&self) -> Self {
54        if self.current_agent_index + 1 < self.agents.len() {
55            // Advance to next agent
56            Self {
57                agents: Arc::clone(&self.agents),
58                current_agent_index: self.current_agent_index + 1,
59                models_per_agent: Arc::clone(&self.models_per_agent),
60                current_model_index: 0,
61                retry_cycle: self.retry_cycle,
62                max_cycles: self.max_cycles,
63                retry_delay_ms: self.retry_delay_ms,
64                backoff_multiplier: self.backoff_multiplier,
65                max_backoff_ms: self.max_backoff_ms,
66                backoff_pending_ms: None,
67                current_role: self.current_role,
68                rate_limit_continuation_prompt: self.rate_limit_continuation_prompt.clone(),
69                last_session_id: self.last_session_id.clone(),
70            }
71        } else {
72            // Wrap around to first agent and increment retry cycle
73            let new_retry_cycle = self.retry_cycle + 1;
74            let new_backoff_pending_ms = if new_retry_cycle >= self.max_cycles {
75                None
76            } else {
77                // Create temporary state to calculate backoff
78                let temp = Self {
79                    agents: Arc::clone(&self.agents),
80                    current_agent_index: 0,
81                    models_per_agent: Arc::clone(&self.models_per_agent),
82                    current_model_index: 0,
83                    retry_cycle: new_retry_cycle,
84                    max_cycles: self.max_cycles,
85                    retry_delay_ms: self.retry_delay_ms,
86                    backoff_multiplier: self.backoff_multiplier,
87                    max_backoff_ms: self.max_backoff_ms,
88                    backoff_pending_ms: None,
89                    current_role: self.current_role,
90                    rate_limit_continuation_prompt: None,
91                    last_session_id: None,
92                };
93                Some(temp.calculate_backoff_delay_ms_for_retry_cycle())
94            };
95
96            Self {
97                agents: Arc::clone(&self.agents),
98                current_agent_index: 0,
99                models_per_agent: Arc::clone(&self.models_per_agent),
100                current_model_index: 0,
101                retry_cycle: new_retry_cycle,
102                max_cycles: self.max_cycles,
103                retry_delay_ms: self.retry_delay_ms,
104                backoff_multiplier: self.backoff_multiplier,
105                max_backoff_ms: self.max_backoff_ms,
106                backoff_pending_ms: new_backoff_pending_ms,
107                current_role: self.current_role,
108                rate_limit_continuation_prompt: self.rate_limit_continuation_prompt.clone(),
109                last_session_id: self.last_session_id.clone(),
110            }
111        }
112    }
113
114    /// Switch to a specific agent by name.
115    ///
116    /// If `to_agent` is unknown, falls back to `switch_to_next_agent()` to keep the
117    /// reducer deterministic.
118    #[must_use]
119    pub fn switch_to_agent_named(&self, to_agent: &str) -> Self {
120        let Some(target_index) = self.agents.iter().position(|a| a == to_agent) else {
121            return self.switch_to_next_agent();
122        };
123
124        if target_index == self.current_agent_index {
125            // Same agent - just reset model index
126            return Self {
127                agents: Arc::clone(&self.agents),
128                current_agent_index: self.current_agent_index,
129                models_per_agent: Arc::clone(&self.models_per_agent),
130                current_model_index: 0,
131                retry_cycle: self.retry_cycle,
132                max_cycles: self.max_cycles,
133                retry_delay_ms: self.retry_delay_ms,
134                backoff_multiplier: self.backoff_multiplier,
135                max_backoff_ms: self.max_backoff_ms,
136                backoff_pending_ms: None,
137                current_role: self.current_role,
138                rate_limit_continuation_prompt: self.rate_limit_continuation_prompt.clone(),
139                last_session_id: self.last_session_id.clone(),
140            };
141        }
142
143        if target_index <= self.current_agent_index {
144            // Treat switching to an earlier agent as starting a new retry cycle.
145            let new_retry_cycle = self.retry_cycle + 1;
146            let new_backoff_pending_ms = if new_retry_cycle >= self.max_cycles && target_index == 0
147            {
148                None
149            } else {
150                // Create temporary state to calculate backoff
151                let temp = Self {
152                    agents: Arc::clone(&self.agents),
153                    current_agent_index: target_index,
154                    models_per_agent: Arc::clone(&self.models_per_agent),
155                    current_model_index: 0,
156                    retry_cycle: new_retry_cycle,
157                    max_cycles: self.max_cycles,
158                    retry_delay_ms: self.retry_delay_ms,
159                    backoff_multiplier: self.backoff_multiplier,
160                    max_backoff_ms: self.max_backoff_ms,
161                    backoff_pending_ms: None,
162                    current_role: self.current_role,
163                    rate_limit_continuation_prompt: None,
164                    last_session_id: None,
165                };
166                Some(temp.calculate_backoff_delay_ms_for_retry_cycle())
167            };
168
169            Self {
170                agents: Arc::clone(&self.agents),
171                current_agent_index: target_index,
172                models_per_agent: Arc::clone(&self.models_per_agent),
173                current_model_index: 0,
174                retry_cycle: new_retry_cycle,
175                max_cycles: self.max_cycles,
176                retry_delay_ms: self.retry_delay_ms,
177                backoff_multiplier: self.backoff_multiplier,
178                max_backoff_ms: self.max_backoff_ms,
179                backoff_pending_ms: new_backoff_pending_ms,
180                current_role: self.current_role,
181                rate_limit_continuation_prompt: self.rate_limit_continuation_prompt.clone(),
182                last_session_id: self.last_session_id.clone(),
183            }
184        } else {
185            // Advancing to later agent
186            Self {
187                agents: Arc::clone(&self.agents),
188                current_agent_index: target_index,
189                models_per_agent: Arc::clone(&self.models_per_agent),
190                current_model_index: 0,
191                retry_cycle: self.retry_cycle,
192                max_cycles: self.max_cycles,
193                retry_delay_ms: self.retry_delay_ms,
194                backoff_multiplier: self.backoff_multiplier,
195                max_backoff_ms: self.max_backoff_ms,
196                backoff_pending_ms: None,
197                current_role: self.current_role,
198                rate_limit_continuation_prompt: self.rate_limit_continuation_prompt.clone(),
199                last_session_id: self.last_session_id.clone(),
200            }
201        }
202    }
203
204    /// Switch to next agent after rate limit, preserving prompt for continuation.
205    ///
206    /// This is used when an agent hits a 429 rate limit error. Instead of
207    /// retrying with the same agent (which would likely hit rate limits again),
208    /// we switch to the next agent and preserve the prompt so the new agent
209    /// can continue the same work.
210    #[must_use]
211    pub fn switch_to_next_agent_with_prompt(&self, prompt: Option<String>) -> Self {
212        let base = self.switch_to_next_agent();
213        // Back-compat: older callers didn't track role. Preserve prompt only.
214        Self {
215            agents: base.agents,
216            current_agent_index: base.current_agent_index,
217            models_per_agent: base.models_per_agent,
218            current_model_index: base.current_model_index,
219            retry_cycle: base.retry_cycle,
220            max_cycles: base.max_cycles,
221            retry_delay_ms: base.retry_delay_ms,
222            backoff_multiplier: base.backoff_multiplier,
223            max_backoff_ms: base.max_backoff_ms,
224            backoff_pending_ms: base.backoff_pending_ms,
225            current_role: base.current_role,
226            rate_limit_continuation_prompt: prompt.map(|p| RateLimitContinuationPrompt {
227                role: base.current_role,
228                prompt: p,
229            }),
230            last_session_id: base.last_session_id,
231        }
232    }
233
234    /// Switch to next agent after rate limit, preserving prompt for continuation (role-scoped).
235    #[must_use]
236    pub fn switch_to_next_agent_with_prompt_for_role(
237        &self,
238        role: AgentRole,
239        prompt: Option<String>,
240    ) -> Self {
241        let base = self.switch_to_next_agent();
242        Self {
243            agents: base.agents,
244            current_agent_index: base.current_agent_index,
245            models_per_agent: base.models_per_agent,
246            current_model_index: base.current_model_index,
247            retry_cycle: base.retry_cycle,
248            max_cycles: base.max_cycles,
249            retry_delay_ms: base.retry_delay_ms,
250            backoff_multiplier: base.backoff_multiplier,
251            max_backoff_ms: base.max_backoff_ms,
252            backoff_pending_ms: base.backoff_pending_ms,
253            current_role: base.current_role,
254            rate_limit_continuation_prompt: prompt
255                .map(|p| RateLimitContinuationPrompt { role, prompt: p }),
256            last_session_id: base.last_session_id,
257        }
258    }
259
260    /// Clear continuation prompt after successful execution.
261    ///
262    /// Called when an agent successfully completes its task, clearing any
263    /// saved prompt context from previous rate-limited agents.
264    #[must_use]
265    pub fn clear_continuation_prompt(&self) -> Self {
266        Self {
267            agents: Arc::clone(&self.agents),
268            current_agent_index: self.current_agent_index,
269            models_per_agent: Arc::clone(&self.models_per_agent),
270            current_model_index: self.current_model_index,
271            retry_cycle: self.retry_cycle,
272            max_cycles: self.max_cycles,
273            retry_delay_ms: self.retry_delay_ms,
274            backoff_multiplier: self.backoff_multiplier,
275            max_backoff_ms: self.max_backoff_ms,
276            backoff_pending_ms: self.backoff_pending_ms,
277            current_role: self.current_role,
278            rate_limit_continuation_prompt: None,
279            last_session_id: self.last_session_id.clone(),
280        }
281    }
282
283    #[must_use]
284    pub fn reset_for_role(&self, role: AgentRole) -> Self {
285        Self {
286            agents: Arc::clone(&self.agents),
287            current_agent_index: 0,
288            models_per_agent: Arc::clone(&self.models_per_agent),
289            current_model_index: 0,
290            retry_cycle: 0,
291            max_cycles: self.max_cycles,
292            retry_delay_ms: self.retry_delay_ms,
293            backoff_multiplier: self.backoff_multiplier,
294            max_backoff_ms: self.max_backoff_ms,
295            backoff_pending_ms: None,
296            current_role: role,
297            rate_limit_continuation_prompt: None,
298            last_session_id: None,
299        }
300    }
301
302    #[must_use]
303    pub fn reset(&self) -> Self {
304        Self {
305            agents: Arc::clone(&self.agents),
306            current_agent_index: 0,
307            models_per_agent: Arc::clone(&self.models_per_agent),
308            current_model_index: 0,
309            retry_cycle: self.retry_cycle,
310            max_cycles: self.max_cycles,
311            retry_delay_ms: self.retry_delay_ms,
312            backoff_multiplier: self.backoff_multiplier,
313            max_backoff_ms: self.max_backoff_ms,
314            backoff_pending_ms: None,
315            current_role: self.current_role,
316            rate_limit_continuation_prompt: None,
317            last_session_id: None,
318        }
319    }
320
321    /// Store session ID from agent response for potential reuse.
322    #[must_use]
323    pub fn with_session_id(&self, session_id: Option<String>) -> Self {
324        Self {
325            agents: Arc::clone(&self.agents),
326            current_agent_index: self.current_agent_index,
327            models_per_agent: Arc::clone(&self.models_per_agent),
328            current_model_index: self.current_model_index,
329            retry_cycle: self.retry_cycle,
330            max_cycles: self.max_cycles,
331            retry_delay_ms: self.retry_delay_ms,
332            backoff_multiplier: self.backoff_multiplier,
333            max_backoff_ms: self.max_backoff_ms,
334            backoff_pending_ms: self.backoff_pending_ms,
335            current_role: self.current_role,
336            rate_limit_continuation_prompt: self.rate_limit_continuation_prompt.clone(),
337            last_session_id: session_id,
338        }
339    }
340
341    /// Clear session ID (e.g., when switching agents or starting new work).
342    #[must_use]
343    pub fn clear_session_id(&self) -> Self {
344        Self {
345            agents: Arc::clone(&self.agents),
346            current_agent_index: self.current_agent_index,
347            models_per_agent: Arc::clone(&self.models_per_agent),
348            current_model_index: self.current_model_index,
349            retry_cycle: self.retry_cycle,
350            max_cycles: self.max_cycles,
351            retry_delay_ms: self.retry_delay_ms,
352            backoff_multiplier: self.backoff_multiplier,
353            max_backoff_ms: self.max_backoff_ms,
354            backoff_pending_ms: self.backoff_pending_ms,
355            current_role: self.current_role,
356            rate_limit_continuation_prompt: self.rate_limit_continuation_prompt.clone(),
357            last_session_id: None,
358        }
359    }
360
361    #[must_use]
362    pub fn start_retry_cycle(&self) -> Self {
363        let new_retry_cycle = self.retry_cycle + 1;
364        let new_backoff_pending_ms = if new_retry_cycle >= self.max_cycles {
365            None
366        } else {
367            // Create temporary state to calculate backoff
368            let temp = Self {
369                agents: Arc::clone(&self.agents),
370                current_agent_index: 0,
371                models_per_agent: Arc::clone(&self.models_per_agent),
372                current_model_index: 0,
373                retry_cycle: new_retry_cycle,
374                max_cycles: self.max_cycles,
375                retry_delay_ms: self.retry_delay_ms,
376                backoff_multiplier: self.backoff_multiplier,
377                max_backoff_ms: self.max_backoff_ms,
378                backoff_pending_ms: None,
379                current_role: self.current_role,
380                rate_limit_continuation_prompt: None,
381                last_session_id: None,
382            };
383            Some(temp.calculate_backoff_delay_ms_for_retry_cycle())
384        };
385
386        Self {
387            agents: Arc::clone(&self.agents),
388            current_agent_index: 0,
389            models_per_agent: Arc::clone(&self.models_per_agent),
390            current_model_index: 0,
391            retry_cycle: new_retry_cycle,
392            max_cycles: self.max_cycles,
393            retry_delay_ms: self.retry_delay_ms,
394            backoff_multiplier: self.backoff_multiplier,
395            max_backoff_ms: self.max_backoff_ms,
396            backoff_pending_ms: new_backoff_pending_ms,
397            current_role: self.current_role,
398            rate_limit_continuation_prompt: self.rate_limit_continuation_prompt.clone(),
399            last_session_id: self.last_session_id.clone(),
400        }
401    }
402
403    #[must_use]
404    pub fn clear_backoff_pending(&self) -> Self {
405        Self {
406            agents: Arc::clone(&self.agents),
407            current_agent_index: self.current_agent_index,
408            models_per_agent: Arc::clone(&self.models_per_agent),
409            current_model_index: self.current_model_index,
410            retry_cycle: self.retry_cycle,
411            max_cycles: self.max_cycles,
412            retry_delay_ms: self.retry_delay_ms,
413            backoff_multiplier: self.backoff_multiplier,
414            max_backoff_ms: self.max_backoff_ms,
415            backoff_pending_ms: None,
416            current_role: self.current_role,
417            rate_limit_continuation_prompt: self.rate_limit_continuation_prompt.clone(),
418            last_session_id: self.last_session_id.clone(),
419        }
420    }
421
422    pub(super) fn calculate_backoff_delay_ms_for_retry_cycle(&self) -> u64 {
423        // The first retry cycle should use the base delay.
424        let cycle_index = self.retry_cycle.saturating_sub(1);
425        calculate_backoff_delay_ms(
426            self.retry_delay_ms,
427            self.backoff_multiplier,
428            self.max_backoff_ms,
429            cycle_index,
430        )
431    }
432}
433
434#[cfg(test)]
435mod backoff_semantics_tests {
436    use super::*;
437
438    #[test]
439    fn test_switch_to_agent_named_preserves_backoff_when_retry_cycle_hits_max_but_state_is_not_exhausted(
440    ) {
441        let mut state = AgentChainState::initial().with_agents(
442            vec!["a".to_string(), "b".to_string(), "c".to_string()],
443            vec![vec![], vec![], vec![]],
444            AgentRole::Developer,
445        );
446        state.max_cycles = 2;
447        state.retry_cycle = 1;
448        state.current_agent_index = 2;
449
450        let next = state.switch_to_agent_named("b");
451
452        assert_eq!(next.current_agent_index, 1);
453        assert_eq!(next.retry_cycle, 2);
454        assert!(
455            next.backoff_pending_ms.is_some(),
456            "backoff should remain pending unless the state is fully exhausted"
457        );
458    }
459}