crb_supervisor/
agent.rs

1use anyhow::Error;
2use async_trait::async_trait;
3use crb_agent::{Address, Agent, AgentContext, AgentSession, MessageFor, RunAgent};
4use crb_runtime::{Context, Controller, InteractiveRuntime, Interruptor, ManagedContext, Runtime};
5use derive_more::{Deref, DerefMut, From, Into};
6use std::collections::{BTreeMap, HashSet};
7use std::fmt::Debug;
8use std::hash::Hash;
9use typed_slab::TypedSlab;
10
11pub trait Supervisor: Agent {
12    type GroupBy: Debug + Ord + Clone + Sync + Send + Eq + Hash;
13}
14
15pub trait SupervisorContext<S: Supervisor> {
16    fn session(&mut self) -> &mut SupervisorSession<S>;
17}
18
19#[derive(Deref, DerefMut)]
20pub struct SupervisorSession<S: Supervisor> {
21    #[deref]
22    #[deref_mut]
23    session: AgentSession<S>,
24    tracker: Tracker<S>,
25}
26
27impl<S: Supervisor> Default for SupervisorSession<S> {
28    fn default() -> Self {
29        Self {
30            session: AgentSession::default(),
31            tracker: Tracker::new(),
32        }
33    }
34}
35
36impl<S: Supervisor> Context for SupervisorSession<S> {
37    type Address = Address<S>;
38
39    fn address(&self) -> &Self::Address {
40        self.session.address()
41    }
42}
43
44impl<S: Supervisor> ManagedContext for SupervisorSession<S> {
45    fn controller(&mut self) -> &mut Controller {
46        self.session.controller()
47    }
48
49    fn shutdown(&mut self) {
50        self.session.shutdown();
51    }
52}
53
54impl<S: Supervisor> AgentContext<S> for SupervisorSession<S> {
55    fn session(&mut self) -> &mut AgentSession<S> {
56        &mut self.session
57    }
58}
59
60impl<S: Supervisor> SupervisorContext<S> for SupervisorSession<S> {
61    fn session(&mut self) -> &mut SupervisorSession<S> {
62        self
63    }
64}
65
66impl<S: Supervisor> SupervisorSession<S> {
67    pub fn spawn_actor<A>(
68        &mut self,
69        input: A,
70        group: S::GroupBy,
71    ) -> <A::Context as Context>::Address
72    where
73        A: Agent,
74        A::Context: Default,
75        S: Supervisor<Context = SupervisorSession<S>>,
76    {
77        let runtime = RunAgent::<A>::new(input);
78        self.spawn_runtime(runtime, group)
79    }
80}
81
82#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, From, Into)]
83pub struct ActivityId(usize);
84
85#[derive(Debug, Default)]
86struct Group {
87    interrupted: bool,
88    ids: HashSet<ActivityId>,
89}
90
91impl Group {
92    fn is_finished(&self) -> bool {
93        self.interrupted && self.ids.is_empty()
94    }
95}
96
97pub struct Tracker<S: Supervisor> {
98    groups: BTreeMap<S::GroupBy, Group>,
99    activities: TypedSlab<ActivityId, Activity<S>>,
100    terminating: bool,
101}
102
103impl<S: Supervisor> Tracker<S> {
104    pub fn new() -> Self {
105        Self {
106            groups: BTreeMap::new(),
107            activities: TypedSlab::new(),
108            terminating: false,
109        }
110    }
111
112    pub fn terminate_group(&mut self, group: S::GroupBy) {
113        if let Some(group) = self.groups.get(&group) {
114            for id in group.ids.iter() {
115                if let Some(activity) = self.activities.get_mut(*id) {
116                    if let Err(err) = activity.interrupt() {
117                        log::error!("Can't interrupt an activity in a group: {err}");
118                    }
119                }
120            }
121        }
122    }
123
124    pub fn terminate_all(&mut self) {
125        self.try_terminate_next();
126    }
127
128    fn register_activity(
129        &mut self,
130        group: S::GroupBy,
131        interruptor: Interruptor,
132    ) -> SupervisedBy<S> {
133        let activity = Activity {
134            group: group.clone(),
135            interruptor,
136        };
137        let id = self.activities.insert(activity);
138        let group_record = self.groups.entry(group.clone()).or_default();
139        group_record.ids.insert(id);
140        if group_record.interrupted {
141            // Interrupt if the group is terminating
142            self.activities.get_mut(id).map(Activity::interrupt);
143        }
144        SupervisedBy { id, group }
145    }
146
147    fn unregister_activity(&mut self, rel: &SupervisedBy<S>) {
148        if let Some(activity) = self.activities.remove(rel.id) {
149            // TODO: check rel.group == activity.group ?
150            if let Some(group) = self.groups.get_mut(&activity.group) {
151                group.ids.remove(&rel.id);
152            }
153        }
154        if self.terminating {
155            self.try_terminate_next();
156        }
157    }
158
159    fn existing_groups(&self) -> Vec<S::GroupBy> {
160        self.groups.keys().rev().cloned().collect()
161    }
162
163    fn try_terminate_next(&mut self) {
164        self.terminating = true;
165        for group_name in self.existing_groups() {
166            if let Some(group) = self.groups.get_mut(&group_name) {
167                if !group.interrupted {
168                    group.interrupted = true;
169                    // Send an interruption signal to all active members of the group.
170                    for id in group.ids.iter() {
171                        if let Some(activity) = self.activities.get_mut(*id) {
172                            if let Err(err) = activity.interrupt() {
173                                log::error!("Can't interrupt the next activity: {err}");
174                            }
175                        }
176                    }
177                }
178                if !group.is_finished() {
179                    break;
180                }
181            }
182        }
183    }
184}
185
186impl<S> SupervisorSession<S>
187where
188    S: Supervisor,
189    S::Context: SupervisorContext<S>,
190{
191    pub fn spawn_runtime<B>(
192        &mut self,
193        trackable: B,
194        group: S::GroupBy,
195    ) -> <B::Context as Context>::Address
196    where
197        B: InteractiveRuntime,
198    {
199        let addr = trackable.address();
200        self.spawn_trackable(trackable, group);
201        addr
202    }
203
204    pub fn spawn_trackable<B>(&mut self, mut trackable: B, group: S::GroupBy)
205    where
206        B: Runtime,
207    {
208        let interruptor = trackable.get_interruptor();
209        let rel = self.tracker.register_activity(group, interruptor);
210        let detacher = DetacherFor {
211            supervisor: self.address().clone(),
212            rel,
213        };
214
215        let fut = async move {
216            trackable.routine().await;
217            // This notification equals calling `detach_trackable`
218            if let Err(err) = detacher.detach() {
219                log::error!("Can't notify a supervisor to detach an activity: {err}");
220            }
221        };
222        crb_core::spawn(fut);
223    }
224}
225
226struct Activity<S: Supervisor> {
227    group: S::GroupBy,
228    // TODO: Consider to use JobHandle here
229    interruptor: Interruptor,
230}
231
232impl<S: Supervisor> Activity<S> {
233    fn interrupt(&mut self) -> Result<(), Error> {
234        self.interruptor.stop(false)
235    }
236}
237
238struct SupervisedBy<S: Supervisor> {
239    id: ActivityId,
240    group: S::GroupBy,
241}
242
243impl<S: Supervisor> Clone for SupervisedBy<S> {
244    fn clone(&self) -> Self {
245        Self {
246            id: self.id.clone(),
247            group: self.group.clone(),
248        }
249    }
250}
251
252struct DetachTrackable<S: Supervisor> {
253    rel: SupervisedBy<S>,
254}
255
256#[async_trait]
257impl<S> MessageFor<S> for DetachTrackable<S>
258where
259    S: Supervisor,
260    S::Context: SupervisorContext<S>,
261{
262    async fn handle(self: Box<Self>, _actor: &mut S, ctx: &mut S::Context) -> Result<(), Error> {
263        SupervisorContext::session(ctx)
264            .tracker
265            .unregister_activity(&self.rel);
266        Ok(())
267    }
268}
269
270pub struct DetacherFor<S: Supervisor> {
271    rel: SupervisedBy<S>,
272    supervisor: <S::Context as Context>::Address,
273}
274
275impl<S> DetacherFor<S>
276where
277    S: Supervisor,
278    S::Context: SupervisorContext<S>,
279{
280    pub fn detach(self) -> Result<(), Error> {
281        let msg = DetachTrackable { rel: self.rel };
282        self.supervisor.send(msg)
283    }
284}