agner_sup/
uniform.rs

1//! Uniform Supervisor
2//! =======
3
4use std::collections::HashSet;
5use std::sync::Arc;
6
7use agner_actors::{ActorID, Context, Event, Exit, Never, Signal, System};
8use agner_init_ack::ContextInitAckExt;
9use agner_utils::result_err_flatten::ResultErrFlattenIn;
10use agner_utils::std_error_pp::StdErrorPP;
11
12use tokio::sync::oneshot;
13
14use crate::common::{CreateChild, StartChildError};
15
16mod child_spec;
17pub use child_spec::UniformChildSpec;
18
19#[derive(Debug, Clone, thiserror::Error)]
20pub enum SupervisorError {
21    #[error("Failed to start a child")]
22    StartChildError(#[source] StartChildError),
23
24    #[error("oneshot-rx error")]
25    OneshotRx(#[source] oneshot::error::RecvError),
26
27    #[error("Timeout")]
28    Timeout(#[source] Arc<tokio::time::error::Elapsed>),
29}
30
31pub async fn start_child<A>(
32    system: &System,
33    sup: ActorID,
34    args: A,
35) -> Result<ActorID, SupervisorError>
36where
37    A: Send + 'static,
38{
39    let (tx, rx) = oneshot::channel();
40    system.send(sup, Message::Start(args, tx)).await;
41    rx.await.err_flatten_in()
42}
43
44pub async fn stop_child<A>(
45    system: &System,
46    sup: ActorID,
47    child: ActorID,
48) -> Result<Exit, SupervisorError>
49where
50    A: Send + 'static,
51{
52    let (tx, rx) = oneshot::channel();
53    system.send(sup, Message::<A>::Stop(child, tx)).await;
54    rx.await.err_flatten_in()
55}
56
57pub enum Message<InArgs> {
58    Start(InArgs, oneshot::Sender<Result<ActorID, SupervisorError>>),
59    Stop(ActorID, oneshot::Sender<Result<Exit, SupervisorError>>),
60}
61
62#[derive(Debug, Clone)]
63pub struct SupSpec<CS>(CS);
64
65impl<CS> SupSpec<CS> {
66    pub fn new(child_spec: CS) -> Self {
67        Self(child_spec)
68    }
69}
70
71/// The behaviour function of the [Uniform Supervisor](crate::uniform).
72pub async fn run<SupArg, B, A, M>(
73    context: &mut Context<Message<SupArg>>,
74    sup_spec: SupSpec<UniformChildSpec<B, A, M>>,
75) -> Result<Never, Exit>
76where
77    UniformChildSpec<B, A, M>: CreateChild<Args = SupArg>,
78    SupArg: Unpin + Send + 'static,
79    B: Send + Sync + 'static,
80    A: Send + Sync + 'static,
81    M: Send + Sync + 'static,
82{
83    context.trap_exit(true).await;
84    context.init_ack_ok(Default::default());
85
86    let SupSpec(mut child_spec) = sup_spec;
87
88    let mut shutting_down = None;
89    let mut children: HashSet<ActorID> = Default::default();
90    loop {
91        match context.next_event().await {
92            Event::Message(Message::Start(args, reply_to)) => {
93                tracing::trace!("starting child");
94
95                let result =
96                    child_spec.create_child(&context.system(), context.actor_id(), args).await;
97
98                if let Some(actor_id) = result.as_ref().ok().copied() {
99                    children.insert(actor_id);
100                }
101
102                tracing::trace!("start result {:?}", result);
103
104                let _ = reply_to.send(result.map_err(Into::into));
105            },
106            Event::Message(Message::Stop(actor_id, reply_to)) =>
107                if children.contains(&actor_id) {
108                    tracing::trace!("stopping child {}", actor_id);
109
110                    let system = context.system();
111                    let job = {
112                        let shutdown_sequence = child_spec.shutdown_sequence().to_owned();
113                        async move {
114                            tracing::trace!("stop-job enter [child: {}]", actor_id);
115                            let result =
116                                crate::common::stop_child(system, actor_id, shutdown_sequence)
117                                    .await;
118
119                            tracing::trace!(
120                                "stop-job done [child: {}; result: {:?}]",
121                                actor_id,
122                                result
123                            );
124
125                            if let Ok(exit) = result {
126                                let _ = reply_to.send(Ok(exit));
127                            }
128                        }
129                    };
130                    context.spawn_job(job).await;
131                } else {
132                    tracing::trace!(
133                        "received a request to stop an unknown actor ({}). Ignoring.",
134                        actor_id
135                    );
136                    let _ = reply_to.send(Ok(Exit::no_actor()));
137                },
138            Event::Signal(Signal::Exit(actor_id, exit_reason)) =>
139                if actor_id == context.actor_id() {
140                    tracing::trace!("received a shutdown signal to myself. Shutting down");
141
142                    shutting_down = Some(exit_reason.to_owned());
143
144                    let system = context.system();
145                    let mut has_some_children = false;
146                    for actor_id in children.iter().copied() {
147                        has_some_children = true;
148
149                        system.exit(actor_id, Exit::shutdown()).await;
150                    }
151
152                    if !has_some_children {
153                        context.exit(exit_reason).await;
154                        unreachable!()
155                    }
156                } else if children.remove(&actor_id) {
157                    tracing::trace!("child {} terminated [exit: {}]", actor_id, exit_reason.pp());
158                    if children.is_empty() {
159                        if let Some(exit_reason) = shutting_down {
160                            tracing::trace!(
161                                "last child terminated. Shutting down: {}",
162                                exit_reason.pp()
163                            );
164                            context.exit(exit_reason).await;
165                            unreachable!()
166                        }
167                    }
168                } else {
169                    tracing::trace!(
170                        "unknown linked process ({}) termianted. Shutting down [exit: {}]",
171                        actor_id,
172                        exit_reason.pp()
173                    );
174                    context.exit(Exit::linked(actor_id, exit_reason)).await;
175                    unreachable!()
176                },
177        }
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184
185    use futures::StreamExt;
186    use std::convert::Infallible;
187    use std::time::Duration;
188
189    use agner_actors::System;
190
191    use crate::common::InitType;
192
193    #[tokio::test]
194    async fn ergonomics() {
195        async fn worker(
196            _context: &mut Context<Infallible>,
197            (worker_id, worker_name): (usize, &'static str),
198        ) -> Result<Never, Exit> {
199            tracing::info!("worker [id: {:?}, name: {:?}]", worker_id, worker_name);
200            tokio::time::sleep(Duration::from_secs(3)).await;
201            std::future::pending().await
202        }
203        let child_spec = UniformChildSpec::uniform()
204            .behaviour(worker)
205            .args_call1({
206                let mut id = 0;
207                move |name| {
208                    id += 1;
209                    (id, name)
210                }
211            })
212            .init_type(InitType::no_ack());
213
214        let sup_spec = SupSpec::new(child_spec);
215
216        let system = System::new(Default::default());
217        let sup = system.spawn(crate::uniform::run, sup_spec, Default::default()).await.unwrap();
218
219        let w1 = start_child(&system, sup, "one").await.unwrap();
220        let w2 = start_child(&system, sup, "two").await.unwrap();
221        let w3 = start_child(&system, sup, "three").await.unwrap();
222
223        let w1_exited = stop_child::<&str>(&system, sup, w1).await.unwrap();
224        assert!(w1_exited.is_shutdown());
225        assert!(system.wait(w1).await.is_shutdown());
226
227        system.exit(sup, Exit::shutdown()).await;
228        assert!(system.wait(sup).await.is_shutdown());
229        assert!(system.wait(w2).await.is_shutdown());
230        assert!(system.wait(w3).await.is_shutdown());
231
232        assert!(system.all_actors().collect::<Vec<_>>().await.is_empty());
233    }
234}
235
236impl From<oneshot::error::RecvError> for SupervisorError {
237    fn from(e: oneshot::error::RecvError) -> Self {
238        Self::OneshotRx(e)
239    }
240}
241impl From<StartChildError> for SupervisorError {
242    fn from(e: StartChildError) -> Self {
243        Self::StartChildError(e)
244    }
245}
246impl From<tokio::time::error::Elapsed> for SupervisorError {
247    fn from(e: tokio::time::error::Elapsed) -> Self {
248        Self::Timeout(Arc::new(e))
249    }
250}