use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::sync::Arc;
use tokio::sync::Semaphore;
use tracing::{debug, info, warn};
use crate::config::ContextWindowConfig;
use crate::traits::{ModelProvider, StateStore};
use crate::types::UserRole;
static EXTRACTION_SEMAPHORE: std::sync::LazyLock<Semaphore> =
std::sync::LazyLock::new(|| Semaphore::new(2));
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InlineFact {
pub category: String,
pub key: String,
pub value: String,
}
pub fn estimate_tokens(text: &str) -> usize {
text.len() / 4
}
pub fn compute_available_budget(
model: &str,
system_prompt: &str,
tool_defs: &[Value],
config: &ContextWindowConfig,
) -> usize {
let total_budget = config
.model_budgets
.get(model)
.copied()
.unwrap_or(config.default_budget);
let system_tokens = estimate_tokens(system_prompt);
let tools_json = serde_json::to_string(tool_defs).unwrap_or_default();
let tools_tokens = estimate_tokens(&tools_json);
let response_reserve = 1536;
total_budget.saturating_sub(system_tokens + tools_tokens + response_reserve)
}
#[allow(dead_code)]
pub fn fit_messages_to_budget(
messages: Vec<Value>,
budget_tokens: usize,
session_summary: Option<&str>,
) -> Vec<Value> {
let messages_json = serde_json::to_string(&messages).unwrap_or_default();
let current_tokens = estimate_tokens(&messages_json);
if current_tokens <= budget_tokens {
return messages;
}
let msg_count = messages.len();
if msg_count <= 2 {
return messages;
}
let keep_recent = 8.min(msg_count - 1);
let anchor = messages[0].clone();
let recent: Vec<Value> = messages[msg_count - keep_recent..].to_vec();
let mut result = Vec::with_capacity(keep_recent + 2);
result.push(anchor);
if let Some(summary) = session_summary {
result.push(json!({
"role": "system",
"content": format!("[Conversation summary: {}]", summary)
}));
}
result.extend(recent);
let dropped = msg_count - result.len() + if session_summary.is_some() { 1 } else { 0 };
info!(
original_count = msg_count,
result_count = result.len(),
dropped,
original_tokens = current_tokens,
budget_tokens,
"Context window: trimmed messages to fit budget"
);
result
}
fn role_quota(role: &str) -> usize {
match role {
"user" => 10,
"assistant" => 10,
"tool" => 8,
_ => 6,
}
}
pub fn fit_messages_with_source_quotas(
messages: Vec<Value>,
budget_tokens: usize,
session_summary: Option<&str>,
) -> Vec<Value> {
let messages_json = serde_json::to_string(&messages).unwrap_or_default();
let current_tokens = estimate_tokens(&messages_json);
if current_tokens <= budget_tokens {
return messages;
}
if messages.len() <= 2 {
return messages;
}
let mut selected_indices: std::collections::BTreeSet<usize> = std::collections::BTreeSet::new();
let mut role_counts: std::collections::HashMap<String, usize> =
std::collections::HashMap::new();
let anchor_idx = messages
.iter()
.position(|m| m.get("role").and_then(|r| r.as_str()) == Some("user"))
.unwrap_or(0);
selected_indices.insert(anchor_idx);
let anchor_role = messages[anchor_idx]
.get("role")
.and_then(|r| r.as_str())
.unwrap_or("unknown")
.to_string();
*role_counts.entry(anchor_role).or_insert(0) += 1;
let keep_recent = 8usize.min(messages.len());
let start = messages.len().saturating_sub(keep_recent);
for (idx, msg) in messages.iter().enumerate().skip(start) {
if selected_indices.insert(idx) {
let role = msg
.get("role")
.and_then(|r| r.as_str())
.unwrap_or("unknown")
.to_string();
*role_counts.entry(role).or_insert(0) += 1;
}
}
for idx in (0..messages.len()).rev() {
if selected_indices.contains(&idx) {
continue;
}
let role = messages[idx]
.get("role")
.and_then(|r| r.as_str())
.unwrap_or("unknown");
let quota = role_quota(role);
let count = role_counts.get(role).copied().unwrap_or(0);
if count >= quota {
continue;
}
selected_indices.insert(idx);
*role_counts.entry(role.to_string()).or_insert(0) += 1;
}
let mut result: Vec<Value> = selected_indices
.iter()
.map(|idx| messages[*idx].clone())
.collect();
if let Some(summary) = session_summary {
if !summary.trim().is_empty() {
let insert_at = 1.min(result.len());
result.insert(
insert_at,
json!({
"role": "system",
"content": format!("[Conversation summary: {}]", summary)
}),
);
}
}
loop {
let json = serde_json::to_string(&result).unwrap_or_default();
if estimate_tokens(&json) <= budget_tokens || result.len() <= 2 {
break;
}
if result.len() > 7 {
result.remove(1);
} else {
break;
}
}
info!(
original_count = messages.len(),
result_count = result.len(),
original_tokens = current_tokens,
budget_tokens,
"Context window: applied source quotas"
);
result
}
pub fn compress_tool_result(tool_name: &str, result: &str, max_chars: usize) -> String {
let total_chars = result.chars().count();
if total_chars <= max_chars {
return result.to_string();
}
const ANNOTATION_OVERHEAD: usize = 64;
const MAX_HEAD_CHARS: usize = 1000;
const MAX_TAIL_CHARS: usize = 800;
const MIN_HEAD_CHARS: usize = 120;
const MIN_TAIL_CHARS: usize = 80;
if looks_like_structured_payload(result) {
let available = max_chars.saturating_sub(ANNOTATION_OVERHEAD);
let struct_head = (available * 7) / 10;
let struct_tail = available.saturating_sub(struct_head);
if total_chars <= struct_head + struct_tail {
return result.to_string();
}
let head_end = byte_index_after_chars(result, struct_head);
let tail_start = byte_index_before_last_chars(result, struct_tail);
let omitted = total_chars.saturating_sub(struct_head + struct_tail);
let compressed = format!(
"{}\n\n[truncated {} chars from structured payload of {} total]\n\n{}",
&result[..head_end],
omitted,
total_chars,
&result[tail_start..]
);
debug!(
tool = tool_name,
original_len = total_chars,
compressed_len = compressed.len(),
"Compressed structured tool result"
);
return compressed;
}
if max_chars <= ANNOTATION_OVERHEAD + MIN_HEAD_CHARS + MIN_TAIL_CHARS {
let head_chars = max_chars.saturating_sub(ANNOTATION_OVERHEAD).max(1);
let head_end = byte_index_after_chars(result, head_chars);
let omitted = total_chars.saturating_sub(head_chars);
return format!(
"{}\n\n[truncated {} chars from {} total]",
&result[..head_end],
omitted,
total_chars
);
}
let available = max_chars.saturating_sub(ANNOTATION_OVERHEAD);
let mut head_chars = (available * 5) / 9;
let mut tail_chars = available.saturating_sub(head_chars);
head_chars = head_chars.clamp(MIN_HEAD_CHARS, MAX_HEAD_CHARS);
tail_chars = tail_chars.clamp(MIN_TAIL_CHARS, MAX_TAIL_CHARS);
if head_chars + tail_chars > available {
tail_chars = available.saturating_sub(head_chars);
}
if tail_chars < MIN_TAIL_CHARS {
tail_chars = MIN_TAIL_CHARS.min(available.saturating_sub(1));
head_chars = available.saturating_sub(tail_chars);
}
if total_chars <= head_chars + tail_chars {
return result.to_string();
}
let head_end = byte_index_after_chars(result, head_chars);
let tail_start = byte_index_before_last_chars(result, tail_chars);
let omitted = total_chars.saturating_sub(head_chars + tail_chars);
let compressed = format!(
"{}\n\n[truncated {} chars from middle of {} total]\n\n{}",
&result[..head_end],
omitted,
total_chars,
&result[tail_start..]
);
debug!(
tool = tool_name,
original_len = total_chars,
compressed_len = compressed.len(),
"Compressed tool result"
);
compressed
}
fn looks_like_structured_payload(result: &str) -> bool {
let trimmed = result.trim_start();
trimmed.starts_with('{')
|| (trimmed.starts_with('[') && !trimmed.starts_with("[UNTRUSTED"))
|| result.contains("\nJSON summary:\n")
|| result.contains("\nTop-level JSON array")
}
fn byte_index_after_chars(s: &str, char_count: usize) -> usize {
if char_count == 0 {
return 0;
}
s.char_indices()
.map(|(idx, _)| idx)
.nth(char_count)
.unwrap_or(s.len())
}
fn byte_index_before_last_chars(s: &str, char_count: usize) -> usize {
if char_count == 0 {
return s.len();
}
let total = s.chars().count();
if char_count >= total {
return 0;
}
byte_index_after_chars(s, total.saturating_sub(char_count))
}
fn message_contains_critical_fact_signal(content: &str) -> bool {
let lower = content.trim().to_ascii_lowercase();
if lower.is_empty() {
return false;
}
lower.contains("my name is")
|| lower.contains("owner name")
|| lower.contains("assistant name")
|| lower.contains("bot name")
|| lower.contains("call me ")
|| lower.contains(" is myself")
|| lower.contains("daughter")
|| lower.contains("son")
|| lower.contains("children")
|| lower.contains("wife")
|| lower.contains("husband")
|| lower.contains("spouse")
|| (lower.contains("saved fact") && lower.contains("name"))
}
pub async fn summarize_messages(
provider: &Arc<dyn ModelProvider>,
model: &str,
messages: &[Value],
state: Option<&Arc<dyn StateStore>>,
) -> anyhow::Result<String> {
let mut conversation_text = String::new();
for msg in messages {
let role = msg
.get("role")
.and_then(|r| r.as_str())
.unwrap_or("unknown");
let content = msg
.get("content")
.and_then(|c| c.as_str())
.unwrap_or("[no content]");
let contains_critical = message_contains_critical_fact_signal(content);
let max_chars = if contains_critical { 1200 } else { 500 };
let truncated = if content.len() > max_chars {
let mut end = max_chars;
while !content.is_char_boundary(end) && end > 0 {
end -= 1;
}
&content[..end]
} else {
content
};
let critical_prefix = if contains_critical { "[CRITICAL] " } else { "" };
conversation_text.push_str(&format!("{}{}: {}\n", critical_prefix, role, truncated));
}
let llm_messages = vec![
json!({
"role": "system",
"content": "You are a conversation summarizer. Be extremely concise and preserve critical identity/profile facts."
}),
json!({
"role": "user",
"content": format!(
"Summarize this conversation concisely. Preserve: topics discussed, decisions made, \
important data/values mentioned, user preferences expressed, pending tasks, \
and critical identity/relationship updates (owner name, assistant name, spouse/children).\n\
Output 3-5 sentences max.\n\n{}",
conversation_text
)
}),
];
let response = provider.chat(model, &llm_messages, &[]).await?;
if let (Some(state), Some(usage)) = (state, &response.usage) {
let _ = state
.record_token_usage("background:summarization", usage)
.await;
}
response
.content
.ok_or_else(|| anyhow::anyhow!("Empty response from summarization LLM"))
}
pub fn should_extract_facts(user_text: &str) -> bool {
let trimmed = user_text.trim();
if trimmed.len() < 20 {
return false;
}
let lower = trimmed.to_lowercase();
let trivial = [
"ok",
"okay",
"thanks",
"thank you",
"thx",
"yes",
"no",
"yep",
"nope",
"sure",
"got it",
"cool",
"nice",
"great",
"good",
"lol",
"haha",
"hmm",
"ah",
"oh",
"right",
"exactly",
"agreed",
"understood",
"roger",
"k",
"kk",
"ty",
"np",
"👍",
"👋",
"🙏",
"✅",
"done",
"perfect",
"awesome",
];
if trivial.contains(&lower.as_str()) {
return false;
}
true
}
pub async fn extract_inline_facts(
provider: &Arc<dyn ModelProvider>,
model: &str,
user_message: &str,
assistant_response: &str,
state: Option<&Arc<dyn StateStore>>,
) -> anyhow::Result<Vec<InlineFact>> {
let _permit = EXTRACTION_SEMAPHORE.acquire().await?;
let llm_messages = vec![
json!({
"role": "system",
"content": "You extract durable facts from conversations. Only extract facts that would be useful to remember long-term. \
Return a JSON array of objects with 'category', 'key', and 'value' fields.\n\n\
Categories: user (personal info), preference (likes/dislikes), project (project details), technical (technical facts).\n\
Use snake_case keys like 'dog_name', 'favorite_color', 'work_company'. Be consistent with naming.\n\n\
CORRECTIONS: If the user is correcting or updating previously stated information (e.g., \"actually\", \"not X, it's Y\", \
\"I changed\", \"I meant\"), extract the CORRECTED fact using the same key format as the original would have used. \
The corrected value will automatically supersede the old one.\n\n\
If nothing is worth remembering, return an empty array: []\n\n\
Examples:\n\
- \"My dog's name is Bella\" → [{\"category\":\"user\",\"key\":\"dog_name\",\"value\":\"Bella\"}]\n\
- \"Actually my dog's name is Max, not Bella\" → [{\"category\":\"user\",\"key\":\"dog_name\",\"value\":\"Max\"}]\n\
- \"I prefer dark mode\" → [{\"category\":\"preference\",\"key\":\"ui_theme\",\"value\":\"dark mode\"}]\n\
- \"My sister lives in Tokyo, not Paris\" → [{\"category\":\"user\",\"key\":\"sister_location\",\"value\":\"Tokyo\"}]\n\
- \"How's the weather?\" → []\n\n\
IMPORTANT: Return ONLY the JSON array, no other text."
}),
json!({
"role": "user",
"content": format!(
"User said: {}\n\nAssistant replied: {}",
truncate_for_extraction(user_message, 500),
truncate_for_extraction(assistant_response, 500)
)
}),
];
let response = provider.chat(model, &llm_messages, &[]).await?;
if let (Some(state), Some(usage)) = (state, &response.usage) {
let _ = state
.record_token_usage("background:progressive_extraction", usage)
.await;
}
let text = match response.content {
Some(t) => t,
None => return Ok(vec![]),
};
let trimmed = text.trim();
let json_str = if let Some(start) = trimmed.find('[') {
if let Some(end) = trimmed.rfind(']') {
&trimmed[start..=end]
} else {
return Ok(vec![]);
}
} else {
return Ok(vec![]);
};
match serde_json::from_str::<Vec<InlineFact>>(json_str) {
Ok(facts) => {
if !facts.is_empty() {
info!(count = facts.len(), "Progressive extraction found facts");
}
Ok(facts)
}
Err(e) => {
debug!(error = %e, response = trimmed, "Failed to parse extraction response");
Ok(vec![])
}
}
}
fn truncate_for_extraction(text: &str, max_len: usize) -> &str {
if text.len() <= max_len {
text
} else {
let mut end = max_len;
while !text.is_char_boundary(end) && end > 0 {
end -= 1;
}
&text[..end]
}
}
#[allow(clippy::too_many_arguments)]
pub fn spawn_progressive_extraction(
provider: Arc<dyn ModelProvider>,
fast_model: String,
state: Arc<dyn StateStore>,
user_text: String,
assistant_response: String,
channel_id: Option<String>,
visibility: crate::types::ChannelVisibility,
user_role: UserRole,
) {
tokio::spawn(async move {
if !user_role.can_persist_owner_memory()
|| matches!(visibility, crate::types::ChannelVisibility::PublicExternal)
{
return;
}
match extract_inline_facts(
&provider,
&fast_model,
&user_text,
&assistant_response,
Some(&state),
)
.await
{
Ok(facts) if !facts.is_empty() => {
for fact in facts {
let privacy = if fact.category.trim().eq_ignore_ascii_case("user") {
crate::types::FactPrivacy::Private
} else {
crate::types::FactPrivacy::Channel
};
if let Err(e) = state
.upsert_fact(
&fact.category,
&fact.key,
&fact.value,
"progressive",
channel_id.as_deref(),
privacy,
)
.await
{
warn!(error = %e, key = fact.key, "Failed to store progressive fact");
}
}
}
Ok(_) => {} Err(e) => {
debug!(error = %e, "Progressive fact extraction failed");
}
}
});
}
pub fn spawn_incremental_summarization(
provider: Arc<dyn ModelProvider>,
fast_model: String,
state: Arc<dyn StateStore>,
session_id: String,
threshold: usize,
window: usize,
user_role: UserRole,
) {
tokio::spawn(async move {
if !user_role.can_persist_owner_memory() {
return;
}
let history = match state.get_history(&session_id, 100).await {
Ok(h) => h,
Err(e) => {
warn!(error = %e, "Failed to get history for summarization");
return;
}
};
if history.len() < threshold {
return;
}
let to_summarize_count = history.len().saturating_sub(window);
if to_summarize_count == 0 {
return;
}
let to_summarize: Vec<Value> = history[..to_summarize_count]
.iter()
.map(|m| {
json!({
"role": m.role,
"content": m.content.as_deref().unwrap_or("")
})
})
.collect();
match summarize_messages(&provider, &fast_model, &to_summarize, Some(&state)).await {
Ok(text) => {
let last_msg_id = history[to_summarize_count - 1].id.clone();
let summary = crate::traits::ConversationSummary {
session_id: session_id.clone(),
summary: text,
message_count: to_summarize_count,
last_message_id: last_msg_id,
updated_at: chrono::Utc::now(),
};
if let Err(e) = state.upsert_conversation_summary(&summary).await {
warn!(error = %e, "Failed to store conversation summary");
} else {
info!(
session_id = session_id.as_str(),
message_count = to_summarize_count,
"Stored conversation summary"
);
}
}
Err(e) => {
warn!(error = %e, "Failed to summarize messages");
}
}
});
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_estimate_tokens() {
assert_eq!(estimate_tokens(""), 0);
assert_eq!(estimate_tokens("hi"), 0); assert_eq!(estimate_tokens("hello world!!"), 3); let long = "a".repeat(1000);
assert_eq!(estimate_tokens(&long), 250);
}
#[test]
fn test_fit_messages_under_budget() {
let messages = vec![
json!({"role": "user", "content": "Hello"}),
json!({"role": "assistant", "content": "Hi there"}),
];
let result = fit_messages_to_budget(messages.clone(), 100_000, None);
assert_eq!(result.len(), 2);
assert_eq!(result, messages);
}
#[test]
fn test_fit_messages_over_budget() {
let mut messages = Vec::new();
for i in 0..15 {
let role = if i % 2 == 0 { "user" } else { "assistant" };
messages.push(json!({"role": role, "content": format!("Message number {}", i)}));
}
let result =
fit_messages_to_budget(messages.clone(), 50, Some("We discussed topics A and B"));
assert_eq!(result.len(), 10);
assert_eq!(result[0]["content"], "Message number 0");
assert!(result[1]["content"]
.as_str()
.unwrap()
.contains("Conversation summary"));
assert_eq!(result[9]["content"], "Message number 14");
}
#[test]
fn test_fit_messages_over_budget_no_summary() {
let mut messages = Vec::new();
for i in 0..10 {
let role = if i % 2 == 0 { "user" } else { "assistant" };
messages.push(json!({"role": role, "content": format!("Message {}", i)}));
}
let result = fit_messages_to_budget(messages, 50, None);
assert_eq!(result.len(), 9);
assert_eq!(result[0]["content"], "Message 0");
assert_eq!(result[8]["content"], "Message 9");
}
#[test]
fn test_fit_with_source_quotas_keeps_anchor_and_recent() {
let mut messages = Vec::new();
for i in 0..18 {
let role = if i % 3 == 0 {
"user"
} else if i % 3 == 1 {
"assistant"
} else {
"tool"
};
messages.push(json!({"role": role, "content": format!("msg-{i}")}));
}
let result = fit_messages_with_source_quotas(messages, 40, Some("summary"));
assert!(!result.is_empty());
assert_eq!(result[0]["role"], "user");
let tail = result.last().unwrap()["content"].as_str().unwrap();
assert!(tail.contains("msg-17"));
}
#[test]
fn test_compress_tool_result_short() {
let short = "Hello world";
let result = compress_tool_result("test_tool", short, 2000);
assert_eq!(result, short);
}
#[test]
fn test_compress_tool_result_long() {
let long = format!("HEAD:{}:TAIL", "x".repeat(5000));
let result = compress_tool_result("test_tool", &long, 2000);
assert!(result.len() < long.len());
assert!(result.contains("[truncated"));
assert!(result.contains("HEAD:"));
assert!(result.contains(":TAIL"));
}
#[test]
fn test_compress_tool_result_keeps_head_and_tail_for_structured_payloads() {
let json_body =
"{\n \"items\": [\n".to_string() + &" {\"id\":1},\n".repeat(100) + " ]\n}";
let structured = format!(
"[UNTRUSTED EXTERNAL DATA from 'http_request']\nHTTP 200 OK\n\nJSON summary:\nitems: array(2 item(s))\n\n{}",
json_body
);
let result = compress_tool_result("http_request", &structured, 600);
assert!(result.contains("JSON summary:"));
assert!(result.contains("structured payload"));
assert!(result.contains("]\n}"));
}
#[test]
fn test_compute_budget() {
let config = ContextWindowConfig {
default_budget: 32000,
model_budgets: {
let mut m = std::collections::HashMap::new();
m.insert("big-model".to_string(), 100000);
m
},
..Default::default()
};
let budget = compute_available_budget("unknown-model", "system prompt", &[], &config);
let expected = 32000 - estimate_tokens("system prompt") - estimate_tokens("[]") - 1536;
assert_eq!(budget, expected);
let budget = compute_available_budget("big-model", "system prompt", &[], &config);
let expected = 100000 - estimate_tokens("system prompt") - estimate_tokens("[]") - 1536;
assert_eq!(budget, expected);
}
#[test]
fn test_should_extract_facts_trivial() {
assert!(!should_extract_facts("ok"));
assert!(!should_extract_facts("thanks"));
assert!(!should_extract_facts("yes"));
assert!(!should_extract_facts("lol"));
assert!(!should_extract_facts("👍"));
assert!(!should_extract_facts("short")); assert!(!should_extract_facts("Got it")); }
#[test]
fn test_should_extract_facts_meaningful() {
assert!(should_extract_facts(
"My dog's name is Bella and she's a golden retriever"
));
assert!(should_extract_facts(
"I work at Acme Corp in the engineering department"
));
assert!(should_extract_facts(
"Please set up a new React project with TypeScript"
));
}
#[test]
fn test_inline_fact_deserialization() {
let json = r#"[{"category":"user","key":"dog_name","value":"Bella"}]"#;
let facts: Vec<InlineFact> = serde_json::from_str(json).unwrap();
assert_eq!(facts.len(), 1);
assert_eq!(facts[0].category, "user");
assert_eq!(facts[0].key, "dog_name");
assert_eq!(facts[0].value, "Bella");
}
#[test]
fn test_inline_fact_empty_array() {
let json = "[]";
let facts: Vec<InlineFact> = serde_json::from_str(json).unwrap();
assert!(facts.is_empty());
}
}