crb_superagent/supervisor/
mod.rs

1pub mod forward;
2pub mod stacker;
3
4pub use forward::ForwardTo;
5pub use stacker::Stacker;
6
7use anyhow::Error;
8use async_trait::async_trait;
9use crb_agent::{
10    Address, Agent, AgentContext, AgentSession, Context, Envelope, Link, MessageFor, RunAgent,
11};
12use crb_core::Tag;
13use crb_runtime::{
14    InteractiveRuntime, InterruptionLevel, Interruptor, ManagedContext, ReachableContext, Runtime,
15};
16use derive_more::{Deref, DerefMut, From, Into};
17use std::cmp::Ordering;
18use std::collections::{BTreeMap, HashSet};
19use std::fmt::Debug;
20use std::hash::{Hash, Hasher};
21use typed_slab::TypedSlab;
22
23#[derive(Debug, Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Hash, From, Into)]
24pub struct ActivityId(usize);
25
26pub trait Supervisor: Agent {
27    type BasedOn: AgentContext<Self>;
28    type GroupBy: Debug + Ord + Clone + Send + Eq + Hash;
29
30    fn finished(&mut self, _rel: &Relation<Self>, _ctx: &mut Context<Self>) {}
31}
32
33pub trait SupervisorContext<S: Supervisor> {
34    fn tracker(&mut self) -> &mut Tracker<S>;
35    fn session(&mut self) -> &mut SupervisorSession<S>;
36}
37
38#[derive(Deref, DerefMut)]
39pub struct SupervisorSession<S: Supervisor> {
40    #[deref]
41    #[deref_mut]
42    pub session: S::BasedOn,
43    pub tracker: Tracker<S>,
44}
45
46impl<S> Default for SupervisorSession<S>
47where
48    S: Supervisor,
49    S::BasedOn: Default,
50{
51    fn default() -> Self {
52        Self {
53            session: S::BasedOn::default(),
54            tracker: Tracker::new(),
55        }
56    }
57}
58
59impl<S: Supervisor> ReachableContext for SupervisorSession<S> {
60    type Address = Address<S>;
61
62    fn address(&self) -> &Self::Address {
63        self.session.address()
64    }
65}
66
67impl<S: Supervisor> AsRef<Address<S>> for SupervisorSession<S> {
68    fn as_ref(&self) -> &Address<S> {
69        self.address()
70    }
71}
72
73impl<S: Supervisor> ManagedContext for SupervisorSession<S> {
74    fn is_alive(&self) -> bool {
75        self.session.is_alive() && !self.tracker.is_terminated()
76    }
77
78    fn shutdown(&mut self) {
79        self.tracker.terminate_all();
80        if self.tracker.is_terminated() {
81            self.session.shutdown();
82        }
83    }
84
85    fn stop(&mut self) {
86        self.session.stop();
87    }
88}
89
90#[async_trait]
91impl<S: Supervisor> AgentContext<S> for SupervisorSession<S> {
92    fn session(&mut self) -> &mut AgentSession<S> {
93        self.session.session()
94    }
95
96    async fn next_envelope(&mut self) -> Option<Envelope<S>> {
97        self.session.next_envelope().await
98    }
99}
100
101impl<S: Supervisor> SupervisorContext<S> for SupervisorSession<S> {
102    fn tracker(&mut self) -> &mut Tracker<S> {
103        &mut self.tracker
104    }
105
106    fn session(&mut self) -> &mut SupervisorSession<S> {
107        self
108    }
109}
110
111#[derive(Debug, Default)]
112struct Group {
113    interrupted: bool,
114    ids: HashSet<ActivityId>,
115}
116
117impl Group {
118    fn is_finished(&self) -> bool {
119        self.interrupted && self.ids.is_empty()
120    }
121}
122
123pub struct Tracker<S: Supervisor> {
124    groups: BTreeMap<S::GroupBy, Group>,
125    activities: TypedSlab<ActivityId, Activity<S>>,
126    terminating: bool,
127}
128
129impl<S: Supervisor> Default for Tracker<S> {
130    fn default() -> Self {
131        Self::new()
132    }
133}
134
135impl<S: Supervisor> Tracker<S> {
136    pub fn new() -> Self {
137        Self {
138            groups: BTreeMap::new(),
139            activities: TypedSlab::new(),
140            terminating: false,
141        }
142    }
143
144    pub fn is_empty(&self) -> bool {
145        self.groups.is_empty() && self.activities.is_empty()
146    }
147
148    pub fn is_terminated(&self) -> bool {
149        self.terminating && self.is_empty()
150    }
151
152    pub fn terminate_group(&mut self, group: S::GroupBy) {
153        if let Some(group) = self.groups.get(&group) {
154            for id in group.ids.iter() {
155                if let Some(activity) = self.activities.get_mut(*id) {
156                    activity.interrupt();
157                }
158            }
159        }
160    }
161
162    pub fn terminate_all(&mut self) {
163        self.try_terminate_next();
164    }
165
166    fn register_activity(&mut self, activity: Activity<S>) -> Relation<S> {
167        let group = activity.group.clone();
168        let id = self.activities.insert(activity);
169        let group_record = self.groups.entry(group.clone()).or_default();
170        group_record.ids.insert(id);
171        if group_record.interrupted {
172            // Interrupt if the group is terminating
173            self.activities.get_mut(id).map(Activity::interrupt);
174        }
175        Relation { id, group }
176    }
177
178    fn unregister_activity(&mut self, rel: &Relation<S>) {
179        if let Some(activity) = self.activities.remove(rel.id) {
180            // TODO: check rel.group == activity.group ?
181            if let Some(group) = self.groups.get_mut(&activity.group) {
182                group.ids.remove(&rel.id);
183                if group.ids.is_empty() {
184                    self.groups.remove(&activity.group);
185                }
186            }
187        }
188        if self.terminating {
189            self.try_terminate_next();
190        }
191    }
192
193    fn existing_groups(&self) -> Vec<S::GroupBy> {
194        self.groups.keys().rev().cloned().collect()
195    }
196
197    fn try_terminate_next(&mut self) {
198        self.terminating = true;
199        for group_name in self.existing_groups() {
200            if let Some(group) = self.groups.get_mut(&group_name) {
201                let name = std::any::type_name::<S>();
202                log::trace!("Agent {name} is terminating group: {:?}", group_name);
203                if !group.interrupted {
204                    group.interrupted = true;
205                    // Send an interruption signal to all active members of the group.
206                    for id in group.ids.iter() {
207                        if let Some(activity) = self.activities.get_mut(*id) {
208                            activity.interrupt();
209                        }
210                    }
211                }
212                if !group.is_finished() {
213                    break;
214                }
215            }
216        }
217    }
218}
219
220impl<S> SupervisorSession<S>
221where
222    S: Supervisor,
223    S::Context: SupervisorContext<S>,
224{
225    pub fn spawn_agent<A>(&mut self, agent: A, group: S::GroupBy) -> (Link<A>, Relation<S>)
226    where
227        A: Agent,
228        A::Context: Default,
229    {
230        let runtime = RunAgent::<A>::new(agent);
231        let (address, relation) = self.spawn_runtime(runtime, group);
232        (address.into(), relation)
233    }
234
235    pub fn spawn_runtime<B>(
236        &mut self,
237        trackable: B,
238        group: S::GroupBy,
239    ) -> (<B::Context as ReachableContext>::Address, Relation<S>)
240    where
241        B: InteractiveRuntime,
242    {
243        let addr = trackable.address();
244        let rel = self.spawn_trackable(trackable, group);
245        (addr, rel)
246    }
247
248    pub fn spawn_trackable<B>(&mut self, mut trackable: B, group: S::GroupBy) -> Relation<S>
249    where
250        B: Runtime,
251    {
252        let activity = Activity {
253            group,
254            interruptor: trackable.get_interruptor(),
255            level: trackable.interruption_level(),
256        };
257        let rel = self.tracker.register_activity(activity);
258        let detacher = DetacherFor {
259            supervisor: self.address().clone(),
260            rel: rel.clone(),
261        };
262
263        let fut = async move {
264            let name = std::any::type_name::<S>();
265            let rn_name = std::any::type_name::<B>();
266            trackable.routine().await;
267            // This notification equals calling `detach_trackable`
268            if let Err(err) = detacher.detach() {
269                log::error!(
270                    "Can't notify a supervisor {name} from {rn_name} to detach an activity: {err}"
271                );
272            } else {
273                log::debug!("A supervisor {name} notified about termination of {rn_name}.");
274            }
275        };
276        crb_core::spawn(fut);
277        rel
278    }
279
280    pub fn assign<R, T>(&mut self, trackable: R, group: S::GroupBy, tag: T) -> Relation<S>
281    where
282        R: ForwardTo<S, T>,
283        T: Tag,
284    {
285        let address = self.address().clone();
286        let trackable = trackable.into_trackable(address, tag);
287        self.spawn_trackable(trackable, group)
288    }
289}
290
291struct Activity<S: Supervisor> {
292    group: S::GroupBy,
293    interruptor: Box<dyn Interruptor>,
294    level: InterruptionLevel,
295}
296
297impl<S: Supervisor> Activity<S> {
298    fn interrupt(&mut self) {
299        self.interruptor.interrupt_with_level(self.level);
300    }
301}
302
303pub struct Relation<S: Supervisor> {
304    pub id: ActivityId,
305    pub group: S::GroupBy,
306}
307
308impl<S: Supervisor> Clone for Relation<S> {
309    fn clone(&self) -> Self {
310        Self {
311            id: self.id,
312            group: self.group.clone(),
313        }
314    }
315}
316
317impl<S: Supervisor> PartialEq for Relation<S> {
318    fn eq(&self, other: &Self) -> bool {
319        self.id == other.id && self.group == other.group
320    }
321}
322
323impl<S: Supervisor> Eq for Relation<S> {}
324
325impl<S: Supervisor> PartialOrd for Relation<S> {
326    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
327        Some(self.cmp(other))
328    }
329}
330
331impl<S: Supervisor> Ord for Relation<S> {
332    fn cmp(&self, other: &Self) -> Ordering {
333        self.id
334            .cmp(&other.id)
335            .then_with(|| self.group.cmp(&other.group))
336    }
337}
338
339impl<S: Supervisor> Hash for Relation<S> {
340    fn hash<H: Hasher>(&self, state: &mut H) {
341        self.id.hash(state);
342        self.group.hash(state);
343    }
344}
345
346struct DetachFrom<S: Supervisor> {
347    rel: Relation<S>,
348}
349
350#[async_trait]
351impl<S> MessageFor<S> for DetachFrom<S>
352where
353    S: Supervisor,
354    S::Context: SupervisorContext<S>,
355{
356    async fn handle(self: Box<Self>, agent: &mut S, ctx: &mut Context<S>) -> Result<(), Error> {
357        let tracker = ctx.tracker();
358        tracker.unregister_activity(&self.rel);
359        if tracker.is_terminated() {
360            ctx.shutdown();
361        }
362        agent.finished(&self.rel, ctx);
363        Ok(())
364    }
365}
366
367pub struct DetacherFor<S: Supervisor> {
368    rel: Relation<S>,
369    supervisor: <S::Context as ReachableContext>::Address,
370}
371
372impl<S> DetacherFor<S>
373where
374    S: Supervisor,
375    S::Context: SupervisorContext<S>,
376{
377    pub fn detach(self) -> Result<(), Error> {
378        let msg = DetachFrom { rel: self.rel };
379        self.supervisor.send(msg)
380    }
381}