use chrono::{DateTime, Utc};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use crate::error::{Result, RuvLLMError};
use super::agentic_memory::{AgenticMemory, AgenticMemoryConfig, MemoryType, RetrievedMemory};
use super::semantic_cache::{SemanticCacheConfig, SemanticToolCache};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ModelTokenLimit {
Haiku,
Sonnet,
Opus,
Custom(usize),
}
impl ModelTokenLimit {
pub fn max_tokens(&self) -> usize {
match self {
ModelTokenLimit::Haiku => 200_000,
ModelTokenLimit::Sonnet => 200_000,
ModelTokenLimit::Opus => 200_000,
ModelTokenLimit::Custom(n) => *n,
}
}
pub fn context_budget(&self) -> usize {
(self.max_tokens() as f32 * 0.8) as usize
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContextManagerConfig {
pub memory: AgenticMemoryConfig,
pub cache: SemanticCacheConfig,
pub default_model: ModelTokenLimit,
pub chars_per_token: f32,
pub max_elements: usize,
pub min_relevance: f32,
pub enable_summarization: bool,
pub summarization_ratio: f32,
pub priority_weights: PriorityWeights,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PriorityWeights {
pub recency: f32,
pub relevance: f32,
pub importance: f32,
pub frequency: f32,
}
impl Default for PriorityWeights {
fn default() -> Self {
Self {
recency: 0.3,
relevance: 0.4,
importance: 0.2,
frequency: 0.1,
}
}
}
impl Default for ContextManagerConfig {
fn default() -> Self {
Self {
memory: AgenticMemoryConfig::default(),
cache: SemanticCacheConfig::default(),
default_model: ModelTokenLimit::Sonnet,
chars_per_token: 4.0,
max_elements: 100,
min_relevance: 0.1,
enable_summarization: true,
summarization_ratio: 0.5,
priority_weights: PriorityWeights::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContextElement {
pub id: String,
pub element_type: ContextElementType,
pub content: String,
pub estimated_tokens: usize,
pub priority: f32,
pub relevance: f32,
pub recency_seconds: i64,
pub is_important: bool,
pub access_count: u64,
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ContextElementType {
System,
User,
Assistant,
Tool,
Memory,
File,
Cached,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ElementPriority {
Critical,
High,
Medium,
Low,
Optional,
}
impl ElementPriority {
pub fn value(&self) -> f32 {
match self {
ElementPriority::Critical => 1.0,
ElementPriority::High => 0.8,
ElementPriority::Medium => 0.6,
ElementPriority::Low => 0.4,
ElementPriority::Optional => 0.2,
}
}
}
pub struct PriorityScorer {
weights: PriorityWeights,
}
impl PriorityScorer {
pub fn new(weights: PriorityWeights) -> Self {
Self { weights }
}
pub fn score(&self, element: &ContextElement) -> f32 {
let recency_score = (-element.recency_seconds as f32 / 86400.0).exp();
let relevance_score = element.relevance;
let importance_score = if element.is_important { 1.0 } else { 0.5 };
let frequency_score = ((element.access_count as f32 + 1.0).ln() / 10.0).min(1.0);
let score = self.weights.recency * recency_score
+ self.weights.relevance * relevance_score
+ self.weights.importance * importance_score
+ self.weights.frequency * frequency_score;
score.min(1.0).max(0.0)
}
pub fn assign_priority(&self, score: f32) -> ElementPriority {
if score >= 0.9 {
ElementPriority::Critical
} else if score >= 0.7 {
ElementPriority::High
} else if score >= 0.5 {
ElementPriority::Medium
} else if score >= 0.3 {
ElementPriority::Low
} else {
ElementPriority::Optional
}
}
}
pub struct MemorySummarizer {
target_ratio: f32,
}
impl MemorySummarizer {
pub fn new(target_ratio: f32) -> Self {
Self { target_ratio }
}
pub fn summarize(&self, content: &str, max_tokens: usize, chars_per_token: f32) -> String {
let max_chars = (max_tokens as f32 * chars_per_token) as usize;
if content.len() <= max_chars {
return content.to_string();
}
let target_len = (max_chars as f32 * self.target_ratio) as usize;
if target_len < 100 {
format!("{}...", &content[..target_len.min(content.len())])
} else {
let keep_start = target_len * 2 / 3;
let keep_end = target_len / 3;
let start = &content[..keep_start.min(content.len())];
let end_start = content.len().saturating_sub(keep_end);
let end = if end_start < content.len() {
&content[end_start..]
} else {
""
};
format!("{}...[truncated]...{}", start, end)
}
}
pub fn summarize_memories(
&self,
memories: &[RetrievedMemory],
max_tokens: usize,
chars_per_token: f32,
) -> String {
let max_chars = (max_tokens as f32 * chars_per_token) as usize;
let mut summary = String::with_capacity(max_chars);
let chars_per_memory = max_chars / memories.len().max(1);
for (i, mem) in memories.iter().enumerate() {
let mem_summary = if mem.content.len() > chars_per_memory {
format!("{}...", &mem.content[..chars_per_memory])
} else {
mem.content.clone()
};
if i > 0 {
summary.push_str("\n---\n");
}
summary.push_str(&format!("[{}] {}", mem.id, mem_summary));
if summary.len() >= max_chars {
break;
}
}
summary
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PreparedContext {
pub elements: Vec<ContextElement>,
pub total_tokens: usize,
pub budget_used: f32,
pub summarized_count: usize,
pub excluded_count: usize,
pub preparation_time_us: u64,
}
impl PreparedContext {
pub fn to_string(&self) -> String {
self.elements
.iter()
.map(|e| e.content.as_str())
.collect::<Vec<_>>()
.join("\n\n")
}
pub fn get_by_type(&self, element_type: ContextElementType) -> Vec<&ContextElement> {
self.elements
.iter()
.filter(|e| e.element_type == element_type)
.collect()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContextManagerStats {
pub total_preparations: u64,
pub avg_tokens: u64,
pub avg_preparation_time_us: u64,
pub summarizations: u64,
pub cache_hits: u64,
pub memory_retrievals: u64,
}
pub struct IntelligentContextManager {
config: ContextManagerConfig,
memory: AgenticMemory,
cache: SemanticToolCache,
scorer: PriorityScorer,
summarizer: MemorySummarizer,
stats: ContextManagerStatsInternal,
}
#[derive(Debug, Default)]
struct ContextManagerStatsInternal {
preparations: AtomicU64,
total_tokens: AtomicU64,
total_time_us: AtomicU64,
summarizations: AtomicU64,
cache_hits: AtomicU64,
memory_retrievals: AtomicU64,
}
impl IntelligentContextManager {
pub fn new(config: ContextManagerConfig) -> Result<Self> {
let memory = AgenticMemory::new(config.memory.clone())?;
let cache = SemanticToolCache::new(config.cache.clone())?;
let scorer = PriorityScorer::new(config.priority_weights.clone());
let summarizer = MemorySummarizer::new(config.summarization_ratio);
Ok(Self {
config,
memory,
cache,
scorer,
summarizer,
stats: ContextManagerStatsInternal::default(),
})
}
pub fn prepare_context(
&self,
messages: &[Message],
query_embedding: Option<&[f32]>,
model: Option<ModelTokenLimit>,
) -> Result<PreparedContext> {
let start = std::time::Instant::now();
self.stats.preparations.fetch_add(1, Ordering::SeqCst);
let model = model.unwrap_or(self.config.default_model);
let budget = model.context_budget();
let mut elements: Vec<ContextElement> = Vec::new();
let now = Utc::now();
for (i, msg) in messages.iter().enumerate() {
let element_type = match msg.role {
MessageRole::System => ContextElementType::System,
MessageRole::User => ContextElementType::User,
MessageRole::Assistant => ContextElementType::Assistant,
};
let estimated_tokens = self.estimate_tokens(&msg.content);
let recency = (now - msg.timestamp).num_seconds();
let element = ContextElement {
id: format!("msg-{}", i),
element_type,
content: msg.content.clone(),
estimated_tokens,
priority: if element_type == ContextElementType::System {
1.0
} else {
0.8
},
relevance: 1.0, recency_seconds: recency,
is_important: element_type == ContextElementType::System,
access_count: 1,
metadata: HashMap::new(),
};
elements.push(element);
}
if let Some(embedding) = query_embedding {
self.stats.memory_retrievals.fetch_add(1, Ordering::SeqCst);
let memories = self
.memory
.get_relevant(embedding, self.config.max_elements)?;
for mem in memories {
if mem.score < self.config.min_relevance {
continue;
}
let estimated_tokens = self.estimate_tokens(&mem.content);
let element = ContextElement {
id: mem.id.clone(),
element_type: ContextElementType::Memory,
content: mem.content,
estimated_tokens,
priority: 0.0, relevance: mem.score,
recency_seconds: 3600, is_important: false,
access_count: 1,
metadata: mem.metadata,
};
elements.push(element);
}
}
if let Some(embedding) = query_embedding {
if let Some(cached) = self.cache.get(embedding)? {
self.stats.cache_hits.fetch_add(1, Ordering::SeqCst);
let estimated_tokens = self.estimate_tokens(&cached.result);
let element = ContextElement {
id: format!("cache-{}", cached.tool_name),
element_type: ContextElementType::Cached,
content: format!("[Cached {}] {}", cached.tool_name, cached.result),
estimated_tokens,
priority: 0.7,
relevance: cached.similarity,
recency_seconds: (now - cached.cached_at).num_seconds(),
is_important: false,
access_count: cached.access_count,
metadata: HashMap::new(),
};
elements.push(element);
}
}
for element in &mut elements {
if element.priority == 0.0 {
element.priority = self.scorer.score(element);
}
}
elements.sort_by(|a, b| {
b.priority
.partial_cmp(&a.priority)
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut total_tokens = 0usize;
let mut included = Vec::new();
let mut summarized_count = 0usize;
let mut excluded_count = 0usize;
for element in elements {
if total_tokens + element.estimated_tokens <= budget {
total_tokens += element.estimated_tokens;
included.push(element);
} else if self.config.enable_summarization && element.priority > 0.5 {
let remaining_budget = budget - total_tokens;
if remaining_budget > 50 {
let summarized_content = self.summarizer.summarize(
&element.content,
remaining_budget,
self.config.chars_per_token,
);
let summarized_tokens = self.estimate_tokens(&summarized_content);
if summarized_tokens <= remaining_budget {
let mut summarized_element = element;
summarized_element.content = summarized_content;
summarized_element.estimated_tokens = summarized_tokens;
total_tokens += summarized_tokens;
included.push(summarized_element);
summarized_count += 1;
self.stats.summarizations.fetch_add(1, Ordering::SeqCst);
} else {
excluded_count += 1;
}
} else {
excluded_count += 1;
}
} else {
excluded_count += 1;
}
}
let elapsed = start.elapsed().as_micros() as u64;
self.stats
.total_tokens
.fetch_add(total_tokens as u64, Ordering::SeqCst);
self.stats
.total_time_us
.fetch_add(elapsed, Ordering::SeqCst);
Ok(PreparedContext {
elements: included,
total_tokens,
budget_used: total_tokens as f32 / budget as f32,
summarized_count,
excluded_count,
preparation_time_us: elapsed,
})
}
pub fn memory(&self) -> &AgenticMemory {
&self.memory
}
pub fn memory_mut(&mut self) -> &mut AgenticMemory {
&mut self.memory
}
pub fn cache(&self) -> &SemanticToolCache {
&self.cache
}
pub fn store_memory(
&self,
key: &str,
content: &str,
embedding: Vec<f32>,
memory_type: MemoryType,
) -> Result<String> {
self.memory.store(key, content, embedding, memory_type)
}
pub fn cache_tool_result(
&self,
tool_name: &str,
input: &str,
result: &str,
embedding: Vec<f32>,
) -> Result<()> {
self.cache.store(tool_name, input, result, embedding)
}
pub fn stats(&self) -> ContextManagerStats {
let preps = self.stats.preparations.load(Ordering::SeqCst);
let total_tokens = self.stats.total_tokens.load(Ordering::SeqCst);
let total_time = self.stats.total_time_us.load(Ordering::SeqCst);
ContextManagerStats {
total_preparations: preps,
avg_tokens: total_tokens.checked_div(preps).unwrap_or(0),
avg_preparation_time_us: total_time.checked_div(preps).unwrap_or(0),
summarizations: self.stats.summarizations.load(Ordering::SeqCst),
cache_hits: self.stats.cache_hits.load(Ordering::SeqCst),
memory_retrievals: self.stats.memory_retrievals.load(Ordering::SeqCst),
}
}
fn estimate_tokens(&self, content: &str) -> usize {
(content.len() as f32 / self.config.chars_per_token).ceil() as usize
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: MessageRole,
pub content: String,
pub timestamp: DateTime<Utc>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum MessageRole {
System,
User,
Assistant,
}
#[cfg(test)]
mod tests {
use super::*;
fn test_config() -> ContextManagerConfig {
ContextManagerConfig {
memory: AgenticMemoryConfig {
semantic_dim: 128,
episodic: super::super::episodic_memory::EpisodicMemoryConfig {
embedding_dim: 128,
..Default::default()
},
..Default::default()
},
cache: SemanticCacheConfig {
embedding_dim: 128,
..Default::default()
},
..Default::default()
}
}
#[test]
fn test_context_manager_creation() {
let config = test_config();
let manager = IntelligentContextManager::new(config).unwrap();
assert_eq!(manager.stats().total_preparations, 0);
}
#[test]
fn test_prepare_context_basic() {
let config = test_config();
let manager = IntelligentContextManager::new(config).unwrap();
let messages = vec![
Message {
role: MessageRole::System,
content: "You are a helpful assistant.".to_string(),
timestamp: Utc::now(),
},
Message {
role: MessageRole::User,
content: "Hello!".to_string(),
timestamp: Utc::now(),
},
];
let prepared = manager.prepare_context(&messages, None, None).unwrap();
assert_eq!(prepared.elements.len(), 2);
assert!(prepared.total_tokens > 0);
assert!(prepared.budget_used < 1.0);
}
#[test]
fn test_prepare_context_with_memory() {
let config = test_config();
let manager = IntelligentContextManager::new(config).unwrap();
let embedding = vec![0.1; 128];
manager
.store_memory(
"fact-1",
"Test fact",
embedding.clone(),
MemoryType::Semantic,
)
.unwrap();
let messages = vec![Message {
role: MessageRole::User,
content: "Tell me about the test.".to_string(),
timestamp: Utc::now(),
}];
let prepared = manager
.prepare_context(&messages, Some(&embedding), None)
.unwrap();
assert!(prepared.elements.len() >= 1);
}
#[test]
fn test_priority_scorer() {
let scorer = PriorityScorer::new(PriorityWeights::default());
let element = ContextElement {
id: "test".to_string(),
element_type: ContextElementType::Memory,
content: "Test content".to_string(),
estimated_tokens: 10,
priority: 0.0,
relevance: 0.9,
recency_seconds: 60,
is_important: true,
access_count: 10,
metadata: HashMap::new(),
};
let score = scorer.score(&element);
assert!(score > 0.5);
assert!(score <= 1.0);
let priority = scorer.assign_priority(score);
assert!(matches!(
priority,
ElementPriority::High | ElementPriority::Critical
));
}
#[test]
fn test_memory_summarizer() {
let summarizer = MemorySummarizer::new(0.5);
let long_content = "A".repeat(1000);
let summarized = summarizer.summarize(&long_content, 50, 4.0);
assert!(summarized.len() < long_content.len());
assert!(summarized.contains("..."));
}
#[test]
fn test_model_token_limits() {
assert_eq!(ModelTokenLimit::Haiku.max_tokens(), 200_000);
assert_eq!(ModelTokenLimit::Sonnet.max_tokens(), 200_000);
assert_eq!(ModelTokenLimit::Opus.max_tokens(), 200_000);
assert_eq!(ModelTokenLimit::Custom(100_000).max_tokens(), 100_000);
assert!(ModelTokenLimit::Sonnet.context_budget() < ModelTokenLimit::Sonnet.max_tokens());
}
}