1use alloc::{
8 boxed::Box,
9 collections::{BTreeMap, btree_map},
10 string::String,
11 sync::Arc,
12 vec::Vec,
13};
14use core::{cmp::Ordering, iter::Peekable, marker::PhantomData, mem, ops::Bound};
15
16use buggy::{Bug, BugExt as _, bug};
17use serde::{Deserialize, Serialize};
18use tracing::warn;
19use yoke::{Yoke, Yokeable};
20
21use crate::{
22 Address, Checkpoint, ClientError, ClientState, CmdId, Command, Engine, Fact, FactPerspective,
23 GraphId, Keys, NullSink, Perspective, Policy, PolicyId, Prior, Priority, Query, QueryMut,
24 Revertable, Segment as _, Sink, Storage, StorageError, StorageProvider,
25 engine::{ActionPlacement, CommandPlacement},
26};
27
28type Bytes = Box<[u8]>;
29
30pub struct Session<SP: StorageProvider, E> {
32 storage_id: GraphId,
34 policy_id: PolicyId,
36
37 base_facts: <SP::Storage as Storage>::FactIndex,
39 fact_log: Vec<(String, Keys, Option<Bytes>)>,
41 current_facts: Arc<BTreeMap<String, BTreeMap<Keys, Option<Bytes>>>>,
43
44 _engine: PhantomData<E>,
46
47 head: Address,
48}
49
50struct SessionPerspective<'a, SP: StorageProvider, E, MS> {
51 session: &'a mut Session<SP, E>,
52 message_sink: &'a mut MS,
53}
54
55impl<SP: StorageProvider, E> Session<SP, E> {
56 pub(super) fn new(provider: &mut SP, storage_id: GraphId) -> Result<Self, ClientError> {
57 let storage = provider.get_storage(storage_id)?;
58 let head_loc = storage.get_head()?;
59 let seg = storage.get_segment(head_loc)?;
60 let command = seg.get_command(head_loc).assume("location must exist")?;
61
62 let base_facts = seg.facts()?;
63
64 let result = Self {
65 storage_id,
66 policy_id: seg.policy(),
67 base_facts,
68 fact_log: Vec::new(),
69 current_facts: Arc::default(),
70 _engine: PhantomData,
71 head: command.address()?,
72 };
73
74 Ok(result)
75 }
76}
77
78impl<SP: StorageProvider, E: Engine> Session<SP, E> {
79 pub fn action<ES, MS>(
82 &mut self,
83 client: &ClientState<E, SP>,
84 effect_sink: &mut ES,
85 message_sink: &mut MS,
86 action: <E::Policy as Policy>::Action<'_>,
87 ) -> Result<(), ClientError>
88 where
89 ES: Sink<E::Effect>,
90 MS: for<'b> Sink<&'b [u8]>,
91 {
92 let policy = client.engine.get_policy(self.policy_id)?;
93
94 let mut perspective = SessionPerspective {
96 session: self,
97 message_sink,
98 };
99 let checkpoint = perspective.checkpoint();
100 effect_sink.begin();
101
102 match policy.call_action(
104 action,
105 &mut perspective,
106 effect_sink,
107 ActionPlacement::OffGraph,
108 ) {
109 Ok(()) => {
110 effect_sink.commit();
112 Ok(())
113 }
114 Err(e) => {
115 perspective.revert(checkpoint)?;
117 perspective.message_sink.rollback();
118 effect_sink.rollback();
119 Err(e.into())
120 }
121 }
122 }
123
124 pub fn receive(
129 &mut self,
130 client: &ClientState<E, SP>,
131 sink: &mut impl Sink<E::Effect>,
132 command_bytes: &[u8],
133 ) -> Result<(), ClientError> {
134 let command: SessionCommand<'_> =
135 postcard::from_bytes(command_bytes).map_err(ClientError::SessionDeserialize)?;
136
137 if command.storage_id != self.storage_id {
138 warn!(%command.storage_id, %self.storage_id, "ephemeral commands must be run on the same graph");
139 return Err(ClientError::NotAuthorized);
140 }
141
142 let policy = client.engine.get_policy(self.policy_id)?;
143
144 let mut perspective = SessionPerspective {
146 session: self,
147 message_sink: &mut NullSink,
148 };
149
150 sink.begin();
152 let checkpoint = perspective.checkpoint();
153 if let Err(e) =
154 policy.call_rule(&command, &mut perspective, sink, CommandPlacement::OffGraph)
155 {
156 perspective.revert(checkpoint)?;
157 sink.rollback();
158 return Err(e.into());
159 }
160 sink.commit();
161
162 Ok(())
163 }
164}
165
166#[derive(Serialize, Deserialize)]
167struct SessionCommand<'a> {
169 storage_id: GraphId,
170 priority: u32, id: CmdId,
172 parent: Address, #[serde(borrow)]
174 data: &'a [u8],
175}
176
177impl Command for SessionCommand<'_> {
178 fn priority(&self) -> Priority {
179 Priority::Basic(self.priority)
180 }
181
182 fn id(&self) -> CmdId {
183 self.id
184 }
185
186 fn parent(&self) -> Prior<Address> {
187 Prior::Single(self.parent)
188 }
189
190 fn policy(&self) -> Option<&[u8]> {
191 None
193 }
194
195 fn bytes(&self) -> &[u8] {
196 self.data
197 }
198}
199
200impl<'sc> SessionCommand<'sc> {
201 fn from_cmd(storage_id: GraphId, command: &'sc impl Command) -> Result<Self, Bug> {
202 if command.policy().is_some() {
203 bug!("session command should have no policy")
204 }
205 Ok(SessionCommand {
206 storage_id,
207 priority: match command.priority() {
208 Priority::Basic(p) => p,
209 _ => bug!("wrong command type"),
210 },
211 id: command.id(),
212 parent: match command.parent() {
213 Prior::Single(p) => p,
214 _ => bug!("wrong command type"),
215 },
216 data: command.bytes(),
217 })
218 }
219}
220
221struct QueryIterator<I1: Iterator, I2: Iterator> {
223 prior: Peekable<I1>,
224 current: Peekable<I2>,
225}
226
227impl<I1, I2> QueryIterator<I1, I2>
228where
229 I1: Iterator<Item = Result<Fact, StorageError>>,
230 I2: Iterator<Item = (Keys, Option<Bytes>)>,
231{
232 fn new(prior: I1, current: I2) -> Self {
233 Self {
234 prior: prior.peekable(),
235 current: current.peekable(),
236 }
237 }
238}
239
240impl<I1, I2> Iterator for QueryIterator<I1, I2>
241where
242 I1: Iterator<Item = Result<Fact, StorageError>>,
243 I2: Iterator<Item = (Keys, Option<Bytes>)>,
244{
245 type Item = Result<Fact, StorageError>;
246
247 fn next(&mut self) -> Option<Self::Item> {
248 loop {
253 let Some(new) = self.current.peek() else {
254 return self.prior.next();
256 };
257 if let Some(old) = self.prior.peek() {
258 let Ok(old) = old else {
259 return self.prior.next();
261 };
262 match new.0.cmp(&old.key) {
263 Ordering::Equal => {
264 let _ = self.prior.next();
266 }
267 Ordering::Greater => {
268 return self.prior.next();
270 }
271 Ordering::Less => {
272 }
274 }
275 }
276 let Some(slot) = self.current.next() else {
277 bug!("expected Some after peek")
278 };
279 if let (k, Some(v)) = slot {
280 return Some(Ok(Fact {
281 key: k.iter().cloned().collect(),
282 value: v,
283 }));
284 }
285 }
286 }
287}
288
289impl<SP, E, MS> FactPerspective for SessionPerspective<'_, SP, E, MS> where SP: StorageProvider {}
290
291impl<SP, E, MS> Query for SessionPerspective<'_, SP, E, MS>
292where
293 SP: StorageProvider,
294{
295 fn query(&self, name: &str, keys: &[Box<[u8]>]) -> Result<Option<Box<[u8]>>, StorageError> {
296 if let Some(slot) = self
297 .session
298 .current_facts
299 .get(name)
300 .and_then(|m| m.get(keys))
301 {
302 return Ok(slot.clone());
303 }
304 self.session.base_facts.query(name, keys)
305 }
306
307 type QueryIterator = QueryIterator<
308 <<SP::Storage as Storage>::FactIndex as Query>::QueryIterator,
309 YokeIter<PrefixIter<'static>, Arc<BTreeMap<String, BTreeMap<Keys, Option<Bytes>>>>>,
310 >;
311 fn query_prefix(
312 &self,
313 name: &str,
314 prefix: &[Box<[u8]>],
315 ) -> Result<Self::QueryIterator, StorageError> {
316 let prior = self.session.base_facts.query_prefix(name, prefix)?;
317 let current = Yoke::<PrefixIter<'static>, _>::attach_to_cart(
318 Arc::clone(&self.session.current_facts),
319 |map| match map.get(name) {
320 Some(facts) => PrefixIter::new(facts, prefix.iter().cloned().collect()),
321 None => PrefixIter::default(),
322 },
323 );
324 Ok(QueryIterator::new(prior, YokeIter::new(current)))
325 }
326}
327
328#[derive(Default, Yokeable)]
333struct PrefixIter<'map> {
334 range: btree_map::Range<'map, Keys, Option<Bytes>>,
335 prefix: Keys,
336}
337
338impl<'map> PrefixIter<'map> {
339 fn new(map: &'map BTreeMap<Keys, Option<Bytes>>, prefix: Keys) -> Self {
340 let range =
341 map.range::<[Box<[u8]>], _>((Bound::Included(prefix.as_ref()), Bound::Unbounded));
342 Self { range, prefix }
343 }
344}
345
346impl Iterator for PrefixIter<'_> {
347 type Item = (Keys, Option<Bytes>);
348
349 fn next(&mut self) -> Option<Self::Item> {
350 self.range
351 .next()
352 .filter(|(k, _)| k.starts_with(&self.prefix))
353 .map(|(k, v)| (k.clone(), v.clone()))
354 }
355}
356
357struct YokeIter<I: for<'a> Yokeable<'a>, C>(Option<Yoke<I, C>>);
359
360impl<I: for<'a> Yokeable<'a>, C> YokeIter<I, C> {
361 fn new(yoke: Yoke<I, C>) -> Self {
362 Self(Some(yoke))
363 }
364}
365
366impl<I, C> Iterator for YokeIter<I, C>
367where
368 I: Iterator + for<'a> Yokeable<'a>,
369 for<'a> <I as Yokeable<'a>>::Output: Iterator<Item = I::Item>,
370{
371 type Item = I::Item;
372
373 fn next(&mut self) -> Option<Self::Item> {
374 let mut item = None;
377 self.0 = Some(self.0.take()?.map_project::<I, _>(|mut it, _| {
378 item = it.next();
379 it
380 }));
381 item
382 }
383}
384
385impl<SP: StorageProvider, E, MS> QueryMut for SessionPerspective<'_, SP, E, MS> {
386 fn insert(&mut self, name: String, keys: Keys, value: Box<[u8]>) {
387 self.session
388 .fact_log
389 .push((name.clone(), keys.clone(), Some(value.clone())));
390 Arc::make_mut(&mut self.session.current_facts)
391 .entry(name)
392 .or_default()
393 .insert(keys, Some(value));
394 }
395
396 fn delete(&mut self, name: String, keys: Keys) {
397 self.session
398 .fact_log
399 .push((name.clone(), keys.clone(), None));
400 Arc::make_mut(&mut self.session.current_facts)
401 .entry(name)
402 .or_default()
403 .insert(keys, None);
404 }
405}
406
407impl<SP, E, MS> Perspective for SessionPerspective<'_, SP, E, MS>
408where
409 SP: StorageProvider,
410 MS: for<'b> Sink<&'b [u8]>,
411{
412 fn policy(&self) -> PolicyId {
413 self.session.policy_id
414 }
415
416 fn add_command(&mut self, command: &impl Command) -> Result<usize, StorageError> {
417 let command = SessionCommand::from_cmd(self.session.storage_id, command)?;
418 self.session.head = command.address()?;
419 let bytes = postcard::to_allocvec(&command).assume("serialize session command")?;
420 self.message_sink.consume(&bytes);
421
422 Ok(0)
423 }
424
425 fn includes(&self, _id: CmdId) -> bool {
426 debug_assert!(false, "only used in transactions");
427
428 false
429 }
430
431 fn head_address(&self) -> Result<Prior<Address>, Bug> {
432 Ok(Prior::Single(self.session.head))
433 }
434}
435
436impl<SP, E, MS> Revertable for SessionPerspective<'_, SP, E, MS>
437where
438 SP: StorageProvider,
439{
440 fn checkpoint(&self) -> Checkpoint {
441 Checkpoint {
442 index: self.session.fact_log.len(),
443 }
444 }
445
446 fn revert(&mut self, checkpoint: Checkpoint) -> Result<(), Bug> {
447 if checkpoint.index == self.session.fact_log.len() {
448 return Ok(());
449 }
450
451 if checkpoint.index > self.session.fact_log.len() {
452 bug!(
453 "A checkpoint's index should always be less than or equal to the length of a session's fact log!"
454 );
455 }
456
457 self.session.fact_log.truncate(checkpoint.index);
458 let mut facts =
460 Arc::get_mut(&mut self.session.current_facts).map_or_else(BTreeMap::new, mem::take);
461 facts.clear();
462 for (n, k, v) in self.session.fact_log.iter().cloned() {
463 facts.entry(n).or_default().insert(k, v);
464 }
465 self.session.current_facts = Arc::new(facts);
466
467 Ok(())
468 }
469}
470
471#[cfg(test)]
472mod test {
473 use super::*;
474
475 #[test]
476 fn test_query_iterator() {
477 #![allow(clippy::type_complexity)]
478
479 let prior: Vec<Result<(&[&[u8]], &[u8]), _>> = vec![
480 Ok((&[b"a"], b"a0")),
481 Ok((&[b"c"], b"c0")),
482 Ok((&[b"d"], b"d0")),
483 Ok((&[b"f"], b"f0")),
484 Err(StorageError::IoError),
485 ];
486 let current: Vec<([Box<[u8]>; 1], Option<&[u8]>)> = vec![
487 ([Box::new(*b"a")], None),
488 ([Box::new(*b"b")], Some(b"b1")),
489 ([Box::new(*b"e")], None),
490 ([Box::new(*b"j")], None),
491 ];
492 let merged: Vec<Result<(&[&[u8]], &[u8]), _>> = vec![
493 Ok((&[b"b"], b"b1")),
494 Ok((&[b"c"], b"c0")),
495 Ok((&[b"d"], b"d0")),
496 Ok((&[b"f"], b"f0")),
497 Err(StorageError::IoError),
498 ];
499
500 let got: Vec<_> = QueryIterator::new(
501 prior.into_iter().map(|r| {
502 r.map(|(k, v)| Fact {
503 key: k.into(),
504 value: v.into(),
505 })
506 }),
507 current
508 .into_iter()
509 .map(|(k, v)| (k.into_iter().collect(), v.map(Box::from))),
510 )
511 .collect();
512 let want: Vec<_> = merged
513 .into_iter()
514 .map(|r| {
515 r.map(|(k, v)| Fact {
516 key: k.into(),
517 value: v.into(),
518 })
519 })
520 .collect();
521
522 assert_eq!(got, want);
523 }
524}