use dashmap::DashMap;
use std::collections::VecDeque;
use std::sync::Mutex;
use std::time::Instant;
pub type ConversationId = String;
#[derive(Debug, Clone)]
pub struct Turn {
pub id: String,
pub role: String,
pub content: String,
pub timestamp: Instant,
pub token_count: usize,
pub metadata: Option<serde_json::Value>,
}
impl Turn {
pub fn new(id: impl Into<String>, role: impl Into<String>, content: impl Into<String>) -> Self {
let content = content.into();
let token_count = content.split_whitespace().count() * 4 / 3;
Self {
id: id.into(),
role: role.into(),
content,
timestamp: Instant::now(),
token_count,
metadata: None,
}
}
pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
self.metadata = Some(metadata);
self
}
pub fn size(&self) -> usize {
self.id.len() + self.role.len() + self.content.len() + 64
}
}
#[derive(Debug)]
pub struct ConversationContext {
pub id: ConversationId,
pub turns: VecDeque<Turn>,
max_turns: usize,
total_tokens: usize,
last_access: Instant,
}
impl ConversationContext {
fn new(id: ConversationId, max_turns: usize) -> Self {
Self {
id,
turns: VecDeque::with_capacity(max_turns),
max_turns,
total_tokens: 0,
last_access: Instant::now(),
}
}
fn append(&mut self, turn: Turn) {
self.total_tokens += turn.token_count;
self.turns.push_back(turn);
while self.turns.len() > self.max_turns {
if let Some(removed) = self.turns.pop_front() {
self.total_tokens = self.total_tokens.saturating_sub(removed.token_count);
}
}
self.last_access = Instant::now();
}
fn get_recent(&self, count: usize) -> Vec<Turn> {
self.turns.iter()
.rev()
.take(count)
.rev()
.cloned()
.collect()
}
fn size(&self) -> usize {
self.turns.iter().map(|t| t.size()).sum()
}
}
struct LruTracker {
order: Mutex<VecDeque<ConversationId>>,
max_size: usize,
}
impl LruTracker {
fn new(max_size: usize) -> Self {
Self {
order: Mutex::new(VecDeque::with_capacity(max_size)),
max_size,
}
}
fn touch(&self, id: &ConversationId) {
let mut order = self.order.lock().unwrap();
if let Some(pos) = order.iter().position(|x| x == id) {
order.remove(pos);
}
order.push_back(id.clone());
}
fn evict_oldest(&self) -> Option<ConversationId> {
self.order.lock().unwrap().pop_front()
}
}
pub struct ConversationContextCache {
contexts: DashMap<ConversationId, ConversationContext>,
lru: LruTracker,
max_turns: usize,
max_conversations: usize,
}
impl ConversationContextCache {
pub fn new(max_conversations: usize, max_turns: usize) -> Self {
Self {
contexts: DashMap::new(),
lru: LruTracker::new(max_conversations),
max_turns,
max_conversations,
}
}
pub fn get_context(&self, conv_id: &str, max_turns: usize) -> Option<Vec<Turn>> {
self.lru.touch(&conv_id.to_string());
let ctx = self.contexts.get(conv_id)?;
Some(ctx.get_recent(max_turns))
}
pub fn get_full_context(&self, conv_id: &str) -> Option<Vec<Turn>> {
self.lru.touch(&conv_id.to_string());
let ctx = self.contexts.get(conv_id)?;
Some(ctx.turns.iter().cloned().collect())
}
pub fn append_turn(&self, conv_id: &str, turn: Turn) {
self.lru.touch(&conv_id.to_string());
while self.contexts.len() >= self.max_conversations {
if let Some(old_id) = self.lru.evict_oldest() {
self.contexts.remove(&old_id);
} else {
break;
}
}
let mut ctx = self.contexts
.entry(conv_id.to_string())
.or_insert_with(|| ConversationContext::new(conv_id.to_string(), self.max_turns));
ctx.append(turn);
}
pub fn clear_conversation(&self, conv_id: &str) {
self.contexts.remove(conv_id);
}
pub fn conversation_count(&self) -> usize {
self.contexts.len()
}
pub fn total_tokens(&self) -> usize {
self.contexts.iter()
.map(|ctx| ctx.total_tokens)
.sum()
}
pub fn stats(&self) -> ConversationCacheStats {
let mut total_turns = 0;
let mut total_size = 0;
for ctx in self.contexts.iter() {
total_turns += ctx.turns.len();
total_size += ctx.size();
}
ConversationCacheStats {
conversations: self.contexts.len(),
total_turns,
total_size_bytes: total_size,
total_tokens: self.total_tokens(),
}
}
}
#[derive(Debug, Clone)]
pub struct ConversationCacheStats {
pub conversations: usize,
pub total_turns: usize,
pub total_size_bytes: usize,
pub total_tokens: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_turn_creation() {
let turn = Turn::new("1", "user", "Hello, how are you?");
assert_eq!(turn.role, "user");
assert!(turn.token_count > 0);
}
#[test]
fn test_append_and_get_context() {
let cache = ConversationContextCache::new(100, 50);
cache.append_turn("conv-1", Turn::new("1", "user", "Hello"));
cache.append_turn("conv-1", Turn::new("2", "assistant", "Hi there!"));
cache.append_turn("conv-1", Turn::new("3", "user", "How are you?"));
let context = cache.get_context("conv-1", 2).unwrap();
assert_eq!(context.len(), 2);
assert_eq!(context[0].content, "Hi there!");
assert_eq!(context[1].content, "How are you?");
}
#[test]
fn test_max_turns_limit() {
let cache = ConversationContextCache::new(100, 3);
for i in 0..5 {
cache.append_turn("conv-1", Turn::new(
format!("{}", i),
"user",
format!("Message {}", i),
));
}
let context = cache.get_full_context("conv-1").unwrap();
assert_eq!(context.len(), 3);
assert_eq!(context[0].content, "Message 2");
}
#[test]
fn test_lru_eviction() {
let cache = ConversationContextCache::new(2, 10);
cache.append_turn("conv-1", Turn::new("1", "user", "Hello 1"));
cache.append_turn("conv-2", Turn::new("1", "user", "Hello 2"));
cache.append_turn("conv-3", Turn::new("1", "user", "Hello 3"));
assert!(cache.get_context("conv-1", 1).is_none());
assert!(cache.get_context("conv-2", 1).is_some());
assert!(cache.get_context("conv-3", 1).is_some());
}
#[test]
fn test_stats() {
let cache = ConversationContextCache::new(100, 50);
cache.append_turn("conv-1", Turn::new("1", "user", "Hello"));
cache.append_turn("conv-1", Turn::new("2", "assistant", "Hi"));
cache.append_turn("conv-2", Turn::new("1", "user", "Test"));
let stats = cache.stats();
assert_eq!(stats.conversations, 2);
assert_eq!(stats.total_turns, 3);
assert!(stats.total_tokens > 0);
}
}