use std::time::Duration;
use serde_json::{json, Value};
use tracing::{debug, info};
use crate::events::EventStore;
use crate::traits::{ChatOptions, ModelProvider, StateStore};
use std::sync::Arc;
const CLASSIFIER_MAX_OUTPUT_TOKENS: u32 = 20;
const CLASSIFIER_TIMEOUT: Duration = Duration::from_secs(5);
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LlmIntentClass {
ScheduleOneShot,
ScheduleRecurring,
MemoryStorage,
MemoryRecall,
Action,
KnowledgeQuestion,
Other,
Unknown,
}
impl LlmIntentClass {
pub(crate) fn from_response_str(s: &str) -> Self {
match s.trim().trim_matches('"').to_ascii_lowercase().as_str() {
"schedule_one_shot" => Self::ScheduleOneShot,
"schedule_recurring" => Self::ScheduleRecurring,
"memory_storage" => Self::MemoryStorage,
"memory_recall" => Self::MemoryRecall,
"action" => Self::Action,
"knowledge_question" => Self::KnowledgeQuestion,
"other" => Self::Other,
_ => Self::Unknown,
}
}
pub fn as_label(&self) -> &'static str {
match self {
Self::ScheduleOneShot => "schedule_one_shot",
Self::ScheduleRecurring => "schedule_recurring",
Self::MemoryStorage => "memory_storage",
Self::MemoryRecall => "memory_recall",
Self::Action => "action",
Self::KnowledgeQuestion => "knowledge_question",
Self::Other => "other",
Self::Unknown => "unknown",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RelationalKind {
Relational,
Recall,
None,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RelationalIntent {
pub kind: RelationalKind,
pub entities: Vec<String>,
}
impl RelationalIntent {
fn none() -> Self {
Self {
kind: RelationalKind::None,
entities: Vec::new(),
}
}
}
pub fn parse_relational_intent(raw: &str) -> RelationalIntent {
let (Some(start), Some(end)) = (raw.find('{'), raw.rfind('}')) else {
return RelationalIntent::none();
};
if end < start {
return RelationalIntent::none();
}
let Ok(v) = serde_json::from_str::<serde_json::Value>(&raw[start..=end]) else {
return RelationalIntent::none();
};
let kind = match v.get("intent").and_then(|i| i.as_str()).unwrap_or("none") {
"relational" => RelationalKind::Relational,
"recall" => RelationalKind::Recall,
_ => RelationalKind::None,
};
let entities = v
.get("entities")
.and_then(|e| e.as_array())
.map(|arr| {
arr.iter()
.filter_map(|x| x.as_str())
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect::<Vec<_>>()
})
.unwrap_or_default();
RelationalIntent { kind, entities }
}
fn build_relational_classifier_messages(user_text: &str) -> Vec<Value> {
let system = "You classify a user message about their personal memory. \
Reply with ONLY a JSON object: {\"intent\": \"relational\"|\"recall\"|\"none\", \"entities\": [..]}. \
\"relational\" = a question about a relationship/connection between entities (e.g. \"who is Caro's spouse?\", \"who is my kid's mom?\", \"what tools does project X use?\"). \
\"recall\" = a direct fact lookup about one entity (e.g. \"what's my dog's name?\"). \
\"none\" = anything else (general knowledge, chit-chat, actions). \
\"entities\" = the people/projects/things the question is about, as named (resolve possessives to the owned entity: \"my mom\" -> \"my mom\"). Keep it short.";
vec![
json!({"role": "system", "content": system}),
json!({"role": "user", "content": user_text}),
]
}
pub async fn classify_relational_intent(
provider: &dyn ModelProvider,
fast_model: &str,
user_text: &str,
) -> RelationalIntent {
let trimmed = user_text.trim();
if trimmed.is_empty() {
return RelationalIntent::none();
}
let messages = build_relational_classifier_messages(trimmed);
let options = ChatOptions {
max_tokens_override: Some(120),
..ChatOptions::default()
};
let call = provider.chat_with_options(fast_model, &messages, &[], &options);
let response = match tokio::time::timeout(CLASSIFIER_TIMEOUT, call).await {
Ok(Ok(r)) => r,
Ok(Err(err)) => {
debug!(?err, "relational classifier call failed; failing open");
return RelationalIntent::none();
}
Err(_) => {
debug!(
timeout_s = CLASSIFIER_TIMEOUT.as_secs(),
"relational classifier timeout"
);
return RelationalIntent::none();
}
};
parse_relational_intent(response.content.as_deref().unwrap_or(""))
}
pub(crate) fn build_classifier_messages(user_text: &str) -> Vec<Value> {
let system = "You are an intent classifier. Read the user's message and \
return exactly one label (no prose, no explanation, no JSON \
wrapping). Valid labels:\n\
- schedule_one_shot: trigger a one-time future action\n\
- schedule_recurring: trigger a repeating action\n\
- memory_storage: store/remember a fact about the user\n\
- memory_recall: recall a fact already stored\n\
- action: run a tool, write code, search, browse, etc.\n\
- knowledge_question: answer a question from general knowledge\n\
- other: doesn't fit any category, or is a mix\n\
Respond with the label only.";
vec![
json!({"role": "system", "content": system}),
json!({"role": "user", "content": user_text}),
]
}
#[allow(dead_code)] pub async fn classify_intent(
provider: &dyn ModelProvider,
fast_model: &str,
user_text: &str,
state: Option<&Arc<dyn StateStore>>,
event_store: Option<Arc<EventStore>>,
) -> LlmIntentClass {
let trimmed = user_text.trim();
if trimmed.is_empty() {
return LlmIntentClass::Other;
}
let messages = build_classifier_messages(trimmed);
let options = ChatOptions {
max_tokens_override: Some(CLASSIFIER_MAX_OUTPUT_TOKENS),
..ChatOptions::default()
};
let call_start = std::time::Instant::now();
let call = provider.chat_with_options(fast_model, &messages, &[], &options);
let response = match tokio::time::timeout(CLASSIFIER_TIMEOUT, call).await {
Ok(Ok(r)) => r,
Ok(Err(err)) => {
debug!(?err, "intent classifier call failed; failing open");
return LlmIntentClass::Unknown;
}
Err(_) => {
debug!(
timeout_s = CLASSIFIER_TIMEOUT.as_secs(),
"intent classifier timeout"
);
return LlmIntentClass::Unknown;
}
};
if let (Some(state), Some(event_store)) = (state, event_store) {
crate::events::record_background_model_call_telemetry(
event_store,
state.as_ref(),
"background:intent_classifier",
"intent_classifier",
fast_model,
&response,
call_start.elapsed(),
)
.await;
}
match response.content.as_deref() {
Some(text) => LlmIntentClass::from_response_str(text),
None => LlmIntentClass::Unknown,
}
}
#[allow(dead_code)] pub fn log_intent_disagreement(user_text: &str, heuristic_label: &str, llm: LlmIntentClass) {
if llm == LlmIntentClass::Unknown {
return;
}
if heuristic_label == llm.as_label() {
return;
}
info!(
event = "intent_disagreement",
heuristic = heuristic_label,
llm = llm.as_label(),
user_text_preview = %crate::utils::truncate_str(user_text, 200),
"heuristic and LLM intent classifiers disagree"
);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::testing::MockProvider;
#[test]
fn parses_known_labels() {
assert_eq!(
LlmIntentClass::from_response_str("schedule_one_shot"),
LlmIntentClass::ScheduleOneShot
);
assert_eq!(
LlmIntentClass::from_response_str("MEMORY_STORAGE"),
LlmIntentClass::MemoryStorage
);
assert_eq!(
LlmIntentClass::from_response_str(" action "),
LlmIntentClass::Action
);
assert_eq!(
LlmIntentClass::from_response_str("\"knowledge_question\""),
LlmIntentClass::KnowledgeQuestion
);
}
#[test]
fn rejects_unknown_labels_as_unknown() {
assert_eq!(
LlmIntentClass::from_response_str("garbage"),
LlmIntentClass::Unknown
);
assert_eq!(
LlmIntentClass::from_response_str(""),
LlmIntentClass::Unknown
);
assert_eq!(
LlmIntentClass::from_response_str("{\"label\":\"action\"}"),
LlmIntentClass::Unknown
);
}
#[test]
fn label_roundtrip_is_stable() {
for class in [
LlmIntentClass::ScheduleOneShot,
LlmIntentClass::ScheduleRecurring,
LlmIntentClass::MemoryStorage,
LlmIntentClass::MemoryRecall,
LlmIntentClass::Action,
LlmIntentClass::KnowledgeQuestion,
LlmIntentClass::Other,
] {
assert_eq!(
LlmIntentClass::from_response_str(class.as_label()),
class,
"round-trip failed for {class:?}"
);
}
}
#[test]
fn message_shape_contains_user_text_and_label_vocabulary() {
let messages = build_classifier_messages("remind me at 5pm");
assert_eq!(messages.len(), 2);
let system = messages[0].get("content").and_then(|c| c.as_str()).unwrap();
for label in [
"schedule_one_shot",
"schedule_recurring",
"memory_storage",
"memory_recall",
"action",
"knowledge_question",
"other",
] {
assert!(system.contains(label), "prompt is missing label {label:?}");
}
let user = messages[1].get("content").and_then(|c| c.as_str()).unwrap();
assert_eq!(user, "remind me at 5pm");
}
#[tokio::test]
async fn classify_returns_parsed_label_on_success() {
let provider =
MockProvider::with_responses(vec![MockProvider::text_response("memory_storage")]);
let result =
classify_intent(&provider, "fast-model", "remember my birthday", None, None).await;
assert_eq!(result, LlmIntentClass::MemoryStorage);
}
#[tokio::test]
async fn classify_fails_open_on_unparseable_response() {
let provider =
MockProvider::with_responses(vec![MockProvider::text_response("not a label")]);
let result =
classify_intent(&provider, "fast-model", "do something weird", None, None).await;
assert_eq!(result, LlmIntentClass::Unknown);
}
#[tokio::test]
async fn classify_short_circuits_on_empty_input() {
let provider = MockProvider::with_responses(vec![]);
let result = classify_intent(&provider, "fast-model", " ", None, None).await;
assert_eq!(result, LlmIntentClass::Other);
assert_eq!(provider.call_count().await, 0);
}
#[tokio::test]
async fn classify_handles_provider_error_by_failing_open() {
let provider = MockProvider::new();
let result = classify_intent(&provider, "fast-model", "anything", None, None).await;
assert_eq!(result, LlmIntentClass::Unknown);
}
#[test]
fn disagreement_log_is_silent_on_match() {
log_intent_disagreement("do something", "action", LlmIntentClass::Action);
log_intent_disagreement("anything", "action", LlmIntentClass::Unknown);
}
#[test]
fn parse_relational_intent_reads_json() {
let r = parse_relational_intent(r#"{"intent":"relational","entities":["Caro","Frank"]}"#);
assert_eq!(r.kind, RelationalKind::Relational);
assert_eq!(r.entities, vec!["Caro".to_string(), "Frank".to_string()]);
}
#[test]
fn parse_relational_intent_tolerates_fencing_and_prose() {
let r = parse_relational_intent(
"Sure!\n```json\n{\"intent\":\"recall\",\"entities\":[\"my dog\"]}\n```",
);
assert_eq!(r.kind, RelationalKind::Recall);
assert_eq!(r.entities, vec!["my dog".to_string()]);
}
#[test]
fn parse_relational_intent_fails_open_on_garbage() {
let r = parse_relational_intent("not json at all");
assert_eq!(r.kind, RelationalKind::None);
assert!(r.entities.is_empty());
}
#[tokio::test]
async fn classify_relational_intent_parses_provider_json() {
let provider = crate::testing::MockProvider::with_responses(vec![
crate::testing::MockProvider::text_response(
r#"{"intent":"relational","entities":["Caro"]}"#,
),
]);
let r = classify_relational_intent(&provider, "fast-model", "who is caro's spouse?").await;
assert_eq!(r.kind, RelationalKind::Relational);
assert_eq!(r.entities, vec!["Caro".to_string()]);
}
#[tokio::test]
async fn classify_relational_intent_fails_open_on_empty_input() {
let provider = crate::testing::MockProvider::new();
let r = classify_relational_intent(&provider, "fast-model", " ").await;
assert_eq!(r.kind, RelationalKind::None);
}
}