1use std::collections::{HashMap, HashSet};
8
9use crate::pattern::Pattern;
10use crate::subject::{Subject, Symbol};
11
12pub trait HasIdentity<V, I: Ord> {
18 fn identity(v: &V) -> &I;
19}
20
21pub trait Mergeable {
23 type MergeStrategy;
24
25 fn merge(strategy: &Self::MergeStrategy, a: Self, b: Self) -> Self;
27}
28
29pub trait Refinable {
31 fn is_refinement_of(sup: &Self, sub: &Self) -> bool;
33}
34
35#[derive(Debug, Clone, PartialEq, Eq)]
41pub enum ReconciliationPolicy<S> {
42 LastWriteWins,
44 FirstWriteWins,
46 Merge(ElementMergeStrategy, S),
48 Strict,
50}
51
52#[derive(Debug, Clone, PartialEq, Eq)]
54pub enum ElementMergeStrategy {
55 ReplaceElements,
57 AppendElements,
59 UnionElements,
61}
62
63#[derive(Debug, Clone, PartialEq, Eq)]
69pub struct SubjectMergeStrategy {
70 pub label_merge: LabelMerge,
71 pub property_merge: PropertyMerge,
72}
73
74#[derive(Debug, Clone, PartialEq, Eq)]
76pub enum LabelMerge {
77 UnionLabels,
78 IntersectLabels,
79 ReplaceLabels,
80}
81
82#[derive(Debug, Clone, PartialEq, Eq)]
84pub enum PropertyMerge {
85 ReplaceProperties,
86 ShallowMerge,
87 DeepMerge,
88}
89
90pub fn default_subject_merge_strategy() -> SubjectMergeStrategy {
92 SubjectMergeStrategy {
93 label_merge: LabelMerge::UnionLabels,
94 property_merge: PropertyMerge::ShallowMerge,
95 }
96}
97
98impl HasIdentity<Subject, Symbol> for Subject {
103 fn identity(v: &Subject) -> &Symbol {
104 &v.identity
105 }
106}
107
108impl Mergeable for Subject {
109 type MergeStrategy = SubjectMergeStrategy;
110
111 fn merge(strategy: &SubjectMergeStrategy, a: Subject, b: Subject) -> Subject {
112 let merged_labels = merge_labels(&strategy.label_merge, &a.labels, &b.labels);
113 let merged_props = merge_properties(&strategy.property_merge, a.properties, b.properties);
114 Subject {
115 identity: a.identity,
116 labels: merged_labels,
117 properties: merged_props,
118 }
119 }
120}
121
122impl Refinable for Subject {
123 fn is_refinement_of(sup: &Subject, sub: &Subject) -> bool {
124 sup.identity == sub.identity
125 && sub.labels.is_subset(&sup.labels)
126 && sub
127 .properties
128 .iter()
129 .all(|(k, v)| sup.properties.get(k) == Some(v))
130 }
131}
132
133fn merge_labels(
134 strategy: &LabelMerge,
135 l1: &HashSet<String>,
136 l2: &HashSet<String>,
137) -> HashSet<String> {
138 match strategy {
139 LabelMerge::UnionLabels => l1.union(l2).cloned().collect(),
140 LabelMerge::IntersectLabels => l1.intersection(l2).cloned().collect(),
141 LabelMerge::ReplaceLabels => l2.clone(),
142 }
143}
144
145fn merge_properties(
146 strategy: &PropertyMerge,
147 p1: HashMap<String, crate::subject::Value>,
148 p2: HashMap<String, crate::subject::Value>,
149) -> HashMap<String, crate::subject::Value> {
150 match strategy {
151 PropertyMerge::ReplaceProperties => p2,
152 PropertyMerge::ShallowMerge | PropertyMerge::DeepMerge => {
153 let mut merged = p1;
155 merged.extend(p2);
156 merged
157 }
158 }
159}
160
161#[derive(Debug, Clone, PartialEq)]
167pub struct ReconcileError {
168 pub message: String,
169}
170
171pub fn reconcile<V>(
180 policy: &ReconciliationPolicy<V::MergeStrategy>,
181 pattern: &Pattern<V>,
182) -> Result<Pattern<V>, ReconcileError>
183where
184 V: HasIdentity<V, Symbol> + Mergeable + Refinable + PartialEq + Clone,
185{
186 match policy {
187 ReconciliationPolicy::Strict => reconcile_strict(pattern),
188 _ => Ok(reconcile_non_strict(policy, pattern)),
189 }
190}
191
192fn reconcile_non_strict<V>(
193 policy: &ReconciliationPolicy<V::MergeStrategy>,
194 pattern: &Pattern<V>,
195) -> Pattern<V>
196where
197 V: HasIdentity<V, Symbol> + Mergeable + Refinable + Clone,
198{
199 let occurrence_map = collect_by_identity(pattern);
200 let canonical_map: HashMap<Symbol, Pattern<V>> = occurrence_map
201 .into_iter()
202 .map(|(id, occurrences)| {
203 let canonical = reconcile_occurrences(policy, occurrences);
204 (id, canonical)
205 })
206 .collect();
207
208 let (rebuilt, _) = rebuild_pattern(&mut HashSet::new(), &canonical_map, pattern);
209 rebuilt
210}
211
212fn reconcile_occurrences<V>(
213 policy: &ReconciliationPolicy<V::MergeStrategy>,
214 occurrences: Vec<Pattern<V>>,
215) -> Pattern<V>
216where
217 V: HasIdentity<V, Symbol> + Mergeable + Clone,
218{
219 match policy {
220 ReconciliationPolicy::LastWriteWins => {
221 let v = occurrences.last().unwrap().value.clone();
222 let all_elements = merge_elements_union(occurrences.iter().map(|p| &p.elements));
223 Pattern {
224 value: v,
225 elements: all_elements,
226 }
227 }
228 ReconciliationPolicy::FirstWriteWins => {
229 let v = occurrences.first().unwrap().value.clone();
230 let all_elements = merge_elements_union(occurrences.iter().map(|p| &p.elements));
231 Pattern {
232 value: v,
233 elements: all_elements,
234 }
235 }
236 ReconciliationPolicy::Merge(elem_strat, val_strat) => {
237 let merged_val = occurrences
238 .iter()
239 .skip(1)
240 .fold(occurrences[0].value.clone(), |acc, p| {
241 V::merge(val_strat, acc, p.value.clone())
242 });
243 let merged_elements =
244 merge_elements(elem_strat, occurrences.iter().map(|p| &p.elements));
245 Pattern {
246 value: merged_val,
247 elements: merged_elements,
248 }
249 }
250 ReconciliationPolicy::Strict => unreachable!("Strict handled separately"),
251 }
252}
253
254fn merge_elements_union<'a, V, I>(lists: I) -> Vec<Pattern<V>>
255where
256 V: HasIdentity<V, Symbol> + Clone,
257 I: Iterator<Item = &'a Vec<Pattern<V>>>,
258 V: 'a,
259{
260 merge_elements(&ElementMergeStrategy::UnionElements, lists)
261}
262
263fn merge_elements<'a, V, I>(strategy: &ElementMergeStrategy, lists: I) -> Vec<Pattern<V>>
264where
265 V: HasIdentity<V, Symbol> + Clone,
266 I: Iterator<Item = &'a Vec<Pattern<V>>>,
267 V: 'a,
268{
269 let all: Vec<Vec<Pattern<V>>> = lists.cloned().collect();
270 match strategy {
271 ElementMergeStrategy::ReplaceElements => all.into_iter().last().unwrap_or_default(),
272 ElementMergeStrategy::AppendElements => all.into_iter().flatten().collect(),
273 ElementMergeStrategy::UnionElements => {
274 let mut seen: HashMap<Symbol, Pattern<V>> = HashMap::new();
275 for elem in all.into_iter().flatten() {
276 let id = V::identity(&elem.value).clone();
277 seen.entry(id).or_insert(elem);
278 }
279 seen.into_values().collect()
280 }
281 }
282}
283
284fn rebuild_pattern<V>(
285 visited: &mut HashSet<Symbol>,
286 canonical_map: &HashMap<Symbol, Pattern<V>>,
287 pattern: &Pattern<V>,
288) -> (Pattern<V>, ())
289where
290 V: HasIdentity<V, Symbol> + Clone,
291{
292 let v_id = V::identity(&pattern.value).clone();
293 let source = canonical_map.get(&v_id).unwrap_or(pattern);
294 visited.insert(v_id);
295
296 let mut rebuilt_elems = Vec::new();
297 for elem in &source.elements {
298 let elem_id = V::identity(&elem.value).clone();
299 if !visited.contains(&elem_id) {
300 let (rebuilt, _) = rebuild_pattern(visited, canonical_map, elem);
301 rebuilt_elems.push(rebuilt);
302 }
303 }
304
305 (
306 Pattern {
307 value: source.value.clone(),
308 elements: rebuilt_elems,
309 },
310 (),
311 )
312}
313
314fn collect_by_identity<V>(pattern: &Pattern<V>) -> HashMap<Symbol, Vec<Pattern<V>>>
315where
316 V: HasIdentity<V, Symbol> + Clone,
317{
318 let mut map: HashMap<Symbol, Vec<Pattern<V>>> = HashMap::new();
319 collect_recursive(pattern, &mut map);
320 map
321}
322
323fn collect_recursive<V>(pattern: &Pattern<V>, map: &mut HashMap<Symbol, Vec<Pattern<V>>>)
324where
325 V: HasIdentity<V, Symbol> + Clone,
326{
327 let id = V::identity(&pattern.value).clone();
328 map.entry(id).or_default().push(pattern.clone());
329 for elem in &pattern.elements {
330 collect_recursive(elem, map);
331 }
332}
333
334fn reconcile_strict<V>(pattern: &Pattern<V>) -> Result<Pattern<V>, ReconcileError>
335where
336 V: HasIdentity<V, Symbol> + Mergeable + Refinable + PartialEq + Clone,
337{
338 let occurrence_map = collect_by_identity(pattern);
339 for occurrences in occurrence_map.values() {
340 if occurrences.len() > 1 {
341 let first = &occurrences[0];
342 for other in &occurrences[1..] {
343 if other.value != first.value {
344 return Err(ReconcileError {
345 message: "Duplicate identities with different content".to_string(),
346 });
347 }
348 }
349 }
350 }
351 Ok(reconcile_non_strict(
352 &ReconciliationPolicy::FirstWriteWins,
353 pattern,
354 ))
355}