1use std::{
5 collections::{HashMap, HashSet},
6 sync::{Arc, Mutex},
7};
8
9use crate::{
10 error::{MutationError, RetrievalError},
11 policy::PolicyAgent,
12 storage::{StorageCollectionWrapper, StorageEngine},
13 util::Iterable,
14 Node,
15};
16use ankurah_proto::{self as proto, Attested, Clock, EntityId, EntityState, Event, EventId};
17use async_trait::async_trait;
18
19pub trait TEvent: std::fmt::Display {
21 type Id: Eq + PartialEq + Clone;
22 type Parent: TClock<Id = Self::Id>;
23
24 fn id(&self) -> Self::Id;
25 fn parent(&self) -> &Self::Parent;
26}
27
28pub trait TClock {
29 type Id: Eq + PartialEq + Clone;
30 fn members(&self) -> &[Self::Id];
31}
32
33impl TClock for Clock {
34 type Id = EventId;
35 fn members(&self) -> &[Self::Id] { self.as_slice() }
36}
37
38impl TEvent for ankurah_proto::Event {
39 type Id = ankurah_proto::EventId;
40 type Parent = Clock;
41
42 fn id(&self) -> EventId { self.id() }
43 fn parent(&self) -> &Clock { &self.parent }
44}
45
46#[async_trait]
47pub trait GetEvents {
48 type Id: Eq + PartialEq + Clone + std::fmt::Debug + Send + Sync;
49 type Event: TEvent<Id = Self::Id> + std::fmt::Display;
50
51 fn estimate_cost(&self, _batch_size: usize) -> usize {
54 1
56 }
57
58 async fn retrieve_event(&self, event_ids: Vec<Self::Id>) -> Result<(usize, Vec<Attested<Self::Event>>), RetrievalError>;
60
61 fn stage_events(&self, events: impl IntoIterator<Item = Attested<Self::Event>>);
64
65 fn mark_event_used(&self, event_id: &Self::Id);
67}
68
69#[async_trait]
70pub trait Retrieve: GetEvents {
71 async fn get_state(&self, entity_id: EntityId) -> Result<Option<Attested<EntityState>>, RetrievalError>;
73}
74
75#[derive(Clone)]
77pub struct LocalRetriever(Arc<LocalRetrieverInner>);
78struct LocalRetrieverInner {
79 collection: StorageCollectionWrapper,
80 staged_events: Mutex<Option<HashMap<EventId, (Attested<Event>, bool)>>>,
82}
83
84impl LocalRetriever {
85 pub fn new(collection: StorageCollectionWrapper) -> Self {
86 Self(Arc::new(LocalRetrieverInner { collection, staged_events: Mutex::new(Some(HashMap::new())) }))
87 }
88
89 pub async fn store_used_events(&mut self) -> Result<(), RetrievalError> {
90 let staged = { self.0.staged_events.lock().unwrap().take() };
91
92 if let Some(staged) = staged {
93 for (_id, (event, used)) in staged.iter() {
94 if *used {
95 self.0.collection.add_event(event).await?;
96 }
97 }
98 }
99
100 Ok(())
101 }
102}
103
104#[async_trait]
105impl GetEvents for LocalRetriever {
106 type Id = EventId;
107 type Event = ankurah_proto::Event;
108
109 async fn retrieve_event(&self, event_ids: Vec<Self::Id>) -> Result<(usize, Vec<Attested<Self::Event>>), RetrievalError> {
110 let mut events = Vec::with_capacity(event_ids.len());
111 let mut event_ids: HashSet<Self::Id> = event_ids.into_iter().collect();
112
113 {
115 if let Some(staged) = self.0.staged_events.lock().unwrap().as_mut() {
116 event_ids.retain(|id| {
117 if let Some((event, used)) = staged.get_mut(id) {
118 events.push(event.clone());
119 *used = true;
120 false
121 } else {
122 true
123 }
124 });
125 }
126 }
127
128 if event_ids.is_empty() {
129 return Ok((0, events));
130 }
131
132 let stored_events = self.0.collection.get_events(event_ids.into_iter().collect()).await?;
137 events.extend(stored_events);
138
139 Ok((1, events))
141 }
142
143 fn stage_events(&self, events: impl IntoIterator<Item = Attested<Self::Event>>) {
144 let mut staged = self.0.staged_events.lock().unwrap();
145 let staged = staged.get_or_insert_with(|| HashMap::new());
146
147 for event in events.into_iter() {
148 staged.insert(event.payload.id(), (event, false));
149 }
150 }
151
152 fn mark_event_used(&self, event_id: &Self::Id) {
153 let mut staged = self.0.staged_events.lock().unwrap();
154 let staged = staged.get_or_insert_with(|| HashMap::new());
155 staged.get_mut(event_id).map(|(_, used)| {
156 *used = true;
157 });
158 }
159}
160
161#[async_trait]
162impl Retrieve for LocalRetriever {
163 async fn get_state(&self, entity_id: EntityId) -> Result<Option<Attested<EntityState>>, RetrievalError> {
164 match self.0.collection.get_state(entity_id).await {
165 Ok(state) => Ok(Some(state)),
166 Err(RetrievalError::EntityNotFound(_)) => Ok(None),
167 Err(e) => Err(e),
168 }
169 }
170}
171
172pub struct EphemeralNodeRetriever<'a, SE, PA, C>
174where
175 SE: StorageEngine + Send + Sync + 'static,
176 PA: PolicyAgent + Send + Sync + 'static,
177 C: Iterable<PA::ContextData> + Send + Sync + 'a,
178{
179 pub collection: proto::CollectionId,
180 pub node: &'a Node<SE, PA>,
181 pub cdata: &'a C,
182 staged_events: Mutex<Option<HashMap<EventId, (Attested<Event>, bool)>>>,
184}
185
186impl<'a, SE, PA, C> EphemeralNodeRetriever<'a, SE, PA, C>
187where
188 SE: StorageEngine + Send + Sync + 'static,
189 PA: PolicyAgent + Send + Sync + 'static,
190 C: Iterable<PA::ContextData> + Send + Sync + 'a,
191{
192 pub fn new(collection: proto::CollectionId, node: &'a Node<SE, PA>, cdata: &'a C) -> Self {
193 Self { collection, node, cdata, staged_events: Mutex::new(Some(HashMap::new())) }
194 }
195
196 pub async fn store_used_events(&self) -> Result<(), MutationError> {
197 let staged = { self.staged_events.lock().unwrap().take() };
198
199 if let Some(staged) = staged {
200 let collection = self.node.system.collection(&self.collection).await?;
203 for (_id, (event, used)) in staged.iter() {
204 if *used {
205 collection.add_event(event).await?;
206 }
207 }
208 }
209
210 Ok(())
211 }
212}
213
214#[async_trait]
215impl<'a, SE, PA, C> GetEvents for EphemeralNodeRetriever<'a, SE, PA, C>
216where
217 SE: StorageEngine + Send + Sync + 'static,
218 PA: PolicyAgent + Send + Sync + 'static,
219 C: Iterable<PA::ContextData> + Send + Sync + 'a,
220{
221 type Id = EventId;
222 type Event = Event;
223
224 async fn retrieve_event(&self, event_ids: Vec<Self::Id>) -> Result<(usize, Vec<Attested<Self::Event>>), RetrievalError> {
225 let mut events = Vec::with_capacity(event_ids.len());
226 let mut event_ids: HashSet<Self::Id> = event_ids.into_iter().collect();
227
228 {
230 if let Some(staged) = self.staged_events.lock().unwrap().as_mut() {
231 event_ids.retain(|id| {
232 if let Some((event, used)) = staged.get_mut(id) {
233 events.push(event.clone());
234 *used = true;
235 false
236 } else {
237 true
238 }
239 });
240 }
241 }
242
243 if event_ids.is_empty() {
244 return Ok((0, events));
245 }
246
247 let collection = self.node.system.collection(&self.collection).await?;
253 for event in collection.get_events(event_ids.iter().cloned().collect()).await? {
255 event_ids.remove(&event.payload.id());
256 events.push(event);
257 }
258
259 if event_ids.is_empty() {
260 return Ok((1, events));
261 }
262
263 let Some(peer_id) = self.node.get_durable_peer_random() else {
265 return Ok((1, events)); };
267
268 match self
269 .node
270 .request(
271 peer_id,
272 self.cdata,
273 proto::NodeRequestBody::GetEvents { collection: self.collection.clone(), event_ids: event_ids.into_iter().collect() }, )
275 .await?
276 {
278 proto::NodeResponseBody::GetEvents(peer_events) => {
279 for event in peer_events.iter() {
280 collection.add_event(event).await?;
281 }
282 events.extend(peer_events);
283 }
284 proto::NodeResponseBody::Error(e) => {
285 return Err(RetrievalError::StorageError(format!("Error from peer: {}", e).into()));
286 }
287 _ => return Err(RetrievalError::StorageError("Unexpected response type from peer".into())),
288 }
289 Ok((5, events))
290 }
291
292 fn stage_events(&self, events: impl IntoIterator<Item = Attested<Self::Event>>) {
293 let mut staged = self.staged_events.lock().unwrap();
294 let staged = staged.get_or_insert_with(|| HashMap::new());
295
296 for event in events.into_iter() {
297 staged.insert(event.payload.id(), (event, false));
298 }
299 }
300
301 fn mark_event_used(&self, event_id: &Self::Id) {
302 let mut staged = self.staged_events.lock().unwrap();
303 let staged = staged.get_or_insert_with(|| HashMap::new());
304 staged.get_mut(event_id).map(|(_, used)| {
305 *used = true;
306 });
307 }
308}
309
310#[async_trait]
311impl<'a, SE, PA, C> Retrieve for EphemeralNodeRetriever<'a, SE, PA, C>
312where
313 SE: StorageEngine + Send + Sync + 'static,
314 PA: PolicyAgent + Send + Sync + 'static,
315 C: Iterable<PA::ContextData> + Send + Sync + 'a,
316{
317 async fn get_state(&self, entity_id: EntityId) -> Result<Option<Attested<EntityState>>, RetrievalError> {
318 let collection = self.node.collections.get(&self.collection).await?;
319 match collection.get_state(entity_id).await {
320 Ok(state) => Ok(Some(state)),
321 Err(RetrievalError::EntityNotFound(_)) => Ok(None),
322 Err(e) => Err(e),
323 }
324 }
325}