1use std::{
2 any::{Any, TypeId},
3 collections::HashMap,
4};
5
6use petgraph::graph::NodeIndex;
7
8pub trait DbKey: 'static {
9 type Value: 'static;
10}
11
12pub trait DataBase {
13 fn get<K: DbKey>(&self) -> Option<&K::Value>;
14 fn get_cloned<K: DbKey>(&self) -> Option<K::Value>
15 where
16 K::Value: Clone,
17 {
18 self.get::<K>().cloned()
19 }
20 fn put<K: DbKey>(&mut self, value: K::Value) -> Option<K::Value>;
21}
22
23pub struct InMemoryDb {
24 data: HashMap<TypeId, Box<dyn Any>>,
25}
26
27impl InMemoryDb {
28 pub fn new() -> Self {
29 InMemoryDb {
30 data: HashMap::new(),
31 }
32 }
33}
34
35impl DataBase for InMemoryDb {
36 fn get<K: DbKey>(&self) -> Option<&K::Value> {
37 let t = TypeId::of::<K>();
38 self.data.get(&t).and_then(|v| v.downcast_ref::<K::Value>())
39 }
40
41 fn put<K: DbKey>(&mut self, value: K::Value) -> Option<K::Value> {
42 self.data
43 .insert(TypeId::of::<K>(), Box::new(value))
44 .and_then(|v| v.downcast::<K::Value>().ok().map(|v| *v))
45 }
46}
47
48pub trait Task<Db: DataBase> {
49 type Input: TaskInput<Db>;
50 type Output: TaskOutput<Db>;
51
52 fn execute(input: Self::Input) -> Self::Output;
53}
54
55impl DbKey for () {
56 type Value = ();
57}
58
59impl<Db: DataBase> TaskInput<Db> for () {
60 fn from_db(_db: &Db) -> Self {
61 ()
62 }
63}
64
65impl<Db: DataBase> TaskOutput<Db> for () {
66 fn to_db(&self, _db: &mut Db) {}
67}
68
69pub trait TaskInput<Db: DataBase>: DbKey<Value = Self>
70where
71 Self: Sized + 'static,
72{
73 fn from_db(db: &Db) -> Self;
74 fn dep_types() -> Vec<TypeId> {
75 vec![]
76 }
77}
78
79pub trait TaskOutput<Db: DataBase>: DbKey<Value = Self>
80where
81 Self: Sized + 'static,
82{
83 fn to_db(&self, db: &mut Db);
84 fn out_types() -> Vec<TypeId> {
85 vec![]
86 }
87}
88
89pub struct ExecutionGraph<Db: DataBase> {
90 tasks: petgraph::graph::DiGraph<TypeId, fn(&mut Db)>,
91 db: Db,
92}
93
94impl<Db: DataBase> ExecutionGraph<Db> {
95 pub fn new(db: Db) -> Self {
96 ExecutionGraph {
97 db,
98 tasks: petgraph::graph::DiGraph::new(),
99 }
100 }
101
102 fn contains_node(&self, ty: &TypeId) -> Option<NodeIndex> {
103 self.tasks.node_indices().find(|i| &self.tasks[*i] == ty)
104 }
105
106 pub fn execute<T: Task<Db>>(&mut self) -> T::Output {
107 for ty in T::Input::dep_types() {
108 if let None = self.contains_node(&ty) {
109 panic!("Missing dependency: {:?}", ty)
110 }
111 }
112 let input = T::Input::from_db(&self.db);
113 let output = T::execute(input);
114 output.to_db(&mut self.db);
115 output
116 }
117}
118
119pub struct ExecutionGraphBuilder<Db: DataBase> {
120 graph: ExecutionGraph<Db>,
121}
122
123impl<Db: DataBase> ExecutionGraphBuilder<Db> {
124 pub fn new(db: Db) -> Self {
125 ExecutionGraphBuilder {
126 graph: ExecutionGraph::new(db),
127 }
128 }
129
130 pub fn add_input<T: DbKey>(&mut self, value: T::Value) -> &mut Self {
131 self.graph.db.put::<T>(value);
132 self
133 }
134
135 pub fn add_task<T: Task<Db>>(&mut self) -> &mut Self {
136 let task_input_node = self.graph.tasks.add_node(TypeId::of::<T::Input>());
137 for dep_ty in T::Input::dep_types() {
138 let Some(in_node_id) = self.graph.contains_node(&dep_ty) else {
139 panic!("Missing dependency: {:?}", dep_ty)
140 };
141
142 self.graph
143 .tasks
144 .add_edge(in_node_id, task_input_node, |db| {
145 let input = T::Input::from_db(db);
146 db.put::<T::Input>(input);
147 });
148 }
149 let out_node = self.graph.tasks.add_node(TypeId::of::<T::Output>());
150 for out_ty in T::Output::out_types() {
151 match self.graph.contains_node(&out_ty) {
152 Some(_out_node_id) => {
153 panic!("Output already exists: {:?}", out_ty)
154 }
155 None => {
156 let out_ty_node = self.graph.tasks.add_node(out_ty);
157 self.graph.tasks.add_edge(out_node, out_ty_node, |_| {});
158 }
159 }
160 }
161 self
162 }
163
164 pub fn build(self) -> ExecutionGraph<Db> {
165 self.graph
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172
173 struct MyKey;
174
175 impl DbKey for MyKey {
176 type Value = i32;
177 }
178
179 #[test]
180 fn test_in_memory_db() {
181 let mut db = InMemoryDb::new();
182 db.put::<MyKey>(42);
183 assert_eq!(db.get::<MyKey>(), Some(&42));
184 }
185
186 #[test]
187 fn test_in_memory_db_wrong_key() {
188 let mut db = InMemoryDb::new();
189 db.put::<MyKey>(42);
190 assert_eq!(db.get::<MyKey>(), Some(&42));
191 assert_eq!(db.get::<MyKey>(), Some(&42));
192 }
193
194 #[derive(Copy, Clone)]
195 struct MyValue {
196 x: i32,
197 }
198
199 impl DbKey for MyValue {
200 type Value = MyValue;
201 }
202
203 impl<Db: DataBase> TaskInput<Db> for MyValue {
204 fn from_db(db: &Db) -> Self {
205 db.get_cloned::<MyValue>().unwrap()
206 }
207 }
208
209 #[derive(Copy, Clone, PartialEq, Debug)]
210 struct MyValue2 {
211 x: i32,
212 }
213
214 impl DbKey for MyValue2 {
215 type Value = MyValue2;
216 }
217
218 impl<Db: DataBase> TaskOutput<Db> for MyValue2 {
219 fn to_db(&self, db: &mut Db) {
220 db.put::<MyValue2>(*self);
221 }
222 }
223
224 struct MyTask;
225
226 impl Task<InMemoryDb> for MyTask {
227 type Input = MyValue;
228 type Output = MyValue2;
229
230 fn execute(input: Self::Input) -> Self::Output {
231 MyValue2 { x: input.x }
232 }
233 }
234
235 #[test]
236 fn test_execution_graph() {
237 let mut builder = ExecutionGraphBuilder::new(InMemoryDb::new());
238 builder.add_input::<MyValue>(MyValue { x: 42 });
239 builder.add_task::<MyTask>();
240 let mut graph = builder.build();
241 graph.execute::<MyTask>();
242 assert_eq!(graph.db.get::<MyValue2>(), Some(&MyValue2 { x: 42 }));
243 }
244}