Skip to main content

peat_protocol/policy/
resolver.rs

1//! Generic conflict resolver implementation
2//!
3//! Provides conflict detection and resolution for any type implementing Conflictable.
4
5use crate::error::Result;
6use crate::policy::conflictable::{ConflictResult, Conflictable};
7use crate::policy::policies::ResolutionPolicy;
8use std::collections::HashMap;
9use std::marker::PhantomData;
10use std::sync::Arc;
11use tokio::sync::RwLock;
12
13/// Generic conflict resolver for any `Conflictable` type
14///
15/// Maintains an index of active items by conflict key for efficient detection.
16/// Works with any type implementing the `Conflictable` trait.
17///
18/// ## Example
19///
20/// ```rust,ignore
21/// use peat_protocol::policy::{GenericConflictResolver, LastWriteWinsPolicy, Conflictable, ConflictResult};
22///
23/// // Create a resolver for your type
24/// let resolver = GenericConflictResolver::<MyType>::new();
25///
26/// // Check for conflicts
27/// let result = resolver.check_conflict(&my_item).await;
28///
29/// // Resolve if needed
30/// if let ConflictResult::Conflict(existing) = result {
31///     let policy = LastWriteWinsPolicy;
32///     let winner = resolver.resolve(vec![existing[0].clone(), my_item], &policy)?;
33/// }
34/// ```
35pub struct GenericConflictResolver<T: Conflictable> {
36    /// Active items indexed by conflict key
37    /// Key: conflict_key, Value: list of items with that key
38    active_items: Arc<RwLock<HashMap<String, Vec<T>>>>,
39    _phantom: PhantomData<T>,
40}
41
42impl<T: Conflictable> GenericConflictResolver<T> {
43    /// Create a new generic conflict resolver
44    pub fn new() -> Self {
45        Self {
46            active_items: Arc::new(RwLock::new(HashMap::new())),
47            _phantom: PhantomData,
48        }
49    }
50
51    /// Check if a new item conflicts with existing items
52    ///
53    /// Returns `ConflictResult::Conflict` if there are existing items
54    /// with overlapping conflict keys.
55    pub async fn check_conflict(&self, item: &T) -> ConflictResult<T> {
56        let keys = item.conflict_keys();
57        let items = self.active_items.read().await;
58
59        let mut conflicting = Vec::new();
60
61        for key in keys {
62            if let Some(existing) = items.get(&key) {
63                // Avoid duplicates
64                for existing_item in existing {
65                    if !conflicting.iter().any(|c: &T| c.id() == existing_item.id()) {
66                        conflicting.push(existing_item.clone());
67                    }
68                }
69            }
70        }
71
72        if conflicting.is_empty() {
73            ConflictResult::NoConflict
74        } else {
75            ConflictResult::Conflict(conflicting)
76        }
77    }
78
79    /// Resolve conflict using the specified policy
80    ///
81    /// Takes a list of conflicting items and returns the "winning" item
82    /// according to the policy's logic.
83    pub fn resolve(&self, items: Vec<T>, policy: &dyn ResolutionPolicy<T>) -> Result<T> {
84        tracing::debug!(
85            "Resolving conflict between {} items using policy: {}",
86            items.len(),
87            policy.name()
88        );
89
90        policy.resolve(items)
91    }
92
93    /// Register an item as active (after conflict resolution)
94    ///
95    /// Adds the item to the conflict index for future conflict checks.
96    pub async fn register(&self, item: &T) -> Result<()> {
97        let keys = item.conflict_keys();
98        let mut items = self.active_items.write().await;
99
100        for key in keys {
101            items.entry(key).or_default().push(item.clone());
102        }
103
104        Ok(())
105    }
106
107    /// Unregister an item from active tracking
108    ///
109    /// Removes the item from all conflict key indices.
110    /// Called when an item completes, expires, or is cancelled.
111    pub async fn unregister(&self, item_id: &str) -> Result<()> {
112        let mut items = self.active_items.write().await;
113
114        // Remove from all key lists
115        for (_, item_list) in items.iter_mut() {
116            item_list.retain(|item| item.id() != item_id);
117        }
118
119        // Clean up empty keys
120        items.retain(|_, item_list| !item_list.is_empty());
121
122        Ok(())
123    }
124
125    /// Get all active items
126    pub async fn get_all_active(&self) -> Vec<T> {
127        let items = self.active_items.read().await;
128        let mut all_items = Vec::new();
129        let mut seen_ids = std::collections::HashSet::new();
130
131        for item_list in items.values() {
132            for item in item_list {
133                if seen_ids.insert(item.id()) {
134                    all_items.push(item.clone());
135                }
136            }
137        }
138
139        all_items
140    }
141
142    /// Get count of active items
143    pub async fn active_count(&self) -> usize {
144        self.get_all_active().await.len()
145    }
146
147    /// Get active items by conflict key
148    pub async fn get_by_key(&self, key: &str) -> Vec<T> {
149        self.active_items
150            .read()
151            .await
152            .get(key)
153            .cloned()
154            .unwrap_or_default()
155    }
156}
157
158impl<T: Conflictable> Default for GenericConflictResolver<T> {
159    fn default() -> Self {
160        Self::new()
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167    use crate::policy::conflictable::AttributeValue;
168    use crate::policy::policies::{HighestAttributeWinsPolicy, LastWriteWinsPolicy};
169    use std::collections::HashMap;
170
171    #[derive(Clone, Debug, PartialEq)]
172    struct TestItem {
173        id: String,
174        resource: String,
175        timestamp: u64,
176        priority: i64,
177    }
178
179    impl Conflictable for TestItem {
180        fn id(&self) -> String {
181            self.id.clone()
182        }
183
184        fn conflict_keys(&self) -> Vec<String> {
185            vec![self.resource.clone()]
186        }
187
188        fn timestamp(&self) -> Option<u64> {
189            Some(self.timestamp)
190        }
191
192        fn attributes(&self) -> HashMap<String, AttributeValue> {
193            let mut attrs = HashMap::new();
194            attrs.insert("priority".to_string(), AttributeValue::Int(self.priority));
195            attrs
196        }
197    }
198
199    #[tokio::test]
200    async fn test_no_conflict_different_resources() {
201        let resolver = GenericConflictResolver::<TestItem>::new();
202
203        let item1 = TestItem {
204            id: "item-1".to_string(),
205            resource: "resource-a".to_string(),
206            timestamp: 1000,
207            priority: 1,
208        };
209
210        resolver.register(&item1).await.unwrap();
211
212        let item2 = TestItem {
213            id: "item-2".to_string(),
214            resource: "resource-b".to_string(),
215            timestamp: 1001,
216            priority: 2,
217        };
218
219        let result = resolver.check_conflict(&item2).await;
220        assert!(!result.is_conflict());
221    }
222
223    #[tokio::test]
224    async fn test_conflict_same_resource() {
225        let resolver = GenericConflictResolver::<TestItem>::new();
226
227        let item1 = TestItem {
228            id: "item-1".to_string(),
229            resource: "resource-a".to_string(),
230            timestamp: 1000,
231            priority: 1,
232        };
233
234        resolver.register(&item1).await.unwrap();
235
236        let item2 = TestItem {
237            id: "item-2".to_string(),
238            resource: "resource-a".to_string(),
239            timestamp: 1001,
240            priority: 2,
241        };
242
243        let result = resolver.check_conflict(&item2).await;
244        assert!(result.is_conflict());
245
246        if let ConflictResult::Conflict(items) = result {
247            assert_eq!(items.len(), 1);
248            assert_eq!(items[0].id, "item-1");
249        }
250    }
251
252    #[tokio::test]
253    async fn test_resolve_last_write_wins() {
254        let resolver = GenericConflictResolver::<TestItem>::new();
255
256        let item1 = TestItem {
257            id: "item-1".to_string(),
258            resource: "resource-a".to_string(),
259            timestamp: 1000,
260            priority: 1,
261        };
262
263        let item2 = TestItem {
264            id: "item-2".to_string(),
265            resource: "resource-a".to_string(),
266            timestamp: 2000,
267            priority: 2,
268        };
269
270        let policy = LastWriteWinsPolicy;
271        let winner = resolver.resolve(vec![item1, item2], &policy).unwrap();
272
273        assert_eq!(winner.id, "item-2"); // Most recent
274    }
275
276    #[tokio::test]
277    async fn test_resolve_highest_priority() {
278        let resolver = GenericConflictResolver::<TestItem>::new();
279
280        let item1 = TestItem {
281            id: "item-1".to_string(),
282            resource: "resource-a".to_string(),
283            timestamp: 2000,
284            priority: 1,
285        };
286
287        let item2 = TestItem {
288            id: "item-2".to_string(),
289            resource: "resource-a".to_string(),
290            timestamp: 1000,
291            priority: 5,
292        };
293
294        let policy = HighestAttributeWinsPolicy::new("priority");
295        let winner = resolver.resolve(vec![item1, item2], &policy).unwrap();
296
297        assert_eq!(winner.id, "item-2"); // Highest priority
298    }
299
300    #[tokio::test]
301    async fn test_unregister() {
302        let resolver = GenericConflictResolver::<TestItem>::new();
303
304        let item1 = TestItem {
305            id: "item-1".to_string(),
306            resource: "resource-a".to_string(),
307            timestamp: 1000,
308            priority: 1,
309        };
310
311        resolver.register(&item1).await.unwrap();
312        assert_eq!(resolver.active_count().await, 1);
313
314        resolver.unregister("item-1").await.unwrap();
315        assert_eq!(resolver.active_count().await, 0);
316    }
317
318    #[tokio::test]
319    async fn test_get_by_key() {
320        let resolver = GenericConflictResolver::<TestItem>::new();
321
322        let item1 = TestItem {
323            id: "item-1".to_string(),
324            resource: "resource-a".to_string(),
325            timestamp: 1000,
326            priority: 1,
327        };
328
329        let item2 = TestItem {
330            id: "item-2".to_string(),
331            resource: "resource-a".to_string(),
332            timestamp: 1001,
333            priority: 2,
334        };
335
336        resolver.register(&item1).await.unwrap();
337        resolver.register(&item2).await.unwrap();
338
339        let items = resolver.get_by_key("resource-a").await;
340        assert_eq!(items.len(), 2);
341    }
342}