1use alloc::{collections::BinaryHeap, vec::Vec};
2
3use buggy::{Bug, BugExt};
4use tracing::trace;
5
6use crate::{
7 Address, Command, CommandId, Engine, EngineError, GraphId, Location, PeerCache, Perspective,
8 Policy, Prior, Priority, Segment, Sink, Storage, StorageError, StorageProvider,
9};
10
11mod session;
12mod transaction;
13
14pub use self::{session::Session, transaction::Transaction};
15
16#[derive(Debug, thiserror::Error)]
18pub enum ClientError {
19 #[error("no such parent: {0}")]
20 NoSuchParent(CommandId),
21 #[error("engine error: {0}")]
22 EngineError(EngineError),
23 #[error("storage error: {0}")]
24 StorageError(#[from] StorageError),
25 #[error("init error")]
26 InitError,
27 #[error("not authorized")]
28 NotAuthorized,
29 #[error("session deserialize error: {0}")]
30 SessionDeserialize(#[from] postcard::Error),
31 #[error(transparent)]
32 Bug(#[from] Bug),
33}
34
35impl From<EngineError> for ClientError {
36 fn from(error: EngineError) -> Self {
37 match error {
38 EngineError::Check => Self::NotAuthorized,
39 _ => Self::EngineError(error),
40 }
41 }
42}
43
44#[derive(Debug)]
49pub struct ClientState<E, SP> {
50 engine: E,
51 provider: SP,
52}
53
54impl<E, SP> ClientState<E, SP> {
55 pub const fn new(engine: E, provider: SP) -> ClientState<E, SP> {
57 ClientState { engine, provider }
58 }
59
60 pub fn provider(&mut self) -> &mut SP {
62 &mut self.provider
63 }
64}
65
66impl<E, SP> ClientState<E, SP>
67where
68 E: Engine,
69 SP: StorageProvider,
70{
71 pub fn new_graph(
76 &mut self,
77 policy_data: &[u8],
78 action: <E::Policy as Policy>::Action<'_>,
79 sink: &mut impl Sink<E::Effect>,
80 ) -> Result<GraphId, ClientError> {
81 let policy_id = self.engine.add_policy(policy_data)?;
82 let policy = self.engine.get_policy(policy_id)?;
83
84 let mut perspective = self.provider.new_perspective(policy_id);
85 sink.begin();
86 policy
87 .call_action(action, &mut perspective, sink)
88 .inspect_err(|_| sink.rollback())?;
89 sink.commit();
90
91 let (graph_id, _) = self.provider.new_storage(perspective)?;
92
93 Ok(graph_id)
94 }
95
96 pub fn commit(
98 &mut self,
99 trx: &mut Transaction<SP, E>,
100 sink: &mut impl Sink<E::Effect>,
101 ) -> Result<(), ClientError> {
102 trx.commit(&mut self.provider, &mut self.engine, sink)?;
103 Ok(())
104 }
105
106 pub fn add_commands(
110 &mut self,
111 trx: &mut Transaction<SP, E>,
112 sink: &mut impl Sink<E::Effect>,
113 commands: &[impl Command],
114 ) -> Result<usize, ClientError> {
115 let count = trx.add_commands(commands, &mut self.provider, &mut self.engine, sink)?;
116 Ok(count)
117 }
118
119 pub fn update_heads(
120 &mut self,
121 storage_id: GraphId,
122 addrs: impl IntoIterator<Item = Address>,
123 request_heads: &mut PeerCache,
124 ) -> Result<(), ClientError> {
125 let storage = self.provider.get_storage(storage_id)?;
126 for address in addrs {
127 if let Some(loc) = storage.get_location(address)? {
128 request_heads.add_command(storage, address, loc)?;
129 }
130 }
131 Ok(())
132 }
133
134 pub fn action(
136 &mut self,
137 storage_id: GraphId,
138 sink: &mut impl Sink<E::Effect>,
139 action: <E::Policy as Policy>::Action<'_>,
140 ) -> Result<(), ClientError> {
141 let storage = self.provider.get_storage(storage_id)?;
142
143 let head = storage.get_head()?;
144
145 let mut perspective = storage
146 .get_linear_perspective(head)?
147 .assume("can always get perspective at head")?;
148
149 let policy_id = perspective.policy();
150 let policy = self.engine.get_policy(policy_id)?;
151
152 sink.begin();
156 match policy.call_action(action, &mut perspective, sink) {
157 Ok(_) => {
158 let segment = storage.write(perspective)?;
159 storage.commit(segment)?;
160 sink.commit();
161 Ok(())
162 }
163 Err(e) => {
164 sink.rollback();
165 Err(e.into())
166 }
167 }
168 }
169}
170
171impl<E, SP> ClientState<E, SP>
172where
173 SP: StorageProvider,
174{
175 pub fn transaction(&mut self, storage_id: GraphId) -> Transaction<SP, E> {
177 Transaction::new(storage_id)
178 }
179
180 pub fn session(&mut self, storage_id: GraphId) -> Result<Session<SP, E>, ClientError> {
182 Session::new(&mut self.provider, storage_id)
183 }
184}
185
186fn last_common_ancestor<S: Storage>(
193 storage: &mut S,
194 left: Location,
195 right: Location,
196) -> Result<(Location, usize), ClientError> {
197 trace!(%left, %right, "finding least common ancestor");
198 let mut left = left;
199 let mut right = right;
200 while left != right {
201 let left_seg = storage.get_segment(left)?;
202 let left_cmd = left_seg.get_command(left).assume("location must exist")?;
203 let right_seg = storage.get_segment(right)?;
204 let right_cmd = right_seg.get_command(right).assume("location must exist")?;
205 if left_cmd.max_cut()? > right_cmd.max_cut()? {
209 left = if let Some(previous) = left.previous() {
210 previous
211 } else {
212 match left_seg.prior() {
213 Prior::None => left,
214 Prior::Single(s) => s,
215 Prior::Merge(_, _) => {
216 assert!(left.command == 0);
217 if let Some((l, _)) = left_seg.skip_list().last() {
218 *l
221 } else {
222 return Ok((left, left_cmd.max_cut()?));
226 }
227 }
228 }
229 };
230 } else {
231 right = if let Some(previous) = right.previous() {
232 previous
233 } else {
234 match right_seg.prior() {
235 Prior::None => right,
236 Prior::Single(s) => s,
237 Prior::Merge(_, _) => {
238 assert!(right.command == 0);
239 if let Some((r, _)) = right_seg.skip_list().last() {
240 *r
243 } else {
244 return Ok((right, right_cmd.max_cut()?));
248 }
249 }
250 }
251 };
252 }
253 }
254 let left_seg = storage.get_segment(left)?;
255 let left_cmd = left_seg.get_command(left).assume("location must exist")?;
256 Ok((left, left_cmd.max_cut()?))
257}
258
259pub fn braid<S: Storage>(
262 storage: &mut S,
263 left: Location,
264 right: Location,
265) -> Result<Vec<Location>, ClientError> {
266 struct Strand<S> {
267 key: (Priority, CommandId),
268 next: Location,
269 segment: S,
270 }
271
272 impl<S: Segment> Strand<S> {
273 fn new(
274 storage: &mut impl Storage<Segment = S>,
275 location: Location,
276 cached_segment: Option<S>,
277 ) -> Result<Self, ClientError> {
278 let segment = cached_segment.map_or_else(|| storage.get_segment(location), Ok)?;
279
280 let key = {
281 let cmd = segment
282 .get_command(location)
283 .ok_or(StorageError::CommandOutOfBounds(location))?;
284 (cmd.priority(), cmd.id())
285 };
286
287 Ok(Strand {
288 key,
289 next: location,
290 segment,
291 })
292 }
293 }
294
295 impl<S> Eq for Strand<S> {}
296 impl<S> PartialEq for Strand<S> {
297 fn eq(&self, other: &Self) -> bool {
298 self.key == other.key
299 }
300 }
301 impl<S> Ord for Strand<S> {
302 fn cmp(&self, other: &Self) -> core::cmp::Ordering {
303 self.key.cmp(&other.key).reverse()
304 }
305 }
306 impl<S> PartialOrd for Strand<S> {
307 fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
308 Some(self.cmp(other))
309 }
310 }
311
312 let mut braid = Vec::new();
313 let mut strands = BinaryHeap::new();
314
315 trace!(%left, %right, "braiding");
316
317 for head in [left, right] {
318 strands.push(Strand::new(storage, head, None)?);
319 }
320
321 while let Some(strand) = strands.pop() {
323 let (prior, mut maybe_cached_segment) = if let Some(previous) = strand.next.previous() {
325 (Prior::Single(previous), Some(strand.segment))
326 } else {
327 (strand.segment.prior(), None)
328 };
329 if matches!(prior, Prior::Merge(..)) {
330 trace!("skipping merge command");
331 } else {
332 trace!("adding {}", strand.next);
333 braid.push(strand.next);
334 }
335
336 'location: for location in prior {
338 for other in &strands {
339 trace!("checking {}", other.next);
340 if (location.same_segment(other.next) && location.command <= other.next.command)
341 || storage.is_ancestor(location, &other.segment)?
342 {
343 trace!("found ancestor");
344 continue 'location;
345 }
346 }
347
348 trace!("strand at {location}");
349 strands.push(Strand::new(
350 storage,
351 location,
352 Option::take(&mut maybe_cached_segment),
355 )?);
356 }
357 if strands.len() == 1 {
358 let next = strands.pop().assume("strands not empty")?.next;
360 trace!("adding {}", strand.next);
361 braid.push(next);
362 break;
363 }
364 }
365
366 braid.reverse();
367 Ok(braid)
368}