Skip to main content

engram/
context.rs

1//! Context assembly — token-budgeted context building for LLM prompts.
2//!
3//! The context assembler retrieves facts via hybrid search, ranks them by
4//! tier priority, fills a token budget greedily, and formats the output
5//! for injection into system prompts, messages, or raw JSON.
6
7use crate::embedding::EmbeddingProvider;
8use crate::fact::{Fact, MemoryTier};
9use crate::graph::GraphStore;
10use crate::retrieve::{HybridRetriever, RetrievalConfig};
11use crate::scope::Scope;
12use crate::store::{FactStore, MemoryError};
13use crate::vector::VectorStore;
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16use std::sync::Arc;
17
18// ---------------------------------------------------------------------------
19// TokenEstimator
20// ---------------------------------------------------------------------------
21
22/// Pluggable token estimation. Implementations convert text to an approximate
23/// token count without requiring a full tokenizer dependency.
24pub trait TokenEstimator: Send + Sync {
25    /// Estimate the number of tokens in `text`.
26    fn estimate(&self, text: &str) -> usize;
27}
28
29/// Character-based token estimator (~4 chars per token).
30/// Accurate to within ~10% for English text across GPT/Claude models.
31pub struct CharTokenEstimator;
32
33impl TokenEstimator for CharTokenEstimator {
34    fn estimate(&self, text: &str) -> usize {
35        // ~4 characters per token is a widely-used heuristic
36        text.len().div_ceil(4)
37    }
38}
39
40// ---------------------------------------------------------------------------
41// OutputFormat
42// ---------------------------------------------------------------------------
43
44/// Format for the assembled context block.
45#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
46#[serde(rename_all = "snake_case")]
47pub enum OutputFormat {
48    /// XML-tagged block for system prompt injection.
49    #[default]
50    SystemPrompt,
51    /// Human-readable Markdown summary.
52    Markdown,
53    /// Structured JSON with all metadata.
54    Raw,
55}
56
57// ---------------------------------------------------------------------------
58// ContextConfig
59// ---------------------------------------------------------------------------
60
61/// Configuration for context assembly.
62#[derive(Debug, Clone)]
63pub struct ContextConfig {
64    /// Maximum number of tokens to include. Default: 2000.
65    pub token_budget: usize,
66    /// Output format. Default: SystemPrompt.
67    pub format: OutputFormat,
68    /// Maximum number of candidate facts to retrieve before budget filtering.
69    /// Default: 50.
70    pub max_candidates: usize,
71}
72
73impl Default for ContextConfig {
74    fn default() -> Self {
75        Self {
76            token_budget: 2000,
77            format: OutputFormat::SystemPrompt,
78            max_candidates: 50,
79        }
80    }
81}
82
83// ---------------------------------------------------------------------------
84// ContextBlock
85// ---------------------------------------------------------------------------
86
87/// The assembled context block returned by `ContextBuilder::build()`.
88#[derive(Debug, Clone, Serialize)]
89pub struct ContextBlock {
90    /// Formatted text ready for prompt injection.
91    pub text: String,
92    /// Estimated token count of the `text` field.
93    pub token_count: usize,
94    /// Number of facts included in the output.
95    pub facts_included: usize,
96    /// Number of facts that were retrieved but omitted due to budget.
97    pub facts_omitted: usize,
98    /// Breakdown of included facts by tier.
99    pub tier_breakdown: HashMap<String, usize>,
100}
101
102// ---------------------------------------------------------------------------
103// Tier priority
104// ---------------------------------------------------------------------------
105
106/// Returns the priority order for `MemoryTier`. Lower value = higher priority.
107fn tier_priority(tier: &MemoryTier) -> u8 {
108    match tier {
109        MemoryTier::Working => 0,
110        MemoryTier::Conversation => 1,
111        MemoryTier::Knowledge => 2,
112    }
113}
114
115/// Sort facts by tier priority (Working first, then Conversation, then Knowledge).
116/// Within the same tier, preserve the original order (which is by retrieval score).
117pub fn sort_by_tier_priority(facts: &mut [Fact]) {
118    facts.sort_by_key(|f| tier_priority(&f.tier));
119}
120
121// ---------------------------------------------------------------------------
122// Formatters
123// ---------------------------------------------------------------------------
124
125/// Format facts as an XML-tagged block for system prompt injection.
126pub fn format_system_prompt(facts: &[Fact]) -> String {
127    let mut working = Vec::new();
128    let mut conversation = Vec::new();
129    let mut knowledge = Vec::new();
130
131    for fact in facts {
132        let line = format!("- {}", fact.text);
133        match fact.tier {
134            MemoryTier::Working => working.push(line),
135            MemoryTier::Conversation => conversation.push(line),
136            MemoryTier::Knowledge => knowledge.push(line),
137        }
138    }
139
140    let mut out = String::from("<memory>\n");
141
142    if !working.is_empty() {
143        out.push_str("<working>\n");
144        for line in &working {
145            out.push_str(line);
146            out.push('\n');
147        }
148        out.push_str("</working>\n");
149    }
150    if !conversation.is_empty() {
151        out.push_str("<conversation>\n");
152        for line in &conversation {
153            out.push_str(line);
154            out.push('\n');
155        }
156        out.push_str("</conversation>\n");
157    }
158    if !knowledge.is_empty() {
159        out.push_str("<knowledge>\n");
160        for line in &knowledge {
161            out.push_str(line);
162            out.push('\n');
163        }
164        out.push_str("</knowledge>\n");
165    }
166
167    out.push_str("</memory>");
168    out
169}
170
171/// Format facts as human-readable Markdown.
172pub fn format_markdown(facts: &[Fact]) -> String {
173    let mut out = String::from("## Memory Context\n\n");
174
175    let mut current_tier: Option<&MemoryTier> = None;
176    for fact in facts {
177        if current_tier != Some(&fact.tier) {
178            let label = match fact.tier {
179                MemoryTier::Working => "Working Memory",
180                MemoryTier::Conversation => "Conversation",
181                MemoryTier::Knowledge => "Knowledge",
182            };
183            out.push_str(&format!("### {label}\n\n"));
184            current_tier = Some(&fact.tier);
185        }
186        out.push_str(&format!("- {}\n", fact.text));
187    }
188
189    out
190}
191
192/// Format facts as raw JSON (array of objects with text, tier, category, confidence).
193pub fn format_raw(facts: &[Fact]) -> String {
194    let entries: Vec<serde_json::Value> = facts
195        .iter()
196        .map(|f| {
197            serde_json::json!({
198                "id": f.id.to_string(),
199                "text": f.text,
200                "tier": f.tier,
201                "category": f.category,
202                "confidence": f.confidence,
203            })
204        })
205        .collect();
206    serde_json::to_string_pretty(&entries).unwrap_or_else(|_| "[]".to_string())
207}
208
209// ---------------------------------------------------------------------------
210// ContextBuilder
211// ---------------------------------------------------------------------------
212
213/// Token-budgeted context assembler.
214///
215/// Retrieves facts via hybrid search, ranks by tier priority, fills
216/// a token budget greedily, and returns a `ContextBlock` ready for
217/// LLM prompt injection.
218pub struct ContextBuilder {
219    fact_store: Arc<dyn FactStore>,
220    vector_store: Arc<dyn VectorStore>,
221    graph_store: Arc<dyn GraphStore>,
222    embedding: Arc<dyn EmbeddingProvider>,
223    estimator: Box<dyn TokenEstimator>,
224    config: ContextConfig,
225}
226
227impl ContextBuilder {
228    pub fn new(
229        fact_store: Arc<dyn FactStore>,
230        vector_store: Arc<dyn VectorStore>,
231        graph_store: Arc<dyn GraphStore>,
232        embedding: Arc<dyn EmbeddingProvider>,
233        config: ContextConfig,
234    ) -> Self {
235        Self {
236            fact_store,
237            vector_store,
238            graph_store,
239            embedding,
240            estimator: Box::new(CharTokenEstimator),
241            config,
242        }
243    }
244
245    /// Override the default token estimator.
246    pub fn with_estimator(mut self, estimator: Box<dyn TokenEstimator>) -> Self {
247        self.estimator = estimator;
248        self
249    }
250
251    /// Build a context block for the given query and scope.
252    pub async fn build(&self, query: &str, scope: &Scope) -> Result<ContextBlock, MemoryError> {
253        // 1. Retrieve candidates via hybrid search
254        let retriever = HybridRetriever::new(
255            self.fact_store.clone(),
256            self.vector_store.clone(),
257            self.graph_store.clone(),
258            self.embedding.clone(),
259            RetrievalConfig::default(),
260        );
261
262        let scored = retriever
263            .search(query, scope, self.config.max_candidates)
264            .await?;
265
266        // 2. Extract facts and sort by tier priority
267        let mut facts: Vec<Fact> = scored.into_iter().map(|sf| sf.fact).collect();
268        sort_by_tier_priority(&mut facts);
269
270        // 3. Greedy budget filling
271        let mut included: Vec<Fact> = Vec::new();
272        let mut token_count: usize = 0;
273        let overhead = self.format_overhead();
274        let budget_for_facts = self.config.token_budget.saturating_sub(overhead);
275
276        for fact in &facts {
277            let fact_tokens = self.estimator.estimate(&fact.text) + 2; // "- " prefix + newline
278            if token_count + fact_tokens > budget_for_facts {
279                continue; // Skip this fact, try smaller ones
280            }
281            token_count += fact_tokens;
282            included.push(fact.clone());
283        }
284
285        let facts_omitted = facts.len() - included.len();
286
287        // 4. Build tier breakdown
288        let mut tier_breakdown: HashMap<String, usize> = HashMap::new();
289        for fact in &included {
290            let tier_name = match fact.tier {
291                MemoryTier::Working => "working",
292                MemoryTier::Conversation => "conversation",
293                MemoryTier::Knowledge => "knowledge",
294            };
295            *tier_breakdown.entry(tier_name.to_string()).or_insert(0) += 1;
296        }
297
298        // 5. Format output
299        let text = match self.config.format {
300            OutputFormat::SystemPrompt => format_system_prompt(&included),
301            OutputFormat::Markdown => format_markdown(&included),
302            OutputFormat::Raw => format_raw(&included),
303        };
304
305        let total_tokens = self.estimator.estimate(&text);
306
307        Ok(ContextBlock {
308            text,
309            token_count: total_tokens,
310            facts_included: included.len(),
311            facts_omitted,
312            tier_breakdown,
313        })
314    }
315
316    /// Estimate the token overhead of the format wrapper (tags, headers, etc.).
317    fn format_overhead(&self) -> usize {
318        match self.config.format {
319            // <memory>\n</memory> + tier tags ≈ 60 chars ≈ 15 tokens
320            OutputFormat::SystemPrompt => 15,
321            // ## Memory Context\n\n### Tier\n\n ≈ 50 chars ≈ 13 tokens
322            OutputFormat::Markdown => 13,
323            // JSON brackets and formatting ≈ 20 chars ≈ 5 tokens
324            OutputFormat::Raw => 5,
325        }
326    }
327}
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332
333    #[test]
334    fn char_estimator_basic() {
335        let est = CharTokenEstimator;
336        // "hello world" = 11 chars → (11+3)/4 = 3 tokens
337        assert_eq!(est.estimate("hello world"), 3);
338        assert_eq!(est.estimate(""), 0);
339        // 100 chars → (100+3)/4 = 25 tokens
340        let hundred = "a".repeat(100);
341        assert_eq!(est.estimate(&hundred), 25);
342    }
343
344    #[test]
345    fn context_config_defaults() {
346        let cfg = ContextConfig::default();
347        assert_eq!(cfg.token_budget, 2000);
348        assert_eq!(cfg.format, OutputFormat::SystemPrompt);
349        assert_eq!(cfg.max_candidates, 50);
350    }
351
352    #[test]
353    fn tier_priority_ordering() {
354        assert!(tier_priority(&MemoryTier::Working) < tier_priority(&MemoryTier::Conversation));
355        assert!(tier_priority(&MemoryTier::Conversation) < tier_priority(&MemoryTier::Knowledge));
356    }
357
358    #[test]
359    fn sort_by_tier_groups_correctly() {
360        use crate::scope::Scope;
361        let scope = Scope::org("test");
362        let mut facts = vec![
363            Fact::new("knowledge fact", scope.clone()).with_tier(MemoryTier::Knowledge),
364            Fact::new("working fact", scope.clone()).with_tier(MemoryTier::Working),
365            Fact::new("convo fact", scope).with_tier(MemoryTier::Conversation),
366        ];
367        sort_by_tier_priority(&mut facts);
368        assert_eq!(facts[0].tier, MemoryTier::Working);
369        assert_eq!(facts[1].tier, MemoryTier::Conversation);
370        assert_eq!(facts[2].tier, MemoryTier::Knowledge);
371    }
372
373    #[test]
374    fn format_system_prompt_groups_by_tier() {
375        use crate::scope::Scope;
376        let scope = Scope::org("test");
377        let facts = vec![
378            Fact::new("ephemeral note", scope.clone()).with_tier(MemoryTier::Working),
379            Fact::new("user likes pizza", scope.clone()).with_tier(MemoryTier::Conversation),
380            Fact::new("company policy", scope).with_tier(MemoryTier::Knowledge),
381        ];
382        let output = format_system_prompt(&facts);
383        assert!(output.starts_with("<memory>"));
384        assert!(output.ends_with("</memory>"));
385        assert!(output.contains("<working>"));
386        assert!(output.contains("- ephemeral note"));
387        assert!(output.contains("<conversation>"));
388        assert!(output.contains("- user likes pizza"));
389        assert!(output.contains("<knowledge>"));
390        assert!(output.contains("- company policy"));
391    }
392
393    #[test]
394    fn format_system_prompt_omits_empty_tiers() {
395        use crate::scope::Scope;
396        let scope = Scope::org("test");
397        let facts = vec![Fact::new("just a convo fact", scope).with_tier(MemoryTier::Conversation)];
398        let output = format_system_prompt(&facts);
399        assert!(!output.contains("<working>"));
400        assert!(output.contains("<conversation>"));
401        assert!(!output.contains("<knowledge>"));
402    }
403
404    #[test]
405    fn format_markdown_produces_sections() {
406        use crate::scope::Scope;
407        let scope = Scope::org("test");
408        let facts = vec![
409            Fact::new("user prefers dark mode", scope.clone()).with_tier(MemoryTier::Knowledge),
410            Fact::new("user asked about Rust", scope).with_tier(MemoryTier::Knowledge),
411        ];
412        let output = format_markdown(&facts);
413        assert!(output.contains("## Memory Context"));
414        assert!(output.contains("### Knowledge"));
415        assert!(output.contains("- user prefers dark mode"));
416    }
417
418    #[test]
419    fn format_raw_produces_valid_json() {
420        use crate::scope::Scope;
421        let scope = Scope::org("test");
422        let facts = vec![Fact::new("test fact", scope).with_tier(MemoryTier::Conversation)];
423        let output = format_raw(&facts);
424        let parsed: Vec<serde_json::Value> = serde_json::from_str(&output).unwrap();
425        assert_eq!(parsed.len(), 1);
426        assert_eq!(parsed[0]["text"], "test fact");
427        assert_eq!(parsed[0]["tier"], "conversation");
428    }
429}