use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use serde_json::Value;
use tokio::sync::Mutex;
use cognis_core::documents::Document;
use cognis_core::error::Result;
use cognis_core::messages::Message;
use cognis_core::vectorstores::base::VectorStore;
use super::BaseMemory;
pub struct VectorStoreMemory {
vectorstore: Arc<dyn VectorStore>,
memory_key: String,
input_key: String,
k: usize,
return_docs: bool,
stored_ids: Arc<Mutex<Vec<String>>>,
last_query: Arc<Mutex<Option<String>>>,
}
impl VectorStoreMemory {
pub fn new(vectorstore: Arc<dyn VectorStore>) -> Self {
Self {
vectorstore,
memory_key: "history".to_string(),
input_key: "input".to_string(),
k: 4,
return_docs: false,
stored_ids: Arc::new(Mutex::new(Vec::new())),
last_query: Arc::new(Mutex::new(None)),
}
}
pub fn with_memory_key(mut self, key: impl Into<String>) -> Self {
self.memory_key = key.into();
self
}
pub fn with_input_key(mut self, key: impl Into<String>) -> Self {
self.input_key = key.into();
self
}
pub fn with_k(mut self, k: usize) -> Self {
self.k = k;
self
}
pub fn with_return_docs(mut self, return_docs: bool) -> Self {
self.return_docs = return_docs;
self
}
pub async fn retrieve_relevant(&self, query: &str, k: Option<usize>) -> Result<Vec<Document>> {
let num = k.unwrap_or(self.k);
self.vectorstore.similarity_search(query, num).await
}
pub async fn add_memory(
&self,
text: &str,
metadata: Option<HashMap<String, Value>>,
) -> Result<()> {
let texts = vec![text.to_string()];
let metadatas = metadata.map(|m| vec![m]);
let metadatas_ref = metadatas.as_deref();
let ids = self
.vectorstore
.add_texts(&texts, metadatas_ref, None)
.await?;
let mut stored = self.stored_ids.lock().await;
stored.extend(ids);
Ok(())
}
pub async fn set_query(&self, query: &str) {
let mut last = self.last_query.lock().await;
*last = Some(query.to_string());
}
}
#[async_trait]
impl BaseMemory for VectorStoreMemory {
async fn load_memory_variables(&self) -> Result<HashMap<String, Value>> {
let mut vars = HashMap::new();
let query = {
let last = self.last_query.lock().await;
last.clone()
};
let Some(query) = query else {
if self.return_docs {
vars.insert(self.memory_key.clone(), Value::Array(vec![]));
} else {
vars.insert(self.memory_key.clone(), Value::String(String::new()));
}
return Ok(vars);
};
let docs = self.retrieve_relevant(&query, None).await?;
if self.return_docs {
let serialized: Vec<Value> = docs
.iter()
.map(|d| serde_json::to_value(d).unwrap_or(Value::Null))
.collect();
vars.insert(self.memory_key.clone(), Value::Array(serialized));
} else {
let text = docs
.iter()
.map(|d| d.page_content.as_str())
.collect::<Vec<_>>()
.join("\n");
vars.insert(self.memory_key.clone(), Value::String(text));
}
Ok(vars)
}
async fn save_context(&self, input: &Message, output: &Message) -> Result<()> {
let input_text = input.content().text();
let output_text = output.content().text();
let combined = format!("Human: {}\nAI: {}", input_text, output_text);
{
let mut last = self.last_query.lock().await;
*last = Some(input_text.clone());
}
let texts = vec![combined];
let ids = self.vectorstore.add_texts(&texts, None, None).await?;
let mut stored = self.stored_ids.lock().await;
stored.extend(ids);
Ok(())
}
async fn clear(&self) -> Result<()> {
let mut stored = self.stored_ids.lock().await;
if !stored.is_empty() {
let id_list: Vec<String> = stored.drain(..).collect();
self.vectorstore.delete(Some(&id_list)).await?;
}
let mut last = self.last_query.lock().await;
*last = None;
Ok(())
}
fn memory_key(&self) -> &str {
&self.memory_key
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::vectorstores::in_memory::InMemoryVectorStore;
use cognis_core::embeddings_fake::DeterministicFakeEmbedding;
use cognis_core::messages::Message;
fn make_store() -> Arc<InMemoryVectorStore> {
let embeddings = Arc::new(DeterministicFakeEmbedding::new(16));
Arc::new(InMemoryVectorStore::new(embeddings))
}
#[tokio::test]
async fn test_save_and_retrieve_memory() {
let store = make_store();
let mem = VectorStoreMemory::new(store.clone());
mem.add_memory("The user's favorite color is blue.", None)
.await
.unwrap();
let docs = mem
.retrieve_relevant("What is the user's favorite color?", None)
.await
.unwrap();
assert_eq!(docs.len(), 1);
assert!(docs[0].page_content.contains("blue"));
}
#[tokio::test]
async fn test_save_multiple_and_retrieve_most_relevant() {
let store = make_store();
let mem = VectorStoreMemory::new(store.clone()).with_k(2);
mem.add_memory("I love programming in Rust.", None)
.await
.unwrap();
mem.add_memory("The weather in Paris is sunny today.", None)
.await
.unwrap();
mem.add_memory("Rust has zero-cost abstractions.", None)
.await
.unwrap();
let docs = mem
.retrieve_relevant("Tell me about Rust programming", None)
.await
.unwrap();
assert_eq!(docs.len(), 2);
}
#[tokio::test]
async fn test_retrieve_with_custom_k() {
let store = make_store();
let mem = VectorStoreMemory::new(store.clone());
mem.add_memory("Fact one", None).await.unwrap();
mem.add_memory("Fact two", None).await.unwrap();
mem.add_memory("Fact three", None).await.unwrap();
let docs = mem.retrieve_relevant("facts", Some(1)).await.unwrap();
assert_eq!(docs.len(), 1);
let docs = mem.retrieve_relevant("facts", Some(3)).await.unwrap();
assert_eq!(docs.len(), 3);
}
#[tokio::test]
async fn test_save_context_formats_correctly() {
let store = make_store();
let mem = VectorStoreMemory::new(store.clone());
let human = Message::human("What is Rust?");
let ai = Message::ai("Rust is a systems programming language.");
mem.save_context(&human, &ai).await.unwrap();
let docs = mem
.retrieve_relevant("Rust programming language", None)
.await
.unwrap();
assert_eq!(docs.len(), 1);
assert!(docs[0].page_content.contains("Human: What is Rust?"));
assert!(docs[0]
.page_content
.contains("AI: Rust is a systems programming language."));
}
#[tokio::test]
async fn test_clear_tracks_ids_for_deletion() {
let store = make_store();
let mem = VectorStoreMemory::new(store.clone());
mem.add_memory("Memory to be cleared", None).await.unwrap();
mem.save_context(&Message::human("Hi"), &Message::ai("Hello"))
.await
.unwrap();
{
let stored = mem.stored_ids.lock().await;
assert_eq!(stored.len(), 2);
}
mem.clear().await.unwrap();
{
let stored = mem.stored_ids.lock().await;
assert!(stored.is_empty());
}
let docs = mem.retrieve_relevant("anything", Some(10)).await.unwrap();
assert!(docs.is_empty());
}
#[tokio::test]
async fn test_load_memory_variables_after_setting_query() {
let store = make_store();
let mem = VectorStoreMemory::new(store.clone());
let vars = mem.load_memory_variables().await.unwrap();
let history = vars.get("history").unwrap().as_str().unwrap();
assert!(history.is_empty());
mem.save_context(
&Message::human("Tell me about Rust"),
&Message::ai("Rust is great!"),
)
.await
.unwrap();
let vars = mem.load_memory_variables().await.unwrap();
let history = vars.get("history").unwrap().as_str().unwrap();
assert!(history.contains("Human: Tell me about Rust"));
assert!(history.contains("AI: Rust is great!"));
mem.set_query("custom query").await;
let vars = mem.load_memory_variables().await.unwrap();
assert!(vars.contains_key("history"));
}
}