crb-superagent 0.0.29

CRB | Composable Runtime Blocks | Agent Extensions
Documentation
pub mod forward;
pub mod stacker;

pub use forward::ForwardTo;
pub use stacker::Stacker;

use anyhow::Error;
use async_trait::async_trait;
use crb_agent::{
    Address, Agent, AgentContext, AgentSession, Context, Envelope, MessageFor, RunAgent,
};
use crb_core::Tag;
use crb_runtime::{
    InteractiveRuntime, InterruptionLevel, Interruptor, ManagedContext, ReachableContext, Runtime,
};
use derive_more::{Deref, DerefMut, From, Into};
use std::cmp::Ordering;
use std::collections::{BTreeMap, HashSet};
use std::fmt::Debug;
use std::hash::{Hash, Hasher};
use typed_slab::TypedSlab;

#[derive(Debug, Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Hash, From, Into)]
pub struct ActivityId(usize);

pub trait Supervisor: Agent {
    type BasedOn: AgentContext<Self>;
    type GroupBy: Debug + Ord + Clone + Send + Eq + Hash;

    fn finished(&mut self, _rel: &Relation<Self>, _ctx: &mut Context<Self>) {}
}

pub trait SupervisorContext<S: Supervisor> {
    fn tracker(&mut self) -> &mut Tracker<S>;
    fn session(&mut self) -> &mut SupervisorSession<S>;
}

#[derive(Deref, DerefMut)]
pub struct SupervisorSession<S: Supervisor> {
    #[deref]
    #[deref_mut]
    pub session: S::BasedOn,
    pub tracker: Tracker<S>,
}

impl<S> Default for SupervisorSession<S>
where
    S: Supervisor,
    S::BasedOn: Default,
{
    fn default() -> Self {
        Self {
            session: S::BasedOn::default(),
            tracker: Tracker::new(),
        }
    }
}

impl<S: Supervisor> ReachableContext for SupervisorSession<S> {
    type Address = Address<S>;

    fn address(&self) -> &Self::Address {
        self.session.address()
    }
}

impl<S: Supervisor> AsRef<Address<S>> for SupervisorSession<S> {
    fn as_ref(&self) -> &Address<S> {
        self.address()
    }
}

impl<S: Supervisor> ManagedContext for SupervisorSession<S> {
    fn is_alive(&self) -> bool {
        self.session.is_alive() && !self.tracker.is_terminated()
    }

    fn shutdown(&mut self) {
        self.tracker.terminate_all();
        if self.tracker.is_terminated() {
            self.session.shutdown();
        }
    }

    fn stop(&mut self) {
        self.session.stop();
    }
}

#[async_trait]
impl<S: Supervisor> AgentContext<S> for SupervisorSession<S> {
    fn session(&mut self) -> &mut AgentSession<S> {
        self.session.session()
    }

    async fn next_envelope(&mut self) -> Option<Envelope<S>> {
        self.session.next_envelope().await
    }
}

impl<S: Supervisor> SupervisorContext<S> for SupervisorSession<S> {
    fn tracker(&mut self) -> &mut Tracker<S> {
        &mut self.tracker
    }

    fn session(&mut self) -> &mut SupervisorSession<S> {
        self
    }
}

#[derive(Debug, Default)]
struct Group {
    interrupted: bool,
    ids: HashSet<ActivityId>,
}

impl Group {
    fn is_finished(&self) -> bool {
        self.interrupted && self.ids.is_empty()
    }
}

pub struct Tracker<S: Supervisor> {
    groups: BTreeMap<S::GroupBy, Group>,
    activities: TypedSlab<ActivityId, Activity<S>>,
    terminating: bool,
}

impl<S: Supervisor> Default for Tracker<S> {
    fn default() -> Self {
        Self::new()
    }
}

impl<S: Supervisor> Tracker<S> {
    pub fn new() -> Self {
        Self {
            groups: BTreeMap::new(),
            activities: TypedSlab::new(),
            terminating: false,
        }
    }

    pub fn is_empty(&self) -> bool {
        self.groups.is_empty() && self.activities.is_empty()
    }

    pub fn is_terminated(&self) -> bool {
        self.terminating && self.is_empty()
    }

    pub fn terminate_group(&mut self, group: S::GroupBy) {
        if let Some(group) = self.groups.get(&group) {
            for id in group.ids.iter() {
                if let Some(activity) = self.activities.get_mut(*id) {
                    activity.interrupt();
                }
            }
        }
    }

    pub fn terminate_all(&mut self) {
        self.try_terminate_next();
    }

    fn register_activity(&mut self, activity: Activity<S>) -> Relation<S> {
        let group = activity.group.clone();
        let id = self.activities.insert(activity);
        let group_record = self.groups.entry(group.clone()).or_default();
        group_record.ids.insert(id);
        if group_record.interrupted {
            // Interrupt if the group is terminating
            self.activities.get_mut(id).map(Activity::interrupt);
        }
        Relation { id, group }
    }

    fn unregister_activity(&mut self, rel: &Relation<S>) {
        if let Some(activity) = self.activities.remove(rel.id) {
            // TODO: check rel.group == activity.group ?
            if let Some(group) = self.groups.get_mut(&activity.group) {
                group.ids.remove(&rel.id);
                if group.ids.is_empty() {
                    self.groups.remove(&activity.group);
                }
            }
        }
        if self.terminating {
            self.try_terminate_next();
        }
    }

    fn existing_groups(&self) -> Vec<S::GroupBy> {
        self.groups.keys().rev().cloned().collect()
    }

    fn try_terminate_next(&mut self) {
        self.terminating = true;
        for group_name in self.existing_groups() {
            if let Some(group) = self.groups.get_mut(&group_name) {
                let name = std::any::type_name::<S>();
                log::trace!("Agent {name} is terminating group: {:?}", group_name);
                if !group.interrupted {
                    group.interrupted = true;
                    // Send an interruption signal to all active members of the group.
                    for id in group.ids.iter() {
                        if let Some(activity) = self.activities.get_mut(*id) {
                            activity.interrupt();
                        }
                    }
                }
                if !group.is_finished() {
                    break;
                }
            }
        }
    }
}

impl<S> SupervisorSession<S>
where
    S: Supervisor,
    S::Context: SupervisorContext<S>,
{
    pub fn spawn_agent<A>(
        &mut self,
        agent: A,
        group: S::GroupBy,
    ) -> (<A::Context as ReachableContext>::Address, Relation<S>)
    where
        A: Agent,
        A::Context: Default,
    {
        let runtime = RunAgent::<A>::new(agent);
        self.spawn_runtime(runtime, group)
    }

    pub fn spawn_runtime<B>(
        &mut self,
        trackable: B,
        group: S::GroupBy,
    ) -> (<B::Context as ReachableContext>::Address, Relation<S>)
    where
        B: InteractiveRuntime,
    {
        let addr = trackable.address();
        let rel = self.spawn_trackable(trackable, group);
        (addr, rel)
    }

    pub fn spawn_trackable<B>(&mut self, mut trackable: B, group: S::GroupBy) -> Relation<S>
    where
        B: Runtime,
    {
        let activity = Activity {
            group,
            interruptor: trackable.get_interruptor(),
            level: trackable.interruption_level(),
        };
        let rel = self.tracker.register_activity(activity);
        let detacher = DetacherFor {
            supervisor: self.address().clone(),
            rel: rel.clone(),
        };

        let fut = async move {
            let name = std::any::type_name::<S>();
            let rn_name = std::any::type_name::<B>();
            trackable.routine().await;
            // This notification equals calling `detach_trackable`
            if let Err(err) = detacher.detach() {
                log::error!(
                    "Can't notify a supervisor {name} from {rn_name} to detach an activity: {err}"
                );
            } else {
                log::debug!("A supervisor {name} notified about termination of {rn_name}.");
            }
        };
        crb_core::spawn(fut);
        rel
    }

    pub fn assign<R, T>(&mut self, trackable: R, group: S::GroupBy, tag: T) -> Relation<S>
    where
        R: ForwardTo<S, T>,
        T: Tag,
    {
        let address = self.address().clone();
        let trackable = trackable.into_trackable(address, tag);
        self.spawn_trackable(trackable, group)
    }
}

struct Activity<S: Supervisor> {
    group: S::GroupBy,
    interruptor: Box<dyn Interruptor>,
    level: InterruptionLevel,
}

impl<S: Supervisor> Activity<S> {
    fn interrupt(&mut self) {
        self.interruptor.interrupt_with_level(self.level);
    }
}

pub struct Relation<S: Supervisor> {
    pub id: ActivityId,
    pub group: S::GroupBy,
}

impl<S: Supervisor> Clone for Relation<S> {
    fn clone(&self) -> Self {
        Self {
            id: self.id,
            group: self.group.clone(),
        }
    }
}

impl<S: Supervisor> PartialEq for Relation<S> {
    fn eq(&self, other: &Self) -> bool {
        self.id == other.id && self.group == other.group
    }
}

impl<S: Supervisor> Eq for Relation<S> {}

impl<S: Supervisor> PartialOrd for Relation<S> {
    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
        Some(self.cmp(other))
    }
}

impl<S: Supervisor> Ord for Relation<S> {
    fn cmp(&self, other: &Self) -> Ordering {
        self.id
            .cmp(&other.id)
            .then_with(|| self.group.cmp(&other.group))
    }
}

impl<S: Supervisor> Hash for Relation<S> {
    fn hash<H: Hasher>(&self, state: &mut H) {
        self.id.hash(state);
        self.group.hash(state);
    }
}

struct DetachFrom<S: Supervisor> {
    rel: Relation<S>,
}

#[async_trait]
impl<S> MessageFor<S> for DetachFrom<S>
where
    S: Supervisor,
    S::Context: SupervisorContext<S>,
{
    async fn handle(self: Box<Self>, agent: &mut S, ctx: &mut Context<S>) -> Result<(), Error> {
        let tracker = ctx.tracker();
        tracker.unregister_activity(&self.rel);
        if tracker.is_terminated() {
            ctx.shutdown();
        }
        agent.finished(&self.rel, ctx);
        Ok(())
    }
}

pub struct DetacherFor<S: Supervisor> {
    rel: Relation<S>,
    supervisor: <S::Context as ReachableContext>::Address,
}

impl<S> DetacherFor<S>
where
    S: Supervisor,
    S::Context: SupervisorContext<S>,
{
    pub fn detach(self) -> Result<(), Error> {
        let msg = DetachFrom { rel: self.rel };
        self.supervisor.send(msg)
    }
}