peat_protocol/policy/
resolver.rs1use crate::error::Result;
6use crate::policy::conflictable::{ConflictResult, Conflictable};
7use crate::policy::policies::ResolutionPolicy;
8use std::collections::HashMap;
9use std::marker::PhantomData;
10use std::sync::Arc;
11use tokio::sync::RwLock;
12
13pub struct GenericConflictResolver<T: Conflictable> {
36 active_items: Arc<RwLock<HashMap<String, Vec<T>>>>,
39 _phantom: PhantomData<T>,
40}
41
42impl<T: Conflictable> GenericConflictResolver<T> {
43 pub fn new() -> Self {
45 Self {
46 active_items: Arc::new(RwLock::new(HashMap::new())),
47 _phantom: PhantomData,
48 }
49 }
50
51 pub async fn check_conflict(&self, item: &T) -> ConflictResult<T> {
56 let keys = item.conflict_keys();
57 let items = self.active_items.read().await;
58
59 let mut conflicting = Vec::new();
60
61 for key in keys {
62 if let Some(existing) = items.get(&key) {
63 for existing_item in existing {
65 if !conflicting.iter().any(|c: &T| c.id() == existing_item.id()) {
66 conflicting.push(existing_item.clone());
67 }
68 }
69 }
70 }
71
72 if conflicting.is_empty() {
73 ConflictResult::NoConflict
74 } else {
75 ConflictResult::Conflict(conflicting)
76 }
77 }
78
79 pub fn resolve(&self, items: Vec<T>, policy: &dyn ResolutionPolicy<T>) -> Result<T> {
84 tracing::debug!(
85 "Resolving conflict between {} items using policy: {}",
86 items.len(),
87 policy.name()
88 );
89
90 policy.resolve(items)
91 }
92
93 pub async fn register(&self, item: &T) -> Result<()> {
97 let keys = item.conflict_keys();
98 let mut items = self.active_items.write().await;
99
100 for key in keys {
101 items.entry(key).or_default().push(item.clone());
102 }
103
104 Ok(())
105 }
106
107 pub async fn unregister(&self, item_id: &str) -> Result<()> {
112 let mut items = self.active_items.write().await;
113
114 for (_, item_list) in items.iter_mut() {
116 item_list.retain(|item| item.id() != item_id);
117 }
118
119 items.retain(|_, item_list| !item_list.is_empty());
121
122 Ok(())
123 }
124
125 pub async fn get_all_active(&self) -> Vec<T> {
127 let items = self.active_items.read().await;
128 let mut all_items = Vec::new();
129 let mut seen_ids = std::collections::HashSet::new();
130
131 for item_list in items.values() {
132 for item in item_list {
133 if seen_ids.insert(item.id()) {
134 all_items.push(item.clone());
135 }
136 }
137 }
138
139 all_items
140 }
141
142 pub async fn active_count(&self) -> usize {
144 self.get_all_active().await.len()
145 }
146
147 pub async fn get_by_key(&self, key: &str) -> Vec<T> {
149 self.active_items
150 .read()
151 .await
152 .get(key)
153 .cloned()
154 .unwrap_or_default()
155 }
156}
157
158impl<T: Conflictable> Default for GenericConflictResolver<T> {
159 fn default() -> Self {
160 Self::new()
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167 use crate::policy::conflictable::AttributeValue;
168 use crate::policy::policies::{HighestAttributeWinsPolicy, LastWriteWinsPolicy};
169 use std::collections::HashMap;
170
171 #[derive(Clone, Debug, PartialEq)]
172 struct TestItem {
173 id: String,
174 resource: String,
175 timestamp: u64,
176 priority: i64,
177 }
178
179 impl Conflictable for TestItem {
180 fn id(&self) -> String {
181 self.id.clone()
182 }
183
184 fn conflict_keys(&self) -> Vec<String> {
185 vec![self.resource.clone()]
186 }
187
188 fn timestamp(&self) -> Option<u64> {
189 Some(self.timestamp)
190 }
191
192 fn attributes(&self) -> HashMap<String, AttributeValue> {
193 let mut attrs = HashMap::new();
194 attrs.insert("priority".to_string(), AttributeValue::Int(self.priority));
195 attrs
196 }
197 }
198
199 #[tokio::test]
200 async fn test_no_conflict_different_resources() {
201 let resolver = GenericConflictResolver::<TestItem>::new();
202
203 let item1 = TestItem {
204 id: "item-1".to_string(),
205 resource: "resource-a".to_string(),
206 timestamp: 1000,
207 priority: 1,
208 };
209
210 resolver.register(&item1).await.unwrap();
211
212 let item2 = TestItem {
213 id: "item-2".to_string(),
214 resource: "resource-b".to_string(),
215 timestamp: 1001,
216 priority: 2,
217 };
218
219 let result = resolver.check_conflict(&item2).await;
220 assert!(!result.is_conflict());
221 }
222
223 #[tokio::test]
224 async fn test_conflict_same_resource() {
225 let resolver = GenericConflictResolver::<TestItem>::new();
226
227 let item1 = TestItem {
228 id: "item-1".to_string(),
229 resource: "resource-a".to_string(),
230 timestamp: 1000,
231 priority: 1,
232 };
233
234 resolver.register(&item1).await.unwrap();
235
236 let item2 = TestItem {
237 id: "item-2".to_string(),
238 resource: "resource-a".to_string(),
239 timestamp: 1001,
240 priority: 2,
241 };
242
243 let result = resolver.check_conflict(&item2).await;
244 assert!(result.is_conflict());
245
246 if let ConflictResult::Conflict(items) = result {
247 assert_eq!(items.len(), 1);
248 assert_eq!(items[0].id, "item-1");
249 }
250 }
251
252 #[tokio::test]
253 async fn test_resolve_last_write_wins() {
254 let resolver = GenericConflictResolver::<TestItem>::new();
255
256 let item1 = TestItem {
257 id: "item-1".to_string(),
258 resource: "resource-a".to_string(),
259 timestamp: 1000,
260 priority: 1,
261 };
262
263 let item2 = TestItem {
264 id: "item-2".to_string(),
265 resource: "resource-a".to_string(),
266 timestamp: 2000,
267 priority: 2,
268 };
269
270 let policy = LastWriteWinsPolicy;
271 let winner = resolver.resolve(vec![item1, item2], &policy).unwrap();
272
273 assert_eq!(winner.id, "item-2"); }
275
276 #[tokio::test]
277 async fn test_resolve_highest_priority() {
278 let resolver = GenericConflictResolver::<TestItem>::new();
279
280 let item1 = TestItem {
281 id: "item-1".to_string(),
282 resource: "resource-a".to_string(),
283 timestamp: 2000,
284 priority: 1,
285 };
286
287 let item2 = TestItem {
288 id: "item-2".to_string(),
289 resource: "resource-a".to_string(),
290 timestamp: 1000,
291 priority: 5,
292 };
293
294 let policy = HighestAttributeWinsPolicy::new("priority");
295 let winner = resolver.resolve(vec![item1, item2], &policy).unwrap();
296
297 assert_eq!(winner.id, "item-2"); }
299
300 #[tokio::test]
301 async fn test_unregister() {
302 let resolver = GenericConflictResolver::<TestItem>::new();
303
304 let item1 = TestItem {
305 id: "item-1".to_string(),
306 resource: "resource-a".to_string(),
307 timestamp: 1000,
308 priority: 1,
309 };
310
311 resolver.register(&item1).await.unwrap();
312 assert_eq!(resolver.active_count().await, 1);
313
314 resolver.unregister("item-1").await.unwrap();
315 assert_eq!(resolver.active_count().await, 0);
316 }
317
318 #[tokio::test]
319 async fn test_get_by_key() {
320 let resolver = GenericConflictResolver::<TestItem>::new();
321
322 let item1 = TestItem {
323 id: "item-1".to_string(),
324 resource: "resource-a".to_string(),
325 timestamp: 1000,
326 priority: 1,
327 };
328
329 let item2 = TestItem {
330 id: "item-2".to_string(),
331 resource: "resource-a".to_string(),
332 timestamp: 1001,
333 priority: 2,
334 };
335
336 resolver.register(&item1).await.unwrap();
337 resolver.register(&item2).await.unwrap();
338
339 let items = resolver.get_by_key("resource-a").await;
340 assert_eq!(items.len(), 2);
341 }
342}