use std::cmp::Ordering;
use crate::budget::BudgetPolicy;
use crate::capability::CapabilityProbe;
use crate::embedder::{cosine, Embedder};
use crate::metrics::ContextCompilerMetrics;
use crate::relevance::{HeuristicScorer, RelevanceScore, RelevanceScorer};
use crate::segment::{Role, Segment, SegmentKind};
use crate::summarizer::{AnchoredSummary, Summarizer};
use crate::{ContextCompilerEvent, ContextEmissionSink, SinkRef};
use ainl_compression::{compress, EfficientMode};
use ainl_contracts::CognitiveVitals;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::Instant;
use tracing::{debug, warn};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ComposedPrompt {
pub segments: Vec<Segment>,
pub anchored_summary: AnchoredSummary,
pub telemetry: ContextCompilerMetrics,
}
pub struct ContextCompiler {
scorer: Arc<dyn RelevanceScorer>,
budget: BudgetPolicy,
summarizer: Option<Arc<dyn Summarizer>>,
embedder: Option<Arc<dyn Embedder>>,
sink: SinkRef,
}
impl ContextCompiler {
#[must_use]
pub fn new(scorer: Arc<dyn RelevanceScorer>, budget: BudgetPolicy) -> Self {
Self {
scorer,
budget,
summarizer: None,
embedder: None,
sink: None,
}
}
#[must_use]
pub fn with_defaults() -> Self {
Self::new(Arc::new(HeuristicScorer::new()), BudgetPolicy::default())
}
#[must_use]
pub fn with_summarizer(mut self, summarizer: Arc<dyn Summarizer>) -> Self {
self.summarizer = Some(summarizer);
self
}
#[must_use]
pub fn with_sink(mut self, sink: Arc<dyn ContextEmissionSink>) -> Self {
self.sink = Some(sink);
self
}
#[must_use]
pub fn with_embedder(mut self, embedder: Arc<dyn Embedder>) -> Self {
self.embedder = Some(embedder);
self
}
#[must_use]
pub fn probe(&self) -> CapabilityProbe {
CapabilityProbe {
summarizer: self.summarizer.is_some(),
embedder: self.embedder.is_some(),
}
}
fn emit(&self, event: ContextCompilerEvent) {
if let Some(sink) = &self.sink {
sink.emit(event);
}
}
pub fn compose(
&self,
latest_user_query: &str,
segments: Vec<Segment>,
existing_summary: Option<&AnchoredSummary>,
vitals: Option<&CognitiveVitals>,
) -> ComposedPrompt {
let t0 = Instant::now();
let probe = self.probe();
let tier = probe.active_tier();
self.emit(ContextCompilerEvent::TierSelected {
tier,
reason: probe.reason(),
});
let mut metrics = ContextCompilerMetrics::new(tier, self.budget.total_window);
let _low_trust = self.budget.vitals_aware && vitals.is_some_and(|v| v.trust < 0.5);
let default_mode = EfficientMode::Balanced;
let mut scored: Vec<(usize, RelevanceScore)> = segments
.iter()
.enumerate()
.map(|(idx, s)| (idx, self.scorer.score(s, latest_user_query, vitals)))
.collect();
scored.sort_by(|a, b| {
b.1 .0
.partial_cmp(&a.1 .0)
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut flexible_budget = self.budget.flexible_budget();
self.emit(ContextCompilerEvent::BudgetAllocated {
total: self.budget.total_window,
per_kind: vec![
(SegmentKind::SystemPrompt, self.budget.system_budget()),
(SegmentKind::ToolDefinitions, self.budget.tool_def_budget()),
(SegmentKind::UserPrompt, self.budget.user_prompt_budget()),
],
});
let recent_pin_threshold = self.budget.recent_turns_keep_verbatim as u32;
let pinned_idx: std::collections::HashSet<usize> = segments
.iter()
.enumerate()
.filter(|(_, s)| {
s.kind.is_always_keep()
|| (s.kind == SegmentKind::RecentTurn && s.age_index < recent_pin_threshold)
|| s.kind == SegmentKind::ToolDefinitions
})
.map(|(i, _)| i)
.collect();
if let Some(emb) = &self.embedder {
if let Ok(qv) = emb.embed(latest_user_query) {
let pinned_order: Vec<(usize, RelevanceScore)> = scored
.iter()
.filter(|(i, _)| pinned_idx.contains(i))
.copied()
.collect();
let mut unpin: Vec<(usize, RelevanceScore)> = scored
.iter()
.filter(|(i, _)| !pinned_idx.contains(i))
.copied()
.collect();
unpin.sort_by(|(ia, _), (ib, _)| {
let a_sim = emb
.embed(&segments[*ia].content)
.map(|v| cosine(&v, &qv))
.unwrap_or(0.0);
let b_sim = emb
.embed(&segments[*ib].content)
.map(|v| cosine(&v, &qv))
.unwrap_or(0.0);
b_sim.partial_cmp(&a_sim).unwrap_or(Ordering::Equal)
});
scored = pinned_order;
scored.extend(unpin);
}
}
let mut keep: Vec<Option<Segment>> = (0..segments.len()).map(|_| None).collect();
let mut summarizer_calls: u32 = 0;
let mut summarizer_failures: u32 = 0;
let mut dropped_for_summarization: Vec<Segment> = Vec::new();
let mut anchored = existing_summary
.cloned()
.unwrap_or_else(AnchoredSummary::empty);
for &(idx, _score) in scored
.iter()
.filter(|(i, _)| pinned_idx.contains(i))
.collect::<Vec<_>>()
.iter()
{
let original = &segments[*idx];
let original_tok = original.token_estimate();
keep[*idx] = Some(original.clone());
metrics.record_segment(original.kind, original_tok, original_tok, false);
self.emit(ContextCompilerEvent::BlockEmitted {
source: source_label(original.kind),
kind: original.kind,
original_tokens: original_tok,
kept_tokens: original_tok,
});
}
for &(idx, _score) in scored.iter().filter(|(i, _)| !pinned_idx.contains(i)) {
let seg = &segments[idx];
let original_tok = seg.token_estimate();
let mode = match seg.kind {
SegmentKind::ToolResult => EfficientMode::Aggressive,
SegmentKind::OlderTurn => default_mode,
SegmentKind::MemoryBlock | SegmentKind::AnchoredSummaryRecall => default_mode,
SegmentKind::RecentTurn => default_mode,
_ => EfficientMode::Off,
};
let compressed = if mode == EfficientMode::Off {
seg.content.clone()
} else {
compress(&seg.content, mode).text
};
let compressed_tok = ainl_compression::tokenize_estimate(&compressed);
if compressed_tok <= flexible_budget {
let mut kept = seg.clone();
kept.content = compressed;
keep[idx] = Some(kept);
flexible_budget = flexible_budget.saturating_sub(compressed_tok);
metrics.record_segment(seg.kind, original_tok, compressed_tok, false);
self.emit(ContextCompilerEvent::BlockEmitted {
source: source_label(seg.kind),
kind: seg.kind,
original_tokens: original_tok,
kept_tokens: compressed_tok,
});
} else {
if seg.kind == SegmentKind::OlderTurn {
dropped_for_summarization.push(seg.clone());
}
metrics.record_segment(seg.kind, original_tok, 0, true);
debug!(
kind = ?seg.kind,
original_tok,
flexible_budget,
"context_compiler: dropped (over budget)"
);
}
}
if let Some(summ) = &self.summarizer {
if !dropped_for_summarization.is_empty() {
let s0 = Instant::now();
summarizer_calls += 1;
match summ.summarize(&dropped_for_summarization, Some(&anchored)) {
Ok(new_summary) => {
let summary_tokens =
ainl_compression::tokenize_estimate(&new_summary.to_prompt_text());
anchored = new_summary;
anchored.token_estimate = summary_tokens;
anchored.iteration = anchored.iteration.saturating_add(1);
self.emit(ContextCompilerEvent::SummarizerInvoked {
duration_ms: s0.elapsed().as_millis() as u64,
segments_in: dropped_for_summarization.len(),
summary_tokens,
});
}
Err(e) => {
summarizer_failures += 1;
warn!(error = %e, "context_compiler: summarizer failed, degrading to Tier 0 for this turn");
self.emit(ContextCompilerEvent::SummarizerFailed {
duration_ms: s0.elapsed().as_millis() as u64,
error_kind: e.kind(),
});
}
}
}
}
let mut composed: Vec<Segment> = keep.into_iter().flatten().collect();
if !anchored.is_empty() {
let recall = Segment {
kind: SegmentKind::AnchoredSummaryRecall,
role: Role::System,
content: anchored.to_prompt_text(),
age_index: 0,
tool_name: None,
base_importance: 1.5,
#[cfg(feature = "freshness")]
freshness: None,
};
let insert_at = composed
.iter()
.position(|s| {
!matches!(s.kind, SegmentKind::SystemPrompt | SegmentKind::MemoryBlock)
})
.unwrap_or(composed.len());
composed.insert(insert_at, recall);
}
let total_kept_tokens: usize = composed.iter().map(|s| s.token_estimate()).sum();
if total_kept_tokens > self.budget.soft_total_cap {
self.emit(ContextCompilerEvent::BudgetExceeded {
overage: total_kept_tokens.saturating_sub(self.budget.soft_total_cap),
});
}
metrics.summarizer_calls = summarizer_calls;
metrics.summarizer_failures = summarizer_failures;
metrics.elapsed_ms = t0.elapsed().as_millis() as u64;
ComposedPrompt {
segments: composed,
anchored_summary: anchored,
telemetry: metrics,
}
}
}
const fn source_label(kind: SegmentKind) -> &'static str {
match kind {
SegmentKind::SystemPrompt => "system_prompt",
SegmentKind::OlderTurn => "older_turn",
SegmentKind::RecentTurn => "recent_turn",
SegmentKind::ToolDefinitions => "tool_definitions",
SegmentKind::ToolResult => "tool_result",
SegmentKind::UserPrompt => "user_prompt",
SegmentKind::AnchoredSummaryRecall => "anchored_summary_recall",
SegmentKind::MemoryBlock => "memory_block",
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::summarizer::SummarizerError;
use std::sync::Mutex;
#[derive(Default)]
struct CapturingSink {
events: Mutex<Vec<ContextCompilerEvent>>,
}
impl ContextEmissionSink for CapturingSink {
fn emit(&self, event: ContextCompilerEvent) {
self.events.lock().expect("lock").push(event);
}
}
fn long_text(prefix: &str, n: usize) -> String {
let mut out = String::new();
for i in 0..n {
out.push_str(prefix);
out.push_str(&format!(" sentence {i}. "));
}
out
}
#[test]
fn tier0_compose_keeps_system_and_user_verbatim() {
let compiler = ContextCompiler::with_defaults();
let segments = vec![
Segment::system_prompt("You are a helpful assistant."),
Segment::user_prompt("Help me debug a tokio runtime issue."),
];
let out = compiler.compose("Help me debug a tokio runtime issue.", segments, None, None);
assert_eq!(out.segments.len(), 2);
assert!(out
.segments
.iter()
.any(|s| s.kind == SegmentKind::SystemPrompt));
assert!(out
.segments
.iter()
.any(|s| s.kind == SegmentKind::UserPrompt));
assert_eq!(out.telemetry.tier, "heuristic");
assert_eq!(out.telemetry.summarizer_calls, 0);
}
#[test]
fn tier0_compresses_long_older_turns_within_budget() {
let budget = BudgetPolicy {
total_window: 4_000, ..BudgetPolicy::default()
};
let compiler = ContextCompiler::new(Arc::new(HeuristicScorer::new()), budget);
let segments = vec![
Segment::system_prompt("system"),
Segment::older_turn(
Role::Assistant,
long_text("rust borrow checker tokio", 200),
10,
),
Segment::user_prompt("rust tokio"),
];
let out = compiler.compose("rust tokio", segments, None, None);
assert!(out
.segments
.iter()
.any(|s| s.kind == SegmentKind::SystemPrompt));
assert!(out
.segments
.iter()
.any(|s| s.kind == SegmentKind::UserPrompt));
assert!(out.telemetry.total_original_tokens > 0);
}
#[test]
fn sink_receives_tier_and_block_events() {
let sink = Arc::new(CapturingSink::default());
let compiler = ContextCompiler::with_defaults().with_sink(sink.clone());
let segments = vec![Segment::system_prompt("sys"), Segment::user_prompt("hi")];
let _ = compiler.compose("hi", segments, None, None);
let events = sink.events.lock().unwrap();
assert!(events
.iter()
.any(|e| matches!(e, ContextCompilerEvent::TierSelected { .. })));
assert!(events
.iter()
.any(|e| matches!(e, ContextCompilerEvent::BlockEmitted { .. })));
assert!(events
.iter()
.any(|e| matches!(e, ContextCompilerEvent::BudgetAllocated { .. })));
}
#[test]
fn tier1_summarizer_invoked_on_dropped_older_turns() {
struct MockSummarizer;
impl Summarizer for MockSummarizer {
fn summarize(
&self,
segments: &[Segment],
_existing: Option<&AnchoredSummary>,
) -> Result<AnchoredSummary, SummarizerError> {
let mut s = AnchoredSummary::empty();
s.sections[0].content = format!("Summarized {} segments.", segments.len());
Ok(s)
}
}
let budget = BudgetPolicy {
total_window: 2_000,
..BudgetPolicy::default()
};
let compiler = ContextCompiler::new(Arc::new(HeuristicScorer::new()), budget)
.with_summarizer(Arc::new(MockSummarizer));
let mut segments: Vec<Segment> = (0..30)
.map(|i| Segment::older_turn(Role::Assistant, long_text("rust", 100), i + 5))
.collect();
segments.insert(0, Segment::system_prompt("sys"));
segments.push(Segment::user_prompt("rust"));
let out = compiler.compose("rust", segments, None, None);
assert_eq!(out.telemetry.tier, "heuristic_summarization");
assert!(out.telemetry.summarizer_calls > 0);
assert!(!out.anchored_summary.is_empty());
assert!(out
.segments
.iter()
.any(|s| s.kind == SegmentKind::AnchoredSummaryRecall));
}
#[test]
fn with_embedder_reranks_unpinned_without_panic() {
use crate::embedder::PlaceholderEmbedder;
let budget = BudgetPolicy {
total_window: 1_000,
..BudgetPolicy::default()
};
let compiler = ContextCompiler::new(Arc::new(HeuristicScorer::new()), budget)
.with_embedder(Arc::new(PlaceholderEmbedder::new()));
let segments = vec![
Segment::system_prompt("sys"),
Segment::older_turn(Role::User, "unrelated zzz", 4),
Segment::older_turn(Role::Assistant, "the answer is forty two", 3),
Segment::user_prompt("forty two"),
];
let out = compiler.compose("forty two", segments, None, None);
assert!(!out.segments.is_empty());
assert_eq!(out.telemetry.tier, "heuristic_summarization_embedding");
}
#[test]
fn summarizer_failure_degrades_gracefully() {
struct FailingSummarizer;
impl Summarizer for FailingSummarizer {
fn summarize(
&self,
_segments: &[Segment],
_existing: Option<&AnchoredSummary>,
) -> Result<AnchoredSummary, SummarizerError> {
Err(SummarizerError::Timeout)
}
}
let sink = Arc::new(CapturingSink::default());
let budget = BudgetPolicy {
total_window: 1_500,
..BudgetPolicy::default()
};
let compiler = ContextCompiler::new(Arc::new(HeuristicScorer::new()), budget)
.with_summarizer(Arc::new(FailingSummarizer))
.with_sink(sink.clone());
let mut segments: Vec<Segment> = (0..20)
.map(|i| Segment::older_turn(Role::Assistant, long_text("rust", 80), i + 5))
.collect();
segments.insert(0, Segment::system_prompt("sys"));
segments.push(Segment::user_prompt("rust"));
let out = compiler.compose("rust", segments, None, None);
assert!(out.telemetry.summarizer_failures > 0);
let events = sink.events.lock().unwrap();
assert!(events
.iter()
.any(|e| matches!(e, ContextCompilerEvent::SummarizerFailed { .. })));
}
}