1use super::{Conflict, ConflictType, ThreeWayMerge};
4use crate::error::Result;
5use crate::types::Memory;
6use chrono::{DateTime, Utc};
7use rusqlite::Connection;
8use serde::{Deserialize, Serialize};
9use std::collections::VecDeque;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
13pub enum ResolutionStrategy {
14 KeepLocal,
16 KeepRemote,
18 ThreeWayMerge,
20 KeepBoth,
22 TakeNewer,
24 TakeLonger,
26 CustomMerge,
28 AutoMerge,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct Resolution {
35 pub strategy: ResolutionStrategy,
37 pub resolved_memory: Memory,
39 pub resolved_at: DateTime<Utc>,
41 pub resolved_by: String,
43 pub notes: Option<String>,
45}
46
47impl Resolution {
48 pub fn new(
50 strategy: ResolutionStrategy,
51 memory: Memory,
52 resolved_by: impl Into<String>,
53 ) -> Self {
54 Self {
55 strategy,
56 resolved_memory: memory,
57 resolved_at: Utc::now(),
58 resolved_by: resolved_by.into(),
59 notes: None,
60 }
61 }
62
63 pub fn with_notes(mut self, notes: impl Into<String>) -> Self {
65 self.notes = Some(notes.into());
66 self
67 }
68}
69
70pub struct ConflictResolver {
72 merger: ThreeWayMerge,
73}
74
75impl Default for ConflictResolver {
76 fn default() -> Self {
77 Self::new()
78 }
79}
80
81impl ConflictResolver {
82 pub fn new() -> Self {
84 Self {
85 merger: ThreeWayMerge::new(),
86 }
87 }
88
89 pub fn resolve(&self, conflict: &Conflict, strategy: ResolutionStrategy) -> Result<Resolution> {
91 let resolved_memory = match strategy {
92 ResolutionStrategy::KeepLocal => conflict.local.memory.clone(),
93 ResolutionStrategy::KeepRemote => conflict.remote.memory.clone(),
94 ResolutionStrategy::ThreeWayMerge => self.three_way_merge(conflict)?,
95 ResolutionStrategy::KeepBoth => {
96 conflict.local.memory.clone()
98 }
99 ResolutionStrategy::TakeNewer => {
100 if conflict.local.created_at > conflict.remote.created_at {
101 conflict.local.memory.clone()
102 } else {
103 conflict.remote.memory.clone()
104 }
105 }
106 ResolutionStrategy::TakeLonger => {
107 if conflict.local.memory.content.len() >= conflict.remote.memory.content.len() {
108 conflict.local.memory.clone()
109 } else {
110 conflict.remote.memory.clone()
111 }
112 }
113 ResolutionStrategy::AutoMerge => self.auto_merge(conflict)?,
114 ResolutionStrategy::CustomMerge => {
115 conflict.local.memory.clone()
117 }
118 };
119
120 Ok(Resolution::new(strategy, resolved_memory, "system"))
121 }
122
123 fn three_way_merge(&self, conflict: &Conflict) -> Result<Memory> {
125 let base_content = conflict
126 .base
127 .as_ref()
128 .map(|b| b.memory.content.as_str())
129 .unwrap_or("");
130
131 let merge_result = self.merger.merge(
132 base_content,
133 &conflict.local.memory.content,
134 &conflict.remote.memory.content,
135 );
136
137 let mut result = conflict.local.memory.clone();
138 result.content = merge_result.content;
139 result.updated_at = Utc::now();
140
141 let base_tags: Vec<String> = conflict
143 .base
144 .as_ref()
145 .map(|b| b.memory.tags.clone())
146 .unwrap_or_default();
147
148 result.tags = self.merger.merge_tags(
149 &base_tags,
150 &conflict.local.memory.tags,
151 &conflict.remote.memory.tags,
152 );
153
154 let base_meta = conflict.base.as_ref().map(|b| &b.memory.metadata);
156 result.metadata = self.merger.merge_metadata_map(
157 base_meta,
158 &conflict.local.memory.metadata,
159 &conflict.remote.memory.metadata,
160 );
161
162 Ok(result)
163 }
164
165 fn auto_merge(&self, conflict: &Conflict) -> Result<Memory> {
167 match conflict.conflict_type {
168 ConflictType::MetadataOnly => {
169 let mut result = conflict.local.memory.clone();
170 let base_meta = conflict.base.as_ref().map(|b| &b.memory.metadata);
171 result.metadata = self.merger.merge_metadata_map(
172 base_meta,
173 &conflict.local.memory.metadata,
174 &conflict.remote.memory.metadata,
175 );
176 result.updated_at = Utc::now();
177 Ok(result)
178 }
179 ConflictType::TagsOnly => {
180 let mut result = conflict.local.memory.clone();
181 let base_tags: Vec<String> = conflict
182 .base
183 .as_ref()
184 .map(|b| b.memory.tags.clone())
185 .unwrap_or_default();
186 result.tags = self.merger.merge_tags(
187 &base_tags,
188 &conflict.local.memory.tags,
189 &conflict.remote.memory.tags,
190 );
191 result.updated_at = Utc::now();
192 Ok(result)
193 }
194 ConflictType::NonOverlapping => self.three_way_merge(conflict),
195 _ => {
196 self.three_way_merge(conflict)
198 }
199 }
200 }
201
202 pub fn suggest_strategy(&self, conflict: &Conflict) -> ResolutionStrategy {
204 match conflict.conflict_type {
205 ConflictType::MetadataOnly => ResolutionStrategy::AutoMerge,
206 ConflictType::TagsOnly => ResolutionStrategy::AutoMerge,
207 ConflictType::NonOverlapping => ResolutionStrategy::ThreeWayMerge,
208 ConflictType::ContentConflict => {
209 let local_len = conflict.local.memory.content.len();
211 let remote_len = conflict.remote.memory.content.len();
212
213 if local_len > remote_len * 2 {
214 ResolutionStrategy::KeepLocal
215 } else if remote_len > local_len * 2 {
216 ResolutionStrategy::KeepRemote
217 } else {
218 ResolutionStrategy::ThreeWayMerge
219 }
220 }
221 ConflictType::DeleteModify => ResolutionStrategy::TakeNewer,
222 ConflictType::CreateCreate => ResolutionStrategy::KeepBoth,
223 }
224 }
225}
226
227pub struct ConflictQueue {
229 conflicts: VecDeque<Conflict>,
231 max_size: usize,
233}
234
235impl Default for ConflictQueue {
236 fn default() -> Self {
237 Self::new(1000)
238 }
239}
240
241impl ConflictQueue {
242 pub fn new(max_size: usize) -> Self {
244 Self {
245 conflicts: VecDeque::new(),
246 max_size,
247 }
248 }
249
250 pub fn push(&mut self, conflict: Conflict) -> bool {
252 if self.conflicts.len() >= self.max_size {
253 return false;
254 }
255 self.conflicts.push_back(conflict);
256 true
257 }
258
259 pub fn pop(&mut self) -> Option<Conflict> {
261 self.conflicts.pop_front()
262 }
263
264 pub fn peek(&self) -> Option<&Conflict> {
266 self.conflicts.front()
267 }
268
269 pub fn get(&self, id: &str) -> Option<&Conflict> {
271 self.conflicts.iter().find(|c| c.id == id)
272 }
273
274 pub fn remove(&mut self, id: &str) -> Option<Conflict> {
276 let pos = self.conflicts.iter().position(|c| c.id == id)?;
277 self.conflicts.remove(pos)
278 }
279
280 pub fn len(&self) -> usize {
282 self.conflicts.len()
283 }
284
285 pub fn is_empty(&self) -> bool {
287 self.conflicts.is_empty()
288 }
289
290 pub fn all(&self) -> impl Iterator<Item = &Conflict> {
292 self.conflicts.iter()
293 }
294
295 pub fn by_memory_id(&self, memory_id: i64) -> Vec<&Conflict> {
297 self.conflicts
298 .iter()
299 .filter(|c| c.memory_id == memory_id)
300 .collect()
301 }
302
303 pub fn auto_resolvable(&self) -> Vec<&Conflict> {
305 self.conflicts
306 .iter()
307 .filter(|c| c.can_auto_resolve())
308 .collect()
309 }
310
311 pub fn clear(&mut self) {
313 self.conflicts.clear();
314 }
315}
316
317#[allow(dead_code)]
319pub fn init_conflict_tables(conn: &Connection) -> Result<()> {
320 conn.execute_batch(
321 r#"
322 CREATE TABLE IF NOT EXISTS conflicts (
323 id TEXT PRIMARY KEY,
324 memory_id INTEGER NOT NULL,
325 base_version TEXT,
326 local_version TEXT NOT NULL,
327 remote_version TEXT NOT NULL,
328 conflict_type TEXT NOT NULL,
329 detected_at TEXT NOT NULL,
330 resolved INTEGER NOT NULL DEFAULT 0,
331 resolution TEXT,
332 created_at TEXT NOT NULL DEFAULT (datetime('now'))
333 );
334
335 CREATE INDEX IF NOT EXISTS idx_conflicts_memory ON conflicts(memory_id);
336 CREATE INDEX IF NOT EXISTS idx_conflicts_resolved ON conflicts(resolved);
337 "#,
338 )?;
339 Ok(())
340}
341
342#[cfg(test)]
343mod tests {
344 use super::*;
345 use crate::sync::conflict::SyncMemoryVersion;
346 use crate::types::{MemoryType, Visibility};
347 use std::collections::HashMap;
348
349 fn create_test_memory(content: &str) -> Memory {
350 Memory {
351 id: 1,
352 content: content.to_string(),
353 memory_type: MemoryType::Note,
354 tags: vec!["test".to_string()],
355 metadata: HashMap::new(),
356 importance: 0.5,
357 access_count: 0,
358 created_at: Utc::now(),
359 updated_at: Utc::now(),
360 last_accessed_at: None,
361 owner_id: None,
362 visibility: Visibility::Private,
363 scope: crate::types::MemoryScope::Global,
364 workspace: "default".to_string(),
365 tier: crate::types::MemoryTier::Permanent,
366 version: 1,
367 has_embedding: false,
368 expires_at: None,
369 content_hash: None,
370 event_time: None,
371 event_duration_seconds: None,
372 trigger_pattern: None,
373 procedure_success_count: 0,
374 procedure_failure_count: 0,
375 summary_of_id: None,
376 lifecycle_state: crate::types::LifecycleState::Active,
377 }
378 }
379
380 fn create_conflict(local_content: &str, remote_content: &str) -> Conflict {
381 let local = SyncMemoryVersion::new(create_test_memory(local_content), "local");
382 let remote = SyncMemoryVersion::new(create_test_memory(remote_content), "remote");
383 Conflict::new(1, None, local, remote, ConflictType::ContentConflict)
384 }
385
386 #[test]
387 fn test_resolve_keep_local() {
388 let resolver = ConflictResolver::new();
389 let conflict = create_conflict("Local content", "Remote content");
390
391 let resolution = resolver
392 .resolve(&conflict, ResolutionStrategy::KeepLocal)
393 .unwrap();
394 assert_eq!(resolution.resolved_memory.content, "Local content");
395 assert_eq!(resolution.strategy, ResolutionStrategy::KeepLocal);
396 }
397
398 #[test]
399 fn test_resolve_keep_remote() {
400 let resolver = ConflictResolver::new();
401 let conflict = create_conflict("Local content", "Remote content");
402
403 let resolution = resolver
404 .resolve(&conflict, ResolutionStrategy::KeepRemote)
405 .unwrap();
406 assert_eq!(resolution.resolved_memory.content, "Remote content");
407 }
408
409 #[test]
410 fn test_resolve_take_longer() {
411 let resolver = ConflictResolver::new();
412 let conflict = create_conflict("Short", "This is much longer content");
413
414 let resolution = resolver
415 .resolve(&conflict, ResolutionStrategy::TakeLonger)
416 .unwrap();
417 assert_eq!(
418 resolution.resolved_memory.content,
419 "This is much longer content"
420 );
421 }
422
423 #[test]
424 fn test_conflict_queue() {
425 let mut queue = ConflictQueue::new(10);
426
427 let c1 = create_conflict("A", "B");
428 let c2 = create_conflict("C", "D");
429 let id1 = c1.id.clone();
430
431 queue.push(c1);
432 queue.push(c2);
433
434 assert_eq!(queue.len(), 2);
435 assert!(queue.get(&id1).is_some());
436
437 let popped = queue.pop().unwrap();
438 assert_eq!(popped.id, id1);
439 assert_eq!(queue.len(), 1);
440 }
441
442 #[test]
443 fn test_suggest_strategy() {
444 let resolver = ConflictResolver::new();
445
446 let mut local_mem = create_test_memory("Same");
448 local_mem
449 .metadata
450 .insert("a".to_string(), serde_json::json!(1));
451
452 let mut remote_mem = create_test_memory("Same");
453 remote_mem
454 .metadata
455 .insert("a".to_string(), serde_json::json!(2));
456
457 let local = SyncMemoryVersion::new(local_mem, "local");
458 let remote = SyncMemoryVersion::new(remote_mem, "remote");
459 let conflict = Conflict::new(1, None, local, remote, ConflictType::MetadataOnly);
460
461 assert_eq!(
462 resolver.suggest_strategy(&conflict),
463 ResolutionStrategy::AutoMerge
464 );
465 }
466}