aranya_runtime/
client.rs

1use alloc::{collections::BinaryHeap, vec::Vec};
2use core::fmt;
3
4use aranya_buggy::{Bug, BugExt};
5use tracing::trace;
6
7use crate::{
8    Command, CommandId, Engine, EngineError, GraphId, Location, PeerCache, Perspective, Policy,
9    Prior, Priority, Segment, Sink, Storage, StorageError, StorageProvider,
10};
11
12mod session;
13mod transaction;
14
15pub use self::{session::Session, transaction::Transaction};
16
17/// An error returned by the runtime client.
18#[derive(Debug)]
19pub enum ClientError {
20    NoSuchParent(CommandId),
21    EngineError(EngineError),
22    StorageError(StorageError),
23    InitError,
24    NotAuthorized,
25    SessionDeserialize(postcard::Error),
26    Bug(Bug),
27}
28
29impl fmt::Display for ClientError {
30    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31        match self {
32            Self::NoSuchParent(id) => write!(f, "no such parent: {id}"),
33            Self::EngineError(e) => write!(f, "engine error: {e}"),
34            Self::StorageError(e) => write!(f, "storage error: {e}"),
35            Self::InitError => write!(f, "init error"),
36            Self::NotAuthorized => write!(f, "not authorized"),
37            Self::SessionDeserialize(e) => write!(f, "session deserialize error: {e}"),
38            Self::Bug(bug) => write!(f, "{bug}"),
39        }
40    }
41}
42
43impl core::error::Error for ClientError {
44    fn source(&self) -> Option<&(dyn core::error::Error + 'static)> {
45        match self {
46            Self::EngineError(e) => Some(e),
47            Self::StorageError(e) => Some(e),
48            Self::Bug(e) => Some(e),
49            _ => None,
50        }
51    }
52}
53
54impl From<EngineError> for ClientError {
55    fn from(error: EngineError) -> Self {
56        match error {
57            EngineError::Check => Self::NotAuthorized,
58            _ => Self::EngineError(error),
59        }
60    }
61}
62
63impl From<StorageError> for ClientError {
64    fn from(error: StorageError) -> Self {
65        ClientError::StorageError(error)
66    }
67}
68
69impl From<Bug> for ClientError {
70    fn from(error: Bug) -> Self {
71        ClientError::Bug(error)
72    }
73}
74
75/// Keeps track of client graph state.
76///
77/// - `E` should be an implementation of [`Engine`].
78/// - `SP` should be an implementation of [`StorageProvider`].
79#[derive(Debug)]
80pub struct ClientState<E, SP> {
81    engine: E,
82    provider: SP,
83}
84
85impl<E, SP> ClientState<E, SP> {
86    /// Creates a `ClientState`.
87    pub const fn new(engine: E, provider: SP) -> ClientState<E, SP> {
88        ClientState { engine, provider }
89    }
90
91    /// Provide access to the [`StorageProvider`].
92    pub fn provider(&mut self) -> &mut SP {
93        &mut self.provider
94    }
95}
96
97impl<E, SP> ClientState<E, SP>
98where
99    E: Engine,
100    SP: StorageProvider,
101{
102    /// Create a new graph (AKA Team). This graph will start with the initial policy
103    /// provided which must be compatible with the engine E. The `payload` is the initial
104    /// init message that will bootstrap the graph facts. Effects produced when processing
105    /// the payload are emitted to the sink.
106    pub fn new_graph(
107        &mut self,
108        policy_data: &[u8],
109        action: <E::Policy as Policy>::Action<'_>,
110        sink: &mut impl Sink<E::Effect>,
111    ) -> Result<GraphId, ClientError> {
112        let policy_id = self.engine.add_policy(policy_data)?;
113        let policy = self.engine.get_policy(policy_id)?;
114
115        let mut perspective = self.provider.new_perspective(policy_id);
116        sink.begin();
117        policy
118            .call_action(action, &mut perspective, sink)
119            .inspect_err(|_| sink.rollback())?;
120        sink.commit();
121
122        let (graph_id, _) = self.provider.new_storage(perspective)?;
123
124        Ok(graph_id)
125    }
126
127    /// Commit the [`Transaction`] to storage, after merging all temporary heads.
128    pub fn commit(
129        &mut self,
130        trx: &mut Transaction<SP, E>,
131        sink: &mut impl Sink<E::Effect>,
132    ) -> Result<(), ClientError> {
133        trx.commit(&mut self.provider, &mut self.engine, sink)
134    }
135
136    /// Add commands to the transaction, writing the results to
137    /// `sink`.
138    /// Returns the number of commands that were added.
139    pub fn add_commands(
140        &mut self,
141        trx: &mut Transaction<SP, E>,
142        sink: &mut impl Sink<E::Effect>,
143        commands: &[impl Command],
144        request_heads: &mut PeerCache,
145    ) -> Result<usize, ClientError> {
146        let count = trx.add_commands(
147            commands,
148            &mut self.provider,
149            &mut self.engine,
150            sink,
151            request_heads,
152        )?;
153        Ok(count)
154    }
155
156    /// Performs an `action`, writing the results to `sink`.
157    pub fn action(
158        &mut self,
159        storage_id: GraphId,
160        sink: &mut impl Sink<E::Effect>,
161        action: <E::Policy as Policy>::Action<'_>,
162    ) -> Result<(), ClientError> {
163        let storage = self.provider.get_storage(storage_id)?;
164
165        let head = storage.get_head()?;
166
167        let mut perspective = storage
168            .get_linear_perspective(head)?
169            .assume("can always get perspective at head")?;
170
171        let policy_id = perspective.policy();
172        let policy = self.engine.get_policy(policy_id)?;
173
174        // No need to checkpoint the perspective since it is only for this action.
175        // Must checkpoint once we add action transactions.
176
177        sink.begin();
178        match policy.call_action(action, &mut perspective, sink) {
179            Ok(_) => {
180                let segment = storage.write(perspective)?;
181                storage.commit(segment)?;
182                sink.commit();
183                Ok(())
184            }
185            Err(e) => {
186                sink.rollback();
187                Err(e.into())
188            }
189        }
190    }
191}
192
193impl<E, SP> ClientState<E, SP>
194where
195    SP: StorageProvider,
196{
197    /// Create a new [`Transaction`], used to receive [`Command`]s when syncing.
198    pub fn transaction(&mut self, storage_id: GraphId) -> Transaction<SP, E> {
199        Transaction::new(storage_id)
200    }
201
202    /// Create an ephemeral [`Session`] associated with this client.
203    pub fn session(&mut self, storage_id: GraphId) -> Result<Session<SP, E>, ClientError> {
204        Session::new(&mut self.provider, storage_id)
205    }
206}
207
208/// Returns the last common ancestor of two Locations.
209///
210/// This walks the graph backwards until the two locations meet. This
211/// ensures that you can jump to the last common ancestor from
212/// the merge command created using left and right and know that you
213/// won't be jumping into a branch.
214fn last_common_ancestor<S: Storage>(
215    storage: &mut S,
216    left: Location,
217    right: Location,
218) -> Result<(Location, usize), ClientError> {
219    trace!(%left, %right, "finding least common ancestor");
220    let mut left = left;
221    let mut right = right;
222    while left != right {
223        let left_seg = storage.get_segment(left)?;
224        let left_cmd = left_seg.get_command(left).assume("location must exist")?;
225        let right_seg = storage.get_segment(right)?;
226        let right_cmd = right_seg.get_command(right).assume("location must exist")?;
227        // The command with the lower max cut could be our least common ancestor
228        // so we keeping following the command with the higher max cut until
229        // both sides converge.
230        if left_cmd.max_cut()? > right_cmd.max_cut()? {
231            left = if let Some(previous) = left.previous() {
232                previous
233            } else {
234                match left_seg.prior() {
235                    Prior::None => left,
236                    Prior::Single(s) => s,
237                    Prior::Merge(_, _) => {
238                        assert!(left.command == 0);
239                        if let Some((l, _)) = left_seg.skip_list().last() {
240                            // If the storage supports skip lists we return the
241                            // last common ancestor of this command.
242                            *l
243                        } else {
244                            // This case will only be hit if the storage doesn't
245                            // support skip lists so we can return anything
246                            // because it won't be used.
247                            return Ok((left, left_cmd.max_cut()?));
248                        }
249                    }
250                }
251            };
252        } else {
253            right = if let Some(previous) = right.previous() {
254                previous
255            } else {
256                match right_seg.prior() {
257                    Prior::None => right,
258                    Prior::Single(s) => s,
259                    Prior::Merge(_, _) => {
260                        assert!(right.command == 0);
261                        if let Some((r, _)) = right_seg.skip_list().last() {
262                            // If the storage supports skip lists we return the
263                            // last common ancestor of this command.
264                            *r
265                        } else {
266                            // This case will only be hit if the storage doesn't
267                            // support skip lists so we can return anything
268                            // because it won't be used.
269                            return Ok((right, right_cmd.max_cut()?));
270                        }
271                    }
272                }
273            };
274        }
275    }
276    let left_seg = storage.get_segment(left)?;
277    let left_cmd = left_seg.get_command(left).assume("location must exist")?;
278    Ok((left, left_cmd.max_cut()?))
279}
280
281/// Enforces deterministic ordering for a set of [`Command`]s in a graph.
282/// Returns the ordering.
283pub fn braid<S: Storage>(
284    storage: &mut S,
285    left: Location,
286    right: Location,
287) -> Result<Vec<Location>, ClientError> {
288    struct Strand<S> {
289        key: (Priority, CommandId),
290        next: Location,
291        segment: S,
292    }
293
294    impl<S: Segment> Strand<S> {
295        fn new(
296            storage: &mut impl Storage<Segment = S>,
297            location: Location,
298            cached_segment: Option<S>,
299        ) -> Result<Self, ClientError> {
300            let segment = cached_segment.map_or_else(|| storage.get_segment(location), Ok)?;
301
302            let key = {
303                let cmd = segment
304                    .get_command(location)
305                    .ok_or(StorageError::CommandOutOfBounds(location))?;
306                (cmd.priority(), cmd.id())
307            };
308
309            Ok(Strand {
310                key,
311                next: location,
312                segment,
313            })
314        }
315    }
316
317    impl<S> Eq for Strand<S> {}
318    impl<S> PartialEq for Strand<S> {
319        fn eq(&self, other: &Self) -> bool {
320            self.key == other.key
321        }
322    }
323    impl<S> Ord for Strand<S> {
324        fn cmp(&self, other: &Self) -> core::cmp::Ordering {
325            self.key.cmp(&other.key).reverse()
326        }
327    }
328    impl<S> PartialOrd for Strand<S> {
329        fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
330            Some(self.cmp(other))
331        }
332    }
333
334    let mut braid = Vec::new();
335    let mut strands = BinaryHeap::new();
336
337    trace!(%left, %right, "braiding");
338
339    for head in [left, right] {
340        strands.push(Strand::new(storage, head, None)?);
341    }
342
343    // Get latest command
344    while let Some(strand) = strands.pop() {
345        // Consume another command off the strand
346        let (prior, mut maybe_cached_segment) = if let Some(previous) = strand.next.previous() {
347            (Prior::Single(previous), Some(strand.segment))
348        } else {
349            (strand.segment.prior(), None)
350        };
351        if matches!(prior, Prior::Merge(..)) {
352            trace!("skipping merge command");
353        } else {
354            trace!("adding {}", strand.next);
355            braid.push(strand.next);
356        }
357
358        // Continue processing prior if not accessible from other strands.
359        'location: for location in prior {
360            for other in &strands {
361                trace!("checking {}", other.next);
362                if (location.same_segment(other.next) && location.command <= other.next.command)
363                    || storage.is_ancestor(location, &other.segment)?
364                {
365                    trace!("found ancestor");
366                    continue 'location;
367                }
368            }
369
370            trace!("strand at {location}");
371            strands.push(Strand::new(
372                storage,
373                location,
374                // Taking is OK here because `maybe_cached_segment` is `Some` when
375                // the current strand has a single parent that is in the same segment
376                Option::take(&mut maybe_cached_segment),
377            )?);
378        }
379        if strands.len() == 1 {
380            // No concurrency left, done.
381            let next = strands.pop().assume("strands not empty")?.next;
382            trace!("adding {}", strand.next);
383            braid.push(next);
384            break;
385        }
386    }
387
388    braid.reverse();
389    Ok(braid)
390}