1use std::collections::{HashMap, HashSet};
4
5use serde::{Deserialize, Serialize};
6
7mod hir;
8mod lir;
9mod mir;
10
11pub use hir::{CalciteId, CalcitePlan, Condition, Op, Operand, Rel};
12pub use lir::{LirCircuit, LirEdge, LirNode, LirNodeId, LirStreamId};
13pub use mir::{MirInput, MirNode, MirNodeId};
14use utoipa::ToSchema;
15
16#[derive(Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone, Copy)]
17#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
18pub struct SourcePosition {
19 pub start_line_number: usize,
20 pub start_column: usize,
21 pub end_line_number: usize,
22 pub end_column: usize,
23}
24
25#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Hash)]
28pub enum Changes {
29 Added(String),
30 Removed(String),
31 Modified(String),
32}
33
34#[derive(Debug, Deserialize, Serialize, ToSchema, PartialEq, Eq, Clone)]
36pub struct Dataflow {
37 pub calcite_plan: HashMap<String, CalcitePlan>,
38 pub mir: HashMap<MirNodeId, MirNode>,
39}
40
41impl Dataflow {
42 pub fn new(
43 calcite_plan: HashMap<String, CalcitePlan>,
44 mir: HashMap<MirNodeId, MirNode>,
45 ) -> Self {
46 Self { calcite_plan, mir }
47 }
48
49 pub fn diff(&self, other: &Dataflow) -> HashSet<Changes> {
55 let mut changes = HashSet::new();
56 let relation_hashes = self.ids_to_hashes(&self.relation_with_ids());
57 let other_relation_hashes = other.ids_to_hashes(&other.relation_with_ids());
58
59 for (relation, hashes) in relation_hashes.iter() {
60 if let Some(other_hashes) = other_relation_hashes.get(relation) {
61 if hashes != other_hashes {
62 changes.insert(Changes::Modified(relation.clone()));
63 }
64 } else {
65 changes.insert(Changes::Removed(relation.clone()));
66 }
67 }
68
69 for (relation, _hashes) in other_relation_hashes {
70 if !relation_hashes.contains_key(&relation) {
71 changes.insert(Changes::Added(relation));
72 }
73 }
74
75 changes
76 }
77
78 fn relation_with_ids(&self) -> HashMap<String, HashSet<usize>> {
81 let mut relations_and_ids: HashMap<String, HashSet<usize>> = HashMap::new();
82 for (key, cp) in self.calcite_plan.iter() {
83 for rel in &cp.rels {
84 relations_and_ids
85 .entry(key.clone())
86 .or_default()
87 .insert(rel.id);
88 }
89 }
90
91 relations_and_ids
92 }
93
94 fn ids_to_hashes(
97 &self,
98 relations_and_ids: &HashMap<String, HashSet<usize>>,
99 ) -> HashMap<String, HashSet<String>> {
100 let mut hashes = HashMap::new();
101 for node in self.mir.values() {
102 let all_calcite_ids: Vec<usize> = node
103 .calcite
104 .as_ref()
105 .map(|cid| cid.clone().into())
106 .unwrap_or_default();
107 let mut all_calcite_ids_of_node = HashSet::with_capacity(all_calcite_ids.len());
108 for id in all_calcite_ids {
109 all_calcite_ids_of_node.insert(id);
110 }
111
112 if let Some(table) = node.table.as_ref() {
115 if let Some(persistent_id) = &node.persistent_id {
116 hashes
117 .entry(table.clone())
118 .or_insert_with(HashSet::new)
119 .insert(persistent_id.clone());
120 }
121 }
122
123 for (key, ids) in relations_and_ids {
124 let node_is_dependency_for_relation = !ids.is_disjoint(&all_calcite_ids_of_node);
125 if node_is_dependency_for_relation {
126 if let Some(persistent_id) = &node.persistent_id {
127 hashes
128 .entry(key.clone())
129 .or_insert_with(HashSet::new)
130 .insert(persistent_id.clone());
131 }
132 }
133 }
134 }
135 hashes
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use crate::*;
142 use std::collections::HashSet;
143
144 const SAMPLE_A: (&str, &str) = ("sample_a", include_str!("../test/sample_a.json"));
145 const SAMPLE_B: (&str, &str) = ("sample_b", include_str!("../test/sample_b.json"));
146 const SAMPLE_B_MOD1: (&str, &str) =
147 ("sample_b_mod1", include_str!("../test/sample_b_mod1.json"));
148 const SAMPLE_B_MOD2: (&str, &str) =
149 ("sample_b_mod2", include_str!("../test/sample_b_mod2.json"));
150 const SAMPLE_B_MOD3: (&str, &str) =
151 ("sample_b_mod3", include_str!("../test/sample_b_mod3.json"));
152 const SAMPLE_C: (&str, &str) = ("sample_c", include_str!("../test/sample_c.json"));
153 const SAMPLES: &[(&str, &str)] = &[
154 SAMPLE_A,
155 SAMPLE_B,
156 SAMPLE_B_MOD1,
157 SAMPLE_B_MOD2,
158 SAMPLE_B_MOD3,
159 SAMPLE_C,
160 ];
161
162 #[test]
163 fn can_parse_ir() {
164 for (_name, json) in SAMPLES.iter() {
165 let _plan: Dataflow = serde_json::from_str(json).unwrap();
167 }
169 }
170
171 #[test]
172 fn can_get_relations_and_ids() {
173 let plan: Dataflow = serde_json::from_str(SAMPLE_B.1).unwrap();
174 let relations = plan.relation_with_ids();
175 assert_eq!(relations["error_view"], HashSet::from([0]));
176
177 let plan: Dataflow = serde_json::from_str(SAMPLE_A.1).unwrap();
178 let relations = plan.relation_with_ids();
179 assert_eq!(relations["error_view"], HashSet::from([0]));
180 assert_eq!(
181 relations["group_can_read"],
182 HashSet::from([1, 2, 3, 4, 5, 6, 7])
183 );
184 assert_eq!(
185 relations["group_can_write"],
186 HashSet::from([8, 9, 10, 11, 12, 13])
187 );
188 assert_eq!(
189 relations["user_can_read"],
190 HashSet::from([14, 15, 16, 17, 18, 19, 20, 21])
191 );
192 assert_eq!(
193 relations["user_can_write"],
194 HashSet::from([22, 23, 24, 25, 26, 27, 28, 29])
195 );
196 }
197
198 #[test]
199 fn can_find_hashes_in_a() {
200 let plan: Dataflow = serde_json::from_str(SAMPLE_A.1).unwrap();
201 let relations = plan.relation_with_ids();
202 let hashes = plan.ids_to_hashes(&relations);
203
204 assert_eq!(
205 hashes["error_view"],
206 HashSet::from([
207 "8b384059bdb44ad811ab341cc5e2a59697f39aac7b463cab027b185db8105e73".to_string(),
208 "933ebf782e1fe804fe85c4d0f3688bdb5234b386c2834892776e692acd9781d9".to_string()
209 ])
210 );
211
212 assert_eq!(
213 hashes["group_file_viewer"],
214 HashSet::from([
215 "44b862944cb9ff1772f75112d6b74d87bcbbe770502fe47f91b07c0bb3987bb3".to_string()
216 ])
217 );
218
219 assert_eq!(
220 hashes["user_can_write"],
221 HashSet::from([
222 "2f90ee4cdb4895d44ac7efb7104402dcf39a5fcfbe90492cc95311d4c70f623e".to_string(),
223 "db1532ae31ea981721261c4a3892a6f373f98ecce41c59b2b8a5f5186c3c7d69".to_string(),
224 "53944e28b6a21187dccb34ee9859bddbb4266157b21541feb2e4166a4034e907".to_string(),
225 "61a52a49c5285c66a9656f211205002d48bd282b9f7f48be666f7fe7c208a338".to_string(),
226 "739c3d0dafe5c2f650824df3529602ddae8acfa8789b35d80ea3eb7c3b156796".to_string(),
227 "e989408d6aaecac2943caac41fdab83aca539622526b4289303c6a4de6eb658f".to_string(),
228 "71d57c70dd7a5e3ae6da1adc565f9c119f93d4dcd50c61573a117d5d9aac3389".to_string(),
229 "a8918a1fd4c90f6091dade7d6a44d46bb72809da98262781240ca1be5d738271".to_string(),
230 "94b255b29c463d2918fe4a8c23cc75e943cb10e486d21942c8cb8c124c31eb7f".to_string(),
231 ])
232 );
233 }
234
235 #[test]
236 fn unchanged_diff_is_empty() {
237 for (_name, json) in SAMPLES.iter() {
238 let plan: Dataflow = serde_json::from_str(json).unwrap();
239 assert!(plan.diff(&plan).is_empty());
240 }
241 }
242
243 fn diff(json1: &str, json2: &str) -> HashSet<Changes> {
244 let plan1: Dataflow = serde_json::from_str(json1).unwrap();
245 let plan2: Dataflow = serde_json::from_str(json2).unwrap();
246 plan1.diff(&plan2)
247 }
248
249 #[test]
250 fn change_only_view() {
251 let diff = diff(SAMPLE_B.1, SAMPLE_B_MOD1.1);
252 assert_eq!(
253 diff,
254 HashSet::from([Changes::Modified("example_count".to_string())])
255 );
256 }
257
258 #[test]
259 fn change_table() {
260 let diff = diff(SAMPLE_B.1, SAMPLE_B_MOD2.1);
261 assert_eq!(
262 diff,
263 HashSet::from([
264 Changes::Modified("example".to_string()),
265 Changes::Modified("example_count".to_string())
266 ])
267 );
268 }
269
270 #[test]
271 fn add_table() {
272 let diff = diff(SAMPLE_B.1, SAMPLE_B_MOD3.1);
273 assert_eq!(
274 diff,
275 HashSet::from([
276 Changes::Added("example_new".to_string()),
278 Changes::Added("example_view_count".to_string())
279 ])
280 );
281 }
282
283 #[test]
284 fn remove_table() {
285 let diff = diff(SAMPLE_B_MOD3.1, SAMPLE_B.1);
286 assert_eq!(
287 diff,
288 HashSet::from([
289 Changes::Removed("example_new".to_string()),
290 Changes::Removed("example_view_count".to_string())
291 ])
292 );
293 }
294}