use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use std::time::Duration;
use anyhow::Result;
use chrono::{DateTime, Utc};
use dashmap::DashMap;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use tracing::{debug, instrument};
use uuid::Uuid;
use crate::{
config::{ContextConfig, StorageBackend},
error::Error,
message::{Message, Response},
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Context {
pub id: String,
pub history: VecDeque<ContextMessage>,
pub user: UserContext,
pub variables: HashMap<String, serde_json::Value>,
pub metadata: ContextMetadata,
pub token_count: usize,
}
impl Context {
#[must_use]
pub fn new(id: impl Into<String>) -> Self {
Self {
id: id.into(),
history: VecDeque::new(),
user: UserContext::default(),
variables: HashMap::new(),
metadata: ContextMetadata::new(),
token_count: 0,
}
}
pub fn add_message(&mut self, message: &Message) {
let context_msg = ContextMessage::from_message(message);
self.token_count += context_msg.estimated_tokens();
self.history.push_back(context_msg);
self.metadata.last_activity = Utc::now();
self.metadata.message_count += 1;
}
pub fn add_response(&mut self, response: &Response) {
let context_msg = ContextMessage::from_response(response);
self.token_count += context_msg.estimated_tokens();
self.history.push_back(context_msg);
self.metadata.last_activity = Utc::now();
self.metadata.message_count += 1;
if let Some(usage) = &response.usage {
self.metadata.total_tokens += usage.total_tokens;
self.metadata.total_cost += usage.estimated_cost;
}
}
pub fn trim_to_token_limit(&mut self, max_tokens: usize) {
while self.token_count > max_tokens && !self.history.is_empty() {
if let Some(removed) = self.history.pop_front() {
self.token_count = self.token_count.saturating_sub(removed.estimated_tokens());
}
}
}
pub fn get_variable(&self, key: &str) -> Option<&serde_json::Value> {
self.variables.get(key)
}
pub fn set_variable(&mut self, key: impl Into<String>, value: serde_json::Value) {
self.variables.insert(key.into(), value);
}
pub fn clear_history(&mut self) {
self.history.clear();
self.token_count = 0;
self.metadata.message_count = 0;
}
#[must_use]
pub fn age(&self) -> Duration {
let now = Utc::now();
(now - self.metadata.created_at)
.to_std()
.unwrap_or(Duration::ZERO)
}
#[must_use]
pub fn is_expired(&self, ttl: Duration) -> bool {
self.age() > ttl
}
#[must_use]
pub fn summary(&self) -> String {
format!(
"Context {} - Messages: {}, Tokens: {}, Age: {:?}",
self.id,
self.metadata.message_count,
self.token_count,
self.age()
)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContextMessage {
pub role: MessageRole,
pub content: String,
pub timestamp: DateTime<Utc>,
pub message_id: Option<Uuid>,
}
impl ContextMessage {
pub fn from_message(message: &Message) -> Self {
Self {
role: MessageRole::User,
content: message.content.clone(),
timestamp: message.timestamp,
message_id: Some(message.id),
}
}
pub fn from_response(response: &Response) -> Self {
Self {
role: MessageRole::Assistant,
content: response.content.clone(),
timestamp: response.timestamp,
message_id: Some(response.id),
}
}
pub fn system(content: impl Into<String>) -> Self {
Self {
role: MessageRole::System,
content: content.into(),
timestamp: Utc::now(),
message_id: None,
}
}
const fn estimated_tokens(&self) -> usize {
self.content.len() / 4
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum MessageRole {
System,
User,
Assistant,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct UserContext {
pub id: Option<String>,
pub name: Option<String>,
pub preferences: HashMap<String, serde_json::Value>,
pub attributes: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContextMetadata {
pub created_at: DateTime<Utc>,
pub last_activity: DateTime<Utc>,
pub message_count: usize,
pub total_tokens: usize,
pub total_cost: f64,
pub tags: Vec<String>,
}
impl ContextMetadata {
fn new() -> Self {
let now = Utc::now();
Self {
created_at: now,
last_activity: now,
message_count: 0,
total_tokens: 0,
total_cost: 0.0,
tags: Vec::new(),
}
}
}
pub struct ContextManager {
config: ContextConfig,
store: Arc<dyn ContextStore>,
cache: Arc<DashMap<String, Arc<RwLock<Context>>>>,
}
impl ContextManager {
#[instrument(skip(config))]
pub async fn new(config: ContextConfig) -> Result<Self> {
debug!("Creating context manager with config: {:?}", config);
let store: Arc<dyn ContextStore> = match &config.storage_backend {
StorageBackend::Memory => Arc::new(MemoryContextStore::new()),
StorageBackend::Redis { url: _ } => {
return Err(Error::new("Redis store not yet implemented").into());
}
StorageBackend::Postgres { url: _ } => {
return Err(Error::new("Postgres store not yet implemented").into());
}
StorageBackend::Sqlite { path: _ } => {
return Err(Error::new("SQLite store not yet implemented").into());
}
};
Ok(Self {
config,
store,
cache: Arc::new(DashMap::new()),
})
}
#[instrument(skip(self))]
pub async fn get_or_create(&self, id: &str) -> Result<Arc<RwLock<Context>>> {
if let Some(context) = self.cache.get(id) {
let ctx = context.clone();
if ctx.read().is_expired(self.config.context_ttl) {
debug!("Context {} is expired, removing", id);
self.cache.remove(id);
} else {
debug!("Found context {} in cache", id);
return Ok(ctx);
}
}
if let Some(context) = self.store.get(id).await? {
if !context.is_expired(self.config.context_ttl) {
debug!("Loaded context {} from store", id);
let ctx = Arc::new(RwLock::new(context));
self.cache.insert(id.to_string(), ctx.clone());
return Ok(ctx);
}
}
debug!("Creating new context {}", id);
let context = Context::new(id);
let ctx = Arc::new(RwLock::new(context));
self.cache.insert(id.to_string(), ctx.clone());
if self.config.persist_context {
let context = ctx.read().clone();
self.store.set(id, context, self.config.context_ttl).await?;
}
Ok(ctx)
}
#[instrument(skip(self, context))]
pub async fn update(&self, id: &str, context: Arc<RwLock<Context>>) -> Result<()> {
{
let mut ctx = context.write();
ctx.trim_to_token_limit(self.config.max_context_tokens);
}
self.cache.insert(id.to_string(), context.clone());
if self.config.persist_context {
let ctx = context.read().clone();
self.store.set(id, ctx, self.config.context_ttl).await?;
}
Ok(())
}
#[instrument(skip(self))]
pub async fn delete(&self, id: &str) -> Result<()> {
debug!("Deleting context {}", id);
self.cache.remove(id);
self.store.delete(id).await?;
Ok(())
}
#[instrument(skip(self))]
pub async fn clear_expired(&self) -> Result<usize> {
let mut removed = 0;
let expired_keys: Vec<String> = self
.cache
.iter()
.filter(|entry| entry.value().read().is_expired(self.config.context_ttl))
.map(|entry| entry.key().clone())
.collect();
for key in expired_keys {
self.cache.remove(&key);
self.store.delete(&key).await?;
removed += 1;
}
debug!("Removed {} expired contexts", removed);
Ok(removed)
}
#[must_use]
pub fn stats(&self) -> ContextStats {
let total = self.cache.len();
let mut total_tokens = 0;
let mut total_messages = 0;
for entry in self.cache.iter() {
let ctx = entry.value().read();
total_tokens += ctx.token_count;
total_messages += ctx.metadata.message_count;
}
ContextStats {
total_contexts: total,
total_tokens,
total_messages,
cache_size: total,
}
}
}
#[async_trait::async_trait]
pub trait ContextStore: Send + Sync {
async fn get(&self, key: &str) -> Result<Option<Context>>;
async fn set(&self, key: &str, context: Context, ttl: Duration) -> Result<()>;
async fn delete(&self, key: &str) -> Result<()>;
async fn list_keys(&self, pattern: &str) -> Result<Vec<String>>;
}
struct MemoryContextStore {
data: Arc<DashMap<String, (Context, DateTime<Utc>)>>,
}
impl MemoryContextStore {
fn new() -> Self {
Self {
data: Arc::new(DashMap::new()),
}
}
}
#[async_trait::async_trait]
impl ContextStore for MemoryContextStore {
async fn get(&self, key: &str) -> Result<Option<Context>> {
Ok(self.data.get(key).map(|entry| entry.0.clone()))
}
async fn set(&self, key: &str, context: Context, ttl: Duration) -> Result<()> {
let expiry = Utc::now() + chrono::Duration::from_std(ttl)?;
self.data.insert(key.to_string(), (context, expiry));
Ok(())
}
async fn delete(&self, key: &str) -> Result<()> {
self.data.remove(key);
Ok(())
}
async fn list_keys(&self, pattern: &str) -> Result<Vec<String>> {
let keys = self
.data
.iter()
.filter(|entry| entry.key().contains(pattern))
.map(|entry| entry.key().clone())
.collect();
Ok(keys)
}
}
#[derive(Debug, Clone)]
pub struct ContextStats {
pub total_contexts: usize,
pub total_tokens: usize,
pub total_messages: usize,
pub cache_size: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_context_creation() {
let context = Context::new("test-123");
assert_eq!(context.id, "test-123");
assert!(context.history.is_empty());
assert_eq!(context.token_count, 0);
}
#[test]
fn test_context_message_addition() {
let mut context = Context::new("test");
let message = Message::text("Hello");
context.add_message(&message);
assert_eq!(context.history.len(), 1);
assert!(context.token_count > 0);
assert_eq!(context.metadata.message_count, 1);
}
#[test]
fn test_context_trimming() {
let mut context = Context::new("test");
for i in 0..10 {
let msg = Message::text(format!("Message {i}"));
context.add_message(&msg);
}
let original_count = context.history.len();
context.trim_to_token_limit(10);
assert!(context.history.len() < original_count);
assert!(context.token_count <= 10);
}
#[test]
fn test_context_variables() {
let mut context = Context::new("test");
context.set_variable("key", serde_json::json!("value"));
assert_eq!(
context.get_variable("key"),
Some(&serde_json::json!("value"))
);
assert_eq!(context.get_variable("missing"), None);
}
#[test]
fn test_context_expiry() {
let context = Context::new("test");
assert!(!context.is_expired(Duration::from_secs(3600)));
}
#[tokio::test]
async fn test_context_manager() {
let config = ContextConfig::default();
let manager = ContextManager::new(config).await.unwrap();
let ctx1 = manager.get_or_create("test-1").await.unwrap();
let ctx2 = manager.get_or_create("test-1").await.unwrap();
assert_eq!(ctx1.read().id, ctx2.read().id);
}
#[tokio::test]
async fn test_memory_store() {
let store = MemoryContextStore::new();
let context = Context::new("test");
store
.set("test", context.clone(), Duration::from_secs(60))
.await
.unwrap();
let loaded = store.get("test").await.unwrap();
assert!(loaded.is_some());
assert_eq!(loaded.unwrap().id, "test");
store.delete("test").await.unwrap();
let deleted = store.get("test").await.unwrap();
assert!(deleted.is_none());
}
}