Skip to main content

locus_sdk/application/
memory_recall.rs

1use std::collections::HashSet;
2use std::sync::Arc;
3
4use anyhow::Result;
5use locus_core_rs::ContextQueryService;
6use locus_core_rs::domain::contracts::NodeStore;
7use locus_core_rs::domain::models::{AvecState, PsiRange, SttpNode};
8
9use crate::application::memory_filters::{build_session_filter, node_matches_common_filters};
10use crate::domain::memory::{
11    FallbackPolicy, MemoryRecallRequest, MemoryRecallResult, RetrievalPath, clamp_limit,
12};
13
14pub struct MemoryRecallService {
15    context_query: ContextQueryService,
16}
17
18impl MemoryRecallService {
19    /// Create a recall service backed by the core resonance query pipeline.
20    pub fn new(store: Arc<dyn NodeStore>) -> Self {
21        Self {
22            context_query: ContextQueryService::new(store),
23        }
24    }
25
26    /// Retrieve context nodes using resonance or hybrid scoring,
27    /// with optional lexical fallback when configured.
28    pub async fn execute(&self, request: &MemoryRecallRequest) -> Result<MemoryRecallResult> {
29        let limit = clamp_limit(request.page.limit);
30        let expanded_limit = (limit.saturating_mul(5)).clamp(1, 200);
31
32        let current = request.current_avec.unwrap_or_else(AvecState::zero);
33        let session_scope = request
34            .scope
35            .session_ids
36            .as_deref()
37            .filter(|sessions| sessions.len() == 1)
38            .and_then(|sessions| sessions.first().map(String::as_str));
39        let session_filter = build_session_filter(&request.scope);
40
41        let mut path = if request.query_embedding.is_some() {
42            RetrievalPath::Hybrid
43        } else {
44            RetrievalPath::ResonanceOnly
45        };
46
47        let primary = if let Some(query_embedding) = request.query_embedding.as_deref() {
48            self.context_query
49                .get_context_hybrid_scoped_filtered_async(
50                    session_scope,
51                    current.stability,
52                    current.friction,
53                    current.logic,
54                    current.autonomy,
55                    request.scope.from_utc,
56                    request.scope.to_utc,
57                    request.scope.tiers.as_deref(),
58                    Some(query_embedding),
59                    request.scoring.alpha,
60                    request.scoring.beta,
61                    limit,
62                )
63                .await
64        } else {
65            self.context_query
66                .get_context_scoped_filtered_async(
67                    session_scope,
68                    current.stability,
69                    current.friction,
70                    current.logic,
71                    current.autonomy,
72                    request.scope.from_utc,
73                    request.scope.to_utc,
74                    request.scope.tiers.as_deref(),
75                    limit,
76                )
77                .await
78        };
79
80        let mut nodes = filter_nodes(primary.nodes, request, session_filter.as_ref());
81
82        if let Some(query_text) = request.query_text.as_deref() {
83            let need_fallback = match request.scoring.fallback_policy {
84                FallbackPolicy::Never => false,
85                FallbackPolicy::OnEmpty => nodes.is_empty(),
86                FallbackPolicy::Always => true,
87            };
88
89            if need_fallback {
90                let fallback_result = self
91                    .context_query
92                    .get_context_scoped_filtered_async(
93                        session_scope,
94                        current.stability,
95                        current.friction,
96                        current.logic,
97                        current.autonomy,
98                        request.scope.from_utc,
99                        request.scope.to_utc,
100                        request.scope.tiers.as_deref(),
101                        expanded_limit,
102                    )
103                    .await;
104
105                let lexical = lexical_filter(
106                    filter_nodes(fallback_result.nodes, request, session_filter.as_ref()),
107                    query_text,
108                );
109
110                if request.scoring.fallback_policy == FallbackPolicy::Always && !nodes.is_empty() {
111                    nodes = merge_unique(nodes, lexical);
112                } else {
113                    nodes = lexical;
114                }
115
116                path = RetrievalPath::LexicalFallback;
117            }
118        }
119
120        let has_more = nodes.len() > limit;
121        nodes.truncate(limit);
122
123        let next_cursor = nodes
124            .last()
125            .map(|node| format!("{}|{}", node.updated_at.to_rfc3339(), node.sync_key));
126
127        let psi_range = psi_range_from_nodes(&nodes);
128
129        Ok(MemoryRecallResult {
130            retrieved: nodes.len(),
131            nodes,
132            psi_range,
133            retrieval_path: path,
134            has_more,
135            next_cursor,
136        })
137    }
138}
139
140fn filter_nodes(
141    nodes: Vec<SttpNode>,
142    request: &MemoryRecallRequest,
143    session_filter: Option<&HashSet<String>>,
144) -> Vec<SttpNode> {
145    nodes.into_iter()
146        .filter(|node| {
147            node_matches_common_filters(node, &request.scope, &request.filter, session_filter)
148        })
149        .collect()
150}
151
152fn lexical_filter(nodes: Vec<SttpNode>, query_text: &str) -> Vec<SttpNode> {
153    let needle = query_text.trim().to_ascii_lowercase();
154    if needle.is_empty() {
155        return nodes;
156    }
157
158    let mut scored = nodes
159        .into_iter()
160        .filter_map(|node| {
161            let summary = node
162                .context_summary
163                .as_deref()
164                .unwrap_or_default()
165                .to_ascii_lowercase();
166            let session = node.session_id.to_ascii_lowercase();
167            let raw = node.raw.to_ascii_lowercase();
168
169            let mut score = 0usize;
170            if summary.contains(&needle) {
171                score += 3;
172            }
173            if session.contains(&needle) {
174                score += 2;
175            }
176            if raw.contains(&needle) {
177                score += 1;
178            }
179
180            if score > 0 {
181                Some((score, node.timestamp, node))
182            } else {
183                None
184            }
185        })
186        .collect::<Vec<_>>();
187
188    scored.sort_by(|left, right| right.0.cmp(&left.0).then_with(|| right.1.cmp(&left.1)));
189
190    scored.into_iter().map(|(_, _, node)| node).collect()
191}
192
193fn merge_unique(primary: Vec<SttpNode>, secondary: Vec<SttpNode>) -> Vec<SttpNode> {
194    let mut merged = Vec::with_capacity(primary.len() + secondary.len());
195    let mut seen = HashSet::new();
196
197    for node in primary.into_iter().chain(secondary.into_iter()) {
198        if seen.insert(node.sync_key.clone()) {
199            merged.push(node);
200        }
201    }
202
203    merged
204}
205
206fn psi_range_from_nodes(nodes: &[SttpNode]) -> PsiRange {
207    if nodes.is_empty() {
208        return PsiRange::default();
209    }
210
211    let (min, max, sum) = nodes
212        .iter()
213        .fold((f32::MAX, f32::MIN, 0.0_f32), |(min, max, sum), node| {
214            (min.min(node.psi), max.max(node.psi), sum + node.psi)
215        });
216
217    PsiRange {
218        min,
219        max,
220        average: sum / nodes.len() as f32,
221    }
222}