1use alloc::{
8 boxed::Box,
9 collections::{BTreeMap, btree_map},
10 string::String,
11 sync::Arc,
12 vec::Vec,
13};
14use core::{cmp::Ordering, iter::Peekable, marker::PhantomData, mem, ops::Bound};
15
16use buggy::{Bug, BugExt, bug};
17use serde::{Deserialize, Serialize};
18use tracing::warn;
19use yoke::{Yoke, Yokeable};
20
21use crate::{
22 Address, Checkpoint, ClientError, ClientState, CmdId, Command, CommandRecall, Engine, Fact,
23 FactPerspective, GraphId, Keys, NullSink, Perspective, Policy, PolicyId, Prior, Priority,
24 Query, QueryMut, Revertable, Segment, Sink, Storage, StorageError, StorageProvider,
25};
26
27type Bytes = Box<[u8]>;
28
29pub struct Session<SP: StorageProvider, E> {
31 storage_id: GraphId,
33 policy_id: PolicyId,
35
36 base_facts: <SP::Storage as Storage>::FactIndex,
38 fact_log: Vec<(String, Keys, Option<Bytes>)>,
40 current_facts: Arc<BTreeMap<String, BTreeMap<Keys, Option<Bytes>>>>,
42
43 _engine: PhantomData<E>,
45
46 head: Address,
47}
48
49struct SessionPerspective<'a, SP: StorageProvider, E, MS> {
50 session: &'a mut Session<SP, E>,
51 message_sink: &'a mut MS,
52}
53
54impl<SP: StorageProvider, E> Session<SP, E> {
55 pub(super) fn new(provider: &mut SP, storage_id: GraphId) -> Result<Self, ClientError> {
56 let storage = provider.get_storage(storage_id)?;
57 let head_loc = storage.get_head()?;
58 let seg = storage.get_segment(head_loc)?;
59 let command = seg.get_command(head_loc).assume("location must exist")?;
60
61 let base_facts = seg.facts()?;
62
63 let result = Self {
64 storage_id,
65 policy_id: seg.policy(),
66 base_facts,
67 fact_log: Vec::new(),
68 current_facts: Arc::default(),
69 _engine: PhantomData,
70 head: command.address()?,
71 };
72
73 Ok(result)
74 }
75}
76
77impl<SP: StorageProvider, E: Engine> Session<SP, E> {
78 pub fn action<ES, MS>(
81 &mut self,
82 client: &ClientState<E, SP>,
83 effect_sink: &mut ES,
84 message_sink: &mut MS,
85 action: <E::Policy as Policy>::Action<'_>,
86 ) -> Result<(), ClientError>
87 where
88 ES: Sink<E::Effect>,
89 MS: for<'b> Sink<&'b [u8]>,
90 {
91 let policy = client.engine.get_policy(self.policy_id)?;
92
93 let mut perspective = SessionPerspective {
95 session: self,
96 message_sink,
97 };
98 let checkpoint = perspective.checkpoint();
99 effect_sink.begin();
100
101 match policy.call_action(action, &mut perspective, effect_sink) {
103 Ok(_) => {
104 effect_sink.commit();
106 Ok(())
107 }
108 Err(e) => {
109 perspective.revert(checkpoint)?;
111 perspective.message_sink.rollback();
112 effect_sink.rollback();
113 Err(e.into())
114 }
115 }
116 }
117
118 pub fn receive(
123 &mut self,
124 client: &ClientState<E, SP>,
125 sink: &mut impl Sink<E::Effect>,
126 command_bytes: &[u8],
127 ) -> Result<(), ClientError> {
128 let command: SessionCommand<'_> =
129 postcard::from_bytes(command_bytes).map_err(ClientError::SessionDeserialize)?;
130
131 if command.storage_id != self.storage_id {
132 warn!(%command.storage_id, %self.storage_id, "ephemeral commands must be run on the same graph");
133 return Err(ClientError::NotAuthorized);
134 }
135
136 let policy = client.engine.get_policy(self.policy_id)?;
137
138 let mut perspective = SessionPerspective {
140 session: self,
141 message_sink: &mut NullSink,
142 };
143
144 sink.begin();
146 let checkpoint = perspective.checkpoint();
147 if let Err(e) = policy.call_rule(&command, &mut perspective, sink, CommandRecall::None) {
148 perspective.revert(checkpoint)?;
149 sink.rollback();
150 return Err(e.into());
151 }
152 sink.commit();
153
154 Ok(())
155 }
156}
157
158#[derive(Serialize, Deserialize)]
159struct SessionCommand<'a> {
161 storage_id: GraphId,
162 priority: u32, id: CmdId,
164 parent: Address, #[serde(borrow)]
166 data: &'a [u8],
167}
168
169impl Command for SessionCommand<'_> {
170 fn priority(&self) -> Priority {
171 Priority::Basic(self.priority)
172 }
173
174 fn id(&self) -> CmdId {
175 self.id
176 }
177
178 fn parent(&self) -> Prior<Address> {
179 Prior::Single(self.parent)
180 }
181
182 fn policy(&self) -> Option<&[u8]> {
183 None
185 }
186
187 fn bytes(&self) -> &[u8] {
188 self.data
189 }
190}
191
192impl<'sc> SessionCommand<'sc> {
193 fn from_cmd(storage_id: GraphId, command: &'sc impl Command) -> Result<Self, Bug> {
194 if command.policy().is_some() {
195 bug!("session command should have no policy")
196 }
197 Ok(SessionCommand {
198 storage_id,
199 priority: match command.priority() {
200 Priority::Basic(p) => p,
201 _ => bug!("wrong command type"),
202 },
203 id: command.id(),
204 parent: match command.parent() {
205 Prior::Single(p) => p,
206 _ => bug!("wrong command type"),
207 },
208 data: command.bytes(),
209 })
210 }
211}
212
213struct QueryIterator<I1: Iterator, I2: Iterator> {
215 prior: Peekable<I1>,
216 current: Peekable<I2>,
217}
218
219impl<I1, I2> QueryIterator<I1, I2>
220where
221 I1: Iterator<Item = Result<Fact, StorageError>>,
222 I2: Iterator<Item = (Keys, Option<Bytes>)>,
223{
224 fn new(prior: I1, current: I2) -> Self {
225 Self {
226 prior: prior.peekable(),
227 current: current.peekable(),
228 }
229 }
230}
231
232impl<I1, I2> Iterator for QueryIterator<I1, I2>
233where
234 I1: Iterator<Item = Result<Fact, StorageError>>,
235 I2: Iterator<Item = (Keys, Option<Bytes>)>,
236{
237 type Item = Result<Fact, StorageError>;
238
239 fn next(&mut self) -> Option<Self::Item> {
240 loop {
245 let Some(new) = self.current.peek() else {
246 return self.prior.next();
248 };
249 if let Some(old) = self.prior.peek() {
250 let Ok(old) = old else {
251 return self.prior.next();
253 };
254 match new.0.cmp(&old.key) {
255 Ordering::Equal => {
256 let _ = self.prior.next();
258 }
259 Ordering::Greater => {
260 return self.prior.next();
262 }
263 Ordering::Less => {
264 }
266 }
267 }
268 let Some(slot) = self.current.next() else {
269 bug!("expected Some after peek")
270 };
271 if let (k, Some(v)) = slot {
272 return Some(Ok(Fact {
273 key: k.iter().cloned().collect(),
274 value: v,
275 }));
276 }
277 }
278 }
279}
280
281impl<SP, E, MS> FactPerspective for SessionPerspective<'_, SP, E, MS> where SP: StorageProvider {}
282
283impl<SP, E, MS> Query for SessionPerspective<'_, SP, E, MS>
284where
285 SP: StorageProvider,
286{
287 fn query(&self, name: &str, keys: &[Box<[u8]>]) -> Result<Option<Box<[u8]>>, StorageError> {
288 if let Some(slot) = self
289 .session
290 .current_facts
291 .get(name)
292 .and_then(|m| m.get(keys))
293 {
294 return Ok(slot.clone());
295 }
296 self.session.base_facts.query(name, keys)
297 }
298
299 type QueryIterator = QueryIterator<
300 <<SP::Storage as Storage>::FactIndex as Query>::QueryIterator,
301 YokeIter<PrefixIter<'static>, Arc<BTreeMap<String, BTreeMap<Keys, Option<Bytes>>>>>,
302 >;
303 fn query_prefix(
304 &self,
305 name: &str,
306 prefix: &[Box<[u8]>],
307 ) -> Result<Self::QueryIterator, StorageError> {
308 let prior = self.session.base_facts.query_prefix(name, prefix)?;
309 let current = Yoke::<PrefixIter<'static>, _>::attach_to_cart(
310 Arc::clone(&self.session.current_facts),
311 |map| match map.get(name) {
312 Some(facts) => PrefixIter::new(facts, prefix.iter().cloned().collect()),
313 None => PrefixIter::default(),
314 },
315 );
316 Ok(QueryIterator::new(prior, YokeIter::new(current)))
317 }
318}
319
320#[derive(Default, Yokeable)]
325struct PrefixIter<'map> {
326 range: btree_map::Range<'map, Keys, Option<Bytes>>,
327 prefix: Keys,
328}
329
330impl<'map> PrefixIter<'map> {
331 fn new(map: &'map BTreeMap<Keys, Option<Bytes>>, prefix: Keys) -> Self {
332 let range =
333 map.range::<[Box<[u8]>], _>((Bound::Included(prefix.as_ref()), Bound::Unbounded));
334 Self { range, prefix }
335 }
336}
337
338impl Iterator for PrefixIter<'_> {
339 type Item = (Keys, Option<Bytes>);
340
341 fn next(&mut self) -> Option<Self::Item> {
342 self.range
343 .next()
344 .filter(|(k, _)| k.starts_with(&self.prefix))
345 .map(|(k, v)| (k.clone(), v.clone()))
346 }
347}
348
349struct YokeIter<I: for<'a> Yokeable<'a>, C>(Option<Yoke<I, C>>);
351
352impl<I: for<'a> Yokeable<'a>, C> YokeIter<I, C> {
353 fn new(yoke: Yoke<I, C>) -> Self {
354 Self(Some(yoke))
355 }
356}
357
358impl<I, C> Iterator for YokeIter<I, C>
359where
360 I: Iterator + for<'a> Yokeable<'a>,
361 for<'a> <I as Yokeable<'a>>::Output: Iterator<Item = I::Item>,
362{
363 type Item = I::Item;
364
365 fn next(&mut self) -> Option<Self::Item> {
366 let mut item = None;
369 self.0 = Some(self.0.take()?.map_project::<I, _>(|mut it, _| {
370 item = it.next();
371 it
372 }));
373 item
374 }
375}
376
377impl<SP: StorageProvider, E, MS> QueryMut for SessionPerspective<'_, SP, E, MS> {
378 fn insert(&mut self, name: String, keys: Keys, value: Box<[u8]>) {
379 self.session
380 .fact_log
381 .push((name.clone(), keys.clone(), Some(value.clone())));
382 Arc::make_mut(&mut self.session.current_facts)
383 .entry(name)
384 .or_default()
385 .insert(keys, Some(value));
386 }
387
388 fn delete(&mut self, name: String, keys: Keys) {
389 self.session
390 .fact_log
391 .push((name.clone(), keys.clone(), None));
392 Arc::make_mut(&mut self.session.current_facts)
393 .entry(name)
394 .or_default()
395 .insert(keys, None);
396 }
397}
398
399impl<SP, E, MS> Perspective for SessionPerspective<'_, SP, E, MS>
400where
401 SP: StorageProvider,
402 MS: for<'b> Sink<&'b [u8]>,
403{
404 fn policy(&self) -> PolicyId {
405 self.session.policy_id
406 }
407
408 fn add_command(&mut self, command: &impl Command) -> Result<usize, StorageError> {
409 let command = SessionCommand::from_cmd(self.session.storage_id, command)?;
410 self.session.head = command.address()?;
411 let bytes = postcard::to_allocvec(&command).assume("serialize session command")?;
412 self.message_sink.consume(&bytes);
413
414 Ok(0)
415 }
416
417 fn includes(&self, _id: CmdId) -> bool {
418 debug_assert!(false, "only used in transactions");
419
420 false
421 }
422
423 fn head_address(&self) -> Result<Prior<Address>, Bug> {
424 Ok(Prior::Single(self.session.head))
425 }
426}
427
428impl<SP, E, MS> Revertable for SessionPerspective<'_, SP, E, MS>
429where
430 SP: StorageProvider,
431{
432 fn checkpoint(&self) -> Checkpoint {
433 Checkpoint {
434 index: self.session.fact_log.len(),
435 }
436 }
437
438 fn revert(&mut self, checkpoint: Checkpoint) -> Result<(), Bug> {
439 if checkpoint.index == self.session.fact_log.len() {
440 return Ok(());
441 }
442
443 if checkpoint.index > self.session.fact_log.len() {
444 bug!(
445 "A checkpoint's index should always be less than or equal to the length of a session's fact log!"
446 );
447 }
448
449 self.session.fact_log.truncate(checkpoint.index);
450 let mut facts =
452 Arc::get_mut(&mut self.session.current_facts).map_or_else(BTreeMap::new, mem::take);
453 facts.clear();
454 for (n, k, v) in self.session.fact_log.iter().cloned() {
455 facts.entry(n).or_default().insert(k, v);
456 }
457 self.session.current_facts = Arc::new(facts);
458
459 Ok(())
460 }
461}
462
463#[cfg(test)]
464mod test {
465 use super::*;
466
467 #[test]
468 fn test_query_iterator() {
469 #![allow(clippy::type_complexity)]
470
471 let prior: Vec<Result<(&[&[u8]], &[u8]), _>> = vec![
472 Ok((&[b"a"], b"a0")),
473 Ok((&[b"c"], b"c0")),
474 Ok((&[b"d"], b"d0")),
475 Ok((&[b"f"], b"f0")),
476 Err(StorageError::IoError),
477 ];
478 let current: Vec<([Box<[u8]>; 1], Option<&[u8]>)> = vec![
479 ([Box::new(*b"a")], None),
480 ([Box::new(*b"b")], Some(b"b1")),
481 ([Box::new(*b"e")], None),
482 ([Box::new(*b"j")], None),
483 ];
484 let merged: Vec<Result<(&[&[u8]], &[u8]), _>> = vec![
485 Ok((&[b"b"], b"b1")),
486 Ok((&[b"c"], b"c0")),
487 Ok((&[b"d"], b"d0")),
488 Ok((&[b"f"], b"f0")),
489 Err(StorageError::IoError),
490 ];
491
492 let got: Vec<_> = QueryIterator::new(
493 prior.into_iter().map(|r| {
494 r.map(|(k, v)| Fact {
495 key: k.into(),
496 value: v.into(),
497 })
498 }),
499 current
500 .into_iter()
501 .map(|(k, v)| (k.into_iter().collect(), v.map(Box::from))),
502 )
503 .collect();
504 let want: Vec<_> = merged
505 .into_iter()
506 .map(|r| {
507 r.map(|(k, v)| Fact {
508 key: k.into(),
509 value: v.into(),
510 })
511 })
512 .collect();
513
514 assert_eq!(got, want);
515 }
516}