Skip to main content

sochdb_query/
context_compiler.rs

1//! Context compiler — hard-budget context assembly as a query primitive.
2//!
3//! Composes exact BPE counting, temporal decay, weighted RRF fusion, and MMR
4//! diversity into a single entry point returning a packed block ≤ budget B.
5
6use crate::exact_token_counter::count_tokens_exact;
7use crate::temporal_decay::{TemporalDecayConfig, TemporalScorer};
8use crate::unified_fusion::fuse_rrf_weighted;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12pub type DocId = u64;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct ContextSpec {
16    pub budget: usize,
17    pub bm25_weight: f32,
18    pub trigram_weight: f32,
19    pub vector_weight: f32,
20    pub mmr_lambda: f32,
21    pub decay_half_life_secs: f64,
22    pub template: ContextTemplate,
23}
24
25impl Default for ContextSpec {
26    fn default() -> Self {
27        Self {
28            budget: 4096,
29            bm25_weight: 0.4,
30            trigram_weight: 0.2,
31            vector_weight: 0.4,
32            mmr_lambda: 0.7,
33            decay_half_life_secs: 86_400.0 * 7.0,
34            template: ContextTemplate::Markdown,
35        }
36    }
37}
38
39#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
40pub enum ContextTemplate {
41    Markdown,
42    Toon,
43    Plain,
44}
45
46#[derive(Debug, Clone)]
47pub struct ContextCandidate {
48    pub doc_id: DocId,
49    pub text: String,
50    pub relevance: f32,
51    pub timestamp_secs: f64,
52    pub episode_id: Option<u64>,
53    pub t_valid_from: Option<u64>,
54    pub t_valid_to: Option<u64>,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct CompiledFact {
59    pub text: String,
60    pub tokens: usize,
61    pub episode_id: Option<u64>,
62    pub t_valid_from: Option<u64>,
63    pub t_valid_to: Option<u64>,
64    pub trust_hint: f32,
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct CompiledContext {
69    pub body: String,
70    pub exact_tokens: usize,
71    pub budget: usize,
72    pub facts: Vec<CompiledFact>,
73    pub truncated: bool,
74}
75
76/// Greedy MMR selection with exact BPE running sum; stops when budget exhausted.
77pub struct ContextCompiler {
78    decay: TemporalScorer,
79}
80
81impl ContextCompiler {
82    pub fn new(spec: &ContextSpec) -> Self {
83        let decay_cfg = TemporalDecayConfig {
84            half_life_secs: spec.decay_half_life_secs,
85            ..TemporalDecayConfig::default()
86        };
87        Self {
88            decay: TemporalScorer::new(decay_cfg),
89        }
90    }
91
92    pub fn compile(
93        &self,
94        spec: &ContextSpec,
95        bm25: &[(DocId, f32)],
96        trigram: &[(DocId, f32)],
97        vector: &[(DocId, f32)],
98        texts: &HashMap<DocId, ContextCandidate>,
99    ) -> CompiledContext {
100        let fused = self.fuse_lanes(spec, bm25, trigram, vector);
101        let mut decayed: Vec<(DocId, f32)> = fused
102            .into_iter()
103            .filter_map(|(id, score)| {
104                texts.get(&id).map(|c| {
105                    let final_score = self.decay.final_score(score, c.timestamp_secs);
106                    (id, final_score)
107                })
108            })
109            .collect();
110        decayed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
111
112        let selected = self.mmr_select(&decayed, texts, spec.budget, spec.mmr_lambda);
113        self.render(spec, &selected, texts)
114    }
115
116    fn fuse_lanes(
117        &self,
118        spec: &ContextSpec,
119        bm25: &[(DocId, f32)],
120        trigram: &[(DocId, f32)],
121        vector: &[(DocId, f32)],
122    ) -> HashMap<DocId, f32> {
123        use crate::filtered_vector_search::ScoredResult;
124        use crate::unified_fusion::RankedList;
125
126        let to_scored = |hits: &[(DocId, f32)]| {
127            hits.iter()
128                .map(|(id, score)| ScoredResult::new(*id, *score))
129                .collect::<Vec<_>>()
130        };
131
132        let bm25_scored = to_scored(bm25);
133        let trigram_scored = to_scored(trigram);
134        let vector_scored = to_scored(vector);
135
136        let mut lists = Vec::new();
137        if !bm25_scored.is_empty() {
138            lists.push(RankedList {
139                results: &bm25_scored,
140                weight: spec.bm25_weight,
141            });
142        }
143        if !trigram_scored.is_empty() {
144            lists.push(RankedList {
145                results: &trigram_scored,
146                weight: spec.trigram_weight,
147            });
148        }
149        if !vector_scored.is_empty() {
150            lists.push(RankedList {
151                results: &vector_scored,
152                weight: spec.vector_weight,
153            });
154        }
155        fuse_rrf_weighted(&lists, 60.0)
156            .into_iter()
157            .map(|(id, score)| (id.0, score))
158            .collect()
159    }
160
161    fn mmr_select(
162        &self,
163        ranked: &[(DocId, f32)],
164        texts: &HashMap<DocId, ContextCandidate>,
165        budget: usize,
166        lambda: f32,
167    ) -> Vec<DocId> {
168        let mut selected: Vec<DocId> = Vec::new();
169        let mut used_tokens = 0usize;
170        let mut remaining: Vec<(DocId, f32)> = ranked.to_vec();
171
172        while !remaining.is_empty() && used_tokens < budget {
173            let mut best_idx = 0;
174            let mut best_mmr = f32::NEG_INFINITY;
175            for (i, (id, rel)) in remaining.iter().enumerate() {
176                let Some(cand) = texts.get(id) else { continue };
177                let tok = count_tokens_exact(&cand.text);
178                if used_tokens + tok > budget {
179                    continue;
180                }
181                let max_sim = selected
182                    .iter()
183                    .filter_map(|sid| texts.get(sid))
184                    .map(|s| jaccard(&cand.text, &s.text))
185                    .fold(0.0f32, f32::max);
186                let mmr = lambda * rel - (1.0 - lambda) * max_sim;
187                if mmr > best_mmr {
188                    best_mmr = mmr;
189                    best_idx = i;
190                }
191            }
192            let (id, _) = remaining.remove(best_idx);
193            let Some(cand) = texts.get(&id) else { break };
194            used_tokens += count_tokens_exact(&cand.text);
195            selected.push(id);
196        }
197        selected
198    }
199
200    fn render(
201        &self,
202        spec: &ContextSpec,
203        selected: &[DocId],
204        texts: &HashMap<DocId, ContextCandidate>,
205    ) -> CompiledContext {
206        let mut facts = Vec::new();
207        let mut body_parts = Vec::new();
208        let mut total_tokens = 0usize;
209
210        for id in selected {
211            let Some(c) = texts.get(id) else { continue };
212            let tok = count_tokens_exact(&c.text);
213            if total_tokens + tok > spec.budget {
214                break;
215            }
216            total_tokens += tok;
217            let rendered = match spec.template {
218                ContextTemplate::Markdown => {
219                    format!("### Memory {}\n{}\n", id, c.text)
220                }
221                ContextTemplate::Toon => format!("mem{}|{}\n", id, c.text.replace('\n', " ")),
222                ContextTemplate::Plain => c.text.clone(),
223            };
224            body_parts.push(rendered);
225            facts.push(CompiledFact {
226                text: c.text.clone(),
227                tokens: tok,
228                episode_id: c.episode_id,
229                t_valid_from: c.t_valid_from,
230                t_valid_to: c.t_valid_to,
231                trust_hint: c.relevance,
232            });
233        }
234
235        let truncated = facts.len() < selected.len();
236        CompiledContext {
237            body: body_parts.join("\n"),
238            exact_tokens: total_tokens,
239            budget: spec.budget,
240            facts,
241            truncated,
242        }
243    }
244}
245
246fn jaccard(a: &str, b: &str) -> f32 {
247    let sa: std::collections::HashSet<_> = a.split_whitespace().collect();
248    let sb: std::collections::HashSet<_> = b.split_whitespace().collect();
249    if sa.is_empty() && sb.is_empty() {
250        return 0.0;
251    }
252    let inter = sa.intersection(&sb).count() as f32;
253    let union = sa.union(&sb).count() as f32;
254    if union == 0.0 { 0.0 } else { inter / union }
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260    use std::collections::HashMap;
261
262    #[test]
263    fn compile_respects_exact_budget() {
264        let spec = ContextSpec {
265            budget: 50,
266            ..ContextSpec::default()
267        };
268        let compiler = ContextCompiler::new(&spec);
269        let bm25 = vec![(1u64, 1.0), (2, 0.8)];
270        let mut texts = HashMap::new();
271        texts.insert(
272            1,
273            ContextCandidate {
274                doc_id: 1,
275                text: "short memory".into(),
276                relevance: 1.0,
277                timestamp_secs: 0.0,
278                episode_id: Some(1),
279                t_valid_from: None,
280                t_valid_to: None,
281            },
282        );
283        texts.insert(
284            2,
285            ContextCandidate {
286                doc_id: 2,
287                text: "a much longer memory entry that would exceed the token budget if included"
288                    .into(),
289                relevance: 0.8,
290                timestamp_secs: 0.0,
291                episode_id: Some(2),
292                t_valid_from: None,
293                t_valid_to: None,
294            },
295        );
296        let out = compiler.compile(&spec, &bm25, &[], &[], &texts);
297        assert!(out.exact_tokens <= spec.budget);
298        assert!(!out.facts.is_empty());
299    }
300}