use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
use uuid::Uuid;
use crate::{AgentId, MultiAgentError, MultiAgentResult};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContextItem {
pub id: Uuid,
pub item_type: ContextItemType,
pub content: String,
pub metadata: ContextMetadata,
pub token_count: u64,
pub relevance_score: f64,
pub added_at: DateTime<Utc>,
}
impl ContextItem {
pub fn new(
item_type: ContextItemType,
content: String,
token_count: u64,
relevance_score: f64,
) -> Self {
Self {
id: Uuid::new_v4(),
item_type,
content,
metadata: ContextMetadata::default(),
token_count,
relevance_score,
added_at: Utc::now(),
}
}
pub fn with_metadata(mut self, metadata: ContextMetadata) -> Self {
self.metadata = metadata;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum ContextItemType {
System,
User,
Assistant,
Tool,
Document,
Concept,
Memory,
Task,
Lesson,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ContextMetadata {
pub source: Option<String>,
pub document_id: Option<String>,
pub concept_ids: Vec<String>,
pub tags: Vec<String>,
pub quality_score: Option<f64>,
pub pinned: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentContext {
pub agent_id: AgentId,
pub items: VecDeque<ContextItem>,
pub max_tokens: u64,
pub current_tokens: u64,
pub max_items: usize,
pub strategy: ContextStrategy,
pub last_updated: DateTime<Utc>,
}
impl AgentContext {
pub fn new(agent_id: AgentId, max_tokens: u64, max_items: usize) -> Self {
Self {
agent_id,
items: VecDeque::new(),
max_tokens,
current_tokens: 0,
max_items,
strategy: ContextStrategy::RelevanceFirst,
last_updated: Utc::now(),
}
}
pub fn add_item(&mut self, item: ContextItem) -> MultiAgentResult<()> {
if self.current_tokens + item.token_count > self.max_tokens {
self.make_space(item.token_count)?;
}
self.current_tokens += item.token_count;
self.items.push_back(item);
self.last_updated = Utc::now();
if self.items.len() > self.max_items {
self.apply_strategy()?;
}
Ok(())
}
pub fn add_items(&mut self, items: Vec<ContextItem>) -> MultiAgentResult<()> {
for item in items {
self.add_item(item)?;
}
Ok(())
}
pub fn remove_item(&mut self, item_id: Uuid) -> MultiAgentResult<ContextItem> {
let position = self
.items
.iter()
.position(|item| item.id == item_id)
.ok_or_else(|| MultiAgentError::ContextError(format!("Item {} not found", item_id)))?;
let item = self.items.remove(position).unwrap();
self.current_tokens -= item.token_count;
self.last_updated = Utc::now();
Ok(item)
}
pub fn clear(&mut self) {
let pinned_items: VecDeque<ContextItem> = self
.items
.drain(..)
.filter(|item| item.metadata.pinned)
.collect();
self.current_tokens = pinned_items.iter().map(|item| item.token_count).sum();
self.items = pinned_items;
self.last_updated = Utc::now();
}
pub fn get_items_by_type(&self, item_type: ContextItemType) -> Vec<&ContextItem> {
self.items
.iter()
.filter(|item| item.item_type == item_type)
.collect()
}
pub fn get_relevant_items(&self, max_tokens: u64) -> Vec<&ContextItem> {
let mut items: Vec<&ContextItem> = self.items.iter().collect();
items.sort_by(|a, b| b.relevance_score.partial_cmp(&a.relevance_score).unwrap());
let mut selected_items = Vec::new();
let mut token_count = 0;
for item in items {
if token_count + item.token_count <= max_tokens {
selected_items.push(item);
token_count += item.token_count;
}
}
selected_items
}
pub fn get_items_by_relevance(
&self,
threshold: f64,
limit: Option<usize>,
) -> Vec<&ContextItem> {
let mut items: Vec<&ContextItem> = self
.items
.iter()
.filter(|item| item.relevance_score >= threshold)
.collect();
items.sort_by(|a, b| b.relevance_score.partial_cmp(&a.relevance_score).unwrap());
if let Some(limit) = limit {
items.truncate(limit);
}
items
}
pub fn format_for_llm(&self) -> String {
let mut formatted = String::new();
for item in &self.items {
match item.item_type {
ContextItemType::System => {
formatted.push_str(&format!("System: {}\n\n", item.content));
}
ContextItemType::User => {
formatted.push_str(&format!("User: {}\n\n", item.content));
}
ContextItemType::Assistant => {
formatted.push_str(&format!("Assistant: {}\n\n", item.content));
}
ContextItemType::Document => {
formatted.push_str(&format!("Document: {}\n\n", item.content));
}
ContextItemType::Concept => {
formatted.push_str(&format!("Concept: {}\n\n", item.content));
}
ContextItemType::Memory => {
formatted.push_str(&format!("Memory: {}\n\n", item.content));
}
ContextItemType::Task => {
formatted.push_str(&format!("Task: {}\n\n", item.content));
}
ContextItemType::Lesson => {
formatted.push_str(&format!("Lesson: {}\n\n", item.content));
}
ContextItemType::Tool => {
formatted.push_str(&format!("Tool Result: {}\n\n", item.content));
}
}
}
formatted
}
fn make_space(&mut self, needed_tokens: u64) -> MultiAgentResult<()> {
let mut tokens_to_free =
needed_tokens.saturating_sub(self.max_tokens - self.current_tokens);
while tokens_to_free > 0 && !self.items.is_empty() {
let (index, _) = self
.items
.iter()
.enumerate()
.filter(|(_, item)| !item.metadata.pinned)
.min_by(|(_, a), (_, b)| a.relevance_score.partial_cmp(&b.relevance_score).unwrap())
.ok_or_else(|| {
MultiAgentError::ContextError("No removable items found".to_string())
})?;
let removed_item = self.items.remove(index).unwrap();
self.current_tokens -= removed_item.token_count;
tokens_to_free = tokens_to_free.saturating_sub(removed_item.token_count);
}
Ok(())
}
fn apply_strategy(&mut self) -> MultiAgentResult<()> {
match self.strategy {
ContextStrategy::RelevanceFirst => {
let mut items: Vec<ContextItem> = self.items.drain(..).collect();
items.sort_by(|a, b| b.relevance_score.partial_cmp(&a.relevance_score).unwrap());
self.items = items.into_iter().take(self.max_items).collect();
}
ContextStrategy::ChronologicalRecent => {
while self.items.len() > self.max_items {
if let Some(item) = self.items.pop_front() {
if !item.metadata.pinned {
self.current_tokens -= item.token_count;
} else {
self.items.push_front(item);
let pos = self.items.iter().position(|item| !item.metadata.pinned);
if let Some(pos) = pos {
let removed = self.items.remove(pos).unwrap();
self.current_tokens -= removed.token_count;
} else {
break; }
}
}
}
}
ContextStrategy::Balanced => {
self.apply_balanced_strategy()?;
}
}
self.current_tokens = self.items.iter().map(|item| item.token_count).sum();
self.last_updated = Utc::now();
Ok(())
}
fn apply_balanced_strategy(&mut self) -> MultiAgentResult<()> {
if self.items.len() <= self.max_items {
return Ok(());
}
let items: Vec<ContextItem> = self.items.drain(..).collect();
let pinned: Vec<ContextItem> = items
.iter()
.filter(|item| item.metadata.pinned)
.cloned()
.collect();
let mut non_pinned: Vec<ContextItem> = items
.into_iter()
.filter(|item| !item.metadata.pinned)
.collect();
let available_slots = self.max_items.saturating_sub(pinned.len());
if non_pinned.len() <= available_slots {
self.items = pinned.into_iter().chain(non_pinned).collect();
} else {
non_pinned.sort_by(|a, b| b.relevance_score.partial_cmp(&a.relevance_score).unwrap());
let relevance_count = (available_slots as f64 * 0.7) as usize;
let recency_count = available_slots - relevance_count;
let mut selected = Vec::new();
selected.extend(non_pinned.iter().take(relevance_count).cloned());
let remaining: Vec<ContextItem> =
non_pinned.into_iter().skip(relevance_count).collect();
let mut recent = remaining;
recent.sort_by_key(|i| std::cmp::Reverse(i.added_at));
selected.extend(recent.into_iter().take(recency_count));
self.items = pinned.into_iter().chain(selected).collect();
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ContextStrategy {
RelevanceFirst,
ChronologicalRecent,
Balanced,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContextSnapshot {
pub id: Uuid,
pub agent_id: AgentId,
pub timestamp: DateTime<Utc>,
pub items: Vec<ContextItem>,
pub token_count: u64,
pub trigger: SnapshotTrigger,
}
impl ContextSnapshot {
pub fn from_context(context: &AgentContext, trigger: SnapshotTrigger) -> Self {
Self {
id: Uuid::new_v4(),
agent_id: context.agent_id,
timestamp: Utc::now(),
items: context.items.iter().cloned().collect(),
token_count: context.current_tokens,
trigger,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SnapshotTrigger {
Manual,
PreChange,
TaskComplete,
Periodic,
PreCleanup,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_context_item_creation() {
let item = ContextItem::new(ContextItemType::User, "Hello world".to_string(), 10, 0.8);
assert_eq!(item.item_type, ContextItemType::User);
assert_eq!(item.content, "Hello world");
assert_eq!(item.token_count, 10);
assert_eq!(item.relevance_score, 0.8);
}
#[test]
fn test_agent_context() {
let agent_id = AgentId::new_v4();
let mut context = AgentContext::new(agent_id, 100, 10);
let item = ContextItem::new(ContextItemType::User, "Test message".to_string(), 20, 0.9);
context.add_item(item).unwrap();
assert_eq!(context.items.len(), 1);
assert_eq!(context.current_tokens, 20);
}
#[test]
fn test_context_token_limit() {
let agent_id = AgentId::new_v4();
let mut context = AgentContext::new(agent_id, 50, 10);
let item1 = ContextItem::new(ContextItemType::User, "Large message".to_string(), 40, 0.9);
context.add_item(item1).unwrap();
let item2 = ContextItem::new(
ContextItemType::User,
"Another message".to_string(),
30,
0.8,
);
context.add_item(item2).unwrap();
assert!(context.current_tokens <= context.max_tokens);
}
#[test]
fn test_pinned_items() {
let agent_id = AgentId::new_v4();
let mut context = AgentContext::new(agent_id, 100, 2);
let mut pinned_item = ContextItem::new(
ContextItemType::System,
"System prompt".to_string(),
30,
1.0,
);
pinned_item.metadata.pinned = true;
context.add_item(pinned_item).unwrap();
let item1 = ContextItem::new(ContextItemType::User, "Message 1".to_string(), 20, 0.5);
context.add_item(item1).unwrap();
let item2 = ContextItem::new(ContextItemType::User, "Message 2".to_string(), 20, 0.6);
context.add_item(item2).unwrap();
assert_eq!(context.items.len(), 2);
assert!(context.items.iter().any(|item| item.metadata.pinned));
}
#[test]
fn test_context_formatting() {
let agent_id = AgentId::new_v4();
let mut context = AgentContext::new(agent_id, 100, 10);
let user_item = ContextItem::new(ContextItemType::User, "Hello".to_string(), 5, 0.9);
context.add_item(user_item).unwrap();
let assistant_item =
ContextItem::new(ContextItemType::Assistant, "Hi there!".to_string(), 8, 0.9);
context.add_item(assistant_item).unwrap();
let formatted = context.format_for_llm();
assert!(formatted.contains("User: Hello"));
assert!(formatted.contains("Assistant: Hi there!"));
}
}