Skip to main content

abu_agent/memory/
retrieval.rs

1use abu_base::chat::ChatMessage;
2use abu_provider::EmbedProvide;
3use abu_rag::{embed::{EmbedError, Embedder}, vectordb::{FlatL2Index, InMemoryStorage, VectorDB, VectorDBError, VectorId}};
4use super::Memory;
5
6pub struct RetrievalMemory<P> {
7    embedder: Embedder<P>,
8    db: VectorDB<FlatL2Index<f32>, InMemoryStorage<String>>,
9    top_k: usize,
10}
11
12impl<P: EmbedProvide> RetrievalMemory<P> {
13    pub fn new(provider: P, model: impl Into<String>, top_k: usize) -> Self {
14        Self { 
15            top_k, 
16            embedder: Embedder::new(provider, model), 
17            db: VectorDB::new(FlatL2Index::new(), InMemoryStorage::new()) 
18        }
19    }
20
21    fn new_id(&self) -> VectorId {
22        self.db.len() as VectorId
23    }
24}
25
26impl<P: EmbedProvide> Memory for RetrievalMemory<P> {
27    type Error = RetrievalMemoryError;
28
29    async fn add(&mut self, user_input: &str, ai_response: &str) -> Result<(), Self::Error> {
30        let content = format!("User: {}\nAI: {}", user_input, ai_response);
31        let embedding = self.embedder.embed_text(content.clone()).await?;
32        self.db.add(self.new_id(), embedding, content).await?;
33        Ok(())
34    }
35    
36    async fn search(&self, query: &str) -> Result<Vec<ChatMessage>, Self::Error> {
37        let query_embedding = self.embedder.embed_text(query).await?;
38        let result = self.db.search(&query_embedding, self.top_k).await?;
39        let retrieved_docs = result.iter()
40            .map(|(_, s)| s.as_ref().as_str())
41            .collect::<Vec<_>>()
42            .join("\n\n");
43        let message = ChatMessage::user(format!("### Relevant Information Retrieved from Memory:\n{}", retrieved_docs));
44        Ok(vec![message])
45    }
46
47    async fn clear(&mut self) -> Result<(), Self::Error> {
48        self.db.clear().await?;
49        Ok(())
50    }
51}
52
53#[derive(Debug, thiserror::Error)]
54pub enum RetrievalMemoryError {
55    #[error(transparent)]
56    Embed(#[from] EmbedError),
57
58    #[error(transparent)]
59    VectorDB(#[from] VectorDBError)
60}