agentrs_memory/
token_aware.rs1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use agentrs_core::{Memory, Message, Result, Role};
6
7use crate::SearchableMemory;
8
9pub trait Tokenizer: Send + Sync + 'static {
11 fn count(&self, text: &str) -> usize;
13}
14
15#[derive(Debug, Clone, Copy, Default)]
17pub struct ApproximateTokenizer;
18
19impl Tokenizer for ApproximateTokenizer {
20 fn count(&self, text: &str) -> usize {
21 ((text.chars().count() as f32) / 3.5).ceil() as usize
22 }
23}
24
25pub struct TokenAwareMemory {
27 messages: Vec<Message>,
28 max_tokens: usize,
29 tokenizer: Arc<dyn Tokenizer>,
30}
31
32impl TokenAwareMemory {
33 pub fn new(max_tokens: usize) -> Self {
35 Self {
36 messages: Vec::new(),
37 max_tokens,
38 tokenizer: Arc::new(ApproximateTokenizer),
39 }
40 }
41
42 pub fn with_tokenizer(max_tokens: usize, tokenizer: Arc<dyn Tokenizer>) -> Self {
44 Self {
45 messages: Vec::new(),
46 max_tokens,
47 tokenizer,
48 }
49 }
50
51 fn total_tokens(&self) -> usize {
52 self.messages
53 .iter()
54 .map(|message| self.tokenizer.count(&message.text_content()))
55 .sum()
56 }
57
58 fn trim_to_budget(&mut self) {
59 while self.total_tokens() > self.max_tokens && self.messages.len() > 1 {
60 if let Some(index) = self
61 .messages
62 .iter()
63 .position(|message| !matches!(message.role, Role::System))
64 {
65 self.messages.remove(index);
66 } else {
67 break;
68 }
69 }
70 }
71}
72
73#[async_trait]
74impl Memory for TokenAwareMemory {
75 async fn store(&mut self, _key: &str, value: Message) -> Result<()> {
76 self.messages.push(value);
77 self.trim_to_budget();
78 Ok(())
79 }
80
81 async fn retrieve(&self, query: &str, limit: usize) -> Result<Vec<Message>> {
82 let query = query.to_lowercase();
83 Ok(self
84 .messages
85 .iter()
86 .filter(|message| message.text_content().to_lowercase().contains(&query))
87 .take(limit)
88 .cloned()
89 .collect())
90 }
91
92 async fn history(&self) -> Result<Vec<Message>> {
93 Ok(self.messages.clone())
94 }
95
96 async fn clear(&mut self) -> Result<()> {
97 self.messages.clear();
98 Ok(())
99 }
100}
101
102#[async_trait]
103impl SearchableMemory for TokenAwareMemory {
104 async fn token_count(&self) -> Result<usize> {
105 Ok(self.total_tokens())
106 }
107}