aranya_runtime/client/
session.rs

1//! Ephemeral sessions for off-graph commands.
2//!
3//! See [`ClientState::session`] and [`Session`].
4//!
5//! Design doc: [Aranya Sessions](https://github.com/aranya-project/aranya-docs/blob/main/src/Aranya-Sessions-note.md)
6
7use alloc::{
8    boxed::Box,
9    collections::{BTreeMap, btree_map},
10    string::String,
11    sync::Arc,
12    vec::Vec,
13};
14use core::{cmp::Ordering, iter::Peekable, marker::PhantomData, mem, ops::Bound};
15
16use buggy::{Bug, BugExt, bug};
17use serde::{Deserialize, Serialize};
18use tracing::warn;
19use yoke::{Yoke, Yokeable};
20
21use crate::{
22    Address, Checkpoint, ClientError, ClientState, CmdId, Command, CommandRecall, Engine, Fact,
23    FactPerspective, GraphId, Keys, NullSink, Perspective, Policy, PolicyId, Prior, Priority,
24    Query, QueryMut, Revertable, Segment, Sink, Storage, StorageError, StorageProvider,
25};
26
27type Bytes = Box<[u8]>;
28
29/// Ephemeral session used to handle/generate off-graph commands.
30pub struct Session<SP: StorageProvider, E> {
31    /// The ID of the associated storage.
32    storage_id: GraphId,
33    /// The policy ID for the session.
34    policy_id: PolicyId,
35
36    /// The prior facts from the graph head.
37    base_facts: <SP::Storage as Storage>::FactIndex,
38    /// The log of facts in insertion order.
39    fact_log: Vec<(String, Keys, Option<Bytes>)>,
40    /// The current facts of the session, relative to `base_facts`.
41    current_facts: Arc<BTreeMap<String, BTreeMap<Keys, Option<Bytes>>>>,
42
43    /// Tag for associated engine.
44    _engine: PhantomData<E>,
45
46    head: Address,
47}
48
49struct SessionPerspective<'a, SP: StorageProvider, E, MS> {
50    session: &'a mut Session<SP, E>,
51    message_sink: &'a mut MS,
52}
53
54impl<SP: StorageProvider, E> Session<SP, E> {
55    pub(super) fn new(provider: &mut SP, storage_id: GraphId) -> Result<Self, ClientError> {
56        let storage = provider.get_storage(storage_id)?;
57        let head_loc = storage.get_head()?;
58        let seg = storage.get_segment(head_loc)?;
59        let command = seg.get_command(head_loc).assume("location must exist")?;
60
61        let base_facts = seg.facts()?;
62
63        let result = Self {
64            storage_id,
65            policy_id: seg.policy(),
66            base_facts,
67            fact_log: Vec::new(),
68            current_facts: Arc::default(),
69            _engine: PhantomData,
70            head: command.address()?,
71        };
72
73        Ok(result)
74    }
75}
76
77impl<SP: StorageProvider, E: Engine> Session<SP, E> {
78    /// Evaluate an action on the ephemeral session and generate serialized
79    /// commands, so another client can [`Session::receive`] them.
80    pub fn action<ES, MS>(
81        &mut self,
82        client: &ClientState<E, SP>,
83        effect_sink: &mut ES,
84        message_sink: &mut MS,
85        action: <E::Policy as Policy>::Action<'_>,
86    ) -> Result<(), ClientError>
87    where
88        ES: Sink<E::Effect>,
89        MS: for<'b> Sink<&'b [u8]>,
90    {
91        let policy = client.engine.get_policy(self.policy_id)?;
92
93        // Use a special perspective so we can send to the message sink.
94        let mut perspective = SessionPerspective {
95            session: self,
96            message_sink,
97        };
98        let checkpoint = perspective.checkpoint();
99        effect_sink.begin();
100
101        // Try to perform action.
102        match policy.call_action(action, &mut perspective, effect_sink) {
103            Ok(_) => {
104                // Success, commit effects
105                effect_sink.commit();
106                Ok(())
107            }
108            Err(e) => {
109                // Other error, revert all? See #513.
110                perspective.revert(checkpoint)?;
111                perspective.message_sink.rollback();
112                effect_sink.rollback();
113                Err(e.into())
114            }
115        }
116    }
117
118    /// Handle a command from another client generated by [`Session::action`].
119    ///
120    /// You do NOT need to reprocess the commands from actions generated in the
121    /// same session.
122    pub fn receive(
123        &mut self,
124        client: &ClientState<E, SP>,
125        sink: &mut impl Sink<E::Effect>,
126        command_bytes: &[u8],
127    ) -> Result<(), ClientError> {
128        let command: SessionCommand<'_> =
129            postcard::from_bytes(command_bytes).map_err(ClientError::SessionDeserialize)?;
130
131        if command.storage_id != self.storage_id {
132            warn!(%command.storage_id, %self.storage_id, "ephemeral commands must be run on the same graph");
133            return Err(ClientError::NotAuthorized);
134        }
135
136        let policy = client.engine.get_policy(self.policy_id)?;
137
138        // Use a special perspective which doesn't check the head
139        let mut perspective = SessionPerspective {
140            session: self,
141            message_sink: &mut NullSink,
142        };
143
144        // Try to evaluate command.
145        sink.begin();
146        let checkpoint = perspective.checkpoint();
147        if let Err(e) = policy.call_rule(&command, &mut perspective, sink, CommandRecall::None) {
148            perspective.revert(checkpoint)?;
149            sink.rollback();
150            return Err(e.into());
151        }
152        sink.commit();
153
154        Ok(())
155    }
156}
157
158#[derive(Serialize, Deserialize)]
159/// Used for serializing session commands
160struct SessionCommand<'a> {
161    storage_id: GraphId,
162    priority: u32, // Priority::Basic
163    id: CmdId,
164    parent: Address, // Prior::Single
165    #[serde(borrow)]
166    data: &'a [u8],
167}
168
169impl Command for SessionCommand<'_> {
170    fn priority(&self) -> Priority {
171        Priority::Basic(self.priority)
172    }
173
174    fn id(&self) -> CmdId {
175        self.id
176    }
177
178    fn parent(&self) -> Prior<Address> {
179        Prior::Single(self.parent)
180    }
181
182    fn policy(&self) -> Option<&[u8]> {
183        // Session commands should never have policy?
184        None
185    }
186
187    fn bytes(&self) -> &[u8] {
188        self.data
189    }
190}
191
192impl<'sc> SessionCommand<'sc> {
193    fn from_cmd(storage_id: GraphId, command: &'sc impl Command) -> Result<Self, Bug> {
194        if command.policy().is_some() {
195            bug!("session command should have no policy")
196        }
197        Ok(SessionCommand {
198            storage_id,
199            priority: match command.priority() {
200                Priority::Basic(p) => p,
201                _ => bug!("wrong command type"),
202            },
203            id: command.id(),
204            parent: match command.parent() {
205                Prior::Single(p) => p,
206                _ => bug!("wrong command type"),
207            },
208            data: command.bytes(),
209        })
210    }
211}
212
213/// Query iterator for SessionPerspective which wraps an inner query iterator
214struct QueryIterator<I1: Iterator, I2: Iterator> {
215    prior: Peekable<I1>,
216    current: Peekable<I2>,
217}
218
219impl<I1, I2> QueryIterator<I1, I2>
220where
221    I1: Iterator<Item = Result<Fact, StorageError>>,
222    I2: Iterator<Item = (Keys, Option<Bytes>)>,
223{
224    fn new(prior: I1, current: I2) -> Self {
225        Self {
226            prior: prior.peekable(),
227            current: current.peekable(),
228        }
229    }
230}
231
232impl<I1, I2> Iterator for QueryIterator<I1, I2>
233where
234    I1: Iterator<Item = Result<Fact, StorageError>>,
235    I2: Iterator<Item = (Keys, Option<Bytes>)>,
236{
237    type Item = Result<Fact, StorageError>;
238
239    fn next(&mut self) -> Option<Self::Item> {
240        // We find the next lowest item between the two iterators,
241        // while also ensuring that newer entries overwrite older.
242        // We loop so we can skip over deleted facts.
243
244        loop {
245            let Some(new) = self.current.peek() else {
246                // If current has run out, just use prior.
247                return self.prior.next();
248            };
249            if let Some(old) = self.prior.peek() {
250                let Ok(old) = old else {
251                    // Bubble up errors as soon as possible, instead of returning `new`.
252                    return self.prior.next();
253                };
254                match new.0.cmp(&old.key) {
255                    Ordering::Equal => {
256                        // new overwrites old.
257                        let _ = self.prior.next();
258                    }
259                    Ordering::Greater => {
260                        // old comes next in sorted order.
261                        return self.prior.next();
262                    }
263                    Ordering::Less => {
264                        // new comes next in sorted order.
265                    }
266                }
267            }
268            let Some(slot) = self.current.next() else {
269                bug!("expected Some after peek")
270            };
271            if let (k, Some(v)) = slot {
272                return Some(Ok(Fact {
273                    key: k.iter().cloned().collect(),
274                    value: v,
275                }));
276            }
277        }
278    }
279}
280
281impl<SP, E, MS> FactPerspective for SessionPerspective<'_, SP, E, MS> where SP: StorageProvider {}
282
283impl<SP, E, MS> Query for SessionPerspective<'_, SP, E, MS>
284where
285    SP: StorageProvider,
286{
287    fn query(&self, name: &str, keys: &[Box<[u8]>]) -> Result<Option<Box<[u8]>>, StorageError> {
288        if let Some(slot) = self
289            .session
290            .current_facts
291            .get(name)
292            .and_then(|m| m.get(keys))
293        {
294            return Ok(slot.clone());
295        }
296        self.session.base_facts.query(name, keys)
297    }
298
299    type QueryIterator = QueryIterator<
300        <<SP::Storage as Storage>::FactIndex as Query>::QueryIterator,
301        YokeIter<PrefixIter<'static>, Arc<BTreeMap<String, BTreeMap<Keys, Option<Bytes>>>>>,
302    >;
303    fn query_prefix(
304        &self,
305        name: &str,
306        prefix: &[Box<[u8]>],
307    ) -> Result<Self::QueryIterator, StorageError> {
308        let prior = self.session.base_facts.query_prefix(name, prefix)?;
309        let current = Yoke::<PrefixIter<'static>, _>::attach_to_cart(
310            Arc::clone(&self.session.current_facts),
311            |map| match map.get(name) {
312                Some(facts) => PrefixIter::new(facts, prefix.iter().cloned().collect()),
313                None => PrefixIter::default(),
314            },
315        );
316        Ok(QueryIterator::new(prior, YokeIter::new(current)))
317    }
318}
319
320/// Iterator over matching prefix of a [`BTreeMap`].
321///
322/// Equivalent to `map.range(&prefix..).take_while(move |(k, _)| k.starts_with(prefix))`,
323/// but nameable and [`Yokeable`].
324#[derive(Default, Yokeable)]
325struct PrefixIter<'map> {
326    range: btree_map::Range<'map, Keys, Option<Bytes>>,
327    prefix: Keys,
328}
329
330impl<'map> PrefixIter<'map> {
331    fn new(map: &'map BTreeMap<Keys, Option<Bytes>>, prefix: Keys) -> Self {
332        let range =
333            map.range::<[Box<[u8]>], _>((Bound::Included(prefix.as_ref()), Bound::Unbounded));
334        Self { range, prefix }
335    }
336}
337
338impl Iterator for PrefixIter<'_> {
339    type Item = (Keys, Option<Bytes>);
340
341    fn next(&mut self) -> Option<Self::Item> {
342        self.range
343            .next()
344            .filter(|(k, _)| k.starts_with(&self.prefix))
345            .map(|(k, v)| (k.clone(), v.clone()))
346    }
347}
348
349/// Wrapper around [`Yoke`] which implements [`Iterator`].
350struct YokeIter<I: for<'a> Yokeable<'a>, C>(Option<Yoke<I, C>>);
351
352impl<I: for<'a> Yokeable<'a>, C> YokeIter<I, C> {
353    fn new(yoke: Yoke<I, C>) -> Self {
354        Self(Some(yoke))
355    }
356}
357
358impl<I, C> Iterator for YokeIter<I, C>
359where
360    I: Iterator + for<'a> Yokeable<'a>,
361    for<'a> <I as Yokeable<'a>>::Output: Iterator<Item = I::Item>,
362{
363    type Item = I::Item;
364
365    fn next(&mut self) -> Option<Self::Item> {
366        // `Yoke::map_project` is currently the only way to mutate something in a yoke and get out a value.
367        // It takes the yoke by value though, so we have to `take` it so we can own it temporarily.
368        let mut item = None;
369        self.0 = Some(self.0.take()?.map_project::<I, _>(|mut it, _| {
370            item = it.next();
371            it
372        }));
373        item
374    }
375}
376
377impl<SP: StorageProvider, E, MS> QueryMut for SessionPerspective<'_, SP, E, MS> {
378    fn insert(&mut self, name: String, keys: Keys, value: Box<[u8]>) {
379        self.session
380            .fact_log
381            .push((name.clone(), keys.clone(), Some(value.clone())));
382        Arc::make_mut(&mut self.session.current_facts)
383            .entry(name)
384            .or_default()
385            .insert(keys, Some(value));
386    }
387
388    fn delete(&mut self, name: String, keys: Keys) {
389        self.session
390            .fact_log
391            .push((name.clone(), keys.clone(), None));
392        Arc::make_mut(&mut self.session.current_facts)
393            .entry(name)
394            .or_default()
395            .insert(keys, None);
396    }
397}
398
399impl<SP, E, MS> Perspective for SessionPerspective<'_, SP, E, MS>
400where
401    SP: StorageProvider,
402    MS: for<'b> Sink<&'b [u8]>,
403{
404    fn policy(&self) -> PolicyId {
405        self.session.policy_id
406    }
407
408    fn add_command(&mut self, command: &impl Command) -> Result<usize, StorageError> {
409        let command = SessionCommand::from_cmd(self.session.storage_id, command)?;
410        self.session.head = command.address()?;
411        let bytes = postcard::to_allocvec(&command).assume("serialize session command")?;
412        self.message_sink.consume(&bytes);
413
414        Ok(0)
415    }
416
417    fn includes(&self, _id: CmdId) -> bool {
418        debug_assert!(false, "only used in transactions");
419
420        false
421    }
422
423    fn head_address(&self) -> Result<Prior<Address>, Bug> {
424        Ok(Prior::Single(self.session.head))
425    }
426}
427
428impl<SP, E, MS> Revertable for SessionPerspective<'_, SP, E, MS>
429where
430    SP: StorageProvider,
431{
432    fn checkpoint(&self) -> Checkpoint {
433        Checkpoint {
434            index: self.session.fact_log.len(),
435        }
436    }
437
438    fn revert(&mut self, checkpoint: Checkpoint) -> Result<(), Bug> {
439        if checkpoint.index == self.session.fact_log.len() {
440            return Ok(());
441        }
442
443        if checkpoint.index > self.session.fact_log.len() {
444            bug!(
445                "A checkpoint's index should always be less than or equal to the length of a session's fact log!"
446            );
447        }
448
449        self.session.fact_log.truncate(checkpoint.index);
450        // Create empty map, but reuse allocation if not shared
451        let mut facts =
452            Arc::get_mut(&mut self.session.current_facts).map_or_else(BTreeMap::new, mem::take);
453        facts.clear();
454        for (n, k, v) in self.session.fact_log.iter().cloned() {
455            facts.entry(n).or_default().insert(k, v);
456        }
457        self.session.current_facts = Arc::new(facts);
458
459        Ok(())
460    }
461}
462
463#[cfg(test)]
464mod test {
465    use super::*;
466
467    #[test]
468    fn test_query_iterator() {
469        #![allow(clippy::type_complexity)]
470
471        let prior: Vec<Result<(&[&[u8]], &[u8]), _>> = vec![
472            Ok((&[b"a"], b"a0")),
473            Ok((&[b"c"], b"c0")),
474            Ok((&[b"d"], b"d0")),
475            Ok((&[b"f"], b"f0")),
476            Err(StorageError::IoError),
477        ];
478        let current: Vec<([Box<[u8]>; 1], Option<&[u8]>)> = vec![
479            ([Box::new(*b"a")], None),
480            ([Box::new(*b"b")], Some(b"b1")),
481            ([Box::new(*b"e")], None),
482            ([Box::new(*b"j")], None),
483        ];
484        let merged: Vec<Result<(&[&[u8]], &[u8]), _>> = vec![
485            Ok((&[b"b"], b"b1")),
486            Ok((&[b"c"], b"c0")),
487            Ok((&[b"d"], b"d0")),
488            Ok((&[b"f"], b"f0")),
489            Err(StorageError::IoError),
490        ];
491
492        let got: Vec<_> = QueryIterator::new(
493            prior.into_iter().map(|r| {
494                r.map(|(k, v)| Fact {
495                    key: k.into(),
496                    value: v.into(),
497                })
498            }),
499            current
500                .into_iter()
501                .map(|(k, v)| (k.into_iter().collect(), v.map(Box::from))),
502        )
503        .collect();
504        let want: Vec<_> = merged
505            .into_iter()
506            .map(|r| {
507                r.map(|(k, v)| Fact {
508                    key: k.into(),
509                    value: v.into(),
510                })
511            })
512            .collect();
513
514        assert_eq!(got, want);
515    }
516}