use crate::error::{SwarmError, SwarmResult};
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::sync::Mutex;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct RetrievalPolicy {
pub top_k: usize,
pub score_threshold: f32,
pub recency_weight: f32,
}
impl Default for RetrievalPolicy {
fn default() -> Self {
Self {
top_k: 5,
score_threshold: 0.0,
recency_weight: 0.0,
}
}
}
impl RetrievalPolicy {
pub fn new(top_k: usize, score_threshold: f32, recency_weight: f32) -> SwarmResult<Self> {
if !(0.0..=1.0).contains(&score_threshold) {
return Err(SwarmError::ValidationError(
"score_threshold must be in [0.0, 1.0]".to_string(),
));
}
if !(0.0..=1.0).contains(&recency_weight) {
return Err(SwarmError::ValidationError(
"recency_weight must be in [0.0, 1.0]".to_string(),
));
}
Ok(Self {
top_k,
score_threshold,
recency_weight,
})
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct MemoryEntry {
pub id: String,
pub text: String,
pub embedding: Vec<f32>,
pub metadata: Value,
pub stored_at: DateTime<Utc>,
pub score: f32,
}
#[async_trait]
pub trait VectorMemory: Send + Sync {
async fn store(
&self,
id: &str,
text: &str,
embedding: Vec<f32>,
metadata: Value,
) -> SwarmResult<()>;
async fn search(
&self,
query_embedding: Vec<f32>,
policy: RetrievalPolicy,
) -> SwarmResult<Vec<MemoryEntry>>;
async fn delete(&self, id: &str) -> SwarmResult<()>;
async fn len(&self) -> SwarmResult<usize>;
async fn is_empty(&self) -> SwarmResult<bool> {
Ok(self.len().await? == 0)
}
}
struct StoredEntry {
id: String,
text: String,
embedding: Vec<f32>,
metadata: Value,
stored_at: DateTime<Utc>,
}
pub struct InMemoryVectorStore {
entries: Mutex<Vec<StoredEntry>>,
}
impl InMemoryVectorStore {
pub fn new() -> Self {
Self {
entries: Mutex::new(Vec::new()),
}
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if mag_a == 0.0 || mag_b == 0.0 {
0.0
} else {
(dot / (mag_a * mag_b)).clamp(-1.0, 1.0)
}
}
}
impl Default for InMemoryVectorStore {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl VectorMemory for InMemoryVectorStore {
async fn store(
&self,
id: &str,
text: &str,
embedding: Vec<f32>,
metadata: Value,
) -> SwarmResult<()> {
let mut entries = self
.entries
.lock()
.map_err(|e| SwarmError::Other(format!("InMemoryVectorStore lock poisoned: {}", e)))?;
if let Some(pos) = entries.iter().position(|e| e.id == id) {
entries[pos] = StoredEntry {
id: id.to_string(),
text: text.to_string(),
embedding,
metadata,
stored_at: Utc::now(),
};
} else {
entries.push(StoredEntry {
id: id.to_string(),
text: text.to_string(),
embedding,
metadata,
stored_at: Utc::now(),
});
}
Ok(())
}
async fn search(
&self,
query_embedding: Vec<f32>,
policy: RetrievalPolicy,
) -> SwarmResult<Vec<MemoryEntry>> {
let entries = self
.entries
.lock()
.map_err(|e| SwarmError::Other(format!("InMemoryVectorStore lock poisoned: {}", e)))?;
if entries.is_empty() {
return Ok(vec![]);
}
let now = Utc::now();
let mut scored: Vec<(f32, usize)> = entries
.iter()
.enumerate()
.map(|(i, e)| {
let semantic = Self::cosine_similarity(&query_embedding, &e.embedding);
let age_secs = (now - e.stored_at).num_seconds().max(0) as f32;
let recency = (-age_secs / 86_400.0).exp();
let combined =
(1.0 - policy.recency_weight) * semantic + policy.recency_weight * recency;
(combined, i)
})
.filter(|(score, _)| *score >= policy.score_threshold)
.collect();
scored.sort_unstable_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(policy.top_k);
let results = scored
.into_iter()
.map(|(score, i)| {
let e = &entries[i];
MemoryEntry {
id: e.id.clone(),
text: e.text.clone(),
embedding: e.embedding.clone(),
metadata: e.metadata.clone(),
stored_at: e.stored_at,
score,
}
})
.collect();
Ok(results)
}
async fn delete(&self, id: &str) -> SwarmResult<()> {
let mut entries = self
.entries
.lock()
.map_err(|e| SwarmError::Other(format!("InMemoryVectorStore lock poisoned: {}", e)))?;
entries.retain(|e| e.id != id);
Ok(())
}
async fn len(&self) -> SwarmResult<usize> {
let entries = self
.entries
.lock()
.map_err(|e| SwarmError::Other(format!("InMemoryVectorStore lock poisoned: {}", e)))?;
Ok(entries.len())
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn vec2(x: f32, y: f32) -> Vec<f32> {
vec![x, y]
}
#[tokio::test]
async fn test_store_and_search() {
let store = InMemoryVectorStore::new();
store
.store("a", "text a", vec2(1.0, 0.0), json!({"tag": "a"}))
.await
.unwrap();
store
.store("b", "text b", vec2(0.0, 1.0), json!({"tag": "b"}))
.await
.unwrap();
let results = store
.search(vec2(1.0, 0.0), RetrievalPolicy::default())
.await
.unwrap();
assert_eq!(results[0].id, "a");
assert!((results[0].score - 1.0).abs() < 1e-5);
}
#[tokio::test]
async fn test_upsert() {
let store = InMemoryVectorStore::new();
store
.store("x", "v1", vec2(1.0, 0.0), json!({}))
.await
.unwrap();
store
.store("x", "v2", vec2(1.0, 0.0), json!({}))
.await
.unwrap();
assert_eq!(store.len().await.unwrap(), 1);
let results = store
.search(vec2(1.0, 0.0), RetrievalPolicy::default())
.await
.unwrap();
assert_eq!(results[0].text, "v2");
}
#[tokio::test]
async fn test_delete() {
let store = InMemoryVectorStore::new();
store
.store("d", "text", vec2(1.0, 0.0), json!({}))
.await
.unwrap();
store.delete("d").await.unwrap();
assert_eq!(store.len().await.unwrap(), 0);
}
#[tokio::test]
async fn test_score_threshold() {
let store = InMemoryVectorStore::new();
store
.store("a", "text a", vec2(1.0, 0.0), json!({}))
.await
.unwrap();
store
.store("b", "text b", vec2(0.0, 1.0), json!({}))
.await
.unwrap();
let policy = RetrievalPolicy::new(5, 0.9, 0.0).unwrap();
let results = store.search(vec2(1.0, 0.0), policy).await.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].id, "a");
}
#[test]
fn test_cosine_zero_vector() {
assert_eq!(
InMemoryVectorStore::cosine_similarity(&[0.0, 0.0], &[1.0, 0.0]),
0.0
);
}
}