use std::fmt::Write as _;
use std::sync::LazyLock;
use std::time::Duration;
use regex::Regex;
use zeph_llm::provider::{Message, MessageMetadata, Role};
use zeph_memory::TokenCounter;
fn sanitize_digest(text: &str) -> String {
static INJECTION_PATTERNS: LazyLock<Vec<Regex>> = LazyLock::new(|| {
vec![
Regex::new(r"<[^>]{1,100}>").unwrap(),
Regex::new(r"(?i)\[/?INST\]|\[/?SYS\]").unwrap(),
Regex::new(r"<\|[^|]{1,30}\|>").unwrap(),
Regex::new(r"(?im)^(system|assistant|user)\s*:\s*").unwrap(),
]
});
static INJECTION_LINE: LazyLock<Regex> = LazyLock::new(|| {
Regex::new(r"(?i)ignore\s+.{0,30}(instruction|above|previous|system)").unwrap()
});
let mut result = text.to_string();
for pattern in INJECTION_PATTERNS.iter() {
let replaced = pattern.replace_all(&result, "");
result = replaced.into_owned();
}
result
.lines()
.filter(|line| !INJECTION_LINE.is_match(line))
.collect::<Vec<_>>()
.join("\n")
}
fn truncate_digest(text: &str, max_tokens: usize, tc: &TokenCounter) -> String {
if tc.count_tokens(text) <= max_tokens {
return text.to_string();
}
let chars: Vec<char> = text.chars().collect();
let mut lo = 0usize;
let mut hi = chars.len();
while lo < hi {
let mid = (lo + hi).div_ceil(2);
let candidate: String = chars[..mid].iter().collect();
if tc.count_tokens(&candidate) <= max_tokens {
lo = mid;
} else {
hi = mid - 1;
}
}
let candidate: String = chars[..lo].iter().collect();
if let Some(pos) = candidate.rfind('\n') {
candidate[..pos].to_string()
} else {
candidate
}
}
use crate::channel::Channel;
use super::Agent;
impl<C: Channel> Agent<C> {
pub(super) async fn maybe_store_session_digest(&mut self) {
if !self.memory_state.digest_config.enabled {
return;
}
let Some(memory) = self.memory_state.memory.clone() else {
return;
};
let Some(conversation_id) = self.memory_state.conversation_id else {
return;
};
let max_input = self.memory_state.digest_config.max_input_messages;
let max_tokens = self.memory_state.digest_config.max_tokens;
let provider_name = self.memory_state.digest_config.provider.clone();
let non_system: Vec<_> = self
.msg
.messages
.iter()
.skip(1)
.filter(|m| m.role != Role::System)
.collect();
if non_system.is_empty() {
return;
}
let slice = if non_system.len() > max_input {
&non_system[non_system.len() - max_input..]
} else {
&non_system[..]
};
let mut conv_text = String::new();
for msg in slice {
let role = match msg.role {
Role::User => "User",
Role::Assistant => "Assistant",
Role::System => "System",
};
let _ = write!(conv_text, "{role}: {}\n\n", msg.content);
}
let prompt = format!(
"You are a session summarizer. Read the following conversation excerpt and produce \
a compact digest (under {max_tokens} tokens) of the key facts, decisions, outcomes, \
and open questions from this session. Be specific and concise. \
Output ONLY the digest text, no preamble.\n\n\
Conversation:\n{conv_text}\n\
Digest:"
);
let chat_messages = vec![Message {
role: Role::User,
content: prompt,
parts: vec![],
metadata: MessageMetadata::default(),
}];
let _ = self
.channel
.send_status("Generating session digest...")
.await;
let timeout = Duration::from_secs(30);
let digest_text = tokio::select! {
() = async { tokio::time::sleep(timeout).await } => {
tracing::warn!("session digest: LLM call timed out");
let _ = self.channel.send_status("").await;
return;
}
result = self.provider.chat_with_named_provider(&provider_name, &chat_messages) => {
match result {
Ok(text) => text,
Err(e) => {
tracing::warn!("session digest: LLM call failed: {e:#}");
let _ = self.channel.send_status("").await;
return;
}
}
}
};
let sanitized = sanitize_digest(&digest_text);
let tc = &self.metrics.token_counter;
let final_text = truncate_digest(&sanitized, max_tokens, tc);
let token_count = i64::try_from(tc.count_tokens(&final_text)).unwrap_or(i64::MAX);
if let Err(e) = memory
.sqlite()
.save_session_digest(conversation_id, &final_text, token_count)
.await
{
tracing::warn!("session digest: storage failed: {e:#}");
} else {
tracing::info!(
conversation_id = conversation_id.0,
tokens = token_count,
"session digest stored"
);
self.memory_state.cached_session_digest = Some((
final_text,
usize::try_from(token_count).unwrap_or(max_tokens),
));
}
let _ = self.channel.send_status("").await;
}
pub(super) async fn load_and_cache_session_digest(&mut self) {
if !self.memory_state.digest_config.enabled {
return;
}
let Some(memory) = self.memory_state.memory.clone() else {
return;
};
let Some(conversation_id) = self.memory_state.conversation_id else {
return;
};
match memory.sqlite().load_session_digest(conversation_id).await {
Ok(Some(digest)) => {
let token_count =
usize::try_from(digest.token_count).unwrap_or(digest.digest.len() / 4);
tracing::debug!(
conversation_id = conversation_id.0,
tokens = token_count,
"session digest loaded"
);
self.memory_state.cached_session_digest = Some((digest.digest, token_count));
}
Ok(None) => {}
Err(e) => {
tracing::warn!("session digest: load failed: {e:#}");
}
}
}
}