Skip to main content

ainl_context_compiler/
orchestrator.rs

1//! `ContextCompiler::compose` — orchestrates segment selection, budget allocation, compaction,
2//! and per-segment compression, returning a [`ComposedPrompt`] plus telemetry.
3//!
4//! The Tier 0 path is deterministic and offline-safe: heuristic relevance scoring, greedy
5//! budget fill, and per-segment [`ainl_compression`] passes. Tier ≥ 1 lights up automatically
6//! when the host injects a [`Summarizer`] (M2) or [`crate::embedder::Embedder`] (M3) — but the
7//! orchestrator never blocks or fails when those are absent.
8
9use std::cmp::Ordering;
10
11use crate::budget::BudgetPolicy;
12use crate::capability::CapabilityProbe;
13use crate::embedder::{cosine, Embedder};
14use crate::metrics::ContextCompilerMetrics;
15use crate::relevance::{HeuristicScorer, RelevanceScore, RelevanceScorer};
16use crate::segment::{Role, Segment, SegmentKind};
17use crate::summarizer::{AnchoredSummary, Summarizer};
18use crate::{ContextCompilerEvent, ContextEmissionSink, SinkRef};
19use ainl_compression::{compress, EfficientMode};
20use ainl_contracts::CognitiveVitals;
21use serde::{Deserialize, Serialize};
22use std::sync::Arc;
23use std::time::Instant;
24use tracing::{debug, warn};
25
26/// Result of one `compose()` call.
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct ComposedPrompt {
29    /// Segments to assemble into the final LLM input, in target order.
30    pub segments: Vec<Segment>,
31    /// The anchored summary state after this compose call (may be `is_empty()` at Tier 0).
32    pub anchored_summary: AnchoredSummary,
33    /// Aggregate + per-segment telemetry.
34    pub telemetry: ContextCompilerMetrics,
35}
36
37/// Builder + entry point for context-window assembly.
38///
39/// Hosts construct one `ContextCompiler` per session (or per agent) and reuse it across turns.
40/// Optional dependencies (`summarizer`, `sink`) can be set once and shared via `Arc`.
41pub struct ContextCompiler {
42    scorer: Arc<dyn RelevanceScorer>,
43    budget: BudgetPolicy,
44    summarizer: Option<Arc<dyn Summarizer>>,
45    embedder: Option<Arc<dyn Embedder>>,
46    sink: SinkRef,
47}
48
49impl ContextCompiler {
50    /// Construct a Tier 0 compiler with the supplied scorer and budget.
51    #[must_use]
52    pub fn new(scorer: Arc<dyn RelevanceScorer>, budget: BudgetPolicy) -> Self {
53        Self {
54            scorer,
55            budget,
56            summarizer: None,
57            embedder: None,
58            sink: None,
59        }
60    }
61
62    /// Convenience: build with the default heuristic scorer and default budget.
63    #[must_use]
64    pub fn with_defaults() -> Self {
65        Self::new(Arc::new(HeuristicScorer::new()), BudgetPolicy::default())
66    }
67
68    /// Inject a Tier 1 summarizer (M2). When absent, the orchestrator runs at Tier 0.
69    #[must_use]
70    pub fn with_summarizer(mut self, summarizer: Arc<dyn Summarizer>) -> Self {
71        self.summarizer = Some(summarizer);
72        self
73    }
74
75    /// Attach a structured-event sink (mirrors `ainl_compression::with_telemetry_callback`).
76    #[must_use]
77    pub fn with_sink(mut self, sink: Arc<dyn ContextEmissionSink>) -> Self {
78        self.sink = Some(sink);
79        self
80    }
81
82    /// Inject a Tier 2 / M3 embedder for relevance reranking of non-pinned segments.
83    #[must_use]
84    pub fn with_embedder(mut self, embedder: Arc<dyn Embedder>) -> Self {
85        self.embedder = Some(embedder);
86        self
87    }
88
89    /// Return the active capability probe based on injected dependencies.
90    #[must_use]
91    pub fn probe(&self) -> CapabilityProbe {
92        CapabilityProbe {
93            summarizer: self.summarizer.is_some(),
94            embedder: self.embedder.is_some(),
95        }
96    }
97
98    fn emit(&self, event: ContextCompilerEvent) {
99        if let Some(sink) = &self.sink {
100            sink.emit(event);
101        }
102    }
103
104    /// Compose a prompt window from `segments`, scored against `latest_user_query`, within
105    /// `self.budget`. `existing_summary` carries over from the prior turn (Tier ≥ 1 only).
106    /// `vitals` (when supplied) caps compression aggressiveness on low-trust turns per
107    /// SELF_LEARNING_INTEGRATION_MAP §15.3.
108    ///
109    /// Algorithm (per the plan):
110    /// 1. Coarse selection — always-keep + heuristic scoring.
111    /// 2. Budget allocation — apportion across kinds per `BudgetPolicy`.
112    /// 3. Older-history compaction — `Summarizer` when present, else heuristic compression.
113    /// 4. Tool-result fine-grained pruning.
114    /// 5. Per-segment compression via `ainl_compression`.
115    /// 6. Emit telemetry events.
116    pub fn compose(
117        &self,
118        latest_user_query: &str,
119        segments: Vec<Segment>,
120        existing_summary: Option<&AnchoredSummary>,
121        vitals: Option<&CognitiveVitals>,
122    ) -> ComposedPrompt {
123        let t0 = Instant::now();
124        let probe = self.probe();
125        let tier = probe.active_tier();
126        self.emit(ContextCompilerEvent::TierSelected {
127            tier,
128            reason: probe.reason(),
129        });
130
131        let mut metrics = ContextCompilerMetrics::new(tier, self.budget.total_window);
132        // Default mode by vitals: low-trust caps at Balanced, otherwise Balanced default.
133        // (Both branches resolve to Balanced today; placeholder while a tighter Aggressive
134        // tier is wired up — keep the conditional so the vitals signal stays observable.)
135        let _low_trust = self.budget.vitals_aware && vitals.is_some_and(|v| v.trust < 0.5);
136        let default_mode = EfficientMode::Balanced;
137
138        // ── 1. Coarse selection ─────────────────────────────────────────────────────────
139        // Score every segment, then split into always-keep and rankable.
140        let mut scored: Vec<(usize, RelevanceScore)> = segments
141            .iter()
142            .enumerate()
143            .map(|(idx, s)| (idx, self.scorer.score(s, latest_user_query, vitals)))
144            .collect();
145        // Sort highest-score first so the greedy fill picks most-relevant segments.
146        scored.sort_by(|a, b| {
147            b.1 .0
148                .partial_cmp(&a.1 .0)
149                .unwrap_or(std::cmp::Ordering::Equal)
150        });
151
152        // ── 2. Budget allocation ────────────────────────────────────────────────────────
153        // Always-keep slots come out of fixed reservations; the rest competes for `flexible_budget`.
154        let mut flexible_budget = self.budget.flexible_budget();
155        self.emit(ContextCompilerEvent::BudgetAllocated {
156            total: self.budget.total_window,
157            per_kind: vec![
158                (SegmentKind::SystemPrompt, self.budget.system_budget()),
159                (SegmentKind::ToolDefinitions, self.budget.tool_def_budget()),
160                (SegmentKind::UserPrompt, self.budget.user_prompt_budget()),
161            ],
162        });
163
164        // Recent-turns-keep-verbatim window: count from age_index = 0 upward; the N most recent
165        // RecentTurn segments are pinned regardless of their heuristic score.
166        let recent_pin_threshold = self.budget.recent_turns_keep_verbatim as u32;
167        let pinned_idx: std::collections::HashSet<usize> = segments
168            .iter()
169            .enumerate()
170            .filter(|(_, s)| {
171                s.kind.is_always_keep()
172                    || (s.kind == SegmentKind::RecentTurn && s.age_index < recent_pin_threshold)
173                    || s.kind == SegmentKind::ToolDefinitions
174            })
175            .map(|(i, _)| i)
176            .collect();
177
178        if let Some(emb) = &self.embedder {
179            if let Ok(qv) = emb.embed(latest_user_query) {
180                let pinned_order: Vec<(usize, RelevanceScore)> = scored
181                    .iter()
182                    .filter(|(i, _)| pinned_idx.contains(i))
183                    .copied()
184                    .collect();
185                let mut unpin: Vec<(usize, RelevanceScore)> = scored
186                    .iter()
187                    .filter(|(i, _)| !pinned_idx.contains(i))
188                    .copied()
189                    .collect();
190                unpin.sort_by(|(ia, _), (ib, _)| {
191                    let a_sim = emb
192                        .embed(&segments[*ia].content)
193                        .map(|v| cosine(&v, &qv))
194                        .unwrap_or(0.0);
195                    let b_sim = emb
196                        .embed(&segments[*ib].content)
197                        .map(|v| cosine(&v, &qv))
198                        .unwrap_or(0.0);
199                    b_sim.partial_cmp(&a_sim).unwrap_or(Ordering::Equal)
200                });
201                scored = pinned_order;
202                scored.extend(unpin);
203            }
204        }
205
206        // ── 3+4+5. Greedy fill with per-segment compression ─────────────────────────────
207        let mut keep: Vec<Option<Segment>> = (0..segments.len()).map(|_| None).collect();
208        let mut summarizer_calls: u32 = 0;
209        let mut summarizer_failures: u32 = 0;
210        let mut dropped_for_summarization: Vec<Segment> = Vec::new();
211        let mut anchored = existing_summary
212            .cloned()
213            .unwrap_or_else(AnchoredSummary::empty);
214
215        // Pinned segments first (always-keep + recent-pinned).
216        for &(idx, _score) in scored
217            .iter()
218            .filter(|(i, _)| pinned_idx.contains(i))
219            .collect::<Vec<_>>()
220            .iter()
221        {
222            let original = &segments[*idx];
223            let original_tok = original.token_estimate();
224            // Pinned segments are never compressed by default (system + user + tool defs + recent).
225            keep[*idx] = Some(original.clone());
226            metrics.record_segment(original.kind, original_tok, original_tok, false);
227            self.emit(ContextCompilerEvent::BlockEmitted {
228                source: source_label(original.kind),
229                kind: original.kind,
230                original_tokens: original_tok,
231                kept_tokens: original_tok,
232            });
233        }
234
235        // Now the rankable rest, in score order.
236        for &(idx, _score) in scored.iter().filter(|(i, _)| !pinned_idx.contains(i)) {
237            let seg = &segments[idx];
238            let original_tok = seg.token_estimate();
239
240            // Compression mode per kind: tool results get most aggressive treatment (highest
241            // savings ratio per industry consensus 10:1-20:1). Older turns get Balanced. Memory
242            // blocks stay verbatim or use Balanced if oversized.
243            let mode = match seg.kind {
244                SegmentKind::ToolResult => EfficientMode::Aggressive,
245                SegmentKind::OlderTurn => default_mode,
246                SegmentKind::MemoryBlock | SegmentKind::AnchoredSummaryRecall => default_mode,
247                SegmentKind::RecentTurn => default_mode,
248                _ => EfficientMode::Off,
249            };
250
251            // Try to compress first to see if it fits in the remaining flexible budget.
252            let compressed = if mode == EfficientMode::Off {
253                seg.content.clone()
254            } else {
255                compress(&seg.content, mode).text
256            };
257            let compressed_tok = ainl_compression::tokenize_estimate(&compressed);
258
259            if compressed_tok <= flexible_budget {
260                let mut kept = seg.clone();
261                kept.content = compressed;
262                keep[idx] = Some(kept);
263                flexible_budget = flexible_budget.saturating_sub(compressed_tok);
264                metrics.record_segment(seg.kind, original_tok, compressed_tok, false);
265                self.emit(ContextCompilerEvent::BlockEmitted {
266                    source: source_label(seg.kind),
267                    kind: seg.kind,
268                    original_tokens: original_tok,
269                    kept_tokens: compressed_tok,
270                });
271            } else {
272                // Doesn't fit — drop. If it's an older turn, queue it for summarization (Tier ≥ 1).
273                if seg.kind == SegmentKind::OlderTurn {
274                    dropped_for_summarization.push(seg.clone());
275                }
276                metrics.record_segment(seg.kind, original_tok, 0, true);
277                debug!(
278                    kind = ?seg.kind,
279                    original_tok,
280                    flexible_budget,
281                    "context_compiler: dropped (over budget)"
282                );
283            }
284        }
285
286        // ── Tier ≥ 1: anchored summarization of dropped older turns ─────────────────────
287        if let Some(summ) = &self.summarizer {
288            if !dropped_for_summarization.is_empty() {
289                let s0 = Instant::now();
290                summarizer_calls += 1;
291                match summ.summarize(&dropped_for_summarization, Some(&anchored)) {
292                    Ok(new_summary) => {
293                        let summary_tokens =
294                            ainl_compression::tokenize_estimate(&new_summary.to_prompt_text());
295                        anchored = new_summary;
296                        anchored.token_estimate = summary_tokens;
297                        anchored.iteration = anchored.iteration.saturating_add(1);
298                        self.emit(ContextCompilerEvent::SummarizerInvoked {
299                            duration_ms: s0.elapsed().as_millis() as u64,
300                            segments_in: dropped_for_summarization.len(),
301                            summary_tokens,
302                        });
303                    }
304                    Err(e) => {
305                        summarizer_failures += 1;
306                        warn!(error = %e, "context_compiler: summarizer failed, degrading to Tier 0 for this turn");
307                        self.emit(ContextCompilerEvent::SummarizerFailed {
308                            duration_ms: s0.elapsed().as_millis() as u64,
309                            error_kind: e.kind(),
310                        });
311                    }
312                }
313            }
314        }
315
316        // Assemble in stable original order: SystemPrompt → MemoryBlock → AnchoredSummaryRecall
317        // → OlderTurn (oldest→newest by age_index desc→asc) → RecentTurn (oldest→newest)
318        // → ToolDefinitions → UserPrompt. We preserve the host's intended ordering for now by
319        // emitting in original index order (the host arranges segments before calling compose).
320        let mut composed: Vec<Segment> = keep.into_iter().flatten().collect();
321
322        // If summarizer produced content, inject it as an AnchoredSummaryRecall segment near
323        // the top so the LLM sees it before older turns.
324        if !anchored.is_empty() {
325            let recall = Segment {
326                kind: SegmentKind::AnchoredSummaryRecall,
327                role: Role::System,
328                content: anchored.to_prompt_text(),
329                age_index: 0,
330                tool_name: None,
331                base_importance: 1.5,
332                #[cfg(feature = "freshness")]
333                freshness: None,
334            };
335            // Insert after SystemPrompt + MemoryBlock segments so it precedes turns.
336            let insert_at = composed
337                .iter()
338                .position(|s| {
339                    !matches!(s.kind, SegmentKind::SystemPrompt | SegmentKind::MemoryBlock)
340                })
341                .unwrap_or(composed.len());
342            composed.insert(insert_at, recall);
343        }
344
345        // Safety net: if we still exceed the soft cap, emit BudgetExceeded so dashboards surface it.
346        let total_kept_tokens: usize = composed.iter().map(|s| s.token_estimate()).sum();
347        if total_kept_tokens > self.budget.soft_total_cap {
348            self.emit(ContextCompilerEvent::BudgetExceeded {
349                overage: total_kept_tokens.saturating_sub(self.budget.soft_total_cap),
350            });
351        }
352
353        metrics.summarizer_calls = summarizer_calls;
354        metrics.summarizer_failures = summarizer_failures;
355        metrics.elapsed_ms = t0.elapsed().as_millis() as u64;
356
357        ComposedPrompt {
358            segments: composed,
359            anchored_summary: anchored,
360            telemetry: metrics,
361        }
362    }
363}
364
365const fn source_label(kind: SegmentKind) -> &'static str {
366    match kind {
367        SegmentKind::SystemPrompt => "system_prompt",
368        SegmentKind::OlderTurn => "older_turn",
369        SegmentKind::RecentTurn => "recent_turn",
370        SegmentKind::ToolDefinitions => "tool_definitions",
371        SegmentKind::ToolResult => "tool_result",
372        SegmentKind::UserPrompt => "user_prompt",
373        SegmentKind::AnchoredSummaryRecall => "anchored_summary_recall",
374        SegmentKind::MemoryBlock => "memory_block",
375    }
376}
377
378#[cfg(test)]
379mod tests {
380    use super::*;
381    use crate::summarizer::SummarizerError;
382    use std::sync::Mutex;
383
384    #[derive(Default)]
385    struct CapturingSink {
386        events: Mutex<Vec<ContextCompilerEvent>>,
387    }
388
389    impl ContextEmissionSink for CapturingSink {
390        fn emit(&self, event: ContextCompilerEvent) {
391            self.events.lock().expect("lock").push(event);
392        }
393    }
394
395    fn long_text(prefix: &str, n: usize) -> String {
396        let mut out = String::new();
397        for i in 0..n {
398            out.push_str(prefix);
399            out.push_str(&format!(" sentence {i}. "));
400        }
401        out
402    }
403
404    #[test]
405    fn tier0_compose_keeps_system_and_user_verbatim() {
406        let compiler = ContextCompiler::with_defaults();
407        let segments = vec![
408            Segment::system_prompt("You are a helpful assistant."),
409            Segment::user_prompt("Help me debug a tokio runtime issue."),
410        ];
411        let out = compiler.compose("Help me debug a tokio runtime issue.", segments, None, None);
412        assert_eq!(out.segments.len(), 2);
413        assert!(out
414            .segments
415            .iter()
416            .any(|s| s.kind == SegmentKind::SystemPrompt));
417        assert!(out
418            .segments
419            .iter()
420            .any(|s| s.kind == SegmentKind::UserPrompt));
421        assert_eq!(out.telemetry.tier, "heuristic");
422        assert_eq!(out.telemetry.summarizer_calls, 0);
423    }
424
425    #[test]
426    fn tier0_compresses_long_older_turns_within_budget() {
427        let budget = BudgetPolicy {
428            total_window: 4_000, // tight to force compression
429            ..BudgetPolicy::default()
430        };
431        let compiler = ContextCompiler::new(Arc::new(HeuristicScorer::new()), budget);
432        let segments = vec![
433            Segment::system_prompt("system"),
434            Segment::older_turn(
435                Role::Assistant,
436                long_text("rust borrow checker tokio", 200),
437                10,
438            ),
439            Segment::user_prompt("rust tokio"),
440        ];
441        let out = compiler.compose("rust tokio", segments, None, None);
442        // System + user always survive.
443        assert!(out
444            .segments
445            .iter()
446            .any(|s| s.kind == SegmentKind::SystemPrompt));
447        assert!(out
448            .segments
449            .iter()
450            .any(|s| s.kind == SegmentKind::UserPrompt));
451        // Older turn either compressed or dropped; metrics show non-zero original.
452        assert!(out.telemetry.total_original_tokens > 0);
453    }
454
455    #[test]
456    fn sink_receives_tier_and_block_events() {
457        let sink = Arc::new(CapturingSink::default());
458        let compiler = ContextCompiler::with_defaults().with_sink(sink.clone());
459        let segments = vec![Segment::system_prompt("sys"), Segment::user_prompt("hi")];
460        let _ = compiler.compose("hi", segments, None, None);
461        let events = sink.events.lock().unwrap();
462        assert!(events
463            .iter()
464            .any(|e| matches!(e, ContextCompilerEvent::TierSelected { .. })));
465        assert!(events
466            .iter()
467            .any(|e| matches!(e, ContextCompilerEvent::BlockEmitted { .. })));
468        assert!(events
469            .iter()
470            .any(|e| matches!(e, ContextCompilerEvent::BudgetAllocated { .. })));
471    }
472
473    #[test]
474    fn tier1_summarizer_invoked_on_dropped_older_turns() {
475        struct MockSummarizer;
476        impl Summarizer for MockSummarizer {
477            fn summarize(
478                &self,
479                segments: &[Segment],
480                _existing: Option<&AnchoredSummary>,
481            ) -> Result<AnchoredSummary, SummarizerError> {
482                let mut s = AnchoredSummary::empty();
483                s.sections[0].content = format!("Summarized {} segments.", segments.len());
484                Ok(s)
485            }
486        }
487        let budget = BudgetPolicy {
488            total_window: 2_000,
489            ..BudgetPolicy::default()
490        };
491        let compiler = ContextCompiler::new(Arc::new(HeuristicScorer::new()), budget)
492            .with_summarizer(Arc::new(MockSummarizer));
493        // Many large older turns, guaranteed to overflow → summarizer fires.
494        let mut segments: Vec<Segment> = (0..30)
495            .map(|i| Segment::older_turn(Role::Assistant, long_text("rust", 100), i + 5))
496            .collect();
497        segments.insert(0, Segment::system_prompt("sys"));
498        segments.push(Segment::user_prompt("rust"));
499        let out = compiler.compose("rust", segments, None, None);
500        assert_eq!(out.telemetry.tier, "heuristic_summarization");
501        assert!(out.telemetry.summarizer_calls > 0);
502        assert!(!out.anchored_summary.is_empty());
503        // Anchored summary should appear in the composed segments as a recall block.
504        assert!(out
505            .segments
506            .iter()
507            .any(|s| s.kind == SegmentKind::AnchoredSummaryRecall));
508    }
509
510    #[test]
511    fn with_embedder_reranks_unpinned_without_panic() {
512        use crate::embedder::PlaceholderEmbedder;
513        let budget = BudgetPolicy {
514            total_window: 1_000,
515            ..BudgetPolicy::default()
516        };
517        let compiler = ContextCompiler::new(Arc::new(HeuristicScorer::new()), budget)
518            .with_embedder(Arc::new(PlaceholderEmbedder::new()));
519        // Two older turns with different text; user prompt matches the second.
520        let segments = vec![
521            Segment::system_prompt("sys"),
522            Segment::older_turn(Role::User, "unrelated zzz", 4),
523            Segment::older_turn(Role::Assistant, "the answer is forty two", 3),
524            Segment::user_prompt("forty two"),
525        ];
526        let out = compiler.compose("forty two", segments, None, None);
527        assert!(!out.segments.is_empty());
528        assert_eq!(out.telemetry.tier, "heuristic_summarization_embedding");
529    }
530
531    #[test]
532    fn summarizer_failure_degrades_gracefully() {
533        struct FailingSummarizer;
534        impl Summarizer for FailingSummarizer {
535            fn summarize(
536                &self,
537                _segments: &[Segment],
538                _existing: Option<&AnchoredSummary>,
539            ) -> Result<AnchoredSummary, SummarizerError> {
540                Err(SummarizerError::Timeout)
541            }
542        }
543        let sink = Arc::new(CapturingSink::default());
544        let budget = BudgetPolicy {
545            total_window: 1_500,
546            ..BudgetPolicy::default()
547        };
548        let compiler = ContextCompiler::new(Arc::new(HeuristicScorer::new()), budget)
549            .with_summarizer(Arc::new(FailingSummarizer))
550            .with_sink(sink.clone());
551        let mut segments: Vec<Segment> = (0..20)
552            .map(|i| Segment::older_turn(Role::Assistant, long_text("rust", 80), i + 5))
553            .collect();
554        segments.insert(0, Segment::system_prompt("sys"));
555        segments.push(Segment::user_prompt("rust"));
556        let out = compiler.compose("rust", segments, None, None);
557        assert!(out.telemetry.summarizer_failures > 0);
558        let events = sink.events.lock().unwrap();
559        assert!(events
560            .iter()
561            .any(|e| matches!(e, ContextCompilerEvent::SummarizerFailed { .. })));
562    }
563}