use crate::error::{AgentError, Result};
use crate::runtime::ContainerId;
use std::time::Duration;
use zlayer_init_actions::InitAction;
use zlayer_spec::{ErrorsSpec, InitFailureAction, InitSpec};
const DEFAULT_MAX_INIT_RETRIES: u32 = 3;
#[derive(Debug, Clone)]
pub struct BackoffConfig {
pub initial_delay: Duration,
pub max_delay: Duration,
pub multiplier: f64,
}
impl Default for BackoffConfig {
fn default() -> Self {
Self {
initial_delay: Duration::from_secs(1),
max_delay: Duration::from_secs(60),
multiplier: 2.0,
}
}
}
impl BackoffConfig {
#[must_use]
#[allow(clippy::cast_possible_wrap)]
pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
let delay_secs = self.initial_delay.as_secs_f64() * self.multiplier.powi(attempt as i32);
let capped_secs = delay_secs.min(self.max_delay.as_secs_f64());
Duration::from_secs_f64(capped_secs)
}
}
pub struct InitOrchestrator {
id: ContainerId,
spec: InitSpec,
error_policy: ErrorsSpec,
max_retries: u32,
backoff_config: BackoffConfig,
}
impl InitOrchestrator {
#[must_use]
pub fn new(id: ContainerId, spec: InitSpec) -> Self {
Self {
id,
spec,
error_policy: ErrorsSpec::default(),
max_retries: DEFAULT_MAX_INIT_RETRIES,
backoff_config: BackoffConfig::default(),
}
}
#[must_use]
pub fn with_error_policy(id: ContainerId, spec: InitSpec, error_policy: ErrorsSpec) -> Self {
Self {
id,
spec,
error_policy,
max_retries: DEFAULT_MAX_INIT_RETRIES,
backoff_config: BackoffConfig::default(),
}
}
#[must_use]
pub fn with_max_retries(mut self, max_retries: u32) -> Self {
self.max_retries = max_retries;
self
}
#[must_use]
pub fn with_backoff_config(mut self, config: BackoffConfig) -> Self {
self.backoff_config = config;
self
}
pub async fn run(&self) -> Result<()> {
match self.error_policy.on_init_failure.action {
InitFailureAction::Fail => {
self.run_init_steps().await
}
InitFailureAction::Restart => {
self.run_with_retries(false).await
}
InitFailureAction::Backoff => {
self.run_with_retries(true).await
}
}
}
async fn run_with_retries(&self, use_backoff: bool) -> Result<()> {
let mut last_error = None;
for attempt in 0..=self.max_retries {
if attempt > 0 {
let delay = if use_backoff {
self.backoff_config.delay_for_attempt(attempt - 1)
} else {
Duration::from_millis(100) };
tracing::info!(
container = %self.id,
attempt = attempt,
max_retries = self.max_retries,
delay_ms = delay.as_millis(),
"Retrying init steps after failure"
);
tokio::time::sleep(delay).await;
}
match self.run_init_steps().await {
Ok(()) => {
if attempt > 0 {
tracing::info!(
container = %self.id,
attempt = attempt,
"Init steps succeeded after retry"
);
}
return Ok(());
}
Err(e) => {
tracing::warn!(
container = %self.id,
attempt = attempt,
max_retries = self.max_retries,
error = %e,
"Init steps failed"
);
last_error = Some(e);
}
}
}
Err(last_error.unwrap_or_else(|| AgentError::InitActionFailed {
id: self.id.to_string(),
reason: "Init failed after all retries".to_string(),
}))
}
async fn run_init_steps(&self) -> Result<()> {
let _start_grace = std::time::Instant::now();
for step in &self.spec.steps {
let _step_start = std::time::Instant::now();
let action = zlayer_init_actions::from_spec(
&step.uses,
&step.with,
Duration::from_secs(30), )
.map_err(|e| AgentError::InitActionFailed {
id: self.id.to_string(),
reason: e.to_string(),
})?;
let timeout = step.timeout.unwrap_or(Duration::from_secs(300));
let result = tokio::time::timeout(timeout, self.execute_step(&action, step)).await;
match result {
Ok(Ok(())) => {
}
Ok(Err(e)) => {
return match step.on_failure {
zlayer_spec::FailureAction::Fail => Err(AgentError::InitActionFailed {
id: self.id.to_string(),
reason: format!("step '{}' failed: {}", step.id, e),
}),
zlayer_spec::FailureAction::Warn => {
tracing::warn!(
container = %self.id,
step = %step.id,
error = %e,
"Init step failed (continuing due to warn policy)"
);
continue; }
zlayer_spec::FailureAction::Continue => {
continue;
}
};
}
Err(_) => {
return match step.on_failure {
zlayer_spec::FailureAction::Fail => Err(AgentError::Timeout { timeout }),
zlayer_spec::FailureAction::Warn => {
tracing::warn!(
container = %self.id,
step = %step.id,
timeout_secs = timeout.as_secs(),
"Init step timed out (continuing due to warn policy)"
);
continue; }
zlayer_spec::FailureAction::Continue => {
continue;
}
};
}
}
if let Some(retry_count) = step.retry {
let _ = retry_count;
}
}
Ok(())
}
async fn execute_step(&self, action: &InitAction, _step: &zlayer_spec::InitStep) -> Result<()> {
match action {
InitAction::WaitTcp(a) => a.execute().await.map_err(|e| AgentError::InitActionFailed {
id: self.id.to_string(),
reason: e.to_string(),
}),
InitAction::WaitHttp(a) => {
a.execute().await.map_err(|e| AgentError::InitActionFailed {
id: self.id.to_string(),
reason: e.to_string(),
})
}
InitAction::Run(a) => a.execute().await.map_err(|e| AgentError::InitActionFailed {
id: self.id.to_string(),
reason: e.to_string(),
}),
#[cfg(feature = "s3")]
InitAction::S3Push(a) => a.execute().await.map_err(|e| AgentError::InitActionFailed {
id: self.id.to_string(),
reason: e.to_string(),
}),
#[cfg(feature = "s3")]
InitAction::S3Pull(a) => a.execute().await.map_err(|e| AgentError::InitActionFailed {
id: self.id.to_string(),
reason: e.to_string(),
}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_container_id() -> ContainerId {
ContainerId {
service: "test".to_string(),
replica: 1,
}
}
#[tokio::test]
async fn test_init_orchestrator_success() {
let id = test_container_id();
let spec = InitSpec { steps: vec![] };
let orchestrator = InitOrchestrator::new(id, spec);
orchestrator.run().await.unwrap();
}
#[tokio::test]
async fn test_init_orchestrator_with_error_policy() {
let id = test_container_id();
let spec = InitSpec { steps: vec![] };
let error_policy = ErrorsSpec::default();
let orchestrator = InitOrchestrator::with_error_policy(id, spec, error_policy);
orchestrator.run().await.unwrap();
}
#[test]
fn test_backoff_config_default() {
let config = BackoffConfig::default();
assert_eq!(config.initial_delay, Duration::from_secs(1));
assert_eq!(config.max_delay, Duration::from_secs(60));
assert!(
(config.multiplier - 2.0).abs() < f64::EPSILON,
"multiplier should be 2.0"
);
}
#[test]
fn test_backoff_delay_calculation() {
let config = BackoffConfig::default();
assert_eq!(config.delay_for_attempt(0), Duration::from_secs(1));
assert_eq!(config.delay_for_attempt(1), Duration::from_secs(2));
assert_eq!(config.delay_for_attempt(2), Duration::from_secs(4));
assert_eq!(config.delay_for_attempt(3), Duration::from_secs(8));
assert_eq!(config.delay_for_attempt(4), Duration::from_secs(16));
assert_eq!(config.delay_for_attempt(5), Duration::from_secs(32));
assert_eq!(config.delay_for_attempt(6), Duration::from_secs(60));
assert_eq!(config.delay_for_attempt(7), Duration::from_secs(60));
}
#[test]
fn test_backoff_custom_config() {
let config = BackoffConfig {
initial_delay: Duration::from_millis(100),
max_delay: Duration::from_secs(10),
multiplier: 3.0,
};
assert_eq!(config.delay_for_attempt(0), Duration::from_millis(100));
assert_eq!(config.delay_for_attempt(1), Duration::from_millis(300));
assert_eq!(config.delay_for_attempt(2), Duration::from_millis(900));
assert_eq!(config.delay_for_attempt(3), Duration::from_millis(2700));
assert_eq!(config.delay_for_attempt(4), Duration::from_millis(8100));
assert_eq!(config.delay_for_attempt(5), Duration::from_secs(10));
}
#[test]
fn test_orchestrator_builder_pattern() {
let id = test_container_id();
let spec = InitSpec { steps: vec![] };
let orchestrator = InitOrchestrator::new(id.clone(), spec.clone())
.with_max_retries(5)
.with_backoff_config(BackoffConfig {
initial_delay: Duration::from_millis(500),
max_delay: Duration::from_secs(30),
multiplier: 1.5,
});
assert_eq!(orchestrator.max_retries, 5);
assert_eq!(
orchestrator.backoff_config.initial_delay,
Duration::from_millis(500)
);
}
}