Skip to main content

mentedb_core/
conflict.rs

1//! Conflict Resolution: detect and resolve concurrent-write conflicts.
2
3use serde::{Deserialize, Serialize};
4
5use crate::types::{AgentId, MemoryId, Timestamp};
6
7/// One agent's version of a memory that may be in conflict.
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct ConflictVersion {
10    /// The agent that produced this version.
11    pub agent_id: AgentId,
12    /// The memory content for this version.
13    pub content: String,
14    /// Confidence score assigned by the authoring agent.
15    pub confidence: f32,
16    /// When this version was created.
17    pub timestamp: Timestamp,
18}
19
20/// How a conflict was (or should be) resolved.
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub enum Resolution {
23    KeepLatest,
24    KeepHighestConfidence,
25    Merge(String),
26    Manual(String),
27}
28
29/// A detected conflict on a single memory.
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct Conflict {
32    /// The memory that has conflicting versions.
33    pub memory_id: MemoryId,
34    /// The competing versions of this memory.
35    pub versions: Vec<ConflictVersion>,
36    /// The chosen resolution, if any.
37    pub resolution: Option<Resolution>,
38}
39
40/// Detects and resolves multi-agent write conflicts.
41#[derive(Debug, Default)]
42pub struct ConflictResolver;
43
44/// Threshold in microseconds: versions within this window count as concurrent.
45const CONFLICT_WINDOW_US: Timestamp = 1_000_000; // 1 second
46
47impl ConflictResolver {
48    /// Creates a new conflict resolver.
49    pub fn new() -> Self {
50        Self
51    }
52
53    /// Detect a conflict: two or more versions by different agents within 1 second.
54    pub fn detect_conflict(
55        &self,
56        memory_id: MemoryId,
57        versions: &[ConflictVersion],
58    ) -> Option<Conflict> {
59        if versions.len() < 2 {
60            return None;
61        }
62
63        // Collect versions whose timestamps are within the conflict window of any other
64        // version written by a different agent.
65        let mut dominated = vec![false; versions.len()];
66        for i in 0..versions.len() {
67            for j in (i + 1)..versions.len() {
68                let dt = versions[i].timestamp.abs_diff(versions[j].timestamp);
69                if dt <= CONFLICT_WINDOW_US && versions[i].agent_id != versions[j].agent_id {
70                    dominated[i] = true;
71                    dominated[j] = true;
72                }
73            }
74        }
75
76        let conflicting: Vec<ConflictVersion> = versions
77            .iter()
78            .zip(dominated.iter())
79            .filter(|&(_, d)| *d)
80            .map(|(v, _)| v.clone())
81            .collect();
82
83        if conflicting.len() >= 2 {
84            Some(Conflict {
85                memory_id,
86                versions: conflicting,
87                resolution: None,
88            })
89        } else {
90            None
91        }
92    }
93
94    /// Resolve a conflict using the given strategy, returning the winning version.
95    pub fn auto_resolve(&self, conflict: &Conflict, strategy: Resolution) -> ConflictVersion {
96        match &strategy {
97            Resolution::KeepLatest => self.resolve_keep_latest(conflict),
98            Resolution::KeepHighestConfidence => self.resolve_keep_highest_confidence(conflict),
99            Resolution::Merge(merged) => ConflictVersion {
100                agent_id: conflict.versions[0].agent_id,
101                content: merged.clone(),
102                confidence: conflict
103                    .versions
104                    .iter()
105                    .map(|v| v.confidence)
106                    .fold(0.0_f32, f32::max),
107                timestamp: conflict
108                    .versions
109                    .iter()
110                    .map(|v| v.timestamp)
111                    .max()
112                    .unwrap_or(0),
113            },
114            Resolution::Manual(text) => ConflictVersion {
115                agent_id: conflict.versions[0].agent_id,
116                content: text.clone(),
117                confidence: 1.0,
118                timestamp: conflict
119                    .versions
120                    .iter()
121                    .map(|v| v.timestamp)
122                    .max()
123                    .unwrap_or(0),
124            },
125        }
126    }
127
128    /// Pick the version with the highest timestamp.
129    pub fn resolve_keep_latest(&self, conflict: &Conflict) -> ConflictVersion {
130        conflict
131            .versions
132            .iter()
133            .max_by_key(|v| v.timestamp)
134            .cloned()
135            .expect("conflict must have at least one version")
136    }
137
138    /// Pick the version with the highest confidence score.
139    pub fn resolve_keep_highest_confidence(&self, conflict: &Conflict) -> ConflictVersion {
140        conflict
141            .versions
142            .iter()
143            .max_by(|a, b| a.confidence.partial_cmp(&b.confidence).unwrap())
144            .cloned()
145            .expect("conflict must have at least one version")
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152
153    fn make_version(
154        agent: AgentId,
155        content: &str,
156        confidence: f32,
157        ts: Timestamp,
158    ) -> ConflictVersion {
159        ConflictVersion {
160            agent_id: agent,
161            content: content.to_string(),
162            confidence,
163            timestamp: ts,
164        }
165    }
166
167    #[test]
168    fn no_conflict_single_version() {
169        let r = ConflictResolver::new();
170        let mid = MemoryId::new();
171        let v = make_version(AgentId::new(), "a", 0.9, 100);
172        assert!(r.detect_conflict(mid, &[v]).is_none());
173    }
174
175    #[test]
176    fn no_conflict_same_agent() {
177        let r = ConflictResolver::new();
178        let mid = MemoryId::new();
179        let a = AgentId::new();
180        let v1 = make_version(a, "a", 0.9, 100);
181        let v2 = make_version(a, "b", 0.8, 200);
182        assert!(r.detect_conflict(mid, &[v1, v2]).is_none());
183    }
184
185    #[test]
186    fn detect_conflict_different_agents() {
187        let r = ConflictResolver::new();
188        let mid = MemoryId::new();
189        let a1 = AgentId::new();
190        let a2 = AgentId::new();
191        let v1 = make_version(a1, "v1", 0.8, 1_000_000);
192        let v2 = make_version(a2, "v2", 0.9, 1_500_000);
193        let conflict = r.detect_conflict(mid, &[v1, v2]);
194        assert!(conflict.is_some());
195        assert_eq!(conflict.unwrap().versions.len(), 2);
196    }
197
198    #[test]
199    fn resolve_keep_latest() {
200        let r = ConflictResolver::new();
201        let a1 = AgentId::new();
202        let a2 = AgentId::new();
203        let conflict = Conflict {
204            memory_id: MemoryId::new(),
205            versions: vec![
206                make_version(a1, "old", 0.9, 100),
207                make_version(a2, "new", 0.5, 200),
208            ],
209            resolution: None,
210        };
211        let winner = r.auto_resolve(&conflict, Resolution::KeepLatest);
212        assert_eq!(winner.content, "new");
213    }
214
215    #[test]
216    fn resolve_keep_highest_confidence() {
217        let r = ConflictResolver::new();
218        let a1 = AgentId::new();
219        let a2 = AgentId::new();
220        let conflict = Conflict {
221            memory_id: MemoryId::new(),
222            versions: vec![
223                make_version(a1, "confident", 0.95, 100),
224                make_version(a2, "unsure", 0.3, 200),
225            ],
226            resolution: None,
227        };
228        let winner = r.auto_resolve(&conflict, Resolution::KeepHighestConfidence);
229        assert_eq!(winner.content, "confident");
230    }
231}