Skip to main content

layer0/
context.rs

1//! Context management for operator message histories.
2//!
3//! This module provides [`OperatorContext`] — a typed, watcher-guarded container for
4//! an operator's message history. Context is a first-class primitive: it tracks messages,
5//! metadata, system prompts, and routes every mutation through registered [`ContextWatcher`]s
6//! for observation and approval.
7//!
8//! ## Core types
9//!
10//! - [`OperatorContext`] — the mutable container (own this, pass it to the turn loop)
11//! - [`ContextMessage`] — a message paired with its [`MessageMeta`]
12//! - [`ContextWatcher`] — observer / gatekeeper trait
13//! - [`ContextSnapshot`] — read-only introspection view
14//! - [`ContextError`] — mutation errors (rejected or out-of-bounds)
15
16use crate::content::Content;
17use crate::id::OperatorId;
18use crate::lifecycle::CompactionPolicy;
19use serde::{Deserialize, Serialize};
20use std::fmt;
21use std::sync::Arc;
22
23/// Per-message annotation attached to every message in an [`OperatorContext`].
24///
25/// All fields are public and directly settable. The [`Default`] implementation
26/// uses [`CompactionPolicy::Normal`] and zeros/nones for everything else.
27#[non_exhaustive]
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct MessageMeta {
30    /// Compaction policy governing how this message survives context reduction.
31    pub policy: CompactionPolicy,
32
33    /// Source of the message, e.g. `"user"` or `"tool:shell"`.
34    #[serde(default, skip_serializing_if = "Option::is_none")]
35    pub source: Option<String>,
36
37    /// Importance hint in the range 0.0–1.0. Higher values should survive compaction longer.
38    #[serde(default, skip_serializing_if = "Option::is_none")]
39    pub salience: Option<f64>,
40
41    /// Monotonic version counter, incremented on each mutation via [`OperatorContext::transform`].
42    pub version: u64,
43}
44
45impl Default for MessageMeta {
46    fn default() -> Self {
47        Self {
48            policy: CompactionPolicy::Normal,
49            source: None,
50            salience: None,
51            version: 0,
52        }
53    }
54}
55
56impl MessageMeta {
57    /// Create metadata with the given policy and defaults for all other fields.
58    pub fn with_policy(policy: CompactionPolicy) -> Self {
59        Self {
60            policy,
61            ..Default::default()
62        }
63    }
64
65    /// Set the source.
66    pub fn set_source(mut self, source: impl Into<String>) -> Self {
67        self.source = Some(source.into());
68        self
69    }
70
71    /// Set the salience score.
72    pub fn set_salience(mut self, salience: f64) -> Self {
73        self.salience = Some(salience);
74        self
75    }
76}
77
78/// Role of a message in the context window.
79#[non_exhaustive]
80#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
81#[serde(rename_all = "snake_case")]
82pub enum Role {
83    /// System instruction.
84    System,
85    /// Human message.
86    User,
87    /// Model response.
88    Assistant,
89    /// Tool/sub-operator result.
90    Tool {
91        /// Name of the tool/operator.
92        name: String,
93        /// Provider-specific call ID for correlation.
94        call_id: String,
95    },
96}
97
98/// A message in an operator's context window.
99///
100/// Concrete type — not generic. Every message has a role, content,
101/// and per-message metadata (compaction policy, salience, source).
102#[non_exhaustive]
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct Message {
105    /// Who produced this message.
106    pub role: Role,
107    /// The message payload.
108    pub content: Content,
109    /// Per-message annotation (compaction policy, salience, source, version).
110    pub meta: MessageMeta,
111}
112
113impl Message {
114    /// Create a new message with default metadata.
115    pub fn new(role: Role, content: Content) -> Self {
116        Self {
117            role,
118            content,
119            meta: MessageMeta::default(),
120        }
121    }
122
123    /// Create a message with `CompactionPolicy::Pinned`.
124    pub fn pinned(role: Role, content: Content) -> Self {
125        Self {
126            role,
127            content,
128            meta: MessageMeta {
129                policy: CompactionPolicy::Pinned,
130                ..Default::default()
131            },
132        }
133    }
134
135    /// Rough token estimate: chars/4 for text, 1000 for images, +4 overhead per message.
136    pub fn estimated_tokens(&self) -> usize {
137        use crate::content::ContentBlock;
138        let content_tokens = match &self.content {
139            Content::Text(s) => s.len() / 4,
140            Content::Blocks(blocks) => blocks
141                .iter()
142                .map(|b| match b {
143                    ContentBlock::Text { text } => text.len() / 4,
144                    ContentBlock::ToolUse { input, .. } => input.to_string().len() / 4,
145                    ContentBlock::ToolResult { content, .. } => content.len() / 4,
146                    ContentBlock::Image { .. } => 1000,
147                    ContentBlock::Custom { data, .. } => data.to_string().len() / 4,
148                })
149                .sum(),
150        };
151        content_tokens + 4 // per-message overhead
152    }
153
154    /// Extract all text content for similarity computation.
155    pub fn text_content(&self) -> String {
156        use crate::content::ContentBlock;
157        match &self.content {
158            Content::Text(s) => s.clone(),
159            Content::Blocks(blocks) => blocks
160                .iter()
161                .filter_map(|b| match b {
162                    ContentBlock::Text { text } => Some(text.as_str()),
163                    ContentBlock::ToolResult { content, .. } => Some(content.as_str()),
164                    _ => None,
165                })
166                .collect::<Vec<_>>()
167                .join(" "),
168        }
169    }
170}
171
172/// A message paired with its metadata.
173///
174/// Parameterised over `M`, the concrete message type used by a particular operator.
175/// Both fields are public so callers can construct messages directly.
176#[derive(Debug, Clone, Serialize, Deserialize)]
177pub struct ContextMessage<M> {
178    /// The message payload.
179    pub message: M,
180
181    /// Per-message metadata (compaction policy, source, salience, version).
182    pub meta: MessageMeta,
183}
184
185/// Where to inject a message into an [`OperatorContext`].
186#[non_exhaustive]
187#[derive(Debug, Clone, Copy, PartialEq, Eq)]
188pub enum Position {
189    /// Append to the end of the message list.
190    Back,
191
192    /// Prepend to the beginning of the message list.
193    Front,
194
195    /// Insert at the given 0-based index.
196    ///
197    /// An index equal to the current length is equivalent to [`Position::Back`].
198    /// An index greater than the current length returns [`ContextError::OutOfBounds`].
199    At(usize),
200}
201
202/// The verdict returned by a [`ContextWatcher`] in response to a proposed mutation.
203#[non_exhaustive]
204#[derive(Debug, Clone)]
205pub enum WatcherVerdict {
206    /// Allow the operation to proceed.
207    Allow,
208
209    /// Reject the operation with a human-readable reason.
210    Reject {
211        /// The reason for rejection.
212        reason: String,
213    },
214}
215
216/// Observer and gatekeeper for mutations to an [`OperatorContext`].
217///
218/// All methods have default implementations that approve the operation (Allow or no-op).
219/// Implementors override only the methods they care about.
220///
221/// Watchers are stored as `Arc<dyn ContextWatcher>` and **must** be `Send + Sync`.
222/// Long-running or I/O-heavy watcher logic will add latency to every mutation; keep
223/// implementations fast.
224///
225/// # Object safety
226///
227/// The trait is object-safe. `on_inject` receives the message as a type-erased
228/// [`fmt::Debug`] reference (the concrete `ContextMessage<M>`) so implementations
229/// can inspect its debug representation without requiring a generic method.
230pub trait ContextWatcher: Send + Sync {
231    /// Called before a message is injected.
232    ///
233    /// `msg` is the full [`ContextMessage<M>`] coerced to `&dyn fmt::Debug`.
234    /// `pos` is the requested injection position.
235    ///
236    /// Return [`WatcherVerdict::Reject`] to abort the inject.
237    fn on_inject(&self, msg: &dyn fmt::Debug, pos: Position) -> WatcherVerdict {
238        let _ = (msg, pos);
239        WatcherVerdict::Allow
240    }
241
242    /// Called before messages are removed (truncate or filter operations).
243    ///
244    /// `count` is the number of messages about to be removed.
245    /// Return [`WatcherVerdict::Reject`] to abort the removal.
246    fn on_remove(&self, count: usize) -> WatcherVerdict {
247        let _ = count;
248        WatcherVerdict::Allow
249    }
250
251    /// Called before a [`OperatorContext::replace_messages`] compaction runs.
252    ///
253    /// `message_count` is the number of messages currently in the context.
254    /// Return [`WatcherVerdict::Reject`] to abort the compaction.
255    fn on_pre_compact(&self, message_count: usize) -> WatcherVerdict {
256        let _ = message_count;
257        WatcherVerdict::Allow
258    }
259
260    /// Called after a [`OperatorContext::replace_messages`] compaction completes.
261    ///
262    /// `removed` is the number of messages dropped (`old_count - new_count`, clamped to 0).
263    /// `remaining` is the count of messages now in the context.
264    fn on_post_compact(&self, removed: usize, remaining: usize) {
265        let _ = (removed, remaining);
266    }
267}
268
269/// Read-only snapshot of an [`OperatorContext`] for introspection and logging.
270#[non_exhaustive]
271#[derive(Debug, Clone, Serialize, Deserialize)]
272pub struct ContextSnapshot {
273    /// Number of messages currently in the context.
274    pub message_count: usize,
275
276    /// Metadata for each message, in order.
277    pub message_metas: Vec<MessageMeta>,
278
279    /// Whether a system prompt is currently set.
280    pub has_system: bool,
281
282    /// The operator this context belongs to.
283    pub operator_id: OperatorId,
284
285    /// Rough token estimate derived from the debug representation length of messages.
286    ///
287    /// Uses the heuristic `total_chars / 4`. Not suitable for billing — use it for
288    /// soft pressure signals only.
289    pub estimated_tokens: usize,
290}
291
292/// Errors returned by [`OperatorContext`] mutation methods.
293#[non_exhaustive]
294#[derive(Debug, thiserror::Error)]
295pub enum ContextError {
296    /// A [`ContextWatcher`] rejected the operation.
297    #[error("rejected by watcher: {reason}")]
298    Rejected {
299        /// The rejection reason from the watcher.
300        reason: String,
301    },
302
303    /// A position index was past the end of the message list.
304    #[error("index {index} is out of bounds (len = {len})")]
305    OutOfBounds {
306        /// The index that was out of bounds.
307        index: usize,
308
309        /// The current length of the message list at the time of the error.
310        len: usize,
311    },
312}
313
314/// A watcher-guarded, typed container for an operator's message history.
315///
316/// `OperatorContext<M>` is the first-class primitive for managing what messages an
317/// operator sees. Every structural mutation (inject, truncate, remove, compact) routes
318/// through the registered [`ContextWatcher`]s in registration order before taking
319/// effect. A single `Reject` verdict from any watcher aborts the operation and
320/// returns [`ContextError::Rejected`].
321///
322/// # Type Parameter
323///
324/// `M` is the message type (e.g. an enum of user / assistant / tool messages).
325/// It must be `Clone + fmt::Debug`. The `Debug` bound is required so the context
326/// can pass messages to watchers and compute token estimates.
327///
328/// # Watcher invocation order
329///
330/// Watchers are invoked in registration order (first registered, first called).
331/// The first watcher to return `Reject` wins; later watchers are not consulted.
332pub struct OperatorContext<M: Clone + fmt::Debug> {
333    operator_id: OperatorId,
334    messages: Vec<ContextMessage<M>>,
335    system: Option<String>,
336    watchers: Vec<Arc<dyn ContextWatcher>>,
337}
338
339impl<M: Clone + fmt::Debug> OperatorContext<M> {
340    /// Create an empty context for the given operator.
341    pub fn new(operator_id: OperatorId) -> Self {
342        Self {
343            operator_id,
344            messages: Vec::new(),
345            system: None,
346            watchers: Vec::new(),
347        }
348    }
349
350    /// Register a watcher. Watchers are invoked in registration order.
351    pub fn add_watcher(&mut self, watcher: Arc<dyn ContextWatcher>) {
352        self.watchers.push(watcher);
353    }
354
355    /// Read-only slice of all messages, in order.
356    pub fn messages(&self) -> &[ContextMessage<M>] {
357        &self.messages
358    }
359
360    /// Number of messages currently in the context.
361    pub fn len(&self) -> usize {
362        self.messages.len()
363    }
364
365    /// Returns `true` if there are no messages.
366    pub fn is_empty(&self) -> bool {
367        self.messages.is_empty()
368    }
369
370    /// The current system prompt, if one is set.
371    pub fn system(&self) -> Option<&str> {
372        self.system.as_deref()
373    }
374
375    /// The operator this context belongs to.
376    pub fn operator_id(&self) -> &OperatorId {
377        &self.operator_id
378    }
379
380    /// Build a read-only snapshot for introspection or logging.
381    ///
382    /// Token estimation uses `total_debug_chars / 4`; treat the value as a
383    /// soft signal, not a billing figure.
384    pub fn snapshot(&self) -> ContextSnapshot {
385        let system_chars = self.system.as_ref().map(|s| s.len()).unwrap_or(0);
386        let message_chars: usize = self
387            .messages
388            .iter()
389            .map(|m| format!("{:?}", m.message).len())
390            .sum();
391        let estimated_tokens = (system_chars + message_chars) / 4;
392
393        ContextSnapshot {
394            message_count: self.messages.len(),
395            message_metas: self.messages.iter().map(|m| m.meta.clone()).collect(),
396            has_system: self.system.is_some(),
397            operator_id: self.operator_id.clone(),
398            estimated_tokens,
399        }
400    }
401
402    /// Set the system prompt, replacing any existing one.
403    pub fn set_system(&mut self, system: impl Into<String>) {
404        self.system = Some(system.into());
405    }
406
407    /// Remove the system prompt.
408    pub fn clear_system(&mut self) {
409        self.system = None;
410    }
411
412    /// Inject a message at the given position.
413    ///
414    /// Calls `on_inject` on all registered watchers first. Returns
415    /// [`ContextError::Rejected`] if any watcher rejects, or
416    /// [`ContextError::OutOfBounds`] if [`Position::At`] exceeds the current
417    /// length.
418    pub fn inject(&mut self, msg: ContextMessage<M>, pos: Position) -> Result<(), ContextError> {
419        for watcher in &self.watchers {
420            match watcher.on_inject(&msg, pos) {
421                WatcherVerdict::Allow => {}
422                WatcherVerdict::Reject { reason } => {
423                    return Err(ContextError::Rejected { reason });
424                }
425            }
426        }
427
428        match pos {
429            Position::Back => self.messages.push(msg),
430            Position::Front => self.messages.insert(0, msg),
431            Position::At(idx) => {
432                if idx > self.messages.len() {
433                    return Err(ContextError::OutOfBounds {
434                        index: idx,
435                        len: self.messages.len(),
436                    });
437                }
438                self.messages.insert(idx, msg);
439            }
440        }
441
442        Ok(())
443    }
444
445    /// Remove and return the last `count` messages.
446    ///
447    /// Returns [`ContextError::OutOfBounds`] if `count` exceeds the current length.
448    /// Fires `on_remove(count)` on all watchers when `count > 0`.
449    pub fn truncate_back(&mut self, count: usize) -> Result<Vec<ContextMessage<M>>, ContextError> {
450        if count > self.messages.len() {
451            return Err(ContextError::OutOfBounds {
452                index: count,
453                len: self.messages.len(),
454            });
455        }
456
457        if count > 0 {
458            for watcher in &self.watchers {
459                match watcher.on_remove(count) {
460                    WatcherVerdict::Allow => {}
461                    WatcherVerdict::Reject { reason } => {
462                        return Err(ContextError::Rejected { reason });
463                    }
464                }
465            }
466        }
467
468        let split_at = self.messages.len() - count;
469        Ok(self.messages.drain(split_at..).collect())
470    }
471
472    /// Remove and return the first `count` messages.
473    ///
474    /// Returns [`ContextError::OutOfBounds`] if `count` exceeds the current length.
475    /// Fires `on_remove(count)` on all watchers when `count > 0`.
476    pub fn truncate_front(&mut self, count: usize) -> Result<Vec<ContextMessage<M>>, ContextError> {
477        if count > self.messages.len() {
478            return Err(ContextError::OutOfBounds {
479                index: count,
480                len: self.messages.len(),
481            });
482        }
483
484        if count > 0 {
485            for watcher in &self.watchers {
486                match watcher.on_remove(count) {
487                    WatcherVerdict::Allow => {}
488                    WatcherVerdict::Reject { reason } => {
489                        return Err(ContextError::Rejected { reason });
490                    }
491                }
492            }
493        }
494
495        Ok(self.messages.drain(..count).collect())
496    }
497
498    /// Remove and return all messages matching `pred`.
499    ///
500    /// Fires `on_remove(n)` on all watchers when `n > 0` matching messages exist.
501    /// Preserves the relative order of messages that are not removed.
502    pub fn remove_where(
503        &mut self,
504        pred: impl Fn(&ContextMessage<M>) -> bool,
505    ) -> Result<Vec<ContextMessage<M>>, ContextError> {
506        let count = self.messages.iter().filter(|m| pred(m)).count();
507
508        if count > 0 {
509            for watcher in &self.watchers {
510                match watcher.on_remove(count) {
511                    WatcherVerdict::Allow => {}
512                    WatcherVerdict::Reject { reason } => {
513                        return Err(ContextError::Rejected { reason });
514                    }
515                }
516            }
517        }
518
519        let mut removed = Vec::new();
520        let mut kept = Vec::new();
521        for msg in self.messages.drain(..) {
522            if pred(&msg) {
523                removed.push(msg);
524            } else {
525                kept.push(msg);
526            }
527        }
528        self.messages = kept;
529        Ok(removed)
530    }
531
532    /// Apply `f` to every message, incrementing `meta.version` after each call.
533    ///
534    /// Version increments unconditionally regardless of whether `f` actually
535    /// mutated the message — callers should use this only when a real mutation
536    /// occurred.
537    pub fn transform(&mut self, mut f: impl FnMut(&mut ContextMessage<M>)) {
538        for msg in &mut self.messages {
539            f(msg);
540            msg.meta.version += 1;
541        }
542    }
543
544    /// Return references to all messages matching `pred` without removing them.
545    ///
546    /// Non-destructive: the context is unchanged after this call.
547    pub fn extract(&self, pred: impl Fn(&ContextMessage<M>) -> bool) -> Vec<&ContextMessage<M>> {
548        self.messages.iter().filter(|m| pred(m)).collect()
549    }
550
551    /// Direct mutable access to the underlying message vector.
552    ///
553    /// **Bypasses all watcher checks.** Prefer the typed mutation methods
554    /// (`inject`, `remove_where`, `transform`, …) unless you need fine-grained
555    /// control that those methods don't expose.
556    pub fn messages_mut(&mut self) -> &mut Vec<ContextMessage<M>> {
557        &mut self.messages
558    }
559
560    /// Replace the entire message list, firing compact watchers.
561    ///
562    /// Fires `on_pre_compact(old_count)` before the swap and
563    /// `on_post_compact(removed, new_count)` after, where
564    /// `removed = old_count.saturating_sub(new_count)`.
565    ///
566    /// Returns the old messages on success, or [`ContextError::Rejected`] if any
567    /// watcher's `on_pre_compact` rejects.
568    pub fn replace_messages(
569        &mut self,
570        new: Vec<ContextMessage<M>>,
571    ) -> Result<Vec<ContextMessage<M>>, ContextError> {
572        let old_count = self.messages.len();
573
574        for watcher in &self.watchers {
575            match watcher.on_pre_compact(old_count) {
576                WatcherVerdict::Allow => {}
577                WatcherVerdict::Reject { reason } => {
578                    return Err(ContextError::Rejected { reason });
579                }
580            }
581        }
582
583        let new_count = new.len();
584        let old = std::mem::replace(&mut self.messages, new);
585        let removed = old_count.saturating_sub(new_count);
586
587        for watcher in &self.watchers {
588            watcher.on_post_compact(removed, new_count);
589        }
590
591        Ok(old)
592    }
593}
594
595// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
596// Context — concrete replacement for OperatorContext<M>
597// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
598
599/// A concrete, watcher-guarded container for an operator's message history.
600///
601/// `Context` replaces the generic `OperatorContext<M>` with a concrete type
602/// using [`Message`] directly. Every structural mutation routes through
603/// registered [`ContextWatcher`]s before taking effect.
604///
605/// # Compaction
606///
607/// Three compaction strategies are available as methods:
608/// - [`compact_truncate`](Context::compact_truncate) — keep the last N messages
609/// - [`compact_by_policy`](Context::compact_by_policy) — remove Normal, keep Pinned
610/// - [`compact_with`](Context::compact_with) — caller-supplied closure
611pub struct Context {
612    operator_id: OperatorId,
613    messages: Vec<Message>,
614    watchers: Vec<Arc<dyn ContextWatcher>>,
615}
616
617impl Context {
618    /// Create an empty context for the given operator.
619    pub fn new(operator_id: OperatorId) -> Self {
620        Self {
621            operator_id,
622            messages: Vec::new(),
623            watchers: Vec::new(),
624        }
625    }
626
627    /// Register a watcher. Watchers are invoked in registration order.
628    pub fn add_watcher(&mut self, watcher: Arc<dyn ContextWatcher>) {
629        self.watchers.push(watcher);
630    }
631
632    /// Read-only slice of all messages, in order.
633    pub fn messages(&self) -> &[Message] {
634        &self.messages
635    }
636
637    /// Number of messages currently in the context.
638    pub fn len(&self) -> usize {
639        self.messages.len()
640    }
641
642    /// Returns `true` if there are no messages.
643    pub fn is_empty(&self) -> bool {
644        self.messages.is_empty()
645    }
646
647    /// The operator this context belongs to.
648    pub fn operator_id(&self) -> &OperatorId {
649        &self.operator_id
650    }
651
652    /// Rough token estimate for the entire context.
653    pub fn estimated_tokens(&self) -> usize {
654        self.messages.iter().map(|m| m.estimated_tokens()).sum()
655    }
656
657    /// Push a message to the end of the context.
658    ///
659    /// Calls `on_inject` on all registered watchers. Returns
660    /// [`ContextError::Rejected`] if any watcher rejects.
661    pub fn push(&mut self, msg: Message) -> Result<(), ContextError> {
662        for watcher in &self.watchers {
663            match watcher.on_inject(&msg, Position::Back) {
664                WatcherVerdict::Allow => {}
665                WatcherVerdict::Reject { reason } => {
666                    return Err(ContextError::Rejected { reason });
667                }
668            }
669        }
670        self.messages.push(msg);
671        Ok(())
672    }
673
674    /// Insert a message at the given position.
675    ///
676    /// Calls `on_inject` on all registered watchers. Returns
677    /// [`ContextError::Rejected`] if any watcher rejects, or
678    /// [`ContextError::OutOfBounds`] if the index exceeds the current length.
679    pub fn insert(&mut self, msg: Message, pos: Position) -> Result<(), ContextError> {
680        for watcher in &self.watchers {
681            match watcher.on_inject(&msg, pos) {
682                WatcherVerdict::Allow => {}
683                WatcherVerdict::Reject { reason } => {
684                    return Err(ContextError::Rejected { reason });
685                }
686            }
687        }
688        match pos {
689            Position::Back => self.messages.push(msg),
690            Position::Front => self.messages.insert(0, msg),
691            Position::At(idx) => {
692                if idx > self.messages.len() {
693                    return Err(ContextError::OutOfBounds {
694                        index: idx,
695                        len: self.messages.len(),
696                    });
697                }
698                self.messages.insert(idx, msg);
699            }
700        }
701        Ok(())
702    }
703
704    /// Keep the last `keep` messages, returning the removed ones.
705    ///
706    /// Fires compact watchers. Does not respect compaction policy —
707    /// use [`compact_by_policy`](Context::compact_by_policy) for that.
708    pub fn compact_truncate(&mut self, keep: usize) -> Vec<Message> {
709        if keep >= self.messages.len() {
710            return Vec::new();
711        }
712        let old_count = self.messages.len();
713        for watcher in &self.watchers {
714            watcher.on_pre_compact(old_count);
715        }
716        let split = self.messages.len() - keep;
717        let removed: Vec<Message> = self.messages.drain(..split).collect();
718        for watcher in &self.watchers {
719            watcher.on_post_compact(removed.len(), self.messages.len());
720        }
721        removed
722    }
723
724    /// Remove all messages with `CompactionPolicy::Normal`, keep `Pinned`.
725    ///
726    /// Returns the removed messages. Fires compact watchers.
727    pub fn compact_by_policy(&mut self) -> Vec<Message> {
728        let old_count = self.messages.len();
729        for watcher in &self.watchers {
730            watcher.on_pre_compact(old_count);
731        }
732        let mut kept = Vec::new();
733        let mut removed = Vec::new();
734        for msg in self.messages.drain(..) {
735            if matches!(msg.meta.policy, CompactionPolicy::Pinned) {
736                kept.push(msg);
737            } else {
738                removed.push(msg);
739            }
740        }
741        self.messages = kept;
742        for watcher in &self.watchers {
743            watcher.on_post_compact(removed.len(), self.messages.len());
744        }
745        removed
746    }
747
748    /// Compact using a caller-supplied closure.
749    ///
750    /// The closure receives `&[Message]` and returns the messages to keep.
751    /// Returns the removed messages. Fires compact watchers.
752    pub fn compact_with(&mut self, f: impl FnOnce(&[Message]) -> Vec<Message>) -> Vec<Message> {
753        let old_count = self.messages.len();
754        for watcher in &self.watchers {
755            watcher.on_pre_compact(old_count);
756        }
757        let new_messages = f(&self.messages);
758        let old = std::mem::replace(&mut self.messages, new_messages);
759        // Determine removed: old messages not in new set
760        let removed_count = old.len().saturating_sub(self.messages.len());
761        let removed = old;
762        for watcher in &self.watchers {
763            watcher.on_post_compact(removed_count, self.messages.len());
764        }
765        removed
766    }
767
768    /// Direct mutable access to the underlying message vector.
769    ///
770    /// **Bypasses all watcher checks.**
771    pub fn messages_mut(&mut self) -> &mut Vec<Message> {
772        &mut self.messages
773    }
774
775    /// Build a read-only snapshot for introspection or logging.
776    pub fn snapshot(&self) -> ContextSnapshot {
777        let estimated_tokens = self.estimated_tokens();
778        ContextSnapshot {
779            message_count: self.messages.len(),
780            message_metas: self.messages.iter().map(|m| m.meta.clone()).collect(),
781            has_system: self.messages.iter().any(|m| matches!(m.role, Role::System)),
782            operator_id: self.operator_id.clone(),
783            estimated_tokens,
784        }
785    }
786}
787
788#[cfg(test)]
789mod tests {
790    use super::*;
791    use std::fmt;
792    use std::sync::Arc;
793    use std::sync::atomic::{AtomicBool, Ordering};
794
795    type TestMsg = String;
796
797    fn make_msg(s: &str) -> ContextMessage<TestMsg> {
798        ContextMessage {
799            message: s.to_string(),
800            meta: MessageMeta::default(),
801        }
802    }
803
804    #[test]
805    fn new_context_is_empty() {
806        let ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("agent-1"));
807        assert!(ctx.is_empty());
808        assert_eq!(ctx.len(), 0);
809        assert!(ctx.messages().is_empty());
810    }
811
812    #[test]
813    fn inject_back_appends_in_order() {
814        let mut ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("a"));
815        ctx.inject(make_msg("first"), Position::Back).unwrap();
816        ctx.inject(make_msg("second"), Position::Back).unwrap();
817        assert_eq!(ctx.messages()[0].message, "first");
818        assert_eq!(ctx.messages()[1].message, "second");
819    }
820
821    #[test]
822    fn inject_front_prepends() {
823        let mut ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("a"));
824        ctx.inject(make_msg("first"), Position::Back).unwrap();
825        ctx.inject(make_msg("second"), Position::Front).unwrap();
826        assert_eq!(ctx.messages()[0].message, "second");
827        assert_eq!(ctx.messages()[1].message, "first");
828    }
829
830    #[test]
831    fn inject_at_inserts_at_index() {
832        let mut ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("a"));
833        ctx.inject(make_msg("a"), Position::Back).unwrap();
834        ctx.inject(make_msg("c"), Position::Back).unwrap();
835        ctx.inject(make_msg("b"), Position::At(1)).unwrap();
836        assert_eq!(ctx.messages()[0].message, "a");
837        assert_eq!(ctx.messages()[1].message, "b");
838        assert_eq!(ctx.messages()[2].message, "c");
839    }
840
841    #[test]
842    fn inject_out_of_bounds_returns_error() {
843        let mut ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("a"));
844        let err = ctx.inject(make_msg("x"), Position::At(5)).unwrap_err();
845        assert!(matches!(
846            err,
847            ContextError::OutOfBounds { index: 5, len: 0 }
848        ));
849        // Context must remain unchanged after the error.
850        assert!(ctx.is_empty());
851    }
852
853    #[test]
854    fn truncate_back_removes_from_end() {
855        let mut ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("a"));
856        ctx.inject(make_msg("a"), Position::Back).unwrap();
857        ctx.inject(make_msg("b"), Position::Back).unwrap();
858        ctx.inject(make_msg("c"), Position::Back).unwrap();
859
860        let removed = ctx.truncate_back(2).unwrap();
861        assert_eq!(removed.len(), 2);
862        assert_eq!(removed[0].message, "b");
863        assert_eq!(removed[1].message, "c");
864        assert_eq!(ctx.len(), 1);
865        assert_eq!(ctx.messages()[0].message, "a");
866    }
867
868    #[test]
869    fn truncate_back_out_of_bounds_returns_error() {
870        let mut ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("a"));
871        ctx.inject(make_msg("a"), Position::Back).unwrap();
872        let err = ctx.truncate_back(5).unwrap_err();
873        assert!(matches!(
874            err,
875            ContextError::OutOfBounds { index: 5, len: 1 }
876        ));
877        assert_eq!(ctx.len(), 1); // unchanged
878    }
879
880    #[test]
881    fn truncate_front_removes_from_start() {
882        let mut ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("a"));
883        ctx.inject(make_msg("a"), Position::Back).unwrap();
884        ctx.inject(make_msg("b"), Position::Back).unwrap();
885        ctx.inject(make_msg("c"), Position::Back).unwrap();
886
887        let removed = ctx.truncate_front(2).unwrap();
888        assert_eq!(removed.len(), 2);
889        assert_eq!(removed[0].message, "a");
890        assert_eq!(removed[1].message, "b");
891        assert_eq!(ctx.len(), 1);
892        assert_eq!(ctx.messages()[0].message, "c");
893    }
894
895    #[test]
896    fn truncate_front_out_of_bounds_returns_error() {
897        let mut ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("a"));
898        ctx.inject(make_msg("a"), Position::Back).unwrap();
899        let err = ctx.truncate_front(5).unwrap_err();
900        assert!(matches!(
901            err,
902            ContextError::OutOfBounds { index: 5, len: 1 }
903        ));
904        assert_eq!(ctx.len(), 1); // unchanged
905    }
906
907    #[test]
908    fn watcher_can_reject_inject() {
909        struct RejectAll;
910
911        impl ContextWatcher for RejectAll {
912            fn on_inject(&self, _msg: &dyn fmt::Debug, _pos: Position) -> WatcherVerdict {
913                WatcherVerdict::Reject {
914                    reason: "policy violation".into(),
915                }
916            }
917        }
918
919        let mut ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("a"));
920        ctx.add_watcher(Arc::new(RejectAll));
921
922        let err = ctx.inject(make_msg("blocked"), Position::Back).unwrap_err();
923        assert!(matches!(err, ContextError::Rejected { .. }));
924        // Injection must have been rolled back.
925        assert!(ctx.is_empty());
926    }
927
928    #[test]
929    fn snapshot_captures_state() {
930        let mut ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("my-agent"));
931        ctx.set_system("You are helpful.");
932        ctx.inject(make_msg("hello"), Position::Back).unwrap();
933
934        let snap = ctx.snapshot();
935        assert_eq!(snap.message_count, 1);
936        assert!(snap.has_system);
937        assert_eq!(snap.operator_id.as_str(), "my-agent");
938        assert_eq!(snap.message_metas.len(), 1);
939    }
940
941    #[test]
942    fn transform_increments_version() {
943        let mut ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("a"));
944        ctx.inject(make_msg("msg"), Position::Back).unwrap();
945        assert_eq!(ctx.messages()[0].meta.version, 0);
946
947        ctx.transform(|_| {});
948        assert_eq!(ctx.messages()[0].meta.version, 1);
949
950        ctx.transform(|_| {});
951        assert_eq!(ctx.messages()[0].meta.version, 2);
952    }
953
954    #[test]
955    fn replace_messages_fires_compact_watchers() {
956        let pre_called = Arc::new(AtomicBool::new(false));
957        let post_called = Arc::new(AtomicBool::new(false));
958
959        struct CompactWatcher {
960            pre: Arc<AtomicBool>,
961            post: Arc<AtomicBool>,
962        }
963
964        impl ContextWatcher for CompactWatcher {
965            fn on_pre_compact(&self, _message_count: usize) -> WatcherVerdict {
966                self.pre.store(true, Ordering::SeqCst);
967                WatcherVerdict::Allow
968            }
969
970            fn on_post_compact(&self, _removed: usize, _remaining: usize) {
971                self.post.store(true, Ordering::SeqCst);
972            }
973        }
974
975        let mut ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("a"));
976        ctx.add_watcher(Arc::new(CompactWatcher {
977            pre: Arc::clone(&pre_called),
978            post: Arc::clone(&post_called),
979        }));
980
981        ctx.inject(make_msg("old"), Position::Back).unwrap();
982        let old = ctx.replace_messages(vec![make_msg("new")]).unwrap();
983
984        assert!(
985            pre_called.load(Ordering::SeqCst),
986            "on_pre_compact not called"
987        );
988        assert!(
989            post_called.load(Ordering::SeqCst),
990            "on_post_compact not called"
991        );
992        assert_eq!(old.len(), 1);
993        assert_eq!(old[0].message, "old");
994        assert_eq!(ctx.messages()[0].message, "new");
995    }
996
997    #[test]
998    fn remove_where_filters_correctly() {
999        let mut ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("a"));
1000        ctx.inject(make_msg("keep"), Position::Back).unwrap();
1001        ctx.inject(make_msg("remove_me"), Position::Back).unwrap();
1002        ctx.inject(make_msg("also keep"), Position::Back).unwrap();
1003
1004        let removed = ctx.remove_where(|m| m.message.contains("remove")).unwrap();
1005        assert_eq!(removed.len(), 1);
1006        assert_eq!(removed[0].message, "remove_me");
1007        assert_eq!(ctx.len(), 2);
1008        assert_eq!(ctx.messages()[0].message, "keep");
1009        assert_eq!(ctx.messages()[1].message, "also keep");
1010    }
1011
1012    #[test]
1013    fn extract_is_non_destructive() {
1014        let mut ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("a"));
1015        ctx.inject(make_msg("a"), Position::Back).unwrap();
1016        ctx.inject(make_msg("b"), Position::Back).unwrap();
1017        ctx.inject(make_msg("c"), Position::Back).unwrap();
1018
1019        let found = ctx.extract(|m| m.message != "b");
1020        assert_eq!(found.len(), 2);
1021        assert_eq!(found[0].message, "a");
1022        assert_eq!(found[1].message, "c");
1023        // Context must be unchanged.
1024        assert_eq!(ctx.len(), 3);
1025    }
1026
1027    #[test]
1028    fn system_prompt_lifecycle() {
1029        let mut ctx: OperatorContext<TestMsg> = OperatorContext::new(OperatorId::from("a"));
1030        assert!(ctx.system().is_none());
1031
1032        ctx.set_system("Hello, system!");
1033        assert_eq!(ctx.system(), Some("Hello, system!"));
1034
1035        ctx.clear_system();
1036        assert!(ctx.system().is_none());
1037    }
1038
1039    #[test]
1040    fn message_construction_and_role_variants() {
1041        use crate::content::Content;
1042        use crate::lifecycle::CompactionPolicy;
1043
1044        let msg = Message {
1045            role: Role::User,
1046            content: Content::text("hello"),
1047            meta: MessageMeta::default(),
1048        };
1049        assert!(matches!(msg.role, Role::User));
1050
1051        let tool_msg = Message {
1052            role: Role::Tool {
1053                name: "shell".into(),
1054                call_id: "tc_1".into(),
1055            },
1056            content: Content::text("output"),
1057            meta: MessageMeta::default(),
1058        };
1059        assert!(matches!(tool_msg.role, Role::Tool { .. }));
1060
1061        let pinned = Message::pinned(Role::System, Content::text("system"));
1062        assert!(matches!(pinned.meta.policy, CompactionPolicy::Pinned));
1063    }
1064
1065    #[test]
1066    fn message_serde_roundtrip() {
1067        use crate::content::Content;
1068
1069        let msg = Message {
1070            role: Role::Assistant,
1071            content: Content::text("hi"),
1072            meta: MessageMeta::default(),
1073        };
1074        let json = serde_json::to_string(&msg).unwrap();
1075        let rt: Message = serde_json::from_str(&json).unwrap();
1076        assert!(matches!(rt.role, Role::Assistant));
1077    }
1078
1079    #[test]
1080    fn message_estimated_tokens() {
1081        use crate::content::Content;
1082
1083        // 20 chars / 4 = 5, + 4 overhead = 9
1084        let msg = Message::new(Role::User, Content::text("12345678901234567890"));
1085        assert_eq!(msg.estimated_tokens(), 9);
1086    }
1087
1088    #[test]
1089    fn message_text_content_extraction() {
1090        use crate::content::Content;
1091
1092        let msg = Message::new(Role::User, Content::text("hello world"));
1093        assert_eq!(msg.text_content(), "hello world");
1094    }
1095
1096    // --- Context tests (Phase 2) ---
1097
1098    #[test]
1099    fn context_push_and_read() {
1100        use crate::content::Content;
1101
1102        let mut ctx = Context::new(OperatorId::from("agent-1"));
1103        ctx.push(Message::new(Role::User, Content::text("hello")))
1104            .unwrap();
1105        ctx.push(Message::new(Role::Assistant, Content::text("hi")))
1106            .unwrap();
1107        assert_eq!(ctx.len(), 2);
1108        assert!(matches!(ctx.messages()[0].role, Role::User));
1109        assert!(matches!(ctx.messages()[1].role, Role::Assistant));
1110    }
1111
1112    #[test]
1113    fn context_compact_truncate() {
1114        use crate::content::Content;
1115
1116        let mut ctx = Context::new(OperatorId::from("a"));
1117        for i in 0..10 {
1118            ctx.push(Message::new(
1119                Role::User,
1120                Content::text(format!("msg {}", i)),
1121            ))
1122            .unwrap();
1123        }
1124        let removed = ctx.compact_truncate(3);
1125        assert_eq!(removed.len(), 7);
1126        assert_eq!(ctx.len(), 3);
1127    }
1128
1129    #[test]
1130    fn context_compact_by_policy_preserves_pinned() {
1131        use crate::content::Content;
1132
1133        let mut ctx = Context::new(OperatorId::from("a"));
1134        ctx.push(Message::pinned(
1135            Role::System,
1136            Content::text("you are helpful"),
1137        ))
1138        .unwrap();
1139        for i in 0..5 {
1140            ctx.push(Message::new(
1141                Role::User,
1142                Content::text(format!("msg {}", i)),
1143            ))
1144            .unwrap();
1145        }
1146        let removed = ctx.compact_by_policy();
1147        assert_eq!(ctx.len(), 1);
1148        assert!(matches!(ctx.messages()[0].role, Role::System));
1149        assert_eq!(removed.len(), 5);
1150    }
1151
1152    #[test]
1153    fn context_compact_with_closure() {
1154        use crate::content::Content;
1155
1156        let mut ctx = Context::new(OperatorId::from("a"));
1157        for i in 0..6 {
1158            ctx.push(Message::new(
1159                Role::User,
1160                Content::text(format!("msg {}", i)),
1161            ))
1162            .unwrap();
1163        }
1164        let removed = ctx.compact_with(|msgs| {
1165            msgs.iter()
1166                .enumerate()
1167                .filter(|(i, _)| i % 2 == 0)
1168                .map(|(_, m)| m.clone())
1169                .collect()
1170        });
1171        assert_eq!(ctx.len(), 3);
1172        // compact_with returns the old messages, not the removed ones
1173        assert_eq!(removed.len(), 6);
1174    }
1175
1176    #[test]
1177    fn context_snapshot() {
1178        use crate::content::Content;
1179
1180        let mut ctx = Context::new(OperatorId::from("my-agent"));
1181        ctx.push(Message::pinned(Role::System, Content::text("system")))
1182            .unwrap();
1183        ctx.push(Message::new(Role::User, Content::text("hello")))
1184            .unwrap();
1185
1186        let snap = ctx.snapshot();
1187        assert_eq!(snap.message_count, 2);
1188        assert!(snap.has_system);
1189        assert_eq!(snap.operator_id.as_str(), "my-agent");
1190        assert_eq!(snap.message_metas.len(), 2);
1191    }
1192
1193    #[test]
1194    fn context_estimated_tokens() {
1195        use crate::content::Content;
1196
1197        let mut ctx = Context::new(OperatorId::from("a"));
1198        // 20 chars / 4 = 5, + 4 overhead = 9 per message
1199        ctx.push(Message::new(
1200            Role::User,
1201            Content::text("12345678901234567890"),
1202        ))
1203        .unwrap();
1204        ctx.push(Message::new(
1205            Role::User,
1206            Content::text("12345678901234567890"),
1207        ))
1208        .unwrap();
1209        assert_eq!(ctx.estimated_tokens(), 18);
1210    }
1211}