1use crate::address::{Address, AddressJoint, Envelope};
2use crate::agent::Agent;
3use crate::extension::ExtensionFor;
4use crate::performers::Next;
5use anyhow::{Result, anyhow};
6use async_trait::async_trait;
7use crb_runtime::{Controller, ManagedContext, ReachableContext};
8use derive_more::{Deref, DerefMut};
9use std::any::{Any, TypeId};
10use std::collections::HashMap;
11
12#[derive(Deref, DerefMut)]
13pub struct Context<A: Agent> {
14 #[deref]
15 #[deref_mut]
16 context: A::Context,
17 extensions: Option<HashMap<TypeId, Box<dyn Any + Send>>>,
18}
19
20impl<A: Agent> Context<A> {
21 pub fn wrap(context: A::Context) -> Self {
22 Self {
23 context,
24 extensions: None,
25 }
26 }
27
28 pub fn add_extension<E>(&mut self, ext: E)
29 where
30 E: ExtensionFor<A>,
31 {
32 self.extensions
33 .get_or_insert_default()
34 .insert(TypeId::of::<E>(), Box::new(ext));
35 }
36
37 pub fn be<E>(&mut self) -> Result<E::View<'_>>
38 where
39 E: ExtensionFor<A>,
40 {
41 let type_id = TypeId::of::<E>();
42 let ext = self
43 .extensions
44 .get_or_insert_default()
45 .get_mut(&type_id)
46 .and_then(|boxed| boxed.downcast_mut::<E>())
47 .ok_or_else(|| anyhow!("Extension {:?} is not available.", type_id))?;
48 Ok(ext.extend(&mut self.context))
49 }
50}
51
52impl<A: Agent> Context<A>
53where
54 A::Context: ReachableContext,
55{
56 pub fn address(&self) -> &<A::Context as ReachableContext>::Address {
57 ReachableContext::address(&self.context)
58 }
59}
60
61#[async_trait]
62pub trait AgentContext<A: Agent>
63where
64 Self: ReachableContext<Address = Address<A>>,
65 Self: ManagedContext,
66{
67 fn session(&mut self) -> &mut AgentSession<A>;
69
70 async fn next_envelope(&mut self) -> Option<Envelope<A>>;
71}
72
73#[derive(Deref, DerefMut)]
74pub struct AgentSession<A: Agent> {
75 pub controller: Controller,
76 pub next_state: Option<Next<A>>,
77 pub joint: AddressJoint<A>,
78 #[deref]
79 #[deref_mut]
80 pub address: Address<A>,
81}
82
83impl<A: Agent> AgentSession<A> {
84 pub fn joint(&mut self) -> &mut AddressJoint<A> {
85 &mut self.joint
86 }
87
88 pub fn do_next(&mut self, next_state: Next<A>) {
89 self.next_state = Some(next_state);
90 }
91}
92
93impl<A: Agent> Default for AgentSession<A> {
94 fn default() -> Self {
95 let controller = Controller::default();
96 let stopper = controller.stopper.clone();
97 let (address, joint) = AddressJoint::new_pair(stopper);
98 Self {
99 controller,
100 next_state: None,
101 joint,
102 address,
103 }
104 }
105}
106
107impl<A: Agent> ReachableContext for AgentSession<A> {
108 type Address = Address<A>;
109
110 fn address(&self) -> &Self::Address {
111 &self.address
112 }
113}
114
115impl<A: Agent> AsRef<Address<A>> for AgentSession<A> {
116 fn as_ref(&self) -> &Address<A> {
117 self.address()
118 }
119}
120
121impl<A: Agent> ManagedContext for AgentSession<A> {
122 fn is_alive(&self) -> bool {
123 self.controller.is_active()
124 }
125
126 fn shutdown(&mut self) {
127 self.joint.close();
128 }
129
130 fn stop(&mut self) {
131 self.controller.stop(false);
132 }
133}
134
135#[async_trait]
136impl<A: Agent> AgentContext<A> for AgentSession<A> {
137 fn session(&mut self) -> &mut AgentSession<A> {
138 self
139 }
140
141 async fn next_envelope(&mut self) -> Option<Envelope<A>> {
142 self.joint().next_envelope().await
143 }
144}