1use crate::types::{PatternId, Query, SearchResult, SubstrateTime, TemporalPattern, TimeRange};
9use dashmap::DashMap;
10use parking_lot::RwLock;
11use std::sync::atomic::{AtomicBool, Ordering};
12use std::sync::Arc;
13
14#[derive(Debug, Clone)]
16pub struct LongTermConfig {
17 pub decay_rate: f32,
19 pub min_salience: f32,
21}
22
23impl Default for LongTermConfig {
24 fn default() -> Self {
25 Self {
26 decay_rate: 0.01,
27 min_salience: 0.1,
28 }
29 }
30}
31
32pub struct LongTermStore {
34 patterns: DashMap<PatternId, TemporalPattern>,
36 temporal_index: Arc<RwLock<Vec<(SubstrateTime, PatternId)>>>,
38 index_dirty: AtomicBool,
40 config: LongTermConfig,
42}
43
44impl LongTermStore {
45 pub fn new(config: LongTermConfig) -> Self {
47 Self {
48 patterns: DashMap::new(),
49 temporal_index: Arc::new(RwLock::new(Vec::new())),
50 index_dirty: AtomicBool::new(false),
51 config,
52 }
53 }
54
55 pub fn integrate(&self, temporal_pattern: TemporalPattern) {
57 let id = temporal_pattern.pattern.id;
58 let timestamp = temporal_pattern.pattern.timestamp;
59
60 self.patterns.insert(id, temporal_pattern);
62
63 let mut index = self.temporal_index.write();
65 index.push((timestamp, id));
66 self.index_dirty.store(true, Ordering::Relaxed);
67 }
68
69 pub fn integrate_batch(&self, patterns: Vec<TemporalPattern>) {
71 let mut index = self.temporal_index.write();
72
73 for temporal_pattern in patterns {
74 let id = temporal_pattern.pattern.id;
75 let timestamp = temporal_pattern.pattern.timestamp;
76 self.patterns.insert(id, temporal_pattern);
77 index.push((timestamp, id));
78 }
79
80 index.sort_by_key(|(t, _)| *t);
82 self.index_dirty.store(false, Ordering::Relaxed);
83 }
84
85 fn ensure_sorted(&self) {
87 if self.index_dirty.load(Ordering::Relaxed) {
88 let mut index = self.temporal_index.write();
89 index.sort_by_key(|(t, _)| *t);
90 self.index_dirty.store(false, Ordering::Relaxed);
91 }
92 }
93
94 pub fn get(&self, id: &PatternId) -> Option<TemporalPattern> {
96 self.patterns.get(id).map(|p| p.clone())
97 }
98
99 pub fn update(&self, temporal_pattern: TemporalPattern) -> bool {
101 let id = temporal_pattern.pattern.id;
102 self.patterns.insert(id, temporal_pattern).is_some()
103 }
104
105 pub fn search(&self, query: &Query) -> Vec<SearchResult> {
107 let k = query.k;
108 let mut results: Vec<SearchResult> = Vec::with_capacity(k + 1);
109
110 for entry in self.patterns.iter() {
111 let temporal_pattern = entry.value();
112 let score =
113 cosine_similarity_simd(&query.embedding, &temporal_pattern.pattern.embedding);
114
115 if results.len() >= k && score <= results.last().map(|r| r.score).unwrap_or(0.0) {
117 continue;
118 }
119
120 results.push(SearchResult {
121 id: temporal_pattern.pattern.id,
122 pattern: temporal_pattern.clone(),
123 score,
124 });
125
126 if results.len() > k {
128 results.sort_by(|a, b| {
129 b.score
130 .partial_cmp(&a.score)
131 .unwrap_or(std::cmp::Ordering::Equal)
132 });
133 results.truncate(k);
134 }
135 }
136
137 results.sort_by(|a, b| {
139 b.score
140 .partial_cmp(&a.score)
141 .unwrap_or(std::cmp::Ordering::Equal)
142 });
143 results
144 }
145
146 pub fn search_with_time_range(
148 &self,
149 query: &Query,
150 time_range: TimeRange,
151 ) -> Vec<SearchResult> {
152 let k = query.k;
153 let mut results: Vec<SearchResult> = Vec::with_capacity(k + 1);
154
155 for entry in self.patterns.iter() {
156 let temporal_pattern = entry.value();
157
158 if !time_range.contains(&temporal_pattern.pattern.timestamp) {
160 continue;
161 }
162
163 let score =
164 cosine_similarity_simd(&query.embedding, &temporal_pattern.pattern.embedding);
165
166 if results.len() >= k && score <= results.last().map(|r| r.score).unwrap_or(0.0) {
168 continue;
169 }
170
171 results.push(SearchResult {
172 id: temporal_pattern.pattern.id,
173 pattern: temporal_pattern.clone(),
174 score,
175 });
176
177 if results.len() > k {
178 results.sort_by(|a, b| {
179 b.score
180 .partial_cmp(&a.score)
181 .unwrap_or(std::cmp::Ordering::Equal)
182 });
183 results.truncate(k);
184 }
185 }
186
187 results.sort_by(|a, b| {
188 b.score
189 .partial_cmp(&a.score)
190 .unwrap_or(std::cmp::Ordering::Equal)
191 });
192 results
193 }
194
195 pub fn filter_by_time(&self, time_range: TimeRange) -> Vec<TemporalPattern> {
197 self.ensure_sorted();
198 let index = self.temporal_index.read();
199
200 let start_idx = index
202 .binary_search_by_key(&time_range.start, |(t, _)| *t)
203 .unwrap_or_else(|i| i);
204
205 let end_idx = index
207 .binary_search_by_key(&time_range.end, |(t, _)| *t)
208 .unwrap_or_else(|i| i);
209
210 index[start_idx..=end_idx.min(index.len().saturating_sub(1))]
212 .iter()
213 .filter_map(|(_, id)| self.patterns.get(id).map(|p| p.clone()))
214 .collect()
215 }
216
217 pub fn decay_low_salience(&self, decay_rate: f32) {
219 let mut to_remove = Vec::new();
220
221 for mut entry in self.patterns.iter_mut() {
222 let temporal_pattern = entry.value_mut();
223
224 temporal_pattern.pattern.salience *= 1.0 - decay_rate;
226
227 if temporal_pattern.pattern.salience < self.config.min_salience {
229 to_remove.push(temporal_pattern.pattern.id);
230 }
231 }
232
233 for id in to_remove {
235 self.remove(&id);
236 }
237 }
238
239 pub fn remove(&self, id: &PatternId) -> Option<TemporalPattern> {
241 let temporal_pattern = self.patterns.remove(id).map(|(_, p)| p)?;
243
244 let mut index = self.temporal_index.write();
246 index.retain(|(_, pid)| pid != id);
247
248 Some(temporal_pattern)
249 }
250
251 pub fn len(&self) -> usize {
253 self.patterns.len()
254 }
255
256 pub fn is_empty(&self) -> bool {
258 self.patterns.is_empty()
259 }
260
261 pub fn clear(&self) {
263 self.patterns.clear();
264 self.temporal_index.write().clear();
265 }
266
267 pub fn all(&self) -> Vec<TemporalPattern> {
269 self.patterns.iter().map(|e| e.value().clone()).collect()
270 }
271
272 pub fn stats(&self) -> LongTermStats {
274 let size = self.patterns.len();
275
276 let total_salience: f32 = self
278 .patterns
279 .iter()
280 .map(|e| e.value().pattern.salience)
281 .sum();
282 let avg_salience = if size > 0 {
283 total_salience / size as f32
284 } else {
285 0.0
286 };
287
288 let mut min_salience = f32::MAX;
290 let mut max_salience = f32::MIN;
291
292 for entry in self.patterns.iter() {
293 let salience = entry.value().pattern.salience;
294 min_salience = min_salience.min(salience);
295 max_salience = max_salience.max(salience);
296 }
297
298 if size == 0 {
299 min_salience = 0.0;
300 max_salience = 0.0;
301 }
302
303 LongTermStats {
304 size,
305 avg_salience,
306 min_salience,
307 max_salience,
308 }
309 }
310}
311
312impl Default for LongTermStore {
313 fn default() -> Self {
314 Self::new(LongTermConfig::default())
315 }
316}
317
318#[derive(Debug, Clone)]
320pub struct LongTermStats {
321 pub size: usize,
323 pub avg_salience: f32,
325 pub min_salience: f32,
327 pub max_salience: f32,
329}
330
331#[inline]
333fn cosine_similarity_simd(a: &[f32], b: &[f32]) -> f32 {
334 if a.len() != b.len() || a.is_empty() {
335 return 0.0;
336 }
337
338 let len = a.len();
339 let chunks = len / 4;
340
341 let mut dot = 0.0f32;
342 let mut mag_a = 0.0f32;
343 let mut mag_b = 0.0f32;
344
345 for i in 0..chunks {
347 let base = i * 4;
348 unsafe {
349 let a0 = *a.get_unchecked(base);
350 let a1 = *a.get_unchecked(base + 1);
351 let a2 = *a.get_unchecked(base + 2);
352 let a3 = *a.get_unchecked(base + 3);
353
354 let b0 = *b.get_unchecked(base);
355 let b1 = *b.get_unchecked(base + 1);
356 let b2 = *b.get_unchecked(base + 2);
357 let b3 = *b.get_unchecked(base + 3);
358
359 dot += a0 * b0 + a1 * b1 + a2 * b2 + a3 * b3;
360 mag_a += a0 * a0 + a1 * a1 + a2 * a2 + a3 * a3;
361 mag_b += b0 * b0 + b1 * b1 + b2 * b2 + b3 * b3;
362 }
363 }
364
365 for i in (chunks * 4)..len {
367 let ai = a[i];
368 let bi = b[i];
369 dot += ai * bi;
370 mag_a += ai * ai;
371 mag_b += bi * bi;
372 }
373
374 let mag = (mag_a * mag_b).sqrt();
375 if mag == 0.0 {
376 return 0.0;
377 }
378
379 dot / mag
380}
381
382#[allow(dead_code)]
384#[inline]
385fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
386 cosine_similarity_simd(a, b)
387}
388
389#[cfg(test)]
390mod tests {
391 use super::*;
392 use crate::types::Metadata;
393
394 #[test]
395 fn test_long_term_store() {
396 let store = LongTermStore::default();
397
398 let temporal_pattern =
399 TemporalPattern::from_embedding(vec![1.0, 2.0, 3.0], Metadata::new());
400 let id = temporal_pattern.pattern.id;
401
402 store.integrate(temporal_pattern);
403
404 assert_eq!(store.len(), 1);
405 assert!(store.get(&id).is_some());
406 }
407
408 #[test]
409 fn test_search() {
410 let store = LongTermStore::default();
411
412 let p1 = TemporalPattern::from_embedding(vec![1.0, 0.0, 0.0], Metadata::new());
414 let p2 = TemporalPattern::from_embedding(vec![0.0, 1.0, 0.0], Metadata::new());
415
416 store.integrate(p1);
417 store.integrate(p2);
418
419 let query = Query::from_embedding(vec![0.9, 0.1, 0.0]).with_k(1);
421 let results = store.search(&query);
422
423 assert_eq!(results.len(), 1);
424 assert!(results[0].score > 0.5);
425 }
426
427 #[test]
428 fn test_decay() {
429 let store = LongTermStore::default();
430
431 let mut temporal_pattern =
432 TemporalPattern::from_embedding(vec![1.0, 2.0, 3.0], Metadata::new());
433 temporal_pattern.pattern.salience = 0.15; let id = temporal_pattern.pattern.id;
435
436 store.integrate(temporal_pattern);
437 assert_eq!(store.len(), 1);
438
439 store.decay_low_salience(0.5);
441 assert_eq!(store.len(), 0);
442 }
443}