1use alloc::{collections::BinaryHeap, vec::Vec};
2use core::fmt;
3
4use aranya_buggy::{Bug, BugExt};
5use tracing::trace;
6
7use crate::{
8 Command, CommandId, Engine, EngineError, GraphId, Location, PeerCache, Perspective, Policy,
9 Prior, Priority, Segment, Sink, Storage, StorageError, StorageProvider,
10};
11
12mod session;
13mod transaction;
14
15pub use self::{session::Session, transaction::Transaction};
16
17#[derive(Debug)]
19pub enum ClientError {
20 NoSuchParent(CommandId),
21 EngineError(EngineError),
22 StorageError(StorageError),
23 InitError,
24 NotAuthorized,
25 SessionDeserialize(postcard::Error),
26 Bug(Bug),
27}
28
29impl fmt::Display for ClientError {
30 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31 match self {
32 Self::NoSuchParent(id) => write!(f, "no such parent: {id}"),
33 Self::EngineError(e) => write!(f, "engine error: {e}"),
34 Self::StorageError(e) => write!(f, "storage error: {e}"),
35 Self::InitError => write!(f, "init error"),
36 Self::NotAuthorized => write!(f, "not authorized"),
37 Self::SessionDeserialize(e) => write!(f, "session deserialize error: {e}"),
38 Self::Bug(bug) => write!(f, "{bug}"),
39 }
40 }
41}
42
43impl core::error::Error for ClientError {
44 fn source(&self) -> Option<&(dyn core::error::Error + 'static)> {
45 match self {
46 Self::EngineError(e) => Some(e),
47 Self::StorageError(e) => Some(e),
48 Self::Bug(e) => Some(e),
49 _ => None,
50 }
51 }
52}
53
54impl From<EngineError> for ClientError {
55 fn from(error: EngineError) -> Self {
56 match error {
57 EngineError::Check => Self::NotAuthorized,
58 _ => Self::EngineError(error),
59 }
60 }
61}
62
63impl From<StorageError> for ClientError {
64 fn from(error: StorageError) -> Self {
65 ClientError::StorageError(error)
66 }
67}
68
69impl From<Bug> for ClientError {
70 fn from(error: Bug) -> Self {
71 ClientError::Bug(error)
72 }
73}
74
75#[derive(Debug)]
80pub struct ClientState<E, SP> {
81 engine: E,
82 provider: SP,
83}
84
85impl<E, SP> ClientState<E, SP> {
86 pub const fn new(engine: E, provider: SP) -> ClientState<E, SP> {
88 ClientState { engine, provider }
89 }
90
91 pub fn provider(&mut self) -> &mut SP {
93 &mut self.provider
94 }
95}
96
97impl<E, SP> ClientState<E, SP>
98where
99 E: Engine,
100 SP: StorageProvider,
101{
102 pub fn new_graph(
107 &mut self,
108 policy_data: &[u8],
109 action: <E::Policy as Policy>::Action<'_>,
110 sink: &mut impl Sink<E::Effect>,
111 ) -> Result<GraphId, ClientError> {
112 let policy_id = self.engine.add_policy(policy_data)?;
113 let policy = self.engine.get_policy(policy_id)?;
114
115 let mut perspective = self.provider.new_perspective(policy_id);
116 sink.begin();
117 policy
118 .call_action(action, &mut perspective, sink)
119 .inspect_err(|_| sink.rollback())?;
120 sink.commit();
121
122 let (graph_id, _) = self.provider.new_storage(perspective)?;
123
124 Ok(graph_id)
125 }
126
127 pub fn commit(
129 &mut self,
130 trx: &mut Transaction<SP, E>,
131 sink: &mut impl Sink<E::Effect>,
132 ) -> Result<(), ClientError> {
133 trx.commit(&mut self.provider, &mut self.engine, sink)
134 }
135
136 pub fn add_commands(
140 &mut self,
141 trx: &mut Transaction<SP, E>,
142 sink: &mut impl Sink<E::Effect>,
143 commands: &[impl Command],
144 request_heads: &mut PeerCache,
145 ) -> Result<usize, ClientError> {
146 let count = trx.add_commands(
147 commands,
148 &mut self.provider,
149 &mut self.engine,
150 sink,
151 request_heads,
152 )?;
153 Ok(count)
154 }
155
156 pub fn action(
158 &mut self,
159 storage_id: GraphId,
160 sink: &mut impl Sink<E::Effect>,
161 action: <E::Policy as Policy>::Action<'_>,
162 ) -> Result<(), ClientError> {
163 let storage = self.provider.get_storage(storage_id)?;
164
165 let head = storage.get_head()?;
166
167 let mut perspective = storage
168 .get_linear_perspective(head)?
169 .assume("can always get perspective at head")?;
170
171 let policy_id = perspective.policy();
172 let policy = self.engine.get_policy(policy_id)?;
173
174 sink.begin();
178 match policy.call_action(action, &mut perspective, sink) {
179 Ok(_) => {
180 let segment = storage.write(perspective)?;
181 storage.commit(segment)?;
182 sink.commit();
183 Ok(())
184 }
185 Err(e) => {
186 sink.rollback();
187 Err(e.into())
188 }
189 }
190 }
191}
192
193impl<E, SP> ClientState<E, SP>
194where
195 SP: StorageProvider,
196{
197 pub fn transaction(&mut self, storage_id: GraphId) -> Transaction<SP, E> {
199 Transaction::new(storage_id)
200 }
201
202 pub fn session(&mut self, storage_id: GraphId) -> Result<Session<SP, E>, ClientError> {
204 Session::new(&mut self.provider, storage_id)
205 }
206}
207
208fn last_common_ancestor<S: Storage>(
215 storage: &mut S,
216 left: Location,
217 right: Location,
218) -> Result<(Location, usize), ClientError> {
219 trace!(%left, %right, "finding least common ancestor");
220 let mut left = left;
221 let mut right = right;
222 while left != right {
223 let left_seg = storage.get_segment(left)?;
224 let left_cmd = left_seg.get_command(left).assume("location must exist")?;
225 let right_seg = storage.get_segment(right)?;
226 let right_cmd = right_seg.get_command(right).assume("location must exist")?;
227 if left_cmd.max_cut()? > right_cmd.max_cut()? {
231 left = if let Some(previous) = left.previous() {
232 previous
233 } else {
234 match left_seg.prior() {
235 Prior::None => left,
236 Prior::Single(s) => s,
237 Prior::Merge(_, _) => {
238 assert!(left.command == 0);
239 if let Some((l, _)) = left_seg.skip_list().last() {
240 *l
243 } else {
244 return Ok((left, left_cmd.max_cut()?));
248 }
249 }
250 }
251 };
252 } else {
253 right = if let Some(previous) = right.previous() {
254 previous
255 } else {
256 match right_seg.prior() {
257 Prior::None => right,
258 Prior::Single(s) => s,
259 Prior::Merge(_, _) => {
260 assert!(right.command == 0);
261 if let Some((r, _)) = right_seg.skip_list().last() {
262 *r
265 } else {
266 return Ok((right, right_cmd.max_cut()?));
270 }
271 }
272 }
273 };
274 }
275 }
276 let left_seg = storage.get_segment(left)?;
277 let left_cmd = left_seg.get_command(left).assume("location must exist")?;
278 Ok((left, left_cmd.max_cut()?))
279}
280
281pub fn braid<S: Storage>(
284 storage: &mut S,
285 left: Location,
286 right: Location,
287) -> Result<Vec<Location>, ClientError> {
288 struct Strand<S> {
289 key: (Priority, CommandId),
290 next: Location,
291 segment: S,
292 }
293
294 impl<S: Segment> Strand<S> {
295 fn new(
296 storage: &mut impl Storage<Segment = S>,
297 location: Location,
298 cached_segment: Option<S>,
299 ) -> Result<Self, ClientError> {
300 let segment = cached_segment.map_or_else(|| storage.get_segment(location), Ok)?;
301
302 let key = {
303 let cmd = segment
304 .get_command(location)
305 .ok_or(StorageError::CommandOutOfBounds(location))?;
306 (cmd.priority(), cmd.id())
307 };
308
309 Ok(Strand {
310 key,
311 next: location,
312 segment,
313 })
314 }
315 }
316
317 impl<S> Eq for Strand<S> {}
318 impl<S> PartialEq for Strand<S> {
319 fn eq(&self, other: &Self) -> bool {
320 self.key == other.key
321 }
322 }
323 impl<S> Ord for Strand<S> {
324 fn cmp(&self, other: &Self) -> core::cmp::Ordering {
325 self.key.cmp(&other.key).reverse()
326 }
327 }
328 impl<S> PartialOrd for Strand<S> {
329 fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
330 Some(self.cmp(other))
331 }
332 }
333
334 let mut braid = Vec::new();
335 let mut strands = BinaryHeap::new();
336
337 trace!(%left, %right, "braiding");
338
339 for head in [left, right] {
340 strands.push(Strand::new(storage, head, None)?);
341 }
342
343 while let Some(strand) = strands.pop() {
345 let (prior, mut maybe_cached_segment) = if let Some(previous) = strand.next.previous() {
347 (Prior::Single(previous), Some(strand.segment))
348 } else {
349 (strand.segment.prior(), None)
350 };
351 if matches!(prior, Prior::Merge(..)) {
352 trace!("skipping merge command");
353 } else {
354 trace!("adding {}", strand.next);
355 braid.push(strand.next);
356 }
357
358 'location: for location in prior {
360 for other in &strands {
361 trace!("checking {}", other.next);
362 if (location.same_segment(other.next) && location.command <= other.next.command)
363 || storage.is_ancestor(location, &other.segment)?
364 {
365 trace!("found ancestor");
366 continue 'location;
367 }
368 }
369
370 trace!("strand at {location}");
371 strands.push(Strand::new(
372 storage,
373 location,
374 Option::take(&mut maybe_cached_segment),
377 )?);
378 }
379 if strands.len() == 1 {
380 let next = strands.pop().assume("strands not empty")?.next;
382 trace!("adding {}", strand.next);
383 braid.push(next);
384 break;
385 }
386 }
387
388 braid.reverse();
389 Ok(braid)
390}