abu-agent 0.2.0

Agent development library
Documentation
use abu_base::chat::ChatMessage;
use abu_provider::EmbedProvide;
use abu_rag::{embed::{EmbedError, Embedder}, vectordb::{FlatL2Index, InMemoryStorage, VectorDB, VectorDBError, VectorId}};
use super::Memory;

pub struct RetrievalMemory<P> {
    embedder: Embedder<P>,
    db: VectorDB<FlatL2Index<f32>, InMemoryStorage<String>>,
    top_k: usize,
}

impl<P: EmbedProvide> RetrievalMemory<P> {
    pub fn new(provider: P, model: impl Into<String>, top_k: usize) -> Self {
        Self { 
            top_k, 
            embedder: Embedder::new(provider, model), 
            db: VectorDB::new(FlatL2Index::new(), InMemoryStorage::new()) 
        }
    }

    fn new_id(&self) -> VectorId {
        self.db.len() as VectorId
    }
}

impl<P: EmbedProvide> Memory for RetrievalMemory<P> {
    type Error = RetrievalMemoryError;

    async fn add(&mut self, user_input: &str, ai_response: &str) -> Result<(), Self::Error> {
        let content = format!("User: {}\nAI: {}", user_input, ai_response);
        let embedding = self.embedder.embed_text(content.clone()).await?;
        self.db.add(self.new_id(), embedding, content).await?;
        Ok(())
    }
    
    async fn search(&self, query: &str) -> Result<Vec<ChatMessage>, Self::Error> {
        let query_embedding = self.embedder.embed_text(query).await?;
        let result = self.db.search(&query_embedding, self.top_k).await?;
        let retrieved_docs = result.iter()
            .map(|(_, s)| s.as_ref().as_str())
            .collect::<Vec<_>>()
            .join("\n\n");
        let message = ChatMessage::user(format!("### Relevant Information Retrieved from Memory:\n{}", retrieved_docs));
        Ok(vec![message])
    }

    async fn clear(&mut self) -> Result<(), Self::Error> {
        self.db.clear().await?;
        Ok(())
    }
}

#[derive(Debug, thiserror::Error)]
pub enum RetrievalMemoryError {
    #[error(transparent)]
    Embed(#[from] EmbedError),

    #[error(transparent)]
    VectorDB(#[from] VectorDBError)
}