aranya_runtime/client/
session.rs

1//! Ephemeral sessions for off-graph commands.
2//!
3//! See [`ClientState::session`] and [`Session`].
4//!
5//! Design doc: [Aranya Sessions](https://github.com/aranya-project/aranya-docs/blob/main/src/Aranya-Sessions-note.md)
6
7use alloc::{
8    boxed::Box,
9    collections::{BTreeMap, btree_map},
10    string::String,
11    sync::Arc,
12    vec::Vec,
13};
14use core::{cmp::Ordering, iter::Peekable, marker::PhantomData, mem, ops::Bound};
15
16use buggy::{Bug, BugExt as _, bug};
17use serde::{Deserialize, Serialize};
18use tracing::warn;
19use yoke::{Yoke, Yokeable};
20
21use crate::{
22    Address, Checkpoint, ClientError, ClientState, CmdId, Command, Engine, Fact, FactPerspective,
23    GraphId, Keys, NullSink, Perspective, Policy, PolicyId, Prior, Priority, Query, QueryMut,
24    Revertable, Segment as _, Sink, Storage, StorageError, StorageProvider,
25    engine::{ActionPlacement, CommandPlacement},
26};
27
28type Bytes = Box<[u8]>;
29
30/// Ephemeral session used to handle/generate off-graph commands.
31pub struct Session<SP: StorageProvider, E> {
32    /// The ID of the associated storage.
33    storage_id: GraphId,
34    /// The policy ID for the session.
35    policy_id: PolicyId,
36
37    /// The prior facts from the graph head.
38    base_facts: <SP::Storage as Storage>::FactIndex,
39    /// The log of facts in insertion order.
40    fact_log: Vec<(String, Keys, Option<Bytes>)>,
41    /// The current facts of the session, relative to `base_facts`.
42    current_facts: Arc<BTreeMap<String, BTreeMap<Keys, Option<Bytes>>>>,
43
44    /// Tag for associated engine.
45    _engine: PhantomData<E>,
46
47    head: Address,
48}
49
50struct SessionPerspective<'a, SP: StorageProvider, E, MS> {
51    session: &'a mut Session<SP, E>,
52    message_sink: &'a mut MS,
53}
54
55impl<SP: StorageProvider, E> Session<SP, E> {
56    pub(super) fn new(provider: &mut SP, storage_id: GraphId) -> Result<Self, ClientError> {
57        let storage = provider.get_storage(storage_id)?;
58        let head_loc = storage.get_head()?;
59        let seg = storage.get_segment(head_loc)?;
60        let command = seg.get_command(head_loc).assume("location must exist")?;
61
62        let base_facts = seg.facts()?;
63
64        let result = Self {
65            storage_id,
66            policy_id: seg.policy(),
67            base_facts,
68            fact_log: Vec::new(),
69            current_facts: Arc::default(),
70            _engine: PhantomData,
71            head: command.address()?,
72        };
73
74        Ok(result)
75    }
76}
77
78impl<SP: StorageProvider, E: Engine> Session<SP, E> {
79    /// Evaluate an action on the ephemeral session and generate serialized
80    /// commands, so another client can [`Session::receive`] them.
81    pub fn action<ES, MS>(
82        &mut self,
83        client: &ClientState<E, SP>,
84        effect_sink: &mut ES,
85        message_sink: &mut MS,
86        action: <E::Policy as Policy>::Action<'_>,
87    ) -> Result<(), ClientError>
88    where
89        ES: Sink<E::Effect>,
90        MS: for<'b> Sink<&'b [u8]>,
91    {
92        let policy = client.engine.get_policy(self.policy_id)?;
93
94        // Use a special perspective so we can send to the message sink.
95        let mut perspective = SessionPerspective {
96            session: self,
97            message_sink,
98        };
99        let checkpoint = perspective.checkpoint();
100        effect_sink.begin();
101
102        // Try to perform action.
103        match policy.call_action(
104            action,
105            &mut perspective,
106            effect_sink,
107            ActionPlacement::OffGraph,
108        ) {
109            Ok(()) => {
110                // Success, commit effects
111                effect_sink.commit();
112                Ok(())
113            }
114            Err(e) => {
115                // Other error, revert all? See #513.
116                perspective.revert(checkpoint)?;
117                perspective.message_sink.rollback();
118                effect_sink.rollback();
119                Err(e.into())
120            }
121        }
122    }
123
124    /// Handle a command from another client generated by [`Session::action`].
125    ///
126    /// You do NOT need to reprocess the commands from actions generated in the
127    /// same session.
128    pub fn receive(
129        &mut self,
130        client: &ClientState<E, SP>,
131        sink: &mut impl Sink<E::Effect>,
132        command_bytes: &[u8],
133    ) -> Result<(), ClientError> {
134        let command: SessionCommand<'_> =
135            postcard::from_bytes(command_bytes).map_err(ClientError::SessionDeserialize)?;
136
137        if command.storage_id != self.storage_id {
138            warn!(%command.storage_id, %self.storage_id, "ephemeral commands must be run on the same graph");
139            return Err(ClientError::NotAuthorized);
140        }
141
142        let policy = client.engine.get_policy(self.policy_id)?;
143
144        // Use a special perspective which doesn't check the head
145        let mut perspective = SessionPerspective {
146            session: self,
147            message_sink: &mut NullSink,
148        };
149
150        // Try to evaluate command.
151        sink.begin();
152        let checkpoint = perspective.checkpoint();
153        if let Err(e) =
154            policy.call_rule(&command, &mut perspective, sink, CommandPlacement::OffGraph)
155        {
156            perspective.revert(checkpoint)?;
157            sink.rollback();
158            return Err(e.into());
159        }
160        sink.commit();
161
162        Ok(())
163    }
164}
165
166#[derive(Serialize, Deserialize)]
167/// Used for serializing session commands
168struct SessionCommand<'a> {
169    storage_id: GraphId,
170    priority: u32, // Priority::Basic
171    id: CmdId,
172    parent: Address, // Prior::Single
173    #[serde(borrow)]
174    data: &'a [u8],
175}
176
177impl Command for SessionCommand<'_> {
178    fn priority(&self) -> Priority {
179        Priority::Basic(self.priority)
180    }
181
182    fn id(&self) -> CmdId {
183        self.id
184    }
185
186    fn parent(&self) -> Prior<Address> {
187        Prior::Single(self.parent)
188    }
189
190    fn policy(&self) -> Option<&[u8]> {
191        // Session commands should never have policy?
192        None
193    }
194
195    fn bytes(&self) -> &[u8] {
196        self.data
197    }
198}
199
200impl<'sc> SessionCommand<'sc> {
201    fn from_cmd(storage_id: GraphId, command: &'sc impl Command) -> Result<Self, Bug> {
202        if command.policy().is_some() {
203            bug!("session command should have no policy")
204        }
205        Ok(SessionCommand {
206            storage_id,
207            priority: match command.priority() {
208                Priority::Basic(p) => p,
209                _ => bug!("wrong command type"),
210            },
211            id: command.id(),
212            parent: match command.parent() {
213                Prior::Single(p) => p,
214                _ => bug!("wrong command type"),
215            },
216            data: command.bytes(),
217        })
218    }
219}
220
221/// Query iterator for SessionPerspective which wraps an inner query iterator
222struct QueryIterator<I1: Iterator, I2: Iterator> {
223    prior: Peekable<I1>,
224    current: Peekable<I2>,
225}
226
227impl<I1, I2> QueryIterator<I1, I2>
228where
229    I1: Iterator<Item = Result<Fact, StorageError>>,
230    I2: Iterator<Item = (Keys, Option<Bytes>)>,
231{
232    fn new(prior: I1, current: I2) -> Self {
233        Self {
234            prior: prior.peekable(),
235            current: current.peekable(),
236        }
237    }
238}
239
240impl<I1, I2> Iterator for QueryIterator<I1, I2>
241where
242    I1: Iterator<Item = Result<Fact, StorageError>>,
243    I2: Iterator<Item = (Keys, Option<Bytes>)>,
244{
245    type Item = Result<Fact, StorageError>;
246
247    fn next(&mut self) -> Option<Self::Item> {
248        // We find the next lowest item between the two iterators,
249        // while also ensuring that newer entries overwrite older.
250        // We loop so we can skip over deleted facts.
251
252        loop {
253            let Some(new) = self.current.peek() else {
254                // If current has run out, just use prior.
255                return self.prior.next();
256            };
257            if let Some(old) = self.prior.peek() {
258                let Ok(old) = old else {
259                    // Bubble up errors as soon as possible, instead of returning `new`.
260                    return self.prior.next();
261                };
262                match new.0.cmp(&old.key) {
263                    Ordering::Equal => {
264                        // new overwrites old.
265                        let _ = self.prior.next();
266                    }
267                    Ordering::Greater => {
268                        // old comes next in sorted order.
269                        return self.prior.next();
270                    }
271                    Ordering::Less => {
272                        // new comes next in sorted order.
273                    }
274                }
275            }
276            let Some(slot) = self.current.next() else {
277                bug!("expected Some after peek")
278            };
279            if let (k, Some(v)) = slot {
280                return Some(Ok(Fact {
281                    key: k.iter().cloned().collect(),
282                    value: v,
283                }));
284            }
285        }
286    }
287}
288
289impl<SP, E, MS> FactPerspective for SessionPerspective<'_, SP, E, MS> where SP: StorageProvider {}
290
291impl<SP, E, MS> Query for SessionPerspective<'_, SP, E, MS>
292where
293    SP: StorageProvider,
294{
295    fn query(&self, name: &str, keys: &[Box<[u8]>]) -> Result<Option<Box<[u8]>>, StorageError> {
296        if let Some(slot) = self
297            .session
298            .current_facts
299            .get(name)
300            .and_then(|m| m.get(keys))
301        {
302            return Ok(slot.clone());
303        }
304        self.session.base_facts.query(name, keys)
305    }
306
307    type QueryIterator = QueryIterator<
308        <<SP::Storage as Storage>::FactIndex as Query>::QueryIterator,
309        YokeIter<PrefixIter<'static>, Arc<BTreeMap<String, BTreeMap<Keys, Option<Bytes>>>>>,
310    >;
311    fn query_prefix(
312        &self,
313        name: &str,
314        prefix: &[Box<[u8]>],
315    ) -> Result<Self::QueryIterator, StorageError> {
316        let prior = self.session.base_facts.query_prefix(name, prefix)?;
317        let current = Yoke::<PrefixIter<'static>, _>::attach_to_cart(
318            Arc::clone(&self.session.current_facts),
319            |map| match map.get(name) {
320                Some(facts) => PrefixIter::new(facts, prefix.iter().cloned().collect()),
321                None => PrefixIter::default(),
322            },
323        );
324        Ok(QueryIterator::new(prior, YokeIter::new(current)))
325    }
326}
327
328/// Iterator over matching prefix of a [`BTreeMap`].
329///
330/// Equivalent to `map.range(&prefix..).take_while(move |(k, _)| k.starts_with(prefix))`,
331/// but nameable and [`Yokeable`].
332#[derive(Default, Yokeable)]
333struct PrefixIter<'map> {
334    range: btree_map::Range<'map, Keys, Option<Bytes>>,
335    prefix: Keys,
336}
337
338impl<'map> PrefixIter<'map> {
339    fn new(map: &'map BTreeMap<Keys, Option<Bytes>>, prefix: Keys) -> Self {
340        let range =
341            map.range::<[Box<[u8]>], _>((Bound::Included(prefix.as_ref()), Bound::Unbounded));
342        Self { range, prefix }
343    }
344}
345
346impl Iterator for PrefixIter<'_> {
347    type Item = (Keys, Option<Bytes>);
348
349    fn next(&mut self) -> Option<Self::Item> {
350        self.range
351            .next()
352            .filter(|(k, _)| k.starts_with(&self.prefix))
353            .map(|(k, v)| (k.clone(), v.clone()))
354    }
355}
356
357/// Wrapper around [`Yoke`] which implements [`Iterator`].
358struct YokeIter<I: for<'a> Yokeable<'a>, C>(Option<Yoke<I, C>>);
359
360impl<I: for<'a> Yokeable<'a>, C> YokeIter<I, C> {
361    fn new(yoke: Yoke<I, C>) -> Self {
362        Self(Some(yoke))
363    }
364}
365
366impl<I, C> Iterator for YokeIter<I, C>
367where
368    I: Iterator + for<'a> Yokeable<'a>,
369    for<'a> <I as Yokeable<'a>>::Output: Iterator<Item = I::Item>,
370{
371    type Item = I::Item;
372
373    fn next(&mut self) -> Option<Self::Item> {
374        // `Yoke::map_project` is currently the only way to mutate something in a yoke and get out a value.
375        // It takes the yoke by value though, so we have to `take` it so we can own it temporarily.
376        let mut item = None;
377        self.0 = Some(self.0.take()?.map_project::<I, _>(|mut it, _| {
378            item = it.next();
379            it
380        }));
381        item
382    }
383}
384
385impl<SP: StorageProvider, E, MS> QueryMut for SessionPerspective<'_, SP, E, MS> {
386    fn insert(&mut self, name: String, keys: Keys, value: Box<[u8]>) {
387        self.session
388            .fact_log
389            .push((name.clone(), keys.clone(), Some(value.clone())));
390        Arc::make_mut(&mut self.session.current_facts)
391            .entry(name)
392            .or_default()
393            .insert(keys, Some(value));
394    }
395
396    fn delete(&mut self, name: String, keys: Keys) {
397        self.session
398            .fact_log
399            .push((name.clone(), keys.clone(), None));
400        Arc::make_mut(&mut self.session.current_facts)
401            .entry(name)
402            .or_default()
403            .insert(keys, None);
404    }
405}
406
407impl<SP, E, MS> Perspective for SessionPerspective<'_, SP, E, MS>
408where
409    SP: StorageProvider,
410    MS: for<'b> Sink<&'b [u8]>,
411{
412    fn policy(&self) -> PolicyId {
413        self.session.policy_id
414    }
415
416    fn add_command(&mut self, command: &impl Command) -> Result<usize, StorageError> {
417        let command = SessionCommand::from_cmd(self.session.storage_id, command)?;
418        self.session.head = command.address()?;
419        let bytes = postcard::to_allocvec(&command).assume("serialize session command")?;
420        self.message_sink.consume(&bytes);
421
422        Ok(0)
423    }
424
425    fn includes(&self, _id: CmdId) -> bool {
426        debug_assert!(false, "only used in transactions");
427
428        false
429    }
430
431    fn head_address(&self) -> Result<Prior<Address>, Bug> {
432        Ok(Prior::Single(self.session.head))
433    }
434}
435
436impl<SP, E, MS> Revertable for SessionPerspective<'_, SP, E, MS>
437where
438    SP: StorageProvider,
439{
440    fn checkpoint(&self) -> Checkpoint {
441        Checkpoint {
442            index: self.session.fact_log.len(),
443        }
444    }
445
446    fn revert(&mut self, checkpoint: Checkpoint) -> Result<(), Bug> {
447        if checkpoint.index == self.session.fact_log.len() {
448            return Ok(());
449        }
450
451        if checkpoint.index > self.session.fact_log.len() {
452            bug!(
453                "A checkpoint's index should always be less than or equal to the length of a session's fact log!"
454            );
455        }
456
457        self.session.fact_log.truncate(checkpoint.index);
458        // Create empty map, but reuse allocation if not shared
459        let mut facts =
460            Arc::get_mut(&mut self.session.current_facts).map_or_else(BTreeMap::new, mem::take);
461        facts.clear();
462        for (n, k, v) in self.session.fact_log.iter().cloned() {
463            facts.entry(n).or_default().insert(k, v);
464        }
465        self.session.current_facts = Arc::new(facts);
466
467        Ok(())
468    }
469}
470
471#[cfg(test)]
472mod test {
473    use super::*;
474
475    #[test]
476    fn test_query_iterator() {
477        #![allow(clippy::type_complexity)]
478
479        let prior: Vec<Result<(&[&[u8]], &[u8]), _>> = vec![
480            Ok((&[b"a"], b"a0")),
481            Ok((&[b"c"], b"c0")),
482            Ok((&[b"d"], b"d0")),
483            Ok((&[b"f"], b"f0")),
484            Err(StorageError::IoError),
485        ];
486        let current: Vec<([Box<[u8]>; 1], Option<&[u8]>)> = vec![
487            ([Box::new(*b"a")], None),
488            ([Box::new(*b"b")], Some(b"b1")),
489            ([Box::new(*b"e")], None),
490            ([Box::new(*b"j")], None),
491        ];
492        let merged: Vec<Result<(&[&[u8]], &[u8]), _>> = vec![
493            Ok((&[b"b"], b"b1")),
494            Ok((&[b"c"], b"c0")),
495            Ok((&[b"d"], b"d0")),
496            Ok((&[b"f"], b"f0")),
497            Err(StorageError::IoError),
498        ];
499
500        let got: Vec<_> = QueryIterator::new(
501            prior.into_iter().map(|r| {
502                r.map(|(k, v)| Fact {
503                    key: k.into(),
504                    value: v.into(),
505                })
506            }),
507            current
508                .into_iter()
509                .map(|(k, v)| (k.into_iter().collect(), v.map(Box::from))),
510        )
511        .collect();
512        let want: Vec<_> = merged
513            .into_iter()
514            .map(|r| {
515                r.map(|(k, v)| Fact {
516                    key: k.into(),
517                    value: v.into(),
518                })
519            })
520            .collect();
521
522        assert_eq!(got, want);
523    }
524}