Skip to main content

foxloom/
store_manager.rs

1use std::{cmp::Ordering, collections::HashMap};
2
3use uuid::Uuid;
4
5use crate::{MemoryRecord, MemoryScope};
6
7#[derive(Debug, Clone)]
8pub struct StoreManagerConfig {
9    pub top_k: usize,
10}
11
12impl Default for StoreManagerConfig {
13    fn default() -> Self {
14        Self { top_k: 8 }
15    }
16}
17
18#[derive(Debug, Clone)]
19pub struct StoreManager {
20    config: StoreManagerConfig,
21}
22
23#[derive(Debug, Clone)]
24pub struct ScopeQuery {
25    pub session_id: String,
26    pub workspace_id: Option<String>,
27    pub user_id: Option<String>,
28}
29
30#[derive(Debug, Clone)]
31pub struct RetrievalCandidate {
32    pub record: MemoryRecord,
33    pub score: f32,
34    pub similarity: f32,
35    pub source: String,
36    pub selected_from_scope: MemoryScope,
37    pub selection_reason: String,
38}
39
40#[derive(Debug, Clone, Default)]
41pub struct MergeStats {
42    pub filtered_candidates: usize,
43    pub kept_candidates: usize,
44}
45
46impl StoreManager {
47    pub fn new(config: StoreManagerConfig) -> Self {
48        Self { config }
49    }
50
51    pub fn top_k(&self) -> usize {
52        self.config.top_k.max(1)
53    }
54
55    pub fn scope_allows(&self, scope: &ScopeQuery, record: &MemoryRecord) -> bool {
56        let session_match = record.session_id.as_deref() == Some(scope.session_id.as_str());
57        let workspace_match = scope
58            .workspace_id
59            .as_deref()
60            .is_some_and(|workspace_id| record.workspace_id.as_deref() == Some(workspace_id));
61        let user_match = scope
62            .user_id
63            .as_deref()
64            .is_some_and(|user_id| record.user_id.as_deref() == Some(user_id));
65        session_match || workspace_match || user_match || record.scope == MemoryScope::Global
66    }
67
68    pub fn merge_scoped_candidates(
69        &self,
70        candidates: Vec<RetrievalCandidate>,
71    ) -> (Vec<RetrievalCandidate>, MergeStats) {
72        let mut by_id = HashMap::<Uuid, RetrievalCandidate>::new();
73        for candidate in candidates {
74            by_id.entry(candidate.record.memory_id).or_insert(candidate);
75        }
76
77        let mut by_entity = HashMap::<String, Vec<RetrievalCandidate>>::new();
78        let mut kept = Vec::new();
79        let mut filtered = 0usize;
80
81        for candidate in by_id.into_values() {
82            if let Some(entity) = normalized_entity_key(&candidate.record) {
83                by_entity.entry(entity).or_default().push(candidate);
84            } else {
85                kept.push(mark_non_conflicting(candidate));
86            }
87        }
88
89        for (_, mut conflicts) in by_entity {
90            if conflicts.is_empty() {
91                continue;
92            }
93            conflicts.sort_by(compare_scope_then_score);
94            let mut winner = conflicts.remove(0);
95            if let Some(second) = conflicts.first() {
96                winner.selection_reason = selection_reason(second, &winner);
97            } else {
98                winner.selection_reason = "non_conflicting".to_string();
99            }
100            filtered = filtered.saturating_add(conflicts.len());
101            kept.push(winner);
102        }
103
104        apply_scope_precedence(&mut kept);
105        let top_k = self.top_k();
106        if kept.len() > top_k {
107            filtered = filtered.saturating_add(kept.len() - top_k);
108            kept.truncate(top_k);
109        }
110        let kept_count = kept.len();
111        (
112            kept,
113            MergeStats {
114                filtered_candidates: filtered,
115                kept_candidates: kept_count,
116            },
117        )
118    }
119
120    pub fn scope_rank(scope: &MemoryScope) -> usize {
121        scope_rank(scope)
122    }
123}
124
125fn apply_scope_precedence(items: &mut [RetrievalCandidate]) {
126    items.sort_by(compare_scope_then_score);
127}
128
129fn compare_scope_then_score(left: &RetrievalCandidate, right: &RetrievalCandidate) -> Ordering {
130    let left_rank = scope_rank(&left.record.scope);
131    let right_rank = scope_rank(&right.record.scope);
132    left_rank
133        .cmp(&right_rank)
134        .then_with(|| {
135            right
136                .score
137                .partial_cmp(&left.score)
138                .unwrap_or(Ordering::Equal)
139        })
140        .then_with(|| {
141            left.record
142                .memory_id
143                .as_u128()
144                .cmp(&right.record.memory_id.as_u128())
145        })
146}
147
148fn selection_reason(existing: &RetrievalCandidate, replacement: &RetrievalCandidate) -> String {
149    let existing_rank = scope_rank(&existing.record.scope);
150    let replacement_rank = scope_rank(&replacement.record.scope);
151    if replacement_rank < existing_rank {
152        "higher_precedence".to_string()
153    } else {
154        "higher_score".to_string()
155    }
156}
157
158fn mark_non_conflicting(mut candidate: RetrievalCandidate) -> RetrievalCandidate {
159    candidate.selection_reason = "non_conflicting".to_string();
160    candidate
161}
162
163fn scope_rank(scope: &MemoryScope) -> usize {
164    match scope {
165        MemoryScope::Workspace => 0,
166        MemoryScope::Session => 1,
167        MemoryScope::User => 2,
168        MemoryScope::Global => 3,
169    }
170}
171
172fn normalized_entity_key(record: &MemoryRecord) -> Option<String> {
173    if let Some(entity) = record
174        .json_fields
175        .get("entity")
176        .and_then(|value| value.as_str())
177        .map(str::trim)
178        .filter(|v| !v.is_empty())
179    {
180        return Some(normalize_entity_text(entity));
181    }
182
183    extract_entity_from_text(&record.text).map(normalize_entity_text)
184}
185
186fn extract_entity_from_text(text: &str) -> Option<&str> {
187    let trimmed = text.trim();
188    if trimmed.is_empty() {
189        return None;
190    }
191    if let Some((lhs, _)) = trimmed.split_once(" is ") {
192        let lhs = lhs.trim();
193        if !lhs.is_empty() {
194            return Some(lhs);
195        }
196    }
197    if let Some((lhs, _)) = trimmed.split_once(':') {
198        let lhs = lhs.trim();
199        if !lhs.is_empty() {
200            return Some(lhs);
201        }
202    }
203    None
204}
205
206fn normalize_entity_text(value: &str) -> String {
207    let mut out = String::with_capacity(value.len());
208    let mut prev_space = false;
209    for c in value.chars() {
210        if c.is_ascii_alphanumeric() {
211            out.push(c.to_ascii_lowercase());
212            prev_space = false;
213        } else if !prev_space {
214            out.push(' ');
215            prev_space = true;
216        }
217    }
218    out.trim().to_string()
219}
220
221#[cfg(test)]
222mod tests {
223    use serde_json::json;
224
225    use crate::MemoryType;
226
227    use super::*;
228
229    fn record(id: u128, scope: MemoryScope, text: &str, entity: Option<&str>) -> MemoryRecord {
230        let mut r = MemoryRecord::new(Uuid::from_u128(id), scope, MemoryType::Policy, text.to_string());
231        r.json_fields = entity.map(|e| json!({"entity": e})).unwrap_or(serde_json::Value::Null);
232        r
233    }
234
235    fn candidate(record: MemoryRecord, score: f32) -> RetrievalCandidate {
236        RetrievalCandidate {
237            selected_from_scope: record.scope.clone(),
238            selection_reason: "non_conflicting".to_string(),
239            similarity: 0.8,
240            source: "test".to_string(),
241            score,
242            record,
243        }
244    }
245
246    #[test]
247    fn merge_prefers_workspace_over_global_for_same_entity() {
248        let mgr = StoreManager::new(StoreManagerConfig { top_k: 8 });
249        let global = candidate(record(1, MemoryScope::Global, "owner is team atlas", Some("owner")), 0.95);
250        let workspace = candidate(record(2, MemoryScope::Workspace, "owner is team zeus", Some("owner")), 0.80);
251
252        let (merged, stats) = mgr.merge_scoped_candidates(vec![global, workspace]);
253        assert_eq!(merged.len(), 1);
254        assert_eq!(merged[0].record.memory_id, Uuid::from_u128(2));
255        assert_eq!(merged[0].selection_reason, "higher_precedence");
256        assert_eq!(stats.filtered_candidates, 1);
257    }
258
259    #[test]
260    fn merge_uses_uuid_tiebreaker_deterministically() {
261        let mgr = StoreManager::new(StoreManagerConfig { top_k: 8 });
262        let a = candidate(record(1, MemoryScope::Session, "a", None), 1.0);
263        let b = candidate(record(2, MemoryScope::Session, "b", None), 1.0);
264        let (merged, _) = mgr.merge_scoped_candidates(vec![b, a]);
265        assert_eq!(merged[0].record.memory_id, Uuid::from_u128(1));
266        assert_eq!(merged[1].record.memory_id, Uuid::from_u128(2));
267    }
268
269    #[test]
270    fn scope_filter_blocks_cross_workspace_leakage() {
271        let mgr = StoreManager::new(StoreManagerConfig { top_k: 8 });
272        let scope = ScopeQuery {
273            session_id: "s1".to_string(),
274            workspace_id: Some("wa".to_string()),
275            user_id: Some("u1".to_string()),
276        };
277        let mut other_ws = record(11, MemoryScope::Workspace, "owner is team zeus", Some("owner"));
278        other_ws.workspace_id = Some("wb".to_string());
279        assert!(!mgr.scope_allows(&scope, &other_ws));
280
281        let mut global = record(12, MemoryScope::Global, "timezone is UTC", Some("timezone"));
282        global.workspace_id = None;
283        assert!(mgr.scope_allows(&scope, &global));
284    }
285}
286
287