orion_core/context.rs
1use crate::error::{CoreError, CoreResult};
2use crate::messages::{Message, Role};
3use crate::template::ChatTemplate;
4use crate::tools::ToolSchema;
5
6/// Strategy for handling context overflow.
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub enum PruneStrategy {
9 /// Drop oldest message pairs (keep system + most recent turns).
10 SlidingWindow,
11 /// Summarize the oldest turns into a single pinned summary message instead
12 /// of dropping them outright. The summarization itself is performed by the
13 /// agent (it needs the LLM backend); the context pipeline still prunes with
14 /// a sliding window once the summary is in place.
15 Summarize,
16}
17
18/// Configuration for context management.
19#[derive(Debug, Clone)]
20pub struct ContextConfig {
21 /// Total context window size in tokens (prompt + reserved response).
22 pub max_context_tokens: u32,
23 /// Tokens reserved for the response; deducted from the prune budget.
24 pub max_response_tokens: u32,
25 /// How to handle a conversation that overflows the budget.
26 pub prune_strategy: PruneStrategy,
27}
28
29impl Default for ContextConfig {
30 fn default() -> Self {
31 Self {
32 max_context_tokens: 4096,
33 max_response_tokens: 2048,
34 prune_strategy: PruneStrategy::SlidingWindow,
35 }
36 }
37}
38
39/// Result of context preparation.
40#[derive(Debug, Clone)]
41pub struct PreparedContext {
42 /// The fully formatted prompt string to feed the backend.
43 pub prompt: String,
44 /// Total token count of `prompt`.
45 pub token_count: u32,
46 /// Number of conversation messages kept in the prompt.
47 pub messages_included: u32,
48 /// Number of conversation messages dropped to fit the budget.
49 pub messages_pruned: u32,
50}
51
52/// Which turns survive pruning and which are dropped, as index ranges into the
53/// messages slice (in original order). Produced by [`plan_prune`]; the agent
54/// uses `dropped` to decide what to summarize under [`PruneStrategy::Summarize`].
55#[derive(Debug, Clone)]
56pub struct PrunePlan {
57 /// Turns that survive pruning, as index ranges into the messages slice.
58 pub kept: Vec<std::ops::Range<usize>>,
59 /// Turns that are dropped to fit the budget, as index ranges.
60 pub dropped: Vec<std::ops::Range<usize>>,
61}
62
63/// Group conversation messages into turns for pair-wise pruning.
64///
65/// A turn starts with a User message and includes all subsequent non-User
66/// messages (Assistant, ToolCall, ToolResult) until the next User message.
67/// Returns index ranges into the messages slice.
68fn group_into_turns(messages: &[Message]) -> Vec<std::ops::Range<usize>> {
69 let mut turns = Vec::new();
70 let mut turn_start: Option<usize> = None;
71
72 for (i, msg) in messages.iter().enumerate() {
73 if msg.role == Role::User {
74 if let Some(start) = turn_start {
75 turns.push(start..i);
76 }
77 turn_start = Some(i);
78 }
79 }
80 if let Some(start) = turn_start {
81 turns.push(start..messages.len());
82 }
83
84 turns
85}
86
87/// Plan which turns survive pruning to fit the token budget.
88///
89/// 1. Deducts system prompt + tools + assistant-prefix overhead from the budget
90/// 2. Groups messages into turns (user + following non-user messages)
91/// 3. Always keeps the most recent turn and every *pinned* turn
92/// 4. Fills the remaining budget with the most-recent non-pinned turns backward
93///
94/// A turn is pinned if any of its messages is `pinned`. Returns
95/// `CoreError::Context` if the system block, the latest turn, or the pinned
96/// turns alone exceed the available budget.
97pub fn plan_prune(
98 template: &dyn ChatTemplate,
99 system_prompt: &str,
100 messages: &[Message],
101 tools: &[ToolSchema],
102 config: &ContextConfig,
103 token_counter: &dyn Fn(&str) -> u32,
104) -> CoreResult<PrunePlan> {
105 let available = config
106 .max_context_tokens
107 .saturating_sub(config.max_response_tokens);
108
109 // Fixed overhead: system block (system prompt + tools) + assistant prefix.
110 let system_block = template.format_system(system_prompt, tools);
111 let fixed_overhead = token_counter(&system_block) + token_counter(template.assistant_prefix());
112
113 if fixed_overhead >= available {
114 return Err(CoreError::Context(format!(
115 "System prompt and tools ({fixed_overhead} tokens) exceed \
116 available context budget ({available} tokens)"
117 )));
118 }
119 let mut budget = available - fixed_overhead;
120
121 let turns = group_into_turns(messages);
122 if turns.is_empty() {
123 return Ok(PrunePlan {
124 kept: vec![],
125 dropped: vec![],
126 });
127 }
128
129 let turn_costs: Vec<u32> = turns
130 .iter()
131 .map(|range| {
132 messages[range.clone()]
133 .iter()
134 .map(|msg| token_counter(&template.format_message(msg)))
135 .sum()
136 })
137 .collect();
138 let turn_pinned: Vec<bool> = turns
139 .iter()
140 .map(|range| messages[range.clone()].iter().any(|m| m.pinned))
141 .collect();
142
143 let last = turns.len() - 1;
144 let mut keep = vec![false; turns.len()];
145
146 // The latest turn must fit — otherwise context overflow.
147 if turn_costs[last] > budget {
148 return Err(CoreError::Context(format!(
149 "Latest message ({} tokens) plus system prompt \
150 ({fixed_overhead} tokens) exceeds context budget ({available} tokens). \
151 Clear the conversation or increase context size.",
152 turn_costs[last]
153 )));
154 }
155 budget -= turn_costs[last];
156 keep[last] = true;
157
158 // Pinned turns always survive, regardless of recency.
159 for i in 0..last {
160 if turn_pinned[i] {
161 if turn_costs[i] > budget {
162 let pinned_total: u32 = (0..turns.len())
163 .filter(|&j| turn_pinned[j])
164 .map(|j| turn_costs[j])
165 .sum();
166 return Err(CoreError::Context(format!(
167 "Pinned messages ({pinned_total} tokens) exceed the available \
168 context budget ({available} tokens). Unpin some messages or \
169 increase context size."
170 )));
171 }
172 budget -= turn_costs[i];
173 keep[i] = true;
174 }
175 }
176
177 // Fill the remaining budget with the most-recent non-pinned turns, walking
178 // backward. Stop at the first non-pinned turn that doesn't fit (sliding
179 // window); already-pinned turns are skipped without stopping the walk.
180 for i in (0..last).rev() {
181 if keep[i] {
182 continue;
183 }
184 if turn_costs[i] <= budget {
185 budget -= turn_costs[i];
186 keep[i] = true;
187 } else {
188 break;
189 }
190 }
191
192 let mut kept = Vec::new();
193 let mut dropped = Vec::new();
194 for (i, range) in turns.iter().enumerate() {
195 if keep[i] {
196 kept.push(range.clone());
197 } else {
198 dropped.push(range.clone());
199 }
200 }
201 Ok(PrunePlan { kept, dropped })
202}
203
204/// Prepare context: prune to fit the budget, apply the template, return the
205/// formatted prompt. Thin wrapper over [`plan_prune`] that formats the kept
206/// turns. Pinned messages always survive (see `plan_prune`).
207///
208/// The agent calls this automatically before each LLM call; call it directly
209/// only when you want custom control.
210///
211/// ```
212/// use orion_core::{ChatMLTemplate, Message};
213/// use orion_core::context::{prepare_context, ContextConfig};
214///
215/// // A real backend tokenizes; here we approximate with a word count.
216/// let token_counter = |text: &str| -> u32 { text.split_whitespace().count() as u32 };
217/// let messages = vec![
218/// Message::user("1", "Hello"),
219/// Message::assistant("2", "Hi there!"),
220/// ];
221///
222/// let prepared = prepare_context(
223/// &ChatMLTemplate, // any `ChatTemplate` impl
224/// "You are helpful.", // system prompt
225/// &messages, // full conversation history
226/// &[], // tool schemas to inject (may be empty)
227/// &ContextConfig::default(),
228/// &token_counter,
229/// )?;
230///
231/// assert!(prepared.prompt.contains("Hi there!"));
232/// assert_eq!(prepared.messages_included, 2);
233/// assert_eq!(prepared.messages_pruned, 0);
234/// # Ok::<(), orion_core::CoreError>(())
235/// ```
236pub fn prepare_context(
237 template: &dyn ChatTemplate,
238 system_prompt: &str,
239 messages: &[Message],
240 tools: &[ToolSchema],
241 config: &ContextConfig,
242 token_counter: &dyn Fn(&str) -> u32,
243) -> CoreResult<PreparedContext> {
244 let plan = plan_prune(
245 template,
246 system_prompt,
247 messages,
248 tools,
249 config,
250 token_counter,
251 )?;
252
253 // Collect kept messages in original order (kept ranges may be non-contiguous
254 // when an old pinned turn survives alongside the recent window).
255 let kept: Vec<Message> = plan
256 .kept
257 .iter()
258 .flat_map(|range| messages[range.clone()].iter().cloned())
259 .collect();
260 let kept_count = kept.len() as u32;
261 let pruned = messages.len() as u32 - kept_count;
262
263 let prompt = template.format(system_prompt, &kept, tools);
264 let token_count = token_counter(&prompt);
265
266 Ok(PreparedContext {
267 prompt,
268 token_count,
269 messages_included: kept_count,
270 messages_pruned: pruned,
271 })
272}