1use std::collections::HashSet;
2use std::sync::Arc;
3
4use anyhow::Result;
5use locus_core_rs::ContextQueryService;
6use locus_core_rs::domain::contracts::NodeStore;
7use locus_core_rs::domain::models::{AvecState, PsiRange, SttpNode};
8
9use crate::application::memory_filters::{build_session_filter, node_matches_common_filters};
10use crate::domain::memory::{
11 FallbackPolicy, MemoryRecallRequest, MemoryRecallResult, RetrievalPath, clamp_limit,
12};
13
14pub struct MemoryRecallService {
15 context_query: ContextQueryService,
16}
17
18impl MemoryRecallService {
19 pub fn new(store: Arc<dyn NodeStore>) -> Self {
21 Self {
22 context_query: ContextQueryService::new(store),
23 }
24 }
25
26 pub async fn execute(&self, request: &MemoryRecallRequest) -> Result<MemoryRecallResult> {
29 let limit = clamp_limit(request.page.limit);
30 let expanded_limit = (limit.saturating_mul(5)).clamp(1, 200);
31
32 let current = request.current_avec.unwrap_or_else(AvecState::zero);
33 let session_scope = request
34 .scope
35 .session_ids
36 .as_deref()
37 .filter(|sessions| sessions.len() == 1)
38 .and_then(|sessions| sessions.first().map(String::as_str));
39 let session_filter = build_session_filter(&request.scope);
40
41 let mut path = if request.query_embedding.is_some() {
42 RetrievalPath::Hybrid
43 } else {
44 RetrievalPath::ResonanceOnly
45 };
46
47 let primary = if let Some(query_embedding) = request.query_embedding.as_deref() {
48 self.context_query
49 .get_context_hybrid_scoped_filtered_async(
50 session_scope,
51 current.stability,
52 current.friction,
53 current.logic,
54 current.autonomy,
55 request.scope.from_utc,
56 request.scope.to_utc,
57 request.scope.tiers.as_deref(),
58 Some(query_embedding),
59 request.scoring.alpha,
60 request.scoring.beta,
61 limit,
62 )
63 .await
64 } else {
65 self.context_query
66 .get_context_scoped_filtered_async(
67 session_scope,
68 current.stability,
69 current.friction,
70 current.logic,
71 current.autonomy,
72 request.scope.from_utc,
73 request.scope.to_utc,
74 request.scope.tiers.as_deref(),
75 limit,
76 )
77 .await
78 };
79
80 let mut nodes = filter_nodes(primary.nodes, request, session_filter.as_ref());
81
82 if let Some(query_text) = request.query_text.as_deref() {
83 let need_fallback = match request.scoring.fallback_policy {
84 FallbackPolicy::Never => false,
85 FallbackPolicy::OnEmpty => nodes.is_empty(),
86 FallbackPolicy::Always => true,
87 };
88
89 if need_fallback {
90 let fallback_result = self
91 .context_query
92 .get_context_scoped_filtered_async(
93 session_scope,
94 current.stability,
95 current.friction,
96 current.logic,
97 current.autonomy,
98 request.scope.from_utc,
99 request.scope.to_utc,
100 request.scope.tiers.as_deref(),
101 expanded_limit,
102 )
103 .await;
104
105 let lexical = lexical_filter(
106 filter_nodes(fallback_result.nodes, request, session_filter.as_ref()),
107 query_text,
108 );
109
110 if request.scoring.fallback_policy == FallbackPolicy::Always && !nodes.is_empty() {
111 nodes = merge_unique(nodes, lexical);
112 } else {
113 nodes = lexical;
114 }
115
116 path = RetrievalPath::LexicalFallback;
117 }
118 }
119
120 let has_more = nodes.len() > limit;
121 nodes.truncate(limit);
122
123 let next_cursor = nodes
124 .last()
125 .map(|node| format!("{}|{}", node.updated_at.to_rfc3339(), node.sync_key));
126
127 let psi_range = psi_range_from_nodes(&nodes);
128
129 Ok(MemoryRecallResult {
130 retrieved: nodes.len(),
131 nodes,
132 psi_range,
133 retrieval_path: path,
134 has_more,
135 next_cursor,
136 })
137 }
138}
139
140fn filter_nodes(
141 nodes: Vec<SttpNode>,
142 request: &MemoryRecallRequest,
143 session_filter: Option<&HashSet<String>>,
144) -> Vec<SttpNode> {
145 nodes.into_iter()
146 .filter(|node| {
147 node_matches_common_filters(node, &request.scope, &request.filter, session_filter)
148 })
149 .collect()
150}
151
152fn lexical_filter(nodes: Vec<SttpNode>, query_text: &str) -> Vec<SttpNode> {
153 let needle = query_text.trim().to_ascii_lowercase();
154 if needle.is_empty() {
155 return nodes;
156 }
157
158 let mut scored = nodes
159 .into_iter()
160 .filter_map(|node| {
161 let summary = node
162 .context_summary
163 .as_deref()
164 .unwrap_or_default()
165 .to_ascii_lowercase();
166 let session = node.session_id.to_ascii_lowercase();
167 let raw = node.raw.to_ascii_lowercase();
168
169 let mut score = 0usize;
170 if summary.contains(&needle) {
171 score += 3;
172 }
173 if session.contains(&needle) {
174 score += 2;
175 }
176 if raw.contains(&needle) {
177 score += 1;
178 }
179
180 if score > 0 {
181 Some((score, node.timestamp, node))
182 } else {
183 None
184 }
185 })
186 .collect::<Vec<_>>();
187
188 scored.sort_by(|left, right| right.0.cmp(&left.0).then_with(|| right.1.cmp(&left.1)));
189
190 scored.into_iter().map(|(_, _, node)| node).collect()
191}
192
193fn merge_unique(primary: Vec<SttpNode>, secondary: Vec<SttpNode>) -> Vec<SttpNode> {
194 let mut merged = Vec::with_capacity(primary.len() + secondary.len());
195 let mut seen = HashSet::new();
196
197 for node in primary.into_iter().chain(secondary.into_iter()) {
198 if seen.insert(node.sync_key.clone()) {
199 merged.push(node);
200 }
201 }
202
203 merged
204}
205
206fn psi_range_from_nodes(nodes: &[SttpNode]) -> PsiRange {
207 if nodes.is_empty() {
208 return PsiRange::default();
209 }
210
211 let (min, max, sum) = nodes
212 .iter()
213 .fold((f32::MAX, f32::MIN, 0.0_f32), |(min, max, sum), node| {
214 (min.min(node.psi), max.max(node.psi), sum + node.psi)
215 });
216
217 PsiRange {
218 min,
219 max,
220 average: sum / nodes.len() as f32,
221 }
222}