1use crate::types::{TemporalPattern, PatternId, Query, SearchResult, SubstrateTime, TimeRange};
9use dashmap::DashMap;
10use parking_lot::RwLock;
11use std::sync::Arc;
12use std::sync::atomic::{AtomicBool, Ordering};
13
14#[derive(Debug, Clone)]
16pub struct LongTermConfig {
17 pub decay_rate: f32,
19 pub min_salience: f32,
21}
22
23impl Default for LongTermConfig {
24 fn default() -> Self {
25 Self {
26 decay_rate: 0.01,
27 min_salience: 0.1,
28 }
29 }
30}
31
32pub struct LongTermStore {
34 patterns: DashMap<PatternId, TemporalPattern>,
36 temporal_index: Arc<RwLock<Vec<(SubstrateTime, PatternId)>>>,
38 index_dirty: AtomicBool,
40 config: LongTermConfig,
42}
43
44impl LongTermStore {
45 pub fn new(config: LongTermConfig) -> Self {
47 Self {
48 patterns: DashMap::new(),
49 temporal_index: Arc::new(RwLock::new(Vec::new())),
50 index_dirty: AtomicBool::new(false),
51 config,
52 }
53 }
54
55 pub fn integrate(&self, temporal_pattern: TemporalPattern) {
57 let id = temporal_pattern.pattern.id;
58 let timestamp = temporal_pattern.pattern.timestamp;
59
60 self.patterns.insert(id, temporal_pattern);
62
63 let mut index = self.temporal_index.write();
65 index.push((timestamp, id));
66 self.index_dirty.store(true, Ordering::Relaxed);
67 }
68
69 pub fn integrate_batch(&self, patterns: Vec<TemporalPattern>) {
71 let mut index = self.temporal_index.write();
72
73 for temporal_pattern in patterns {
74 let id = temporal_pattern.pattern.id;
75 let timestamp = temporal_pattern.pattern.timestamp;
76 self.patterns.insert(id, temporal_pattern);
77 index.push((timestamp, id));
78 }
79
80 index.sort_by_key(|(t, _)| *t);
82 self.index_dirty.store(false, Ordering::Relaxed);
83 }
84
85 fn ensure_sorted(&self) {
87 if self.index_dirty.load(Ordering::Relaxed) {
88 let mut index = self.temporal_index.write();
89 index.sort_by_key(|(t, _)| *t);
90 self.index_dirty.store(false, Ordering::Relaxed);
91 }
92 }
93
94 pub fn get(&self, id: &PatternId) -> Option<TemporalPattern> {
96 self.patterns.get(id).map(|p| p.clone())
97 }
98
99 pub fn update(&self, temporal_pattern: TemporalPattern) -> bool {
101 let id = temporal_pattern.pattern.id;
102 self.patterns.insert(id, temporal_pattern).is_some()
103 }
104
105 pub fn search(&self, query: &Query) -> Vec<SearchResult> {
107 let k = query.k;
108 let mut results: Vec<SearchResult> = Vec::with_capacity(k + 1);
109
110 for entry in self.patterns.iter() {
111 let temporal_pattern = entry.value();
112 let score = cosine_similarity_simd(&query.embedding, &temporal_pattern.pattern.embedding);
113
114 if results.len() >= k && score <= results.last().map(|r| r.score).unwrap_or(0.0) {
116 continue;
117 }
118
119 results.push(SearchResult {
120 id: temporal_pattern.pattern.id,
121 pattern: temporal_pattern.clone(),
122 score,
123 });
124
125 if results.len() > k {
127 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
128 results.truncate(k);
129 }
130 }
131
132 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
134 results
135 }
136
137 pub fn search_with_time_range(&self, query: &Query, time_range: TimeRange) -> Vec<SearchResult> {
139 let k = query.k;
140 let mut results: Vec<SearchResult> = Vec::with_capacity(k + 1);
141
142 for entry in self.patterns.iter() {
143 let temporal_pattern = entry.value();
144
145 if !time_range.contains(&temporal_pattern.pattern.timestamp) {
147 continue;
148 }
149
150 let score = cosine_similarity_simd(&query.embedding, &temporal_pattern.pattern.embedding);
151
152 if results.len() >= k && score <= results.last().map(|r| r.score).unwrap_or(0.0) {
154 continue;
155 }
156
157 results.push(SearchResult {
158 id: temporal_pattern.pattern.id,
159 pattern: temporal_pattern.clone(),
160 score,
161 });
162
163 if results.len() > k {
164 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
165 results.truncate(k);
166 }
167 }
168
169 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
170 results
171 }
172
173 pub fn filter_by_time(&self, time_range: TimeRange) -> Vec<TemporalPattern> {
175 self.ensure_sorted();
176 let index = self.temporal_index.read();
177
178 let start_idx = index
180 .binary_search_by_key(&time_range.start, |(t, _)| *t)
181 .unwrap_or_else(|i| i);
182
183 let end_idx = index
185 .binary_search_by_key(&time_range.end, |(t, _)| *t)
186 .unwrap_or_else(|i| i);
187
188 index[start_idx..=end_idx.min(index.len().saturating_sub(1))]
190 .iter()
191 .filter_map(|(_, id)| self.patterns.get(id).map(|p| p.clone()))
192 .collect()
193 }
194
195 pub fn decay_low_salience(&self, decay_rate: f32) {
197 let mut to_remove = Vec::new();
198
199 for mut entry in self.patterns.iter_mut() {
200 let temporal_pattern = entry.value_mut();
201
202 temporal_pattern.pattern.salience *= 1.0 - decay_rate;
204
205 if temporal_pattern.pattern.salience < self.config.min_salience {
207 to_remove.push(temporal_pattern.pattern.id);
208 }
209 }
210
211 for id in to_remove {
213 self.remove(&id);
214 }
215 }
216
217 pub fn remove(&self, id: &PatternId) -> Option<TemporalPattern> {
219 let temporal_pattern = self.patterns.remove(id).map(|(_, p)| p)?;
221
222 let mut index = self.temporal_index.write();
224 index.retain(|(_, pid)| pid != id);
225
226 Some(temporal_pattern)
227 }
228
229 pub fn len(&self) -> usize {
231 self.patterns.len()
232 }
233
234 pub fn is_empty(&self) -> bool {
236 self.patterns.is_empty()
237 }
238
239 pub fn clear(&self) {
241 self.patterns.clear();
242 self.temporal_index.write().clear();
243 }
244
245 pub fn all(&self) -> Vec<TemporalPattern> {
247 self.patterns.iter().map(|e| e.value().clone()).collect()
248 }
249
250 pub fn stats(&self) -> LongTermStats {
252 let size = self.patterns.len();
253
254 let total_salience: f32 = self.patterns.iter().map(|e| e.value().pattern.salience).sum();
256 let avg_salience = if size > 0 {
257 total_salience / size as f32
258 } else {
259 0.0
260 };
261
262 let mut min_salience = f32::MAX;
264 let mut max_salience = f32::MIN;
265
266 for entry in self.patterns.iter() {
267 let salience = entry.value().pattern.salience;
268 min_salience = min_salience.min(salience);
269 max_salience = max_salience.max(salience);
270 }
271
272 if size == 0 {
273 min_salience = 0.0;
274 max_salience = 0.0;
275 }
276
277 LongTermStats {
278 size,
279 avg_salience,
280 min_salience,
281 max_salience,
282 }
283 }
284}
285
286impl Default for LongTermStore {
287 fn default() -> Self {
288 Self::new(LongTermConfig::default())
289 }
290}
291
292#[derive(Debug, Clone)]
294pub struct LongTermStats {
295 pub size: usize,
297 pub avg_salience: f32,
299 pub min_salience: f32,
301 pub max_salience: f32,
303}
304
305#[inline]
307fn cosine_similarity_simd(a: &[f32], b: &[f32]) -> f32 {
308 if a.len() != b.len() || a.is_empty() {
309 return 0.0;
310 }
311
312 let len = a.len();
313 let chunks = len / 4;
314
315 let mut dot = 0.0f32;
316 let mut mag_a = 0.0f32;
317 let mut mag_b = 0.0f32;
318
319 for i in 0..chunks {
321 let base = i * 4;
322 unsafe {
323 let a0 = *a.get_unchecked(base);
324 let a1 = *a.get_unchecked(base + 1);
325 let a2 = *a.get_unchecked(base + 2);
326 let a3 = *a.get_unchecked(base + 3);
327
328 let b0 = *b.get_unchecked(base);
329 let b1 = *b.get_unchecked(base + 1);
330 let b2 = *b.get_unchecked(base + 2);
331 let b3 = *b.get_unchecked(base + 3);
332
333 dot += a0 * b0 + a1 * b1 + a2 * b2 + a3 * b3;
334 mag_a += a0 * a0 + a1 * a1 + a2 * a2 + a3 * a3;
335 mag_b += b0 * b0 + b1 * b1 + b2 * b2 + b3 * b3;
336 }
337 }
338
339 for i in (chunks * 4)..len {
341 let ai = a[i];
342 let bi = b[i];
343 dot += ai * bi;
344 mag_a += ai * ai;
345 mag_b += bi * bi;
346 }
347
348 let mag = (mag_a * mag_b).sqrt();
349 if mag == 0.0 {
350 return 0.0;
351 }
352
353 dot / mag
354}
355
356#[inline]
358fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
359 cosine_similarity_simd(a, b)
360}
361
362#[cfg(test)]
363mod tests {
364 use super::*;
365 use crate::types::Metadata;
366
367 #[test]
368 fn test_long_term_store() {
369 let store = LongTermStore::default();
370
371 let temporal_pattern = TemporalPattern::from_embedding(vec![1.0, 2.0, 3.0], Metadata::new());
372 let id = temporal_pattern.pattern.id;
373
374 store.integrate(temporal_pattern);
375
376 assert_eq!(store.len(), 1);
377 assert!(store.get(&id).is_some());
378 }
379
380 #[test]
381 fn test_search() {
382 let store = LongTermStore::default();
383
384 let p1 = TemporalPattern::from_embedding(vec![1.0, 0.0, 0.0], Metadata::new());
386 let p2 = TemporalPattern::from_embedding(vec![0.0, 1.0, 0.0], Metadata::new());
387
388 store.integrate(p1);
389 store.integrate(p2);
390
391 let query = Query::from_embedding(vec![0.9, 0.1, 0.0]).with_k(1);
393 let results = store.search(&query);
394
395 assert_eq!(results.len(), 1);
396 assert!(results[0].score > 0.5);
397 }
398
399 #[test]
400 fn test_decay() {
401 let store = LongTermStore::default();
402
403 let mut temporal_pattern = TemporalPattern::from_embedding(vec![1.0, 2.0, 3.0], Metadata::new());
404 temporal_pattern.pattern.salience = 0.15; let id = temporal_pattern.pattern.id;
406
407 store.integrate(temporal_pattern);
408 assert_eq!(store.len(), 1);
409
410 store.decay_low_salience(0.5);
412 assert_eq!(store.len(), 0);
413 }
414}