Skip to main content

abu_agent/memory/
hierarchical.rs

1use abu_base::chat::ChatMessage;
2use abu_provider::EmbedProvide;
3use super::{retrieval::RetrievalMemoryError, Memory, RetrievalMemory, SliceWindowMemory};
4
5pub struct HierarchicalMemory<P> {
6    working_memory: SliceWindowMemory,
7    long_term_memory: RetrievalMemory<P>,
8    promotion_keywords: Vec<&'static str>,
9}
10
11impl<P: EmbedProvide> HierarchicalMemory<P> {
12    pub fn new(window_size: usize, provider: P, model: impl Into<String>, top_k: usize) -> Self {
13        let working_memory = SliceWindowMemory::new(window_size);
14        let long_term_memory = RetrievalMemory::new(provider, model, top_k);
15        Self {
16            working_memory, 
17            long_term_memory,
18            promotion_keywords: vec![
19                "remember", "rule", "preference", "always", "never", "allergic"
20            ],
21        }
22    }
23}
24
25impl<P: EmbedProvide> Memory for HierarchicalMemory<P> {
26    type Error = RetrievalMemoryError;
27
28    async fn add(&mut self, user_input: &str, ai_response: &str) -> Result<(), Self::Error> {
29        self.working_memory.add(user_input, ai_response).await.unwrap();
30        
31        let user_input_lower = user_input.to_lowercase();
32        let has_keyword = self.promotion_keywords.iter()
33            .any(|&promotion_keyword| user_input_lower.contains(promotion_keyword));
34        if has_keyword {
35            self.long_term_memory.add(user_input, ai_response).await?;
36        }
37
38        Ok(())
39    }
40    
41    async fn search(&self, query: &str) -> Result<Vec<ChatMessage>, Self::Error> {
42        let working_messages = self.working_memory.search(query).await.unwrap();
43        let long_term_messages = self.long_term_memory.search(query).await?;
44        
45        let mut messages = vec![];
46        messages.push(ChatMessage::user("Retrieved Long-Term Memories:"));
47        messages.extend(long_term_messages);
48        messages.push(ChatMessage::user("Recent Conversation (Working Memory):"));
49        messages.extend(working_messages);
50
51        Ok(messages)
52    }
53
54    async fn clear(&mut self) -> Result<(), Self::Error> {
55        self.working_memory.clear().await.unwrap();
56        self.long_term_memory.clear().await?;
57        Ok(())
58    }
59}