use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use super::config::VectorStoreConfig;
use super::embedding_client::{cosine_similarity, EmbeddingClient};
use super::types::InteractionRecord;
use crate::Result;
pub struct VectorMemoryStore {
storage: Arc<RwLock<HashMap<String, Vec<InteractionRecord>>>>,
embedding_client: Option<Arc<EmbeddingClient>>,
config: VectorStoreConfig,
}
impl VectorMemoryStore {
pub fn new(config: VectorStoreConfig) -> Self {
let embedding_client = if config.enabled {
Some(Arc::new(EmbeddingClient::new(
config.embedding_provider.clone(),
config.embedding_model.clone(),
None,
None,
)))
} else {
None
};
Self {
storage: Arc::new(RwLock::new(HashMap::new())),
embedding_client,
config,
}
}
pub async fn store_interaction(
&self,
session_id: &str,
interaction: &InteractionRecord,
) -> Result<()> {
let mut interaction_with_embedding = interaction.clone();
if let Some(ref client) = self.embedding_client {
let summary = interaction.summary();
match client.generate_embedding(&summary).await {
Ok(embedding) => {
interaction_with_embedding.embedding = Some(embedding);
}
Err(e) => {
tracing::warn!("Failed to generate embedding: {}", e);
}
}
}
let mut storage = self.storage.write().await;
storage
.entry(session_id.to_string())
.or_insert_with(Vec::new)
.push(interaction_with_embedding);
Ok(())
}
pub async fn retrieve_context(
&self,
session_id: &str,
query: &str,
limit: usize,
) -> Result<Vec<InteractionRecord>> {
let storage = self.storage.read().await;
let interactions = storage.get(session_id).cloned().unwrap_or_default();
if self.embedding_client.is_none() || interactions.is_empty() {
return Ok(interactions.into_iter().rev().take(limit).collect());
}
let query_embedding = match &self.embedding_client {
Some(client) => match client.generate_embedding(query).await {
Ok(emb) => emb,
Err(e) => {
tracing::warn!("Failed to generate query embedding: {}", e);
return Ok(interactions.into_iter().rev().take(limit).collect());
}
},
None => return Ok(interactions.into_iter().rev().take(limit).collect()),
};
let mut scored_interactions: Vec<(InteractionRecord, f32)> = interactions
.into_iter()
.filter_map(|interaction| {
interaction.embedding.as_ref().map(|emb| {
let score = cosine_similarity(&query_embedding, emb);
(interaction.clone(), score)
})
})
.collect();
scored_interactions.retain(|(_, score)| *score >= self.config.similarity_threshold);
scored_interactions
.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
Ok(scored_interactions
.into_iter()
.take(limit)
.map(|(interaction, _)| interaction)
.collect())
}
pub async fn get_session_interactions(
&self,
session_id: &str,
) -> Result<Vec<InteractionRecord>> {
let storage = self.storage.read().await;
Ok(storage.get(session_id).cloned().unwrap_or_default())
}
pub async fn clear_session(&self, session_id: &str) -> Result<()> {
let mut storage = self.storage.write().await;
storage.remove(session_id);
Ok(())
}
pub async fn clear_all(&self) -> Result<()> {
let mut storage = self.storage.write().await;
storage.clear();
Ok(())
}
}
impl Default for VectorMemoryStore {
fn default() -> Self {
Self::new(VectorStoreConfig::default())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_store_and_retrieve() {
let store = VectorMemoryStore::new(VectorStoreConfig::default());
let interaction = InteractionRecord::new(
"POST",
"/api/users",
Some(serde_json::json!({"name": "Alice"})),
201,
Some(serde_json::json!({"id": "user_1", "name": "Alice"})),
);
store.store_interaction("session_1", &interaction).await.unwrap();
let retrieved = store.retrieve_context("session_1", "user creation", 10).await.unwrap();
assert_eq!(retrieved.len(), 1);
assert_eq!(retrieved[0].method, "POST");
}
#[tokio::test]
async fn test_clear_session() {
let store = VectorMemoryStore::new(VectorStoreConfig::default());
let interaction = InteractionRecord::new("GET", "/api/users", None, 200, None);
store.store_interaction("session_1", &interaction).await.unwrap();
store.clear_session("session_1").await.unwrap();
let retrieved = store.get_session_interactions("session_1").await.unwrap();
assert!(retrieved.is_empty());
}
}