Skip to main content

ralph_workflow/agents/
fallback.rs

1//! Fallback chain configuration for agent fault tolerance.
2//!
3//! This module defines the `FallbackConfig` structure that controls how Ralph
4//! handles agent failures. It supports:
5//! - Agent-level fallback (try different agents)
6//! - Provider-level fallback (try different models within same agent)
7//! - Exponential backoff with cycling
8
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12/// Agent role (developer, reviewer, or commit).
13///
14/// Each role can have its own chain of fallback agents.
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
16pub enum AgentRole {
17    /// Developer agent: implements features based on PROMPT.md.
18    Developer,
19    /// Reviewer agent: reviews code and fixes issues.
20    Reviewer,
21    /// Commit agent: generates commit messages from diffs.
22    Commit,
23    /// Analysis agent: independently verifies progress (diff vs plan).
24    Analysis,
25}
26
27impl std::fmt::Display for AgentRole {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        match self {
30            Self::Developer => write!(f, "developer"),
31            Self::Reviewer => write!(f, "reviewer"),
32            Self::Commit => write!(f, "commit"),
33            Self::Analysis => write!(f, "analysis"),
34        }
35    }
36}
37
38/// Agent chain configuration for preferred agents and fallback switching.
39///
40/// The agent chain defines both:
41/// 1. The **preferred agent** (first in the list) for each role
42/// 2. The **fallback agents** (remaining in the list) to try if the preferred fails
43///
44/// This provides a unified way to configure which agents to use and in what order.
45/// Ralph automatically switches to the next agent in the chain when encountering
46/// errors like rate limits or auth failures.
47///
48/// ## Provider-Level Fallback
49///
50/// In addition to agent-level fallback, you can configure provider-level fallback
51/// within a single agent using the `provider_fallback` field. This is useful for
52/// agents like opencode that support multiple providers/models via the `-m` flag.
53///
54/// Example:
55/// ```toml
56/// [agent_chain]
57/// provider_fallback.opencode = ["-m opencode/glm-4.7-free", "-m opencode/claude-sonnet-4"]
58/// ```
59///
60/// ## Exponential Backoff and Cycling
61///
62/// When all fallbacks are exhausted, Ralph uses exponential backoff and cycles
63/// back to the first agent in the chain:
64/// - Base delay starts at `retry_delay_ms` (default: 1000ms)
65/// - Each cycle multiplies by `backoff_multiplier` (default: 2.0)
66/// - Capped at `max_backoff_ms` (default: 60000ms = 1 minute)
67/// - Maximum cycles controlled by `max_cycles` (default: 3)
68#[derive(Debug, Clone, Deserialize, Serialize)]
69pub struct FallbackConfig {
70    /// Ordered list of agents for developer role (first = preferred, rest = fallbacks).
71    #[serde(default)]
72    pub developer: Vec<String>,
73    /// Ordered list of agents for reviewer role (first = preferred, rest = fallbacks).
74    #[serde(default)]
75    pub reviewer: Vec<String>,
76    /// Ordered list of agents for commit role (first = preferred, rest = fallbacks).
77    #[serde(default)]
78    pub commit: Vec<String>,
79    /// Ordered list of agents for analysis role (first = preferred, rest = fallbacks).
80    ///
81    /// If empty, analysis falls back to the developer chain.
82    #[serde(default)]
83    pub analysis: Vec<String>,
84    /// Provider-level fallback: maps agent name to list of model flags to try.
85    /// Example: `opencode = ["-m opencode/glm-4.7-free", "-m opencode/claude-sonnet-4"]`
86    #[serde(default)]
87    pub provider_fallback: HashMap<String, Vec<String>>,
88    /// Maximum number of retries per agent before moving to next.
89    #[serde(default = "default_max_retries")]
90    pub max_retries: u32,
91    /// Base delay between retries in milliseconds.
92    #[serde(default = "default_retry_delay_ms")]
93    pub retry_delay_ms: u64,
94    /// Multiplier for exponential backoff (default: 2.0).
95    #[serde(default = "default_backoff_multiplier")]
96    pub backoff_multiplier: f64,
97    /// Maximum backoff delay in milliseconds (default: 60000 = 1 minute).
98    #[serde(default = "default_max_backoff_ms")]
99    pub max_backoff_ms: u64,
100    /// Maximum number of cycles through all agents before giving up (default: 3).
101    #[serde(default = "default_max_cycles")]
102    pub max_cycles: u32,
103}
104
105const fn default_max_retries() -> u32 {
106    3
107}
108
109const fn default_retry_delay_ms() -> u64 {
110    1000
111}
112
113const fn default_backoff_multiplier() -> f64 {
114    2.0
115}
116
117const fn default_max_backoff_ms() -> u64 {
118    60000 // 1 minute
119}
120
121const fn default_max_cycles() -> u32 {
122    3
123}
124
125// IEEE 754 double precision constants for f64_to_u64_via_bits
126const IEEE_754_EXP_BIAS: i32 = 1023;
127const IEEE_754_EXP_MASK: u64 = 0x7FF;
128const IEEE_754_MANTISSA_MASK: u64 = 0x000F_FFFF_FFFF_FFFF;
129const IEEE_754_IMPLICIT_ONE: u64 = 1u64 << 52;
130
131/// Convert f64 to u64 using IEEE 754 bit manipulation to avoid cast lints.
132///
133/// This function handles the conversion by extracting the raw bits of the f64
134/// and manually decoding the IEEE 754 format. For values in the range [0, 100000],
135/// this produces correct results without triggering clippy's cast lints.
136fn f64_to_u64_via_bits(value: f64) -> u64 {
137    // Handle special cases first
138    if !value.is_finite() || value < 0.0 {
139        return 0;
140    }
141
142    // Use to_bits() to get the raw IEEE 754 representation
143    let bits = value.to_bits();
144
145    // IEEE 754 double precision:
146    // - Bit 63: sign (we know it's 0 for non-negative values)
147    // - Bits 52-62: exponent (biased by 1023)
148    // - Bits 0-51: mantissa (with implicit leading 1 for normalized numbers)
149    let exp_biased = ((bits >> 52) & IEEE_754_EXP_MASK) as i32;
150    let mantissa = bits & IEEE_754_MANTISSA_MASK;
151
152    // Check for denormal numbers (exponent == 0)
153    if exp_biased == 0 {
154        // Denormal: value = mantissa * 2^(-1022)
155        // For small values (< 1), this results in 0
156        return 0;
157    }
158
159    // Normalized number
160    let exp = exp_biased - IEEE_754_EXP_BIAS;
161
162    // For integer values, the exponent tells us where the binary point is
163    // If exp < 0, the value is < 1, so round to 0
164    if exp < 0 {
165        return 0;
166    }
167
168    // For exp >= 0, we have an integer value
169    // The value is (1.mantissa) * 2^exp where 1.mantissa has 53 bits
170    let full_mantissa = mantissa | IEEE_754_IMPLICIT_ONE;
171
172    // Shift to get the integer value
173    // We need to shift right by (52 - exp) to get the integer
174    let shift = 52i32 - exp;
175
176    if shift <= 0 {
177        // Value is very large, saturate at u64::MAX
178        // But our input is clamped to [0, 100000], so this won't happen
179        u64::MAX
180    } else if shift < 64 {
181        full_mantissa >> shift
182    } else {
183        0
184    }
185}
186
187impl Default for FallbackConfig {
188    fn default() -> Self {
189        Self {
190            developer: Vec::new(),
191            reviewer: Vec::new(),
192            commit: Vec::new(),
193            analysis: Vec::new(),
194            provider_fallback: HashMap::new(),
195            max_retries: default_max_retries(),
196            retry_delay_ms: default_retry_delay_ms(),
197            backoff_multiplier: default_backoff_multiplier(),
198            max_backoff_ms: default_max_backoff_ms(),
199            max_cycles: default_max_cycles(),
200        }
201    }
202}
203
204impl FallbackConfig {
205    /// Calculate exponential backoff delay for a given cycle.
206    ///
207    /// Uses the formula: min(base * multiplier^cycle, `max_backoff`)
208    ///
209    /// Uses integer arithmetic to avoid floating-point casting issues.
210    #[must_use]
211    pub fn calculate_backoff(&self, cycle: u32) -> u64 {
212        // For common multiplier values, use direct integer computation
213        // to avoid f64->u64 conversion and associated clippy lints.
214        let multiplier_hundredths = self.get_multiplier_hundredths();
215        let base_hundredths = self.retry_delay_ms.saturating_mul(100);
216
217        // Calculate: base * (multiplier^cycle) / 100^cycle
218        // Use saturating arithmetic to avoid overflow
219        let mut delay_hundredths = base_hundredths;
220        for _ in 0..cycle {
221            delay_hundredths = delay_hundredths.saturating_mul(multiplier_hundredths);
222            delay_hundredths = delay_hundredths.saturating_div(100);
223        }
224
225        // Convert back to milliseconds
226        delay_hundredths.div_euclid(100).min(self.max_backoff_ms)
227    }
228
229    /// Get the multiplier as hundredths (e.g., 2.0 -> 200, 1.5 -> 150).
230    ///
231    /// Uses a lookup table for common values to avoid f64->u64 casts.
232    /// For uncommon values, uses a safe conversion with validation.
233    fn get_multiplier_hundredths(&self) -> u64 {
234        const EPSILON: f64 = 0.0001;
235
236        // Common multiplier values - use exact integer matches
237        // This avoids the cast for the vast majority of cases
238        let m = self.backoff_multiplier;
239        if (m - 1.0).abs() < EPSILON {
240            return 100;
241        } else if (m - 1.5).abs() < EPSILON {
242            return 150;
243        } else if (m - 2.0).abs() < EPSILON {
244            return 200;
245        } else if (m - 2.5).abs() < EPSILON {
246            return 250;
247        } else if (m - 3.0).abs() < EPSILON {
248            return 300;
249        } else if (m - 4.0).abs() < EPSILON {
250            return 400;
251        } else if (m - 5.0).abs() < EPSILON {
252            return 500;
253        } else if (m - 10.0).abs() < EPSILON {
254            return 1000;
255        }
256
257        // For uncommon values, compute using the original formula
258        // The value is clamped to [0.0, 1000.0] so the result is in [0.0, 100000.0]
259        // We use to_bits() and manual decoding to avoid cast lints
260        let clamped = m.clamp(0.0, 1000.0);
261        let multiplied = clamped * 100.0;
262        let rounded = multiplied.round();
263
264        // Manual f64 to u64 conversion using IEEE 754 bit representation
265        f64_to_u64_via_bits(rounded)
266    }
267
268    /// Get fallback agents for a role.
269    #[must_use]
270    pub fn get_fallbacks(&self, role: AgentRole) -> &[String] {
271        match role {
272            AgentRole::Developer => &self.developer,
273            AgentRole::Reviewer => &self.reviewer,
274            AgentRole::Commit => self.get_effective_commit_fallbacks(),
275            AgentRole::Analysis => self.get_effective_analysis_fallbacks(),
276        }
277    }
278
279    /// Get effective fallback agents for analysis role.
280    ///
281    /// Falls back to developer chain if analysis chain is empty.
282    fn get_effective_analysis_fallbacks(&self) -> &[String] {
283        if self.analysis.is_empty() {
284            &self.developer
285        } else {
286            &self.analysis
287        }
288    }
289
290    /// Get effective fallback agents for commit role.
291    ///
292    /// Falls back to reviewer chain if commit chain is empty.
293    /// This ensures commit message generation can use the same agents
294    /// configured for code review when no dedicated commit agents are specified.
295    fn get_effective_commit_fallbacks(&self) -> &[String] {
296        if self.commit.is_empty() {
297            &self.reviewer
298        } else {
299            &self.commit
300        }
301    }
302
303    /// Check if fallback is configured for a role.
304    #[must_use]
305    pub fn has_fallbacks(&self, role: AgentRole) -> bool {
306        !self.get_fallbacks(role).is_empty()
307    }
308
309    /// Get provider-level fallback model flags for an agent.
310    ///
311    /// Returns the list of model flags to try for the given agent name.
312    /// Empty slice if no provider fallback is configured for this agent.
313    pub fn get_provider_fallbacks(&self, agent_name: &str) -> &[String] {
314        self.provider_fallback
315            .get(agent_name)
316            .map_or(&[], std::vec::Vec::as_slice)
317    }
318
319    /// Check if provider-level fallback is configured for an agent.
320    #[must_use]
321    pub fn has_provider_fallbacks(&self, agent_name: &str) -> bool {
322        self.provider_fallback
323            .get(agent_name)
324            .is_some_and(|v| !v.is_empty())
325    }
326}
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331
332    #[test]
333    fn test_agent_role_display() {
334        assert_eq!(format!("{}", AgentRole::Developer), "developer");
335        assert_eq!(format!("{}", AgentRole::Reviewer), "reviewer");
336        assert_eq!(format!("{}", AgentRole::Commit), "commit");
337        assert_eq!(format!("{}", AgentRole::Analysis), "analysis");
338    }
339
340    #[test]
341    fn test_fallback_config_defaults() {
342        let config = FallbackConfig::default();
343        assert!(config.developer.is_empty());
344        assert!(config.reviewer.is_empty());
345        assert!(config.commit.is_empty());
346        assert!(config.analysis.is_empty());
347        assert_eq!(config.max_retries, 3);
348        assert_eq!(config.retry_delay_ms, 1000);
349        // Use approximate comparison for floating point
350        assert!((config.backoff_multiplier - 2.0).abs() < f64::EPSILON);
351        assert_eq!(config.max_backoff_ms, 60000);
352        assert_eq!(config.max_cycles, 3);
353    }
354
355    #[test]
356    fn test_fallback_config_calculate_backoff() {
357        let config = FallbackConfig {
358            retry_delay_ms: 1000,
359            backoff_multiplier: 2.0,
360            max_backoff_ms: 60000,
361            ..Default::default()
362        };
363
364        assert_eq!(config.calculate_backoff(0), 1000);
365        assert_eq!(config.calculate_backoff(1), 2000);
366        assert_eq!(config.calculate_backoff(2), 4000);
367        assert_eq!(config.calculate_backoff(3), 8000);
368
369        // Should cap at max
370        assert_eq!(config.calculate_backoff(10), 60000);
371    }
372
373    #[test]
374    fn test_fallback_config_get_fallbacks() {
375        let config = FallbackConfig {
376            developer: vec!["claude".to_string(), "codex".to_string()],
377            reviewer: vec!["codex".to_string()],
378            ..Default::default()
379        };
380
381        assert_eq!(
382            config.get_fallbacks(AgentRole::Developer),
383            &["claude", "codex"]
384        );
385        assert_eq!(config.get_fallbacks(AgentRole::Reviewer), &["codex"]);
386
387        // Analysis defaults to developer chain when not configured.
388        assert_eq!(
389            config.get_fallbacks(AgentRole::Analysis),
390            &["claude", "codex"]
391        );
392    }
393
394    #[test]
395    fn test_fallback_config_has_fallbacks() {
396        let config = FallbackConfig {
397            developer: vec!["claude".to_string()],
398            reviewer: vec![],
399            ..Default::default()
400        };
401
402        assert!(config.has_fallbacks(AgentRole::Developer));
403        assert!(config.has_fallbacks(AgentRole::Analysis));
404        assert!(!config.has_fallbacks(AgentRole::Reviewer));
405    }
406
407    #[test]
408    fn test_fallback_config_defaults_provider_fallback() {
409        let config = FallbackConfig::default();
410        assert!(config.get_provider_fallbacks("opencode").is_empty());
411        assert!(!config.has_provider_fallbacks("opencode"));
412    }
413
414    #[test]
415    fn test_provider_fallback_config() {
416        let mut provider_fallback = HashMap::new();
417        provider_fallback.insert(
418            "opencode".to_string(),
419            vec![
420                "-m opencode/glm-4.7-free".to_string(),
421                "-m opencode/claude-sonnet-4".to_string(),
422            ],
423        );
424
425        let config = FallbackConfig {
426            provider_fallback,
427            ..Default::default()
428        };
429
430        let fallbacks = config.get_provider_fallbacks("opencode");
431        assert_eq!(fallbacks.len(), 2);
432        assert_eq!(fallbacks[0], "-m opencode/glm-4.7-free");
433        assert_eq!(fallbacks[1], "-m opencode/claude-sonnet-4");
434
435        assert!(config.has_provider_fallbacks("opencode"));
436        assert!(!config.has_provider_fallbacks("claude"));
437    }
438
439    #[test]
440    fn test_fallback_config_from_toml() {
441        let toml_str = r#"
442            developer = ["claude", "codex"]
443            reviewer = ["codex", "claude"]
444            max_retries = 5
445            retry_delay_ms = 2000
446
447            [provider_fallback]
448            opencode = ["-m opencode/glm-4.7-free", "-m zai/glm-4.7"]
449        "#;
450
451        let config: FallbackConfig = toml::from_str(toml_str).unwrap();
452        assert_eq!(config.developer, vec!["claude", "codex"]);
453        assert_eq!(config.reviewer, vec!["codex", "claude"]);
454        assert_eq!(config.max_retries, 5);
455        assert_eq!(config.retry_delay_ms, 2000);
456        assert_eq!(config.get_provider_fallbacks("opencode").len(), 2);
457    }
458
459    #[test]
460    fn test_commit_uses_reviewer_chain_when_empty() {
461        // When commit chain is empty, it should fall back to reviewer chain
462        let config = FallbackConfig {
463            commit: vec![],
464            reviewer: vec!["agent1".to_string(), "agent2".to_string()],
465            ..Default::default()
466        };
467
468        // Commit role should use reviewer chain when commit chain is empty
469        assert_eq!(
470            config.get_fallbacks(AgentRole::Commit),
471            &["agent1", "agent2"]
472        );
473        assert!(config.has_fallbacks(AgentRole::Commit));
474    }
475
476    #[test]
477    fn test_commit_uses_own_chain_when_configured() {
478        // When commit chain is configured, it should use its own chain
479        let config = FallbackConfig {
480            commit: vec!["commit-agent".to_string()],
481            reviewer: vec!["reviewer-agent".to_string()],
482            ..Default::default()
483        };
484
485        // Commit role should use its own chain
486        assert_eq!(config.get_fallbacks(AgentRole::Commit), &["commit-agent"]);
487    }
488}