1use alloc::{
8 boxed::Box,
9 collections::{BTreeMap, btree_map},
10 string::String,
11 sync::Arc,
12 vec::Vec,
13};
14use core::{cmp::Ordering, iter::Peekable, marker::PhantomData, mem, ops::Bound};
15
16use buggy::{Bug, BugExt as _, bug};
17use serde::{Deserialize, Serialize};
18use tracing::warn;
19use yoke::{Yoke, Yokeable};
20
21use crate::{
22 Address, Checkpoint, ClientError, ClientState, CmdId, Command, Fact, FactPerspective, GraphId,
23 Keys, NullSink, Perspective, Policy, PolicyId, PolicyStore, Prior, Priority, Query, QueryMut,
24 Revertable, Segment as _, Sink, Storage, StorageError, StorageProvider,
25 policy::{ActionPlacement, CommandPlacement},
26};
27
28type Bytes = Box<[u8]>;
29
30pub struct Session<SP: StorageProvider, PS> {
32 graph_id: GraphId,
34 policy_id: PolicyId,
36
37 base_facts: <SP::Storage as Storage>::FactIndex,
39 fact_log: Vec<(String, Keys, Option<Bytes>)>,
41 current_facts: Arc<BTreeMap<String, BTreeMap<Keys, Option<Bytes>>>>,
43
44 _policy_store: PhantomData<PS>,
46
47 head: Address,
48}
49
50struct SessionPerspective<'a, SP: StorageProvider, PS, MS> {
51 session: &'a mut Session<SP, PS>,
52 message_sink: &'a mut MS,
53}
54
55impl<SP: StorageProvider, PS> Session<SP, PS> {
56 pub(super) fn new(provider: &mut SP, graph_id: GraphId) -> Result<Self, ClientError> {
57 let storage = provider.get_storage(graph_id)?;
58 let head_loc = storage.get_head()?;
59 let seg = storage.get_segment(head_loc)?;
60
61 let base_facts = seg.facts()?;
62
63 let result = Self {
64 graph_id,
65 policy_id: seg.policy(),
66 base_facts,
67 fact_log: Vec::new(),
68 current_facts: Arc::default(),
69 _policy_store: PhantomData,
70 head: storage.get_head_address()?,
71 };
72
73 Ok(result)
74 }
75}
76
77impl<SP: StorageProvider, PS: PolicyStore> Session<SP, PS> {
78 pub fn action<ES, MS>(
81 &mut self,
82 client: &ClientState<PS, SP>,
83 effect_sink: &mut ES,
84 message_sink: &mut MS,
85 action: <PS::Policy as Policy>::Action<'_>,
86 ) -> Result<(), ClientError>
87 where
88 ES: Sink<PS::Effect>,
89 MS: for<'b> Sink<&'b [u8]>,
90 {
91 let policy = client.policy_store.get_policy(self.policy_id)?;
92
93 let mut perspective = SessionPerspective {
95 session: self,
96 message_sink,
97 };
98 let checkpoint = perspective.checkpoint();
99 effect_sink.begin();
100
101 match policy.call_action(
103 action,
104 &mut perspective,
105 effect_sink,
106 ActionPlacement::OffGraph,
107 ) {
108 Ok(()) => {
109 effect_sink.commit();
111 Ok(())
112 }
113 Err(e) => {
114 perspective.revert(checkpoint)?;
116 perspective.message_sink.rollback();
117 effect_sink.rollback();
118 Err(e.into())
119 }
120 }
121 }
122
123 pub fn receive(
128 &mut self,
129 client: &ClientState<PS, SP>,
130 sink: &mut impl Sink<PS::Effect>,
131 command_bytes: &[u8],
132 ) -> Result<(), ClientError> {
133 let command: SessionCommand<'_> =
134 postcard::from_bytes(command_bytes).map_err(ClientError::SessionDeserialize)?;
135
136 if command.graph_id != self.graph_id {
137 warn!(%command.graph_id, %self.graph_id, "ephemeral commands must be run on the same graph");
138 return Err(ClientError::NotAuthorized);
139 }
140
141 let policy = client.policy_store.get_policy(self.policy_id)?;
142
143 let mut perspective = SessionPerspective {
145 session: self,
146 message_sink: &mut NullSink,
147 };
148
149 sink.begin();
151 let checkpoint = perspective.checkpoint();
152 if let Err(e) =
153 policy.call_rule(&command, &mut perspective, sink, CommandPlacement::OffGraph)
154 {
155 perspective.revert(checkpoint)?;
156 sink.rollback();
157 return Err(e.into());
158 }
159 sink.commit();
160
161 Ok(())
162 }
163}
164
165#[derive(Serialize, Deserialize)]
166struct SessionCommand<'a> {
168 graph_id: GraphId,
169 priority: u32, id: CmdId,
171 parent: Address, #[serde(borrow)]
173 data: &'a [u8],
174}
175
176impl Command for SessionCommand<'_> {
177 fn priority(&self) -> Priority {
178 Priority::Basic(self.priority)
179 }
180
181 fn id(&self) -> CmdId {
182 self.id
183 }
184
185 fn parent(&self) -> Prior<Address> {
186 Prior::Single(self.parent)
187 }
188
189 fn policy(&self) -> Option<&[u8]> {
190 None
192 }
193
194 fn bytes(&self) -> &[u8] {
195 self.data
196 }
197}
198
199impl<'sc> SessionCommand<'sc> {
200 fn from_cmd(graph_id: GraphId, command: &'sc impl Command) -> Result<Self, Bug> {
201 if command.policy().is_some() {
202 bug!("session command should have no policy")
203 }
204 Ok(SessionCommand {
205 graph_id,
206 priority: match command.priority() {
207 Priority::Basic(p) => p,
208 _ => bug!("wrong command type"),
209 },
210 id: command.id(),
211 parent: match command.parent() {
212 Prior::Single(p) => p,
213 _ => bug!("wrong command type"),
214 },
215 data: command.bytes(),
216 })
217 }
218}
219
220struct QueryIterator<I1: Iterator, I2: Iterator> {
222 prior: Peekable<I1>,
223 current: Peekable<I2>,
224}
225
226impl<I1, I2> QueryIterator<I1, I2>
227where
228 I1: Iterator<Item = Result<Fact, StorageError>>,
229 I2: Iterator<Item = (Keys, Option<Bytes>)>,
230{
231 fn new(prior: I1, current: I2) -> Self {
232 Self {
233 prior: prior.peekable(),
234 current: current.peekable(),
235 }
236 }
237}
238
239impl<I1, I2> Iterator for QueryIterator<I1, I2>
240where
241 I1: Iterator<Item = Result<Fact, StorageError>>,
242 I2: Iterator<Item = (Keys, Option<Bytes>)>,
243{
244 type Item = Result<Fact, StorageError>;
245
246 fn next(&mut self) -> Option<Self::Item> {
247 loop {
252 let Some(new) = self.current.peek() else {
253 return self.prior.next();
255 };
256 if let Some(old) = self.prior.peek() {
257 let Ok(old) = old else {
258 return self.prior.next();
260 };
261 match new.0.cmp(&old.key) {
262 Ordering::Equal => {
263 let _ = self.prior.next();
265 }
266 Ordering::Greater => {
267 return self.prior.next();
269 }
270 Ordering::Less => {
271 }
273 }
274 }
275 let Some(slot) = self.current.next() else {
276 bug!("expected Some after peek")
277 };
278 if let (k, Some(v)) = slot {
279 return Some(Ok(Fact {
280 key: k.iter().cloned().collect(),
281 value: v,
282 }));
283 }
284 }
285 }
286}
287
288impl<SP, PS, MS> FactPerspective for SessionPerspective<'_, SP, PS, MS> where SP: StorageProvider {}
289
290impl<SP, PS, MS> Query for SessionPerspective<'_, SP, PS, MS>
291where
292 SP: StorageProvider,
293{
294 fn query(&self, name: &str, keys: &[Box<[u8]>]) -> Result<Option<Box<[u8]>>, StorageError> {
295 if let Some(slot) = self
296 .session
297 .current_facts
298 .get(name)
299 .and_then(|m| m.get(keys))
300 {
301 return Ok(slot.clone());
302 }
303 self.session.base_facts.query(name, keys)
304 }
305
306 type QueryIterator = QueryIterator<
307 <<SP::Storage as Storage>::FactIndex as Query>::QueryIterator,
308 YokeIter<PrefixIter<'static>, Arc<BTreeMap<String, BTreeMap<Keys, Option<Bytes>>>>>,
309 >;
310 fn query_prefix(
311 &self,
312 name: &str,
313 prefix: &[Box<[u8]>],
314 ) -> Result<Self::QueryIterator, StorageError> {
315 let prior = self.session.base_facts.query_prefix(name, prefix)?;
316 let current = Yoke::<PrefixIter<'static>, _>::attach_to_cart(
317 Arc::clone(&self.session.current_facts),
318 |map| match map.get(name) {
319 Some(facts) => PrefixIter::new(facts, prefix.iter().cloned().collect()),
320 None => PrefixIter::default(),
321 },
322 );
323 Ok(QueryIterator::new(prior, YokeIter::new(current)))
324 }
325}
326
327#[derive(Default, Yokeable)]
332struct PrefixIter<'map> {
333 range: btree_map::Range<'map, Keys, Option<Bytes>>,
334 prefix: Keys,
335}
336
337impl<'map> PrefixIter<'map> {
338 fn new(map: &'map BTreeMap<Keys, Option<Bytes>>, prefix: Keys) -> Self {
339 let range =
340 map.range::<[Box<[u8]>], _>((Bound::Included(prefix.as_ref()), Bound::Unbounded));
341 Self { range, prefix }
342 }
343}
344
345impl Iterator for PrefixIter<'_> {
346 type Item = (Keys, Option<Bytes>);
347
348 fn next(&mut self) -> Option<Self::Item> {
349 self.range
350 .next()
351 .filter(|(k, _)| k.starts_with(&self.prefix))
352 .map(|(k, v)| (k.clone(), v.clone()))
353 }
354}
355
356struct YokeIter<I: for<'a> Yokeable<'a>, C>(Option<Yoke<I, C>>);
358
359impl<I: for<'a> Yokeable<'a>, C> YokeIter<I, C> {
360 fn new(yoke: Yoke<I, C>) -> Self {
361 Self(Some(yoke))
362 }
363}
364
365impl<I, C> Iterator for YokeIter<I, C>
366where
367 I: Iterator + for<'a> Yokeable<'a>,
368 for<'a> <I as Yokeable<'a>>::Output: Iterator<Item = I::Item>,
369{
370 type Item = I::Item;
371
372 fn next(&mut self) -> Option<Self::Item> {
373 let mut item = None;
376 self.0 = Some(self.0.take()?.map_project::<I, _>(|mut it, _| {
377 item = it.next();
378 it
379 }));
380 item
381 }
382}
383
384impl<SP: StorageProvider, PS, MS> QueryMut for SessionPerspective<'_, SP, PS, MS> {
385 fn insert(&mut self, name: String, keys: Keys, value: Box<[u8]>) {
386 self.session
387 .fact_log
388 .push((name.clone(), keys.clone(), Some(value.clone())));
389 Arc::make_mut(&mut self.session.current_facts)
390 .entry(name)
391 .or_default()
392 .insert(keys, Some(value));
393 }
394
395 fn delete(&mut self, name: String, keys: Keys) {
396 self.session
397 .fact_log
398 .push((name.clone(), keys.clone(), None));
399 Arc::make_mut(&mut self.session.current_facts)
400 .entry(name)
401 .or_default()
402 .insert(keys, None);
403 }
404}
405
406impl<SP, PS, MS> Perspective for SessionPerspective<'_, SP, PS, MS>
407where
408 SP: StorageProvider,
409 MS: for<'b> Sink<&'b [u8]>,
410{
411 fn policy(&self) -> PolicyId {
412 self.session.policy_id
413 }
414
415 fn add_command(&mut self, command: &impl Command) -> Result<usize, StorageError> {
416 let command = SessionCommand::from_cmd(self.session.graph_id, command)?;
417 self.session.head = command.address()?;
418 let bytes = postcard::to_allocvec(&command).assume("serialize session command")?;
419 self.message_sink.consume(&bytes);
420
421 Ok(0)
422 }
423
424 fn includes(&self, _id: CmdId) -> bool {
425 debug_assert!(false, "only used in transactions");
426
427 false
428 }
429
430 fn head_address(&self) -> Result<Prior<Address>, Bug> {
431 Ok(Prior::Single(self.session.head))
432 }
433}
434
435impl<SP, PS, MS> Revertable for SessionPerspective<'_, SP, PS, MS>
436where
437 SP: StorageProvider,
438{
439 fn checkpoint(&self) -> Checkpoint {
440 Checkpoint {
441 index: self.session.fact_log.len(),
442 }
443 }
444
445 fn revert(&mut self, checkpoint: Checkpoint) -> Result<(), Bug> {
446 if checkpoint.index == self.session.fact_log.len() {
447 return Ok(());
448 }
449
450 if checkpoint.index > self.session.fact_log.len() {
451 bug!(
452 "A checkpoint's index should always be less than or equal to the length of a session's fact log!"
453 );
454 }
455
456 self.session.fact_log.truncate(checkpoint.index);
457 let mut facts =
459 Arc::get_mut(&mut self.session.current_facts).map_or_else(BTreeMap::new, mem::take);
460 facts.clear();
461 for (n, k, v) in self.session.fact_log.iter().cloned() {
462 facts.entry(n).or_default().insert(k, v);
463 }
464 self.session.current_facts = Arc::new(facts);
465
466 Ok(())
467 }
468}
469
470#[cfg(test)]
471mod test {
472 use super::*;
473
474 #[test]
475 fn test_query_iterator() {
476 #![allow(clippy::type_complexity)]
477
478 let prior: Vec<Result<(&[&[u8]], &[u8]), _>> = vec![
479 Ok((&[b"a"], b"a0")),
480 Ok((&[b"c"], b"c0")),
481 Ok((&[b"d"], b"d0")),
482 Ok((&[b"f"], b"f0")),
483 Err(StorageError::IoError),
484 ];
485 let current: Vec<([Box<[u8]>; 1], Option<&[u8]>)> = vec![
486 ([Box::new(*b"a")], None),
487 ([Box::new(*b"b")], Some(b"b1")),
488 ([Box::new(*b"e")], None),
489 ([Box::new(*b"j")], None),
490 ];
491 let merged: Vec<Result<(&[&[u8]], &[u8]), _>> = vec![
492 Ok((&[b"b"], b"b1")),
493 Ok((&[b"c"], b"c0")),
494 Ok((&[b"d"], b"d0")),
495 Ok((&[b"f"], b"f0")),
496 Err(StorageError::IoError),
497 ];
498
499 let got: Vec<_> = QueryIterator::new(
500 prior.into_iter().map(|r| {
501 r.map(|(k, v)| Fact {
502 key: k.into(),
503 value: v.into(),
504 })
505 }),
506 current
507 .into_iter()
508 .map(|(k, v)| (k.into_iter().collect(), v.map(Box::from))),
509 )
510 .collect();
511 let want: Vec<_> = merged
512 .into_iter()
513 .map(|r| {
514 r.map(|(k, v)| Fact {
515 key: k.into(),
516 value: v.into(),
517 })
518 })
519 .collect();
520
521 assert_eq!(got, want);
522 }
523}