abu_agent/memory/
retrieval.rs1use abu_base::chat::ChatMessage;
2use abu_provider::EmbedProvide;
3use abu_rag::{embed::{EmbedError, Embedder}, vectordb::{FlatL2Index, InMemoryStorage, VectorDB, VectorDBError, VectorId}};
4use super::Memory;
5
6pub struct RetrievalMemory<P> {
7 embedder: Embedder<P>,
8 db: VectorDB<FlatL2Index<f32>, InMemoryStorage<String>>,
9 top_k: usize,
10}
11
12impl<P: EmbedProvide> RetrievalMemory<P> {
13 pub fn new(provider: P, model: impl Into<String>, top_k: usize) -> Self {
14 Self {
15 top_k,
16 embedder: Embedder::new(provider, model),
17 db: VectorDB::new(FlatL2Index::new(), InMemoryStorage::new())
18 }
19 }
20
21 fn new_id(&self) -> VectorId {
22 self.db.len() as VectorId
23 }
24}
25
26impl<P: EmbedProvide> Memory for RetrievalMemory<P> {
27 type Error = RetrievalMemoryError;
28
29 async fn add(&mut self, user_input: &str, ai_response: &str) -> Result<(), Self::Error> {
30 let content = format!("User: {}\nAI: {}", user_input, ai_response);
31 let embedding = self.embedder.embed_text(content.clone()).await?;
32 self.db.add(self.new_id(), embedding, content).await?;
33 Ok(())
34 }
35
36 async fn search(&self, query: &str) -> Result<Vec<ChatMessage>, Self::Error> {
37 let query_embedding = self.embedder.embed_text(query).await?;
38 let result = self.db.search(&query_embedding, self.top_k).await?;
39 let retrieved_docs = result.iter()
40 .map(|(_, s)| s.as_ref().as_str())
41 .collect::<Vec<_>>()
42 .join("\n\n");
43 let message = ChatMessage::user(format!("### Relevant Information Retrieved from Memory:\n{}", retrieved_docs));
44 Ok(vec![message])
45 }
46
47 async fn clear(&mut self) -> Result<(), Self::Error> {
48 self.db.clear().await?;
49 Ok(())
50 }
51}
52
53#[derive(Debug, thiserror::Error)]
54pub enum RetrievalMemoryError {
55 #[error(transparent)]
56 Embed(#[from] EmbedError),
57
58 #[error(transparent)]
59 VectorDB(#[from] VectorDBError)
60}