use std::time::Duration;
use serde::{Deserialize, Serialize};
use tokio::time::timeout;
use zeph_llm::any::AnyProvider;
use zeph_llm::provider::{LlmProvider as _, Message, Role};
use crate::error::MemoryError;
use crate::store::DbStore;
use crate::store::persona::PersonaFactRow;
const EXTRACTION_SYSTEM_PROMPT: &str = "\
You are a persona fact extractor. Given a list of user messages and any existing persona \
facts for each category, extract factual claims the user makes about themselves: their \
preferences, domain knowledge, working style, communication style, and background.
Rules:
1. Only extract facts from first-person user statements (\"I prefer\", \"I work on\", \
\"my team\", \"I use\", etc.). Ignore assistant messages.
2. Do NOT extract facts from questions, greetings, or tool outputs.
3. For each extracted fact, decide if it is NEW (no existing fact contradicts it) or \
UPDATE (it contradicts or replaces an existing fact). For UPDATE, provide the \
`supersedes_id` of the older fact.
4. Confidence: 0.8-1.0 for explicit statements (\"I prefer X\"), 0.4-0.7 for inferences.
5. Categories: preference, domain_knowledge, working_style, communication, background.
6. Keep content concise (one sentence max). Normalize to English.
7. Return empty array if no facts are found.
Output JSON array of objects:
[
{
\"category\": \"preference|domain_knowledge|working_style|communication|background\",
\"content\": \"concise factual statement\",
\"confidence\": 0.0-1.0,
\"action\": \"new|update\",
\"supersedes_id\": null or integer id of the fact being replaced
}
]";
pub struct PersonaExtractionConfig {
pub enabled: bool,
pub min_messages: usize,
pub max_messages: usize,
pub extraction_timeout_secs: u64,
}
#[derive(Debug, Deserialize, Serialize)]
struct ExtractedFact {
category: String,
content: String,
confidence: f64,
action: String,
supersedes_id: Option<i64>,
}
#[must_use]
pub fn contains_self_referential_language(text: &str) -> bool {
let lower = text.to_lowercase();
let tokens = [" i ", " i'", " my ", " me ", " mine ", "i am ", "i'm "];
tokens.iter().any(|t| lower.contains(t)) || lower.starts_with("i ") || lower.starts_with("my ")
}
#[cfg_attr(
feature = "profiling",
tracing::instrument(name = "memory.persona_extract", skip_all, fields(fact_count = tracing::field::Empty))
)]
pub async fn extract_persona_facts(
store: &DbStore,
provider: &AnyProvider,
user_messages: &[&str],
config: &PersonaExtractionConfig,
conversation_id: Option<i64>,
) -> Result<usize, MemoryError> {
if !config.enabled || user_messages.len() < config.min_messages {
return Ok(0);
}
let has_self_ref = user_messages
.iter()
.any(|m| contains_self_referential_language(m));
if !has_self_ref {
return Ok(0);
}
let messages_to_send: Vec<&str> = user_messages
.iter()
.rev()
.take(config.max_messages)
.copied()
.collect::<Vec<_>>()
.into_iter()
.rev()
.collect();
let existing_facts = store.load_persona_facts(0.0).await?;
let user_prompt = build_extraction_prompt(&messages_to_send, &existing_facts);
let llm_messages = [
Message::from_legacy(Role::System, EXTRACTION_SYSTEM_PROMPT),
Message::from_legacy(Role::User, user_prompt),
];
let extraction_timeout = Duration::from_secs(config.extraction_timeout_secs);
let response = match timeout(extraction_timeout, provider.chat(&llm_messages)).await {
Ok(Ok(text)) => text,
Ok(Err(e)) => return Err(MemoryError::Llm(e)),
Err(_) => {
tracing::warn!(
"persona extraction timed out after {}s",
config.extraction_timeout_secs
);
return Ok(0);
}
};
let facts = parse_extraction_response(&response);
if facts.is_empty() {
return Ok(0);
}
let mut upserted = 0usize;
for fact in facts {
if fact.category.is_empty() || fact.content.is_empty() {
continue;
}
if !is_valid_category(&fact.category) {
tracing::debug!(
category = %fact.category,
"persona extraction: skipping unknown category"
);
continue;
}
match store
.upsert_persona_fact(
&fact.category,
&fact.content,
fact.confidence.clamp(0.0, 1.0),
conversation_id,
fact.supersedes_id,
)
.await
{
Ok(_) => upserted += 1,
Err(e) => {
tracing::warn!(error = %e, "persona extraction: failed to upsert fact");
}
}
}
tracing::debug!(upserted, "persona extraction complete");
#[cfg(feature = "profiling")]
tracing::Span::current().record("fact_count", upserted);
Ok(upserted)
}
fn is_valid_category(category: &str) -> bool {
matches!(
category,
"preference" | "domain_knowledge" | "working_style" | "communication" | "background"
)
}
fn build_extraction_prompt(messages: &[&str], existing_facts: &[PersonaFactRow]) -> String {
let mut prompt = String::from("User messages to analyze:\n");
for (i, msg) in messages.iter().enumerate() {
use std::fmt::Write as _;
let _ = writeln!(prompt, "[{}] {}", i + 1, msg);
}
if !existing_facts.is_empty() {
prompt.push_str("\nExisting persona facts (for contradiction detection):\n");
for fact in existing_facts {
use std::fmt::Write as _;
let _ = writeln!(
prompt,
" id={} category={} content=\"{}\"",
fact.id, fact.category, fact.content
);
}
}
prompt.push_str("\nExtract persona facts as JSON array.");
prompt
}
fn parse_extraction_response(response: &str) -> Vec<ExtractedFact> {
if let Ok(facts) = serde_json::from_str::<Vec<ExtractedFact>>(response) {
return facts;
}
if let (Some(start), Some(end)) = (response.find('['), response.rfind(']'))
&& end > start
{
let slice = &response[start..=end];
if let Ok(facts) = serde_json::from_str::<Vec<ExtractedFact>>(slice) {
return facts;
}
}
tracing::warn!(
"persona extraction: failed to parse LLM response (len={}): {:.200}",
response.len(),
response
);
Vec::new()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::store::DbStore;
async fn make_store() -> DbStore {
DbStore::with_pool_size(":memory:", 1)
.await
.expect("in-memory store")
}
#[test]
fn self_ref_detects_i_prefer() {
assert!(contains_self_referential_language("I prefer dark mode"));
}
#[test]
fn self_ref_detects_my_team() {
assert!(contains_self_referential_language("my team uses Rust"));
}
#[test]
fn self_ref_detects_sentence_starting_with_i() {
assert!(contains_self_referential_language("I work remotely"));
}
#[test]
fn self_ref_detects_inline_i() {
assert!(contains_self_referential_language(
"Sometimes I prefer async"
));
}
#[test]
fn self_ref_detects_me_inline() {
assert!(contains_self_referential_language(
"That helps me understand"
));
}
#[test]
fn self_ref_no_match_for_third_person() {
assert!(!contains_self_referential_language("The team uses Python"));
}
#[test]
fn self_ref_no_match_for_tool_output() {
assert!(!contains_self_referential_language("Error: file not found"));
}
#[test]
fn self_ref_no_match_for_empty_string() {
assert!(!contains_self_referential_language(""));
}
#[tokio::test]
async fn extraction_gate_skips_when_no_self_ref() {
let store = make_store().await;
let cfg = PersonaExtractionConfig {
enabled: true,
min_messages: 1,
max_messages: 10,
extraction_timeout_secs: 5,
};
let cfg_disabled = PersonaExtractionConfig {
enabled: false,
min_messages: 1,
max_messages: 10,
extraction_timeout_secs: 5,
};
let cfg_min = PersonaExtractionConfig {
enabled: true,
min_messages: 5,
max_messages: 10,
extraction_timeout_secs: 5,
};
let messages: Vec<&str> = vec![];
assert!(messages.len() < cfg_min.min_messages);
let _ = (store, cfg, cfg_disabled, cfg_min); }
#[test]
fn parse_direct_json_array() {
let json = r#"[{"category":"preference","content":"I prefer dark mode","confidence":0.9,"action":"new","supersedes_id":null}]"#;
let facts = parse_extraction_response(json);
assert_eq!(facts.len(), 1);
assert_eq!(facts[0].category, "preference");
assert_eq!(facts[0].content, "I prefer dark mode");
assert!((facts[0].confidence - 0.9).abs() < 1e-9);
assert_eq!(facts[0].action, "new");
assert!(facts[0].supersedes_id.is_none());
}
#[test]
fn parse_json_embedded_in_prose() {
let response = "Sure! Here are the facts:\n[{\"category\":\"domain_knowledge\",\"content\":\"Uses Rust\",\"confidence\":0.8,\"action\":\"new\",\"supersedes_id\":null}]\nThat's all.";
let facts = parse_extraction_response(response);
assert_eq!(facts.len(), 1);
assert_eq!(facts[0].category, "domain_knowledge");
}
#[test]
fn parse_empty_array() {
let facts = parse_extraction_response("[]");
assert!(facts.is_empty());
}
#[test]
fn parse_invalid_json_returns_empty() {
let facts = parse_extraction_response("not json at all");
assert!(facts.is_empty());
}
#[test]
fn parse_supersedes_id_populated() {
let json = r#"[{"category":"preference","content":"I prefer dark mode","confidence":0.9,"action":"update","supersedes_id":7}]"#;
let facts = parse_extraction_response(json);
assert_eq!(facts[0].supersedes_id, Some(7));
assert_eq!(facts[0].action, "update");
}
#[tokio::test]
async fn contradiction_second_fact_supersedes_first() {
let store = make_store().await;
let old_id = store
.upsert_persona_fact("preference", "I prefer light mode", 0.8, None, None)
.await
.expect("old fact");
store
.upsert_persona_fact("preference", "I prefer dark mode", 0.9, None, Some(old_id))
.await
.expect("new fact");
let facts = store.load_persona_facts(0.0).await.expect("load");
assert_eq!(facts.len(), 1, "superseded fact should be excluded");
assert_eq!(facts[0].content, "I prefer dark mode");
}
}