use std::collections::VecDeque;
use tokio::sync::Mutex;
use crate::error::Result;
use crate::memory::traits::Memory;
use crate::model::types::Message;
fn json_value_len(v: &serde_json::Value) -> usize {
match v {
serde_json::Value::Null => 4,
serde_json::Value::Bool(b) => if *b { 4 } else { 5 },
serde_json::Value::Number(n) => {
let mut buf = itoa::Buffer::new();
if let Some(i) = n.as_i64() {
buf.format(i).len()
} else if let Some(u) = n.as_u64() {
buf.format(u).len()
} else {
let mut fbuf = ryu::Buffer::new();
fbuf.format(n.as_f64().unwrap_or(0.0)).len()
}
}
serde_json::Value::String(s) => s.len() + 2,
serde_json::Value::Array(arr) => {
2 + arr.iter().map(|v| json_value_len(v) + 1).sum::<usize>()
}
serde_json::Value::Object(map) => {
2 + map.iter().map(|(k, v)| k.len() + 3 + json_value_len(v) + 1).sum::<usize>()
}
}
}
fn default_estimate_tokens(msg: &Message) -> usize {
let mut chars = 0usize;
if let Some(ref content) = msg.content {
chars += content.len();
}
for tc in &msg.tool_calls {
chars += tc.name.len();
chars += json_value_len(&tc.arguments);
}
if let Some(ref id) = msg.tool_call_id {
chars += id.len();
}
chars += 6;
chars.div_ceil(4)
}
type TokenCounterFn = Box<dyn Fn(&Message) -> usize + Send + Sync>;
pub struct TokenWindowMemory {
inner: Mutex<TokenWindowInner>,
max_tokens: usize,
token_counter: TokenCounterFn,
}
struct TokenWindowInner {
messages: VecDeque<Message>,
token_counts: VecDeque<usize>,
total_tokens: usize,
}
impl TokenWindowMemory {
pub fn new(max_tokens: usize) -> Self {
Self {
inner: Mutex::new(TokenWindowInner {
messages: VecDeque::new(),
token_counts: VecDeque::new(),
total_tokens: 0,
}),
max_tokens,
token_counter: Box::new(default_estimate_tokens),
}
}
pub fn with_token_counter<F>(mut self, counter: F) -> Self
where
F: Fn(&Message) -> usize + Send + Sync + 'static,
{
self.token_counter = Box::new(counter);
self
}
pub async fn current_tokens(&self) -> usize {
self.inner.lock().await.total_tokens
}
}
impl Memory for TokenWindowMemory {
async fn add_message(&self, message: Message) -> Result<()> {
let tokens = (self.token_counter)(&message);
let mut inner = self.inner.lock().await;
inner.messages.push_back(message);
inner.token_counts.push_back(tokens);
inner.total_tokens += tokens;
while inner.total_tokens > self.max_tokens && inner.messages.len() > 1 {
if let Some(removed_tokens) = inner.token_counts.pop_front() {
inner.messages.pop_front();
inner.total_tokens -= removed_tokens;
}
}
Ok(())
}
async fn get_messages(&self) -> Result<Vec<Message>> {
let mut inner = self.inner.lock().await;
Ok(inner.messages.make_contiguous().to_vec())
}
async fn clear(&self) -> Result<()> {
let mut inner = self.inner.lock().await;
inner.messages.clear();
inner.token_counts.clear();
inner.total_tokens = 0;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::model::types::Role;
use crate::tool::ToolCall;
#[tokio::test]
async fn test_add_and_get_messages() {
let memory = TokenWindowMemory::new(10_000);
memory.add_message(Message::user("hello")).await.unwrap();
memory
.add_message(Message::assistant("hi"))
.await
.unwrap();
let msgs = memory.get_messages().await.unwrap();
assert_eq!(msgs.len(), 2);
assert_eq!(msgs[0].role, Role::User);
assert_eq!(msgs[1].role, Role::Assistant);
}
#[tokio::test]
async fn test_evicts_old_messages_when_over_budget() {
let memory = TokenWindowMemory::new(20).with_token_counter(|msg| {
msg.content.as_ref().map_or(0, |c| c.len())
});
memory
.add_message(Message::user("aaaaaaaaaa"))
.await
.unwrap();
memory
.add_message(Message::user("bbbbbbbbbb"))
.await
.unwrap();
assert_eq!(memory.get_messages().await.unwrap().len(), 2);
memory
.add_message(Message::user("cccccccccc"))
.await
.unwrap();
let msgs = memory.get_messages().await.unwrap();
assert_eq!(msgs.len(), 2);
assert_eq!(msgs[0].content.as_deref(), Some("bbbbbbbbbb"));
assert_eq!(msgs[1].content.as_deref(), Some("cccccccccc"));
}
#[tokio::test]
async fn test_evicts_multiple_to_fit() {
let memory = TokenWindowMemory::new(15).with_token_counter(|msg| {
msg.content.as_ref().map_or(0, |c| c.len())
});
memory.add_message(Message::user("aaa")).await.unwrap(); memory.add_message(Message::user("bbb")).await.unwrap(); memory.add_message(Message::user("ccc")).await.unwrap();
memory
.add_message(Message::user("dddddddd"))
.await
.unwrap();
let msgs = memory.get_messages().await.unwrap();
assert_eq!(msgs.len(), 3);
assert_eq!(msgs[0].content.as_deref(), Some("bbb"));
assert_eq!(msgs[1].content.as_deref(), Some("ccc"));
assert_eq!(msgs[2].content.as_deref(), Some("dddddddd"));
}
#[tokio::test]
async fn test_clear_resets_tokens() {
let memory = TokenWindowMemory::new(100);
memory.add_message(Message::user("hello")).await.unwrap();
assert!(memory.current_tokens().await > 0);
memory.clear().await.unwrap();
assert_eq!(memory.current_tokens().await, 0);
assert!(memory.get_messages().await.unwrap().is_empty());
}
#[tokio::test]
async fn test_default_estimator_counts_tool_calls() {
let memory = TokenWindowMemory::new(10_000);
let msg = Message::assistant_with_tool_calls(vec![ToolCall {
id: "tc_1".into(),
name: "calculator".into(),
arguments: serde_json::json!({"expression": "2+2"}),
}]);
memory.add_message(msg).await.unwrap();
assert!(memory.current_tokens().await > 0);
}
#[tokio::test]
async fn test_custom_token_counter() {
let memory = TokenWindowMemory::new(5).with_token_counter(|_| 1);
for i in 0..7 {
memory
.add_message(Message::user(format!("msg{i}")))
.await
.unwrap();
}
let msgs = memory.get_messages().await.unwrap();
assert_eq!(msgs.len(), 5);
assert_eq!(msgs[0].content.as_deref(), Some("msg2"));
}
#[tokio::test]
async fn test_single_message_exceeds_budget() {
let memory = TokenWindowMemory::new(5).with_token_counter(|msg| {
msg.content.as_ref().map_or(0, |c| c.len())
});
memory
.add_message(Message::user("short"))
.await
.unwrap();
memory
.add_message(Message::user("this is a very long message"))
.await
.unwrap();
let msgs = memory.get_messages().await.unwrap();
assert_eq!(msgs.len(), 1);
assert_eq!(
msgs[0].content.as_deref(),
Some("this is a very long message")
);
}
#[tokio::test]
async fn test_empty_memory() {
let memory = TokenWindowMemory::new(100);
assert_eq!(memory.current_tokens().await, 0);
assert!(memory.get_messages().await.unwrap().is_empty());
}
}