Skip to main content

orion_core/
context.rs

1use crate::error::{CoreError, CoreResult};
2use crate::messages::{Message, Role};
3use crate::template::ChatTemplate;
4use crate::tools::ToolSchema;
5
6/// Strategy for handling context overflow.
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub enum PruneStrategy {
9    /// Drop oldest message pairs (keep system + most recent turns).
10    SlidingWindow,
11    /// Summarize the oldest turns into a single pinned summary message instead
12    /// of dropping them outright. The summarization itself is performed by the
13    /// agent (it needs the LLM backend); the context pipeline still prunes with
14    /// a sliding window once the summary is in place.
15    Summarize,
16}
17
18/// Configuration for context management.
19#[derive(Debug, Clone)]
20pub struct ContextConfig {
21    /// Total context window size in tokens (prompt + reserved response).
22    pub max_context_tokens: u32,
23    /// Tokens reserved for the response; deducted from the prune budget.
24    pub max_response_tokens: u32,
25    /// How to handle a conversation that overflows the budget.
26    pub prune_strategy: PruneStrategy,
27}
28
29impl Default for ContextConfig {
30    fn default() -> Self {
31        Self {
32            max_context_tokens: 4096,
33            max_response_tokens: 2048,
34            prune_strategy: PruneStrategy::SlidingWindow,
35        }
36    }
37}
38
39/// Result of context preparation.
40#[derive(Debug, Clone)]
41pub struct PreparedContext {
42    /// The fully formatted prompt string to feed the backend.
43    pub prompt: String,
44    /// Total token count of `prompt`.
45    pub token_count: u32,
46    /// Number of conversation messages kept in the prompt.
47    pub messages_included: u32,
48    /// Number of conversation messages dropped to fit the budget.
49    pub messages_pruned: u32,
50}
51
52/// Which turns survive pruning and which are dropped, as index ranges into the
53/// messages slice (in original order). Produced by [`plan_prune`]; the agent
54/// uses `dropped` to decide what to summarize under [`PruneStrategy::Summarize`].
55#[derive(Debug, Clone)]
56pub struct PrunePlan {
57    /// Turns that survive pruning, as index ranges into the messages slice.
58    pub kept: Vec<std::ops::Range<usize>>,
59    /// Turns that are dropped to fit the budget, as index ranges.
60    pub dropped: Vec<std::ops::Range<usize>>,
61}
62
63/// Group conversation messages into turns for pair-wise pruning.
64///
65/// A turn starts with a User message and includes all subsequent non-User
66/// messages (Assistant, ToolCall, ToolResult) until the next User message.
67/// Returns index ranges into the messages slice.
68fn group_into_turns(messages: &[Message]) -> Vec<std::ops::Range<usize>> {
69    let mut turns = Vec::new();
70    let mut turn_start: Option<usize> = None;
71
72    for (i, msg) in messages.iter().enumerate() {
73        if msg.role == Role::User {
74            if let Some(start) = turn_start {
75                turns.push(start..i);
76            }
77            turn_start = Some(i);
78        }
79    }
80    if let Some(start) = turn_start {
81        turns.push(start..messages.len());
82    }
83
84    turns
85}
86
87/// Plan which turns survive pruning to fit the token budget.
88///
89/// 1. Deducts system prompt + tools + assistant-prefix overhead from the budget
90/// 2. Groups messages into turns (user + following non-user messages)
91/// 3. Always keeps the most recent turn and every *pinned* turn
92/// 4. Fills the remaining budget with the most-recent non-pinned turns backward
93///
94/// A turn is pinned if any of its messages is `pinned`. Returns
95/// `CoreError::Context` if the system block, the latest turn, or the pinned
96/// turns alone exceed the available budget.
97pub fn plan_prune(
98    template: &dyn ChatTemplate,
99    system_prompt: &str,
100    messages: &[Message],
101    tools: &[ToolSchema],
102    config: &ContextConfig,
103    token_counter: &dyn Fn(&str) -> u32,
104) -> CoreResult<PrunePlan> {
105    let available = config
106        .max_context_tokens
107        .saturating_sub(config.max_response_tokens);
108
109    // Fixed overhead: system block (system prompt + tools) + assistant prefix.
110    let system_block = template.format_system(system_prompt, tools);
111    let fixed_overhead = token_counter(&system_block) + token_counter(template.assistant_prefix());
112
113    if fixed_overhead >= available {
114        return Err(CoreError::Context(format!(
115            "System prompt and tools ({fixed_overhead} tokens) exceed \
116             available context budget ({available} tokens)"
117        )));
118    }
119    let mut budget = available - fixed_overhead;
120
121    let turns = group_into_turns(messages);
122    if turns.is_empty() {
123        return Ok(PrunePlan {
124            kept: vec![],
125            dropped: vec![],
126        });
127    }
128
129    let turn_costs: Vec<u32> = turns
130        .iter()
131        .map(|range| {
132            messages[range.clone()]
133                .iter()
134                .map(|msg| token_counter(&template.format_message(msg)))
135                .sum()
136        })
137        .collect();
138    let turn_pinned: Vec<bool> = turns
139        .iter()
140        .map(|range| messages[range.clone()].iter().any(|m| m.pinned))
141        .collect();
142
143    let last = turns.len() - 1;
144    let mut keep = vec![false; turns.len()];
145
146    // The latest turn must fit — otherwise context overflow.
147    if turn_costs[last] > budget {
148        return Err(CoreError::Context(format!(
149            "Latest message ({} tokens) plus system prompt \
150             ({fixed_overhead} tokens) exceeds context budget ({available} tokens). \
151             Clear the conversation or increase context size.",
152            turn_costs[last]
153        )));
154    }
155    budget -= turn_costs[last];
156    keep[last] = true;
157
158    // Pinned turns always survive, regardless of recency.
159    for i in 0..last {
160        if turn_pinned[i] {
161            if turn_costs[i] > budget {
162                let pinned_total: u32 = (0..turns.len())
163                    .filter(|&j| turn_pinned[j])
164                    .map(|j| turn_costs[j])
165                    .sum();
166                return Err(CoreError::Context(format!(
167                    "Pinned messages ({pinned_total} tokens) exceed the available \
168                     context budget ({available} tokens). Unpin some messages or \
169                     increase context size."
170                )));
171            }
172            budget -= turn_costs[i];
173            keep[i] = true;
174        }
175    }
176
177    // Fill the remaining budget with the most-recent non-pinned turns, walking
178    // backward. Stop at the first non-pinned turn that doesn't fit (sliding
179    // window); already-pinned turns are skipped without stopping the walk.
180    for i in (0..last).rev() {
181        if keep[i] {
182            continue;
183        }
184        if turn_costs[i] <= budget {
185            budget -= turn_costs[i];
186            keep[i] = true;
187        } else {
188            break;
189        }
190    }
191
192    let mut kept = Vec::new();
193    let mut dropped = Vec::new();
194    for (i, range) in turns.iter().enumerate() {
195        if keep[i] {
196            kept.push(range.clone());
197        } else {
198            dropped.push(range.clone());
199        }
200    }
201    Ok(PrunePlan { kept, dropped })
202}
203
204/// Prepare context: prune to fit the budget, apply the template, return the
205/// formatted prompt. Thin wrapper over [`plan_prune`] that formats the kept
206/// turns. Pinned messages always survive (see `plan_prune`).
207///
208/// The agent calls this automatically before each LLM call; call it directly
209/// only when you want custom control.
210///
211/// ```
212/// use orion_core::{ChatMLTemplate, Message};
213/// use orion_core::context::{prepare_context, ContextConfig};
214///
215/// // A real backend tokenizes; here we approximate with a word count.
216/// let token_counter = |text: &str| -> u32 { text.split_whitespace().count() as u32 };
217/// let messages = vec![
218///     Message::user("1", "Hello"),
219///     Message::assistant("2", "Hi there!"),
220/// ];
221///
222/// let prepared = prepare_context(
223///     &ChatMLTemplate,           // any `ChatTemplate` impl
224///     "You are helpful.",        // system prompt
225///     &messages,                 // full conversation history
226///     &[],                       // tool schemas to inject (may be empty)
227///     &ContextConfig::default(),
228///     &token_counter,
229/// )?;
230///
231/// assert!(prepared.prompt.contains("Hi there!"));
232/// assert_eq!(prepared.messages_included, 2);
233/// assert_eq!(prepared.messages_pruned, 0);
234/// # Ok::<(), orion_core::CoreError>(())
235/// ```
236pub fn prepare_context(
237    template: &dyn ChatTemplate,
238    system_prompt: &str,
239    messages: &[Message],
240    tools: &[ToolSchema],
241    config: &ContextConfig,
242    token_counter: &dyn Fn(&str) -> u32,
243) -> CoreResult<PreparedContext> {
244    let plan = plan_prune(
245        template,
246        system_prompt,
247        messages,
248        tools,
249        config,
250        token_counter,
251    )?;
252
253    // Collect kept messages in original order (kept ranges may be non-contiguous
254    // when an old pinned turn survives alongside the recent window).
255    let kept: Vec<Message> = plan
256        .kept
257        .iter()
258        .flat_map(|range| messages[range.clone()].iter().cloned())
259        .collect();
260    let kept_count = kept.len() as u32;
261    let pruned = messages.len() as u32 - kept_count;
262
263    let prompt = template.format(system_prompt, &kept, tools);
264    let token_count = token_counter(&prompt);
265
266    Ok(PreparedContext {
267        prompt,
268        token_count,
269        messages_included: kept_count,
270        messages_pruned: pruned,
271    })
272}