1use anyhow::Error;
2use async_trait::async_trait;
3use crb_agent::{Address, Agent, AgentContext, AgentSession, MessageFor, RunAgent};
4use crb_runtime::{Context, Controller, InteractiveRuntime, Interruptor, ManagedContext, Runtime};
5use derive_more::{Deref, DerefMut, From, Into};
6use std::collections::{BTreeMap, HashSet};
7use std::fmt::Debug;
8use std::hash::Hash;
9use typed_slab::TypedSlab;
10
11pub trait Supervisor: Agent {
12 type GroupBy: Debug + Ord + Clone + Sync + Send + Eq + Hash;
13}
14
15pub trait SupervisorContext<S: Supervisor> {
16 fn session(&mut self) -> &mut SupervisorSession<S>;
17}
18
19#[derive(Deref, DerefMut)]
20pub struct SupervisorSession<S: Supervisor> {
21 #[deref]
22 #[deref_mut]
23 session: AgentSession<S>,
24 tracker: Tracker<S>,
25}
26
27impl<S: Supervisor> Default for SupervisorSession<S> {
28 fn default() -> Self {
29 Self {
30 session: AgentSession::default(),
31 tracker: Tracker::new(),
32 }
33 }
34}
35
36impl<S: Supervisor> Context for SupervisorSession<S> {
37 type Address = Address<S>;
38
39 fn address(&self) -> &Self::Address {
40 self.session.address()
41 }
42}
43
44impl<S: Supervisor> ManagedContext for SupervisorSession<S> {
45 fn controller(&mut self) -> &mut Controller {
46 self.session.controller()
47 }
48
49 fn shutdown(&mut self) {
50 self.session.shutdown();
51 }
52}
53
54impl<S: Supervisor> AgentContext<S> for SupervisorSession<S> {
55 fn session(&mut self) -> &mut AgentSession<S> {
56 &mut self.session
57 }
58}
59
60impl<S: Supervisor> SupervisorContext<S> for SupervisorSession<S> {
61 fn session(&mut self) -> &mut SupervisorSession<S> {
62 self
63 }
64}
65
66impl<S: Supervisor> SupervisorSession<S> {
67 pub fn spawn_actor<A>(
68 &mut self,
69 input: A,
70 group: S::GroupBy,
71 ) -> <A::Context as Context>::Address
72 where
73 A: Agent,
74 A::Context: Default,
75 S: Supervisor<Context = SupervisorSession<S>>,
76 {
77 let runtime = RunAgent::<A>::new(input);
78 self.spawn_runtime(runtime, group)
79 }
80}
81
82#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, From, Into)]
83pub struct ActivityId(usize);
84
85#[derive(Debug, Default)]
86struct Group {
87 interrupted: bool,
88 ids: HashSet<ActivityId>,
89}
90
91impl Group {
92 fn is_finished(&self) -> bool {
93 self.interrupted && self.ids.is_empty()
94 }
95}
96
97pub struct Tracker<S: Supervisor> {
98 groups: BTreeMap<S::GroupBy, Group>,
99 activities: TypedSlab<ActivityId, Activity<S>>,
100 terminating: bool,
101}
102
103impl<S: Supervisor> Tracker<S> {
104 pub fn new() -> Self {
105 Self {
106 groups: BTreeMap::new(),
107 activities: TypedSlab::new(),
108 terminating: false,
109 }
110 }
111
112 pub fn terminate_group(&mut self, group: S::GroupBy) {
113 if let Some(group) = self.groups.get(&group) {
114 for id in group.ids.iter() {
115 if let Some(activity) = self.activities.get_mut(*id) {
116 if let Err(err) = activity.interrupt() {
117 log::error!("Can't interrupt an activity in a group: {err}");
118 }
119 }
120 }
121 }
122 }
123
124 pub fn terminate_all(&mut self) {
125 self.try_terminate_next();
126 }
127
128 fn register_activity(
129 &mut self,
130 group: S::GroupBy,
131 interruptor: Interruptor,
132 ) -> SupervisedBy<S> {
133 let activity = Activity {
134 group: group.clone(),
135 interruptor,
136 };
137 let id = self.activities.insert(activity);
138 let group_record = self.groups.entry(group.clone()).or_default();
139 group_record.ids.insert(id);
140 if group_record.interrupted {
141 self.activities.get_mut(id).map(Activity::interrupt);
143 }
144 SupervisedBy { id, group }
145 }
146
147 fn unregister_activity(&mut self, rel: &SupervisedBy<S>) {
148 if let Some(activity) = self.activities.remove(rel.id) {
149 if let Some(group) = self.groups.get_mut(&activity.group) {
151 group.ids.remove(&rel.id);
152 }
153 }
154 if self.terminating {
155 self.try_terminate_next();
156 }
157 }
158
159 fn existing_groups(&self) -> Vec<S::GroupBy> {
160 self.groups.keys().rev().cloned().collect()
161 }
162
163 fn try_terminate_next(&mut self) {
164 self.terminating = true;
165 for group_name in self.existing_groups() {
166 if let Some(group) = self.groups.get_mut(&group_name) {
167 if !group.interrupted {
168 group.interrupted = true;
169 for id in group.ids.iter() {
171 if let Some(activity) = self.activities.get_mut(*id) {
172 if let Err(err) = activity.interrupt() {
173 log::error!("Can't interrupt the next activity: {err}");
174 }
175 }
176 }
177 }
178 if !group.is_finished() {
179 break;
180 }
181 }
182 }
183 }
184}
185
186impl<S> SupervisorSession<S>
187where
188 S: Supervisor,
189 S::Context: SupervisorContext<S>,
190{
191 pub fn spawn_runtime<B>(
192 &mut self,
193 trackable: B,
194 group: S::GroupBy,
195 ) -> <B::Context as Context>::Address
196 where
197 B: InteractiveRuntime,
198 {
199 let addr = trackable.address();
200 self.spawn_trackable(trackable, group);
201 addr
202 }
203
204 pub fn spawn_trackable<B>(&mut self, mut trackable: B, group: S::GroupBy)
205 where
206 B: Runtime,
207 {
208 let interruptor = trackable.get_interruptor();
209 let rel = self.tracker.register_activity(group, interruptor);
210 let detacher = DetacherFor {
211 supervisor: self.address().clone(),
212 rel,
213 };
214
215 let fut = async move {
216 trackable.routine().await;
217 if let Err(err) = detacher.detach() {
219 log::error!("Can't notify a supervisor to detach an activity: {err}");
220 }
221 };
222 crb_core::spawn(fut);
223 }
224}
225
226struct Activity<S: Supervisor> {
227 group: S::GroupBy,
228 interruptor: Interruptor,
230}
231
232impl<S: Supervisor> Activity<S> {
233 fn interrupt(&mut self) -> Result<(), Error> {
234 self.interruptor.stop(false)
235 }
236}
237
238struct SupervisedBy<S: Supervisor> {
239 id: ActivityId,
240 group: S::GroupBy,
241}
242
243impl<S: Supervisor> Clone for SupervisedBy<S> {
244 fn clone(&self) -> Self {
245 Self {
246 id: self.id.clone(),
247 group: self.group.clone(),
248 }
249 }
250}
251
252struct DetachTrackable<S: Supervisor> {
253 rel: SupervisedBy<S>,
254}
255
256#[async_trait]
257impl<S> MessageFor<S> for DetachTrackable<S>
258where
259 S: Supervisor,
260 S::Context: SupervisorContext<S>,
261{
262 async fn handle(self: Box<Self>, _actor: &mut S, ctx: &mut S::Context) -> Result<(), Error> {
263 SupervisorContext::session(ctx)
264 .tracker
265 .unregister_activity(&self.rel);
266 Ok(())
267 }
268}
269
270pub struct DetacherFor<S: Supervisor> {
271 rel: SupervisedBy<S>,
272 supervisor: <S::Context as Context>::Address,
273}
274
275impl<S> DetacherFor<S>
276where
277 S: Supervisor,
278 S::Context: SupervisorContext<S>,
279{
280 pub fn detach(self) -> Result<(), Error> {
281 let msg = DetachTrackable { rel: self.rel };
282 self.supervisor.send(msg)
283 }
284}