1use std::{cmp::Ordering, collections::HashMap};
2
3use uuid::Uuid;
4
5use crate::{MemoryRecord, MemoryScope};
6
7#[derive(Debug, Clone)]
8pub struct StoreManagerConfig {
9 pub top_k: usize,
10}
11
12impl Default for StoreManagerConfig {
13 fn default() -> Self {
14 Self { top_k: 8 }
15 }
16}
17
18#[derive(Debug, Clone)]
19pub struct StoreManager {
20 config: StoreManagerConfig,
21}
22
23#[derive(Debug, Clone)]
24pub struct ScopeQuery {
25 pub session_id: String,
26 pub workspace_id: Option<String>,
27 pub user_id: Option<String>,
28}
29
30#[derive(Debug, Clone)]
31pub struct RetrievalCandidate {
32 pub record: MemoryRecord,
33 pub score: f32,
34 pub similarity: f32,
35 pub source: String,
36 pub selected_from_scope: MemoryScope,
37 pub selection_reason: String,
38}
39
40#[derive(Debug, Clone, Default)]
41pub struct MergeStats {
42 pub filtered_candidates: usize,
43 pub kept_candidates: usize,
44}
45
46impl StoreManager {
47 pub fn new(config: StoreManagerConfig) -> Self {
48 Self { config }
49 }
50
51 pub fn top_k(&self) -> usize {
52 self.config.top_k.max(1)
53 }
54
55 pub fn scope_allows(&self, scope: &ScopeQuery, record: &MemoryRecord) -> bool {
56 let session_match = record.session_id.as_deref() == Some(scope.session_id.as_str());
57 let workspace_match = scope
58 .workspace_id
59 .as_deref()
60 .is_some_and(|workspace_id| record.workspace_id.as_deref() == Some(workspace_id));
61 let user_match = scope
62 .user_id
63 .as_deref()
64 .is_some_and(|user_id| record.user_id.as_deref() == Some(user_id));
65 session_match || workspace_match || user_match || record.scope == MemoryScope::Global
66 }
67
68 pub fn merge_scoped_candidates(
69 &self,
70 candidates: Vec<RetrievalCandidate>,
71 ) -> (Vec<RetrievalCandidate>, MergeStats) {
72 let mut by_id = HashMap::<Uuid, RetrievalCandidate>::new();
73 for candidate in candidates {
74 by_id.entry(candidate.record.memory_id).or_insert(candidate);
75 }
76
77 let mut by_entity = HashMap::<String, Vec<RetrievalCandidate>>::new();
78 let mut kept = Vec::new();
79 let mut filtered = 0usize;
80
81 for candidate in by_id.into_values() {
82 if let Some(entity) = normalized_entity_key(&candidate.record) {
83 by_entity.entry(entity).or_default().push(candidate);
84 } else {
85 kept.push(mark_non_conflicting(candidate));
86 }
87 }
88
89 for (_, mut conflicts) in by_entity {
90 if conflicts.is_empty() {
91 continue;
92 }
93 conflicts.sort_by(compare_scope_then_score);
94 let mut winner = conflicts.remove(0);
95 if let Some(second) = conflicts.first() {
96 winner.selection_reason = selection_reason(second, &winner);
97 } else {
98 winner.selection_reason = "non_conflicting".to_string();
99 }
100 filtered = filtered.saturating_add(conflicts.len());
101 kept.push(winner);
102 }
103
104 apply_scope_precedence(&mut kept);
105 let top_k = self.top_k();
106 if kept.len() > top_k {
107 filtered = filtered.saturating_add(kept.len() - top_k);
108 kept.truncate(top_k);
109 }
110 let kept_count = kept.len();
111 (
112 kept,
113 MergeStats {
114 filtered_candidates: filtered,
115 kept_candidates: kept_count,
116 },
117 )
118 }
119
120 pub fn scope_rank(scope: &MemoryScope) -> usize {
121 scope_rank(scope)
122 }
123}
124
125fn apply_scope_precedence(items: &mut [RetrievalCandidate]) {
126 items.sort_by(compare_scope_then_score);
127}
128
129fn compare_scope_then_score(left: &RetrievalCandidate, right: &RetrievalCandidate) -> Ordering {
130 let left_rank = scope_rank(&left.record.scope);
131 let right_rank = scope_rank(&right.record.scope);
132 left_rank
133 .cmp(&right_rank)
134 .then_with(|| {
135 right
136 .score
137 .partial_cmp(&left.score)
138 .unwrap_or(Ordering::Equal)
139 })
140 .then_with(|| {
141 left.record
142 .memory_id
143 .as_u128()
144 .cmp(&right.record.memory_id.as_u128())
145 })
146}
147
148fn selection_reason(existing: &RetrievalCandidate, replacement: &RetrievalCandidate) -> String {
149 let existing_rank = scope_rank(&existing.record.scope);
150 let replacement_rank = scope_rank(&replacement.record.scope);
151 if replacement_rank < existing_rank {
152 "higher_precedence".to_string()
153 } else {
154 "higher_score".to_string()
155 }
156}
157
158fn mark_non_conflicting(mut candidate: RetrievalCandidate) -> RetrievalCandidate {
159 candidate.selection_reason = "non_conflicting".to_string();
160 candidate
161}
162
163fn scope_rank(scope: &MemoryScope) -> usize {
164 match scope {
165 MemoryScope::Workspace => 0,
166 MemoryScope::Session => 1,
167 MemoryScope::User => 2,
168 MemoryScope::Global => 3,
169 }
170}
171
172fn normalized_entity_key(record: &MemoryRecord) -> Option<String> {
173 if let Some(entity) = record
174 .json_fields
175 .get("entity")
176 .and_then(|value| value.as_str())
177 .map(str::trim)
178 .filter(|v| !v.is_empty())
179 {
180 return Some(normalize_entity_text(entity));
181 }
182
183 extract_entity_from_text(&record.text).map(normalize_entity_text)
184}
185
186fn extract_entity_from_text(text: &str) -> Option<&str> {
187 let trimmed = text.trim();
188 if trimmed.is_empty() {
189 return None;
190 }
191 if let Some((lhs, _)) = trimmed.split_once(" is ") {
192 let lhs = lhs.trim();
193 if !lhs.is_empty() {
194 return Some(lhs);
195 }
196 }
197 if let Some((lhs, _)) = trimmed.split_once(':') {
198 let lhs = lhs.trim();
199 if !lhs.is_empty() {
200 return Some(lhs);
201 }
202 }
203 None
204}
205
206fn normalize_entity_text(value: &str) -> String {
207 let mut out = String::with_capacity(value.len());
208 let mut prev_space = false;
209 for c in value.chars() {
210 if c.is_ascii_alphanumeric() {
211 out.push(c.to_ascii_lowercase());
212 prev_space = false;
213 } else if !prev_space {
214 out.push(' ');
215 prev_space = true;
216 }
217 }
218 out.trim().to_string()
219}
220
221#[cfg(test)]
222mod tests {
223 use serde_json::json;
224
225 use crate::MemoryType;
226
227 use super::*;
228
229 fn record(id: u128, scope: MemoryScope, text: &str, entity: Option<&str>) -> MemoryRecord {
230 let mut r = MemoryRecord::new(Uuid::from_u128(id), scope, MemoryType::Policy, text.to_string());
231 r.json_fields = entity.map(|e| json!({"entity": e})).unwrap_or(serde_json::Value::Null);
232 r
233 }
234
235 fn candidate(record: MemoryRecord, score: f32) -> RetrievalCandidate {
236 RetrievalCandidate {
237 selected_from_scope: record.scope.clone(),
238 selection_reason: "non_conflicting".to_string(),
239 similarity: 0.8,
240 source: "test".to_string(),
241 score,
242 record,
243 }
244 }
245
246 #[test]
247 fn merge_prefers_workspace_over_global_for_same_entity() {
248 let mgr = StoreManager::new(StoreManagerConfig { top_k: 8 });
249 let global = candidate(record(1, MemoryScope::Global, "owner is team atlas", Some("owner")), 0.95);
250 let workspace = candidate(record(2, MemoryScope::Workspace, "owner is team zeus", Some("owner")), 0.80);
251
252 let (merged, stats) = mgr.merge_scoped_candidates(vec![global, workspace]);
253 assert_eq!(merged.len(), 1);
254 assert_eq!(merged[0].record.memory_id, Uuid::from_u128(2));
255 assert_eq!(merged[0].selection_reason, "higher_precedence");
256 assert_eq!(stats.filtered_candidates, 1);
257 }
258
259 #[test]
260 fn merge_uses_uuid_tiebreaker_deterministically() {
261 let mgr = StoreManager::new(StoreManagerConfig { top_k: 8 });
262 let a = candidate(record(1, MemoryScope::Session, "a", None), 1.0);
263 let b = candidate(record(2, MemoryScope::Session, "b", None), 1.0);
264 let (merged, _) = mgr.merge_scoped_candidates(vec![b, a]);
265 assert_eq!(merged[0].record.memory_id, Uuid::from_u128(1));
266 assert_eq!(merged[1].record.memory_id, Uuid::from_u128(2));
267 }
268
269 #[test]
270 fn scope_filter_blocks_cross_workspace_leakage() {
271 let mgr = StoreManager::new(StoreManagerConfig { top_k: 8 });
272 let scope = ScopeQuery {
273 session_id: "s1".to_string(),
274 workspace_id: Some("wa".to_string()),
275 user_id: Some("u1".to_string()),
276 };
277 let mut other_ws = record(11, MemoryScope::Workspace, "owner is team zeus", Some("owner"));
278 other_ws.workspace_id = Some("wb".to_string());
279 assert!(!mgr.scope_allows(&scope, &other_ws));
280
281 let mut global = record(12, MemoryScope::Global, "timezone is UTC", Some("timezone"));
282 global.workspace_id = None;
283 assert!(mgr.scope_allows(&scope, &global));
284 }
285}
286
287