Skip to main content

inc_complete/
accumulate.rs

1use std::marker::PhantomData;
2
3use dashmap::DashMap;
4
5use crate::{Cell, Computation, Run, Storage, StorageFor};
6
7/// An accumulator is a collection which can accumulate a given cell key associated with multiple
8/// values of a given type. `Accumulator<MyItem>` is an example of such a type.
9///
10/// This is most often a hashmap or similar map
11pub trait Accumulate<Item> {
12    /// Push an item to the context of the given cell
13    fn accumulate(&self, cell: Cell, item: Item);
14
15    /// Retrieve all items associated with the given cell.
16    /// Note that this should only include the exact cell given, not any
17    /// values accumulated from dependencies.
18    fn get_accumulated<Items>(&self, cell: Cell) -> Items
19        where Items: FromIterator<Item>;
20}
21
22pub struct Accumulator<Item> {
23    map: DashMap<Cell, Vec<Item>>,
24}
25
26impl<Item> Default for Accumulator<Item> {
27    fn default() -> Self {
28        Self { map: Default::default() }
29    }
30}
31
32impl<Item> Accumulator<Item> {
33    pub fn clear(&self, cell: Cell) {
34        self.map.remove(&cell);
35    }
36}
37
38impl<Item: Clone> Accumulate<Item> for Accumulator<Item> {
39    fn accumulate(&self, cell: Cell, item: Item) {
40        self.map.entry(cell).or_default().push(item);
41    }
42
43    fn get_accumulated<Items>(&self, cell: Cell) -> Items
44        where Items: FromIterator<Item>
45    {
46        if let Some(items) = self.map.get(&cell) {
47            FromIterator::from_iter(items.iter().cloned())
48        } else {
49            FromIterator::from_iter(std::iter::empty())
50        }
51    }
52}
53
54#[derive(serde::Serialize, serde::Deserialize, Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
55#[serde(transparent)]
56pub struct Accumulated<Item> {
57    cell: Cell,
58    _item: std::marker::PhantomData<Item>,
59}
60
61impl<Item> Accumulated<Item> {
62    pub(crate) fn new(cell: Cell) -> Self {
63        Self { cell, _item: PhantomData }
64    }
65}
66
67impl<Item: 'static> Computation for Accumulated<Item> {
68    type Output = Vec<Item>;
69    const IS_INPUT: bool = false;
70    const ASSUME_CHANGED: bool = false;
71
72    fn computation_id() -> u32 {
73        100000
74    }
75}
76
77impl<S, Item> Run<S> for Accumulated<Item> where
78    Item: 'static,
79    S: Storage + StorageFor<Accumulated<Item>> + Accumulate<Item>,
80{
81    fn run(&self, db: &crate::DbHandle<S>) -> Self::Output {
82        db.get_accumulated_with_cell(self.cell)
83    }
84}