use std::{
collections::{HashMap, HashSet},
sync::{Arc, Mutex},
};
use crate::{
error::{MutationError, RetrievalError},
policy::PolicyAgent,
storage::{StorageCollectionWrapper, StorageEngine},
util::Iterable,
Node,
};
use ankurah_proto::{self as proto, Attested, Clock, EntityId, EntityState, Event, EventId};
use async_trait::async_trait;
pub trait TEvent: std::fmt::Display {
type Id: Eq + PartialEq + Clone;
type Parent: TClock<Id = Self::Id>;
fn id(&self) -> Self::Id;
fn parent(&self) -> &Self::Parent;
}
pub trait TClock {
type Id: Eq + PartialEq + Clone;
fn members(&self) -> &[Self::Id];
}
impl TClock for Clock {
type Id = EventId;
fn members(&self) -> &[Self::Id] { self.as_slice() }
}
impl TEvent for ankurah_proto::Event {
type Id = ankurah_proto::EventId;
type Parent = Clock;
fn id(&self) -> EventId { self.id() }
fn parent(&self) -> &Clock { &self.parent }
}
#[async_trait]
pub trait GetEvents {
type Id: Eq + PartialEq + Clone + std::fmt::Debug + Send + Sync;
type Event: TEvent<Id = Self::Id> + std::fmt::Display;
fn estimate_cost(&self, _batch_size: usize) -> usize {
1
}
async fn retrieve_event(&self, event_ids: Vec<Self::Id>) -> Result<(usize, Vec<Attested<Self::Event>>), RetrievalError>;
fn stage_events(&self, events: impl IntoIterator<Item = Attested<Self::Event>>);
fn mark_event_used(&self, event_id: &Self::Id);
}
#[async_trait]
pub trait Retrieve: GetEvents {
async fn get_state(&self, entity_id: EntityId) -> Result<Option<Attested<EntityState>>, RetrievalError>;
}
#[derive(Clone)]
pub struct LocalRetriever(Arc<LocalRetrieverInner>);
struct LocalRetrieverInner {
collection: StorageCollectionWrapper,
staged_events: Mutex<Option<HashMap<EventId, (Attested<Event>, bool)>>>,
}
impl LocalRetriever {
pub fn new(collection: StorageCollectionWrapper) -> Self {
Self(Arc::new(LocalRetrieverInner { collection, staged_events: Mutex::new(Some(HashMap::new())) }))
}
pub async fn store_used_events(&mut self) -> Result<(), RetrievalError> {
let staged = { self.0.staged_events.lock().unwrap().take() };
if let Some(staged) = staged {
for (_id, (event, used)) in staged.iter() {
if *used {
self.0.collection.add_event(event).await?;
}
}
}
Ok(())
}
}
#[async_trait]
impl GetEvents for LocalRetriever {
type Id = EventId;
type Event = ankurah_proto::Event;
async fn retrieve_event(&self, event_ids: Vec<Self::Id>) -> Result<(usize, Vec<Attested<Self::Event>>), RetrievalError> {
let mut events = Vec::with_capacity(event_ids.len());
let mut event_ids: HashSet<Self::Id> = event_ids.into_iter().collect();
{
if let Some(staged) = self.0.staged_events.lock().unwrap().as_mut() {
event_ids.retain(|id| {
if let Some((event, used)) = staged.get_mut(id) {
events.push(event.clone());
*used = true;
false
} else {
true
}
});
}
}
if event_ids.is_empty() {
return Ok((0, events));
}
let stored_events = self.0.collection.get_events(event_ids.into_iter().collect()).await?;
events.extend(stored_events);
Ok((1, events))
}
fn stage_events(&self, events: impl IntoIterator<Item = Attested<Self::Event>>) {
let mut staged = self.0.staged_events.lock().unwrap();
let staged = staged.get_or_insert_with(|| HashMap::new());
for event in events.into_iter() {
staged.insert(event.payload.id(), (event, false));
}
}
fn mark_event_used(&self, event_id: &Self::Id) {
let mut staged = self.0.staged_events.lock().unwrap();
let staged = staged.get_or_insert_with(|| HashMap::new());
staged.get_mut(event_id).map(|(_, used)| {
*used = true;
});
}
}
#[async_trait]
impl Retrieve for LocalRetriever {
async fn get_state(&self, entity_id: EntityId) -> Result<Option<Attested<EntityState>>, RetrievalError> {
match self.0.collection.get_state(entity_id).await {
Ok(state) => Ok(Some(state)),
Err(RetrievalError::EntityNotFound(_)) => Ok(None),
Err(e) => Err(e),
}
}
}
pub struct EphemeralNodeRetriever<'a, SE, PA, C>
where
SE: StorageEngine + Send + Sync + 'static,
PA: PolicyAgent + Send + Sync + 'static,
C: Iterable<PA::ContextData> + Send + Sync + 'a,
{
pub collection: proto::CollectionId,
pub node: &'a Node<SE, PA>,
pub cdata: &'a C,
staged_events: Mutex<Option<HashMap<EventId, (Attested<Event>, bool)>>>,
}
impl<'a, SE, PA, C> EphemeralNodeRetriever<'a, SE, PA, C>
where
SE: StorageEngine + Send + Sync + 'static,
PA: PolicyAgent + Send + Sync + 'static,
C: Iterable<PA::ContextData> + Send + Sync + 'a,
{
pub fn new(collection: proto::CollectionId, node: &'a Node<SE, PA>, cdata: &'a C) -> Self {
Self { collection, node, cdata, staged_events: Mutex::new(Some(HashMap::new())) }
}
pub async fn store_used_events(&self) -> Result<(), MutationError> {
let staged = { self.staged_events.lock().unwrap().take() };
if let Some(staged) = staged {
let collection = self.node.system.collection(&self.collection).await?;
for (_id, (event, used)) in staged.iter() {
if *used {
collection.add_event(event).await?;
}
}
}
Ok(())
}
}
#[async_trait]
impl<'a, SE, PA, C> GetEvents for EphemeralNodeRetriever<'a, SE, PA, C>
where
SE: StorageEngine + Send + Sync + 'static,
PA: PolicyAgent + Send + Sync + 'static,
C: Iterable<PA::ContextData> + Send + Sync + 'a,
{
type Id = EventId;
type Event = Event;
async fn retrieve_event(&self, event_ids: Vec<Self::Id>) -> Result<(usize, Vec<Attested<Self::Event>>), RetrievalError> {
let mut events = Vec::with_capacity(event_ids.len());
let mut event_ids: HashSet<Self::Id> = event_ids.into_iter().collect();
{
if let Some(staged) = self.staged_events.lock().unwrap().as_mut() {
event_ids.retain(|id| {
if let Some((event, used)) = staged.get_mut(id) {
events.push(event.clone());
*used = true;
false
} else {
true
}
});
}
}
if event_ids.is_empty() {
return Ok((0, events));
}
let collection = self.node.system.collection(&self.collection).await?;
for event in collection.get_events(event_ids.iter().cloned().collect()).await? {
event_ids.remove(&event.payload.id());
events.push(event);
}
if event_ids.is_empty() {
return Ok((1, events));
}
let Some(peer_id) = self.node.get_durable_peer_random() else {
return Ok((1, events)); };
match self
.node
.request(
peer_id,
self.cdata,
proto::NodeRequestBody::GetEvents { collection: self.collection.clone(), event_ids: event_ids.into_iter().collect() }, )
.await?
{
proto::NodeResponseBody::GetEvents(peer_events) => {
for event in peer_events.iter() {
collection.add_event(event).await?;
}
events.extend(peer_events);
}
proto::NodeResponseBody::Error(e) => {
return Err(RetrievalError::StorageError(format!("Error from peer: {}", e).into()));
}
_ => return Err(RetrievalError::StorageError("Unexpected response type from peer".into())),
}
Ok((5, events))
}
fn stage_events(&self, events: impl IntoIterator<Item = Attested<Self::Event>>) {
let mut staged = self.staged_events.lock().unwrap();
let staged = staged.get_or_insert_with(|| HashMap::new());
for event in events.into_iter() {
staged.insert(event.payload.id(), (event, false));
}
}
fn mark_event_used(&self, event_id: &Self::Id) {
let mut staged = self.staged_events.lock().unwrap();
let staged = staged.get_or_insert_with(|| HashMap::new());
staged.get_mut(event_id).map(|(_, used)| {
*used = true;
});
}
}
#[async_trait]
impl<'a, SE, PA, C> Retrieve for EphemeralNodeRetriever<'a, SE, PA, C>
where
SE: StorageEngine + Send + Sync + 'static,
PA: PolicyAgent + Send + Sync + 'static,
C: Iterable<PA::ContextData> + Send + Sync + 'a,
{
async fn get_state(&self, entity_id: EntityId) -> Result<Option<Attested<EntityState>>, RetrievalError> {
let collection = self.node.collections.get(&self.collection).await?;
match collection.get_state(entity_id).await {
Ok(state) => Ok(Some(state)),
Err(RetrievalError::EntityNotFound(_)) => Ok(None),
Err(e) => Err(e),
}
}
}