aranya_runtime/
client.rs

1use alloc::{collections::BinaryHeap, vec::Vec};
2
3use buggy::{Bug, BugExt};
4use tracing::trace;
5
6use crate::{
7    Address, Command, CommandId, Engine, EngineError, GraphId, Location, PeerCache, Perspective,
8    Policy, Prior, Priority, Segment, Sink, Storage, StorageError, StorageProvider,
9};
10
11mod session;
12mod transaction;
13
14pub use self::{session::Session, transaction::Transaction};
15
16/// An error returned by the runtime client.
17#[derive(Debug, thiserror::Error)]
18pub enum ClientError {
19    #[error("no such parent: {0}")]
20    NoSuchParent(CommandId),
21    #[error("engine error: {0}")]
22    EngineError(EngineError),
23    #[error("storage error: {0}")]
24    StorageError(#[from] StorageError),
25    #[error("init error")]
26    InitError,
27    #[error("not authorized")]
28    NotAuthorized,
29    #[error("session deserialize error: {0}")]
30    SessionDeserialize(#[from] postcard::Error),
31    #[error(transparent)]
32    Bug(#[from] Bug),
33}
34
35impl From<EngineError> for ClientError {
36    fn from(error: EngineError) -> Self {
37        match error {
38            EngineError::Check => Self::NotAuthorized,
39            _ => Self::EngineError(error),
40        }
41    }
42}
43
44/// Keeps track of client graph state.
45///
46/// - `E` should be an implementation of [`Engine`].
47/// - `SP` should be an implementation of [`StorageProvider`].
48#[derive(Debug)]
49pub struct ClientState<E, SP> {
50    engine: E,
51    provider: SP,
52}
53
54impl<E, SP> ClientState<E, SP> {
55    /// Creates a `ClientState`.
56    pub const fn new(engine: E, provider: SP) -> ClientState<E, SP> {
57        ClientState { engine, provider }
58    }
59
60    /// Provide access to the [`StorageProvider`].
61    pub fn provider(&mut self) -> &mut SP {
62        &mut self.provider
63    }
64}
65
66impl<E, SP> ClientState<E, SP>
67where
68    E: Engine,
69    SP: StorageProvider,
70{
71    /// Create a new graph (AKA Team). This graph will start with the initial policy
72    /// provided which must be compatible with the engine E. The `payload` is the initial
73    /// init message that will bootstrap the graph facts. Effects produced when processing
74    /// the payload are emitted to the sink.
75    pub fn new_graph(
76        &mut self,
77        policy_data: &[u8],
78        action: <E::Policy as Policy>::Action<'_>,
79        sink: &mut impl Sink<E::Effect>,
80    ) -> Result<GraphId, ClientError> {
81        let policy_id = self.engine.add_policy(policy_data)?;
82        let policy = self.engine.get_policy(policy_id)?;
83
84        let mut perspective = self.provider.new_perspective(policy_id);
85        sink.begin();
86        policy
87            .call_action(action, &mut perspective, sink)
88            .inspect_err(|_| sink.rollback())?;
89        sink.commit();
90
91        let (graph_id, _) = self.provider.new_storage(perspective)?;
92
93        Ok(graph_id)
94    }
95
96    /// Commit the [`Transaction`] to storage, after merging all temporary heads.
97    pub fn commit(
98        &mut self,
99        trx: &mut Transaction<SP, E>,
100        sink: &mut impl Sink<E::Effect>,
101    ) -> Result<(), ClientError> {
102        trx.commit(&mut self.provider, &mut self.engine, sink)?;
103        Ok(())
104    }
105
106    /// Add commands to the transaction, writing the results to
107    /// `sink`.
108    /// Returns the number of commands that were added.
109    pub fn add_commands(
110        &mut self,
111        trx: &mut Transaction<SP, E>,
112        sink: &mut impl Sink<E::Effect>,
113        commands: &[impl Command],
114    ) -> Result<usize, ClientError> {
115        let count = trx.add_commands(commands, &mut self.provider, &mut self.engine, sink)?;
116        Ok(count)
117    }
118
119    pub fn update_heads(
120        &mut self,
121        storage_id: GraphId,
122        addrs: impl IntoIterator<Item = Address>,
123        request_heads: &mut PeerCache,
124    ) -> Result<(), ClientError> {
125        let storage = self.provider.get_storage(storage_id)?;
126        for address in addrs {
127            if let Some(loc) = storage.get_location(address)? {
128                request_heads.add_command(storage, address, loc)?;
129            }
130        }
131        Ok(())
132    }
133
134    /// Performs an `action`, writing the results to `sink`.
135    pub fn action(
136        &mut self,
137        storage_id: GraphId,
138        sink: &mut impl Sink<E::Effect>,
139        action: <E::Policy as Policy>::Action<'_>,
140    ) -> Result<(), ClientError> {
141        let storage = self.provider.get_storage(storage_id)?;
142
143        let head = storage.get_head()?;
144
145        let mut perspective = storage
146            .get_linear_perspective(head)?
147            .assume("can always get perspective at head")?;
148
149        let policy_id = perspective.policy();
150        let policy = self.engine.get_policy(policy_id)?;
151
152        // No need to checkpoint the perspective since it is only for this action.
153        // Must checkpoint once we add action transactions.
154
155        sink.begin();
156        match policy.call_action(action, &mut perspective, sink) {
157            Ok(_) => {
158                let segment = storage.write(perspective)?;
159                storage.commit(segment)?;
160                sink.commit();
161                Ok(())
162            }
163            Err(e) => {
164                sink.rollback();
165                Err(e.into())
166            }
167        }
168    }
169}
170
171impl<E, SP> ClientState<E, SP>
172where
173    SP: StorageProvider,
174{
175    /// Create a new [`Transaction`], used to receive [`Command`]s when syncing.
176    pub fn transaction(&mut self, storage_id: GraphId) -> Transaction<SP, E> {
177        Transaction::new(storage_id)
178    }
179
180    /// Create an ephemeral [`Session`] associated with this client.
181    pub fn session(&mut self, storage_id: GraphId) -> Result<Session<SP, E>, ClientError> {
182        Session::new(&mut self.provider, storage_id)
183    }
184}
185
186/// Returns the last common ancestor of two Locations.
187///
188/// This walks the graph backwards until the two locations meet. This
189/// ensures that you can jump to the last common ancestor from
190/// the merge command created using left and right and know that you
191/// won't be jumping into a branch.
192fn last_common_ancestor<S: Storage>(
193    storage: &mut S,
194    left: Location,
195    right: Location,
196) -> Result<(Location, usize), ClientError> {
197    trace!(%left, %right, "finding least common ancestor");
198    let mut left = left;
199    let mut right = right;
200    while left != right {
201        let left_seg = storage.get_segment(left)?;
202        let left_cmd = left_seg.get_command(left).assume("location must exist")?;
203        let right_seg = storage.get_segment(right)?;
204        let right_cmd = right_seg.get_command(right).assume("location must exist")?;
205        // The command with the lower max cut could be our least common ancestor
206        // so we keeping following the command with the higher max cut until
207        // both sides converge.
208        if left_cmd.max_cut()? > right_cmd.max_cut()? {
209            left = if let Some(previous) = left.previous() {
210                previous
211            } else {
212                match left_seg.prior() {
213                    Prior::None => left,
214                    Prior::Single(s) => s,
215                    Prior::Merge(_, _) => {
216                        assert!(left.command == 0);
217                        if let Some((l, _)) = left_seg.skip_list().last() {
218                            // If the storage supports skip lists we return the
219                            // last common ancestor of this command.
220                            *l
221                        } else {
222                            // This case will only be hit if the storage doesn't
223                            // support skip lists so we can return anything
224                            // because it won't be used.
225                            return Ok((left, left_cmd.max_cut()?));
226                        }
227                    }
228                }
229            };
230        } else {
231            right = if let Some(previous) = right.previous() {
232                previous
233            } else {
234                match right_seg.prior() {
235                    Prior::None => right,
236                    Prior::Single(s) => s,
237                    Prior::Merge(_, _) => {
238                        assert!(right.command == 0);
239                        if let Some((r, _)) = right_seg.skip_list().last() {
240                            // If the storage supports skip lists we return the
241                            // last common ancestor of this command.
242                            *r
243                        } else {
244                            // This case will only be hit if the storage doesn't
245                            // support skip lists so we can return anything
246                            // because it won't be used.
247                            return Ok((right, right_cmd.max_cut()?));
248                        }
249                    }
250                }
251            };
252        }
253    }
254    let left_seg = storage.get_segment(left)?;
255    let left_cmd = left_seg.get_command(left).assume("location must exist")?;
256    Ok((left, left_cmd.max_cut()?))
257}
258
259/// Enforces deterministic ordering for a set of [`Command`]s in a graph.
260/// Returns the ordering.
261pub fn braid<S: Storage>(
262    storage: &mut S,
263    left: Location,
264    right: Location,
265) -> Result<Vec<Location>, ClientError> {
266    struct Strand<S> {
267        key: (Priority, CommandId),
268        next: Location,
269        segment: S,
270    }
271
272    impl<S: Segment> Strand<S> {
273        fn new(
274            storage: &mut impl Storage<Segment = S>,
275            location: Location,
276            cached_segment: Option<S>,
277        ) -> Result<Self, ClientError> {
278            let segment = cached_segment.map_or_else(|| storage.get_segment(location), Ok)?;
279
280            let key = {
281                let cmd = segment
282                    .get_command(location)
283                    .ok_or(StorageError::CommandOutOfBounds(location))?;
284                (cmd.priority(), cmd.id())
285            };
286
287            Ok(Strand {
288                key,
289                next: location,
290                segment,
291            })
292        }
293    }
294
295    impl<S> Eq for Strand<S> {}
296    impl<S> PartialEq for Strand<S> {
297        fn eq(&self, other: &Self) -> bool {
298            self.key == other.key
299        }
300    }
301    impl<S> Ord for Strand<S> {
302        fn cmp(&self, other: &Self) -> core::cmp::Ordering {
303            self.key.cmp(&other.key).reverse()
304        }
305    }
306    impl<S> PartialOrd for Strand<S> {
307        fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
308            Some(self.cmp(other))
309        }
310    }
311
312    let mut braid = Vec::new();
313    let mut strands = BinaryHeap::new();
314
315    trace!(%left, %right, "braiding");
316
317    for head in [left, right] {
318        strands.push(Strand::new(storage, head, None)?);
319    }
320
321    // Get latest command
322    while let Some(strand) = strands.pop() {
323        // Consume another command off the strand
324        let (prior, mut maybe_cached_segment) = if let Some(previous) = strand.next.previous() {
325            (Prior::Single(previous), Some(strand.segment))
326        } else {
327            (strand.segment.prior(), None)
328        };
329        if matches!(prior, Prior::Merge(..)) {
330            trace!("skipping merge command");
331        } else {
332            trace!("adding {}", strand.next);
333            braid.push(strand.next);
334        }
335
336        // Continue processing prior if not accessible from other strands.
337        'location: for location in prior {
338            for other in &strands {
339                trace!("checking {}", other.next);
340                if (location.same_segment(other.next) && location.command <= other.next.command)
341                    || storage.is_ancestor(location, &other.segment)?
342                {
343                    trace!("found ancestor");
344                    continue 'location;
345                }
346            }
347
348            trace!("strand at {location}");
349            strands.push(Strand::new(
350                storage,
351                location,
352                // Taking is OK here because `maybe_cached_segment` is `Some` when
353                // the current strand has a single parent that is in the same segment
354                Option::take(&mut maybe_cached_segment),
355            )?);
356        }
357        if strands.len() == 1 {
358            // No concurrency left, done.
359            let next = strands.pop().assume("strands not empty")?.next;
360            trace!("adding {}", strand.next);
361            braid.push(next);
362            break;
363        }
364    }
365
366    braid.reverse();
367    Ok(braid)
368}