Skip to main content

inc_complete/
accumulate.rs

1use std::{collections::BTreeSet, marker::PhantomData};
2
3use dashmap::DashMap;
4use serde::{Deserialize, Serialize};
5
6use crate::{Cell, Computation, Run, Storage, StorageFor};
7
8/// An accumulator is a collection which can accumulate a given cell key associated with multiple
9/// values of a given type. `Accumulator<MyItem>` is an example of such a type.
10///
11/// This is most often a hashmap or similar map
12pub trait Accumulate<Item> {
13    /// Push an item to the context of the given cell
14    fn accumulate(&self, cell: Cell, item: Item);
15
16    /// Retrieve all items associated with the given cell.
17    /// Note that this should only include the exact cell given, not any
18    /// values accumulated from dependencies.
19    fn get_accumulated<Items>(&self, cell: Cell) -> Items
20    where
21        Items: FromIterator<Item>;
22}
23
24pub struct Accumulator<Item> {
25    map: DashMap<Cell, Vec<Item>>,
26}
27
28impl<Item> Default for Accumulator<Item> {
29    fn default() -> Self {
30        Self {
31            map: Default::default(),
32        }
33    }
34}
35
36impl<Item> Accumulator<Item> {
37    pub fn clear(&self, cell: Cell) {
38        self.map.remove(&cell);
39    }
40}
41
42impl<Item: Clone> Accumulate<Item> for Accumulator<Item> {
43    fn accumulate(&self, cell: Cell, item: Item) {
44        self.map.entry(cell).or_default().push(item);
45    }
46
47    fn get_accumulated<Items>(&self, cell: Cell) -> Items
48    where
49        Items: FromIterator<Item>,
50    {
51        if let Some(items) = self.map.get(&cell) {
52            FromIterator::from_iter(items.iter().cloned())
53        } else {
54            FromIterator::from_iter(std::iter::empty())
55        }
56    }
57}
58
59#[derive(Serialize, Deserialize, Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
60#[serde(transparent)]
61pub struct Accumulated<Item> {
62    cell: Cell,
63    _item: std::marker::PhantomData<Item>,
64}
65
66impl<Item> Accumulated<Item> {
67    pub(crate) fn new(cell: Cell) -> Self {
68        Self {
69            cell,
70            _item: PhantomData,
71        }
72    }
73}
74
75// Arbitrary semi-random value meant to not be easily accidentally used for ids in user code.
76// Must be unique across all computation IDs.
77pub(crate) const ACCUMULATED_COMPUTATION_ID: u32 = 0x54325243;
78
79impl<Item: 'static> Computation for Accumulated<Item> {
80    type Output = BTreeSet<Item>;
81    const IS_INPUT: bool = false;
82    const ASSUME_CHANGED: bool = false;
83
84    fn computation_id() -> u32 {
85        ACCUMULATED_COMPUTATION_ID
86    }
87}
88
89impl<S, Item> Run<S> for Accumulated<Item>
90where
91    Item: 'static + Ord,
92    S: Storage + StorageFor<Accumulated<Item>> + Accumulate<Item>,
93{
94    fn run(&self, db: &crate::DbHandle<S>) -> Self::Output {
95        db.get_accumulated_with_cell(self.cell)
96    }
97}
98
99impl<Item: Serialize + Clone> Serialize for Accumulator<Item> {
100    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
101    where
102        S: serde::Serializer,
103    {
104        let vec: Vec<(Cell, Vec<Item>)> = self
105            .map
106            .iter()
107            .map(|entry| (*entry.key(), entry.value().clone()))
108            .collect();
109
110        vec.serialize(serializer)
111    }
112}
113
114impl<'de, Item: Deserialize<'de>> Deserialize<'de> for Accumulator<Item> {
115    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
116    where
117        D: serde::Deserializer<'de>,
118    {
119        let vec: Vec<(Cell, Vec<Item>)> = Deserialize::deserialize(deserializer)?;
120        let map = vec.into_iter().collect();
121        Ok(Accumulator { map })
122    }
123}