use alloc::{
boxed::Box,
collections::{btree_map, BTreeMap},
string::String,
sync::Arc,
vec::Vec,
};
use core::{cmp::Ordering, iter::Peekable, marker::PhantomData, mem, ops::Bound};
use aranya_buggy::{bug, Bug, BugExt};
use serde::{Deserialize, Serialize};
use yoke::{Yoke, Yokeable};
use crate::{
Address, Checkpoint, ClientError, ClientState, Command, CommandId, CommandRecall, Engine, Fact,
FactPerspective, GraphId, Keys, NullSink, Perspective, Policy, PolicyId, Prior, Priority,
Query, QueryMut, Revertable, Segment, Sink, Storage, StorageError, StorageProvider,
MAX_COMMAND_LENGTH,
};
type Bytes = Box<[u8]>;
pub struct Session<SP: StorageProvider, E> {
storage_id: GraphId,
policy_id: PolicyId,
base_facts: <SP::Storage as Storage>::FactIndex,
fact_log: Vec<(String, Keys, Option<Bytes>)>,
current_facts: Arc<BTreeMap<String, BTreeMap<Keys, Option<Bytes>>>>,
_engine: PhantomData<E>,
head: Address,
}
struct SessionPerspective<'a, SP: StorageProvider, E, MS> {
session: &'a mut Session<SP, E>,
message_sink: &'a mut MS,
}
impl<SP: StorageProvider, E> Session<SP, E> {
pub(super) fn new(provider: &mut SP, storage_id: GraphId) -> Result<Self, ClientError> {
let storage = provider.get_storage(storage_id)?;
let head_loc = storage.get_head()?;
let seg = storage.get_segment(head_loc)?;
let command = seg.get_command(head_loc).assume("location must exist")?;
let base_facts = seg.facts()?;
let result = Self {
storage_id,
policy_id: seg.policy(),
base_facts,
fact_log: Vec::new(),
current_facts: Arc::default(),
_engine: PhantomData,
head: command.address()?,
};
Ok(result)
}
}
impl<SP: StorageProvider, E: Engine> Session<SP, E> {
pub fn action<ES, MS>(
&mut self,
client: &ClientState<E, SP>,
effect_sink: &mut ES,
message_sink: &mut MS,
action: <E::Policy as Policy>::Action<'_>,
) -> Result<(), ClientError>
where
ES: Sink<E::Effect>,
MS: for<'b> Sink<&'b [u8]>,
{
let policy = client.engine.get_policy(self.policy_id)?;
let mut perspective = SessionPerspective {
session: self,
message_sink,
};
let checkpoint = perspective.checkpoint();
effect_sink.begin();
match policy.call_action(action, &mut perspective, effect_sink) {
Ok(_) => {
effect_sink.commit();
Ok(())
}
Err(e) => {
perspective.revert(checkpoint)?;
perspective.message_sink.rollback();
effect_sink.rollback();
Err(e.into())
}
}
}
pub fn receive(
&mut self,
client: &ClientState<E, SP>,
sink: &mut impl Sink<E::Effect>,
command_bytes: &[u8],
) -> Result<(), ClientError> {
let command: SessionCommand<'_> =
postcard::from_bytes(command_bytes).map_err(ClientError::SessionDeserialize)?;
if command.storage_id != self.storage_id {
bug!("ephemeral commands must be run on the same graph");
}
let policy = client.engine.get_policy(self.policy_id)?;
let mut perspective = SessionPerspective {
session: self,
message_sink: &mut NullSink,
};
sink.begin();
let checkpoint = perspective.checkpoint();
if let Err(e) = policy.call_rule(&command, &mut perspective, sink, CommandRecall::None) {
perspective.revert(checkpoint)?;
sink.rollback();
return Err(e.into());
}
sink.commit();
Ok(())
}
}
#[derive(Serialize, Deserialize)]
struct SessionCommand<'a> {
storage_id: GraphId,
priority: u32, id: CommandId,
parent: Address, #[serde(borrow)]
data: &'a [u8],
}
impl Command for SessionCommand<'_> {
fn priority(&self) -> Priority {
Priority::Basic(self.priority)
}
fn id(&self) -> CommandId {
self.id
}
fn parent(&self) -> Prior<Address> {
Prior::Single(self.parent)
}
fn policy(&self) -> Option<&[u8]> {
None
}
fn bytes(&self) -> &[u8] {
self.data
}
}
impl<'sc> SessionCommand<'sc> {
fn from_cmd(storage_id: GraphId, command: &'sc impl Command) -> Result<Self, Bug> {
if command.policy().is_some() {
bug!("session command should have no policy")
}
Ok(SessionCommand {
storage_id,
priority: match command.priority() {
Priority::Basic(p) => p,
_ => bug!("wrong command type"),
},
id: command.id(),
parent: match command.parent() {
Prior::Single(p) => p,
_ => bug!("wrong command type"),
},
data: command.bytes(),
})
}
}
struct QueryIterator<I1: Iterator, I2: Iterator> {
prior: Peekable<I1>,
current: Peekable<I2>,
}
impl<I1, I2> QueryIterator<I1, I2>
where
I1: Iterator<Item = Result<Fact, StorageError>>,
I2: Iterator<Item = (Keys, Option<Bytes>)>,
{
fn new(prior: I1, current: I2) -> Self {
Self {
prior: prior.peekable(),
current: current.peekable(),
}
}
}
impl<I1, I2> Iterator for QueryIterator<I1, I2>
where
I1: Iterator<Item = Result<Fact, StorageError>>,
I2: Iterator<Item = (Keys, Option<Bytes>)>,
{
type Item = Result<Fact, StorageError>;
fn next(&mut self) -> Option<Self::Item> {
loop {
let Some(new) = self.current.peek() else {
return self.prior.next();
};
if let Some(old) = self.prior.peek() {
let Ok(old) = old else {
return self.prior.next();
};
match new.0.cmp(&old.key) {
Ordering::Equal => {
let _ = self.prior.next();
}
Ordering::Greater => {
return self.prior.next();
}
Ordering::Less => {
}
}
}
let Some(slot) = self.current.next() else {
bug!("expected Some after peek")
};
if let (k, Some(v)) = slot {
return Some(Ok(Fact {
key: k.iter().cloned().collect(),
value: v,
}));
}
}
}
}
impl<SP, E, MS> FactPerspective for SessionPerspective<'_, SP, E, MS> where SP: StorageProvider {}
impl<SP, E, MS> Query for SessionPerspective<'_, SP, E, MS>
where
SP: StorageProvider,
{
fn query(&self, name: &str, keys: &[Box<[u8]>]) -> Result<Option<Box<[u8]>>, StorageError> {
if let Some(slot) = self
.session
.current_facts
.get(name)
.and_then(|m| m.get(keys))
{
return Ok(slot.clone());
}
self.session.base_facts.query(name, keys)
}
type QueryIterator = QueryIterator<
<<SP::Storage as Storage>::FactIndex as Query>::QueryIterator,
YokeIter<PrefixIter<'static>, Arc<BTreeMap<String, BTreeMap<Keys, Option<Bytes>>>>>,
>;
fn query_prefix(
&self,
name: &str,
prefix: &[Box<[u8]>],
) -> Result<Self::QueryIterator, StorageError> {
let prior = self.session.base_facts.query_prefix(name, prefix)?;
let current = Yoke::<PrefixIter<'static>, _>::attach_to_cart(
Arc::clone(&self.session.current_facts),
|map| match map.get(name) {
Some(facts) => PrefixIter::new(facts, prefix.iter().cloned().collect()),
None => PrefixIter::default(),
},
);
Ok(QueryIterator::new(prior, YokeIter::new(current)))
}
}
#[derive(Default, Yokeable)]
struct PrefixIter<'map> {
range: btree_map::Range<'map, Keys, Option<Bytes>>,
prefix: Keys,
}
impl<'map> PrefixIter<'map> {
fn new(map: &'map BTreeMap<Keys, Option<Bytes>>, prefix: Keys) -> Self {
let range =
map.range::<[Box<[u8]>], _>((Bound::Included(prefix.as_ref()), Bound::Unbounded));
Self { range, prefix }
}
}
impl Iterator for PrefixIter<'_> {
type Item = (Keys, Option<Bytes>);
fn next(&mut self) -> Option<Self::Item> {
self.range
.next()
.filter(|(k, _)| k.starts_with(&self.prefix))
.map(|(k, v)| (k.clone(), v.clone()))
}
}
struct YokeIter<I: for<'a> Yokeable<'a>, C>(Option<Yoke<I, C>>);
impl<I: for<'a> Yokeable<'a>, C> YokeIter<I, C> {
fn new(yoke: Yoke<I, C>) -> Self {
Self(Some(yoke))
}
}
impl<I, C> Iterator for YokeIter<I, C>
where
I: Iterator + for<'a> Yokeable<'a>,
for<'a> <I as Yokeable<'a>>::Output: Iterator<Item = I::Item>,
{
type Item = I::Item;
fn next(&mut self) -> Option<Self::Item> {
let mut item = None;
self.0 = Some(self.0.take()?.map_project::<I, _>(|mut it, _| {
item = it.next();
it
}));
item
}
}
impl<SP: StorageProvider, E, MS> QueryMut for SessionPerspective<'_, SP, E, MS> {
fn insert(&mut self, name: String, keys: Keys, value: Box<[u8]>) {
self.session
.fact_log
.push((name.clone(), keys.clone(), Some(value.clone())));
Arc::make_mut(&mut self.session.current_facts)
.entry(name)
.or_default()
.insert(keys, Some(value));
}
fn delete(&mut self, name: String, keys: Keys) {
self.session
.fact_log
.push((name.clone(), keys.clone(), None));
Arc::make_mut(&mut self.session.current_facts)
.entry(name)
.or_default()
.insert(keys, None);
}
}
impl<SP, E, MS> Perspective for SessionPerspective<'_, SP, E, MS>
where
SP: StorageProvider,
MS: for<'b> Sink<&'b [u8]>,
{
fn policy(&self) -> PolicyId {
self.session.policy_id
}
fn add_command(&mut self, command: &impl Command) -> Result<usize, StorageError> {
let command = SessionCommand::from_cmd(self.session.storage_id, command)?;
self.session.head = command.address()?;
let mut buf = [0u8; MAX_COMMAND_LENGTH];
let bytes = postcard::to_slice(&command, &mut buf).assume("can serialize")?;
self.message_sink.consume(bytes);
Ok(0)
}
fn includes(&self, _id: CommandId) -> bool {
debug_assert!(false, "only used in transactions");
false
}
fn head_address(&self) -> Result<Prior<Address>, Bug> {
Ok(Prior::Single(self.session.head))
}
}
impl<SP, E, MS> Revertable for SessionPerspective<'_, SP, E, MS>
where
SP: StorageProvider,
{
fn checkpoint(&self) -> Checkpoint {
Checkpoint {
index: self.session.fact_log.len(),
}
}
fn revert(&mut self, checkpoint: Checkpoint) -> Result<(), Bug> {
if checkpoint.index == self.session.fact_log.len() {
return Ok(());
}
if checkpoint.index > self.session.fact_log.len() {
bug!("A checkpoint's index should always be less than or equal to the length of a session's fact log!");
}
self.session.fact_log.truncate(checkpoint.index);
let mut facts =
Arc::get_mut(&mut self.session.current_facts).map_or_else(BTreeMap::new, mem::take);
facts.clear();
for (n, k, v) in self.session.fact_log.iter().cloned() {
facts.entry(n).or_default().insert(k, v);
}
self.session.current_facts = Arc::new(facts);
Ok(())
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_query_iterator() {
#![allow(clippy::type_complexity)]
let prior: Vec<Result<(&[&[u8]], &[u8]), _>> = vec![
Ok((&[b"a"], b"a0")),
Ok((&[b"c"], b"c0")),
Ok((&[b"d"], b"d0")),
Ok((&[b"f"], b"f0")),
Err(StorageError::IoError),
];
let current: Vec<([Box<[u8]>; 1], Option<&[u8]>)> = vec![
([Box::new(*b"a")], None),
([Box::new(*b"b")], Some(b"b1")),
([Box::new(*b"e")], None),
([Box::new(*b"j")], None),
];
let merged: Vec<Result<(&[&[u8]], &[u8]), _>> = vec![
Ok((&[b"b"], b"b1")),
Ok((&[b"c"], b"c0")),
Ok((&[b"d"], b"d0")),
Ok((&[b"f"], b"f0")),
Err(StorageError::IoError),
];
let got: Vec<_> = QueryIterator::new(
prior.into_iter().map(|r| {
r.map(|(k, v)| Fact {
key: k.into(),
value: v.into(),
})
}),
current
.into_iter()
.map(|(k, v)| (k.into_iter().collect(), v.map(Box::from))),
)
.collect();
let want: Vec<_> = merged
.into_iter()
.map(|r| {
r.map(|(k, v)| Fact {
key: k.into(),
value: v.into(),
})
})
.collect();
assert_eq!(got, want);
}
}