1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
use std::{
    ops::Deref,
    sync::{Arc, Mutex},
};

use async_channel::{Receiver, Sender};
use futures_lite::prelude::*;
use slab::Slab;

/// A nursery represents a dynamic scope in which tasks can be spawned. It is used for *structured concurrency*, and it ensures that tasks spawned within the nursery terminate before the nursery falls out of scope.
///
/// We intentionally force all futures spawned in the nursery to return `anyhow::Result<()>`, and we do not expose join handles. This encourages a CSP-style anonymous-process way of thinking, and integration with `anyhow` allows for powerful error-propagation techniques.
///
/// Spawning is done through a NurseryHandle, which is clonable, rather than the Nursery itself. Nursery `Deref`'s into a `NurseryHandle` so that you can spawn from Nursery as well.
pub struct Nursery {
    nhandle: NurseryHandle,
    recv_error: Receiver<anyhow::Error>,
    recv_term: Receiver<()>,
}

impl Deref for Nursery {
    type Target = NurseryHandle;
    fn deref(&self) -> &Self::Target {
        &self.nhandle
    }
}

impl Default for Nursery {
    fn default() -> Self {
        Self::new()
    }
}

impl Nursery {
    /// Creates a new nursery
    pub fn new() -> Self {
        let (send_error, recv_error) = async_channel::unbounded();
        let (send_term, recv_term) = async_channel::unbounded();
        Self {
            nhandle: NurseryHandle {
                task_holder: Arc::new(Mutex::new(Slab::default())),
                send_error,
                send_term,
            },
            recv_error,
            recv_term,
        }
    }

    /// To handle
    pub fn handle(&self) -> NurseryHandle {
        self.nhandle.clone()
    }

    /// Waits for the tasks in the nursery to terminate. If any errors are propagated, immediately returns the error, terminating the whole nursery.
    ///
    /// This function asynchronously blocks until all NurseryHandles are dropped.
    pub async fn wait(self) -> anyhow::Result<()> {
        // simultaneously poll tasks and errors
        let a = async {
            if self.recv_error.sender_count() > 1 {
                let next_error = self
                    .recv_error
                    .recv()
                    .await
                    .expect("recv_error should never fail");
                Err(next_error)
            } else {
                Ok(())
            }
        };
        let b = async {
            while self.recv_error.sender_count() > 1 {
                self.recv_term.recv().await?;
            }
            Ok(())
        };
        a.or(b).await
    }

    /// Helper function that waits for nursery tasks synchronously.
    pub fn wait_sync(self) -> anyhow::Result<()> {
        futures_lite::future::block_on(self.wait())
    }
}

#[derive(Clone)]
pub struct NurseryHandle {
    task_holder: Arc<Mutex<Slab<async_executor::Task<()>>>>,
    send_error: Sender<anyhow::Error>,
    send_term: Sender<()>,
}

impl NurseryHandle {
    /// Spawns a task in the nursery, using the given recovery strategy. Takes a closure that returns a future because the task may be restarted on failure.
    pub fn spawn<F: Future<Output = anyhow::Result<()>> + Send + 'static>(
        &self,
        mut on_error: OnError,
        task_gen: impl FnOnce(NurseryHandle) -> F + Send + 'static,
    ) {
        let send_error = self.send_error.clone();
        let this = self.clone();
        let (send_tid, recv_tid) = async_oneshot::oneshot();
        let task_holder = self.task_holder.clone();
        let task = crate::spawn(async move {
            scopeguard::defer!({
                let _ = this.send_term.try_send(());
            });
            let send_error = send_error.clone();
            let result = task_gen(this.clone()).await;
            match result {
                Ok(()) => (),
                Err(err) => {
                    while let OnError::Custom(f) = on_error {
                        on_error = f(&err)
                    }
                    match on_error {
                        OnError::Ignore => (),
                        OnError::Propagate => {
                            let _ = send_error.send(err).await;
                        }
                        _ => unreachable!(),
                    }
                    if let Ok(tid) = recv_tid.await {
                        drop(task_holder.lock().unwrap().remove(tid));
                    };
                }
            }
        });
        let task_id = self.task_holder.lock().unwrap().insert(task);
        if send_tid.send(task_id).is_err() {
            drop(self.task_holder.lock().unwrap().remove(task_id));
        }
    }
}

/// The strategy used to recover from errors that a task returns.
pub enum OnError {
    Ignore,
    Propagate,
    Custom(Box<dyn FnOnce(&anyhow::Error) -> OnError + Send + Sync>),
}

impl OnError {
    /// Creates an error strategy based on the given closure. Convenience wrapper over `OnError::Custom` that boxes the closure for you.
    pub fn custom(f: impl FnOnce(&anyhow::Error) -> OnError + 'static + Send + Sync) -> Self {
        Self::Custom(Box::new(f))
    }
    /// Creates an error strategy that runs the given closure and ignores the error.
    pub fn ignore_with(f: impl FnOnce(&anyhow::Error) + 'static + Send + Sync) -> Self {
        Self::Custom(Box::new(move |e| {
            f(e);
            Self::Ignore
        }))
    }
    /// Creates an error strategy that runs the given closure and propagates it to the nursery.
    pub fn propagate_with(f: impl FnOnce(&anyhow::Error) + 'static + Send + Sync) -> Self {
        Self::Custom(Box::new(move |e| {
            f(e);
            Self::Propagate
        }))
    }
}

#[cfg(test)]
mod tests {
    use std::sync::atomic::AtomicUsize;

    use super::*;

    #[test]
    fn nursery_simple() {
        let nursery = Nursery::new();
        let counter = Arc::new(AtomicUsize::new(0));
        nursery.spawn(OnError::Ignore, {
            let counter = counter;
            move |nursery| async move {
                eprintln!("hello world");
                nursery.spawn(
                    OnError::propagate_with(|e| eprintln!("error: {}", e)),
                    |_| async {
                        eprintln!("attempt");
                        anyhow::bail!("oh no");
                    },
                );
                drop(nursery);
                drop(counter);
                Ok(())
            }
        });
        assert!(nursery.wait_sync().is_err())
    }
}