Skip to main content

agentrs_memory/
token_aware.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use agentrs_core::{Memory, Message, Result, Role};
6
7use crate::SearchableMemory;
8
9/// Counts approximate tokens for a string.
10pub trait Tokenizer: Send + Sync + 'static {
11    /// Returns the estimated token count.
12    fn count(&self, text: &str) -> usize;
13}
14
15/// Lightweight tokenizer approximation that avoids external dependencies.
16#[derive(Debug, Clone, Copy, Default)]
17pub struct ApproximateTokenizer;
18
19impl Tokenizer for ApproximateTokenizer {
20    fn count(&self, text: &str) -> usize {
21        ((text.chars().count() as f32) / 3.5).ceil() as usize
22    }
23}
24
25/// Memory backend that trims history to fit a token budget.
26pub struct TokenAwareMemory {
27    messages: Vec<Message>,
28    max_tokens: usize,
29    tokenizer: Arc<dyn Tokenizer>,
30}
31
32impl TokenAwareMemory {
33    /// Creates a token-aware backend with the default tokenizer.
34    pub fn new(max_tokens: usize) -> Self {
35        Self {
36            messages: Vec::new(),
37            max_tokens,
38            tokenizer: Arc::new(ApproximateTokenizer),
39        }
40    }
41
42    /// Creates a token-aware backend with a custom tokenizer.
43    pub fn with_tokenizer(max_tokens: usize, tokenizer: Arc<dyn Tokenizer>) -> Self {
44        Self {
45            messages: Vec::new(),
46            max_tokens,
47            tokenizer,
48        }
49    }
50
51    fn total_tokens(&self) -> usize {
52        self.messages
53            .iter()
54            .map(|message| self.tokenizer.count(&message.text_content()))
55            .sum()
56    }
57
58    fn trim_to_budget(&mut self) {
59        while self.total_tokens() > self.max_tokens && self.messages.len() > 1 {
60            if let Some(index) = self
61                .messages
62                .iter()
63                .position(|message| !matches!(message.role, Role::System))
64            {
65                self.messages.remove(index);
66            } else {
67                break;
68            }
69        }
70    }
71}
72
73#[async_trait]
74impl Memory for TokenAwareMemory {
75    async fn store(&mut self, _key: &str, value: Message) -> Result<()> {
76        self.messages.push(value);
77        self.trim_to_budget();
78        Ok(())
79    }
80
81    async fn retrieve(&self, query: &str, limit: usize) -> Result<Vec<Message>> {
82        let query = query.to_lowercase();
83        Ok(self
84            .messages
85            .iter()
86            .filter(|message| message.text_content().to_lowercase().contains(&query))
87            .take(limit)
88            .cloned()
89            .collect())
90    }
91
92    async fn history(&self) -> Result<Vec<Message>> {
93        Ok(self.messages.clone())
94    }
95
96    async fn clear(&mut self) -> Result<()> {
97        self.messages.clear();
98        Ok(())
99    }
100}
101
102#[async_trait]
103impl SearchableMemory for TokenAwareMemory {
104    async fn token_count(&self) -> Result<usize> {
105        Ok(self.total_tokens())
106    }
107}