use crate::core::SessionId;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::io::Read as IoRead;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct SessionConfig {
pub max_tokens: usize,
pub keep_recent_messages: usize,
pub compression_level: i32,
pub compression_threshold: f32,
}
impl Default for SessionConfig {
fn default() -> Self {
Self {
max_tokens: 100_000, keep_recent_messages: 20, compression_level: 3, compression_threshold: 0.8, }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionContext {
pub session_id: SessionId,
pub conversation_history: TokenEfficientHistory,
pub task_context: TaskContext,
pub agent_state: AgentState,
pub workspace_state: WorkspaceState,
pub metadata: HashMap<String, serde_json::Value>,
pub config: SessionConfig,
}
impl SessionContext {
pub fn new(session_id: SessionId) -> Self {
let config = SessionConfig::default();
let mut conversation_history = TokenEfficientHistory::new();
conversation_history.max_tokens = config.max_tokens;
conversation_history.keep_recent = config.keep_recent_messages;
conversation_history.compression_level = config.compression_level;
Self {
session_id,
conversation_history,
task_context: TaskContext::default(),
agent_state: AgentState::default(),
workspace_state: WorkspaceState::default(),
metadata: HashMap::new(),
config,
}
}
pub fn add_message(&mut self, message: Message) {
self.conversation_history.add_message_struct(message);
}
pub fn add_message_raw(&mut self, role: MessageRole, content: String) {
self.conversation_history.add_message(role, content);
}
pub fn get_message_count(&self) -> usize {
self.conversation_history.messages.len()
}
pub fn get_total_tokens(&self) -> usize {
self.conversation_history.current_tokens
}
pub fn get_recent_messages(&self, n: usize) -> Vec<&Message> {
let message_count = self.conversation_history.messages.len();
if n >= message_count {
self.conversation_history.messages.iter().collect()
} else {
self.conversation_history
.messages
.iter()
.skip(message_count - n)
.collect()
}
}
pub async fn compress_context(&mut self) -> bool {
let threshold = (self.conversation_history.max_tokens as f32
* self.config.compression_threshold) as usize;
if self.conversation_history.current_tokens > threshold {
self.conversation_history.compress_old_messages();
true
} else {
false
}
}
pub fn update_task(&mut self, task: TaskContext) {
self.task_context = task;
}
pub fn summarize(&self) -> ContextSummary {
ContextSummary {
session_id: self.session_id.clone(),
message_count: self.conversation_history.messages.len(),
current_task: self.task_context.name.clone(),
agent_state: self.agent_state.state.clone(),
workspace_files: self.workspace_state.tracked_files.len(),
}
}
pub fn get_compression_stats(&self) -> CompressionStats {
self.conversation_history.get_compression_stats()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenEfficientHistory {
#[serde(default)]
pub messages: Vec<Message>,
#[serde(default)]
pub compressed_history: Option<CompressedHistory>,
#[serde(default = "default_max_tokens")]
pub max_tokens: usize,
#[serde(default)]
pub current_tokens: usize,
#[serde(default = "default_keep_recent")]
pub keep_recent: usize,
#[serde(default = "default_compression_level")]
pub compression_level: i32,
#[serde(default)]
pub total_messages_added: usize,
#[serde(default)]
pub tokens_saved_by_compression: usize,
}
fn default_max_tokens() -> usize {
100_000
}
fn default_keep_recent() -> usize {
20
}
fn default_compression_level() -> i32 {
3
}
impl Default for TokenEfficientHistory {
fn default() -> Self {
Self::new()
}
}
impl TokenEfficientHistory {
pub fn new() -> Self {
Self {
messages: Vec::new(),
compressed_history: None,
max_tokens: 100_000,
current_tokens: 0,
keep_recent: 20,
compression_level: 3,
total_messages_added: 0,
tokens_saved_by_compression: 0,
}
}
pub fn add_message(&mut self, role: MessageRole, content: String) {
let token_estimate = estimate_tokens(&content);
let message = Message {
role,
content,
timestamp: Utc::now(),
token_count: token_estimate,
};
self.messages.push(message);
self.current_tokens += token_estimate;
self.total_messages_added += 1;
if self.current_tokens > self.max_tokens {
self.compress_old_messages();
}
}
pub fn add_message_struct(&mut self, message: Message) {
self.current_tokens += message.token_count;
self.messages.push(message);
self.total_messages_added += 1;
if self.current_tokens > self.max_tokens {
self.compress_old_messages();
}
}
pub fn compress_old_messages(&mut self) {
if self.messages.len() <= self.keep_recent {
return;
}
let split_point = self.messages.len() - self.keep_recent;
let messages_to_compress: Vec<Message> = self.messages.drain(..split_point).collect();
if messages_to_compress.is_empty() {
return;
}
let tokens_to_compress: usize = messages_to_compress.iter().map(|m| m.token_count).sum();
let json_data = match serde_json::to_vec(&messages_to_compress) {
Ok(data) => data,
Err(e) => {
tracing::warn!("Failed to serialize messages for compression: {}", e);
let mut restored = messages_to_compress;
restored.append(&mut self.messages);
self.messages = restored;
return;
}
};
let compressed_data = match zstd::encode_all(json_data.as_slice(), self.compression_level) {
Ok(data) => data,
Err(e) => {
tracing::warn!("Failed to compress messages: {}", e);
let mut restored = messages_to_compress;
restored.append(&mut self.messages);
self.messages = restored;
return;
}
};
let original_size = json_data.len();
let compressed_size = compressed_data.len();
let compression_ratio = if original_size > 0 {
1.0 - (compressed_size as f64 / original_size as f64)
} else {
0.0
};
let summary = create_compression_summary(&messages_to_compress);
let new_compressed = if let Some(existing) = self.compressed_history.take() {
CompressedHistory {
compressed_data: merge_compressed_data(
&existing.compressed_data,
&compressed_data,
self.compression_level,
),
summary: format!("{}\n---\n{}", existing.summary, summary),
message_count: existing.message_count + messages_to_compress.len(),
original_tokens: existing.original_tokens + tokens_to_compress,
compressed_bytes: existing.compressed_bytes + compressed_size,
compression_ratio: (existing.compression_ratio + compression_ratio) / 2.0,
}
} else {
CompressedHistory {
compressed_data,
summary,
message_count: messages_to_compress.len(),
original_tokens: tokens_to_compress,
compressed_bytes: compressed_size,
compression_ratio,
}
};
self.compressed_history = Some(new_compressed);
self.current_tokens -= tokens_to_compress;
self.tokens_saved_by_compression += tokens_to_compress;
let summary_tokens = estimate_tokens(
self.compressed_history
.as_ref()
.map(|h| h.summary.as_str())
.unwrap_or(""),
);
self.current_tokens += summary_tokens.min(100);
tracing::info!(
"Compressed {} messages ({} tokens) with {:.1}% ratio",
messages_to_compress.len(),
tokens_to_compress,
compression_ratio * 100.0
);
}
pub fn decompress_history(&self) -> Option<Vec<Message>> {
let compressed = self.compressed_history.as_ref()?;
let mut decompressed = Vec::new();
let mut decoder = match zstd::Decoder::new(compressed.compressed_data.as_slice()) {
Ok(d) => d,
Err(e) => {
tracing::error!("Failed to create zstd decoder: {}", e);
return None;
}
};
if let Err(e) = decoder.read_to_end(&mut decompressed) {
tracing::error!("Failed to decompress history: {}", e);
return None;
}
match serde_json::from_slice(&decompressed) {
Ok(messages) => Some(messages),
Err(e) => {
tracing::error!("Failed to deserialize decompressed messages: {}", e);
None
}
}
}
pub fn get_all_messages(&self) -> Vec<Message> {
let mut all_messages = self.decompress_history().unwrap_or_default();
all_messages.extend(self.messages.clone());
all_messages
}
pub fn get_messages_within_limit(&self, token_limit: usize) -> Vec<&Message> {
let mut messages = Vec::new();
let mut tokens = 0;
for message in self.messages.iter().rev() {
if tokens + message.token_count <= token_limit {
messages.push(message);
tokens += message.token_count;
} else {
break;
}
}
messages.reverse();
messages
}
pub fn get_compression_stats(&self) -> CompressionStats {
let compressed_stats = self.compressed_history.as_ref().map(|h| {
(
h.message_count,
h.original_tokens,
h.compressed_bytes,
h.compression_ratio,
)
});
CompressionStats {
total_messages_added: self.total_messages_added,
active_messages: self.messages.len(),
compressed_messages: compressed_stats.map(|(c, _, _, _)| c).unwrap_or(0),
active_tokens: self.current_tokens,
tokens_saved: self.tokens_saved_by_compression,
compressed_bytes: compressed_stats.map(|(_, _, b, _)| b).unwrap_or(0),
compression_ratio: compressed_stats.map(|(_, _, _, r)| r).unwrap_or(0.0),
}
}
}
fn estimate_tokens(content: &str) -> usize {
if content.is_empty() {
return 1;
}
let word_count = content.split_whitespace().count();
let char_count = content.chars().count();
let special_chars = content
.chars()
.filter(|c| !c.is_alphanumeric() && !c.is_whitespace())
.count();
let estimate = (word_count as f64 * 1.3) as usize + special_chars + 2;
let char_estimate = char_count / 4;
estimate.max(char_estimate).max(1)
}
fn create_compression_summary(messages: &[Message]) -> String {
if messages.is_empty() {
return String::new();
}
let first = messages.first().unwrap();
let last = messages.last().unwrap();
let user_count = messages
.iter()
.filter(|m| m.role == MessageRole::User)
.count();
let assistant_count = messages
.iter()
.filter(|m| m.role == MessageRole::Assistant)
.count();
format!(
"[Compressed: {} messages ({} user, {} assistant) from {} to {}]",
messages.len(),
user_count,
assistant_count,
first.timestamp.format("%H:%M:%S"),
last.timestamp.format("%H:%M:%S")
)
}
fn merge_compressed_data(existing: &[u8], new: &[u8], level: i32) -> Vec<u8> {
let mut existing_decompressed = Vec::new();
if let Ok(mut decoder) = zstd::Decoder::new(existing) {
let _ = decoder.read_to_end(&mut existing_decompressed);
}
let mut new_decompressed = Vec::new();
if let Ok(mut decoder) = zstd::Decoder::new(new) {
let _ = decoder.read_to_end(&mut new_decompressed);
}
let existing_messages: Vec<Message> =
serde_json::from_slice(&existing_decompressed).unwrap_or_default();
let new_messages: Vec<Message> = serde_json::from_slice(&new_decompressed).unwrap_or_default();
let mut merged = existing_messages;
merged.extend(new_messages);
let json_data = serde_json::to_vec(&merged).unwrap_or_default();
zstd::encode_all(json_data.as_slice(), level).unwrap_or_else(|_| new.to_vec())
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: MessageRole,
pub content: String,
pub timestamp: DateTime<Utc>,
pub token_count: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum MessageRole {
System,
User,
Assistant,
Tool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompressedHistory {
#[serde(with = "base64_serde")]
pub compressed_data: Vec<u8>,
pub summary: String,
pub message_count: usize,
pub original_tokens: usize,
pub compressed_bytes: usize,
pub compression_ratio: f64,
}
mod base64_serde {
use base64::{Engine, engine::general_purpose::STANDARD};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub fn serialize<S>(data: &Vec<u8>, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let encoded = STANDARD.encode(data);
encoded.serialize(serializer)
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
where
D: Deserializer<'de>,
{
let encoded = String::deserialize(deserializer)?;
STANDARD.decode(&encoded).map_err(serde::de::Error::custom)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompressionStats {
pub total_messages_added: usize,
pub active_messages: usize,
pub compressed_messages: usize,
pub active_tokens: usize,
pub tokens_saved: usize,
pub compressed_bytes: usize,
pub compression_ratio: f64,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TaskContext {
pub id: Option<String>,
pub name: Option<String>,
pub description: Option<String>,
pub task_type: Option<String>,
pub priority: Option<TaskPriority>,
pub started_at: Option<DateTime<Utc>>,
pub metadata: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum TaskPriority {
Low,
Medium,
High,
Critical,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct AgentState {
pub state: String,
pub capabilities: Vec<String>,
pub metrics: HashMap<String, f64>,
pub last_error: Option<String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct WorkspaceState {
pub working_directory: String,
pub tracked_files: HashMap<String, FileState>,
pub recent_changes: Vec<FileChange>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FileState {
pub path: String,
pub last_modified: DateTime<Utc>,
pub hash: String,
pub is_modified: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FileChange {
pub path: String,
pub change_type: FileChangeType,
pub timestamp: DateTime<Utc>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum FileChangeType {
Created,
Modified,
Deleted,
Renamed,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContextSummary {
pub session_id: SessionId,
pub message_count: usize,
pub current_task: Option<String>,
pub agent_state: String,
pub workspace_files: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_token_efficient_history() {
let mut history = TokenEfficientHistory::new();
history.max_tokens = 200; history.keep_recent = 3;
for i in 0..10 {
history.add_message(
MessageRole::User,
format!("Test message number {}", i), );
}
assert!(history.messages.len() <= 10);
assert!(
history.compressed_history.is_some() || history.current_tokens <= history.max_tokens
);
assert!(
history.current_tokens <= history.max_tokens + 150,
"current_tokens {} exceeded max_tokens {} + 150",
history.current_tokens,
history.max_tokens
);
}
#[test]
fn test_zstd_compression() {
let mut history = TokenEfficientHistory::new();
history.max_tokens = 50;
history.keep_recent = 2;
for i in 0..20 {
history.add_message(MessageRole::User, format!("Test message number {}", i));
}
assert!(history.compressed_history.is_some());
let decompressed = history.decompress_history();
assert!(decompressed.is_some());
let messages = decompressed.unwrap();
assert!(!messages.is_empty());
}
#[test]
fn test_compression_stats() {
let mut history = TokenEfficientHistory::new();
history.max_tokens = 30;
history.keep_recent = 2;
for i in 0..10 {
history.add_message(MessageRole::User, format!("Message {}", i));
}
let stats = history.get_compression_stats();
assert_eq!(stats.total_messages_added, 10);
assert!(stats.compressed_messages > 0 || stats.active_messages == 10);
}
#[test]
fn test_context_summary() {
let session_id = SessionId::new();
let mut context = SessionContext::new(session_id.clone());
context.add_message_raw(MessageRole::User, "Hello".to_string());
context.add_message_raw(MessageRole::Assistant, "Hi there!".to_string());
let summary = context.summarize();
assert_eq!(summary.session_id, session_id);
assert_eq!(summary.message_count, 2);
}
#[test]
fn test_new_api_methods() {
let session_id = SessionId::new();
let mut context = SessionContext::new(session_id.clone());
assert_eq!(context.get_message_count(), 0);
let message = Message {
role: MessageRole::User,
content: "Test message".to_string(),
timestamp: Utc::now(),
token_count: 3,
};
context.add_message(message);
assert_eq!(context.get_message_count(), 1);
assert_eq!(context.get_total_tokens(), 3);
let recent = context.get_recent_messages(1);
assert_eq!(recent.len(), 1);
assert_eq!(recent[0].content, "Test message");
assert_eq!(context.config.max_tokens, 100_000);
}
#[tokio::test]
async fn test_compress_context() {
let session_id = SessionId::new();
let mut context = SessionContext::new(session_id);
context.config.max_tokens = 50;
context.config.compression_threshold = 0.5;
context.conversation_history.max_tokens = 50;
context.conversation_history.keep_recent = 3;
for i in 0..10 {
let message = Message {
role: MessageRole::User,
content: format!("Message {}", i),
timestamp: Utc::now(),
token_count: 10,
};
context.add_message(message);
}
let stats = context.get_compression_stats();
assert!(stats.compressed_messages > 0 || stats.total_messages_added == 10);
}
#[test]
fn test_estimate_tokens() {
let tokens = estimate_tokens("Hello world");
assert!(tokens >= 2);
let tokens = estimate_tokens("");
assert_eq!(tokens, 1);
let tokens = estimate_tokens("Hello, world! How are you?");
assert!(tokens >= 5);
let tokens = estimate_tokens(&"word ".repeat(100));
assert!(tokens >= 100);
}
#[test]
fn test_get_all_messages() {
let mut history = TokenEfficientHistory::new();
history.max_tokens = 20;
history.keep_recent = 2;
for i in 0..5 {
history.add_message(MessageRole::User, format!("Message {}", i));
}
let all = history.get_all_messages();
assert_eq!(all.len(), 5);
}
}