1use std::cmp::Ordering;
10
11use crate::budget::BudgetPolicy;
12use crate::capability::CapabilityProbe;
13use crate::embedder::{cosine, Embedder};
14use crate::metrics::ContextCompilerMetrics;
15use crate::relevance::{HeuristicScorer, RelevanceScore, RelevanceScorer};
16use crate::segment::{Role, Segment, SegmentKind};
17use crate::summarizer::{AnchoredSummary, Summarizer};
18use crate::{ContextCompilerEvent, ContextEmissionSink, SinkRef};
19use ainl_compression::{compress, EfficientMode};
20use ainl_contracts::CognitiveVitals;
21use serde::{Deserialize, Serialize};
22use std::sync::Arc;
23use std::time::Instant;
24use tracing::{debug, warn};
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct ComposedPrompt {
29 pub segments: Vec<Segment>,
31 pub anchored_summary: AnchoredSummary,
33 pub telemetry: ContextCompilerMetrics,
35}
36
37pub struct ContextCompiler {
42 scorer: Arc<dyn RelevanceScorer>,
43 budget: BudgetPolicy,
44 summarizer: Option<Arc<dyn Summarizer>>,
45 embedder: Option<Arc<dyn Embedder>>,
46 sink: SinkRef,
47}
48
49impl ContextCompiler {
50 #[must_use]
52 pub fn new(scorer: Arc<dyn RelevanceScorer>, budget: BudgetPolicy) -> Self {
53 Self {
54 scorer,
55 budget,
56 summarizer: None,
57 embedder: None,
58 sink: None,
59 }
60 }
61
62 #[must_use]
64 pub fn with_defaults() -> Self {
65 Self::new(Arc::new(HeuristicScorer::new()), BudgetPolicy::default())
66 }
67
68 #[must_use]
70 pub fn with_summarizer(mut self, summarizer: Arc<dyn Summarizer>) -> Self {
71 self.summarizer = Some(summarizer);
72 self
73 }
74
75 #[must_use]
77 pub fn with_sink(mut self, sink: Arc<dyn ContextEmissionSink>) -> Self {
78 self.sink = Some(sink);
79 self
80 }
81
82 #[must_use]
84 pub fn with_embedder(mut self, embedder: Arc<dyn Embedder>) -> Self {
85 self.embedder = Some(embedder);
86 self
87 }
88
89 #[must_use]
91 pub fn probe(&self) -> CapabilityProbe {
92 CapabilityProbe {
93 summarizer: self.summarizer.is_some(),
94 embedder: self.embedder.is_some(),
95 }
96 }
97
98 fn emit(&self, event: ContextCompilerEvent) {
99 if let Some(sink) = &self.sink {
100 sink.emit(event);
101 }
102 }
103
104 pub fn compose(
117 &self,
118 latest_user_query: &str,
119 segments: Vec<Segment>,
120 existing_summary: Option<&AnchoredSummary>,
121 vitals: Option<&CognitiveVitals>,
122 ) -> ComposedPrompt {
123 let t0 = Instant::now();
124 let probe = self.probe();
125 let tier = probe.active_tier();
126 self.emit(ContextCompilerEvent::TierSelected {
127 tier,
128 reason: probe.reason(),
129 });
130
131 let mut metrics = ContextCompilerMetrics::new(tier, self.budget.total_window);
132 let _low_trust = self.budget.vitals_aware && vitals.is_some_and(|v| v.trust < 0.5);
136 let default_mode = EfficientMode::Balanced;
137
138 let mut scored: Vec<(usize, RelevanceScore)> = segments
141 .iter()
142 .enumerate()
143 .map(|(idx, s)| (idx, self.scorer.score(s, latest_user_query, vitals)))
144 .collect();
145 scored.sort_by(|a, b| {
147 b.1 .0
148 .partial_cmp(&a.1 .0)
149 .unwrap_or(std::cmp::Ordering::Equal)
150 });
151
152 let mut flexible_budget = self.budget.flexible_budget();
155 self.emit(ContextCompilerEvent::BudgetAllocated {
156 total: self.budget.total_window,
157 per_kind: vec![
158 (SegmentKind::SystemPrompt, self.budget.system_budget()),
159 (SegmentKind::ToolDefinitions, self.budget.tool_def_budget()),
160 (SegmentKind::UserPrompt, self.budget.user_prompt_budget()),
161 ],
162 });
163
164 let recent_pin_threshold = self.budget.recent_turns_keep_verbatim as u32;
167 let pinned_idx: std::collections::HashSet<usize> = segments
168 .iter()
169 .enumerate()
170 .filter(|(_, s)| {
171 s.kind.is_always_keep()
172 || (s.kind == SegmentKind::RecentTurn && s.age_index < recent_pin_threshold)
173 || s.kind == SegmentKind::ToolDefinitions
174 })
175 .map(|(i, _)| i)
176 .collect();
177
178 if let Some(emb) = &self.embedder {
179 if let Ok(qv) = emb.embed(latest_user_query) {
180 let pinned_order: Vec<(usize, RelevanceScore)> = scored
181 .iter()
182 .filter(|(i, _)| pinned_idx.contains(i))
183 .copied()
184 .collect();
185 let mut unpin: Vec<(usize, RelevanceScore)> = scored
186 .iter()
187 .filter(|(i, _)| !pinned_idx.contains(i))
188 .copied()
189 .collect();
190 unpin.sort_by(|(ia, _), (ib, _)| {
191 let a_sim = emb
192 .embed(&segments[*ia].content)
193 .map(|v| cosine(&v, &qv))
194 .unwrap_or(0.0);
195 let b_sim = emb
196 .embed(&segments[*ib].content)
197 .map(|v| cosine(&v, &qv))
198 .unwrap_or(0.0);
199 b_sim.partial_cmp(&a_sim).unwrap_or(Ordering::Equal)
200 });
201 scored = pinned_order;
202 scored.extend(unpin);
203 }
204 }
205
206 let mut keep: Vec<Option<Segment>> = (0..segments.len()).map(|_| None).collect();
208 let mut summarizer_calls: u32 = 0;
209 let mut summarizer_failures: u32 = 0;
210 let mut dropped_for_summarization: Vec<Segment> = Vec::new();
211 let mut anchored = existing_summary
212 .cloned()
213 .unwrap_or_else(AnchoredSummary::empty);
214
215 for &(idx, _score) in scored
217 .iter()
218 .filter(|(i, _)| pinned_idx.contains(i))
219 .collect::<Vec<_>>()
220 .iter()
221 {
222 let original = &segments[*idx];
223 let original_tok = original.token_estimate();
224 keep[*idx] = Some(original.clone());
226 metrics.record_segment(original.kind, original_tok, original_tok, false);
227 self.emit(ContextCompilerEvent::BlockEmitted {
228 source: source_label(original.kind),
229 kind: original.kind,
230 original_tokens: original_tok,
231 kept_tokens: original_tok,
232 });
233 }
234
235 for &(idx, _score) in scored.iter().filter(|(i, _)| !pinned_idx.contains(i)) {
237 let seg = &segments[idx];
238 let original_tok = seg.token_estimate();
239
240 let mode = match seg.kind {
244 SegmentKind::ToolResult => EfficientMode::Aggressive,
245 SegmentKind::OlderTurn => default_mode,
246 SegmentKind::MemoryBlock | SegmentKind::AnchoredSummaryRecall => default_mode,
247 SegmentKind::RecentTurn => default_mode,
248 _ => EfficientMode::Off,
249 };
250
251 let compressed = if mode == EfficientMode::Off {
253 seg.content.clone()
254 } else {
255 compress(&seg.content, mode).text
256 };
257 let compressed_tok = ainl_compression::tokenize_estimate(&compressed);
258
259 if compressed_tok <= flexible_budget {
260 let mut kept = seg.clone();
261 kept.content = compressed;
262 keep[idx] = Some(kept);
263 flexible_budget = flexible_budget.saturating_sub(compressed_tok);
264 metrics.record_segment(seg.kind, original_tok, compressed_tok, false);
265 self.emit(ContextCompilerEvent::BlockEmitted {
266 source: source_label(seg.kind),
267 kind: seg.kind,
268 original_tokens: original_tok,
269 kept_tokens: compressed_tok,
270 });
271 } else {
272 if seg.kind == SegmentKind::OlderTurn {
274 dropped_for_summarization.push(seg.clone());
275 }
276 metrics.record_segment(seg.kind, original_tok, 0, true);
277 debug!(
278 kind = ?seg.kind,
279 original_tok,
280 flexible_budget,
281 "context_compiler: dropped (over budget)"
282 );
283 }
284 }
285
286 if let Some(summ) = &self.summarizer {
288 if !dropped_for_summarization.is_empty() {
289 let s0 = Instant::now();
290 summarizer_calls += 1;
291 match summ.summarize(&dropped_for_summarization, Some(&anchored)) {
292 Ok(new_summary) => {
293 let summary_tokens =
294 ainl_compression::tokenize_estimate(&new_summary.to_prompt_text());
295 anchored = new_summary;
296 anchored.token_estimate = summary_tokens;
297 anchored.iteration = anchored.iteration.saturating_add(1);
298 self.emit(ContextCompilerEvent::SummarizerInvoked {
299 duration_ms: s0.elapsed().as_millis() as u64,
300 segments_in: dropped_for_summarization.len(),
301 summary_tokens,
302 });
303 }
304 Err(e) => {
305 summarizer_failures += 1;
306 warn!(error = %e, "context_compiler: summarizer failed, degrading to Tier 0 for this turn");
307 self.emit(ContextCompilerEvent::SummarizerFailed {
308 duration_ms: s0.elapsed().as_millis() as u64,
309 error_kind: e.kind(),
310 });
311 }
312 }
313 }
314 }
315
316 let mut composed: Vec<Segment> = keep.into_iter().flatten().collect();
321
322 if !anchored.is_empty() {
325 let recall = Segment {
326 kind: SegmentKind::AnchoredSummaryRecall,
327 role: Role::System,
328 content: anchored.to_prompt_text(),
329 age_index: 0,
330 tool_name: None,
331 base_importance: 1.5,
332 #[cfg(feature = "freshness")]
333 freshness: None,
334 };
335 let insert_at = composed
337 .iter()
338 .position(|s| {
339 !matches!(s.kind, SegmentKind::SystemPrompt | SegmentKind::MemoryBlock)
340 })
341 .unwrap_or(composed.len());
342 composed.insert(insert_at, recall);
343 }
344
345 let total_kept_tokens: usize = composed.iter().map(|s| s.token_estimate()).sum();
347 if total_kept_tokens > self.budget.soft_total_cap {
348 self.emit(ContextCompilerEvent::BudgetExceeded {
349 overage: total_kept_tokens.saturating_sub(self.budget.soft_total_cap),
350 });
351 }
352
353 metrics.summarizer_calls = summarizer_calls;
354 metrics.summarizer_failures = summarizer_failures;
355 metrics.elapsed_ms = t0.elapsed().as_millis() as u64;
356
357 ComposedPrompt {
358 segments: composed,
359 anchored_summary: anchored,
360 telemetry: metrics,
361 }
362 }
363}
364
365const fn source_label(kind: SegmentKind) -> &'static str {
366 match kind {
367 SegmentKind::SystemPrompt => "system_prompt",
368 SegmentKind::OlderTurn => "older_turn",
369 SegmentKind::RecentTurn => "recent_turn",
370 SegmentKind::ToolDefinitions => "tool_definitions",
371 SegmentKind::ToolResult => "tool_result",
372 SegmentKind::UserPrompt => "user_prompt",
373 SegmentKind::AnchoredSummaryRecall => "anchored_summary_recall",
374 SegmentKind::MemoryBlock => "memory_block",
375 }
376}
377
378#[cfg(test)]
379mod tests {
380 use super::*;
381 use crate::summarizer::SummarizerError;
382 use std::sync::Mutex;
383
384 #[derive(Default)]
385 struct CapturingSink {
386 events: Mutex<Vec<ContextCompilerEvent>>,
387 }
388
389 impl ContextEmissionSink for CapturingSink {
390 fn emit(&self, event: ContextCompilerEvent) {
391 self.events.lock().expect("lock").push(event);
392 }
393 }
394
395 fn long_text(prefix: &str, n: usize) -> String {
396 let mut out = String::new();
397 for i in 0..n {
398 out.push_str(prefix);
399 out.push_str(&format!(" sentence {i}. "));
400 }
401 out
402 }
403
404 #[test]
405 fn tier0_compose_keeps_system_and_user_verbatim() {
406 let compiler = ContextCompiler::with_defaults();
407 let segments = vec![
408 Segment::system_prompt("You are a helpful assistant."),
409 Segment::user_prompt("Help me debug a tokio runtime issue."),
410 ];
411 let out = compiler.compose("Help me debug a tokio runtime issue.", segments, None, None);
412 assert_eq!(out.segments.len(), 2);
413 assert!(out
414 .segments
415 .iter()
416 .any(|s| s.kind == SegmentKind::SystemPrompt));
417 assert!(out
418 .segments
419 .iter()
420 .any(|s| s.kind == SegmentKind::UserPrompt));
421 assert_eq!(out.telemetry.tier, "heuristic");
422 assert_eq!(out.telemetry.summarizer_calls, 0);
423 }
424
425 #[test]
426 fn tier0_compresses_long_older_turns_within_budget() {
427 let budget = BudgetPolicy {
428 total_window: 4_000, ..BudgetPolicy::default()
430 };
431 let compiler = ContextCompiler::new(Arc::new(HeuristicScorer::new()), budget);
432 let segments = vec![
433 Segment::system_prompt("system"),
434 Segment::older_turn(
435 Role::Assistant,
436 long_text("rust borrow checker tokio", 200),
437 10,
438 ),
439 Segment::user_prompt("rust tokio"),
440 ];
441 let out = compiler.compose("rust tokio", segments, None, None);
442 assert!(out
444 .segments
445 .iter()
446 .any(|s| s.kind == SegmentKind::SystemPrompt));
447 assert!(out
448 .segments
449 .iter()
450 .any(|s| s.kind == SegmentKind::UserPrompt));
451 assert!(out.telemetry.total_original_tokens > 0);
453 }
454
455 #[test]
456 fn sink_receives_tier_and_block_events() {
457 let sink = Arc::new(CapturingSink::default());
458 let compiler = ContextCompiler::with_defaults().with_sink(sink.clone());
459 let segments = vec![Segment::system_prompt("sys"), Segment::user_prompt("hi")];
460 let _ = compiler.compose("hi", segments, None, None);
461 let events = sink.events.lock().unwrap();
462 assert!(events
463 .iter()
464 .any(|e| matches!(e, ContextCompilerEvent::TierSelected { .. })));
465 assert!(events
466 .iter()
467 .any(|e| matches!(e, ContextCompilerEvent::BlockEmitted { .. })));
468 assert!(events
469 .iter()
470 .any(|e| matches!(e, ContextCompilerEvent::BudgetAllocated { .. })));
471 }
472
473 #[test]
474 fn tier1_summarizer_invoked_on_dropped_older_turns() {
475 struct MockSummarizer;
476 impl Summarizer for MockSummarizer {
477 fn summarize(
478 &self,
479 segments: &[Segment],
480 _existing: Option<&AnchoredSummary>,
481 ) -> Result<AnchoredSummary, SummarizerError> {
482 let mut s = AnchoredSummary::empty();
483 s.sections[0].content = format!("Summarized {} segments.", segments.len());
484 Ok(s)
485 }
486 }
487 let budget = BudgetPolicy {
488 total_window: 2_000,
489 ..BudgetPolicy::default()
490 };
491 let compiler = ContextCompiler::new(Arc::new(HeuristicScorer::new()), budget)
492 .with_summarizer(Arc::new(MockSummarizer));
493 let mut segments: Vec<Segment> = (0..30)
495 .map(|i| Segment::older_turn(Role::Assistant, long_text("rust", 100), i + 5))
496 .collect();
497 segments.insert(0, Segment::system_prompt("sys"));
498 segments.push(Segment::user_prompt("rust"));
499 let out = compiler.compose("rust", segments, None, None);
500 assert_eq!(out.telemetry.tier, "heuristic_summarization");
501 assert!(out.telemetry.summarizer_calls > 0);
502 assert!(!out.anchored_summary.is_empty());
503 assert!(out
505 .segments
506 .iter()
507 .any(|s| s.kind == SegmentKind::AnchoredSummaryRecall));
508 }
509
510 #[test]
511 fn with_embedder_reranks_unpinned_without_panic() {
512 use crate::embedder::PlaceholderEmbedder;
513 let budget = BudgetPolicy {
514 total_window: 1_000,
515 ..BudgetPolicy::default()
516 };
517 let compiler = ContextCompiler::new(Arc::new(HeuristicScorer::new()), budget)
518 .with_embedder(Arc::new(PlaceholderEmbedder::new()));
519 let segments = vec![
521 Segment::system_prompt("sys"),
522 Segment::older_turn(Role::User, "unrelated zzz", 4),
523 Segment::older_turn(Role::Assistant, "the answer is forty two", 3),
524 Segment::user_prompt("forty two"),
525 ];
526 let out = compiler.compose("forty two", segments, None, None);
527 assert!(!out.segments.is_empty());
528 assert_eq!(out.telemetry.tier, "heuristic_summarization_embedding");
529 }
530
531 #[test]
532 fn summarizer_failure_degrades_gracefully() {
533 struct FailingSummarizer;
534 impl Summarizer for FailingSummarizer {
535 fn summarize(
536 &self,
537 _segments: &[Segment],
538 _existing: Option<&AnchoredSummary>,
539 ) -> Result<AnchoredSummary, SummarizerError> {
540 Err(SummarizerError::Timeout)
541 }
542 }
543 let sink = Arc::new(CapturingSink::default());
544 let budget = BudgetPolicy {
545 total_window: 1_500,
546 ..BudgetPolicy::default()
547 };
548 let compiler = ContextCompiler::new(Arc::new(HeuristicScorer::new()), budget)
549 .with_summarizer(Arc::new(FailingSummarizer))
550 .with_sink(sink.clone());
551 let mut segments: Vec<Segment> = (0..20)
552 .map(|i| Segment::older_turn(Role::Assistant, long_text("rust", 80), i + 5))
553 .collect();
554 segments.insert(0, Segment::system_prompt("sys"));
555 segments.push(Segment::user_prompt("rust"));
556 let out = compiler.compose("rust", segments, None, None);
557 assert!(out.telemetry.summarizer_failures > 0);
558 let events = sink.events.lock().unwrap();
559 assert!(events
560 .iter()
561 .any(|e| matches!(e, ContextCompilerEvent::SummarizerFailed { .. })));
562 }
563}