Skip to main content

aranya_runtime/client/
session.rs

1//! Ephemeral sessions for off-graph commands.
2//!
3//! See [`ClientState::session`] and [`Session`].
4//!
5//! Design doc: [Aranya Sessions](https://github.com/aranya-project/aranya-docs/blob/main/src/Aranya-Sessions-note.md)
6
7use alloc::{
8    boxed::Box,
9    collections::{BTreeMap, btree_map},
10    string::String,
11    sync::Arc,
12    vec::Vec,
13};
14use core::{cmp::Ordering, iter::Peekable, marker::PhantomData, mem, ops::Bound};
15
16use buggy::{Bug, BugExt as _, bug};
17use serde::{Deserialize, Serialize};
18use tracing::warn;
19use yoke::{Yoke, Yokeable};
20
21use crate::{
22    Address, Checkpoint, ClientError, ClientState, CmdId, Command, Fact, FactPerspective, GraphId,
23    Keys, NullSink, Perspective, Policy, PolicyId, PolicyStore, Prior, Priority, Query, QueryMut,
24    Revertable, Segment as _, Sink, Storage, StorageError, StorageProvider,
25    policy::{ActionPlacement, CommandPlacement},
26};
27
28type Bytes = Box<[u8]>;
29
30/// Ephemeral session used to handle/generate off-graph commands.
31pub struct Session<SP: StorageProvider, PS> {
32    /// The ID of the associated graph.
33    graph_id: GraphId,
34    /// The policy ID for the session.
35    policy_id: PolicyId,
36
37    /// The prior facts from the graph head.
38    base_facts: <SP::Storage as Storage>::FactIndex,
39    /// The log of facts in insertion order.
40    fact_log: Vec<(String, Keys, Option<Bytes>)>,
41    /// The current facts of the session, relative to `base_facts`.
42    current_facts: Arc<BTreeMap<String, BTreeMap<Keys, Option<Bytes>>>>,
43
44    /// Tag for associated policy store.
45    _policy_store: PhantomData<PS>,
46
47    head: Address,
48}
49
50struct SessionPerspective<'a, SP: StorageProvider, PS, MS> {
51    session: &'a mut Session<SP, PS>,
52    message_sink: &'a mut MS,
53}
54
55impl<SP: StorageProvider, PS> Session<SP, PS> {
56    pub(super) fn new(provider: &mut SP, graph_id: GraphId) -> Result<Self, ClientError> {
57        let storage = provider.get_storage(graph_id)?;
58        let head_loc = storage.get_head()?;
59        let seg = storage.get_segment(head_loc)?;
60
61        let base_facts = seg.facts()?;
62
63        let result = Self {
64            graph_id,
65            policy_id: seg.policy(),
66            base_facts,
67            fact_log: Vec::new(),
68            current_facts: Arc::default(),
69            _policy_store: PhantomData,
70            head: storage.get_head_address()?,
71        };
72
73        Ok(result)
74    }
75}
76
77impl<SP: StorageProvider, PS: PolicyStore> Session<SP, PS> {
78    /// Evaluate an action on the ephemeral session and generate serialized
79    /// commands, so another client can [`Session::receive`] them.
80    pub fn action<ES, MS>(
81        &mut self,
82        client: &ClientState<PS, SP>,
83        effect_sink: &mut ES,
84        message_sink: &mut MS,
85        action: <PS::Policy as Policy>::Action<'_>,
86    ) -> Result<(), ClientError>
87    where
88        ES: Sink<PS::Effect>,
89        MS: for<'b> Sink<&'b [u8]>,
90    {
91        let policy = client.policy_store.get_policy(self.policy_id)?;
92
93        // Use a special perspective so we can send to the message sink.
94        let mut perspective = SessionPerspective {
95            session: self,
96            message_sink,
97        };
98        let checkpoint = perspective.checkpoint();
99        effect_sink.begin();
100
101        // Try to perform action.
102        match policy.call_action(
103            action,
104            &mut perspective,
105            effect_sink,
106            ActionPlacement::OffGraph,
107        ) {
108            Ok(()) => {
109                // Success, commit effects
110                effect_sink.commit();
111                Ok(())
112            }
113            Err(e) => {
114                // Other error, revert all? See #513.
115                perspective.revert(checkpoint)?;
116                perspective.message_sink.rollback();
117                effect_sink.rollback();
118                Err(e.into())
119            }
120        }
121    }
122
123    /// Handle a command from another client generated by [`Session::action`].
124    ///
125    /// You do NOT need to reprocess the commands from actions generated in the
126    /// same session.
127    pub fn receive(
128        &mut self,
129        client: &ClientState<PS, SP>,
130        sink: &mut impl Sink<PS::Effect>,
131        command_bytes: &[u8],
132    ) -> Result<(), ClientError> {
133        let command: SessionCommand<'_> =
134            postcard::from_bytes(command_bytes).map_err(ClientError::SessionDeserialize)?;
135
136        if command.graph_id != self.graph_id {
137            warn!(%command.graph_id, %self.graph_id, "ephemeral commands must be run on the same graph");
138            return Err(ClientError::NotAuthorized);
139        }
140
141        let policy = client.policy_store.get_policy(self.policy_id)?;
142
143        // Use a special perspective which doesn't check the head
144        let mut perspective = SessionPerspective {
145            session: self,
146            message_sink: &mut NullSink,
147        };
148
149        // Try to evaluate command.
150        sink.begin();
151        let checkpoint = perspective.checkpoint();
152        if let Err(e) =
153            policy.call_rule(&command, &mut perspective, sink, CommandPlacement::OffGraph)
154        {
155            perspective.revert(checkpoint)?;
156            sink.rollback();
157            return Err(e.into());
158        }
159        sink.commit();
160
161        Ok(())
162    }
163}
164
165#[derive(Serialize, Deserialize)]
166/// Used for serializing session commands
167struct SessionCommand<'a> {
168    graph_id: GraphId,
169    priority: u32, // Priority::Basic
170    id: CmdId,
171    parent: Address, // Prior::Single
172    #[serde(borrow)]
173    data: &'a [u8],
174}
175
176impl Command for SessionCommand<'_> {
177    fn priority(&self) -> Priority {
178        Priority::Basic(self.priority)
179    }
180
181    fn id(&self) -> CmdId {
182        self.id
183    }
184
185    fn parent(&self) -> Prior<Address> {
186        Prior::Single(self.parent)
187    }
188
189    fn policy(&self) -> Option<&[u8]> {
190        // Session commands should never have policy?
191        None
192    }
193
194    fn bytes(&self) -> &[u8] {
195        self.data
196    }
197}
198
199impl<'sc> SessionCommand<'sc> {
200    fn from_cmd(graph_id: GraphId, command: &'sc impl Command) -> Result<Self, Bug> {
201        if command.policy().is_some() {
202            bug!("session command should have no policy")
203        }
204        Ok(SessionCommand {
205            graph_id,
206            priority: match command.priority() {
207                Priority::Basic(p) => p,
208                _ => bug!("wrong command type"),
209            },
210            id: command.id(),
211            parent: match command.parent() {
212                Prior::Single(p) => p,
213                _ => bug!("wrong command type"),
214            },
215            data: command.bytes(),
216        })
217    }
218}
219
220/// Query iterator for SessionPerspective which wraps an inner query iterator
221struct QueryIterator<I1: Iterator, I2: Iterator> {
222    prior: Peekable<I1>,
223    current: Peekable<I2>,
224}
225
226impl<I1, I2> QueryIterator<I1, I2>
227where
228    I1: Iterator<Item = Result<Fact, StorageError>>,
229    I2: Iterator<Item = (Keys, Option<Bytes>)>,
230{
231    fn new(prior: I1, current: I2) -> Self {
232        Self {
233            prior: prior.peekable(),
234            current: current.peekable(),
235        }
236    }
237}
238
239impl<I1, I2> Iterator for QueryIterator<I1, I2>
240where
241    I1: Iterator<Item = Result<Fact, StorageError>>,
242    I2: Iterator<Item = (Keys, Option<Bytes>)>,
243{
244    type Item = Result<Fact, StorageError>;
245
246    fn next(&mut self) -> Option<Self::Item> {
247        // We find the next lowest item between the two iterators,
248        // while also ensuring that newer entries overwrite older.
249        // We loop so we can skip over deleted facts.
250
251        loop {
252            let Some(new) = self.current.peek() else {
253                // If current has run out, just use prior.
254                return self.prior.next();
255            };
256            if let Some(old) = self.prior.peek() {
257                let Ok(old) = old else {
258                    // Bubble up errors as soon as possible, instead of returning `new`.
259                    return self.prior.next();
260                };
261                match new.0.cmp(&old.key) {
262                    Ordering::Equal => {
263                        // new overwrites old.
264                        let _ = self.prior.next();
265                    }
266                    Ordering::Greater => {
267                        // old comes next in sorted order.
268                        return self.prior.next();
269                    }
270                    Ordering::Less => {
271                        // new comes next in sorted order.
272                    }
273                }
274            }
275            let Some(slot) = self.current.next() else {
276                bug!("expected Some after peek")
277            };
278            if let (k, Some(v)) = slot {
279                return Some(Ok(Fact {
280                    key: k.iter().cloned().collect(),
281                    value: v,
282                }));
283            }
284        }
285    }
286}
287
288impl<SP, PS, MS> FactPerspective for SessionPerspective<'_, SP, PS, MS> where SP: StorageProvider {}
289
290impl<SP, PS, MS> Query for SessionPerspective<'_, SP, PS, MS>
291where
292    SP: StorageProvider,
293{
294    fn query(&self, name: &str, keys: &[Box<[u8]>]) -> Result<Option<Box<[u8]>>, StorageError> {
295        if let Some(slot) = self
296            .session
297            .current_facts
298            .get(name)
299            .and_then(|m| m.get(keys))
300        {
301            return Ok(slot.clone());
302        }
303        self.session.base_facts.query(name, keys)
304    }
305
306    type QueryIterator = QueryIterator<
307        <<SP::Storage as Storage>::FactIndex as Query>::QueryIterator,
308        YokeIter<PrefixIter<'static>, Arc<BTreeMap<String, BTreeMap<Keys, Option<Bytes>>>>>,
309    >;
310    fn query_prefix(
311        &self,
312        name: &str,
313        prefix: &[Box<[u8]>],
314    ) -> Result<Self::QueryIterator, StorageError> {
315        let prior = self.session.base_facts.query_prefix(name, prefix)?;
316        let current = Yoke::<PrefixIter<'static>, _>::attach_to_cart(
317            Arc::clone(&self.session.current_facts),
318            |map| match map.get(name) {
319                Some(facts) => PrefixIter::new(facts, prefix.iter().cloned().collect()),
320                None => PrefixIter::default(),
321            },
322        );
323        Ok(QueryIterator::new(prior, YokeIter::new(current)))
324    }
325}
326
327/// Iterator over matching prefix of a [`BTreeMap`].
328///
329/// Equivalent to `map.range(&prefix..).take_while(move |(k, _)| k.starts_with(prefix))`,
330/// but nameable and [`Yokeable`].
331#[derive(Default, Yokeable)]
332struct PrefixIter<'map> {
333    range: btree_map::Range<'map, Keys, Option<Bytes>>,
334    prefix: Keys,
335}
336
337impl<'map> PrefixIter<'map> {
338    fn new(map: &'map BTreeMap<Keys, Option<Bytes>>, prefix: Keys) -> Self {
339        let range =
340            map.range::<[Box<[u8]>], _>((Bound::Included(prefix.as_ref()), Bound::Unbounded));
341        Self { range, prefix }
342    }
343}
344
345impl Iterator for PrefixIter<'_> {
346    type Item = (Keys, Option<Bytes>);
347
348    fn next(&mut self) -> Option<Self::Item> {
349        self.range
350            .next()
351            .filter(|(k, _)| k.starts_with(&self.prefix))
352            .map(|(k, v)| (k.clone(), v.clone()))
353    }
354}
355
356/// Wrapper around [`Yoke`] which implements [`Iterator`].
357struct YokeIter<I: for<'a> Yokeable<'a>, C>(Option<Yoke<I, C>>);
358
359impl<I: for<'a> Yokeable<'a>, C> YokeIter<I, C> {
360    fn new(yoke: Yoke<I, C>) -> Self {
361        Self(Some(yoke))
362    }
363}
364
365impl<I, C> Iterator for YokeIter<I, C>
366where
367    I: Iterator + for<'a> Yokeable<'a>,
368    for<'a> <I as Yokeable<'a>>::Output: Iterator<Item = I::Item>,
369{
370    type Item = I::Item;
371
372    fn next(&mut self) -> Option<Self::Item> {
373        // `Yoke::map_project` is currently the only way to mutate something in a yoke and get out a value.
374        // It takes the yoke by value though, so we have to `take` it so we can own it temporarily.
375        let mut item = None;
376        self.0 = Some(self.0.take()?.map_project::<I, _>(|mut it, _| {
377            item = it.next();
378            it
379        }));
380        item
381    }
382}
383
384impl<SP: StorageProvider, PS, MS> QueryMut for SessionPerspective<'_, SP, PS, MS> {
385    fn insert(&mut self, name: String, keys: Keys, value: Box<[u8]>) {
386        self.session
387            .fact_log
388            .push((name.clone(), keys.clone(), Some(value.clone())));
389        Arc::make_mut(&mut self.session.current_facts)
390            .entry(name)
391            .or_default()
392            .insert(keys, Some(value));
393    }
394
395    fn delete(&mut self, name: String, keys: Keys) {
396        self.session
397            .fact_log
398            .push((name.clone(), keys.clone(), None));
399        Arc::make_mut(&mut self.session.current_facts)
400            .entry(name)
401            .or_default()
402            .insert(keys, None);
403    }
404}
405
406impl<SP, PS, MS> Perspective for SessionPerspective<'_, SP, PS, MS>
407where
408    SP: StorageProvider,
409    MS: for<'b> Sink<&'b [u8]>,
410{
411    fn policy(&self) -> PolicyId {
412        self.session.policy_id
413    }
414
415    fn add_command(&mut self, command: &impl Command) -> Result<usize, StorageError> {
416        let command = SessionCommand::from_cmd(self.session.graph_id, command)?;
417        self.session.head = command.address()?;
418        let bytes = postcard::to_allocvec(&command).assume("serialize session command")?;
419        self.message_sink.consume(&bytes);
420
421        Ok(0)
422    }
423
424    fn includes(&self, _id: CmdId) -> bool {
425        debug_assert!(false, "only used in transactions");
426
427        false
428    }
429
430    fn head_address(&self) -> Result<Prior<Address>, Bug> {
431        Ok(Prior::Single(self.session.head))
432    }
433}
434
435impl<SP, PS, MS> Revertable for SessionPerspective<'_, SP, PS, MS>
436where
437    SP: StorageProvider,
438{
439    fn checkpoint(&self) -> Checkpoint {
440        Checkpoint {
441            index: self.session.fact_log.len(),
442        }
443    }
444
445    fn revert(&mut self, checkpoint: Checkpoint) -> Result<(), Bug> {
446        if checkpoint.index == self.session.fact_log.len() {
447            return Ok(());
448        }
449
450        if checkpoint.index > self.session.fact_log.len() {
451            bug!(
452                "A checkpoint's index should always be less than or equal to the length of a session's fact log!"
453            );
454        }
455
456        self.session.fact_log.truncate(checkpoint.index);
457        // Create empty map, but reuse allocation if not shared
458        let mut facts =
459            Arc::get_mut(&mut self.session.current_facts).map_or_else(BTreeMap::new, mem::take);
460        facts.clear();
461        for (n, k, v) in self.session.fact_log.iter().cloned() {
462            facts.entry(n).or_default().insert(k, v);
463        }
464        self.session.current_facts = Arc::new(facts);
465
466        Ok(())
467    }
468}
469
470#[cfg(test)]
471mod test {
472    use super::*;
473
474    #[test]
475    fn test_query_iterator() {
476        #![allow(clippy::type_complexity)]
477
478        let prior: Vec<Result<(&[&[u8]], &[u8]), _>> = vec![
479            Ok((&[b"a"], b"a0")),
480            Ok((&[b"c"], b"c0")),
481            Ok((&[b"d"], b"d0")),
482            Ok((&[b"f"], b"f0")),
483            Err(StorageError::IoError),
484        ];
485        let current: Vec<([Box<[u8]>; 1], Option<&[u8]>)> = vec![
486            ([Box::new(*b"a")], None),
487            ([Box::new(*b"b")], Some(b"b1")),
488            ([Box::new(*b"e")], None),
489            ([Box::new(*b"j")], None),
490        ];
491        let merged: Vec<Result<(&[&[u8]], &[u8]), _>> = vec![
492            Ok((&[b"b"], b"b1")),
493            Ok((&[b"c"], b"c0")),
494            Ok((&[b"d"], b"d0")),
495            Ok((&[b"f"], b"f0")),
496            Err(StorageError::IoError),
497        ];
498
499        let got: Vec<_> = QueryIterator::new(
500            prior.into_iter().map(|r| {
501                r.map(|(k, v)| Fact {
502                    key: k.into(),
503                    value: v.into(),
504                })
505            }),
506            current
507                .into_iter()
508                .map(|(k, v)| (k.into_iter().collect(), v.map(Box::from))),
509        )
510        .collect();
511        let want: Vec<_> = merged
512            .into_iter()
513            .map(|r| {
514                r.map(|(k, v)| Fact {
515                    key: k.into(),
516                    value: v.into(),
517                })
518            })
519            .collect();
520
521        assert_eq!(got, want);
522    }
523}