use crate::content::Content;
use crate::id::OperatorId;
use crate::lifecycle::CompactionPolicy;
use serde::{Deserialize, Serialize};
use std::fmt;
use std::sync::Arc;
#[non_exhaustive]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessageMeta {
pub policy: CompactionPolicy,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub source: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub salience: Option<f64>,
pub version: u64,
}
impl Default for MessageMeta {
fn default() -> Self {
Self {
policy: CompactionPolicy::Normal,
source: None,
salience: None,
version: 0,
}
}
}
impl MessageMeta {
pub fn with_policy(policy: CompactionPolicy) -> Self {
Self {
policy,
..Default::default()
}
}
pub fn set_source(mut self, source: impl Into<String>) -> Self {
self.source = Some(source.into());
self
}
pub fn set_salience(mut self, salience: f64) -> Self {
self.salience = Some(salience);
self
}
}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Role {
System,
User,
Assistant,
Tool {
name: String,
call_id: String,
},
}
#[non_exhaustive]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: Role,
pub content: Content,
pub meta: MessageMeta,
}
impl Message {
pub fn new(role: Role, content: Content) -> Self {
Self {
role,
content,
meta: MessageMeta::default(),
}
}
pub fn pinned(role: Role, content: Content) -> Self {
Self {
role,
content,
meta: MessageMeta {
policy: CompactionPolicy::Pinned,
..Default::default()
},
}
}
pub fn estimated_tokens(&self) -> usize {
use crate::content::ContentBlock;
let content_tokens = match &self.content {
Content::Text(s) => s.len() / 4,
Content::Blocks(blocks) => blocks
.iter()
.map(|b| match b {
ContentBlock::Text { text } => text.len() / 4,
ContentBlock::ToolUse { input, .. } => input.to_string().len() / 4,
ContentBlock::ToolResult { content, .. } => content.len() / 4,
ContentBlock::Image { .. } => 1000,
ContentBlock::Custom { data, .. } => data.to_string().len() / 4,
})
.sum(),
};
content_tokens + 4 }
pub fn text_content(&self) -> String {
use crate::content::ContentBlock;
match &self.content {
Content::Text(s) => s.clone(),
Content::Blocks(blocks) => blocks
.iter()
.filter_map(|b| match b {
ContentBlock::Text { text } => Some(text.as_str()),
ContentBlock::ToolResult { content, .. } => Some(content.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join(" "),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContextMessage<M> {
pub message: M,
pub meta: MessageMeta,
}
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Position {
Back,
Front,
At(usize),
}
#[non_exhaustive]
#[derive(Debug, Clone)]
pub enum WatcherVerdict {
Allow,
Reject {
reason: String,
},
}
pub trait ContextWatcher: Send + Sync {
fn on_inject(&self, msg: &dyn fmt::Debug, pos: Position) -> WatcherVerdict {
let _ = (msg, pos);
WatcherVerdict::Allow
}
fn on_remove(&self, count: usize) -> WatcherVerdict {
let _ = count;
WatcherVerdict::Allow
}
fn on_pre_compact(&self, message_count: usize) -> WatcherVerdict {
let _ = message_count;
WatcherVerdict::Allow
}
fn on_post_compact(&self, removed: usize, remaining: usize) {
let _ = (removed, remaining);
}
}
#[non_exhaustive]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContextSnapshot {
pub message_count: usize,
pub message_metas: Vec<MessageMeta>,
pub has_system: bool,
pub operator_id: OperatorId,
pub estimated_tokens: usize,
}
#[non_exhaustive]
#[derive(Debug, thiserror::Error)]
pub enum ContextError {
#[error("rejected by watcher: {reason}")]
Rejected {
reason: String,
},
#[error("index {index} is out of bounds (len = {len})")]
OutOfBounds {
index: usize,
len: usize,
},
}
pub struct OperatorContext<M: Clone + fmt::Debug> {
operator_id: OperatorId,
messages: Vec<ContextMessage<M>>,
system: Option<String>,
watchers: Vec<Arc<dyn ContextWatcher>>,
}
impl<M: Clone + fmt::Debug> OperatorContext<M> {
pub fn new(operator_id: OperatorId) -> Self {
Self {
operator_id,
messages: Vec::new(),
system: None,
watchers: Vec::new(),
}
}
pub fn add_watcher(&mut self, watcher: Arc<dyn ContextWatcher>) {
self.watchers.push(watcher);
}
pub fn messages(&self) -> &[ContextMessage<M>] {
&self.messages
}
pub fn len(&self) -> usize {
self.messages.len()
}
pub fn is_empty(&self) -> bool {
self.messages.is_empty()
}
pub fn system(&self) -> Option<&str> {
self.system.as_deref()
}
pub fn operator_id(&self) -> &OperatorId {
&self.operator_id
}
pub fn snapshot(&self) -> ContextSnapshot {
let system_chars = self.system.as_ref().map(|s| s.len()).unwrap_or(0);
let message_chars: usize = self
.messages
.iter()
.map(|m| format!("{:?}", m.message).len())
.sum();
let estimated_tokens = (system_chars + message_chars) / 4;
ContextSnapshot {
message_count: self.messages.len(),
message_metas: self.messages.iter().map(|m| m.meta.clone()).collect(),
has_system: self.system.is_some(),
operator_id: self.operator_id.clone(),
estimated_tokens,
}
}
pub fn set_system(&mut self, system: impl Into<String>) {
self.system = Some(system.into());
}
pub fn clear_system(&mut self) {
self.system = None;
}
pub fn inject(&mut self, msg: ContextMessage<M>, pos: Position) -> Result<(), ContextError> {
for watcher in &self.watchers {
match watcher.on_inject(&msg, pos) {
WatcherVerdict::Allow => {}
WatcherVerdict::Reject { reason } => {
return Err(ContextError::Rejected { reason });
}
}
}
match pos {
Position::Back => self.messages.push(msg),
Position::Front => self.messages.insert(0, msg),
Position::At(idx) => {
if idx > self.messages.len() {
return Err(ContextError::OutOfBounds {
index: idx,
len: self.messages.len(),
});
}
self.messages.insert(idx, msg);
}
}
Ok(())
}
pub fn truncate_back(&mut self, count: usize) -> Result<Vec<ContextMessage<M>>, ContextError> {
if count > self.messages.len() {
return Err(ContextError::OutOfBounds {
index: count,
len: self.messages.len(),
});
}
if count > 0 {
for watcher in &self.watchers {
match watcher.on_remove(count) {
WatcherVerdict::Allow => {}
WatcherVerdict::Reject { reason } => {
return Err(ContextError::Rejected { reason });
}
}
}
}
let split_at = self.messages.len() - count;
Ok(self.messages.drain(split_at..).collect())
}
pub fn truncate_front(&mut self, count: usize) -> Result<Vec<ContextMessage<M>>, ContextError> {
if count > self.messages.len() {
return Err(ContextError::OutOfBounds {
index: count,
len: self.messages.len(),
});
}
if count > 0 {
for watcher in &self.watchers {
match watcher.on_remove(count) {
WatcherVerdict::Allow => {}
WatcherVerdict::Reject { reason } => {
return Err(ContextError::Rejected { reason });
}
}
}
}
Ok(self.messages.drain(..count).collect())
}
pub fn remove_where(
&mut self,
pred: impl Fn(&ContextMessage<M>) -> bool,
) -> Result<Vec<ContextMessage<M>>, ContextError> {
let count = self.messages.iter().filter(|m| pred(m)).count();
if count > 0 {
for watcher in &self.watchers {
match watcher.on_remove(count) {
WatcherVerdict::Allow => {}
WatcherVerdict::Reject { reason } => {
return Err(ContextError::Rejected { reason });
}
}
}
}
let mut removed = Vec::new();
let mut kept = Vec::new();
for msg in self.messages.drain(..) {
if pred(&msg) {
removed.push(msg);
} else {
kept.push(msg);
}
}
self.messages = kept;
Ok(removed)
}
pub fn transform(&mut self, mut f: impl FnMut(&mut ContextMessage<M>)) {
for msg in &mut self.messages {
f(msg);
msg.meta.version += 1;
}
}
pub fn extract(&self, pred: impl Fn(&ContextMessage<M>) -> bool) -> Vec<&ContextMessage<M>> {
self.messages.iter().filter(|m| pred(m)).collect()
}
pub fn messages_mut(&mut self) -> &mut Vec<ContextMessage<M>> {
&mut self.messages
}
pub fn replace_messages(
&mut self,
new: Vec<ContextMessage<M>>,
) -> Result<Vec<ContextMessage<M>>, ContextError> {
let old_count = self.messages.len();
for watcher in &self.watchers {
match watcher.on_pre_compact(old_count) {
WatcherVerdict::Allow => {}
WatcherVerdict::Reject { reason } => {
return Err(ContextError::Rejected { reason });
}
}
}
let new_count = new.len();
let old = std::mem::replace(&mut self.messages, new);
let removed = old_count.saturating_sub(new_count);
for watcher in &self.watchers {
watcher.on_post_compact(removed, new_count);
}
Ok(old)
}
}
pub struct Context {
operator_id: OperatorId,
messages: Vec<Message>,
watchers: Vec<Arc<dyn ContextWatcher>>,
}
impl Context {
pub fn new(operator_id: OperatorId) -> Self {
Self {
operator_id,
messages: Vec::new(),
watchers: Vec::new(),
}
}
pub fn add_watcher(&mut self, watcher: Arc<dyn ContextWatcher>) {
self.watchers.push(watcher);
}
pub fn messages(&self) -> &[Message] {
&self.messages
}
pub fn len(&self) -> usize {
self.messages.len()
}
pub fn is_empty(&self) -> bool {
self.messages.is_empty()
}
pub fn operator_id(&self) -> &OperatorId {
&self.operator_id
}
pub fn estimated_tokens(&self) -> usize {
self.messages.iter().map(|m| m.estimated_tokens()).sum()
}
pub fn push(&mut self, msg: Message) -> Result<(), ContextError> {
for watcher in &self.watchers {
match watcher.on_inject(&msg, Position::Back) {
WatcherVerdict::Allow => {}
WatcherVerdict::Reject { reason } => {
return Err(ContextError::Rejected { reason });
}
}
}
self.messages.push(msg);
Ok(())
}
pub fn insert(&mut self, msg: Message, pos: Position) -> Result<(), ContextError> {
for watcher in &self.watchers {
match watcher.on_inject(&msg, pos) {
WatcherVerdict::Allow => {}
WatcherVerdict::Reject { reason } => {
return Err(ContextError::Rejected { reason });
}
}
}
match pos {
Position::Back => self.messages.push(msg),
Position::Front => self.messages.insert(0, msg),
Position::At(idx) => {
if idx > self.messages.len() {
return Err(ContextError::OutOfBounds {
index: idx,
len: self.messages.len(),
});
}
self.messages.insert(idx, msg);
}
}
Ok(())
}
pub fn compact_truncate(&mut self, keep: usize) -> Vec<Message> {
if keep >= self.messages.len() {
return Vec::new();
}
let old_count = self.messages.len();
for watcher in &self.watchers {
watcher.on_pre_compact(old_count);
}
let split = self.messages.len() - keep;
let removed: Vec<Message> = self.messages.drain(..split).collect();
for watcher in &self.watchers {
watcher.on_post_compact(removed.len(), self.messages.len());
}
removed
}
pub fn compact_by_policy(&mut self) -> Vec<Message> {
let old_count = self.messages.len();
for watcher in &self.watchers {
watcher.on_pre_compact(old_count);
}
let mut kept = Vec::new();
let mut removed = Vec::new();
for msg in self.messages.drain(..) {
if matches!(msg.meta.policy, CompactionPolicy::Pinned) {
kept.push(msg);
} else {
removed.push(msg);
}
}
self.messages = kept;
for watcher in &self.watchers {
watcher.on_post_compact(removed.len(), self.messages.len());
}
removed
}
pub fn compact_with(&mut self, f: impl FnOnce(&[Message]) -> Vec<Message>) -> Vec<Message> {
let old_count = self.messages.len();
for watcher in &self.watchers {
watcher.on_pre_compact(old_count);
}
let new_messages = f(&self.messages);
let old = std::mem::replace(&mut self.messages, new_messages);
let removed_count = old.len().saturating_sub(self.messages.len());
let removed = old;
for watcher in &self.watchers {
watcher.on_post_compact(removed_count, self.messages.len());
}
removed
}
pub fn messages_mut(&mut self) -> &mut Vec<Message> {
&mut self.messages
}
pub fn snapshot(&self) -> ContextSnapshot {
let estimated_tokens = self.estimated_tokens();
ContextSnapshot {
message_count: self.messages.len(),
message_metas: self.messages.iter().map(|m| m.meta.clone()).collect(),
has_system: self.messages.iter().any(|m| matches!(m.role, Role::System)),
operator_id: self.operator_id.clone(),
estimated_tokens,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fmt;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
type TestMsg = String;
fn make_msg(s: &str) -> ContextMessage<TestMsg> {
ContextMessage {
message: s.to_string(),
meta: MessageMeta::default(),
}
}
#[test]
fn new_context_is_empty() {
let ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("agent-1"));
assert!(ctx.is_empty());
assert_eq!(ctx.len(), 0);
assert!(ctx.messages().is_empty());
}
#[test]
fn inject_back_appends_in_order() {
let mut ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("a"));
ctx.inject(make_msg("first"), Position::Back).unwrap();
ctx.inject(make_msg("second"), Position::Back).unwrap();
assert_eq!(ctx.messages()[0].message, "first");
assert_eq!(ctx.messages()[1].message, "second");
}
#[test]
fn inject_front_prepends() {
let mut ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("a"));
ctx.inject(make_msg("first"), Position::Back).unwrap();
ctx.inject(make_msg("second"), Position::Front).unwrap();
assert_eq!(ctx.messages()[0].message, "second");
assert_eq!(ctx.messages()[1].message, "first");
}
#[test]
fn inject_at_inserts_at_index() {
let mut ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("a"));
ctx.inject(make_msg("a"), Position::Back).unwrap();
ctx.inject(make_msg("c"), Position::Back).unwrap();
ctx.inject(make_msg("b"), Position::At(1)).unwrap();
assert_eq!(ctx.messages()[0].message, "a");
assert_eq!(ctx.messages()[1].message, "b");
assert_eq!(ctx.messages()[2].message, "c");
}
#[test]
fn inject_out_of_bounds_returns_error() {
let mut ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("a"));
let err = ctx.inject(make_msg("x"), Position::At(5)).unwrap_err();
assert!(matches!(
err,
ContextError::OutOfBounds { index: 5, len: 0 }
));
assert!(ctx.is_empty());
}
#[test]
fn truncate_back_removes_from_end() {
let mut ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("a"));
ctx.inject(make_msg("a"), Position::Back).unwrap();
ctx.inject(make_msg("b"), Position::Back).unwrap();
ctx.inject(make_msg("c"), Position::Back).unwrap();
let removed = ctx.truncate_back(2).unwrap();
assert_eq!(removed.len(), 2);
assert_eq!(removed[0].message, "b");
assert_eq!(removed[1].message, "c");
assert_eq!(ctx.len(), 1);
assert_eq!(ctx.messages()[0].message, "a");
}
#[test]
fn truncate_back_out_of_bounds_returns_error() {
let mut ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("a"));
ctx.inject(make_msg("a"), Position::Back).unwrap();
let err = ctx.truncate_back(5).unwrap_err();
assert!(matches!(
err,
ContextError::OutOfBounds { index: 5, len: 1 }
));
assert_eq!(ctx.len(), 1); }
#[test]
fn truncate_front_removes_from_start() {
let mut ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("a"));
ctx.inject(make_msg("a"), Position::Back).unwrap();
ctx.inject(make_msg("b"), Position::Back).unwrap();
ctx.inject(make_msg("c"), Position::Back).unwrap();
let removed = ctx.truncate_front(2).unwrap();
assert_eq!(removed.len(), 2);
assert_eq!(removed[0].message, "a");
assert_eq!(removed[1].message, "b");
assert_eq!(ctx.len(), 1);
assert_eq!(ctx.messages()[0].message, "c");
}
#[test]
fn truncate_front_out_of_bounds_returns_error() {
let mut ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("a"));
ctx.inject(make_msg("a"), Position::Back).unwrap();
let err = ctx.truncate_front(5).unwrap_err();
assert!(matches!(
err,
ContextError::OutOfBounds { index: 5, len: 1 }
));
assert_eq!(ctx.len(), 1); }
#[test]
fn watcher_can_reject_inject() {
struct RejectAll;
impl ContextWatcher for RejectAll {
fn on_inject(&self, _msg: &dyn fmt::Debug, _pos: Position) -> WatcherVerdict {
WatcherVerdict::Reject {
reason: "policy violation".into(),
}
}
}
let mut ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("a"));
ctx.add_watcher(Arc::new(RejectAll));
let err = ctx.inject(make_msg("blocked"), Position::Back).unwrap_err();
assert!(matches!(err, ContextError::Rejected { .. }));
assert!(ctx.is_empty());
}
#[test]
fn snapshot_captures_state() {
let mut ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("my-agent"));
ctx.set_system("You are helpful.");
ctx.inject(make_msg("hello"), Position::Back).unwrap();
let snap = ctx.snapshot();
assert_eq!(snap.message_count, 1);
assert!(snap.has_system);
assert_eq!(snap.operator_id.as_str(), "my-agent");
assert_eq!(snap.message_metas.len(), 1);
}
#[test]
fn transform_increments_version() {
let mut ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("a"));
ctx.inject(make_msg("msg"), Position::Back).unwrap();
assert_eq!(ctx.messages()[0].meta.version, 0);
ctx.transform(|_| {});
assert_eq!(ctx.messages()[0].meta.version, 1);
ctx.transform(|_| {});
assert_eq!(ctx.messages()[0].meta.version, 2);
}
#[test]
fn replace_messages_fires_compact_watchers() {
let pre_called = Arc::new(AtomicBool::new(false));
let post_called = Arc::new(AtomicBool::new(false));
struct CompactWatcher {
pre: Arc<AtomicBool>,
post: Arc<AtomicBool>,
}
impl ContextWatcher for CompactWatcher {
fn on_pre_compact(&self, _message_count: usize) -> WatcherVerdict {
self.pre.store(true, Ordering::SeqCst);
WatcherVerdict::Allow
}
fn on_post_compact(&self, _removed: usize, _remaining: usize) {
self.post.store(true, Ordering::SeqCst);
}
}
let mut ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("a"));
ctx.add_watcher(Arc::new(CompactWatcher {
pre: Arc::clone(&pre_called),
post: Arc::clone(&post_called),
}));
ctx.inject(make_msg("old"), Position::Back).unwrap();
let old = ctx.replace_messages(vec![make_msg("new")]).unwrap();
assert!(
pre_called.load(Ordering::SeqCst),
"on_pre_compact not called"
);
assert!(
post_called.load(Ordering::SeqCst),
"on_post_compact not called"
);
assert_eq!(old.len(), 1);
assert_eq!(old[0].message, "old");
assert_eq!(ctx.messages()[0].message, "new");
}
#[test]
fn remove_where_filters_correctly() {
let mut ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("a"));
ctx.inject(make_msg("keep"), Position::Back).unwrap();
ctx.inject(make_msg("remove_me"), Position::Back).unwrap();
ctx.inject(make_msg("also keep"), Position::Back).unwrap();
let removed = ctx.remove_where(|m| m.message.contains("remove")).unwrap();
assert_eq!(removed.len(), 1);
assert_eq!(removed[0].message, "remove_me");
assert_eq!(ctx.len(), 2);
assert_eq!(ctx.messages()[0].message, "keep");
assert_eq!(ctx.messages()[1].message, "also keep");
}
#[test]
fn extract_is_non_destructive() {
let mut ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("a"));
ctx.inject(make_msg("a"), Position::Back).unwrap();
ctx.inject(make_msg("b"), Position::Back).unwrap();
ctx.inject(make_msg("c"), Position::Back).unwrap();
let found = ctx.extract(|m| m.message != "b");
assert_eq!(found.len(), 2);
assert_eq!(found[0].message, "a");
assert_eq!(found[1].message, "c");
assert_eq!(ctx.len(), 3);
}
#[test]
fn system_prompt_lifecycle() {
let mut ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("a"));
assert!(ctx.system().is_none());
ctx.set_system("Hello, system!");
assert_eq!(ctx.system(), Some("Hello, system!"));
ctx.clear_system();
assert!(ctx.system().is_none());
}
#[test]
fn message_construction_and_role_variants() {
use crate::content::Content;
use crate::lifecycle::CompactionPolicy;
let msg = Message {
role: Role::User,
content: Content::text("hello"),
meta: MessageMeta::default(),
};
assert!(matches!(msg.role, Role::User));
let tool_msg = Message {
role: Role::Tool {
name: "shell".into(),
call_id: "tc_1".into(),
},
content: Content::text("output"),
meta: MessageMeta::default(),
};
assert!(matches!(tool_msg.role, Role::Tool { .. }));
let pinned = Message::pinned(Role::System, Content::text("system"));
assert!(matches!(pinned.meta.policy, CompactionPolicy::Pinned));
}
#[test]
fn message_serde_roundtrip() {
use crate::content::Content;
let msg = Message {
role: Role::Assistant,
content: Content::text("hi"),
meta: MessageMeta::default(),
};
let json = serde_json::to_string(&msg).unwrap();
let rt: Message = serde_json::from_str(&json).unwrap();
assert!(matches!(rt.role, Role::Assistant));
}
#[test]
fn message_estimated_tokens() {
use crate::content::Content;
let msg = Message::new(Role::User, Content::text("12345678901234567890"));
assert_eq!(msg.estimated_tokens(), 9);
}
#[test]
fn message_text_content_extraction() {
use crate::content::Content;
let msg = Message::new(Role::User, Content::text("hello world"));
assert_eq!(msg.text_content(), "hello world");
}
#[test]
fn context_push_and_read() {
use crate::content::Content;
let mut ctx = Context::new(OperatorId::from("agent-1"));
ctx.push(Message::new(Role::User, Content::text("hello")))
.unwrap();
ctx.push(Message::new(Role::Assistant, Content::text("hi")))
.unwrap();
assert_eq!(ctx.len(), 2);
assert!(matches!(ctx.messages()[0].role, Role::User));
assert!(matches!(ctx.messages()[1].role, Role::Assistant));
}
#[test]
fn context_compact_truncate() {
use crate::content::Content;
let mut ctx = Context::new(OperatorId::from("a"));
for i in 0..10 {
ctx.push(Message::new(
Role::User,
Content::text(format!("msg {}", i)),
))
.unwrap();
}
let removed = ctx.compact_truncate(3);
assert_eq!(removed.len(), 7);
assert_eq!(ctx.len(), 3);
}
#[test]
fn context_compact_by_policy_preserves_pinned() {
use crate::content::Content;
let mut ctx = Context::new(OperatorId::from("a"));
ctx.push(Message::pinned(
Role::System,
Content::text("you are helpful"),
))
.unwrap();
for i in 0..5 {
ctx.push(Message::new(
Role::User,
Content::text(format!("msg {}", i)),
))
.unwrap();
}
let removed = ctx.compact_by_policy();
assert_eq!(ctx.len(), 1);
assert!(matches!(ctx.messages()[0].role, Role::System));
assert_eq!(removed.len(), 5);
}
#[test]
fn context_compact_with_closure() {
use crate::content::Content;
let mut ctx = Context::new(OperatorId::from("a"));
for i in 0..6 {
ctx.push(Message::new(
Role::User,
Content::text(format!("msg {}", i)),
))
.unwrap();
}
let removed = ctx.compact_with(|msgs| {
msgs.iter()
.enumerate()
.filter(|(i, _)| i % 2 == 0)
.map(|(_, m)| m.clone())
.collect()
});
assert_eq!(ctx.len(), 3);
assert_eq!(removed.len(), 6);
}
#[test]
fn context_snapshot() {
use crate::content::Content;
let mut ctx = Context::new(OperatorId::from("my-agent"));
ctx.push(Message::pinned(Role::System, Content::text("system")))
.unwrap();
ctx.push(Message::new(Role::User, Content::text("hello")))
.unwrap();
let snap = ctx.snapshot();
assert_eq!(snap.message_count, 2);
assert!(snap.has_system);
assert_eq!(snap.operator_id.as_str(), "my-agent");
assert_eq!(snap.message_metas.len(), 2);
}
#[test]
fn context_estimated_tokens() {
use crate::content::Content;
let mut ctx = Context::new(OperatorId::from("a"));
ctx.push(Message::new(
Role::User,
Content::text("12345678901234567890"),
))
.unwrap();
ctx.push(Message::new(
Role::User,
Content::text("12345678901234567890"),
))
.unwrap();
assert_eq!(ctx.estimated_tokens(), 18);
}
}