initium 2.1.0

Swiss-army toolbox for Kubernetes initContainers — wait-for, seed, render, fetch in a single static Rust binary
use std::time::{Duration, Instant};

pub struct Config {
    pub max_attempts: u32,
    pub initial_delay: Duration,
    pub max_delay: Duration,
    pub backoff_factor: f64,
    pub jitter_fraction: f64,
}

impl Config {
    pub fn validate(&self) -> Result<(), String> {
        if self.max_attempts < 1 {
            return Err(format!(
                "max-attempts must be >= 1, got {}",
                self.max_attempts
            ));
        }
        if self.initial_delay.is_zero() {
            return Err("initial-delay must be > 0".into());
        }
        if self.max_delay < self.initial_delay {
            return Err(format!(
                "max-delay ({:?}) must be >= initial-delay ({:?})",
                self.max_delay, self.initial_delay
            ));
        }
        if self.backoff_factor < 1.0 {
            return Err(format!(
                "backoff-factor must be >= 1.0, got {}",
                self.backoff_factor
            ));
        }
        if !(0.0..=1.0).contains(&self.jitter_fraction) {
            return Err(format!(
                "jitter-fraction must be in [0, 1], got {}",
                self.jitter_fraction
            ));
        }
        Ok(())
    }
}

pub fn delay(cfg: &Config, attempt: u32) -> Duration {
    let base = cfg.initial_delay.as_secs_f64() * cfg.backoff_factor.powi(attempt as i32);
    let capped = base.min(cfg.max_delay.as_secs_f64());
    let jitter = if cfg.jitter_fraction > 0.0 {
        capped * cfg.jitter_fraction * rand::random::<f64>()
    } else {
        0.0
    };
    Duration::from_secs_f64(capped + jitter)
}

pub struct RetryResult {
    pub attempt: u32,
    pub err: Option<String>,
}

pub fn do_retry<F>(cfg: &Config, deadline: Option<Instant>, mut f: F) -> RetryResult
where
    F: FnMut(u32) -> std::result::Result<(), String>,
{
    for attempt in 0..cfg.max_attempts {
        match f(attempt) {
            Ok(()) => return RetryResult { attempt, err: None },
            Err(e) => {
                if attempt == cfg.max_attempts - 1 {
                    return RetryResult {
                        attempt,
                        err: Some(format!(
                            "all {} attempts failed, last error: {}",
                            cfg.max_attempts, e
                        )),
                    };
                }
                let d = delay(cfg, attempt);
                if let Some(dl) = deadline {
                    if Instant::now() + d > dl {
                        return RetryResult {
                            attempt,
                            err: Some(format!("deadline exceeded after attempt {}", attempt + 1)),
                        };
                    }
                }
                std::thread::sleep(d);
            }
        }
    }
    RetryResult {
        attempt: 0,
        err: Some("max attempts reached".into()),
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    fn test_config() -> Config {
        Config {
            max_attempts: 3,
            initial_delay: Duration::from_millis(10),
            max_delay: Duration::from_millis(100),
            backoff_factor: 2.0,
            jitter_fraction: 0.0,
        }
    }

    #[test]
    fn test_validate_ok() {
        assert!(test_config().validate().is_ok());
    }

    #[test]
    fn test_validate_max_attempts() {
        let mut cfg = test_config();
        cfg.max_attempts = 0;
        assert!(cfg.validate().is_err());
    }

    #[test]
    fn test_validate_initial_delay() {
        let mut cfg = test_config();
        cfg.initial_delay = Duration::ZERO;
        assert!(cfg.validate().is_err());
    }

    #[test]
    fn test_validate_max_delay() {
        let mut cfg = test_config();
        cfg.max_delay = Duration::from_millis(1);
        assert!(cfg.validate().is_err());
    }

    #[test]
    fn test_validate_backoff() {
        let mut cfg = test_config();
        cfg.backoff_factor = 0.5;
        assert!(cfg.validate().is_err());
    }

    #[test]
    fn test_validate_jitter() {
        let mut cfg = test_config();
        cfg.jitter_fraction = 1.5;
        assert!(cfg.validate().is_err());
    }

    #[test]
    fn test_delay_exponential() {
        let cfg = test_config();
        let d0 = delay(&cfg, 0);
        let d1 = delay(&cfg, 1);
        let d2 = delay(&cfg, 2);
        assert!(d1 > d0);
        assert!(d2 > d1);
    }

    #[test]
    fn test_delay_capped() {
        let cfg = test_config();
        let d = delay(&cfg, 100);
        assert!(d <= cfg.max_delay + Duration::from_millis(1));
    }

    #[test]
    fn test_do_success() {
        let cfg = test_config();
        let result = do_retry(&cfg, None, |_| Ok(()));
        assert!(result.err.is_none());
        assert_eq!(result.attempt, 0);
    }

    #[test]
    fn test_do_eventual_success() {
        let cfg = test_config();
        let result = do_retry(&cfg, None, |attempt| {
            if attempt < 2 {
                Err("not yet".into())
            } else {
                Ok(())
            }
        });
        assert!(result.err.is_none());
        assert_eq!(result.attempt, 2);
    }

    #[test]
    fn test_do_all_fail() {
        let cfg = test_config();
        let result = do_retry(&cfg, None, |_| Err("fail".into()));
        assert!(result.err.is_some());
        assert!(result.err.unwrap().contains("all 3 attempts failed"));
    }

    #[test]
    fn test_do_deadline() {
        let cfg = Config {
            max_attempts: 100,
            initial_delay: Duration::from_millis(50),
            max_delay: Duration::from_secs(1),
            backoff_factor: 1.0,
            jitter_fraction: 0.0,
        };
        let deadline = Instant::now() + Duration::from_millis(10);
        let result = do_retry(&cfg, Some(deadline), |_| Err("fail".into()));
        assert!(result.err.is_some());
        assert!(result.err.unwrap().contains("deadline"));
    }
}