Skip to main content

ankurah_core/
retrieval.rs

1//! Implements GetEvents for NodeAndContext, allowing event retrieval from local and remote sources.
2//! This lives in lineage because event retrieval is a lineage concern, not a context/session concern.
3
4use std::{
5    collections::{HashMap, HashSet},
6    sync::{Arc, Mutex},
7};
8
9use crate::{
10    error::{MutationError, RetrievalError},
11    policy::PolicyAgent,
12    storage::{StorageCollectionWrapper, StorageEngine},
13    util::Iterable,
14    Node,
15};
16use ankurah_proto::{self as proto, Attested, Clock, EntityId, EntityState, Event, EventId};
17use async_trait::async_trait;
18
19/// a trait for events and eventlike things that can be descended
20pub trait TEvent: std::fmt::Display {
21    type Id: Eq + PartialEq + Clone;
22    type Parent: TClock<Id = Self::Id>;
23
24    fn id(&self) -> Self::Id;
25    fn parent(&self) -> &Self::Parent;
26}
27
28pub trait TClock {
29    type Id: Eq + PartialEq + Clone;
30    fn members(&self) -> &[Self::Id];
31}
32
33impl TClock for Clock {
34    type Id = EventId;
35    fn members(&self) -> &[Self::Id] { self.as_slice() }
36}
37
38impl TEvent for ankurah_proto::Event {
39    type Id = ankurah_proto::EventId;
40    type Parent = Clock;
41
42    fn id(&self) -> EventId { self.id() }
43    fn parent(&self) -> &Clock { &self.parent }
44}
45
46#[async_trait]
47pub trait GetEvents {
48    type Id: Eq + PartialEq + Clone + std::fmt::Debug + Send + Sync;
49    type Event: TEvent<Id = Self::Id> + std::fmt::Display;
50
51    /// Estimate the budget cost for retrieving a batch of events
52    /// This allows different implementations to model their cost structure
53    fn estimate_cost(&self, _batch_size: usize) -> usize {
54        // Default implementation: fixed cost of 1 per batch
55        1
56    }
57
58    /// retrieve the events from the store OR the remote peer
59    async fn retrieve_event(&self, event_ids: Vec<Self::Id>) -> Result<(usize, Vec<Attested<Self::Event>>), RetrievalError>;
60
61    /// Stage events for immediate retrieval without storage. Used when applying EventBridge deltas.
62    /// Staged events are available for lineage comparison at zero budget cost before being persisted.
63    fn stage_events(&self, events: impl IntoIterator<Item = Attested<Self::Event>>);
64
65    /// Mark an event as used. Used when applying EventBridge deltas.
66    fn mark_event_used(&self, event_id: &Self::Id);
67}
68
69#[async_trait]
70pub trait Retrieve: GetEvents {
71    // Each implementation of Retrieve determines whether to use local or remote storage
72    async fn get_state(&self, entity_id: EntityId) -> Result<Option<Attested<EntityState>>, RetrievalError>;
73}
74
75/// Durable node retriever - retrieves everything locally from storage
76#[derive(Clone)]
77pub struct LocalRetriever(Arc<LocalRetrieverInner>);
78struct LocalRetrieverInner {
79    collection: StorageCollectionWrapper,
80    // Tuple is (event, was_used)
81    staged_events: Mutex<Option<HashMap<EventId, (Attested<Event>, bool)>>>,
82}
83
84impl LocalRetriever {
85    pub fn new(collection: StorageCollectionWrapper) -> Self {
86        Self(Arc::new(LocalRetrieverInner { collection, staged_events: Mutex::new(Some(HashMap::new())) }))
87    }
88
89    pub async fn store_used_events(&mut self) -> Result<(), RetrievalError> {
90        let staged = { self.0.staged_events.lock().unwrap().take() };
91
92        if let Some(staged) = staged {
93            for (_id, (event, used)) in staged.iter() {
94                if *used {
95                    self.0.collection.add_event(event).await?;
96                }
97            }
98        }
99
100        Ok(())
101    }
102}
103
104#[async_trait]
105impl GetEvents for LocalRetriever {
106    type Id = EventId;
107    type Event = ankurah_proto::Event;
108
109    async fn retrieve_event(&self, event_ids: Vec<Self::Id>) -> Result<(usize, Vec<Attested<Self::Event>>), RetrievalError> {
110        let mut events = Vec::with_capacity(event_ids.len());
111        let mut event_ids: HashSet<Self::Id> = event_ids.into_iter().collect();
112
113        // First check staged events (zero cost)
114        {
115            if let Some(staged) = self.0.staged_events.lock().unwrap().as_mut() {
116                event_ids.retain(|id| {
117                    if let Some((event, used)) = staged.get_mut(id) {
118                        events.push(event.clone());
119                        *used = true;
120                        false
121                    } else {
122                        true
123                    }
124                });
125            }
126        }
127
128        if event_ids.is_empty() {
129            return Ok((0, events));
130        }
131
132        // staged events are free
133        // cost for local retrieval is 1 per batch
134
135        // Then retrieve from storage if needed
136        let stored_events = self.0.collection.get_events(event_ids.into_iter().collect()).await?;
137        events.extend(stored_events);
138
139        // TODO: push the consumption figure to the store, because its not necessarily the same for all stores
140        Ok((1, events))
141    }
142
143    fn stage_events(&self, events: impl IntoIterator<Item = Attested<Self::Event>>) {
144        let mut staged = self.0.staged_events.lock().unwrap();
145        let staged = staged.get_or_insert_with(|| HashMap::new());
146
147        for event in events.into_iter() {
148            staged.insert(event.payload.id(), (event, false));
149        }
150    }
151
152    fn mark_event_used(&self, event_id: &Self::Id) {
153        let mut staged = self.0.staged_events.lock().unwrap();
154        let staged = staged.get_or_insert_with(|| HashMap::new());
155        staged.get_mut(event_id).map(|(_, used)| {
156            *used = true;
157        });
158    }
159}
160
161#[async_trait]
162impl Retrieve for LocalRetriever {
163    async fn get_state(&self, entity_id: EntityId) -> Result<Option<Attested<EntityState>>, RetrievalError> {
164        match self.0.collection.get_state(entity_id).await {
165            Ok(state) => Ok(Some(state)),
166            Err(RetrievalError::EntityNotFound(_)) => Ok(None),
167            Err(e) => Err(e),
168        }
169    }
170}
171
172/// Ephemeral node retriever - retrieves events remotely, states locally, with multiple contexts for authentication
173pub struct EphemeralNodeRetriever<'a, SE, PA, C>
174where
175    SE: StorageEngine + Send + Sync + 'static,
176    PA: PolicyAgent + Send + Sync + 'static,
177    C: Iterable<PA::ContextData> + Send + Sync + 'a,
178{
179    pub collection: proto::CollectionId,
180    pub node: &'a Node<SE, PA>,
181    pub cdata: &'a C,
182    // Tuple is (event, was_used)
183    staged_events: Mutex<Option<HashMap<EventId, (Attested<Event>, bool)>>>,
184}
185
186impl<'a, SE, PA, C> EphemeralNodeRetriever<'a, SE, PA, C>
187where
188    SE: StorageEngine + Send + Sync + 'static,
189    PA: PolicyAgent + Send + Sync + 'static,
190    C: Iterable<PA::ContextData> + Send + Sync + 'a,
191{
192    pub fn new(collection: proto::CollectionId, node: &'a Node<SE, PA>, cdata: &'a C) -> Self {
193        Self { collection, node, cdata, staged_events: Mutex::new(Some(HashMap::new())) }
194    }
195
196    pub async fn store_used_events(&self) -> Result<(), MutationError> {
197        let staged = { self.staged_events.lock().unwrap().take() };
198
199        if let Some(staged) = staged {
200            // For ephemeral nodes, storing events is optional
201            // Only store if we actually want to persist them
202            let collection = self.node.system.collection(&self.collection).await?;
203            for (_id, (event, used)) in staged.iter() {
204                if *used {
205                    collection.add_event(event).await?;
206                }
207            }
208        }
209
210        Ok(())
211    }
212}
213
214#[async_trait]
215impl<'a, SE, PA, C> GetEvents for EphemeralNodeRetriever<'a, SE, PA, C>
216where
217    SE: StorageEngine + Send + Sync + 'static,
218    PA: PolicyAgent + Send + Sync + 'static,
219    C: Iterable<PA::ContextData> + Send + Sync + 'a,
220{
221    type Id = EventId;
222    type Event = Event;
223
224    async fn retrieve_event(&self, event_ids: Vec<Self::Id>) -> Result<(usize, Vec<Attested<Self::Event>>), RetrievalError> {
225        let mut events = Vec::with_capacity(event_ids.len());
226        let mut event_ids: HashSet<Self::Id> = event_ids.into_iter().collect();
227
228        // First check staged events (zero cost)
229        {
230            if let Some(staged) = self.staged_events.lock().unwrap().as_mut() {
231                event_ids.retain(|id| {
232                    if let Some((event, used)) = staged.get_mut(id) {
233                        events.push(event.clone());
234                        *used = true;
235                        false
236                    } else {
237                        true
238                    }
239                });
240            }
241        }
242
243        if event_ids.is_empty() {
244            return Ok((0, events));
245        }
246
247        // staged events are free
248        // cost for local retrieval is 1 per batch
249        // cost for remote retrieval is 5 per batch
250
251        // Then try to get events from local storage
252        let collection = self.node.system.collection(&self.collection).await?;
253        // TODO update get_events to take &HashSet
254        for event in collection.get_events(event_ids.iter().cloned().collect()).await? {
255            event_ids.remove(&event.payload.id());
256            events.push(event);
257        }
258
259        if event_ids.is_empty() {
260            return Ok((1, events));
261        }
262
263        // If we have missing events and a durable peer, try to fetch them
264        let Some(peer_id) = self.node.get_durable_peer_random() else {
265            return Ok((1, events)); // no durable peers - return what we have
266        };
267
268        match self
269            .node
270            .request(
271                peer_id,
272                self.cdata,
273                proto::NodeRequestBody::GetEvents { collection: self.collection.clone(), event_ids: event_ids.into_iter().collect() }, // TODO update ::GetEvents to take HashSet
274            )
275            .await?
276            // .map_err(|e| RetrievalError::StorageError(format!("Request failed: {}", e).into()))?
277        {
278            proto::NodeResponseBody::GetEvents(peer_events) => {
279                for event in peer_events.iter() {
280                    collection.add_event(event).await?;
281                }
282                events.extend(peer_events);
283            }
284            proto::NodeResponseBody::Error(e) => {
285                return Err(RetrievalError::StorageError(format!("Error from peer: {}", e).into()));
286            }
287            _ => return Err(RetrievalError::StorageError("Unexpected response type from peer".into())),
288        }
289        Ok((5, events))
290    }
291
292    fn stage_events(&self, events: impl IntoIterator<Item = Attested<Self::Event>>) {
293        let mut staged = self.staged_events.lock().unwrap();
294        let staged = staged.get_or_insert_with(|| HashMap::new());
295
296        for event in events.into_iter() {
297            staged.insert(event.payload.id(), (event, false));
298        }
299    }
300
301    fn mark_event_used(&self, event_id: &Self::Id) {
302        let mut staged = self.staged_events.lock().unwrap();
303        let staged = staged.get_or_insert_with(|| HashMap::new());
304        staged.get_mut(event_id).map(|(_, used)| {
305            *used = true;
306        });
307    }
308}
309
310#[async_trait]
311impl<'a, SE, PA, C> Retrieve for EphemeralNodeRetriever<'a, SE, PA, C>
312where
313    SE: StorageEngine + Send + Sync + 'static,
314    PA: PolicyAgent + Send + Sync + 'static,
315    C: Iterable<PA::ContextData> + Send + Sync + 'a,
316{
317    async fn get_state(&self, entity_id: EntityId) -> Result<Option<Attested<EntityState>>, RetrievalError> {
318        let collection = self.node.collections.get(&self.collection).await?;
319        match collection.get_state(entity_id).await {
320            Ok(state) => Ok(Some(state)),
321            Err(RetrievalError::EntityNotFound(_)) => Ok(None),
322            Err(e) => Err(e),
323        }
324    }
325}