Skip to main content

ralph_workflow/reducer/state/agent_chain/
transitions.rs

1// State transition methods for AgentChainState.
2//
3// These methods implement the fallback chain progression: advancing models,
4// switching agents, and starting retry cycles with backoff.
5
6use std::sync::Arc;
7
8use super::backoff::calculate_backoff_delay_ms;
9use super::{AgentChainState, AgentDrain, AgentRole, DrainMode, RateLimitContinuationPrompt};
10
11impl AgentChainState {
12    #[must_use]
13    pub fn advance_to_next_model(&self) -> Self {
14        let start_agent_index = self.current_agent_index;
15
16        // When models are configured, we try each model for the current agent once.
17        // If the models list is exhausted, advance to the next agent/retry cycle
18        // instead of looping models indefinitely.
19        let next = match self.models_per_agent.get(self.current_agent_index) {
20            Some(models) if !models.is_empty() => {
21                if self.current_model_index + 1 < models.len() {
22                    // Simple model advance - only increment model index
23                    Self {
24                        agents: Arc::clone(&self.agents),
25                        current_agent_index: self.current_agent_index,
26                        models_per_agent: Arc::clone(&self.models_per_agent),
27                        current_model_index: self.current_model_index + 1,
28                        retry_cycle: self.retry_cycle,
29                        max_cycles: self.max_cycles,
30                        retry_delay_ms: self.retry_delay_ms,
31                        backoff_multiplier: self.backoff_multiplier,
32                        max_backoff_ms: self.max_backoff_ms,
33                        backoff_pending_ms: self.backoff_pending_ms,
34                        current_role: self.current_role,
35                        current_drain: self.current_drain,
36                        current_mode: self.current_mode,
37                        rate_limit_continuation_prompt: self.rate_limit_continuation_prompt.clone(),
38                        last_session_id: self.last_session_id.clone(),
39                        last_failure_reason: self.last_failure_reason.clone(),
40                    }
41                } else {
42                    self.switch_to_next_agent()
43                }
44            }
45            _ => self.switch_to_next_agent(),
46        };
47
48        // Clear session ID when switching to a different agent
49        if next.current_agent_index != start_agent_index {
50            Self {
51                last_session_id: None,
52                ..next
53            }
54        } else {
55            next
56        }
57    }
58
59    #[must_use]
60    pub fn switch_to_next_agent(&self) -> Self {
61        if self.current_agent_index + 1 < self.agents.len() {
62            // Advance to next agent
63            Self {
64                agents: Arc::clone(&self.agents),
65                current_agent_index: self.current_agent_index + 1,
66                models_per_agent: Arc::clone(&self.models_per_agent),
67                current_model_index: 0,
68                retry_cycle: self.retry_cycle,
69                max_cycles: self.max_cycles,
70                retry_delay_ms: self.retry_delay_ms,
71                backoff_multiplier: self.backoff_multiplier,
72                max_backoff_ms: self.max_backoff_ms,
73                backoff_pending_ms: None,
74                current_role: self.current_role,
75                current_drain: self.current_drain,
76                current_mode: self.current_mode,
77                rate_limit_continuation_prompt: self.rate_limit_continuation_prompt.clone(),
78                last_session_id: self.last_session_id.clone(),
79                last_failure_reason: self.last_failure_reason.clone(),
80            }
81        } else {
82            // Wrap around to first agent and increment retry cycle
83            let new_retry_cycle = self.retry_cycle + 1;
84            let new_backoff_pending_ms = if new_retry_cycle >= self.max_cycles {
85                None
86            } else {
87                // Create temporary state to calculate backoff
88                let temp = Self {
89                    agents: Arc::clone(&self.agents),
90                    current_agent_index: 0,
91                    models_per_agent: Arc::clone(&self.models_per_agent),
92                    current_model_index: 0,
93                    retry_cycle: new_retry_cycle,
94                    max_cycles: self.max_cycles,
95                    retry_delay_ms: self.retry_delay_ms,
96                    backoff_multiplier: self.backoff_multiplier,
97                    max_backoff_ms: self.max_backoff_ms,
98                    backoff_pending_ms: None,
99                    current_role: self.current_role,
100                    current_drain: self.current_drain,
101                    current_mode: self.current_mode,
102                    rate_limit_continuation_prompt: None,
103                    last_session_id: None,
104                    last_failure_reason: None,
105                };
106                Some(temp.calculate_backoff_delay_ms_for_retry_cycle())
107            };
108
109            Self {
110                agents: Arc::clone(&self.agents),
111                current_agent_index: 0,
112                models_per_agent: Arc::clone(&self.models_per_agent),
113                current_model_index: 0,
114                retry_cycle: new_retry_cycle,
115                max_cycles: self.max_cycles,
116                retry_delay_ms: self.retry_delay_ms,
117                backoff_multiplier: self.backoff_multiplier,
118                max_backoff_ms: self.max_backoff_ms,
119                backoff_pending_ms: new_backoff_pending_ms,
120                current_role: self.current_role,
121                current_drain: self.current_drain,
122                current_mode: self.current_mode,
123                rate_limit_continuation_prompt: self.rate_limit_continuation_prompt.clone(),
124                last_session_id: self.last_session_id.clone(),
125                last_failure_reason: self.last_failure_reason.clone(),
126            }
127        }
128    }
129
130    /// Switch to a specific agent by name.
131    ///
132    /// If `to_agent` is unknown, falls back to `switch_to_next_agent()` to keep the
133    /// reducer deterministic.
134    #[must_use]
135    pub fn switch_to_agent_named(&self, to_agent: &str) -> Self {
136        let Some(target_index) = self.agents.iter().position(|a| a == to_agent) else {
137            return self.switch_to_next_agent();
138        };
139
140        if target_index == self.current_agent_index {
141            // Same agent - just reset model index
142            return Self {
143                agents: Arc::clone(&self.agents),
144                current_agent_index: self.current_agent_index,
145                models_per_agent: Arc::clone(&self.models_per_agent),
146                current_model_index: 0,
147                retry_cycle: self.retry_cycle,
148                max_cycles: self.max_cycles,
149                retry_delay_ms: self.retry_delay_ms,
150                backoff_multiplier: self.backoff_multiplier,
151                max_backoff_ms: self.max_backoff_ms,
152                backoff_pending_ms: None,
153                current_role: self.current_role,
154                current_drain: self.current_drain,
155                current_mode: self.current_mode,
156                rate_limit_continuation_prompt: self.rate_limit_continuation_prompt.clone(),
157                last_session_id: self.last_session_id.clone(),
158                last_failure_reason: self.last_failure_reason.clone(),
159            };
160        }
161
162        if target_index <= self.current_agent_index {
163            // Treat switching to an earlier agent as starting a new retry cycle.
164            let new_retry_cycle = self.retry_cycle + 1;
165            let new_backoff_pending_ms = if new_retry_cycle >= self.max_cycles && target_index == 0
166            {
167                None
168            } else {
169                // Create temporary state to calculate backoff
170                let temp = Self {
171                    agents: Arc::clone(&self.agents),
172                    current_agent_index: target_index,
173                    models_per_agent: Arc::clone(&self.models_per_agent),
174                    current_model_index: 0,
175                    retry_cycle: new_retry_cycle,
176                    max_cycles: self.max_cycles,
177                    retry_delay_ms: self.retry_delay_ms,
178                    backoff_multiplier: self.backoff_multiplier,
179                    max_backoff_ms: self.max_backoff_ms,
180                    backoff_pending_ms: None,
181                    current_role: self.current_role,
182                    current_drain: self.current_drain,
183                    current_mode: self.current_mode,
184                    rate_limit_continuation_prompt: None,
185                    last_session_id: None,
186                    last_failure_reason: None,
187                };
188                Some(temp.calculate_backoff_delay_ms_for_retry_cycle())
189            };
190
191            Self {
192                agents: Arc::clone(&self.agents),
193                current_agent_index: target_index,
194                models_per_agent: Arc::clone(&self.models_per_agent),
195                current_model_index: 0,
196                retry_cycle: new_retry_cycle,
197                max_cycles: self.max_cycles,
198                retry_delay_ms: self.retry_delay_ms,
199                backoff_multiplier: self.backoff_multiplier,
200                max_backoff_ms: self.max_backoff_ms,
201                backoff_pending_ms: new_backoff_pending_ms,
202                current_role: self.current_role,
203                current_drain: self.current_drain,
204                current_mode: self.current_mode,
205                rate_limit_continuation_prompt: self.rate_limit_continuation_prompt.clone(),
206                last_session_id: self.last_session_id.clone(),
207                last_failure_reason: self.last_failure_reason.clone(),
208            }
209        } else {
210            // Advancing to later agent
211            Self {
212                agents: Arc::clone(&self.agents),
213                current_agent_index: target_index,
214                models_per_agent: Arc::clone(&self.models_per_agent),
215                current_model_index: 0,
216                retry_cycle: self.retry_cycle,
217                max_cycles: self.max_cycles,
218                retry_delay_ms: self.retry_delay_ms,
219                backoff_multiplier: self.backoff_multiplier,
220                max_backoff_ms: self.max_backoff_ms,
221                backoff_pending_ms: None,
222                current_role: self.current_role,
223                current_drain: self.current_drain,
224                current_mode: self.current_mode,
225                rate_limit_continuation_prompt: self.rate_limit_continuation_prompt.clone(),
226                last_session_id: self.last_session_id.clone(),
227                last_failure_reason: self.last_failure_reason.clone(),
228            }
229        }
230    }
231
232    /// Switch to next agent after rate limit, preserving prompt for continuation.
233    ///
234    /// This is used when an agent hits a 429 rate limit error. Instead of
235    /// retrying with the same agent (which would likely hit rate limits again),
236    /// we switch to the next agent and preserve the prompt so the new agent
237    /// can continue the same work.
238    #[must_use]
239    pub fn switch_to_next_agent_with_prompt(&self, prompt: Option<String>) -> Self {
240        let base = self.switch_to_next_agent();
241        // Back-compat: older callers didn't track role. Preserve prompt only.
242        Self {
243            agents: base.agents,
244            current_agent_index: base.current_agent_index,
245            models_per_agent: base.models_per_agent,
246            current_model_index: base.current_model_index,
247            retry_cycle: base.retry_cycle,
248            max_cycles: base.max_cycles,
249            retry_delay_ms: base.retry_delay_ms,
250            backoff_multiplier: base.backoff_multiplier,
251            max_backoff_ms: base.max_backoff_ms,
252            backoff_pending_ms: base.backoff_pending_ms,
253            current_role: base.current_role,
254            current_drain: base.current_drain,
255            current_mode: base.current_mode,
256            rate_limit_continuation_prompt: prompt.map(|p| RateLimitContinuationPrompt {
257                drain: base.current_drain,
258                role: base.current_role,
259                prompt: p,
260            }),
261            last_session_id: base.last_session_id,
262            last_failure_reason: base.last_failure_reason.clone(),
263        }
264    }
265
266    /// Switch to next agent after rate limit, preserving prompt for continuation (role-scoped).
267    #[must_use]
268    pub fn switch_to_next_agent_with_prompt_for_role(
269        &self,
270        role: AgentRole,
271        prompt: Option<String>,
272    ) -> Self {
273        let base = self.switch_to_next_agent();
274        Self {
275            agents: base.agents,
276            current_agent_index: base.current_agent_index,
277            models_per_agent: base.models_per_agent,
278            current_model_index: base.current_model_index,
279            retry_cycle: base.retry_cycle,
280            max_cycles: base.max_cycles,
281            retry_delay_ms: base.retry_delay_ms,
282            backoff_multiplier: base.backoff_multiplier,
283            max_backoff_ms: base.max_backoff_ms,
284            backoff_pending_ms: base.backoff_pending_ms,
285            current_role: base.current_role,
286            current_drain: base.current_drain,
287            current_mode: base.current_mode,
288            rate_limit_continuation_prompt: prompt.map(|p| RateLimitContinuationPrompt {
289                drain: base.current_drain,
290                role,
291                prompt: p,
292            }),
293            last_session_id: base.last_session_id,
294            last_failure_reason: base.last_failure_reason.clone(),
295        }
296    }
297
298    /// Clear continuation prompt after successful execution.
299    ///
300    /// Called when an agent successfully completes its task, clearing any
301    /// saved prompt context from previous rate-limited agents.
302    #[must_use]
303    pub fn clear_continuation_prompt(&self) -> Self {
304        Self {
305            agents: Arc::clone(&self.agents),
306            current_agent_index: self.current_agent_index,
307            models_per_agent: Arc::clone(&self.models_per_agent),
308            current_model_index: self.current_model_index,
309            retry_cycle: self.retry_cycle,
310            max_cycles: self.max_cycles,
311            retry_delay_ms: self.retry_delay_ms,
312            backoff_multiplier: self.backoff_multiplier,
313            max_backoff_ms: self.max_backoff_ms,
314            backoff_pending_ms: self.backoff_pending_ms,
315            current_role: self.current_role,
316            current_drain: self.current_drain,
317            current_mode: self.current_mode,
318            rate_limit_continuation_prompt: None,
319            last_session_id: self.last_session_id.clone(),
320            last_failure_reason: None,
321        }
322    }
323
324    #[must_use]
325    pub fn reset_for_drain(&self, drain: AgentDrain) -> Self {
326        Self {
327            agents: Arc::clone(&self.agents),
328            current_agent_index: 0,
329            models_per_agent: Arc::clone(&self.models_per_agent),
330            current_model_index: 0,
331            retry_cycle: 0,
332            max_cycles: self.max_cycles,
333            retry_delay_ms: self.retry_delay_ms,
334            backoff_multiplier: self.backoff_multiplier,
335            max_backoff_ms: self.max_backoff_ms,
336            backoff_pending_ms: None,
337            current_role: drain.role(),
338            current_drain: drain,
339            current_mode: DrainMode::Normal,
340            rate_limit_continuation_prompt: None,
341            last_session_id: None,
342            last_failure_reason: None,
343        }
344    }
345
346    #[must_use]
347    pub fn reset_for_role(&self, role: AgentRole) -> Self {
348        self.reset_for_drain(match role {
349            AgentRole::Developer => AgentDrain::Development,
350            AgentRole::Reviewer => AgentDrain::Review,
351            AgentRole::Commit => AgentDrain::Commit,
352            AgentRole::Analysis => AgentDrain::Analysis,
353        })
354    }
355
356    #[must_use]
357    pub fn reset(&self) -> Self {
358        Self {
359            agents: Arc::clone(&self.agents),
360            current_agent_index: 0,
361            models_per_agent: Arc::clone(&self.models_per_agent),
362            current_model_index: 0,
363            retry_cycle: self.retry_cycle,
364            max_cycles: self.max_cycles,
365            retry_delay_ms: self.retry_delay_ms,
366            backoff_multiplier: self.backoff_multiplier,
367            max_backoff_ms: self.max_backoff_ms,
368            backoff_pending_ms: None,
369            current_role: self.current_role,
370            current_drain: self.current_drain,
371            current_mode: DrainMode::Normal,
372            rate_limit_continuation_prompt: None,
373            last_session_id: None,
374            last_failure_reason: None,
375        }
376    }
377
378    /// Store session ID from agent response for potential reuse.
379    #[must_use]
380    pub fn with_session_id(&self, session_id: Option<String>) -> Self {
381        Self {
382            agents: Arc::clone(&self.agents),
383            current_agent_index: self.current_agent_index,
384            models_per_agent: Arc::clone(&self.models_per_agent),
385            current_model_index: self.current_model_index,
386            retry_cycle: self.retry_cycle,
387            max_cycles: self.max_cycles,
388            retry_delay_ms: self.retry_delay_ms,
389            backoff_multiplier: self.backoff_multiplier,
390            max_backoff_ms: self.max_backoff_ms,
391            backoff_pending_ms: self.backoff_pending_ms,
392            current_role: self.current_role,
393            current_drain: self.current_drain,
394            current_mode: self.current_mode,
395            rate_limit_continuation_prompt: self.rate_limit_continuation_prompt.clone(),
396            last_session_id: session_id,
397            last_failure_reason: self.last_failure_reason.clone(),
398        }
399    }
400
401    /// Store last failure reason for CLI output context.
402    #[must_use]
403    pub fn with_failure_reason(&self, reason: Option<String>) -> Self {
404        Self {
405            agents: Arc::clone(&self.agents),
406            current_agent_index: self.current_agent_index,
407            models_per_agent: Arc::clone(&self.models_per_agent),
408            current_model_index: self.current_model_index,
409            retry_cycle: self.retry_cycle,
410            max_cycles: self.max_cycles,
411            retry_delay_ms: self.retry_delay_ms,
412            backoff_multiplier: self.backoff_multiplier,
413            max_backoff_ms: self.max_backoff_ms,
414            backoff_pending_ms: self.backoff_pending_ms,
415            current_role: self.current_role,
416            current_drain: self.current_drain,
417            current_mode: self.current_mode,
418            rate_limit_continuation_prompt: self.rate_limit_continuation_prompt.clone(),
419            last_session_id: self.last_session_id.clone(),
420            last_failure_reason: reason,
421        }
422    }
423
424    /// Clear session ID (e.g., when switching agents or starting new work).
425    #[must_use]
426    pub fn clear_session_id(&self) -> Self {
427        Self {
428            agents: Arc::clone(&self.agents),
429            current_agent_index: self.current_agent_index,
430            models_per_agent: Arc::clone(&self.models_per_agent),
431            current_model_index: self.current_model_index,
432            retry_cycle: self.retry_cycle,
433            max_cycles: self.max_cycles,
434            retry_delay_ms: self.retry_delay_ms,
435            backoff_multiplier: self.backoff_multiplier,
436            max_backoff_ms: self.max_backoff_ms,
437            backoff_pending_ms: self.backoff_pending_ms,
438            current_role: self.current_role,
439            current_drain: self.current_drain,
440            current_mode: self.current_mode,
441            rate_limit_continuation_prompt: self.rate_limit_continuation_prompt.clone(),
442            last_session_id: None,
443            last_failure_reason: self.last_failure_reason.clone(),
444        }
445    }
446
447    #[must_use]
448    pub fn start_retry_cycle(&self) -> Self {
449        let new_retry_cycle = self.retry_cycle + 1;
450        let new_backoff_pending_ms = if new_retry_cycle >= self.max_cycles {
451            None
452        } else {
453            // Create temporary state to calculate backoff
454            let temp = Self {
455                agents: Arc::clone(&self.agents),
456                current_agent_index: 0,
457                models_per_agent: Arc::clone(&self.models_per_agent),
458                current_model_index: 0,
459                retry_cycle: new_retry_cycle,
460                max_cycles: self.max_cycles,
461                retry_delay_ms: self.retry_delay_ms,
462                backoff_multiplier: self.backoff_multiplier,
463                max_backoff_ms: self.max_backoff_ms,
464                backoff_pending_ms: None,
465                current_role: self.current_role,
466                current_drain: self.current_drain,
467                current_mode: self.current_mode,
468                rate_limit_continuation_prompt: None,
469                last_session_id: None,
470                last_failure_reason: None,
471            };
472            Some(temp.calculate_backoff_delay_ms_for_retry_cycle())
473        };
474
475        Self {
476            agents: Arc::clone(&self.agents),
477            current_agent_index: 0,
478            models_per_agent: Arc::clone(&self.models_per_agent),
479            current_model_index: 0,
480            retry_cycle: new_retry_cycle,
481            max_cycles: self.max_cycles,
482            retry_delay_ms: self.retry_delay_ms,
483            backoff_multiplier: self.backoff_multiplier,
484            max_backoff_ms: self.max_backoff_ms,
485            backoff_pending_ms: new_backoff_pending_ms,
486            current_role: self.current_role,
487            current_drain: self.current_drain,
488            current_mode: self.current_mode,
489            rate_limit_continuation_prompt: self.rate_limit_continuation_prompt.clone(),
490            last_session_id: self.last_session_id.clone(),
491            last_failure_reason: self.last_failure_reason.clone(),
492        }
493    }
494
495    #[must_use]
496    pub fn clear_backoff_pending(&self) -> Self {
497        Self {
498            agents: Arc::clone(&self.agents),
499            current_agent_index: self.current_agent_index,
500            models_per_agent: Arc::clone(&self.models_per_agent),
501            current_model_index: self.current_model_index,
502            retry_cycle: self.retry_cycle,
503            max_cycles: self.max_cycles,
504            retry_delay_ms: self.retry_delay_ms,
505            backoff_multiplier: self.backoff_multiplier,
506            max_backoff_ms: self.max_backoff_ms,
507            backoff_pending_ms: None,
508            current_role: self.current_role,
509            current_drain: self.current_drain,
510            current_mode: self.current_mode,
511            rate_limit_continuation_prompt: self.rate_limit_continuation_prompt.clone(),
512            last_session_id: self.last_session_id.clone(),
513            last_failure_reason: self.last_failure_reason.clone(),
514        }
515    }
516
517    pub(super) fn calculate_backoff_delay_ms_for_retry_cycle(&self) -> u64 {
518        // The first retry cycle should use the base delay.
519        let cycle_index = self.retry_cycle.saturating_sub(1);
520        calculate_backoff_delay_ms(
521            self.retry_delay_ms,
522            self.backoff_multiplier,
523            self.max_backoff_ms,
524            cycle_index,
525        )
526    }
527}
528
529#[cfg(test)]
530mod backoff_semantics_tests {
531    use super::*;
532
533    #[test]
534    fn test_switch_to_agent_named_preserves_backoff_when_retry_cycle_hits_max_but_state_is_not_exhausted(
535    ) {
536        let state = AgentChainState::initial()
537            .with_agents(
538                vec!["a".to_string(), "b".to_string(), "c".to_string()],
539                vec![vec![], vec![], vec![]],
540                AgentRole::Developer,
541            )
542            .with_max_cycles(2)
543            .with_retry_cycle(1)
544            .with_current_agent_index(2);
545
546        let next = state.switch_to_agent_named("b");
547
548        assert_eq!(next.current_agent_index, 1);
549        assert_eq!(next.retry_cycle, 2);
550        assert!(
551            next.backoff_pending_ms.is_some(),
552            "backoff should remain pending unless the state is fully exhausted"
553        );
554    }
555}