crb_superagent/supervisor/
mod.rs1pub mod forward;
2pub mod stacker;
3
4pub use forward::ForwardTo;
5pub use stacker::Stacker;
6
7use anyhow::Error;
8use async_trait::async_trait;
9use crb_agent::{
10 Address, Agent, AgentContext, AgentSession, Context, Envelope, Link, MessageFor, RunAgent,
11};
12use crb_core::Tag;
13use crb_runtime::{
14 InteractiveRuntime, InterruptionLevel, Interruptor, ManagedContext, ReachableContext, Runtime,
15};
16use derive_more::{Deref, DerefMut, From, Into};
17use std::cmp::Ordering;
18use std::collections::{BTreeMap, HashSet};
19use std::fmt::Debug;
20use std::hash::{Hash, Hasher};
21use typed_slab::TypedSlab;
22
23#[derive(Debug, Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Hash, From, Into)]
24pub struct ActivityId(usize);
25
26pub trait Supervisor: Agent {
27 type BasedOn: AgentContext<Self>;
28 type GroupBy: Debug + Ord + Clone + Send + Eq + Hash;
29
30 fn finished(&mut self, _rel: &Relation<Self>, _ctx: &mut Context<Self>) {}
31}
32
33pub trait SupervisorContext<S: Supervisor> {
34 fn tracker(&mut self) -> &mut Tracker<S>;
35 fn session(&mut self) -> &mut SupervisorSession<S>;
36}
37
38#[derive(Deref, DerefMut)]
39pub struct SupervisorSession<S: Supervisor> {
40 #[deref]
41 #[deref_mut]
42 pub session: S::BasedOn,
43 pub tracker: Tracker<S>,
44}
45
46impl<S> Default for SupervisorSession<S>
47where
48 S: Supervisor,
49 S::BasedOn: Default,
50{
51 fn default() -> Self {
52 Self {
53 session: S::BasedOn::default(),
54 tracker: Tracker::new(),
55 }
56 }
57}
58
59impl<S: Supervisor> ReachableContext for SupervisorSession<S> {
60 type Address = Address<S>;
61
62 fn address(&self) -> &Self::Address {
63 self.session.address()
64 }
65}
66
67impl<S: Supervisor> AsRef<Address<S>> for SupervisorSession<S> {
68 fn as_ref(&self) -> &Address<S> {
69 self.address()
70 }
71}
72
73impl<S: Supervisor> ManagedContext for SupervisorSession<S> {
74 fn is_alive(&self) -> bool {
75 self.session.is_alive() && !self.tracker.is_terminated()
76 }
77
78 fn shutdown(&mut self) {
79 self.tracker.terminate_all();
80 if self.tracker.is_terminated() {
81 self.session.shutdown();
82 }
83 }
84
85 fn stop(&mut self) {
86 self.session.stop();
87 }
88}
89
90#[async_trait]
91impl<S: Supervisor> AgentContext<S> for SupervisorSession<S> {
92 fn session(&mut self) -> &mut AgentSession<S> {
93 self.session.session()
94 }
95
96 async fn next_envelope(&mut self) -> Option<Envelope<S>> {
97 self.session.next_envelope().await
98 }
99}
100
101impl<S: Supervisor> SupervisorContext<S> for SupervisorSession<S> {
102 fn tracker(&mut self) -> &mut Tracker<S> {
103 &mut self.tracker
104 }
105
106 fn session(&mut self) -> &mut SupervisorSession<S> {
107 self
108 }
109}
110
111#[derive(Debug, Default)]
112struct Group {
113 interrupted: bool,
114 ids: HashSet<ActivityId>,
115}
116
117impl Group {
118 fn is_finished(&self) -> bool {
119 self.interrupted && self.ids.is_empty()
120 }
121}
122
123pub struct Tracker<S: Supervisor> {
124 groups: BTreeMap<S::GroupBy, Group>,
125 activities: TypedSlab<ActivityId, Activity<S>>,
126 terminating: bool,
127}
128
129impl<S: Supervisor> Default for Tracker<S> {
130 fn default() -> Self {
131 Self::new()
132 }
133}
134
135impl<S: Supervisor> Tracker<S> {
136 pub fn new() -> Self {
137 Self {
138 groups: BTreeMap::new(),
139 activities: TypedSlab::new(),
140 terminating: false,
141 }
142 }
143
144 pub fn is_empty(&self) -> bool {
145 self.groups.is_empty() && self.activities.is_empty()
146 }
147
148 pub fn is_terminated(&self) -> bool {
149 self.terminating && self.is_empty()
150 }
151
152 pub fn terminate_group(&mut self, group: S::GroupBy) {
153 if let Some(group) = self.groups.get(&group) {
154 for id in group.ids.iter() {
155 if let Some(activity) = self.activities.get_mut(*id) {
156 activity.interrupt();
157 }
158 }
159 }
160 }
161
162 pub fn terminate_all(&mut self) {
163 self.try_terminate_next();
164 }
165
166 fn register_activity(&mut self, activity: Activity<S>) -> Relation<S> {
167 let group = activity.group.clone();
168 let id = self.activities.insert(activity);
169 let group_record = self.groups.entry(group.clone()).or_default();
170 group_record.ids.insert(id);
171 if group_record.interrupted {
172 self.activities.get_mut(id).map(Activity::interrupt);
174 }
175 Relation { id, group }
176 }
177
178 fn unregister_activity(&mut self, rel: &Relation<S>) {
179 if let Some(activity) = self.activities.remove(rel.id) {
180 if let Some(group) = self.groups.get_mut(&activity.group) {
182 group.ids.remove(&rel.id);
183 if group.ids.is_empty() {
184 self.groups.remove(&activity.group);
185 }
186 }
187 }
188 if self.terminating {
189 self.try_terminate_next();
190 }
191 }
192
193 fn existing_groups(&self) -> Vec<S::GroupBy> {
194 self.groups.keys().rev().cloned().collect()
195 }
196
197 fn try_terminate_next(&mut self) {
198 self.terminating = true;
199 for group_name in self.existing_groups() {
200 if let Some(group) = self.groups.get_mut(&group_name) {
201 let name = std::any::type_name::<S>();
202 log::trace!("Agent {name} is terminating group: {:?}", group_name);
203 if !group.interrupted {
204 group.interrupted = true;
205 for id in group.ids.iter() {
207 if let Some(activity) = self.activities.get_mut(*id) {
208 activity.interrupt();
209 }
210 }
211 }
212 if !group.is_finished() {
213 break;
214 }
215 }
216 }
217 }
218}
219
220impl<S> SupervisorSession<S>
221where
222 S: Supervisor,
223 S::Context: SupervisorContext<S>,
224{
225 pub fn spawn_agent<A>(&mut self, agent: A, group: S::GroupBy) -> (Link<A>, Relation<S>)
226 where
227 A: Agent,
228 A::Context: Default,
229 {
230 let runtime = RunAgent::<A>::new(agent);
231 let (address, relation) = self.spawn_runtime(runtime, group);
232 (address.into(), relation)
233 }
234
235 pub fn spawn_runtime<B>(
236 &mut self,
237 trackable: B,
238 group: S::GroupBy,
239 ) -> (<B::Context as ReachableContext>::Address, Relation<S>)
240 where
241 B: InteractiveRuntime,
242 {
243 let addr = trackable.address();
244 let rel = self.spawn_trackable(trackable, group);
245 (addr, rel)
246 }
247
248 pub fn spawn_trackable<B>(&mut self, mut trackable: B, group: S::GroupBy) -> Relation<S>
249 where
250 B: Runtime,
251 {
252 let activity = Activity {
253 group,
254 interruptor: trackable.get_interruptor(),
255 level: trackable.interruption_level(),
256 };
257 let rel = self.tracker.register_activity(activity);
258 let detacher = DetacherFor {
259 supervisor: self.address().clone(),
260 rel: rel.clone(),
261 };
262
263 let fut = async move {
264 let name = std::any::type_name::<S>();
265 let rn_name = std::any::type_name::<B>();
266 trackable.routine().await;
267 if let Err(err) = detacher.detach() {
269 log::error!(
270 "Can't notify a supervisor {name} from {rn_name} to detach an activity: {err}"
271 );
272 } else {
273 log::debug!("A supervisor {name} notified about termination of {rn_name}.");
274 }
275 };
276 crb_core::spawn(fut);
277 rel
278 }
279
280 pub fn assign<R, T>(&mut self, trackable: R, group: S::GroupBy, tag: T) -> Relation<S>
281 where
282 R: ForwardTo<S, T>,
283 T: Tag,
284 {
285 let address = self.address().clone();
286 let trackable = trackable.into_trackable(address, tag);
287 self.spawn_trackable(trackable, group)
288 }
289}
290
291struct Activity<S: Supervisor> {
292 group: S::GroupBy,
293 interruptor: Box<dyn Interruptor>,
294 level: InterruptionLevel,
295}
296
297impl<S: Supervisor> Activity<S> {
298 fn interrupt(&mut self) {
299 self.interruptor.interrupt_with_level(self.level);
300 }
301}
302
303pub struct Relation<S: Supervisor> {
304 pub id: ActivityId,
305 pub group: S::GroupBy,
306}
307
308impl<S: Supervisor> Clone for Relation<S> {
309 fn clone(&self) -> Self {
310 Self {
311 id: self.id,
312 group: self.group.clone(),
313 }
314 }
315}
316
317impl<S: Supervisor> PartialEq for Relation<S> {
318 fn eq(&self, other: &Self) -> bool {
319 self.id == other.id && self.group == other.group
320 }
321}
322
323impl<S: Supervisor> Eq for Relation<S> {}
324
325impl<S: Supervisor> PartialOrd for Relation<S> {
326 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
327 Some(self.cmp(other))
328 }
329}
330
331impl<S: Supervisor> Ord for Relation<S> {
332 fn cmp(&self, other: &Self) -> Ordering {
333 self.id
334 .cmp(&other.id)
335 .then_with(|| self.group.cmp(&other.group))
336 }
337}
338
339impl<S: Supervisor> Hash for Relation<S> {
340 fn hash<H: Hasher>(&self, state: &mut H) {
341 self.id.hash(state);
342 self.group.hash(state);
343 }
344}
345
346struct DetachFrom<S: Supervisor> {
347 rel: Relation<S>,
348}
349
350#[async_trait]
351impl<S> MessageFor<S> for DetachFrom<S>
352where
353 S: Supervisor,
354 S::Context: SupervisorContext<S>,
355{
356 async fn handle(self: Box<Self>, agent: &mut S, ctx: &mut Context<S>) -> Result<(), Error> {
357 let tracker = ctx.tracker();
358 tracker.unregister_activity(&self.rel);
359 if tracker.is_terminated() {
360 ctx.shutdown();
361 }
362 agent.finished(&self.rel, ctx);
363 Ok(())
364 }
365}
366
367pub struct DetacherFor<S: Supervisor> {
368 rel: Relation<S>,
369 supervisor: <S::Context as ReachableContext>::Address,
370}
371
372impl<S> DetacherFor<S>
373where
374 S: Supervisor,
375 S::Context: SupervisorContext<S>,
376{
377 pub fn detach(self) -> Result<(), Error> {
378 let msg = DetachFrom { rel: self.rel };
379 self.supervisor.send(msg)
380 }
381}