use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use serde_json::Value;
use tokio::sync::RwLock;
use cognis_core::error::Result;
use cognis_core::messages::{get_buffer_string, Message};
use super::BaseMemory;
pub trait TokenCounter: Send + Sync {
fn count_tokens(&self, text: &str) -> usize;
}
#[derive(Debug, Clone, Copy, Default)]
pub struct SimpleTokenCounter;
impl SimpleTokenCounter {
pub fn new() -> Self {
Self
}
}
impl TokenCounter for SimpleTokenCounter {
fn count_tokens(&self, text: &str) -> usize {
text.split_whitespace().count() * 4 / 3
}
}
#[derive(Debug, Clone, Copy)]
pub struct CharBasedTokenCounter {
chars_per_token: f64,
}
impl CharBasedTokenCounter {
pub fn new(chars_per_token: f64) -> Self {
Self { chars_per_token }
}
}
impl TokenCounter for CharBasedTokenCounter {
fn count_tokens(&self, text: &str) -> usize {
(text.chars().count() as f64 / self.chars_per_token).ceil() as usize
}
}
pub struct TokenBufferMemory {
messages: Arc<RwLock<Vec<Message>>>,
max_token_limit: usize,
token_counter: Box<dyn TokenCounter>,
memory_key: String,
return_messages: bool,
}
impl TokenBufferMemory {
pub fn new() -> Self {
Self {
messages: Arc::new(RwLock::new(Vec::new())),
max_token_limit: 2000,
token_counter: Box::new(SimpleTokenCounter),
memory_key: "history".to_string(),
return_messages: true,
}
}
pub fn with_max_tokens(mut self, limit: usize) -> Self {
self.max_token_limit = limit;
self
}
pub fn with_counter(mut self, counter: impl TokenCounter + 'static) -> Self {
self.token_counter = Box::new(counter);
self
}
pub fn with_memory_key(mut self, key: impl Into<String>) -> Self {
self.memory_key = key.into();
self
}
pub fn with_return_messages(mut self, return_messages: bool) -> Self {
self.return_messages = return_messages;
self
}
pub fn builder() -> TokenBufferMemoryBuilder {
TokenBufferMemoryBuilder::default()
}
pub async fn add_message(&self, msg: Message) {
let mut messages = self.messages.write().await;
messages.push(msg);
self.trim_messages(&mut messages);
}
pub async fn get_messages(&self) -> Vec<Message> {
self.messages.read().await.clone()
}
pub async fn total_tokens(&self) -> usize {
let messages = self.messages.read().await;
self.count_messages_tokens(&messages)
}
pub async fn clear_messages(&self) {
let mut messages = self.messages.write().await;
messages.clear();
}
fn count_messages_tokens(&self, messages: &[Message]) -> usize {
messages
.iter()
.map(|m| {
let text = m.content().text();
self.token_counter.count_tokens(&text) + 3
})
.sum()
}
fn trim_messages(&self, messages: &mut Vec<Message>) {
while !messages.is_empty() && self.count_messages_tokens(messages) > self.max_token_limit {
messages.remove(0);
}
}
}
impl Default for TokenBufferMemory {
fn default() -> Self {
Self::new()
}
}
#[derive(Default)]
pub struct TokenBufferMemoryBuilder {
max_tokens: Option<usize>,
counter: Option<Box<dyn TokenCounter>>,
memory_key: Option<String>,
return_messages: Option<bool>,
}
impl TokenBufferMemoryBuilder {
pub fn max_tokens(mut self, limit: usize) -> Self {
self.max_tokens = Some(limit);
self
}
pub fn counter(mut self, counter: impl TokenCounter + 'static) -> Self {
self.counter = Some(Box::new(counter));
self
}
pub fn memory_key(mut self, key: impl Into<String>) -> Self {
self.memory_key = Some(key.into());
self
}
pub fn return_messages(mut self, val: bool) -> Self {
self.return_messages = Some(val);
self
}
pub fn build(self) -> TokenBufferMemory {
TokenBufferMemory {
messages: Arc::new(RwLock::new(Vec::new())),
max_token_limit: self.max_tokens.unwrap_or(2000),
token_counter: self.counter.unwrap_or_else(|| Box::new(SimpleTokenCounter)),
memory_key: self.memory_key.unwrap_or_else(|| "history".to_string()),
return_messages: self.return_messages.unwrap_or(true),
}
}
}
#[async_trait]
impl BaseMemory for TokenBufferMemory {
async fn load_memory_variables(&self) -> Result<HashMap<String, Value>> {
let messages = self.messages.read().await;
let mut vars = HashMap::new();
if self.return_messages {
let serialized: Vec<Value> = messages
.iter()
.map(|m| serde_json::to_value(m).unwrap_or(Value::Null))
.collect();
vars.insert(self.memory_key.clone(), Value::Array(serialized));
} else {
let buffer = get_buffer_string(&messages, "Human", "AI");
vars.insert(self.memory_key.clone(), Value::String(buffer));
}
Ok(vars)
}
async fn save_context(&self, input: &Message, output: &Message) -> Result<()> {
let mut messages = self.messages.write().await;
messages.push(input.clone());
messages.push(output.clone());
self.trim_messages(&mut messages);
Ok(())
}
async fn clear(&self) -> Result<()> {
let mut messages = self.messages.write().await;
messages.clear();
Ok(())
}
fn memory_key(&self) -> &str {
&self.memory_key
}
}
#[cfg(test)]
mod tests {
use super::*;
use cognis_core::messages::Message;
#[test]
fn test_simple_counter_empty() {
let counter = SimpleTokenCounter::new();
assert_eq!(counter.count_tokens(""), 0);
}
#[test]
fn test_simple_counter_single_word() {
let counter = SimpleTokenCounter::new();
assert_eq!(counter.count_tokens("hello"), 1);
}
#[test]
fn test_simple_counter_multiple_words() {
let counter = SimpleTokenCounter::new();
let tokens = counter.count_tokens("hello world foo bar");
assert_eq!(tokens, 5);
}
#[test]
fn test_char_based_counter() {
let counter = CharBasedTokenCounter::new(4.0);
assert_eq!(counter.count_tokens("hello"), 2);
}
#[test]
fn test_char_based_counter_longer() {
let counter = CharBasedTokenCounter::new(4.0);
assert_eq!(counter.count_tokens("12345678901234567890"), 5);
}
#[tokio::test]
async fn test_new_memory_empty() {
let mem = TokenBufferMemory::new();
let msgs = mem.get_messages().await;
assert!(msgs.is_empty());
assert_eq!(mem.total_tokens().await, 0);
}
#[tokio::test]
async fn test_add_and_get_messages() {
let mem = TokenBufferMemory::new().with_max_tokens(10000);
mem.add_message(Message::human("Hello")).await;
mem.add_message(Message::ai("Hi there")).await;
let msgs = mem.get_messages().await;
assert_eq!(msgs.len(), 2);
}
#[tokio::test]
async fn test_save_context_adds_two() {
let mem = TokenBufferMemory::new().with_max_tokens(10000);
mem.save_context(&Message::human("Hey"), &Message::ai("Hello"))
.await
.unwrap();
let msgs = mem.get_messages().await;
assert_eq!(msgs.len(), 2);
}
#[tokio::test]
async fn test_total_tokens_nonzero() {
let mem = TokenBufferMemory::new().with_max_tokens(10000);
mem.add_message(Message::human("Hello world")).await;
assert!(mem.total_tokens().await > 0);
}
#[tokio::test]
async fn test_trimming_by_token_count() {
let mem = TokenBufferMemory::new().with_max_tokens(10);
mem.save_context(
&Message::human("This is a fairly long message that should use many tokens"),
&Message::ai("This is also a long response with many tokens in it"),
)
.await
.unwrap();
assert!(mem.total_tokens().await <= 10);
}
#[tokio::test]
async fn test_clear() {
let mem = TokenBufferMemory::new();
mem.save_context(&Message::human("Hi"), &Message::ai("Hello"))
.await
.unwrap();
mem.clear().await.unwrap();
let msgs = mem.get_messages().await;
assert!(msgs.is_empty());
assert_eq!(mem.total_tokens().await, 0);
}
#[tokio::test]
async fn test_clear_messages_method() {
let mem = TokenBufferMemory::new();
mem.add_message(Message::human("Test")).await;
mem.clear_messages().await;
assert!(mem.get_messages().await.is_empty());
}
#[tokio::test]
async fn test_builder_default() {
let mem = TokenBufferMemory::builder().build();
assert_eq!(mem.max_token_limit, 2000);
assert_eq!(mem.memory_key, "history");
assert!(mem.return_messages);
}
#[tokio::test]
async fn test_builder_custom() {
let mem = TokenBufferMemory::builder()
.max_tokens(500)
.memory_key("chat")
.return_messages(false)
.counter(SimpleTokenCounter::new())
.build();
assert_eq!(mem.max_token_limit, 500);
assert_eq!(mem.memory_key, "chat");
assert!(!mem.return_messages);
}
#[tokio::test]
async fn test_with_custom_counter() {
let mem = TokenBufferMemory::new()
.with_max_tokens(100)
.with_counter(CharBasedTokenCounter::new(4.0));
mem.add_message(Message::human("Hello")).await;
assert!(mem.total_tokens().await > 0);
}
#[tokio::test]
async fn test_load_as_json() {
let mem = TokenBufferMemory::new().with_max_tokens(10000);
mem.save_context(&Message::human("Hello"), &Message::ai("Hi"))
.await
.unwrap();
let vars = mem.load_memory_variables().await.unwrap();
let history = vars.get("history").unwrap().as_array().unwrap();
assert_eq!(history.len(), 2);
}
#[tokio::test]
async fn test_load_as_string() {
let mem = TokenBufferMemory::new()
.with_max_tokens(10000)
.with_return_messages(false);
mem.save_context(&Message::human("Hello"), &Message::ai("World"))
.await
.unwrap();
let vars = mem.load_memory_variables().await.unwrap();
let history = vars.get("history").unwrap().as_str().unwrap();
assert!(history.contains("Hello"));
assert!(history.contains("World"));
}
#[tokio::test]
async fn test_custom_memory_key() {
let mem = TokenBufferMemory::new()
.with_max_tokens(10000)
.with_memory_key("chat_log");
mem.save_context(&Message::human("Hi"), &Message::ai("Hey"))
.await
.unwrap();
let vars = mem.load_memory_variables().await.unwrap();
assert!(vars.contains_key("chat_log"));
assert!(!vars.contains_key("history"));
}
#[tokio::test]
async fn test_memory_key_method() {
let mem = TokenBufferMemory::new().with_memory_key("custom");
assert_eq!(mem.memory_key(), "custom");
}
}