use crate::error::types::SupervisorError;
use crate::id::types::{ChildId, SupervisorPath};
use crate::spec::child::{BackoffPolicy, ChildSpec, HealthPolicy, RestartPolicy, ShutdownPolicy};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::time::Duration;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
pub enum SupervisionStrategy {
OneForOne,
OneForAll,
RestForOne,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EscalationPolicy {
EscalateToParent,
ShutdownTree,
QuarantineScope,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct RestartBudget {
pub max_restarts: u32,
pub window: Duration,
}
impl RestartBudget {
pub fn new(max_restarts: u32, window: Duration) -> Self {
Self {
max_restarts,
window,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct GroupStrategy {
pub group: String,
pub strategy: SupervisionStrategy,
pub restart_budget: Option<RestartBudget>,
pub escalation_policy: Option<EscalationPolicy>,
}
impl GroupStrategy {
pub fn new(group: impl Into<String>, strategy: SupervisionStrategy) -> Self {
Self {
group: group.into(),
strategy,
restart_budget: None,
escalation_policy: None,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ChildStrategyOverride {
pub child_id: ChildId,
pub strategy: SupervisionStrategy,
pub restart_budget: Option<RestartBudget>,
pub escalation_policy: Option<EscalationPolicy>,
}
impl ChildStrategyOverride {
pub fn new(child_id: ChildId, strategy: SupervisionStrategy) -> Self {
Self {
child_id,
strategy,
restart_budget: None,
escalation_policy: None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct DynamicSupervisorPolicy {
pub enabled: bool,
pub child_limit: Option<usize>,
}
impl DynamicSupervisorPolicy {
pub fn unbounded() -> Self {
Self {
enabled: true,
child_limit: None,
}
}
pub fn limited(child_limit: usize) -> Self {
Self {
enabled: true,
child_limit: Some(child_limit),
}
}
pub fn allows_addition(&self, current_child_count: usize) -> bool {
self.enabled
&& self
.child_limit
.is_none_or(|limit| current_child_count < limit)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct StrategyExecutionPlan {
pub failed_child: ChildId,
pub strategy: SupervisionStrategy,
pub scope: Vec<ChildId>,
pub group: Option<String>,
pub restart_budget: Option<RestartBudget>,
pub escalation_policy: Option<EscalationPolicy>,
pub dynamic_supervisor_enabled: bool,
}
#[derive(Debug, Clone)]
pub struct SupervisorSpec {
pub path: SupervisorPath,
pub strategy: SupervisionStrategy,
pub children: Vec<ChildSpec>,
pub config_version: String,
pub default_restart_policy: RestartPolicy,
pub default_backoff_policy: BackoffPolicy,
pub default_health_policy: HealthPolicy,
pub default_shutdown_policy: ShutdownPolicy,
pub supervisor_failure_limit: u32,
pub restart_budget: Option<RestartBudget>,
pub escalation_policy: Option<EscalationPolicy>,
pub group_strategies: Vec<GroupStrategy>,
pub child_strategy_overrides: Vec<ChildStrategyOverride>,
pub dynamic_supervisor_policy: DynamicSupervisorPolicy,
pub control_channel_capacity: usize,
pub event_channel_capacity: usize,
}
impl SupervisorSpec {
pub fn root(children: Vec<ChildSpec>) -> Self {
let channel_capacity = channel_capacity_for_children(children.len());
Self {
path: SupervisorPath::root(),
strategy: SupervisionStrategy::OneForOne,
children,
config_version: String::from("unversioned"),
default_restart_policy: RestartPolicy::Transient,
default_backoff_policy: BackoffPolicy::new(
Duration::from_millis(10),
Duration::from_secs(1),
0.0,
),
default_health_policy: HealthPolicy::new(
Duration::from_secs(1),
Duration::from_secs(3),
),
default_shutdown_policy: ShutdownPolicy::new(
Duration::from_secs(5),
Duration::from_secs(1),
),
supervisor_failure_limit: 1,
restart_budget: None,
escalation_policy: None,
group_strategies: Vec::new(),
child_strategy_overrides: Vec::new(),
dynamic_supervisor_policy: DynamicSupervisorPolicy::unbounded(),
control_channel_capacity: channel_capacity,
event_channel_capacity: channel_capacity.saturating_mul(2),
}
}
pub fn validate(&self) -> Result<(), SupervisorError> {
if self.config_version.trim().is_empty() {
return Err(SupervisorError::fatal_config(
"config version must not be empty",
));
}
if self.supervisor_failure_limit == 0 {
return Err(SupervisorError::fatal_config(
"supervisor failure limit must be greater than zero",
));
}
if self.control_channel_capacity == 0 {
return Err(SupervisorError::fatal_config(
"control channel capacity must be greater than zero",
));
}
if self.event_channel_capacity == 0 {
return Err(SupervisorError::fatal_config(
"event channel capacity must be greater than zero",
));
}
for child in &self.children {
child.validate()?;
}
validate_restart_budget(self.restart_budget)?;
validate_group_strategies(&self.group_strategies, &self.children)?;
validate_child_strategy_overrides(self)?;
validate_dynamic_policy(self.dynamic_supervisor_policy)?;
Ok(())
}
}
fn validate_restart_budget(budget: Option<RestartBudget>) -> Result<(), SupervisorError> {
let Some(budget) = budget else {
return Ok(());
};
if budget.max_restarts == 0 {
return Err(SupervisorError::fatal_config(
"restart budget max_restarts must be greater than zero",
));
}
if budget.window.is_zero() {
return Err(SupervisorError::fatal_config(
"restart budget window must be greater than zero",
));
}
Ok(())
}
fn validate_group_strategies(
strategies: &[GroupStrategy],
children: &[ChildSpec],
) -> Result<(), SupervisorError> {
let mut groups = HashSet::new();
for strategy in strategies {
if strategy.group.trim().is_empty() {
return Err(SupervisorError::fatal_config(
"group strategy group must not be empty",
));
}
if !groups.insert(strategy.group.clone()) {
return Err(SupervisorError::fatal_config(format!(
"duplicate group strategy: {}",
strategy.group
)));
}
validate_restart_budget(strategy.restart_budget)?;
}
validate_group_membership(strategies, children)?;
Ok(())
}
fn validate_group_membership(
strategies: &[GroupStrategy],
children: &[ChildSpec],
) -> Result<(), SupervisorError> {
let groups = strategies
.iter()
.map(|strategy| strategy.group.clone())
.collect::<HashSet<_>>();
for strategy in strategies {
if !children
.iter()
.any(|child| child.tags.contains(&strategy.group))
{
return Err(SupervisorError::fatal_config(format!(
"group strategy references unused group: {}",
strategy.group
)));
}
}
for child in children {
let configured_group_count = child
.tags
.iter()
.filter(|tag| groups.contains(*tag))
.count();
if configured_group_count > 1 {
return Err(SupervisorError::fatal_config(format!(
"child strategy groups are ambiguous for child: {}",
child.id
)));
}
}
Ok(())
}
fn validate_child_strategy_overrides(spec: &SupervisorSpec) -> Result<(), SupervisorError> {
let child_ids = spec
.children
.iter()
.map(|child| child.id.clone())
.collect::<HashSet<_>>();
let mut overrides = HashSet::new();
for strategy in &spec.child_strategy_overrides {
if !child_ids.contains(&strategy.child_id) {
return Err(SupervisorError::fatal_config(format!(
"child strategy override references unknown child: {}",
strategy.child_id
)));
}
if !overrides.insert(strategy.child_id.clone()) {
return Err(SupervisorError::fatal_config(format!(
"duplicate child strategy override: {}",
strategy.child_id
)));
}
validate_restart_budget(strategy.restart_budget)?;
}
Ok(())
}
fn validate_dynamic_policy(policy: DynamicSupervisorPolicy) -> Result<(), SupervisorError> {
if policy.child_limit == Some(0) {
return Err(SupervisorError::fatal_config(
"dynamic supervisor child_limit must be greater than zero",
));
}
Ok(())
}
fn channel_capacity_for_children(child_count: usize) -> usize {
child_count.saturating_add(1)
}