1use std::collections::HashMap;
2use std::sync::Arc;
3
4use anyhow::Result;
5use locus_core_rs::domain::contracts::NodeStore;
6use locus_core_rs::domain::models::{AvecState, NodeQuery, SttpNode};
7
8use crate::application::memory_filters::{build_session_filter, node_matches_common_filters};
9use crate::domain::memory::{
10 MemoryAggregateGroup, MemoryAggregateRequest, MemoryAggregateResult, MemoryGroupBy, NumericStats,
11 clamp_groups, clamp_nodes,
12};
13
14pub struct MemoryAggregateService {
15 store: Arc<dyn NodeStore>,
16}
17
18impl MemoryAggregateService {
19 pub fn new(store: Arc<dyn NodeStore>) -> Self {
21 Self { store }
22 }
23
24 pub async fn execute(&self, request: &MemoryAggregateRequest) -> Result<MemoryAggregateResult> {
29 let max_nodes = clamp_nodes(if request.max_nodes == 0 {
30 5000
31 } else {
32 request.max_nodes
33 });
34 let max_groups = clamp_groups(if request.max_groups == 0 {
35 500
36 } else {
37 request.max_groups
38 });
39
40 let single_session = request
41 .scope
42 .session_ids
43 .as_deref()
44 .filter(|sessions| sessions.len() == 1)
45 .and_then(|sessions| sessions.first().cloned());
46
47 let nodes = self
48 .store
49 .query_nodes_async(NodeQuery {
50 limit: max_nodes,
51 session_id: single_session,
52 from_utc: request.scope.from_utc,
53 to_utc: request.scope.to_utc,
54 tiers: request.scope.tiers.clone(),
55 })
56 .await?;
57
58 let session_filter = build_session_filter(&request.scope);
59
60 let filtered = nodes
61 .into_iter()
62 .filter(|node| {
63 node_matches_common_filters(node, &request.scope, &request.filter, session_filter.as_ref())
64 })
65 .collect::<Vec<_>>();
66
67 let scanned_nodes = filtered.len();
68
69 let mut grouped: HashMap<String, Vec<SttpNode>> = HashMap::new();
70 for node in filtered {
71 let key = group_key(&node, request.group_by);
72 grouped.entry(key).or_default().push(node);
73 }
74
75 let mut groups = grouped
76 .into_iter()
77 .map(|(key, nodes)| to_group(key, &nodes))
78 .collect::<Vec<_>>();
79
80 groups.sort_by(|left, right| {
81 right
82 .node_count
83 .cmp(&left.node_count)
84 .then_with(|| left.key.cmp(&right.key))
85 });
86
87 let total_groups = groups.len();
88 groups.truncate(max_groups);
89
90 Ok(MemoryAggregateResult {
91 groups,
92 total_groups,
93 scanned_nodes,
94 })
95 }
96}
97
98fn group_key(node: &SttpNode, group_by: MemoryGroupBy) -> String {
99 match group_by {
100 MemoryGroupBy::SessionId => node.session_id.clone(),
101 MemoryGroupBy::Tier => node.tier.clone(),
102 MemoryGroupBy::EmbeddingModel => node
103 .embedding_model
104 .clone()
105 .unwrap_or_else(|| "none".to_string()),
106 MemoryGroupBy::DateDay => node.timestamp.date_naive().to_string(),
107 }
108}
109
110fn to_group(key: String, nodes: &[SttpNode]) -> MemoryAggregateGroup {
111 let node_count = nodes.len();
112
113 let embedding_count = nodes
114 .iter()
115 .filter(|node| node.embedding.as_ref().is_some_and(|values| !values.is_empty()))
116 .count();
117
118 let embedding_coverage = if node_count == 0 {
119 0.0
120 } else {
121 embedding_count as f32 / node_count as f32
122 };
123
124 let avg_user_avec = average_avec(nodes.iter().map(|node| node.user_avec).collect::<Vec<_>>().as_slice());
125 let avg_model_avec =
126 average_avec(nodes.iter().map(|node| node.model_avec).collect::<Vec<_>>().as_slice());
127
128 let compression_states = nodes
129 .iter()
130 .filter_map(|node| node.compression_avec)
131 .collect::<Vec<_>>();
132
133 let avg_compression_avec = if compression_states.is_empty() {
134 None
135 } else {
136 Some(average_avec(compression_states.as_slice()))
137 };
138
139 let psi_stats = average_metric(nodes.iter().map(|node| node.psi).collect::<Vec<_>>().as_slice());
140 let rho_stats = average_metric(nodes.iter().map(|node| node.rho).collect::<Vec<_>>().as_slice());
141 let kappa_stats =
142 average_metric(nodes.iter().map(|node| node.kappa).collect::<Vec<_>>().as_slice());
143
144 MemoryAggregateGroup {
145 key,
146 node_count,
147 embedding_coverage,
148 avg_user_avec,
149 avg_model_avec,
150 avg_compression_avec,
151 psi_stats,
152 rho_stats,
153 kappa_stats,
154 }
155}
156
157fn average_avec(values: &[AvecState]) -> AvecState {
158 if values.is_empty() {
159 return AvecState::zero();
160 }
161
162 let mut stability = 0.0_f32;
163 let mut friction = 0.0_f32;
164 let mut logic = 0.0_f32;
165 let mut autonomy = 0.0_f32;
166
167 for value in values {
168 stability += value.stability;
169 friction += value.friction;
170 logic += value.logic;
171 autonomy += value.autonomy;
172 }
173
174 let count = values.len() as f32;
175
176 AvecState {
177 stability: stability / count,
178 friction: friction / count,
179 logic: logic / count,
180 autonomy: autonomy / count,
181 }
182}
183
184fn average_metric(values: &[f32]) -> NumericStats {
185 if values.is_empty() {
186 return NumericStats::default();
187 }
188
189 let (min, max, sum) = values.iter().fold(
190 (f32::MAX, f32::MIN, 0.0_f32),
191 |(min, max, sum), value| (min.min(*value), max.max(*value), sum + *value),
192 );
193
194 NumericStats {
195 min,
196 max,
197 average: sum / values.len() as f32,
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use std::sync::Arc;
204
205 use chrono::{Duration, Utc};
206 use locus_core_rs::{InMemoryNodeStore, NodeStore};
207 use locus_core_rs::domain::models::{AvecState, SttpNode};
208
209 use super::MemoryAggregateService;
210 use crate::domain::memory::{MemoryAggregateRequest, MemoryGroupBy};
211
212 #[tokio::test]
213 async fn aggregates_nodes_by_session_with_coverage() {
214 let store: Arc<dyn NodeStore> = Arc::new(InMemoryNodeStore::new());
215 let now = Utc::now();
216
217 store
218 .upsert_node_async(test_node("s-1", "raw", now - Duration::minutes(2), Some(vec![0.1, 0.2])))
219 .await
220 .expect("upsert should succeed");
221 store
222 .upsert_node_async(test_node("s-1", "raw", now - Duration::minutes(1), None))
223 .await
224 .expect("upsert should succeed");
225 store
226 .upsert_node_async(test_node("s-2", "raw", now, Some(vec![0.3, 0.4])))
227 .await
228 .expect("upsert should succeed");
229
230 let service = MemoryAggregateService::new(store);
231 let request = MemoryAggregateRequest {
232 group_by: MemoryGroupBy::SessionId,
233 max_groups: 10,
234 max_nodes: 100,
235 ..Default::default()
236 };
237
238 let result = service.execute(&request).await.expect("aggregate should succeed");
239
240 assert_eq!(result.total_groups, 2);
241 let s1 = result
242 .groups
243 .iter()
244 .find(|group| group.key == "s-1")
245 .expect("s-1 group should exist");
246 assert_eq!(s1.node_count, 2);
247 assert!((s1.embedding_coverage - 0.5).abs() < f32::EPSILON);
248 }
249
250 fn test_node(session_id: &str, tier: &str, timestamp: chrono::DateTime<Utc>, embedding: Option<Vec<f32>>) -> SttpNode {
251 let user = AvecState {
252 stability: 0.6,
253 friction: 0.4,
254 logic: 0.8,
255 autonomy: 0.7,
256 };
257 let model = AvecState {
258 stability: 0.5,
259 friction: 0.3,
260 logic: 0.9,
261 autonomy: 0.6,
262 };
263
264 SttpNode {
265 raw: format!("raw:{session_id}:{tier}:{timestamp}"),
266 session_id: session_id.to_string(),
267 tier: tier.to_string(),
268 timestamp,
269 compression_depth: 1,
270 parent_node_id: None,
271 sync_key: format!("{}:{}:{}", session_id, tier, timestamp.timestamp_nanos_opt().unwrap_or_default()),
272 updated_at: timestamp,
273 source_metadata: None,
274 context_summary: Some("summary".to_string()),
275 embedding_dimensions: embedding.as_ref().map(|v| v.len()),
276 embedding_model: embedding.as_ref().map(|_| "test-model".to_string()),
277 embedding,
278 embedded_at: None,
279 user_avec: user,
280 model_avec: model,
281 compression_avec: Some(model),
282 rho: 0.9,
283 kappa: 0.8,
284 psi: 2.5,
285 }
286 }
287}