Skip to main content

ralph_workflow/agents/
fallback.rs

1//! Fallback chain configuration for agent fault tolerance.
2//!
3//! This module defines the `FallbackConfig` structure that controls how Ralph
4//! handles agent failures. It supports:
5//! - Agent-level fallback (try different agents)
6//! - Provider-level fallback (try different models within same agent)
7//! - Exponential backoff with cycling
8
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12/// Agent role (developer, reviewer, or commit).
13///
14/// Each role can have its own chain of fallback agents.
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
16pub enum AgentRole {
17    /// Developer agent: implements features based on PROMPT.md.
18    Developer,
19    /// Reviewer agent: reviews code and fixes issues.
20    Reviewer,
21    /// Commit agent: generates commit messages from diffs.
22    Commit,
23    /// Analysis agent: independently verifies progress (diff vs plan).
24    Analysis,
25}
26
27impl std::fmt::Display for AgentRole {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        match self {
30            Self::Developer => write!(f, "developer"),
31            Self::Reviewer => write!(f, "reviewer"),
32            Self::Commit => write!(f, "commit"),
33            Self::Analysis => write!(f, "analysis"),
34        }
35    }
36}
37
38/// Agent chain configuration for preferred agents and fallback switching.
39///
40/// The agent chain defines both:
41/// 1. The **preferred agent** (first in the list) for each role
42/// 2. The **fallback agents** (remaining in the list) to try if the preferred fails
43///
44/// This provides a unified way to configure which agents to use and in what order.
45/// Ralph automatically switches to the next agent in the chain when encountering
46/// errors like rate limits or auth failures.
47///
48/// ## Provider-Level Fallback
49///
50/// In addition to agent-level fallback, you can configure provider-level fallback
51/// within a single agent using the `provider_fallback` field. This is useful for
52/// agents like opencode that support multiple providers/models via the `-m` flag.
53///
54/// Example:
55/// ```toml
56/// [agent_chain]
57/// provider_fallback.opencode = ["-m opencode/glm-4.7-free", "-m opencode/claude-sonnet-4"]
58/// ```
59///
60/// ## Exponential Backoff and Cycling
61///
62/// When all fallbacks are exhausted, Ralph uses exponential backoff and cycles
63/// back to the first agent in the chain:
64/// - Base delay starts at `retry_delay_ms` (default: 1000ms)
65/// - Each cycle multiplies by `backoff_multiplier` (default: 2.0)
66/// - Capped at `max_backoff_ms` (default: 60000ms = 1 minute)
67/// - Maximum cycles controlled by `max_cycles` (default: 3)
68#[derive(Debug, Clone, Deserialize)]
69pub struct FallbackConfig {
70    /// Ordered list of agents for developer role (first = preferred, rest = fallbacks).
71    #[serde(default)]
72    pub developer: Vec<String>,
73    /// Ordered list of agents for reviewer role (first = preferred, rest = fallbacks).
74    #[serde(default)]
75    pub reviewer: Vec<String>,
76    /// Ordered list of agents for commit role (first = preferred, rest = fallbacks).
77    #[serde(default)]
78    pub commit: Vec<String>,
79    /// Ordered list of agents for analysis role (first = preferred, rest = fallbacks).
80    ///
81    /// If empty, analysis falls back to the developer chain.
82    #[serde(default)]
83    pub analysis: Vec<String>,
84    /// Provider-level fallback: maps agent name to list of model flags to try.
85    /// Example: `opencode = ["-m opencode/glm-4.7-free", "-m opencode/claude-sonnet-4"]`
86    #[serde(default)]
87    pub provider_fallback: HashMap<String, Vec<String>>,
88    /// Maximum number of retries per agent before moving to next.
89    #[serde(default = "default_max_retries")]
90    pub max_retries: u32,
91    /// Base delay between retries in milliseconds.
92    #[serde(default = "default_retry_delay_ms")]
93    pub retry_delay_ms: u64,
94    /// Multiplier for exponential backoff (default: 2.0).
95    #[serde(default = "default_backoff_multiplier")]
96    pub backoff_multiplier: f64,
97    /// Maximum backoff delay in milliseconds (default: 60000 = 1 minute).
98    #[serde(default = "default_max_backoff_ms")]
99    pub max_backoff_ms: u64,
100    /// Maximum number of cycles through all agents before giving up (default: 3).
101    #[serde(default = "default_max_cycles")]
102    pub max_cycles: u32,
103}
104
105const fn default_max_retries() -> u32 {
106    3
107}
108
109const fn default_retry_delay_ms() -> u64 {
110    1000
111}
112
113const fn default_backoff_multiplier() -> f64 {
114    2.0
115}
116
117const fn default_max_backoff_ms() -> u64 {
118    60000 // 1 minute
119}
120
121const fn default_max_cycles() -> u32 {
122    3
123}
124
125// IEEE 754 double precision constants for f64_to_u64_via_bits
126const IEEE_754_EXP_BIAS: i32 = 1023;
127const IEEE_754_EXP_MASK: u64 = 0x7FF;
128const IEEE_754_MANTISSA_MASK: u64 = 0x000F_FFFF_FFFF_FFFF;
129const IEEE_754_IMPLICIT_ONE: u64 = 1u64 << 52;
130
131/// Convert f64 to u64 using IEEE 754 bit manipulation to avoid cast lints.
132///
133/// This function handles the conversion by extracting the raw bits of the f64
134/// and manually decoding the IEEE 754 format. For values in the range [0, 100000],
135/// this produces correct results without triggering clippy's cast lints.
136fn f64_to_u64_via_bits(value: f64) -> u64 {
137    // Handle special cases first
138    if !value.is_finite() || value < 0.0 {
139        return 0;
140    }
141
142    // Use to_bits() to get the raw IEEE 754 representation
143    let bits = value.to_bits();
144
145    // IEEE 754 double precision:
146    // - Bit 63: sign (we know it's 0 for non-negative values)
147    // - Bits 52-62: exponent (biased by 1023)
148    // - Bits 0-51: mantissa (with implicit leading 1 for normalized numbers)
149    let exp_biased = ((bits >> 52) & IEEE_754_EXP_MASK) as i32;
150    let mantissa = bits & IEEE_754_MANTISSA_MASK;
151
152    // Check for denormal numbers (exponent == 0)
153    if exp_biased == 0 {
154        // Denormal: value = mantissa * 2^(-1022)
155        // For small values (< 1), this results in 0
156        return 0;
157    }
158
159    // Normalized number
160    let exp = exp_biased - IEEE_754_EXP_BIAS;
161
162    // For integer values, the exponent tells us where the binary point is
163    // If exp < 0, the value is < 1, so round to 0
164    if exp < 0 {
165        return 0;
166    }
167
168    // For exp >= 0, we have an integer value
169    // The value is (1.mantissa) * 2^exp where 1.mantissa has 53 bits
170    let full_mantissa = mantissa | IEEE_754_IMPLICIT_ONE;
171
172    // Shift to get the integer value
173    // We need to shift right by (52 - exp) to get the integer
174    let shift = 52i32 - exp;
175
176    if shift <= 0 {
177        // Value is very large, saturate at u64::MAX
178        // But our input is clamped to [0, 100000], so this won't happen
179        u64::MAX
180    } else if shift < 64 {
181        full_mantissa >> shift
182    } else {
183        0
184    }
185}
186
187impl Default for FallbackConfig {
188    fn default() -> Self {
189        Self {
190            developer: Vec::new(),
191            reviewer: Vec::new(),
192            commit: Vec::new(),
193            analysis: Vec::new(),
194            provider_fallback: HashMap::new(),
195            max_retries: default_max_retries(),
196            retry_delay_ms: default_retry_delay_ms(),
197            backoff_multiplier: default_backoff_multiplier(),
198            max_backoff_ms: default_max_backoff_ms(),
199            max_cycles: default_max_cycles(),
200        }
201    }
202}
203
204impl FallbackConfig {
205    /// Calculate exponential backoff delay for a given cycle.
206    ///
207    /// Uses the formula: min(base * multiplier^cycle, `max_backoff`)
208    ///
209    /// Uses integer arithmetic to avoid floating-point casting issues.
210    pub fn calculate_backoff(&self, cycle: u32) -> u64 {
211        // For common multiplier values, use direct integer computation
212        // to avoid f64->u64 conversion and associated clippy lints.
213        let multiplier_hundredths = self.get_multiplier_hundredths();
214        let base_hundredths = self.retry_delay_ms.saturating_mul(100);
215
216        // Calculate: base * (multiplier^cycle) / 100^cycle
217        // Use saturating arithmetic to avoid overflow
218        let mut delay_hundredths = base_hundredths;
219        for _ in 0..cycle {
220            delay_hundredths = delay_hundredths.saturating_mul(multiplier_hundredths);
221            delay_hundredths = delay_hundredths.saturating_div(100);
222        }
223
224        // Convert back to milliseconds
225        delay_hundredths.div_euclid(100).min(self.max_backoff_ms)
226    }
227
228    /// Get the multiplier as hundredths (e.g., 2.0 -> 200, 1.5 -> 150).
229    ///
230    /// Uses a lookup table for common values to avoid f64->u64 casts.
231    /// For uncommon values, uses a safe conversion with validation.
232    fn get_multiplier_hundredths(&self) -> u64 {
233        const EPSILON: f64 = 0.0001;
234
235        // Common multiplier values - use exact integer matches
236        // This avoids the cast for the vast majority of cases
237        let m = self.backoff_multiplier;
238        if (m - 1.0).abs() < EPSILON {
239            return 100;
240        } else if (m - 1.5).abs() < EPSILON {
241            return 150;
242        } else if (m - 2.0).abs() < EPSILON {
243            return 200;
244        } else if (m - 2.5).abs() < EPSILON {
245            return 250;
246        } else if (m - 3.0).abs() < EPSILON {
247            return 300;
248        } else if (m - 4.0).abs() < EPSILON {
249            return 400;
250        } else if (m - 5.0).abs() < EPSILON {
251            return 500;
252        } else if (m - 10.0).abs() < EPSILON {
253            return 1000;
254        }
255
256        // For uncommon values, compute using the original formula
257        // The value is clamped to [0.0, 1000.0] so the result is in [0.0, 100000.0]
258        // We use to_bits() and manual decoding to avoid cast lints
259        let clamped = m.clamp(0.0, 1000.0);
260        let multiplied = clamped * 100.0;
261        let rounded = multiplied.round();
262
263        // Manual f64 to u64 conversion using IEEE 754 bit representation
264        f64_to_u64_via_bits(rounded)
265    }
266
267    /// Get fallback agents for a role.
268    pub fn get_fallbacks(&self, role: AgentRole) -> &[String] {
269        match role {
270            AgentRole::Developer => &self.developer,
271            AgentRole::Reviewer => &self.reviewer,
272            AgentRole::Commit => self.get_effective_commit_fallbacks(),
273            AgentRole::Analysis => self.get_effective_analysis_fallbacks(),
274        }
275    }
276
277    /// Get effective fallback agents for analysis role.
278    ///
279    /// Falls back to developer chain if analysis chain is empty.
280    fn get_effective_analysis_fallbacks(&self) -> &[String] {
281        if self.analysis.is_empty() {
282            &self.developer
283        } else {
284            &self.analysis
285        }
286    }
287
288    /// Get effective fallback agents for commit role.
289    ///
290    /// Falls back to reviewer chain if commit chain is empty.
291    /// This ensures commit message generation can use the same agents
292    /// configured for code review when no dedicated commit agents are specified.
293    fn get_effective_commit_fallbacks(&self) -> &[String] {
294        if self.commit.is_empty() {
295            &self.reviewer
296        } else {
297            &self.commit
298        }
299    }
300
301    /// Check if fallback is configured for a role.
302    pub fn has_fallbacks(&self, role: AgentRole) -> bool {
303        !self.get_fallbacks(role).is_empty()
304    }
305
306    /// Get provider-level fallback model flags for an agent.
307    ///
308    /// Returns the list of model flags to try for the given agent name.
309    /// Empty slice if no provider fallback is configured for this agent.
310    pub fn get_provider_fallbacks(&self, agent_name: &str) -> &[String] {
311        self.provider_fallback
312            .get(agent_name)
313            .map_or(&[], std::vec::Vec::as_slice)
314    }
315
316    /// Check if provider-level fallback is configured for an agent.
317    pub fn has_provider_fallbacks(&self, agent_name: &str) -> bool {
318        self.provider_fallback
319            .get(agent_name)
320            .is_some_and(|v| !v.is_empty())
321    }
322}
323
324#[cfg(test)]
325mod tests {
326    use super::*;
327
328    #[test]
329    fn test_agent_role_display() {
330        assert_eq!(format!("{}", AgentRole::Developer), "developer");
331        assert_eq!(format!("{}", AgentRole::Reviewer), "reviewer");
332        assert_eq!(format!("{}", AgentRole::Commit), "commit");
333        assert_eq!(format!("{}", AgentRole::Analysis), "analysis");
334    }
335
336    #[test]
337    fn test_fallback_config_defaults() {
338        let config = FallbackConfig::default();
339        assert!(config.developer.is_empty());
340        assert!(config.reviewer.is_empty());
341        assert!(config.commit.is_empty());
342        assert!(config.analysis.is_empty());
343        assert_eq!(config.max_retries, 3);
344        assert_eq!(config.retry_delay_ms, 1000);
345        // Use approximate comparison for floating point
346        assert!((config.backoff_multiplier - 2.0).abs() < f64::EPSILON);
347        assert_eq!(config.max_backoff_ms, 60000);
348        assert_eq!(config.max_cycles, 3);
349    }
350
351    #[test]
352    fn test_fallback_config_calculate_backoff() {
353        let config = FallbackConfig {
354            retry_delay_ms: 1000,
355            backoff_multiplier: 2.0,
356            max_backoff_ms: 60000,
357            ..Default::default()
358        };
359
360        assert_eq!(config.calculate_backoff(0), 1000);
361        assert_eq!(config.calculate_backoff(1), 2000);
362        assert_eq!(config.calculate_backoff(2), 4000);
363        assert_eq!(config.calculate_backoff(3), 8000);
364
365        // Should cap at max
366        assert_eq!(config.calculate_backoff(10), 60000);
367    }
368
369    #[test]
370    fn test_fallback_config_get_fallbacks() {
371        let config = FallbackConfig {
372            developer: vec!["claude".to_string(), "codex".to_string()],
373            reviewer: vec!["codex".to_string()],
374            ..Default::default()
375        };
376
377        assert_eq!(
378            config.get_fallbacks(AgentRole::Developer),
379            &["claude", "codex"]
380        );
381        assert_eq!(config.get_fallbacks(AgentRole::Reviewer), &["codex"]);
382
383        // Analysis defaults to developer chain when not configured.
384        assert_eq!(
385            config.get_fallbacks(AgentRole::Analysis),
386            &["claude", "codex"]
387        );
388    }
389
390    #[test]
391    fn test_fallback_config_has_fallbacks() {
392        let config = FallbackConfig {
393            developer: vec!["claude".to_string()],
394            reviewer: vec![],
395            ..Default::default()
396        };
397
398        assert!(config.has_fallbacks(AgentRole::Developer));
399        assert!(config.has_fallbacks(AgentRole::Analysis));
400        assert!(!config.has_fallbacks(AgentRole::Reviewer));
401    }
402
403    #[test]
404    fn test_fallback_config_defaults_provider_fallback() {
405        let config = FallbackConfig::default();
406        assert!(config.get_provider_fallbacks("opencode").is_empty());
407        assert!(!config.has_provider_fallbacks("opencode"));
408    }
409
410    #[test]
411    fn test_provider_fallback_config() {
412        let mut provider_fallback = HashMap::new();
413        provider_fallback.insert(
414            "opencode".to_string(),
415            vec![
416                "-m opencode/glm-4.7-free".to_string(),
417                "-m opencode/claude-sonnet-4".to_string(),
418            ],
419        );
420
421        let config = FallbackConfig {
422            provider_fallback,
423            ..Default::default()
424        };
425
426        let fallbacks = config.get_provider_fallbacks("opencode");
427        assert_eq!(fallbacks.len(), 2);
428        assert_eq!(fallbacks[0], "-m opencode/glm-4.7-free");
429        assert_eq!(fallbacks[1], "-m opencode/claude-sonnet-4");
430
431        assert!(config.has_provider_fallbacks("opencode"));
432        assert!(!config.has_provider_fallbacks("claude"));
433    }
434
435    #[test]
436    fn test_fallback_config_from_toml() {
437        let toml_str = r#"
438            developer = ["claude", "codex"]
439            reviewer = ["codex", "claude"]
440            max_retries = 5
441            retry_delay_ms = 2000
442
443            [provider_fallback]
444            opencode = ["-m opencode/glm-4.7-free", "-m zai/glm-4.7"]
445        "#;
446
447        let config: FallbackConfig = toml::from_str(toml_str).unwrap();
448        assert_eq!(config.developer, vec!["claude", "codex"]);
449        assert_eq!(config.reviewer, vec!["codex", "claude"]);
450        assert_eq!(config.max_retries, 5);
451        assert_eq!(config.retry_delay_ms, 2000);
452        assert_eq!(config.get_provider_fallbacks("opencode").len(), 2);
453    }
454
455    #[test]
456    fn test_commit_uses_reviewer_chain_when_empty() {
457        // When commit chain is empty, it should fall back to reviewer chain
458        let config = FallbackConfig {
459            commit: vec![],
460            reviewer: vec!["agent1".to_string(), "agent2".to_string()],
461            ..Default::default()
462        };
463
464        // Commit role should use reviewer chain when commit chain is empty
465        assert_eq!(
466            config.get_fallbacks(AgentRole::Commit),
467            &["agent1", "agent2"]
468        );
469        assert!(config.has_fallbacks(AgentRole::Commit));
470    }
471
472    #[test]
473    fn test_commit_uses_own_chain_when_configured() {
474        // When commit chain is configured, it should use its own chain
475        let config = FallbackConfig {
476            commit: vec!["commit-agent".to_string()],
477            reviewer: vec!["reviewer-agent".to_string()],
478            ..Default::default()
479        };
480
481        // Commit role should use its own chain
482        assert_eq!(config.get_fallbacks(AgentRole::Commit), &["commit-agent"]);
483    }
484}