use std::sync::Arc;
use std::time::{Duration, Instant};
use futures::stream::{FuturesUnordered, StreamExt};
use parking_lot::RwLock;
use tokio::task::{JoinError, JoinHandle};
use tokio_util::sync::CancellationToken;
use oxios_gateway::{ActiveWebDist, Gateway, Surface, SurfaceContext};
use crate::kernel::Kernel;
pub enum ShutdownOutcome {
Graceful,
Fatal {
name: String,
reason: String,
},
}
#[derive(Clone)]
enum Tag {
Critical(String),
Web,
}
pub struct RestartConfig {
pub max_retries: u32,
pub reset_after: Duration,
pub initial_backoff: Duration,
pub max_backoff: Duration,
}
impl Default for RestartConfig {
fn default() -> Self {
Self {
max_retries: 5,
reset_after: Duration::from_secs(300),
initial_backoff: Duration::from_millis(500),
max_backoff: Duration::from_secs(30),
}
}
}
type TaggedFut = std::pin::Pin<Box<dyn Future<Output = (Tag, Result<(), JoinError>)> + Send>>;
fn tagged(tag: Tag, handle: JoinHandle<()>) -> TaggedFut {
Box::pin(async move { (tag, handle.await) })
}
pub struct WebSurfaceRestarter {
gateway: Arc<Gateway>,
kernel_handle: Arc<oxios_kernel::KernelHandle>,
config: Arc<RwLock<oxios_kernel::OxiosConfig>>,
config_path: std::path::PathBuf,
web_dist: ActiveWebDist,
}
impl WebSurfaceRestarter {
pub fn new(
kernel: &Kernel,
config: Arc<RwLock<oxios_kernel::OxiosConfig>>,
config_path: std::path::PathBuf,
web_dist: ActiveWebDist,
) -> Self {
Self {
gateway: kernel.gateway(),
kernel_handle: kernel.handle(),
config,
config_path,
web_dist,
}
}
async fn start(&self, shutdown: CancellationToken) -> Result<JoinHandle<()>, anyhow::Error> {
let _ = self.gateway.unregister("web").await;
let surface = crate::api::WebSurface::new();
let ctx = SurfaceContext {
kernel: self.kernel_handle.clone(),
config: self.config.clone(),
config_path: self.config_path.clone(),
web_dist: self.web_dist.clone(),
shutdown,
};
let handle = surface.start(ctx).await?;
if let Some(channel) = handle.channel {
self.gateway.register(channel).await?;
}
Ok(handle
.tasks
.into_iter()
.next()
.expect("web surface spawns a server task"))
}
}
struct WebState {
restarter: Arc<WebSurfaceRestarter>,
retries: u32,
last_start: Instant,
}
struct PendingRestart {
deadline: Instant,
}
pub struct TaskSupervisor {
root: CancellationToken,
restart: RestartConfig,
tasks: FuturesUnordered<TaggedFut>,
web: Option<WebState>,
pending: Option<PendingRestart>,
stop_gateway: Option<Box<dyn FnOnce() + Send>>,
}
impl TaskSupervisor {
pub fn new(root: CancellationToken, restart: RestartConfig) -> Self {
Self {
root,
restart,
tasks: FuturesUnordered::new(),
web: None,
pending: None,
stop_gateway: None,
}
}
pub fn track_critical(&mut self, name: impl Into<String>, handle: JoinHandle<()>) {
self.tasks.push(tagged(Tag::Critical(name.into()), handle));
}
pub fn track_web(&mut self, handle: JoinHandle<()>, restarter: Arc<WebSurfaceRestarter>) {
self.web = Some(WebState {
restarter,
retries: 0,
last_start: Instant::now(),
});
self.tasks.push(tagged(Tag::Web, handle));
}
pub fn with_gateway_stop(&mut self, stop: impl FnOnce() + Send + 'static) {
self.stop_gateway = Some(Box::new(stop));
}
pub async fn run(mut self) -> ShutdownOutcome {
let outcome = self.watch().await;
let drain_timeout = match &outcome {
ShutdownOutcome::Graceful => Duration::from_secs(10),
ShutdownOutcome::Fatal { .. } => Duration::from_secs(3),
};
self.drain(drain_timeout).await;
outcome
}
async fn watch(&mut self) -> ShutdownOutcome {
loop {
let timer = self
.pending
.as_ref()
.map(|p| tokio::time::sleep_until(p.deadline.into()));
tokio::select! {
biased;
_ = tokio::signal::ctrl_c() => return ShutdownOutcome::Graceful,
_ = self.root.cancelled() => return ShutdownOutcome::Graceful,
Some((tag, result)) = self.tasks.next() => {
if let Some(outcome) = self.on_completion(tag, result).await {
return outcome;
}
}
_ = async {
match timer {
Some(sleep) => sleep.await,
None => std::future::pending::<()>().await,
}
} => {
if let Some(outcome) = self.fire_web_restart().await {
return outcome;
}
}
}
}
}
async fn drain(&mut self, timeout: Duration) {
tracing::info!(
timeout_secs = timeout.as_secs(),
"Supervisor draining tracked tasks"
);
self.root.cancel();
if let Some(stop) = self.stop_gateway.take() {
stop();
}
let _ = tokio::time::timeout(timeout, async {
while self.tasks.next().await.is_some() {}
})
.await;
}
async fn on_completion(
&mut self,
tag: Tag,
result: Result<(), JoinError>,
) -> Option<ShutdownOutcome> {
match tag {
Tag::Critical(name) => {
let reason = describe_exit(&name, result);
tracing::error!(task = %name, %reason, "Critical task exited unexpectedly");
Some(ShutdownOutcome::Fatal { name, reason })
}
Tag::Web => {
let reason = describe_exit("web", result);
let ws = self.web.as_mut().expect("web state present while tracked");
if ws.last_start.elapsed() >= self.restart.reset_after {
ws.retries = 0;
}
ws.retries += 1;
if ws.retries > self.restart.max_retries {
tracing::error!(
task = "web",
retries = ws.retries,
%reason,
"Web surface restart budget exhausted; escalating to fatal"
);
return Some(ShutdownOutcome::Fatal {
name: "web".to_string(),
reason: format!("{reason} (retries exhausted: {})", ws.retries),
});
}
let backoff = compute_backoff(&self.restart, ws.retries);
tracing::warn!(
task = "web",
retry = ws.retries,
max = self.restart.max_retries,
backoff_ms = backoff.as_millis() as u64,
%reason,
"Web surface exited; scheduling restart"
);
self.pending = Some(PendingRestart {
deadline: Instant::now() + backoff,
});
None
}
}
}
async fn fire_web_restart(&mut self) -> Option<ShutdownOutcome> {
let _pending = self.pending.take()?;
let ws = self.web.as_mut().expect("web state present while tracked");
let child = self.root.child_token();
match ws.restarter.start(child).await {
Ok(handle) => {
ws.last_start = Instant::now();
self.tasks.push(tagged(Tag::Web, handle));
tracing::info!(task = "web", "Web surface restarted");
oxios_kernel::metrics::get_metrics().inc_supervisor_restart();
None
}
Err(e) => {
ws.retries += 1;
if ws.retries > self.restart.max_retries {
tracing::error!(task = "web", error = %e, "Web restart failed; budget exhausted");
return Some(ShutdownOutcome::Fatal {
name: "web".to_string(),
reason: format!("restart failed: {e}"),
});
}
let backoff = compute_backoff(&self.restart, ws.retries);
tracing::warn!(
task = "web",
retry = ws.retries,
backoff_ms = backoff.as_millis() as u64,
error = %e,
"Web restart failed; rescheduling"
);
self.pending = Some(PendingRestart {
deadline: Instant::now() + backoff,
});
None
}
}
}
}
fn compute_backoff(cfg: &RestartConfig, attempt: u32) -> Duration {
let exp = cfg
.initial_backoff
.saturating_mul(2u32.saturating_pow(attempt.saturating_sub(1)));
exp.min(cfg.max_backoff)
}
fn describe_exit(name: &str, result: Result<(), JoinError>) -> String {
match result {
Ok(()) => format!("'{name}' task completed unexpectedly"),
Err(e) if e.is_cancelled() => format!("'{name}' task was cancelled"),
Err(e) if e.is_panic() => format!("'{name}' task panicked"),
Err(e) => format!("'{name}' task error: {e}"),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn backoff_is_exponential_and_capped() {
let cfg = RestartConfig {
max_retries: 5,
reset_after: Duration::from_secs(300),
initial_backoff: Duration::from_millis(500),
max_backoff: Duration::from_secs(30),
};
assert_eq!(compute_backoff(&cfg, 1), Duration::from_millis(500));
assert_eq!(compute_backoff(&cfg, 2), Duration::from_secs(1));
assert_eq!(compute_backoff(&cfg, 3), Duration::from_secs(2));
assert_eq!(compute_backoff(&cfg, 4), Duration::from_secs(4));
assert_eq!(compute_backoff(&cfg, 100), Duration::from_secs(30));
}
#[test]
fn backoff_attempt_zero_is_initial() {
let cfg = RestartConfig::default();
assert_eq!(compute_backoff(&cfg, 0), cfg.initial_backoff);
}
#[tokio::test]
async fn failfast_task_exit_is_fatal() {
let root = CancellationToken::new();
let mut sup = TaskSupervisor::new(root.clone(), RestartConfig::default());
sup.track_critical("gateway", tokio::spawn(async {}));
let outcome = sup.run().await;
match outcome {
ShutdownOutcome::Fatal { name, .. } => assert_eq!(name, "gateway"),
_ => panic!("expected Fatal for critical task exit"),
}
}
#[tokio::test]
async fn ctrl_c_is_graceful() {
let root = CancellationToken::new();
let mut sup = TaskSupervisor::new(root.clone(), RestartConfig::default());
let task_token = root.clone();
sup.track_critical(
"gateway",
tokio::spawn(async move {
task_token.cancelled().await;
}),
);
root.cancel();
let outcome = sup.run().await;
assert!(matches!(outcome, ShutdownOutcome::Graceful));
}
}