use crate::types::{PatternId, Query, SearchResult, SubstrateTime, TemporalPattern, TimeRange};
use dashmap::DashMap;
use parking_lot::RwLock;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct LongTermConfig {
pub decay_rate: f32,
pub min_salience: f32,
}
impl Default for LongTermConfig {
fn default() -> Self {
Self {
decay_rate: 0.01,
min_salience: 0.1,
}
}
}
pub struct LongTermStore {
patterns: DashMap<PatternId, TemporalPattern>,
temporal_index: Arc<RwLock<Vec<(SubstrateTime, PatternId)>>>,
index_dirty: AtomicBool,
config: LongTermConfig,
}
impl LongTermStore {
pub fn new(config: LongTermConfig) -> Self {
Self {
patterns: DashMap::new(),
temporal_index: Arc::new(RwLock::new(Vec::new())),
index_dirty: AtomicBool::new(false),
config,
}
}
pub fn integrate(&self, temporal_pattern: TemporalPattern) {
let id = temporal_pattern.pattern.id;
let timestamp = temporal_pattern.pattern.timestamp;
self.patterns.insert(id, temporal_pattern);
let mut index = self.temporal_index.write();
index.push((timestamp, id));
self.index_dirty.store(true, Ordering::Relaxed);
}
pub fn integrate_batch(&self, patterns: Vec<TemporalPattern>) {
let mut index = self.temporal_index.write();
for temporal_pattern in patterns {
let id = temporal_pattern.pattern.id;
let timestamp = temporal_pattern.pattern.timestamp;
self.patterns.insert(id, temporal_pattern);
index.push((timestamp, id));
}
index.sort_by_key(|(t, _)| *t);
self.index_dirty.store(false, Ordering::Relaxed);
}
fn ensure_sorted(&self) {
if self.index_dirty.load(Ordering::Relaxed) {
let mut index = self.temporal_index.write();
index.sort_by_key(|(t, _)| *t);
self.index_dirty.store(false, Ordering::Relaxed);
}
}
pub fn get(&self, id: &PatternId) -> Option<TemporalPattern> {
self.patterns.get(id).map(|p| p.clone())
}
pub fn update(&self, temporal_pattern: TemporalPattern) -> bool {
let id = temporal_pattern.pattern.id;
self.patterns.insert(id, temporal_pattern).is_some()
}
pub fn search(&self, query: &Query) -> Vec<SearchResult> {
let k = query.k;
let mut results: Vec<SearchResult> = Vec::with_capacity(k + 1);
for entry in self.patterns.iter() {
let temporal_pattern = entry.value();
let score =
cosine_similarity_simd(&query.embedding, &temporal_pattern.pattern.embedding);
if results.len() >= k && score <= results.last().map(|r| r.score).unwrap_or(0.0) {
continue;
}
results.push(SearchResult {
id: temporal_pattern.pattern.id,
pattern: temporal_pattern.clone(),
score,
});
if results.len() > k {
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(k);
}
}
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results
}
pub fn search_with_time_range(
&self,
query: &Query,
time_range: TimeRange,
) -> Vec<SearchResult> {
let k = query.k;
let mut results: Vec<SearchResult> = Vec::with_capacity(k + 1);
for entry in self.patterns.iter() {
let temporal_pattern = entry.value();
if !time_range.contains(&temporal_pattern.pattern.timestamp) {
continue;
}
let score =
cosine_similarity_simd(&query.embedding, &temporal_pattern.pattern.embedding);
if results.len() >= k && score <= results.last().map(|r| r.score).unwrap_or(0.0) {
continue;
}
results.push(SearchResult {
id: temporal_pattern.pattern.id,
pattern: temporal_pattern.clone(),
score,
});
if results.len() > k {
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(k);
}
}
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results
}
pub fn filter_by_time(&self, time_range: TimeRange) -> Vec<TemporalPattern> {
self.ensure_sorted();
let index = self.temporal_index.read();
let start_idx = index
.binary_search_by_key(&time_range.start, |(t, _)| *t)
.unwrap_or_else(|i| i);
let end_idx = index
.binary_search_by_key(&time_range.end, |(t, _)| *t)
.unwrap_or_else(|i| i);
index[start_idx..=end_idx.min(index.len().saturating_sub(1))]
.iter()
.filter_map(|(_, id)| self.patterns.get(id).map(|p| p.clone()))
.collect()
}
pub fn decay_low_salience(&self, decay_rate: f32) {
let mut to_remove = Vec::new();
for mut entry in self.patterns.iter_mut() {
let temporal_pattern = entry.value_mut();
temporal_pattern.pattern.salience *= 1.0 - decay_rate;
if temporal_pattern.pattern.salience < self.config.min_salience {
to_remove.push(temporal_pattern.pattern.id);
}
}
for id in to_remove {
self.remove(&id);
}
}
pub fn remove(&self, id: &PatternId) -> Option<TemporalPattern> {
let temporal_pattern = self.patterns.remove(id).map(|(_, p)| p)?;
let mut index = self.temporal_index.write();
index.retain(|(_, pid)| pid != id);
Some(temporal_pattern)
}
pub fn len(&self) -> usize {
self.patterns.len()
}
pub fn is_empty(&self) -> bool {
self.patterns.is_empty()
}
pub fn clear(&self) {
self.patterns.clear();
self.temporal_index.write().clear();
}
pub fn all(&self) -> Vec<TemporalPattern> {
self.patterns.iter().map(|e| e.value().clone()).collect()
}
pub fn stats(&self) -> LongTermStats {
let size = self.patterns.len();
let total_salience: f32 = self
.patterns
.iter()
.map(|e| e.value().pattern.salience)
.sum();
let avg_salience = if size > 0 {
total_salience / size as f32
} else {
0.0
};
let mut min_salience = f32::MAX;
let mut max_salience = f32::MIN;
for entry in self.patterns.iter() {
let salience = entry.value().pattern.salience;
min_salience = min_salience.min(salience);
max_salience = max_salience.max(salience);
}
if size == 0 {
min_salience = 0.0;
max_salience = 0.0;
}
LongTermStats {
size,
avg_salience,
min_salience,
max_salience,
}
}
}
impl Default for LongTermStore {
fn default() -> Self {
Self::new(LongTermConfig::default())
}
}
#[derive(Debug, Clone)]
pub struct LongTermStats {
pub size: usize,
pub avg_salience: f32,
pub min_salience: f32,
pub max_salience: f32,
}
#[inline]
fn cosine_similarity_simd(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let len = a.len();
let chunks = len / 4;
let mut dot = 0.0f32;
let mut mag_a = 0.0f32;
let mut mag_b = 0.0f32;
for i in 0..chunks {
let base = i * 4;
unsafe {
let a0 = *a.get_unchecked(base);
let a1 = *a.get_unchecked(base + 1);
let a2 = *a.get_unchecked(base + 2);
let a3 = *a.get_unchecked(base + 3);
let b0 = *b.get_unchecked(base);
let b1 = *b.get_unchecked(base + 1);
let b2 = *b.get_unchecked(base + 2);
let b3 = *b.get_unchecked(base + 3);
dot += a0 * b0 + a1 * b1 + a2 * b2 + a3 * b3;
mag_a += a0 * a0 + a1 * a1 + a2 * a2 + a3 * a3;
mag_b += b0 * b0 + b1 * b1 + b2 * b2 + b3 * b3;
}
}
for i in (chunks * 4)..len {
let ai = a[i];
let bi = b[i];
dot += ai * bi;
mag_a += ai * ai;
mag_b += bi * bi;
}
let mag = (mag_a * mag_b).sqrt();
if mag == 0.0 {
return 0.0;
}
dot / mag
}
#[allow(dead_code)]
#[inline]
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
cosine_similarity_simd(a, b)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::Metadata;
#[test]
fn test_long_term_store() {
let store = LongTermStore::default();
let temporal_pattern =
TemporalPattern::from_embedding(vec![1.0, 2.0, 3.0], Metadata::new());
let id = temporal_pattern.pattern.id;
store.integrate(temporal_pattern);
assert_eq!(store.len(), 1);
assert!(store.get(&id).is_some());
}
#[test]
fn test_search() {
let store = LongTermStore::default();
let p1 = TemporalPattern::from_embedding(vec![1.0, 0.0, 0.0], Metadata::new());
let p2 = TemporalPattern::from_embedding(vec![0.0, 1.0, 0.0], Metadata::new());
store.integrate(p1);
store.integrate(p2);
let query = Query::from_embedding(vec![0.9, 0.1, 0.0]).with_k(1);
let results = store.search(&query);
assert_eq!(results.len(), 1);
assert!(results[0].score > 0.5);
}
#[test]
fn test_decay() {
let store = LongTermStore::default();
let mut temporal_pattern =
TemporalPattern::from_embedding(vec![1.0, 2.0, 3.0], Metadata::new());
temporal_pattern.pattern.salience = 0.15; let id = temporal_pattern.pattern.id;
store.integrate(temporal_pattern);
assert_eq!(store.len(), 1);
store.decay_low_salience(0.5);
assert_eq!(store.len(), 0);
}
}