use crate::embedding::EmbeddingProvider;
use crate::fact::{Fact, MemoryTier};
use crate::graph::GraphStore;
use crate::retrieve::{HybridRetriever, RetrievalConfig};
use crate::scope::Scope;
use crate::store::{FactStore, MemoryError};
use crate::vector::VectorStore;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
pub trait TokenEstimator: Send + Sync {
fn estimate(&self, text: &str) -> usize;
}
pub struct CharTokenEstimator;
impl TokenEstimator for CharTokenEstimator {
fn estimate(&self, text: &str) -> usize {
text.len().div_ceil(4)
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum OutputFormat {
#[default]
SystemPrompt,
Markdown,
Raw,
}
#[derive(Debug, Clone)]
pub struct ContextConfig {
pub token_budget: usize,
pub format: OutputFormat,
pub max_candidates: usize,
}
impl Default for ContextConfig {
fn default() -> Self {
Self {
token_budget: 2000,
format: OutputFormat::SystemPrompt,
max_candidates: 50,
}
}
}
#[derive(Debug, Clone, Serialize)]
pub struct ContextBlock {
pub text: String,
pub token_count: usize,
pub facts_included: usize,
pub facts_omitted: usize,
pub tier_breakdown: HashMap<String, usize>,
}
fn tier_priority(tier: &MemoryTier) -> u8 {
match tier {
MemoryTier::Working => 0,
MemoryTier::Conversation => 1,
MemoryTier::Knowledge => 2,
}
}
pub fn sort_by_tier_priority(facts: &mut [Fact]) {
facts.sort_by_key(|f| tier_priority(&f.tier));
}
pub fn format_system_prompt(facts: &[Fact]) -> String {
let mut working = Vec::new();
let mut conversation = Vec::new();
let mut knowledge = Vec::new();
for fact in facts {
let line = format!("- {}", fact.text);
match fact.tier {
MemoryTier::Working => working.push(line),
MemoryTier::Conversation => conversation.push(line),
MemoryTier::Knowledge => knowledge.push(line),
}
}
let mut out = String::from("<memory>\n");
if !working.is_empty() {
out.push_str("<working>\n");
for line in &working {
out.push_str(line);
out.push('\n');
}
out.push_str("</working>\n");
}
if !conversation.is_empty() {
out.push_str("<conversation>\n");
for line in &conversation {
out.push_str(line);
out.push('\n');
}
out.push_str("</conversation>\n");
}
if !knowledge.is_empty() {
out.push_str("<knowledge>\n");
for line in &knowledge {
out.push_str(line);
out.push('\n');
}
out.push_str("</knowledge>\n");
}
out.push_str("</memory>");
out
}
pub fn format_markdown(facts: &[Fact]) -> String {
let mut out = String::from("## Memory Context\n\n");
let mut current_tier: Option<&MemoryTier> = None;
for fact in facts {
if current_tier != Some(&fact.tier) {
let label = match fact.tier {
MemoryTier::Working => "Working Memory",
MemoryTier::Conversation => "Conversation",
MemoryTier::Knowledge => "Knowledge",
};
out.push_str(&format!("### {label}\n\n"));
current_tier = Some(&fact.tier);
}
out.push_str(&format!("- {}\n", fact.text));
}
out
}
pub fn format_raw(facts: &[Fact]) -> String {
let entries: Vec<serde_json::Value> = facts
.iter()
.map(|f| {
serde_json::json!({
"id": f.id.to_string(),
"text": f.text,
"tier": f.tier,
"category": f.category,
"confidence": f.confidence,
})
})
.collect();
serde_json::to_string_pretty(&entries).unwrap_or_else(|_| "[]".to_string())
}
pub struct ContextBuilder {
fact_store: Arc<dyn FactStore>,
vector_store: Arc<dyn VectorStore>,
graph_store: Arc<dyn GraphStore>,
embedding: Arc<dyn EmbeddingProvider>,
estimator: Box<dyn TokenEstimator>,
config: ContextConfig,
}
impl ContextBuilder {
pub fn new(
fact_store: Arc<dyn FactStore>,
vector_store: Arc<dyn VectorStore>,
graph_store: Arc<dyn GraphStore>,
embedding: Arc<dyn EmbeddingProvider>,
config: ContextConfig,
) -> Self {
Self {
fact_store,
vector_store,
graph_store,
embedding,
estimator: Box::new(CharTokenEstimator),
config,
}
}
pub fn with_estimator(mut self, estimator: Box<dyn TokenEstimator>) -> Self {
self.estimator = estimator;
self
}
pub async fn build(&self, query: &str, scope: &Scope) -> Result<ContextBlock, MemoryError> {
let retriever = HybridRetriever::new(
self.fact_store.clone(),
self.vector_store.clone(),
self.graph_store.clone(),
self.embedding.clone(),
RetrievalConfig::default(),
);
let scored = retriever
.search(query, scope, self.config.max_candidates)
.await?;
let mut facts: Vec<Fact> = scored.into_iter().map(|sf| sf.fact).collect();
sort_by_tier_priority(&mut facts);
let mut included: Vec<Fact> = Vec::new();
let mut token_count: usize = 0;
let overhead = self.format_overhead();
let budget_for_facts = self.config.token_budget.saturating_sub(overhead);
for fact in &facts {
let fact_tokens = self.estimator.estimate(&fact.text) + 2; if token_count + fact_tokens > budget_for_facts {
continue; }
token_count += fact_tokens;
included.push(fact.clone());
}
let facts_omitted = facts.len() - included.len();
let mut tier_breakdown: HashMap<String, usize> = HashMap::new();
for fact in &included {
let tier_name = match fact.tier {
MemoryTier::Working => "working",
MemoryTier::Conversation => "conversation",
MemoryTier::Knowledge => "knowledge",
};
*tier_breakdown.entry(tier_name.to_string()).or_insert(0) += 1;
}
let text = match self.config.format {
OutputFormat::SystemPrompt => format_system_prompt(&included),
OutputFormat::Markdown => format_markdown(&included),
OutputFormat::Raw => format_raw(&included),
};
let total_tokens = self.estimator.estimate(&text);
Ok(ContextBlock {
text,
token_count: total_tokens,
facts_included: included.len(),
facts_omitted,
tier_breakdown,
})
}
fn format_overhead(&self) -> usize {
match self.config.format {
OutputFormat::SystemPrompt => 15,
OutputFormat::Markdown => 13,
OutputFormat::Raw => 5,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn char_estimator_basic() {
let est = CharTokenEstimator;
assert_eq!(est.estimate("hello world"), 3);
assert_eq!(est.estimate(""), 0);
let hundred = "a".repeat(100);
assert_eq!(est.estimate(&hundred), 25);
}
#[test]
fn context_config_defaults() {
let cfg = ContextConfig::default();
assert_eq!(cfg.token_budget, 2000);
assert_eq!(cfg.format, OutputFormat::SystemPrompt);
assert_eq!(cfg.max_candidates, 50);
}
#[test]
fn tier_priority_ordering() {
assert!(tier_priority(&MemoryTier::Working) < tier_priority(&MemoryTier::Conversation));
assert!(tier_priority(&MemoryTier::Conversation) < tier_priority(&MemoryTier::Knowledge));
}
#[test]
fn sort_by_tier_groups_correctly() {
use crate::scope::Scope;
let scope = Scope::org("test");
let mut facts = vec![
Fact::new("knowledge fact", scope.clone()).with_tier(MemoryTier::Knowledge),
Fact::new("working fact", scope.clone()).with_tier(MemoryTier::Working),
Fact::new("convo fact", scope).with_tier(MemoryTier::Conversation),
];
sort_by_tier_priority(&mut facts);
assert_eq!(facts[0].tier, MemoryTier::Working);
assert_eq!(facts[1].tier, MemoryTier::Conversation);
assert_eq!(facts[2].tier, MemoryTier::Knowledge);
}
#[test]
fn format_system_prompt_groups_by_tier() {
use crate::scope::Scope;
let scope = Scope::org("test");
let facts = vec![
Fact::new("ephemeral note", scope.clone()).with_tier(MemoryTier::Working),
Fact::new("user likes pizza", scope.clone()).with_tier(MemoryTier::Conversation),
Fact::new("company policy", scope).with_tier(MemoryTier::Knowledge),
];
let output = format_system_prompt(&facts);
assert!(output.starts_with("<memory>"));
assert!(output.ends_with("</memory>"));
assert!(output.contains("<working>"));
assert!(output.contains("- ephemeral note"));
assert!(output.contains("<conversation>"));
assert!(output.contains("- user likes pizza"));
assert!(output.contains("<knowledge>"));
assert!(output.contains("- company policy"));
}
#[test]
fn format_system_prompt_omits_empty_tiers() {
use crate::scope::Scope;
let scope = Scope::org("test");
let facts = vec![Fact::new("just a convo fact", scope).with_tier(MemoryTier::Conversation)];
let output = format_system_prompt(&facts);
assert!(!output.contains("<working>"));
assert!(output.contains("<conversation>"));
assert!(!output.contains("<knowledge>"));
}
#[test]
fn format_markdown_produces_sections() {
use crate::scope::Scope;
let scope = Scope::org("test");
let facts = vec![
Fact::new("user prefers dark mode", scope.clone()).with_tier(MemoryTier::Knowledge),
Fact::new("user asked about Rust", scope).with_tier(MemoryTier::Knowledge),
];
let output = format_markdown(&facts);
assert!(output.contains("## Memory Context"));
assert!(output.contains("### Knowledge"));
assert!(output.contains("- user prefers dark mode"));
}
#[test]
fn format_raw_produces_valid_json() {
use crate::scope::Scope;
let scope = Scope::org("test");
let facts = vec![Fact::new("test fact", scope).with_tier(MemoryTier::Conversation)];
let output = format_raw(&facts);
let parsed: Vec<serde_json::Value> = serde_json::from_str(&output).unwrap();
assert_eq!(parsed.len(), 1);
assert_eq!(parsed[0]["text"], "test fact");
assert_eq!(parsed[0]["tier"], "conversation");
}
}