computation_graph/
lib.rs

1use std::{
2    any::{Any, TypeId},
3    collections::HashMap,
4};
5
6use petgraph::graph::NodeIndex;
7
8pub trait DbKey: 'static {
9    type Value: 'static;
10}
11
12pub trait DataBase {
13    fn get<K: DbKey>(&self) -> Option<&K::Value>;
14    fn get_cloned<K: DbKey>(&self) -> Option<K::Value>
15    where
16        K::Value: Clone,
17    {
18        self.get::<K>().cloned()
19    }
20    fn put<K: DbKey>(&mut self, value: K::Value) -> Option<K::Value>;
21}
22
23pub struct InMemoryDb {
24    data: HashMap<TypeId, Box<dyn Any>>,
25}
26
27impl InMemoryDb {
28    pub fn new() -> Self {
29        InMemoryDb {
30            data: HashMap::new(),
31        }
32    }
33}
34
35impl DataBase for InMemoryDb {
36    fn get<K: DbKey>(&self) -> Option<&K::Value> {
37        let t = TypeId::of::<K>();
38        self.data.get(&t).and_then(|v| v.downcast_ref::<K::Value>())
39    }
40
41    fn put<K: DbKey>(&mut self, value: K::Value) -> Option<K::Value> {
42        self.data
43            .insert(TypeId::of::<K>(), Box::new(value))
44            .and_then(|v| v.downcast::<K::Value>().ok().map(|v| *v))
45    }
46}
47
48pub trait Task<Db: DataBase> {
49    type Input: TaskInput<Db>;
50    type Output: TaskOutput<Db>;
51
52    fn execute(input: Self::Input) -> Self::Output;
53}
54
55impl DbKey for () {
56    type Value = ();
57}
58
59impl<Db: DataBase> TaskInput<Db> for () {
60    fn from_db(_db: &Db) -> Self {
61        ()
62    }
63}
64
65impl<Db: DataBase> TaskOutput<Db> for () {
66    fn to_db(&self, _db: &mut Db) {}
67}
68
69pub trait TaskInput<Db: DataBase>: DbKey<Value = Self>
70where
71    Self: Sized + 'static,
72{
73    fn from_db(db: &Db) -> Self;
74    fn dep_types() -> Vec<TypeId> {
75        vec![]
76    }
77}
78
79pub trait TaskOutput<Db: DataBase>: DbKey<Value = Self>
80where
81    Self: Sized + 'static,
82{
83    fn to_db(&self, db: &mut Db);
84    fn out_types() -> Vec<TypeId> {
85        vec![]
86    }
87}
88
89pub struct ExecutionGraph<Db: DataBase> {
90    tasks: petgraph::graph::DiGraph<TypeId, fn(&mut Db)>,
91    db: Db,
92}
93
94impl<Db: DataBase> ExecutionGraph<Db> {
95    pub fn new(db: Db) -> Self {
96        ExecutionGraph {
97            db,
98            tasks: petgraph::graph::DiGraph::new(),
99        }
100    }
101
102    fn contains_node(&self, ty: &TypeId) -> Option<NodeIndex> {
103        self.tasks.node_indices().find(|i| &self.tasks[*i] == ty)
104    }
105
106    pub fn execute<T: Task<Db>>(&mut self) -> T::Output {
107        for ty in T::Input::dep_types() {
108            if let None = self.contains_node(&ty) {
109                panic!("Missing dependency: {:?}", ty)
110            }
111        }
112        let input = T::Input::from_db(&self.db);
113        let output = T::execute(input);
114        output.to_db(&mut self.db);
115        output
116    }
117}
118
119pub struct ExecutionGraphBuilder<Db: DataBase> {
120    graph: ExecutionGraph<Db>,
121}
122
123impl<Db: DataBase> ExecutionGraphBuilder<Db> {
124    pub fn new(db: Db) -> Self {
125        ExecutionGraphBuilder {
126            graph: ExecutionGraph::new(db),
127        }
128    }
129
130    pub fn add_input<T: DbKey>(&mut self, value: T::Value) -> &mut Self {
131        self.graph.db.put::<T>(value);
132        self
133    }
134
135    pub fn add_task<T: Task<Db>>(&mut self) -> &mut Self {
136        let task_input_node = self.graph.tasks.add_node(TypeId::of::<T::Input>());
137        for dep_ty in T::Input::dep_types() {
138            let Some(in_node_id) = self.graph.contains_node(&dep_ty) else {
139                panic!("Missing dependency: {:?}", dep_ty)
140            };
141
142            self.graph
143                .tasks
144                .add_edge(in_node_id, task_input_node, |db| {
145                    let input = T::Input::from_db(db);
146                    db.put::<T::Input>(input);
147                });
148        }
149        let out_node = self.graph.tasks.add_node(TypeId::of::<T::Output>());
150        for out_ty in T::Output::out_types() {
151            match self.graph.contains_node(&out_ty) {
152                Some(_out_node_id) => {
153                    panic!("Output already exists: {:?}", out_ty)
154                }
155                None => {
156                    let out_ty_node = self.graph.tasks.add_node(out_ty);
157                    self.graph.tasks.add_edge(out_node, out_ty_node, |_| {});
158                }
159            }
160        }
161        self
162    }
163
164    pub fn build(self) -> ExecutionGraph<Db> {
165        self.graph
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172
173    struct MyKey;
174
175    impl DbKey for MyKey {
176        type Value = i32;
177    }
178
179    #[test]
180    fn test_in_memory_db() {
181        let mut db = InMemoryDb::new();
182        db.put::<MyKey>(42);
183        assert_eq!(db.get::<MyKey>(), Some(&42));
184    }
185
186    #[test]
187    fn test_in_memory_db_wrong_key() {
188        let mut db = InMemoryDb::new();
189        db.put::<MyKey>(42);
190        assert_eq!(db.get::<MyKey>(), Some(&42));
191        assert_eq!(db.get::<MyKey>(), Some(&42));
192    }
193
194    #[derive(Copy, Clone)]
195    struct MyValue {
196        x: i32,
197    }
198
199    impl DbKey for MyValue {
200        type Value = MyValue;
201    }
202
203    impl<Db: DataBase> TaskInput<Db> for MyValue {
204        fn from_db(db: &Db) -> Self {
205            db.get_cloned::<MyValue>().unwrap()
206        }
207    }
208
209    #[derive(Copy, Clone, PartialEq, Debug)]
210    struct MyValue2 {
211        x: i32,
212    }
213
214    impl DbKey for MyValue2 {
215        type Value = MyValue2;
216    }
217
218    impl<Db: DataBase> TaskOutput<Db> for MyValue2 {
219        fn to_db(&self, db: &mut Db) {
220            db.put::<MyValue2>(*self);
221        }
222    }
223
224    struct MyTask;
225
226    impl Task<InMemoryDb> for MyTask {
227        type Input = MyValue;
228        type Output = MyValue2;
229
230        fn execute(input: Self::Input) -> Self::Output {
231            MyValue2 { x: input.x }
232        }
233    }
234
235    #[test]
236    fn test_execution_graph() {
237        let mut builder = ExecutionGraphBuilder::new(InMemoryDb::new());
238        builder.add_input::<MyValue>(MyValue { x: 42 });
239        builder.add_task::<MyTask>();
240        let mut graph = builder.build();
241        graph.execute::<MyTask>();
242        assert_eq!(graph.db.get::<MyValue2>(), Some(&MyValue2 { x: 42 }));
243    }
244}