1use serde::{Deserialize, Serialize};
4
5use crate::types::{AgentId, MemoryId, Timestamp};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct ConflictVersion {
10 pub agent_id: AgentId,
11 pub content: String,
12 pub confidence: f32,
13 pub timestamp: Timestamp,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub enum Resolution {
19 KeepLatest,
20 KeepHighestConfidence,
21 Merge(String),
22 Manual(String),
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct Conflict {
28 pub memory_id: MemoryId,
29 pub versions: Vec<ConflictVersion>,
30 pub resolution: Option<Resolution>,
31}
32
33#[derive(Debug, Default)]
35pub struct ConflictResolver;
36
37const CONFLICT_WINDOW_US: Timestamp = 1_000_000; impl ConflictResolver {
41 pub fn new() -> Self {
42 Self
43 }
44
45 pub fn detect_conflict(
47 &self,
48 memory_id: MemoryId,
49 versions: &[ConflictVersion],
50 ) -> Option<Conflict> {
51 if versions.len() < 2 {
52 return None;
53 }
54
55 let mut dominated = vec![false; versions.len()];
58 for i in 0..versions.len() {
59 for j in (i + 1)..versions.len() {
60 let dt = versions[i].timestamp.abs_diff(versions[j].timestamp);
61 if dt <= CONFLICT_WINDOW_US && versions[i].agent_id != versions[j].agent_id {
62 dominated[i] = true;
63 dominated[j] = true;
64 }
65 }
66 }
67
68 let conflicting: Vec<ConflictVersion> = versions
69 .iter()
70 .zip(dominated.iter())
71 .filter(|&(_, d)| *d)
72 .map(|(v, _)| v.clone())
73 .collect();
74
75 if conflicting.len() >= 2 {
76 Some(Conflict {
77 memory_id,
78 versions: conflicting,
79 resolution: None,
80 })
81 } else {
82 None
83 }
84 }
85
86 pub fn auto_resolve(&self, conflict: &Conflict, strategy: Resolution) -> ConflictVersion {
88 match &strategy {
89 Resolution::KeepLatest => self.resolve_keep_latest(conflict),
90 Resolution::KeepHighestConfidence => self.resolve_keep_highest_confidence(conflict),
91 Resolution::Merge(merged) => ConflictVersion {
92 agent_id: conflict.versions[0].agent_id,
93 content: merged.clone(),
94 confidence: conflict
95 .versions
96 .iter()
97 .map(|v| v.confidence)
98 .fold(0.0_f32, f32::max),
99 timestamp: conflict
100 .versions
101 .iter()
102 .map(|v| v.timestamp)
103 .max()
104 .unwrap_or(0),
105 },
106 Resolution::Manual(text) => ConflictVersion {
107 agent_id: conflict.versions[0].agent_id,
108 content: text.clone(),
109 confidence: 1.0,
110 timestamp: conflict
111 .versions
112 .iter()
113 .map(|v| v.timestamp)
114 .max()
115 .unwrap_or(0),
116 },
117 }
118 }
119
120 pub fn resolve_keep_latest(&self, conflict: &Conflict) -> ConflictVersion {
122 conflict
123 .versions
124 .iter()
125 .max_by_key(|v| v.timestamp)
126 .cloned()
127 .expect("conflict must have at least one version")
128 }
129
130 pub fn resolve_keep_highest_confidence(&self, conflict: &Conflict) -> ConflictVersion {
132 conflict
133 .versions
134 .iter()
135 .max_by(|a, b| a.confidence.partial_cmp(&b.confidence).unwrap())
136 .cloned()
137 .expect("conflict must have at least one version")
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use super::*;
144 use uuid::Uuid;
145
146 fn make_version(
147 agent: AgentId,
148 content: &str,
149 confidence: f32,
150 ts: Timestamp,
151 ) -> ConflictVersion {
152 ConflictVersion {
153 agent_id: agent,
154 content: content.to_string(),
155 confidence,
156 timestamp: ts,
157 }
158 }
159
160 #[test]
161 fn no_conflict_single_version() {
162 let r = ConflictResolver::new();
163 let mid = Uuid::new_v4();
164 let v = make_version(Uuid::new_v4(), "a", 0.9, 100);
165 assert!(r.detect_conflict(mid, &[v]).is_none());
166 }
167
168 #[test]
169 fn no_conflict_same_agent() {
170 let r = ConflictResolver::new();
171 let mid = Uuid::new_v4();
172 let a = Uuid::new_v4();
173 let v1 = make_version(a, "a", 0.9, 100);
174 let v2 = make_version(a, "b", 0.8, 200);
175 assert!(r.detect_conflict(mid, &[v1, v2]).is_none());
176 }
177
178 #[test]
179 fn detect_conflict_different_agents() {
180 let r = ConflictResolver::new();
181 let mid = Uuid::new_v4();
182 let a1 = Uuid::new_v4();
183 let a2 = Uuid::new_v4();
184 let v1 = make_version(a1, "v1", 0.8, 1_000_000);
185 let v2 = make_version(a2, "v2", 0.9, 1_500_000);
186 let conflict = r.detect_conflict(mid, &[v1, v2]);
187 assert!(conflict.is_some());
188 assert_eq!(conflict.unwrap().versions.len(), 2);
189 }
190
191 #[test]
192 fn resolve_keep_latest() {
193 let r = ConflictResolver::new();
194 let a1 = Uuid::new_v4();
195 let a2 = Uuid::new_v4();
196 let conflict = Conflict {
197 memory_id: Uuid::new_v4(),
198 versions: vec![
199 make_version(a1, "old", 0.9, 100),
200 make_version(a2, "new", 0.5, 200),
201 ],
202 resolution: None,
203 };
204 let winner = r.auto_resolve(&conflict, Resolution::KeepLatest);
205 assert_eq!(winner.content, "new");
206 }
207
208 #[test]
209 fn resolve_keep_highest_confidence() {
210 let r = ConflictResolver::new();
211 let a1 = Uuid::new_v4();
212 let a2 = Uuid::new_v4();
213 let conflict = Conflict {
214 memory_id: Uuid::new_v4(),
215 versions: vec![
216 make_version(a1, "confident", 0.95, 100),
217 make_version(a2, "unsure", 0.3, 200),
218 ],
219 resolution: None,
220 };
221 let winner = r.auto_resolve(&conflict, Resolution::KeepHighestConfidence);
222 assert_eq!(winner.content, "confident");
223 }
224}