Skip to main content

locus_sdk/application/
memory_explain.rs

1use std::collections::HashSet;
2use std::sync::Arc;
3
4use anyhow::Result;
5use locus_core_rs::ContextQueryService;
6use locus_core_rs::domain::contracts::NodeStore;
7use locus_core_rs::domain::models::{AvecState, SttpNode};
8
9use crate::application::memory_filters::{build_session_filter, node_matches_common_filters};
10use crate::domain::memory::{
11    FallbackPolicy, MemoryExplainRequest, MemoryExplainResult, MemoryExplainStage, RetrievalPath,
12    clamp_limit,
13};
14
15pub struct MemoryExplainService {
16    context_query: ContextQueryService,
17}
18
19impl MemoryExplainService {
20    /// Create an explanation service for retrieval-stage introspection.
21    pub fn new(store: Arc<dyn NodeStore>) -> Self {
22        Self {
23            context_query: ContextQueryService::new(store),
24        }
25    }
26
27    /// Explain retrieval behavior for a recall request.
28    ///
29    /// Returns per-stage counts, retrieval path, and fallback diagnostics
30    /// without mutating stored nodes.
31    pub async fn execute(&self, request: &MemoryExplainRequest) -> Result<MemoryExplainResult> {
32        let recall = &request.recall;
33        let limit = clamp_limit(recall.page.limit);
34        let expanded_limit = (limit.saturating_mul(5)).clamp(1, 200);
35
36        let current = recall.current_avec.unwrap_or_else(AvecState::zero);
37        let session_scope = recall
38            .scope
39            .session_ids
40            .as_deref()
41            .filter(|sessions| sessions.len() == 1)
42            .and_then(|sessions| sessions.first().map(String::as_str));
43        let session_filter = build_session_filter(&recall.scope);
44
45        let mut stages = Vec::new();
46        let mut path = if recall.query_embedding.is_some() {
47            RetrievalPath::Hybrid
48        } else {
49            RetrievalPath::ResonanceOnly
50        };
51        let mut fallback_triggered = false;
52        let mut fallback_reason = None;
53
54        let primary = if let Some(query_embedding) = recall.query_embedding.as_deref() {
55            self.context_query
56                .get_context_hybrid_scoped_filtered_async(
57                    session_scope,
58                    current.stability,
59                    current.friction,
60                    current.logic,
61                    current.autonomy,
62                    recall.scope.from_utc,
63                    recall.scope.to_utc,
64                    recall.scope.tiers.as_deref(),
65                    Some(query_embedding),
66                    recall.scoring.alpha,
67                    recall.scoring.beta,
68                    limit,
69                )
70                .await
71        } else {
72            self.context_query
73                .get_context_scoped_filtered_async(
74                    session_scope,
75                    current.stability,
76                    current.friction,
77                    current.logic,
78                    current.autonomy,
79                    recall.scope.from_utc,
80                    recall.scope.to_utc,
81                    recall.scope.tiers.as_deref(),
82                    limit,
83                )
84                .await
85        };
86
87        stages.push(MemoryExplainStage {
88            stage: "primary_retrieval".to_string(),
89            count: primary.nodes.len(),
90        });
91
92        let filtered_primary = filter_nodes(primary.nodes, recall, session_filter.as_ref());
93        stages.push(MemoryExplainStage {
94            stage: "after_common_filter".to_string(),
95            count: filtered_primary.len(),
96        });
97
98        if let Some(query_text) = recall.query_text.as_deref() {
99            let need_fallback = match recall.scoring.fallback_policy {
100                FallbackPolicy::Never => false,
101                FallbackPolicy::OnEmpty => filtered_primary.is_empty(),
102                FallbackPolicy::Always => true,
103            };
104
105            if need_fallback {
106                fallback_triggered = true;
107                fallback_reason = Some(match recall.scoring.fallback_policy {
108                    FallbackPolicy::Never => "never".to_string(),
109                    FallbackPolicy::OnEmpty => {
110                        "fallback_policy=on_empty and primary result set is empty".to_string()
111                    }
112                    FallbackPolicy::Always => "fallback_policy=always".to_string(),
113                });
114
115                let fallback = self
116                    .context_query
117                    .get_context_scoped_filtered_async(
118                        session_scope,
119                        current.stability,
120                        current.friction,
121                        current.logic,
122                        current.autonomy,
123                        recall.scope.from_utc,
124                        recall.scope.to_utc,
125                        recall.scope.tiers.as_deref(),
126                        expanded_limit,
127                    )
128                    .await;
129
130                stages.push(MemoryExplainStage {
131                    stage: "fallback_retrieval".to_string(),
132                    count: fallback.nodes.len(),
133                });
134
135                let filtered_fallback =
136                    filter_nodes(fallback.nodes, recall, session_filter.as_ref());
137                stages.push(MemoryExplainStage {
138                    stage: "fallback_after_common_filter".to_string(),
139                    count: filtered_fallback.len(),
140                });
141
142                let lexical = lexical_filter(filtered_fallback, query_text);
143                stages.push(MemoryExplainStage {
144                    stage: "lexical_filter".to_string(),
145                    count: lexical.len(),
146                });
147
148                path = RetrievalPath::LexicalFallback;
149            }
150        }
151
152        Ok(MemoryExplainResult {
153            retrieval_path: path,
154            fallback_triggered,
155            fallback_reason,
156            stages,
157            scoring: recall.scoring.clone(),
158        })
159    }
160}
161
162fn filter_nodes(
163    nodes: Vec<SttpNode>,
164    request: &crate::domain::memory::MemoryRecallRequest,
165    session_filter: Option<&HashSet<String>>,
166) -> Vec<SttpNode> {
167    nodes.into_iter()
168        .filter(|node| {
169            node_matches_common_filters(node, &request.scope, &request.filter, session_filter)
170        })
171        .collect()
172}
173
174fn lexical_filter(nodes: Vec<SttpNode>, query_text: &str) -> Vec<SttpNode> {
175    let needle = query_text.trim().to_ascii_lowercase();
176    if needle.is_empty() {
177        return nodes;
178    }
179
180    let mut scored = nodes
181        .into_iter()
182        .filter_map(|node| {
183            let summary = node
184                .context_summary
185                .as_deref()
186                .unwrap_or_default()
187                .to_ascii_lowercase();
188            let session = node.session_id.to_ascii_lowercase();
189            let raw = node.raw.to_ascii_lowercase();
190
191            let mut score = 0usize;
192            if summary.contains(&needle) {
193                score += 3;
194            }
195            if session.contains(&needle) {
196                score += 2;
197            }
198            if raw.contains(&needle) {
199                score += 1;
200            }
201
202            if score > 0 {
203                Some((score, node.timestamp, node))
204            } else {
205                None
206            }
207        })
208        .collect::<Vec<_>>();
209
210    scored.sort_by(|left, right| right.0.cmp(&left.0).then_with(|| right.1.cmp(&left.1)));
211
212    scored.into_iter().map(|(_, _, node)| node).collect()
213}
214
215#[cfg(test)]
216mod tests {
217    use std::sync::Arc;
218
219    use chrono::Utc;
220    use locus_core_rs::domain::models::{AvecState, SttpNode};
221    use locus_core_rs::{InMemoryNodeStore, NodeStore};
222
223    use super::MemoryExplainService;
224    use crate::domain::memory::{
225        FallbackPolicy, MemoryExplainRequest, MemoryFilter, MemoryRecallRequest, MemoryScoring,
226    };
227
228    #[tokio::test]
229    async fn explain_marks_fallback_when_on_empty_and_no_primary_results() {
230        let store: Arc<dyn NodeStore> = Arc::new(InMemoryNodeStore::new());
231        let node = test_node("s-explain", "raw", "some unrelated payload");
232        store
233            .upsert_node_async(node)
234            .await
235            .expect("upsert should succeed");
236
237        let service = MemoryExplainService::new(store);
238        let request = MemoryExplainRequest {
239            recall: MemoryRecallRequest {
240                query_text: Some("nonexistent-token".to_string()),
241                filter: MemoryFilter {
242                    has_embedding: Some(true),
243                    ..Default::default()
244                },
245                scoring: MemoryScoring {
246                    fallback_policy: FallbackPolicy::OnEmpty,
247                    ..Default::default()
248                },
249                ..Default::default()
250            },
251        };
252
253        let result = service.execute(&request).await.expect("explain should succeed");
254
255        assert!(result.fallback_triggered);
256        assert_eq!(result.retrieval_path, crate::domain::memory::RetrievalPath::LexicalFallback);
257        assert!(result
258            .stages
259            .iter()
260            .any(|stage| stage.stage == "fallback_retrieval"));
261    }
262
263    fn test_node(session_id: &str, tier: &str, raw: &str) -> SttpNode {
264        let now = Utc::now();
265        let user = AvecState {
266            stability: 0.6,
267            friction: 0.4,
268            logic: 0.8,
269            autonomy: 0.7,
270        };
271
272        SttpNode {
273            raw: raw.to_string(),
274            session_id: session_id.to_string(),
275            tier: tier.to_string(),
276            timestamp: now,
277            compression_depth: 1,
278            parent_node_id: None,
279            sync_key: format!("{session_id}:{tier}:{}", now.timestamp_nanos_opt().unwrap_or_default()),
280            updated_at: now,
281            source_metadata: None,
282            context_summary: Some("summary".to_string()),
283            embedding_dimensions: None,
284            embedding_model: None,
285            embedding: None,
286            embedded_at: None,
287            user_avec: user,
288            model_avec: user,
289            compression_avec: Some(user),
290            rho: 0.9,
291            kappa: 0.8,
292            psi: 2.5,
293        }
294    }
295}