1use crate::exact_token_counter::count_tokens_exact;
7use crate::temporal_decay::{TemporalDecayConfig, TemporalScorer};
8use crate::unified_fusion::fuse_rrf_weighted;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12pub type DocId = u64;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct ContextSpec {
16 pub budget: usize,
17 pub bm25_weight: f32,
18 pub trigram_weight: f32,
19 pub vector_weight: f32,
20 pub mmr_lambda: f32,
21 pub decay_half_life_secs: f64,
22 pub template: ContextTemplate,
23}
24
25impl Default for ContextSpec {
26 fn default() -> Self {
27 Self {
28 budget: 4096,
29 bm25_weight: 0.4,
30 trigram_weight: 0.2,
31 vector_weight: 0.4,
32 mmr_lambda: 0.7,
33 decay_half_life_secs: 86_400.0 * 7.0,
34 template: ContextTemplate::Markdown,
35 }
36 }
37}
38
39#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
40pub enum ContextTemplate {
41 Markdown,
42 Toon,
43 Plain,
44}
45
46#[derive(Debug, Clone)]
47pub struct ContextCandidate {
48 pub doc_id: DocId,
49 pub text: String,
50 pub relevance: f32,
51 pub timestamp_secs: f64,
52 pub episode_id: Option<u64>,
53 pub t_valid_from: Option<u64>,
54 pub t_valid_to: Option<u64>,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct CompiledFact {
59 pub text: String,
60 pub tokens: usize,
61 pub episode_id: Option<u64>,
62 pub t_valid_from: Option<u64>,
63 pub t_valid_to: Option<u64>,
64 pub trust_hint: f32,
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct CompiledContext {
69 pub body: String,
70 pub exact_tokens: usize,
71 pub budget: usize,
72 pub facts: Vec<CompiledFact>,
73 pub truncated: bool,
74}
75
76pub struct ContextCompiler {
78 decay: TemporalScorer,
79}
80
81impl ContextCompiler {
82 pub fn new(spec: &ContextSpec) -> Self {
83 let decay_cfg = TemporalDecayConfig {
84 half_life_secs: spec.decay_half_life_secs,
85 ..TemporalDecayConfig::default()
86 };
87 Self {
88 decay: TemporalScorer::new(decay_cfg),
89 }
90 }
91
92 pub fn compile(
93 &self,
94 spec: &ContextSpec,
95 bm25: &[(DocId, f32)],
96 trigram: &[(DocId, f32)],
97 vector: &[(DocId, f32)],
98 texts: &HashMap<DocId, ContextCandidate>,
99 ) -> CompiledContext {
100 let fused = self.fuse_lanes(spec, bm25, trigram, vector);
101 let mut decayed: Vec<(DocId, f32)> = fused
102 .into_iter()
103 .filter_map(|(id, score)| {
104 texts.get(&id).map(|c| {
105 let final_score = self.decay.final_score(score, c.timestamp_secs);
106 (id, final_score)
107 })
108 })
109 .collect();
110 decayed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
111
112 let selected = self.mmr_select(&decayed, texts, spec.budget, spec.mmr_lambda);
113 self.render(spec, &selected, texts)
114 }
115
116 fn fuse_lanes(
117 &self,
118 spec: &ContextSpec,
119 bm25: &[(DocId, f32)],
120 trigram: &[(DocId, f32)],
121 vector: &[(DocId, f32)],
122 ) -> HashMap<DocId, f32> {
123 use crate::filtered_vector_search::ScoredResult;
124 use crate::unified_fusion::RankedList;
125
126 let to_scored = |hits: &[(DocId, f32)]| {
127 hits.iter()
128 .map(|(id, score)| ScoredResult::new(*id, *score))
129 .collect::<Vec<_>>()
130 };
131
132 let bm25_scored = to_scored(bm25);
133 let trigram_scored = to_scored(trigram);
134 let vector_scored = to_scored(vector);
135
136 let mut lists = Vec::new();
137 if !bm25_scored.is_empty() {
138 lists.push(RankedList {
139 results: &bm25_scored,
140 weight: spec.bm25_weight,
141 });
142 }
143 if !trigram_scored.is_empty() {
144 lists.push(RankedList {
145 results: &trigram_scored,
146 weight: spec.trigram_weight,
147 });
148 }
149 if !vector_scored.is_empty() {
150 lists.push(RankedList {
151 results: &vector_scored,
152 weight: spec.vector_weight,
153 });
154 }
155 fuse_rrf_weighted(&lists, 60.0)
156 .into_iter()
157 .map(|(id, score)| (id.0, score))
158 .collect()
159 }
160
161 fn mmr_select(
162 &self,
163 ranked: &[(DocId, f32)],
164 texts: &HashMap<DocId, ContextCandidate>,
165 budget: usize,
166 lambda: f32,
167 ) -> Vec<DocId> {
168 let mut selected: Vec<DocId> = Vec::new();
169 let mut used_tokens = 0usize;
170 let mut remaining: Vec<(DocId, f32)> = ranked.to_vec();
171
172 while !remaining.is_empty() && used_tokens < budget {
173 let mut best_idx = 0;
174 let mut best_mmr = f32::NEG_INFINITY;
175 for (i, (id, rel)) in remaining.iter().enumerate() {
176 let Some(cand) = texts.get(id) else { continue };
177 let tok = count_tokens_exact(&cand.text);
178 if used_tokens + tok > budget {
179 continue;
180 }
181 let max_sim = selected
182 .iter()
183 .filter_map(|sid| texts.get(sid))
184 .map(|s| jaccard(&cand.text, &s.text))
185 .fold(0.0f32, f32::max);
186 let mmr = lambda * rel - (1.0 - lambda) * max_sim;
187 if mmr > best_mmr {
188 best_mmr = mmr;
189 best_idx = i;
190 }
191 }
192 let (id, _) = remaining.remove(best_idx);
193 let Some(cand) = texts.get(&id) else { break };
194 used_tokens += count_tokens_exact(&cand.text);
195 selected.push(id);
196 }
197 selected
198 }
199
200 fn render(
201 &self,
202 spec: &ContextSpec,
203 selected: &[DocId],
204 texts: &HashMap<DocId, ContextCandidate>,
205 ) -> CompiledContext {
206 let mut facts = Vec::new();
207 let mut body_parts = Vec::new();
208 let mut total_tokens = 0usize;
209
210 for id in selected {
211 let Some(c) = texts.get(id) else { continue };
212 let tok = count_tokens_exact(&c.text);
213 if total_tokens + tok > spec.budget {
214 break;
215 }
216 total_tokens += tok;
217 let rendered = match spec.template {
218 ContextTemplate::Markdown => {
219 format!("### Memory {}\n{}\n", id, c.text)
220 }
221 ContextTemplate::Toon => format!("mem{}|{}\n", id, c.text.replace('\n', " ")),
222 ContextTemplate::Plain => c.text.clone(),
223 };
224 body_parts.push(rendered);
225 facts.push(CompiledFact {
226 text: c.text.clone(),
227 tokens: tok,
228 episode_id: c.episode_id,
229 t_valid_from: c.t_valid_from,
230 t_valid_to: c.t_valid_to,
231 trust_hint: c.relevance,
232 });
233 }
234
235 let truncated = facts.len() < selected.len();
236 CompiledContext {
237 body: body_parts.join("\n"),
238 exact_tokens: total_tokens,
239 budget: spec.budget,
240 facts,
241 truncated,
242 }
243 }
244}
245
246fn jaccard(a: &str, b: &str) -> f32 {
247 let sa: std::collections::HashSet<_> = a.split_whitespace().collect();
248 let sb: std::collections::HashSet<_> = b.split_whitespace().collect();
249 if sa.is_empty() && sb.is_empty() {
250 return 0.0;
251 }
252 let inter = sa.intersection(&sb).count() as f32;
253 let union = sa.union(&sb).count() as f32;
254 if union == 0.0 {
255 0.0
256 } else {
257 inter / union
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264 use std::collections::HashMap;
265
266 #[test]
267 fn compile_respects_exact_budget() {
268 let spec = ContextSpec {
269 budget: 50,
270 ..ContextSpec::default()
271 };
272 let compiler = ContextCompiler::new(&spec);
273 let bm25 = vec![(1u64, 1.0), (2, 0.8)];
274 let mut texts = HashMap::new();
275 texts.insert(
276 1,
277 ContextCandidate {
278 doc_id: 1,
279 text: "short memory".into(),
280 relevance: 1.0,
281 timestamp_secs: 0.0,
282 episode_id: Some(1),
283 t_valid_from: None,
284 t_valid_to: None,
285 },
286 );
287 texts.insert(
288 2,
289 ContextCandidate {
290 doc_id: 2,
291 text: "a much longer memory entry that would exceed the token budget if included"
292 .into(),
293 relevance: 0.8,
294 timestamp_secs: 0.0,
295 episode_id: Some(2),
296 t_valid_from: None,
297 t_valid_to: None,
298 },
299 );
300 let out = compiler.compile(&spec, &bm25, &[], &[], &texts);
301 assert!(out.exact_tokens <= spec.budget);
302 assert!(!out.facts.is_empty());
303 }
304}