1use ractor::concurrency::{Duration, Instant};
2use ractor::{ActorCell, ActorId, ActorProcessingErr, ActorRef, Message, SpawnErr};
3use std::collections::HashMap;
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::Arc;
7use thiserror::Error;
8
9#[derive(Error, Debug, Clone)]
11pub enum SupervisorError {
12 #[error("Child '{child_id}' not found in specs")]
13 ChildNotFound { child_id: String },
14
15 #[error("Child '{pid}' does not have a name set")]
16 ChildNameNotSet { pid: ActorId },
17
18 #[error("Spawn error '{child_id}': {reason}")]
19 ChildSpawnError { child_id: String, reason: String },
20
21 #[error("Meltdown: {reason}")]
22 Meltdown { reason: String },
23}
24
25#[derive(Clone)]
35pub struct ChildBackoffFn(pub Arc<BackoffFn>);
36
37type BackoffFn = dyn Fn(&str, usize, Instant, Option<Duration>) -> Option<Duration> + Send + Sync;
38
39impl ChildBackoffFn {
40 pub fn new<F>(func: F) -> Self
50 where
51 F: Fn(&str, usize, Instant, Option<Duration>) -> Option<Duration> + Send + Sync + 'static,
52 {
53 Self(Arc::new(func))
54 }
55
56 pub fn call(
58 &self,
59 child_id: &str,
60 restart_count: usize,
61 last_restart: Instant,
62 reset_after: Option<Duration>,
63 ) -> Option<Duration> {
64 (self.0)(child_id, restart_count, last_restart, reset_after)
65 }
66}
67
68pub type SpawnFuture = Pin<Box<dyn Future<Output = Result<ActorCell, SpawnErr>> + Send>>;
70
71#[derive(Clone)]
73pub struct SpawnFn(pub Arc<dyn Fn(ActorCell, String) -> SpawnFuture + Send + Sync>);
74
75impl SpawnFn {
76 pub fn new<F, Fut>(func: F) -> Self
86 where
87 F: Fn(ActorCell, String) -> Fut + Send + Sync + 'static,
88 Fut: Future<Output = Result<ActorCell, SpawnErr>> + Send + 'static,
89 {
90 Self(Arc::new(move |cell, id| Box::pin(func(cell, id))))
91 }
92
93 pub fn call(&self, cell: ActorCell, id: String) -> SpawnFuture {
95 (self.0)(cell, id)
96 }
97}
98
99#[derive(Debug, Clone, Copy, PartialEq, Eq)]
101pub enum Restart {
102 Permanent,
104 Transient,
107 Temporary,
109}
110
111#[derive(Clone)]
113pub struct ChildSpec {
114 pub id: String,
123
124 pub restart: Restart,
126
127 pub spawn_fn: SpawnFn,
129
130 pub backoff_fn: Option<ChildBackoffFn>,
132
133 pub reset_after: Option<Duration>,
136}
137
138impl std::fmt::Debug for ChildSpec {
139 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
140 f.debug_struct("ChildSpec")
141 .field("id", &self.id)
142 .field("restart", &self.restart)
143 .field("reset_after", &self.reset_after)
144 .finish()
145 }
146}
147
148#[derive(Debug, Clone)]
150pub struct ChildFailureState {
151 pub restart_count: usize,
152 pub last_fail_instant: Instant,
153}
154
155#[derive(Clone, Debug)]
157pub struct RestartLog {
158 pub child_id: String,
159 pub timestamp: Instant,
160}
161
162pub trait CoreSupervisorOptions<Strategy> {
163 fn max_restarts(&self) -> usize;
164 fn max_window(&self) -> Duration;
165 fn reset_after(&self) -> Option<Duration>;
166 fn strategy(&self) -> Strategy;
167}
168
169#[derive(Debug)]
170pub enum ExitReason {
171 Normal,
172 Reason(Option<String>),
173 Error(Box<dyn std::error::Error + Send + Sync>),
174}
175
176pub trait SupervisorCore {
177 type Message: Message;
178 type Strategy;
179 type Options: CoreSupervisorOptions<Self::Strategy>;
180
181 fn child_failure_state(&mut self) -> &mut HashMap<String, ChildFailureState>;
182 fn restart_log(&mut self) -> &mut Vec<RestartLog>;
183 fn options(&self) -> &Self::Options;
184 fn restart_msg(
185 &self,
186 child_spec: &ChildSpec,
187 strategy: Self::Strategy,
188 myself: ActorRef<Self::Message>,
189 ) -> Self::Message;
190
191 fn prepare_child_failure(&mut self, child_spec: &ChildSpec) -> Result<(), ActorProcessingErr> {
194 let child_id = &child_spec.id;
195 let now = Instant::now();
196 let entry = self
197 .child_failure_state()
198 .entry(child_id.clone())
199 .or_insert_with(|| ChildFailureState {
200 restart_count: 0,
201 last_fail_instant: now,
202 });
203
204 if let Some(threshold) = child_spec.reset_after {
205 if now.duration_since(entry.last_fail_instant) >= threshold {
206 entry.restart_count = 0;
207 }
208 }
209
210 entry.restart_count += 1;
211 entry.last_fail_instant = now;
212 Ok(())
213 }
214
215 fn handle_child_exit(
221 &mut self,
222 child_spec: &ChildSpec,
223 abnormal: bool,
224 ) -> Result<bool, ActorProcessingErr> {
225 let policy = child_spec.restart;
226
227 let should_restart = match policy {
229 Restart::Permanent => true,
230 Restart::Transient => abnormal,
231 Restart::Temporary => false,
232 };
233
234 if should_restart {
235 self.prepare_child_failure(child_spec)?;
236 }
237
238 Ok(should_restart)
239 }
240
241 fn handle_child_restart(
245 &mut self,
246 child_spec: &ChildSpec,
247 abnormal: bool,
248 myself: ActorRef<Self::Message>,
249 reason: &ExitReason,
250 ) -> Result<(), ActorProcessingErr> {
251 if self.handle_child_exit(child_spec, abnormal)? {
252 log_child_restart(child_spec, abnormal, reason);
253 self.schedule_restart(child_spec, self.options().strategy(), myself.clone())?;
254 }
255
256 Ok(())
257 }
258
259 fn track_global_restart(&mut self, child_id: &str) -> Result<(), ActorProcessingErr> {
265 let now: Instant = Instant::now();
266
267 let max_restarts = self.options().max_restarts();
268 let max_window = self.options().max_window();
269 let reset_after = self.options().reset_after();
270
271 let restart_log = self.restart_log();
272
273 if let (Some(thresh), Some(latest)) = (reset_after, restart_log.last()) {
274 if now.duration_since(latest.timestamp) >= thresh {
275 restart_log.clear();
276 }
277 }
278
279 restart_log.push(RestartLog {
280 child_id: child_id.to_string(),
281 timestamp: now,
282 });
283
284 restart_log.retain(|t| now.duration_since(t.timestamp) < max_window);
285
286 if restart_log.len() > max_restarts {
287 Err(SupervisorError::Meltdown {
288 reason: "max_restarts exceeded".to_string(),
289 }
290 .into())
291 } else {
292 Ok(())
293 }
294 }
295
296 fn schedule_restart(
298 &mut self,
299 child_spec: &ChildSpec,
300 strategy: Self::Strategy,
301 myself: ActorRef<Self::Message>,
302 ) -> Result<(), ActorProcessingErr> {
303 let child_id = &child_spec.id;
304
305 let (restart_count, last_fail_instant) = {
306 let failure_state = self.child_failure_state();
307 let st = failure_state
308 .get(child_id)
309 .ok_or(SupervisorError::ChildNotFound {
310 child_id: child_id.clone(),
311 })?;
312 (st.restart_count, st.last_fail_instant)
313 };
314 let msg = self.restart_msg(child_spec, strategy, myself.clone());
315
316 let delay = child_spec
317 .backoff_fn
318 .as_ref()
319 .and_then(|cb: &ChildBackoffFn| {
320 cb.call(
321 child_id,
322 restart_count,
323 last_fail_instant,
324 child_spec.reset_after,
325 )
326 });
327
328 match delay {
329 Some(delay) => {
330 myself.send_after(delay, move || msg);
331 }
332 None => {
333 myself.send_message(msg)?;
334 }
335 }
336
337 Ok(())
338 }
339}
340
341fn log_child_restart(child_spec: &ChildSpec, abnormal: bool, reason: &ExitReason) {
342 match (abnormal, reason) {
343 (true, ExitReason::Error(err)) => log::error!(
344 "Restarting child: {}, exit: abnormal, error: {:?}",
345 child_spec.id,
346 err
347 ),
348 (false, ExitReason::Error(err)) => log::warn!(
349 "Restarting child: {}, exit: normal, error: {:?}",
350 child_spec.id,
351 err
352 ),
353 (true, ExitReason::Reason(Some(reason))) => log::error!(
354 "Restarting child: {}, exit: abnormal, reason: {}",
355 child_spec.id,
356 reason
357 ),
358 (false, ExitReason::Reason(Some(reason))) => log::warn!(
359 "Restarting child: {}, exit: normal, reason: {}",
360 child_spec.id,
361 reason
362 ),
363 (true, ExitReason::Reason(None)) => {
364 log::error!("Restarting child: {}, exit: abnormal", child_spec.id)
365 }
366 (false, ExitReason::Reason(None)) => {
367 log::warn!("Restarting child: {}, exit: normal", child_spec.id)
368 }
369 (true, ExitReason::Normal) => {
370 log::error!("Restarting child: {}, exit: abnormal", child_spec.id)
371 }
372 (false, ExitReason::Normal) => {
373 log::warn!("Restarting child: {}, exit: normal", child_spec.id)
374 }
375 }
376}