1use crate::content::Content;
17use crate::id::OperatorId;
18use crate::lifecycle::CompactionPolicy;
19use serde::{Deserialize, Serialize};
20use std::fmt;
21use std::sync::Arc;
22
23#[non_exhaustive]
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct MessageMeta {
30 pub policy: CompactionPolicy,
32
33 #[serde(default, skip_serializing_if = "Option::is_none")]
35 pub source: Option<String>,
36
37 #[serde(default, skip_serializing_if = "Option::is_none")]
39 pub salience: Option<f64>,
40
41 pub version: u64,
43}
44
45impl Default for MessageMeta {
46 fn default() -> Self {
47 Self {
48 policy: CompactionPolicy::Normal,
49 source: None,
50 salience: None,
51 version: 0,
52 }
53 }
54}
55
56impl MessageMeta {
57 pub fn with_policy(policy: CompactionPolicy) -> Self {
59 Self {
60 policy,
61 ..Default::default()
62 }
63 }
64
65 pub fn set_source(mut self, source: impl Into<String>) -> Self {
67 self.source = Some(source.into());
68 self
69 }
70
71 pub fn set_salience(mut self, salience: f64) -> Self {
73 self.salience = Some(salience);
74 self
75 }
76}
77
78#[non_exhaustive]
80#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
81#[serde(rename_all = "snake_case")]
82pub enum Role {
83 System,
85 User,
87 Assistant,
89 Tool {
91 name: String,
93 call_id: String,
95 },
96}
97
98#[non_exhaustive]
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct Message {
105 pub role: Role,
107 pub content: Content,
109 pub meta: MessageMeta,
111}
112
113impl Message {
114 pub fn new(role: Role, content: Content) -> Self {
116 Self {
117 role,
118 content,
119 meta: MessageMeta::default(),
120 }
121 }
122
123 pub fn pinned(role: Role, content: Content) -> Self {
125 Self {
126 role,
127 content,
128 meta: MessageMeta {
129 policy: CompactionPolicy::Pinned,
130 ..Default::default()
131 },
132 }
133 }
134
135 pub fn estimated_tokens(&self) -> usize {
137 use crate::content::ContentBlock;
138 let content_tokens = match &self.content {
139 Content::Text(s) => s.len() / 4,
140 Content::Blocks(blocks) => blocks
141 .iter()
142 .map(|b| match b {
143 ContentBlock::Text { text } => text.len() / 4,
144 ContentBlock::ToolUse { input, .. } => input.to_string().len() / 4,
145 ContentBlock::ToolResult { content, .. } => content.len() / 4,
146 ContentBlock::Image { .. } => 1000,
147 ContentBlock::Custom { data, .. } => data.to_string().len() / 4,
148 })
149 .sum(),
150 };
151 content_tokens + 4 }
153
154 pub fn text_content(&self) -> String {
156 use crate::content::ContentBlock;
157 match &self.content {
158 Content::Text(s) => s.clone(),
159 Content::Blocks(blocks) => blocks
160 .iter()
161 .filter_map(|b| match b {
162 ContentBlock::Text { text } => Some(text.as_str()),
163 ContentBlock::ToolResult { content, .. } => Some(content.as_str()),
164 _ => None,
165 })
166 .collect::<Vec<_>>()
167 .join(" "),
168 }
169 }
170}
171
172#[derive(Debug, Clone, Serialize, Deserialize)]
177pub struct ContextMessage<M> {
178 pub message: M,
180
181 pub meta: MessageMeta,
183}
184
185#[non_exhaustive]
187#[derive(Debug, Clone, Copy, PartialEq, Eq)]
188pub enum Position {
189 Back,
191
192 Front,
194
195 At(usize),
200}
201
202#[non_exhaustive]
204#[derive(Debug, Clone)]
205pub enum WatcherVerdict {
206 Allow,
208
209 Reject {
211 reason: String,
213 },
214}
215
216pub trait ContextWatcher: Send + Sync {
231 fn on_inject(&self, msg: &dyn fmt::Debug, pos: Position) -> WatcherVerdict {
238 let _ = (msg, pos);
239 WatcherVerdict::Allow
240 }
241
242 fn on_remove(&self, count: usize) -> WatcherVerdict {
247 let _ = count;
248 WatcherVerdict::Allow
249 }
250
251 fn on_pre_compact(&self, message_count: usize) -> WatcherVerdict {
256 let _ = message_count;
257 WatcherVerdict::Allow
258 }
259
260 fn on_post_compact(&self, removed: usize, remaining: usize) {
265 let _ = (removed, remaining);
266 }
267}
268
269#[non_exhaustive]
271#[derive(Debug, Clone, Serialize, Deserialize)]
272pub struct ContextSnapshot {
273 pub message_count: usize,
275
276 pub message_metas: Vec<MessageMeta>,
278
279 pub has_system: bool,
281
282 pub operator_id: OperatorId,
284
285 pub estimated_tokens: usize,
290}
291
292#[non_exhaustive]
294#[derive(Debug, thiserror::Error)]
295pub enum ContextError {
296 #[error("rejected by watcher: {reason}")]
298 Rejected {
299 reason: String,
301 },
302
303 #[error("index {index} is out of bounds (len = {len})")]
305 OutOfBounds {
306 index: usize,
308
309 len: usize,
311 },
312}
313
314pub struct OperatorContext<M: Clone + fmt::Debug> {
333 operator_id: OperatorId,
334 messages: Vec<ContextMessage<M>>,
335 system: Option<String>,
336 watchers: Vec<Arc<dyn ContextWatcher>>,
337}
338
339impl<M: Clone + fmt::Debug> OperatorContext<M> {
340 pub fn new(operator_id: OperatorId) -> Self {
342 Self {
343 operator_id,
344 messages: Vec::new(),
345 system: None,
346 watchers: Vec::new(),
347 }
348 }
349
350 pub fn add_watcher(&mut self, watcher: Arc<dyn ContextWatcher>) {
352 self.watchers.push(watcher);
353 }
354
355 pub fn messages(&self) -> &[ContextMessage<M>] {
357 &self.messages
358 }
359
360 pub fn len(&self) -> usize {
362 self.messages.len()
363 }
364
365 pub fn is_empty(&self) -> bool {
367 self.messages.is_empty()
368 }
369
370 pub fn system(&self) -> Option<&str> {
372 self.system.as_deref()
373 }
374
375 pub fn operator_id(&self) -> &OperatorId {
377 &self.operator_id
378 }
379
380 pub fn snapshot(&self) -> ContextSnapshot {
385 let system_chars = self.system.as_ref().map(|s| s.len()).unwrap_or(0);
386 let message_chars: usize = self
387 .messages
388 .iter()
389 .map(|m| format!("{:?}", m.message).len())
390 .sum();
391 let estimated_tokens = (system_chars + message_chars) / 4;
392
393 ContextSnapshot {
394 message_count: self.messages.len(),
395 message_metas: self.messages.iter().map(|m| m.meta.clone()).collect(),
396 has_system: self.system.is_some(),
397 operator_id: self.operator_id.clone(),
398 estimated_tokens,
399 }
400 }
401
402 pub fn set_system(&mut self, system: impl Into<String>) {
404 self.system = Some(system.into());
405 }
406
407 pub fn clear_system(&mut self) {
409 self.system = None;
410 }
411
412 pub fn inject(&mut self, msg: ContextMessage<M>, pos: Position) -> Result<(), ContextError> {
419 for watcher in &self.watchers {
420 match watcher.on_inject(&msg, pos) {
421 WatcherVerdict::Allow => {}
422 WatcherVerdict::Reject { reason } => {
423 return Err(ContextError::Rejected { reason });
424 }
425 }
426 }
427
428 match pos {
429 Position::Back => self.messages.push(msg),
430 Position::Front => self.messages.insert(0, msg),
431 Position::At(idx) => {
432 if idx > self.messages.len() {
433 return Err(ContextError::OutOfBounds {
434 index: idx,
435 len: self.messages.len(),
436 });
437 }
438 self.messages.insert(idx, msg);
439 }
440 }
441
442 Ok(())
443 }
444
445 pub fn truncate_back(&mut self, count: usize) -> Result<Vec<ContextMessage<M>>, ContextError> {
450 if count > self.messages.len() {
451 return Err(ContextError::OutOfBounds {
452 index: count,
453 len: self.messages.len(),
454 });
455 }
456
457 if count > 0 {
458 for watcher in &self.watchers {
459 match watcher.on_remove(count) {
460 WatcherVerdict::Allow => {}
461 WatcherVerdict::Reject { reason } => {
462 return Err(ContextError::Rejected { reason });
463 }
464 }
465 }
466 }
467
468 let split_at = self.messages.len() - count;
469 Ok(self.messages.drain(split_at..).collect())
470 }
471
472 pub fn truncate_front(&mut self, count: usize) -> Result<Vec<ContextMessage<M>>, ContextError> {
477 if count > self.messages.len() {
478 return Err(ContextError::OutOfBounds {
479 index: count,
480 len: self.messages.len(),
481 });
482 }
483
484 if count > 0 {
485 for watcher in &self.watchers {
486 match watcher.on_remove(count) {
487 WatcherVerdict::Allow => {}
488 WatcherVerdict::Reject { reason } => {
489 return Err(ContextError::Rejected { reason });
490 }
491 }
492 }
493 }
494
495 Ok(self.messages.drain(..count).collect())
496 }
497
498 pub fn remove_where(
503 &mut self,
504 pred: impl Fn(&ContextMessage<M>) -> bool,
505 ) -> Result<Vec<ContextMessage<M>>, ContextError> {
506 let count = self.messages.iter().filter(|m| pred(m)).count();
507
508 if count > 0 {
509 for watcher in &self.watchers {
510 match watcher.on_remove(count) {
511 WatcherVerdict::Allow => {}
512 WatcherVerdict::Reject { reason } => {
513 return Err(ContextError::Rejected { reason });
514 }
515 }
516 }
517 }
518
519 let mut removed = Vec::new();
520 let mut kept = Vec::new();
521 for msg in self.messages.drain(..) {
522 if pred(&msg) {
523 removed.push(msg);
524 } else {
525 kept.push(msg);
526 }
527 }
528 self.messages = kept;
529 Ok(removed)
530 }
531
532 pub fn transform(&mut self, mut f: impl FnMut(&mut ContextMessage<M>)) {
538 for msg in &mut self.messages {
539 f(msg);
540 msg.meta.version += 1;
541 }
542 }
543
544 pub fn extract(&self, pred: impl Fn(&ContextMessage<M>) -> bool) -> Vec<&ContextMessage<M>> {
548 self.messages.iter().filter(|m| pred(m)).collect()
549 }
550
551 pub fn messages_mut(&mut self) -> &mut Vec<ContextMessage<M>> {
557 &mut self.messages
558 }
559
560 pub fn replace_messages(
569 &mut self,
570 new: Vec<ContextMessage<M>>,
571 ) -> Result<Vec<ContextMessage<M>>, ContextError> {
572 let old_count = self.messages.len();
573
574 for watcher in &self.watchers {
575 match watcher.on_pre_compact(old_count) {
576 WatcherVerdict::Allow => {}
577 WatcherVerdict::Reject { reason } => {
578 return Err(ContextError::Rejected { reason });
579 }
580 }
581 }
582
583 let new_count = new.len();
584 let old = std::mem::replace(&mut self.messages, new);
585 let removed = old_count.saturating_sub(new_count);
586
587 for watcher in &self.watchers {
588 watcher.on_post_compact(removed, new_count);
589 }
590
591 Ok(old)
592 }
593}
594
595pub struct Context {
612 operator_id: OperatorId,
613 messages: Vec<Message>,
614 watchers: Vec<Arc<dyn ContextWatcher>>,
615}
616
617impl Context {
618 pub fn new(operator_id: OperatorId) -> Self {
620 Self {
621 operator_id,
622 messages: Vec::new(),
623 watchers: Vec::new(),
624 }
625 }
626
627 pub fn add_watcher(&mut self, watcher: Arc<dyn ContextWatcher>) {
629 self.watchers.push(watcher);
630 }
631
632 pub fn messages(&self) -> &[Message] {
634 &self.messages
635 }
636
637 pub fn len(&self) -> usize {
639 self.messages.len()
640 }
641
642 pub fn is_empty(&self) -> bool {
644 self.messages.is_empty()
645 }
646
647 pub fn operator_id(&self) -> &OperatorId {
649 &self.operator_id
650 }
651
652 pub fn estimated_tokens(&self) -> usize {
654 self.messages.iter().map(|m| m.estimated_tokens()).sum()
655 }
656
657 pub fn push(&mut self, msg: Message) -> Result<(), ContextError> {
662 for watcher in &self.watchers {
663 match watcher.on_inject(&msg, Position::Back) {
664 WatcherVerdict::Allow => {}
665 WatcherVerdict::Reject { reason } => {
666 return Err(ContextError::Rejected { reason });
667 }
668 }
669 }
670 self.messages.push(msg);
671 Ok(())
672 }
673
674 pub fn insert(&mut self, msg: Message, pos: Position) -> Result<(), ContextError> {
680 for watcher in &self.watchers {
681 match watcher.on_inject(&msg, pos) {
682 WatcherVerdict::Allow => {}
683 WatcherVerdict::Reject { reason } => {
684 return Err(ContextError::Rejected { reason });
685 }
686 }
687 }
688 match pos {
689 Position::Back => self.messages.push(msg),
690 Position::Front => self.messages.insert(0, msg),
691 Position::At(idx) => {
692 if idx > self.messages.len() {
693 return Err(ContextError::OutOfBounds {
694 index: idx,
695 len: self.messages.len(),
696 });
697 }
698 self.messages.insert(idx, msg);
699 }
700 }
701 Ok(())
702 }
703
704 pub fn compact_truncate(&mut self, keep: usize) -> Vec<Message> {
709 if keep >= self.messages.len() {
710 return Vec::new();
711 }
712 let old_count = self.messages.len();
713 for watcher in &self.watchers {
714 watcher.on_pre_compact(old_count);
715 }
716 let split = self.messages.len() - keep;
717 let removed: Vec<Message> = self.messages.drain(..split).collect();
718 for watcher in &self.watchers {
719 watcher.on_post_compact(removed.len(), self.messages.len());
720 }
721 removed
722 }
723
724 pub fn compact_by_policy(&mut self) -> Vec<Message> {
728 let old_count = self.messages.len();
729 for watcher in &self.watchers {
730 watcher.on_pre_compact(old_count);
731 }
732 let mut kept = Vec::new();
733 let mut removed = Vec::new();
734 for msg in self.messages.drain(..) {
735 if matches!(msg.meta.policy, CompactionPolicy::Pinned) {
736 kept.push(msg);
737 } else {
738 removed.push(msg);
739 }
740 }
741 self.messages = kept;
742 for watcher in &self.watchers {
743 watcher.on_post_compact(removed.len(), self.messages.len());
744 }
745 removed
746 }
747
748 pub fn compact_with(&mut self, f: impl FnOnce(&[Message]) -> Vec<Message>) -> Vec<Message> {
753 let old_count = self.messages.len();
754 for watcher in &self.watchers {
755 watcher.on_pre_compact(old_count);
756 }
757 let new_messages = f(&self.messages);
758 let old = std::mem::replace(&mut self.messages, new_messages);
759 let removed_count = old.len().saturating_sub(self.messages.len());
761 let removed = old;
762 for watcher in &self.watchers {
763 watcher.on_post_compact(removed_count, self.messages.len());
764 }
765 removed
766 }
767
768 pub fn messages_mut(&mut self) -> &mut Vec<Message> {
772 &mut self.messages
773 }
774
775 pub fn snapshot(&self) -> ContextSnapshot {
777 let estimated_tokens = self.estimated_tokens();
778 ContextSnapshot {
779 message_count: self.messages.len(),
780 message_metas: self.messages.iter().map(|m| m.meta.clone()).collect(),
781 has_system: self.messages.iter().any(|m| matches!(m.role, Role::System)),
782 operator_id: self.operator_id.clone(),
783 estimated_tokens,
784 }
785 }
786}
787
788#[cfg(test)]
789mod tests {
790 use super::*;
791 use std::fmt;
792 use std::sync::Arc;
793 use std::sync::atomic::{AtomicBool, Ordering};
794
795 type TestMsg = String;
796
797 fn make_msg(s: &str) -> ContextMessage<TestMsg> {
798 ContextMessage {
799 message: s.to_string(),
800 meta: MessageMeta::default(),
801 }
802 }
803
804 #[test]
805 fn new_context_is_empty() {
806 let ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("agent-1"));
807 assert!(ctx.is_empty());
808 assert_eq!(ctx.len(), 0);
809 assert!(ctx.messages().is_empty());
810 }
811
812 #[test]
813 fn inject_back_appends_in_order() {
814 let mut ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("a"));
815 ctx.inject(make_msg("first"), Position::Back).unwrap();
816 ctx.inject(make_msg("second"), Position::Back).unwrap();
817 assert_eq!(ctx.messages()[0].message, "first");
818 assert_eq!(ctx.messages()[1].message, "second");
819 }
820
821 #[test]
822 fn inject_front_prepends() {
823 let mut ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("a"));
824 ctx.inject(make_msg("first"), Position::Back).unwrap();
825 ctx.inject(make_msg("second"), Position::Front).unwrap();
826 assert_eq!(ctx.messages()[0].message, "second");
827 assert_eq!(ctx.messages()[1].message, "first");
828 }
829
830 #[test]
831 fn inject_at_inserts_at_index() {
832 let mut ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("a"));
833 ctx.inject(make_msg("a"), Position::Back).unwrap();
834 ctx.inject(make_msg("c"), Position::Back).unwrap();
835 ctx.inject(make_msg("b"), Position::At(1)).unwrap();
836 assert_eq!(ctx.messages()[0].message, "a");
837 assert_eq!(ctx.messages()[1].message, "b");
838 assert_eq!(ctx.messages()[2].message, "c");
839 }
840
841 #[test]
842 fn inject_out_of_bounds_returns_error() {
843 let mut ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("a"));
844 let err = ctx.inject(make_msg("x"), Position::At(5)).unwrap_err();
845 assert!(matches!(
846 err,
847 ContextError::OutOfBounds { index: 5, len: 0 }
848 ));
849 assert!(ctx.is_empty());
851 }
852
853 #[test]
854 fn truncate_back_removes_from_end() {
855 let mut ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("a"));
856 ctx.inject(make_msg("a"), Position::Back).unwrap();
857 ctx.inject(make_msg("b"), Position::Back).unwrap();
858 ctx.inject(make_msg("c"), Position::Back).unwrap();
859
860 let removed = ctx.truncate_back(2).unwrap();
861 assert_eq!(removed.len(), 2);
862 assert_eq!(removed[0].message, "b");
863 assert_eq!(removed[1].message, "c");
864 assert_eq!(ctx.len(), 1);
865 assert_eq!(ctx.messages()[0].message, "a");
866 }
867
868 #[test]
869 fn truncate_back_out_of_bounds_returns_error() {
870 let mut ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("a"));
871 ctx.inject(make_msg("a"), Position::Back).unwrap();
872 let err = ctx.truncate_back(5).unwrap_err();
873 assert!(matches!(
874 err,
875 ContextError::OutOfBounds { index: 5, len: 1 }
876 ));
877 assert_eq!(ctx.len(), 1); }
879
880 #[test]
881 fn truncate_front_removes_from_start() {
882 let mut ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("a"));
883 ctx.inject(make_msg("a"), Position::Back).unwrap();
884 ctx.inject(make_msg("b"), Position::Back).unwrap();
885 ctx.inject(make_msg("c"), Position::Back).unwrap();
886
887 let removed = ctx.truncate_front(2).unwrap();
888 assert_eq!(removed.len(), 2);
889 assert_eq!(removed[0].message, "a");
890 assert_eq!(removed[1].message, "b");
891 assert_eq!(ctx.len(), 1);
892 assert_eq!(ctx.messages()[0].message, "c");
893 }
894
895 #[test]
896 fn truncate_front_out_of_bounds_returns_error() {
897 let mut ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("a"));
898 ctx.inject(make_msg("a"), Position::Back).unwrap();
899 let err = ctx.truncate_front(5).unwrap_err();
900 assert!(matches!(
901 err,
902 ContextError::OutOfBounds { index: 5, len: 1 }
903 ));
904 assert_eq!(ctx.len(), 1); }
906
907 #[test]
908 fn watcher_can_reject_inject() {
909 struct RejectAll;
910
911 impl ContextWatcher for RejectAll {
912 fn on_inject(&self, _msg: &dyn fmt::Debug, _pos: Position) -> WatcherVerdict {
913 WatcherVerdict::Reject {
914 reason: "policy violation".into(),
915 }
916 }
917 }
918
919 let mut ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("a"));
920 ctx.add_watcher(Arc::new(RejectAll));
921
922 let err = ctx.inject(make_msg("blocked"), Position::Back).unwrap_err();
923 assert!(matches!(err, ContextError::Rejected { .. }));
924 assert!(ctx.is_empty());
926 }
927
928 #[test]
929 fn snapshot_captures_state() {
930 let mut ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("my-agent"));
931 ctx.set_system("You are helpful.");
932 ctx.inject(make_msg("hello"), Position::Back).unwrap();
933
934 let snap = ctx.snapshot();
935 assert_eq!(snap.message_count, 1);
936 assert!(snap.has_system);
937 assert_eq!(snap.operator_id.as_str(), "my-agent");
938 assert_eq!(snap.message_metas.len(), 1);
939 }
940
941 #[test]
942 fn transform_increments_version() {
943 let mut ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("a"));
944 ctx.inject(make_msg("msg"), Position::Back).unwrap();
945 assert_eq!(ctx.messages()[0].meta.version, 0);
946
947 ctx.transform(|_| {});
948 assert_eq!(ctx.messages()[0].meta.version, 1);
949
950 ctx.transform(|_| {});
951 assert_eq!(ctx.messages()[0].meta.version, 2);
952 }
953
954 #[test]
955 fn replace_messages_fires_compact_watchers() {
956 let pre_called = Arc::new(AtomicBool::new(false));
957 let post_called = Arc::new(AtomicBool::new(false));
958
959 struct CompactWatcher {
960 pre: Arc<AtomicBool>,
961 post: Arc<AtomicBool>,
962 }
963
964 impl ContextWatcher for CompactWatcher {
965 fn on_pre_compact(&self, _message_count: usize) -> WatcherVerdict {
966 self.pre.store(true, Ordering::SeqCst);
967 WatcherVerdict::Allow
968 }
969
970 fn on_post_compact(&self, _removed: usize, _remaining: usize) {
971 self.post.store(true, Ordering::SeqCst);
972 }
973 }
974
975 let mut ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("a"));
976 ctx.add_watcher(Arc::new(CompactWatcher {
977 pre: Arc::clone(&pre_called),
978 post: Arc::clone(&post_called),
979 }));
980
981 ctx.inject(make_msg("old"), Position::Back).unwrap();
982 let old = ctx.replace_messages(vec![make_msg("new")]).unwrap();
983
984 assert!(
985 pre_called.load(Ordering::SeqCst),
986 "on_pre_compact not called"
987 );
988 assert!(
989 post_called.load(Ordering::SeqCst),
990 "on_post_compact not called"
991 );
992 assert_eq!(old.len(), 1);
993 assert_eq!(old[0].message, "old");
994 assert_eq!(ctx.messages()[0].message, "new");
995 }
996
997 #[test]
998 fn remove_where_filters_correctly() {
999 let mut ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("a"));
1000 ctx.inject(make_msg("keep"), Position::Back).unwrap();
1001 ctx.inject(make_msg("remove_me"), Position::Back).unwrap();
1002 ctx.inject(make_msg("also keep"), Position::Back).unwrap();
1003
1004 let removed = ctx.remove_where(|m| m.message.contains("remove")).unwrap();
1005 assert_eq!(removed.len(), 1);
1006 assert_eq!(removed[0].message, "remove_me");
1007 assert_eq!(ctx.len(), 2);
1008 assert_eq!(ctx.messages()[0].message, "keep");
1009 assert_eq!(ctx.messages()[1].message, "also keep");
1010 }
1011
1012 #[test]
1013 fn extract_is_non_destructive() {
1014 let mut ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("a"));
1015 ctx.inject(make_msg("a"), Position::Back).unwrap();
1016 ctx.inject(make_msg("b"), Position::Back).unwrap();
1017 ctx.inject(make_msg("c"), Position::Back).unwrap();
1018
1019 let found = ctx.extract(|m| m.message != "b");
1020 assert_eq!(found.len(), 2);
1021 assert_eq!(found[0].message, "a");
1022 assert_eq!(found[1].message, "c");
1023 assert_eq!(ctx.len(), 3);
1025 }
1026
1027 #[test]
1028 fn system_prompt_lifecycle() {
1029 let mut ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("a"));
1030 assert!(ctx.system().is_none());
1031
1032 ctx.set_system("Hello, system!");
1033 assert_eq!(ctx.system(), Some("Hello, system!"));
1034
1035 ctx.clear_system();
1036 assert!(ctx.system().is_none());
1037 }
1038
1039 #[test]
1040 fn message_construction_and_role_variants() {
1041 use crate::content::Content;
1042 use crate::lifecycle::CompactionPolicy;
1043
1044 let msg = Message {
1045 role: Role::User,
1046 content: Content::text("hello"),
1047 meta: MessageMeta::default(),
1048 };
1049 assert!(matches!(msg.role, Role::User));
1050
1051 let tool_msg = Message {
1052 role: Role::Tool {
1053 name: "shell".into(),
1054 call_id: "tc_1".into(),
1055 },
1056 content: Content::text("output"),
1057 meta: MessageMeta::default(),
1058 };
1059 assert!(matches!(tool_msg.role, Role::Tool { .. }));
1060
1061 let pinned = Message::pinned(Role::System, Content::text("system"));
1062 assert!(matches!(pinned.meta.policy, CompactionPolicy::Pinned));
1063 }
1064
1065 #[test]
1066 fn message_serde_roundtrip() {
1067 use crate::content::Content;
1068
1069 let msg = Message {
1070 role: Role::Assistant,
1071 content: Content::text("hi"),
1072 meta: MessageMeta::default(),
1073 };
1074 let json = serde_json::to_string(&msg).unwrap();
1075 let rt: Message = serde_json::from_str(&json).unwrap();
1076 assert!(matches!(rt.role, Role::Assistant));
1077 }
1078
1079 #[test]
1080 fn message_estimated_tokens() {
1081 use crate::content::Content;
1082
1083 let msg = Message::new(Role::User, Content::text("12345678901234567890"));
1085 assert_eq!(msg.estimated_tokens(), 9);
1086 }
1087
1088 #[test]
1089 fn message_text_content_extraction() {
1090 use crate::content::Content;
1091
1092 let msg = Message::new(Role::User, Content::text("hello world"));
1093 assert_eq!(msg.text_content(), "hello world");
1094 }
1095
1096 #[test]
1099 fn context_push_and_read() {
1100 use crate::content::Content;
1101
1102 let mut ctx = Context::new(OperatorId::from("agent-1"));
1103 ctx.push(Message::new(Role::User, Content::text("hello")))
1104 .unwrap();
1105 ctx.push(Message::new(Role::Assistant, Content::text("hi")))
1106 .unwrap();
1107 assert_eq!(ctx.len(), 2);
1108 assert!(matches!(ctx.messages()[0].role, Role::User));
1109 assert!(matches!(ctx.messages()[1].role, Role::Assistant));
1110 }
1111
1112 #[test]
1113 fn context_compact_truncate() {
1114 use crate::content::Content;
1115
1116 let mut ctx = Context::new(OperatorId::from("a"));
1117 for i in 0..10 {
1118 ctx.push(Message::new(
1119 Role::User,
1120 Content::text(format!("msg {}", i)),
1121 ))
1122 .unwrap();
1123 }
1124 let removed = ctx.compact_truncate(3);
1125 assert_eq!(removed.len(), 7);
1126 assert_eq!(ctx.len(), 3);
1127 }
1128
1129 #[test]
1130 fn context_compact_by_policy_preserves_pinned() {
1131 use crate::content::Content;
1132
1133 let mut ctx = Context::new(OperatorId::from("a"));
1134 ctx.push(Message::pinned(
1135 Role::System,
1136 Content::text("you are helpful"),
1137 ))
1138 .unwrap();
1139 for i in 0..5 {
1140 ctx.push(Message::new(
1141 Role::User,
1142 Content::text(format!("msg {}", i)),
1143 ))
1144 .unwrap();
1145 }
1146 let removed = ctx.compact_by_policy();
1147 assert_eq!(ctx.len(), 1);
1148 assert!(matches!(ctx.messages()[0].role, Role::System));
1149 assert_eq!(removed.len(), 5);
1150 }
1151
1152 #[test]
1153 fn context_compact_with_closure() {
1154 use crate::content::Content;
1155
1156 let mut ctx = Context::new(OperatorId::from("a"));
1157 for i in 0..6 {
1158 ctx.push(Message::new(
1159 Role::User,
1160 Content::text(format!("msg {}", i)),
1161 ))
1162 .unwrap();
1163 }
1164 let removed = ctx.compact_with(|msgs| {
1165 msgs.iter()
1166 .enumerate()
1167 .filter(|(i, _)| i % 2 == 0)
1168 .map(|(_, m)| m.clone())
1169 .collect()
1170 });
1171 assert_eq!(ctx.len(), 3);
1172 assert_eq!(removed.len(), 6);
1174 }
1175
1176 #[test]
1177 fn context_snapshot() {
1178 use crate::content::Content;
1179
1180 let mut ctx = Context::new(OperatorId::from("my-agent"));
1181 ctx.push(Message::pinned(Role::System, Content::text("system")))
1182 .unwrap();
1183 ctx.push(Message::new(Role::User, Content::text("hello")))
1184 .unwrap();
1185
1186 let snap = ctx.snapshot();
1187 assert_eq!(snap.message_count, 2);
1188 assert!(snap.has_system);
1189 assert_eq!(snap.operator_id.as_str(), "my-agent");
1190 assert_eq!(snap.message_metas.len(), 2);
1191 }
1192
1193 #[test]
1194 fn context_estimated_tokens() {
1195 use crate::content::Content;
1196
1197 let mut ctx = Context::new(OperatorId::from("a"));
1198 ctx.push(Message::new(
1200 Role::User,
1201 Content::text("12345678901234567890"),
1202 ))
1203 .unwrap();
1204 ctx.push(Message::new(
1205 Role::User,
1206 Content::text("12345678901234567890"),
1207 ))
1208 .unwrap();
1209 assert_eq!(ctx.estimated_tokens(), 18);
1210 }
1211}