1use crate::embedding::EmbeddingProvider;
8use crate::fact::{Fact, MemoryTier};
9use crate::graph::GraphStore;
10use crate::retrieve::{HybridRetriever, RetrievalConfig};
11use crate::scope::Scope;
12use crate::store::{FactStore, MemoryError};
13use crate::vector::VectorStore;
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16use std::sync::Arc;
17
18pub trait TokenEstimator: Send + Sync {
25 fn estimate(&self, text: &str) -> usize;
27}
28
29pub struct CharTokenEstimator;
32
33impl TokenEstimator for CharTokenEstimator {
34 fn estimate(&self, text: &str) -> usize {
35 text.len().div_ceil(4)
37 }
38}
39
40#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
46#[serde(rename_all = "snake_case")]
47pub enum OutputFormat {
48 #[default]
50 SystemPrompt,
51 Markdown,
53 Raw,
55}
56
57#[derive(Debug, Clone)]
63pub struct ContextConfig {
64 pub token_budget: usize,
66 pub format: OutputFormat,
68 pub max_candidates: usize,
71}
72
73impl Default for ContextConfig {
74 fn default() -> Self {
75 Self {
76 token_budget: 2000,
77 format: OutputFormat::SystemPrompt,
78 max_candidates: 50,
79 }
80 }
81}
82
83#[derive(Debug, Clone, Serialize)]
89pub struct ContextBlock {
90 pub text: String,
92 pub token_count: usize,
94 pub facts_included: usize,
96 pub facts_omitted: usize,
98 pub tier_breakdown: HashMap<String, usize>,
100}
101
102fn tier_priority(tier: &MemoryTier) -> u8 {
108 match tier {
109 MemoryTier::Working => 0,
110 MemoryTier::Conversation => 1,
111 MemoryTier::Knowledge => 2,
112 }
113}
114
115pub fn sort_by_tier_priority(facts: &mut [Fact]) {
118 facts.sort_by_key(|f| tier_priority(&f.tier));
119}
120
121pub fn format_system_prompt(facts: &[Fact]) -> String {
127 let mut working = Vec::new();
128 let mut conversation = Vec::new();
129 let mut knowledge = Vec::new();
130
131 for fact in facts {
132 let line = format!("- {}", fact.text);
133 match fact.tier {
134 MemoryTier::Working => working.push(line),
135 MemoryTier::Conversation => conversation.push(line),
136 MemoryTier::Knowledge => knowledge.push(line),
137 }
138 }
139
140 let mut out = String::from("<memory>\n");
141
142 if !working.is_empty() {
143 out.push_str("<working>\n");
144 for line in &working {
145 out.push_str(line);
146 out.push('\n');
147 }
148 out.push_str("</working>\n");
149 }
150 if !conversation.is_empty() {
151 out.push_str("<conversation>\n");
152 for line in &conversation {
153 out.push_str(line);
154 out.push('\n');
155 }
156 out.push_str("</conversation>\n");
157 }
158 if !knowledge.is_empty() {
159 out.push_str("<knowledge>\n");
160 for line in &knowledge {
161 out.push_str(line);
162 out.push('\n');
163 }
164 out.push_str("</knowledge>\n");
165 }
166
167 out.push_str("</memory>");
168 out
169}
170
171pub fn format_markdown(facts: &[Fact]) -> String {
173 let mut out = String::from("## Memory Context\n\n");
174
175 let mut current_tier: Option<&MemoryTier> = None;
176 for fact in facts {
177 if current_tier != Some(&fact.tier) {
178 let label = match fact.tier {
179 MemoryTier::Working => "Working Memory",
180 MemoryTier::Conversation => "Conversation",
181 MemoryTier::Knowledge => "Knowledge",
182 };
183 out.push_str(&format!("### {label}\n\n"));
184 current_tier = Some(&fact.tier);
185 }
186 out.push_str(&format!("- {}\n", fact.text));
187 }
188
189 out
190}
191
192pub fn format_raw(facts: &[Fact]) -> String {
194 let entries: Vec<serde_json::Value> = facts
195 .iter()
196 .map(|f| {
197 serde_json::json!({
198 "id": f.id.to_string(),
199 "text": f.text,
200 "tier": f.tier,
201 "category": f.category,
202 "confidence": f.confidence,
203 })
204 })
205 .collect();
206 serde_json::to_string_pretty(&entries).unwrap_or_else(|_| "[]".to_string())
207}
208
209pub struct ContextBuilder {
219 fact_store: Arc<dyn FactStore>,
220 vector_store: Arc<dyn VectorStore>,
221 graph_store: Arc<dyn GraphStore>,
222 embedding: Arc<dyn EmbeddingProvider>,
223 estimator: Box<dyn TokenEstimator>,
224 config: ContextConfig,
225}
226
227impl ContextBuilder {
228 pub fn new(
229 fact_store: Arc<dyn FactStore>,
230 vector_store: Arc<dyn VectorStore>,
231 graph_store: Arc<dyn GraphStore>,
232 embedding: Arc<dyn EmbeddingProvider>,
233 config: ContextConfig,
234 ) -> Self {
235 Self {
236 fact_store,
237 vector_store,
238 graph_store,
239 embedding,
240 estimator: Box::new(CharTokenEstimator),
241 config,
242 }
243 }
244
245 pub fn with_estimator(mut self, estimator: Box<dyn TokenEstimator>) -> Self {
247 self.estimator = estimator;
248 self
249 }
250
251 pub async fn build(&self, query: &str, scope: &Scope) -> Result<ContextBlock, MemoryError> {
253 let retriever = HybridRetriever::new(
255 self.fact_store.clone(),
256 self.vector_store.clone(),
257 self.graph_store.clone(),
258 self.embedding.clone(),
259 RetrievalConfig::default(),
260 );
261
262 let scored = retriever
263 .search(query, scope, self.config.max_candidates)
264 .await?;
265
266 let mut facts: Vec<Fact> = scored.into_iter().map(|sf| sf.fact).collect();
268 sort_by_tier_priority(&mut facts);
269
270 let mut included: Vec<Fact> = Vec::new();
272 let mut token_count: usize = 0;
273 let overhead = self.format_overhead();
274 let budget_for_facts = self.config.token_budget.saturating_sub(overhead);
275
276 for fact in &facts {
277 let fact_tokens = self.estimator.estimate(&fact.text) + 2; if token_count + fact_tokens > budget_for_facts {
279 continue; }
281 token_count += fact_tokens;
282 included.push(fact.clone());
283 }
284
285 let facts_omitted = facts.len() - included.len();
286
287 let mut tier_breakdown: HashMap<String, usize> = HashMap::new();
289 for fact in &included {
290 let tier_name = match fact.tier {
291 MemoryTier::Working => "working",
292 MemoryTier::Conversation => "conversation",
293 MemoryTier::Knowledge => "knowledge",
294 };
295 *tier_breakdown.entry(tier_name.to_string()).or_insert(0) += 1;
296 }
297
298 let text = match self.config.format {
300 OutputFormat::SystemPrompt => format_system_prompt(&included),
301 OutputFormat::Markdown => format_markdown(&included),
302 OutputFormat::Raw => format_raw(&included),
303 };
304
305 let total_tokens = self.estimator.estimate(&text);
306
307 Ok(ContextBlock {
308 text,
309 token_count: total_tokens,
310 facts_included: included.len(),
311 facts_omitted,
312 tier_breakdown,
313 })
314 }
315
316 fn format_overhead(&self) -> usize {
318 match self.config.format {
319 OutputFormat::SystemPrompt => 15,
321 OutputFormat::Markdown => 13,
323 OutputFormat::Raw => 5,
325 }
326 }
327}
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332
333 #[test]
334 fn char_estimator_basic() {
335 let est = CharTokenEstimator;
336 assert_eq!(est.estimate("hello world"), 3);
338 assert_eq!(est.estimate(""), 0);
339 let hundred = "a".repeat(100);
341 assert_eq!(est.estimate(&hundred), 25);
342 }
343
344 #[test]
345 fn context_config_defaults() {
346 let cfg = ContextConfig::default();
347 assert_eq!(cfg.token_budget, 2000);
348 assert_eq!(cfg.format, OutputFormat::SystemPrompt);
349 assert_eq!(cfg.max_candidates, 50);
350 }
351
352 #[test]
353 fn tier_priority_ordering() {
354 assert!(tier_priority(&MemoryTier::Working) < tier_priority(&MemoryTier::Conversation));
355 assert!(tier_priority(&MemoryTier::Conversation) < tier_priority(&MemoryTier::Knowledge));
356 }
357
358 #[test]
359 fn sort_by_tier_groups_correctly() {
360 use crate::scope::Scope;
361 let scope = Scope::org("test");
362 let mut facts = vec![
363 Fact::new("knowledge fact", scope.clone()).with_tier(MemoryTier::Knowledge),
364 Fact::new("working fact", scope.clone()).with_tier(MemoryTier::Working),
365 Fact::new("convo fact", scope).with_tier(MemoryTier::Conversation),
366 ];
367 sort_by_tier_priority(&mut facts);
368 assert_eq!(facts[0].tier, MemoryTier::Working);
369 assert_eq!(facts[1].tier, MemoryTier::Conversation);
370 assert_eq!(facts[2].tier, MemoryTier::Knowledge);
371 }
372
373 #[test]
374 fn format_system_prompt_groups_by_tier() {
375 use crate::scope::Scope;
376 let scope = Scope::org("test");
377 let facts = vec![
378 Fact::new("ephemeral note", scope.clone()).with_tier(MemoryTier::Working),
379 Fact::new("user likes pizza", scope.clone()).with_tier(MemoryTier::Conversation),
380 Fact::new("company policy", scope).with_tier(MemoryTier::Knowledge),
381 ];
382 let output = format_system_prompt(&facts);
383 assert!(output.starts_with("<memory>"));
384 assert!(output.ends_with("</memory>"));
385 assert!(output.contains("<working>"));
386 assert!(output.contains("- ephemeral note"));
387 assert!(output.contains("<conversation>"));
388 assert!(output.contains("- user likes pizza"));
389 assert!(output.contains("<knowledge>"));
390 assert!(output.contains("- company policy"));
391 }
392
393 #[test]
394 fn format_system_prompt_omits_empty_tiers() {
395 use crate::scope::Scope;
396 let scope = Scope::org("test");
397 let facts = vec![Fact::new("just a convo fact", scope).with_tier(MemoryTier::Conversation)];
398 let output = format_system_prompt(&facts);
399 assert!(!output.contains("<working>"));
400 assert!(output.contains("<conversation>"));
401 assert!(!output.contains("<knowledge>"));
402 }
403
404 #[test]
405 fn format_markdown_produces_sections() {
406 use crate::scope::Scope;
407 let scope = Scope::org("test");
408 let facts = vec![
409 Fact::new("user prefers dark mode", scope.clone()).with_tier(MemoryTier::Knowledge),
410 Fact::new("user asked about Rust", scope).with_tier(MemoryTier::Knowledge),
411 ];
412 let output = format_markdown(&facts);
413 assert!(output.contains("## Memory Context"));
414 assert!(output.contains("### Knowledge"));
415 assert!(output.contains("- user prefers dark mode"));
416 }
417
418 #[test]
419 fn format_raw_produces_valid_json() {
420 use crate::scope::Scope;
421 let scope = Scope::org("test");
422 let facts = vec![Fact::new("test fact", scope).with_tier(MemoryTier::Conversation)];
423 let output = format_raw(&facts);
424 let parsed: Vec<serde_json::Value> = serde_json::from_str(&output).unwrap();
425 assert_eq!(parsed.len(), 1);
426 assert_eq!(parsed[0]["text"], "test fact");
427 assert_eq!(parsed[0]["tier"], "conversation");
428 }
429}