1use std::collections::HashSet;
5use std::sync::Arc;
6
7use agner_actors::{ActorID, Context, Event, Exit, Never, Signal, System};
8use agner_init_ack::ContextInitAckExt;
9use agner_utils::result_err_flatten::ResultErrFlattenIn;
10use agner_utils::std_error_pp::StdErrorPP;
11
12use tokio::sync::oneshot;
13
14use crate::common::{CreateChild, StartChildError};
15
16mod child_spec;
17pub use child_spec::UniformChildSpec;
18
19#[derive(Debug, Clone, thiserror::Error)]
20pub enum SupervisorError {
21 #[error("Failed to start a child")]
22 StartChildError(#[source] StartChildError),
23
24 #[error("oneshot-rx error")]
25 OneshotRx(#[source] oneshot::error::RecvError),
26
27 #[error("Timeout")]
28 Timeout(#[source] Arc<tokio::time::error::Elapsed>),
29}
30
31pub async fn start_child<A>(
32 system: &System,
33 sup: ActorID,
34 args: A,
35) -> Result<ActorID, SupervisorError>
36where
37 A: Send + 'static,
38{
39 let (tx, rx) = oneshot::channel();
40 system.send(sup, Message::Start(args, tx)).await;
41 rx.await.err_flatten_in()
42}
43
44pub async fn stop_child<A>(
45 system: &System,
46 sup: ActorID,
47 child: ActorID,
48) -> Result<Exit, SupervisorError>
49where
50 A: Send + 'static,
51{
52 let (tx, rx) = oneshot::channel();
53 system.send(sup, Message::<A>::Stop(child, tx)).await;
54 rx.await.err_flatten_in()
55}
56
57pub enum Message<InArgs> {
58 Start(InArgs, oneshot::Sender<Result<ActorID, SupervisorError>>),
59 Stop(ActorID, oneshot::Sender<Result<Exit, SupervisorError>>),
60}
61
62#[derive(Debug, Clone)]
63pub struct SupSpec<CS>(CS);
64
65impl<CS> SupSpec<CS> {
66 pub fn new(child_spec: CS) -> Self {
67 Self(child_spec)
68 }
69}
70
71pub async fn run<SupArg, B, A, M>(
73 context: &mut Context<Message<SupArg>>,
74 sup_spec: SupSpec<UniformChildSpec<B, A, M>>,
75) -> Result<Never, Exit>
76where
77 UniformChildSpec<B, A, M>: CreateChild<Args = SupArg>,
78 SupArg: Unpin + Send + 'static,
79 B: Send + Sync + 'static,
80 A: Send + Sync + 'static,
81 M: Send + Sync + 'static,
82{
83 context.trap_exit(true).await;
84 context.init_ack_ok(Default::default());
85
86 let SupSpec(mut child_spec) = sup_spec;
87
88 let mut shutting_down = None;
89 let mut children: HashSet<ActorID> = Default::default();
90 loop {
91 match context.next_event().await {
92 Event::Message(Message::Start(args, reply_to)) => {
93 tracing::trace!("starting child");
94
95 let result =
96 child_spec.create_child(&context.system(), context.actor_id(), args).await;
97
98 if let Some(actor_id) = result.as_ref().ok().copied() {
99 children.insert(actor_id);
100 }
101
102 tracing::trace!("start result {:?}", result);
103
104 let _ = reply_to.send(result.map_err(Into::into));
105 },
106 Event::Message(Message::Stop(actor_id, reply_to)) =>
107 if children.contains(&actor_id) {
108 tracing::trace!("stopping child {}", actor_id);
109
110 let system = context.system();
111 let job = {
112 let shutdown_sequence = child_spec.shutdown_sequence().to_owned();
113 async move {
114 tracing::trace!("stop-job enter [child: {}]", actor_id);
115 let result =
116 crate::common::stop_child(system, actor_id, shutdown_sequence)
117 .await;
118
119 tracing::trace!(
120 "stop-job done [child: {}; result: {:?}]",
121 actor_id,
122 result
123 );
124
125 if let Ok(exit) = result {
126 let _ = reply_to.send(Ok(exit));
127 }
128 }
129 };
130 context.spawn_job(job).await;
131 } else {
132 tracing::trace!(
133 "received a request to stop an unknown actor ({}). Ignoring.",
134 actor_id
135 );
136 let _ = reply_to.send(Ok(Exit::no_actor()));
137 },
138 Event::Signal(Signal::Exit(actor_id, exit_reason)) =>
139 if actor_id == context.actor_id() {
140 tracing::trace!("received a shutdown signal to myself. Shutting down");
141
142 shutting_down = Some(exit_reason.to_owned());
143
144 let system = context.system();
145 let mut has_some_children = false;
146 for actor_id in children.iter().copied() {
147 has_some_children = true;
148
149 system.exit(actor_id, Exit::shutdown()).await;
150 }
151
152 if !has_some_children {
153 context.exit(exit_reason).await;
154 unreachable!()
155 }
156 } else if children.remove(&actor_id) {
157 tracing::trace!("child {} terminated [exit: {}]", actor_id, exit_reason.pp());
158 if children.is_empty() {
159 if let Some(exit_reason) = shutting_down {
160 tracing::trace!(
161 "last child terminated. Shutting down: {}",
162 exit_reason.pp()
163 );
164 context.exit(exit_reason).await;
165 unreachable!()
166 }
167 }
168 } else {
169 tracing::trace!(
170 "unknown linked process ({}) termianted. Shutting down [exit: {}]",
171 actor_id,
172 exit_reason.pp()
173 );
174 context.exit(Exit::linked(actor_id, exit_reason)).await;
175 unreachable!()
176 },
177 }
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184
185 use futures::StreamExt;
186 use std::convert::Infallible;
187 use std::time::Duration;
188
189 use agner_actors::System;
190
191 use crate::common::InitType;
192
193 #[tokio::test]
194 async fn ergonomics() {
195 async fn worker(
196 _context: &mut Context<Infallible>,
197 (worker_id, worker_name): (usize, &'static str),
198 ) -> Result<Never, Exit> {
199 tracing::info!("worker [id: {:?}, name: {:?}]", worker_id, worker_name);
200 tokio::time::sleep(Duration::from_secs(3)).await;
201 std::future::pending().await
202 }
203 let child_spec = UniformChildSpec::uniform()
204 .behaviour(worker)
205 .args_call1({
206 let mut id = 0;
207 move |name| {
208 id += 1;
209 (id, name)
210 }
211 })
212 .init_type(InitType::no_ack());
213
214 let sup_spec = SupSpec::new(child_spec);
215
216 let system = System::new(Default::default());
217 let sup = system.spawn(crate::uniform::run, sup_spec, Default::default()).await.unwrap();
218
219 let w1 = start_child(&system, sup, "one").await.unwrap();
220 let w2 = start_child(&system, sup, "two").await.unwrap();
221 let w3 = start_child(&system, sup, "three").await.unwrap();
222
223 let w1_exited = stop_child::<&str>(&system, sup, w1).await.unwrap();
224 assert!(w1_exited.is_shutdown());
225 assert!(system.wait(w1).await.is_shutdown());
226
227 system.exit(sup, Exit::shutdown()).await;
228 assert!(system.wait(sup).await.is_shutdown());
229 assert!(system.wait(w2).await.is_shutdown());
230 assert!(system.wait(w3).await.is_shutdown());
231
232 assert!(system.all_actors().collect::<Vec<_>>().await.is_empty());
233 }
234}
235
236impl From<oneshot::error::RecvError> for SupervisorError {
237 fn from(e: oneshot::error::RecvError) -> Self {
238 Self::OneshotRx(e)
239 }
240}
241impl From<StartChildError> for SupervisorError {
242 fn from(e: StartChildError) -> Self {
243 Self::StartChildError(e)
244 }
245}
246impl From<tokio::time::error::Elapsed> for SupervisorError {
247 fn from(e: tokio::time::error::Elapsed) -> Self {
248 Self::Timeout(Arc::new(e))
249 }
250}