1use serde::{Deserialize, Serialize};
4
5use crate::types::{AgentId, MemoryId, Timestamp};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct ConflictVersion {
10 pub agent_id: AgentId,
12 pub content: String,
14 pub confidence: f32,
16 pub timestamp: Timestamp,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub enum Resolution {
23 KeepLatest,
24 KeepHighestConfidence,
25 Merge(String),
26 Manual(String),
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct Conflict {
32 pub memory_id: MemoryId,
34 pub versions: Vec<ConflictVersion>,
36 pub resolution: Option<Resolution>,
38}
39
40#[derive(Debug, Default)]
42pub struct ConflictResolver;
43
44const CONFLICT_WINDOW_US: Timestamp = 1_000_000; impl ConflictResolver {
48 pub fn new() -> Self {
50 Self
51 }
52
53 pub fn detect_conflict(
55 &self,
56 memory_id: MemoryId,
57 versions: &[ConflictVersion],
58 ) -> Option<Conflict> {
59 if versions.len() < 2 {
60 return None;
61 }
62
63 let mut dominated = vec![false; versions.len()];
66 for i in 0..versions.len() {
67 for j in (i + 1)..versions.len() {
68 let dt = versions[i].timestamp.abs_diff(versions[j].timestamp);
69 if dt <= CONFLICT_WINDOW_US && versions[i].agent_id != versions[j].agent_id {
70 dominated[i] = true;
71 dominated[j] = true;
72 }
73 }
74 }
75
76 let conflicting: Vec<ConflictVersion> = versions
77 .iter()
78 .zip(dominated.iter())
79 .filter(|&(_, d)| *d)
80 .map(|(v, _)| v.clone())
81 .collect();
82
83 if conflicting.len() >= 2 {
84 Some(Conflict {
85 memory_id,
86 versions: conflicting,
87 resolution: None,
88 })
89 } else {
90 None
91 }
92 }
93
94 pub fn auto_resolve(&self, conflict: &Conflict, strategy: Resolution) -> ConflictVersion {
96 match &strategy {
97 Resolution::KeepLatest => self.resolve_keep_latest(conflict),
98 Resolution::KeepHighestConfidence => self.resolve_keep_highest_confidence(conflict),
99 Resolution::Merge(merged) => ConflictVersion {
100 agent_id: conflict.versions[0].agent_id,
101 content: merged.clone(),
102 confidence: conflict
103 .versions
104 .iter()
105 .map(|v| v.confidence)
106 .fold(0.0_f32, f32::max),
107 timestamp: conflict
108 .versions
109 .iter()
110 .map(|v| v.timestamp)
111 .max()
112 .unwrap_or(0),
113 },
114 Resolution::Manual(text) => ConflictVersion {
115 agent_id: conflict.versions[0].agent_id,
116 content: text.clone(),
117 confidence: 1.0,
118 timestamp: conflict
119 .versions
120 .iter()
121 .map(|v| v.timestamp)
122 .max()
123 .unwrap_or(0),
124 },
125 }
126 }
127
128 pub fn resolve_keep_latest(&self, conflict: &Conflict) -> ConflictVersion {
130 conflict
131 .versions
132 .iter()
133 .max_by_key(|v| v.timestamp)
134 .cloned()
135 .expect("conflict must have at least one version")
136 }
137
138 pub fn resolve_keep_highest_confidence(&self, conflict: &Conflict) -> ConflictVersion {
140 conflict
141 .versions
142 .iter()
143 .max_by(|a, b| a.confidence.partial_cmp(&b.confidence).unwrap())
144 .cloned()
145 .expect("conflict must have at least one version")
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use super::*;
152
153 fn make_version(
154 agent: AgentId,
155 content: &str,
156 confidence: f32,
157 ts: Timestamp,
158 ) -> ConflictVersion {
159 ConflictVersion {
160 agent_id: agent,
161 content: content.to_string(),
162 confidence,
163 timestamp: ts,
164 }
165 }
166
167 #[test]
168 fn no_conflict_single_version() {
169 let r = ConflictResolver::new();
170 let mid = MemoryId::new();
171 let v = make_version(AgentId::new(), "a", 0.9, 100);
172 assert!(r.detect_conflict(mid, &[v]).is_none());
173 }
174
175 #[test]
176 fn no_conflict_same_agent() {
177 let r = ConflictResolver::new();
178 let mid = MemoryId::new();
179 let a = AgentId::new();
180 let v1 = make_version(a, "a", 0.9, 100);
181 let v2 = make_version(a, "b", 0.8, 200);
182 assert!(r.detect_conflict(mid, &[v1, v2]).is_none());
183 }
184
185 #[test]
186 fn detect_conflict_different_agents() {
187 let r = ConflictResolver::new();
188 let mid = MemoryId::new();
189 let a1 = AgentId::new();
190 let a2 = AgentId::new();
191 let v1 = make_version(a1, "v1", 0.8, 1_000_000);
192 let v2 = make_version(a2, "v2", 0.9, 1_500_000);
193 let conflict = r.detect_conflict(mid, &[v1, v2]);
194 assert!(conflict.is_some());
195 assert_eq!(conflict.unwrap().versions.len(), 2);
196 }
197
198 #[test]
199 fn resolve_keep_latest() {
200 let r = ConflictResolver::new();
201 let a1 = AgentId::new();
202 let a2 = AgentId::new();
203 let conflict = Conflict {
204 memory_id: MemoryId::new(),
205 versions: vec![
206 make_version(a1, "old", 0.9, 100),
207 make_version(a2, "new", 0.5, 200),
208 ],
209 resolution: None,
210 };
211 let winner = r.auto_resolve(&conflict, Resolution::KeepLatest);
212 assert_eq!(winner.content, "new");
213 }
214
215 #[test]
216 fn resolve_keep_highest_confidence() {
217 let r = ConflictResolver::new();
218 let a1 = AgentId::new();
219 let a2 = AgentId::new();
220 let conflict = Conflict {
221 memory_id: MemoryId::new(),
222 versions: vec![
223 make_version(a1, "confident", 0.95, 100),
224 make_version(a2, "unsure", 0.3, 200),
225 ],
226 resolution: None,
227 };
228 let winner = r.auto_resolve(&conflict, Resolution::KeepHighestConfidence);
229 assert_eq!(winner.content, "confident");
230 }
231}