use std::sync::{Arc, Mutex};
use crate::context::CompactionReport;
use crate::context_transformer::ContextTransformer;
use crate::types::{AgentMessage, LlmMessage};
#[derive(Debug, Clone)]
pub struct ContextVersion {
pub version: u64,
pub turn: u64,
pub timestamp: u64,
pub messages: Vec<LlmMessage>,
pub summary: Option<String>,
}
#[derive(Debug, Clone)]
pub struct ContextVersionMeta {
pub version: u64,
pub turn: u64,
pub timestamp: u64,
pub message_count: usize,
pub has_summary: bool,
}
pub trait ContextVersionStore: Send + Sync {
fn save_version(&self, version: &ContextVersion);
fn load_version(&self, version: u64) -> Option<ContextVersion>;
fn list_versions(&self) -> Vec<ContextVersionMeta>;
fn latest_version(&self) -> Option<ContextVersion> {
let versions = self.list_versions();
versions
.last()
.and_then(|meta| self.load_version(meta.version))
}
}
pub trait ContextSummarizer: Send + Sync {
fn summarize(&self, messages: &[LlmMessage]) -> Option<String>;
}
pub struct InMemoryVersionStore {
versions: Mutex<Vec<ContextVersion>>,
}
impl InMemoryVersionStore {
#[must_use]
pub const fn new() -> Self {
Self {
versions: Mutex::new(Vec::new()),
}
}
pub fn len(&self) -> usize {
self.versions
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl Default for InMemoryVersionStore {
fn default() -> Self {
Self::new()
}
}
impl ContextVersionStore for InMemoryVersionStore {
fn save_version(&self, version: &ContextVersion) {
let mut guard = self
.versions
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
guard.push(version.clone());
}
fn load_version(&self, version: u64) -> Option<ContextVersion> {
let guard = self
.versions
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
guard.iter().find(|v| v.version == version).cloned()
}
fn list_versions(&self) -> Vec<ContextVersionMeta> {
let guard = self
.versions
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
guard
.iter()
.map(|v| ContextVersionMeta {
version: v.version,
turn: v.turn,
timestamp: v.timestamp,
message_count: v.messages.len(),
has_summary: v.summary.is_some(),
})
.collect()
}
}
pub struct VersioningTransformer {
inner: Box<dyn ContextTransformer>,
store: Arc<dyn ContextVersionStore>,
summarizer: Option<Arc<dyn ContextSummarizer>>,
state: Mutex<VersioningState>,
}
struct VersioningState {
next_version: u64,
turn_counter: u64,
}
impl VersioningTransformer {
pub fn new(
inner: impl ContextTransformer + 'static,
store: Arc<dyn ContextVersionStore>,
) -> Self {
Self {
inner: Box::new(inner),
store,
summarizer: None,
state: Mutex::new(VersioningState {
next_version: 1,
turn_counter: 0,
}),
}
}
#[must_use]
pub fn with_summarizer(mut self, summarizer: Arc<dyn ContextSummarizer>) -> Self {
self.summarizer = Some(summarizer);
self
}
pub fn store(&self) -> &Arc<dyn ContextVersionStore> {
&self.store
}
}
impl ContextTransformer for VersioningTransformer {
fn transform(
&self,
messages: &mut Vec<AgentMessage>,
overflow: bool,
) -> Option<CompactionReport> {
let report = self.inner.transform(messages, overflow)?;
if report.dropped_messages.is_empty() {
return Some(report);
}
let mut state = self
.state
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
state.turn_counter += 1;
let summary = self
.summarizer
.as_ref()
.and_then(|s| s.summarize(&report.dropped_messages));
let version = ContextVersion {
version: state.next_version,
turn: state.turn_counter,
timestamp: crate::util::now_timestamp(),
messages: report.dropped_messages.clone(),
summary,
};
state.next_version += 1;
drop(state);
self.store.save_version(&version);
Some(report)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::context_transformer::SlidingWindowTransformer;
use crate::types::{ContentBlock, UserMessage};
fn text_message(text: &str) -> AgentMessage {
AgentMessage::Llm(LlmMessage::User(UserMessage {
content: vec![ContentBlock::Text {
text: text.to_owned(),
}],
timestamp: 0,
cache_hint: None,
}))
}
#[test]
fn versioning_captures_dropped_messages() {
let store: Arc<dyn ContextVersionStore> = Arc::new(InMemoryVersionStore::new());
let inner = SlidingWindowTransformer::new(250, 100, 1);
let transformer = VersioningTransformer::new(inner, Arc::clone(&store));
let body = "x".repeat(400);
let mut messages = vec![
text_message(&body),
text_message(&body),
text_message(&body),
text_message(&body),
];
let report = transformer.transform(&mut messages, false);
assert!(report.is_some());
assert_eq!(messages.len(), 2);
let versions = store.list_versions();
assert_eq!(versions.len(), 1);
assert_eq!(versions[0].version, 1);
assert_eq!(versions[0].message_count, 2);
let v = store.load_version(1).unwrap();
assert_eq!(v.messages.len(), 2);
assert!(v.summary.is_none());
}
#[test]
fn versioning_with_summarizer() {
struct TestSummarizer;
impl ContextSummarizer for TestSummarizer {
fn summarize(&self, messages: &[LlmMessage]) -> Option<String> {
Some(format!("Summary of {} messages", messages.len()))
}
}
let store: Arc<dyn ContextVersionStore> = Arc::new(InMemoryVersionStore::new());
let inner = SlidingWindowTransformer::new(250, 100, 1);
let transformer = VersioningTransformer::new(inner, Arc::clone(&store))
.with_summarizer(Arc::new(TestSummarizer));
let body = "x".repeat(400);
let mut messages = vec![
text_message(&body),
text_message(&body),
text_message(&body),
text_message(&body),
];
transformer.transform(&mut messages, false);
let v = store.load_version(1).unwrap();
assert_eq!(v.summary.as_deref(), Some("Summary of 2 messages"));
}
#[test]
fn no_compaction_no_version_saved() {
let store: Arc<dyn ContextVersionStore> = Arc::new(InMemoryVersionStore::new());
let inner = SlidingWindowTransformer::new(10_000, 5_000, 1);
let transformer = VersioningTransformer::new(inner, Arc::clone(&store));
let mut messages = vec![text_message("hello"), text_message("world")];
let report = transformer.transform(&mut messages, false);
assert!(report.is_none());
assert!(store.list_versions().is_empty());
}
#[test]
fn multiple_compactions_increment_version() {
let store: Arc<dyn ContextVersionStore> = Arc::new(InMemoryVersionStore::new());
let inner = SlidingWindowTransformer::new(250, 100, 1);
let transformer = VersioningTransformer::new(inner, Arc::clone(&store));
let body = "x".repeat(400);
let mut messages = vec![
text_message(&body),
text_message(&body),
text_message(&body),
text_message(&body),
];
transformer.transform(&mut messages, false);
let mut messages = vec![
text_message(&body),
text_message(&body),
text_message(&body),
text_message(&body),
];
transformer.transform(&mut messages, false);
let versions = store.list_versions();
assert_eq!(versions.len(), 2);
assert_eq!(versions[0].version, 1);
assert_eq!(versions[1].version, 2);
}
#[test]
fn latest_version_returns_most_recent() {
let store: Arc<dyn ContextVersionStore> = Arc::new(InMemoryVersionStore::new());
let inner = SlidingWindowTransformer::new(250, 100, 1);
let transformer = VersioningTransformer::new(inner, Arc::clone(&store));
let body = "x".repeat(400);
for _ in 0..3 {
let mut messages = vec![
text_message(&body),
text_message(&body),
text_message(&body),
text_message(&body),
];
transformer.transform(&mut messages, false);
}
let latest = store.latest_version().unwrap();
assert_eq!(latest.version, 3);
}
#[test]
fn in_memory_store_load_nonexistent() {
let store = InMemoryVersionStore::new();
assert!(store.load_version(999).is_none());
assert!(store.is_empty());
}
#[test]
fn version_meta_fields_correct() {
let store: Arc<dyn ContextVersionStore> = Arc::new(InMemoryVersionStore::new());
let inner = SlidingWindowTransformer::new(250, 100, 1);
let transformer = VersioningTransformer::new(inner, Arc::clone(&store));
let body = "x".repeat(400);
let mut messages = vec![
text_message(&body),
text_message(&body),
text_message(&body),
text_message(&body),
];
transformer.transform(&mut messages, false);
let meta = &store.list_versions()[0];
assert_eq!(meta.version, 1);
assert_eq!(meta.turn, 1);
assert!(!meta.has_summary);
assert!(meta.timestamp > 0);
assert_eq!(meta.message_count, 2);
}
#[test]
fn store_accessor() {
let store: Arc<dyn ContextVersionStore> = Arc::new(InMemoryVersionStore::new());
let inner = SlidingWindowTransformer::new(250, 100, 1);
let transformer = VersioningTransformer::new(inner, Arc::clone(&store));
assert!(transformer.store().list_versions().is_empty());
}
#[test]
fn report_dropped_messages_populated_by_compaction() {
use crate::context::compact_sliding_window_with;
let body = "x".repeat(400);
let mut messages = vec![
text_message(&body), text_message(&body), text_message(&body), text_message(&body), ];
let report = compact_sliding_window_with(&mut messages, 250, 1, None).unwrap();
assert_eq!(report.dropped_messages.len(), 2);
assert_eq!(messages.len(), 2);
}
#[test]
fn versioning_uses_report_dropped_messages_not_debug_diff() {
let store: Arc<dyn ContextVersionStore> = Arc::new(InMemoryVersionStore::new());
let inner = SlidingWindowTransformer::new(250, 100, 1);
let transformer = VersioningTransformer::new(inner, Arc::clone(&store));
let body_a = "a".repeat(400); let body_b = "b".repeat(400);
let mut messages = vec![
text_message(&body_a),
text_message(&body_a),
text_message(&body_b),
text_message(&body_b),
];
let report = transformer.transform(&mut messages, false);
assert!(report.is_some());
let v = store.load_version(1).unwrap();
assert_eq!(v.messages.len(), 2);
if let LlmMessage::User(ref u) = v.messages[0] {
if let ContentBlock::Text { ref text } = u.content[0] {
assert_eq!(text, &body_a);
} else {
panic!("expected text block");
}
} else {
panic!("expected user message");
}
}
#[test]
fn custom_messages_excluded_from_dropped_messages() {
use crate::context::compact_sliding_window_with;
use crate::types::CustomMessage;
#[derive(Debug)]
struct Marker;
impl CustomMessage for Marker {
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
let body = "x".repeat(400); let mut messages = vec![
text_message(&body), AgentMessage::Custom(Box::new(Marker)), text_message(&body), text_message(&body), ];
let report = compact_sliding_window_with(&mut messages, 250, 1, None).unwrap();
assert_eq!(report.dropped_messages.len(), 1);
}
}