1use std::collections::HashSet;
2use std::sync::Arc;
3
4use anyhow::Result;
5use locus_core_rs::ContextQueryService;
6use locus_core_rs::domain::contracts::NodeStore;
7use locus_core_rs::domain::models::{AvecState, SttpNode};
8
9use crate::application::memory_filters::{build_session_filter, node_matches_common_filters};
10use crate::domain::memory::{
11 FallbackPolicy, MemoryExplainRequest, MemoryExplainResult, MemoryExplainStage, RetrievalPath,
12 clamp_limit,
13};
14
15pub struct MemoryExplainService {
16 context_query: ContextQueryService,
17}
18
19impl MemoryExplainService {
20 pub fn new(store: Arc<dyn NodeStore>) -> Self {
22 Self {
23 context_query: ContextQueryService::new(store),
24 }
25 }
26
27 pub async fn execute(&self, request: &MemoryExplainRequest) -> Result<MemoryExplainResult> {
32 let recall = &request.recall;
33 let limit = clamp_limit(recall.page.limit);
34 let expanded_limit = (limit.saturating_mul(5)).clamp(1, 200);
35
36 let current = recall.current_avec.unwrap_or_else(AvecState::zero);
37 let session_scope = recall
38 .scope
39 .session_ids
40 .as_deref()
41 .filter(|sessions| sessions.len() == 1)
42 .and_then(|sessions| sessions.first().map(String::as_str));
43 let session_filter = build_session_filter(&recall.scope);
44
45 let mut stages = Vec::new();
46 let mut path = if recall.query_embedding.is_some() {
47 RetrievalPath::Hybrid
48 } else {
49 RetrievalPath::ResonanceOnly
50 };
51 let mut fallback_triggered = false;
52 let mut fallback_reason = None;
53
54 let primary = if let Some(query_embedding) = recall.query_embedding.as_deref() {
55 self.context_query
56 .get_context_hybrid_scoped_filtered_async(
57 session_scope,
58 current.stability,
59 current.friction,
60 current.logic,
61 current.autonomy,
62 recall.scope.from_utc,
63 recall.scope.to_utc,
64 recall.scope.tiers.as_deref(),
65 Some(query_embedding),
66 recall.scoring.alpha,
67 recall.scoring.beta,
68 limit,
69 )
70 .await
71 } else {
72 self.context_query
73 .get_context_scoped_filtered_async(
74 session_scope,
75 current.stability,
76 current.friction,
77 current.logic,
78 current.autonomy,
79 recall.scope.from_utc,
80 recall.scope.to_utc,
81 recall.scope.tiers.as_deref(),
82 limit,
83 )
84 .await
85 };
86
87 stages.push(MemoryExplainStage {
88 stage: "primary_retrieval".to_string(),
89 count: primary.nodes.len(),
90 });
91
92 let filtered_primary = filter_nodes(primary.nodes, recall, session_filter.as_ref());
93 stages.push(MemoryExplainStage {
94 stage: "after_common_filter".to_string(),
95 count: filtered_primary.len(),
96 });
97
98 if let Some(query_text) = recall.query_text.as_deref() {
99 let need_fallback = match recall.scoring.fallback_policy {
100 FallbackPolicy::Never => false,
101 FallbackPolicy::OnEmpty => filtered_primary.is_empty(),
102 FallbackPolicy::Always => true,
103 };
104
105 if need_fallback {
106 fallback_triggered = true;
107 fallback_reason = Some(match recall.scoring.fallback_policy {
108 FallbackPolicy::Never => "never".to_string(),
109 FallbackPolicy::OnEmpty => {
110 "fallback_policy=on_empty and primary result set is empty".to_string()
111 }
112 FallbackPolicy::Always => "fallback_policy=always".to_string(),
113 });
114
115 let fallback = self
116 .context_query
117 .get_context_scoped_filtered_async(
118 session_scope,
119 current.stability,
120 current.friction,
121 current.logic,
122 current.autonomy,
123 recall.scope.from_utc,
124 recall.scope.to_utc,
125 recall.scope.tiers.as_deref(),
126 expanded_limit,
127 )
128 .await;
129
130 stages.push(MemoryExplainStage {
131 stage: "fallback_retrieval".to_string(),
132 count: fallback.nodes.len(),
133 });
134
135 let filtered_fallback =
136 filter_nodes(fallback.nodes, recall, session_filter.as_ref());
137 stages.push(MemoryExplainStage {
138 stage: "fallback_after_common_filter".to_string(),
139 count: filtered_fallback.len(),
140 });
141
142 let lexical = lexical_filter(filtered_fallback, query_text);
143 stages.push(MemoryExplainStage {
144 stage: "lexical_filter".to_string(),
145 count: lexical.len(),
146 });
147
148 path = RetrievalPath::LexicalFallback;
149 }
150 }
151
152 Ok(MemoryExplainResult {
153 retrieval_path: path,
154 fallback_triggered,
155 fallback_reason,
156 stages,
157 scoring: recall.scoring.clone(),
158 })
159 }
160}
161
162fn filter_nodes(
163 nodes: Vec<SttpNode>,
164 request: &crate::domain::memory::MemoryRecallRequest,
165 session_filter: Option<&HashSet<String>>,
166) -> Vec<SttpNode> {
167 nodes.into_iter()
168 .filter(|node| {
169 node_matches_common_filters(node, &request.scope, &request.filter, session_filter)
170 })
171 .collect()
172}
173
174fn lexical_filter(nodes: Vec<SttpNode>, query_text: &str) -> Vec<SttpNode> {
175 let needle = query_text.trim().to_ascii_lowercase();
176 if needle.is_empty() {
177 return nodes;
178 }
179
180 let mut scored = nodes
181 .into_iter()
182 .filter_map(|node| {
183 let summary = node
184 .context_summary
185 .as_deref()
186 .unwrap_or_default()
187 .to_ascii_lowercase();
188 let session = node.session_id.to_ascii_lowercase();
189 let raw = node.raw.to_ascii_lowercase();
190
191 let mut score = 0usize;
192 if summary.contains(&needle) {
193 score += 3;
194 }
195 if session.contains(&needle) {
196 score += 2;
197 }
198 if raw.contains(&needle) {
199 score += 1;
200 }
201
202 if score > 0 {
203 Some((score, node.timestamp, node))
204 } else {
205 None
206 }
207 })
208 .collect::<Vec<_>>();
209
210 scored.sort_by(|left, right| right.0.cmp(&left.0).then_with(|| right.1.cmp(&left.1)));
211
212 scored.into_iter().map(|(_, _, node)| node).collect()
213}
214
215#[cfg(test)]
216mod tests {
217 use std::sync::Arc;
218
219 use chrono::Utc;
220 use locus_core_rs::domain::models::{AvecState, SttpNode};
221 use locus_core_rs::{InMemoryNodeStore, NodeStore};
222
223 use super::MemoryExplainService;
224 use crate::domain::memory::{
225 FallbackPolicy, MemoryExplainRequest, MemoryFilter, MemoryRecallRequest, MemoryScoring,
226 };
227
228 #[tokio::test]
229 async fn explain_marks_fallback_when_on_empty_and_no_primary_results() {
230 let store: Arc<dyn NodeStore> = Arc::new(InMemoryNodeStore::new());
231 let node = test_node("s-explain", "raw", "some unrelated payload");
232 store
233 .upsert_node_async(node)
234 .await
235 .expect("upsert should succeed");
236
237 let service = MemoryExplainService::new(store);
238 let request = MemoryExplainRequest {
239 recall: MemoryRecallRequest {
240 query_text: Some("nonexistent-token".to_string()),
241 filter: MemoryFilter {
242 has_embedding: Some(true),
243 ..Default::default()
244 },
245 scoring: MemoryScoring {
246 fallback_policy: FallbackPolicy::OnEmpty,
247 ..Default::default()
248 },
249 ..Default::default()
250 },
251 };
252
253 let result = service.execute(&request).await.expect("explain should succeed");
254
255 assert!(result.fallback_triggered);
256 assert_eq!(result.retrieval_path, crate::domain::memory::RetrievalPath::LexicalFallback);
257 assert!(result
258 .stages
259 .iter()
260 .any(|stage| stage.stage == "fallback_retrieval"));
261 }
262
263 fn test_node(session_id: &str, tier: &str, raw: &str) -> SttpNode {
264 let now = Utc::now();
265 let user = AvecState {
266 stability: 0.6,
267 friction: 0.4,
268 logic: 0.8,
269 autonomy: 0.7,
270 };
271
272 SttpNode {
273 raw: raw.to_string(),
274 session_id: session_id.to_string(),
275 tier: tier.to_string(),
276 timestamp: now,
277 compression_depth: 1,
278 parent_node_id: None,
279 sync_key: format!("{session_id}:{tier}:{}", now.timestamp_nanos_opt().unwrap_or_default()),
280 updated_at: now,
281 source_metadata: None,
282 context_summary: Some("summary".to_string()),
283 embedding_dimensions: None,
284 embedding_model: None,
285 embedding: None,
286 embedded_at: None,
287 user_avec: user,
288 model_avec: user,
289 compression_avec: Some(user),
290 rho: 0.9,
291 kappa: 0.8,
292 psi: 2.5,
293 }
294 }
295}