crb_agent/
context.rs

1use crate::address::{Address, AddressJoint, Envelope};
2use crate::agent::Agent;
3use crate::extension::ExtensionFor;
4use crate::performers::Next;
5use anyhow::{Result, anyhow};
6use async_trait::async_trait;
7use crb_runtime::{Controller, ManagedContext, ReachableContext};
8use derive_more::{Deref, DerefMut};
9use std::any::{Any, TypeId};
10use std::collections::HashMap;
11
12#[derive(Deref, DerefMut)]
13pub struct Context<A: Agent> {
14    #[deref]
15    #[deref_mut]
16    context: A::Context,
17    extensions: Option<HashMap<TypeId, Box<dyn Any + Send>>>,
18}
19
20impl<A: Agent> Context<A> {
21    pub fn wrap(context: A::Context) -> Self {
22        Self {
23            context,
24            extensions: None,
25        }
26    }
27
28    pub fn add_extension<E>(&mut self, ext: E)
29    where
30        E: ExtensionFor<A>,
31    {
32        self.extensions
33            .get_or_insert_default()
34            .insert(TypeId::of::<E>(), Box::new(ext));
35    }
36
37    pub fn be<E>(&mut self) -> Result<E::View<'_>>
38    where
39        E: ExtensionFor<A>,
40    {
41        let type_id = TypeId::of::<E>();
42        let ext = self
43            .extensions
44            .get_or_insert_default()
45            .get_mut(&type_id)
46            .and_then(|boxed| boxed.downcast_mut::<E>())
47            .ok_or_else(|| anyhow!("Extension {:?} is not available.", type_id))?;
48        Ok(ext.extend(&mut self.context))
49    }
50}
51
52impl<A: Agent> Context<A>
53where
54    A::Context: ReachableContext,
55{
56    pub fn address(&self) -> &<A::Context as ReachableContext>::Address {
57        ReachableContext::address(&self.context)
58    }
59}
60
61#[async_trait]
62pub trait AgentContext<A: Agent>
63where
64    Self: ReachableContext<Address = Address<A>>,
65    Self: ManagedContext,
66{
67    // TODO: Replace with explicit methods
68    fn session(&mut self) -> &mut AgentSession<A>;
69
70    async fn next_envelope(&mut self) -> Option<Envelope<A>>;
71}
72
73#[derive(Deref, DerefMut)]
74pub struct AgentSession<A: Agent> {
75    pub controller: Controller,
76    pub next_state: Option<Next<A>>,
77    pub joint: AddressJoint<A>,
78    #[deref]
79    #[deref_mut]
80    pub address: Address<A>,
81}
82
83impl<A: Agent> AgentSession<A> {
84    pub fn joint(&mut self) -> &mut AddressJoint<A> {
85        &mut self.joint
86    }
87
88    pub fn do_next(&mut self, next_state: Next<A>) {
89        self.next_state = Some(next_state);
90    }
91}
92
93impl<A: Agent> Default for AgentSession<A> {
94    fn default() -> Self {
95        let controller = Controller::default();
96        let stopper = controller.stopper.clone();
97        let (address, joint) = AddressJoint::new_pair(stopper);
98        Self {
99            controller,
100            next_state: None,
101            joint,
102            address,
103        }
104    }
105}
106
107impl<A: Agent> ReachableContext for AgentSession<A> {
108    type Address = Address<A>;
109
110    fn address(&self) -> &Self::Address {
111        &self.address
112    }
113}
114
115impl<A: Agent> AsRef<Address<A>> for AgentSession<A> {
116    fn as_ref(&self) -> &Address<A> {
117        self.address()
118    }
119}
120
121impl<A: Agent> ManagedContext for AgentSession<A> {
122    fn is_alive(&self) -> bool {
123        self.controller.is_active()
124    }
125
126    fn shutdown(&mut self) {
127        self.joint.close();
128    }
129
130    fn stop(&mut self) {
131        self.controller.stop(false);
132    }
133}
134
135#[async_trait]
136impl<A: Agent> AgentContext<A> for AgentSession<A> {
137    fn session(&mut self) -> &mut AgentSession<A> {
138        self
139    }
140
141    async fn next_envelope(&mut self) -> Option<Envelope<A>> {
142        self.joint().next_envelope().await
143    }
144}