use std::time::Duration;
use serde_json::{json, Value};
use tracing::{debug, info};
use crate::events::EventStore;
use crate::traits::{ChatOptions, ModelProvider, StateStore};
use std::sync::Arc;
const CLASSIFIER_MAX_OUTPUT_TOKENS: u32 = 20;
const CLASSIFIER_TIMEOUT: Duration = Duration::from_secs(5);
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LlmIntentClass {
ScheduleOneShot,
ScheduleRecurring,
MemoryStorage,
MemoryRecall,
Action,
KnowledgeQuestion,
Other,
Unknown,
}
impl LlmIntentClass {
pub(crate) fn from_response_str(s: &str) -> Self {
match s.trim().trim_matches('"').to_ascii_lowercase().as_str() {
"schedule_one_shot" => Self::ScheduleOneShot,
"schedule_recurring" => Self::ScheduleRecurring,
"memory_storage" => Self::MemoryStorage,
"memory_recall" => Self::MemoryRecall,
"action" => Self::Action,
"knowledge_question" => Self::KnowledgeQuestion,
"other" => Self::Other,
_ => Self::Unknown,
}
}
pub fn as_label(&self) -> &'static str {
match self {
Self::ScheduleOneShot => "schedule_one_shot",
Self::ScheduleRecurring => "schedule_recurring",
Self::MemoryStorage => "memory_storage",
Self::MemoryRecall => "memory_recall",
Self::Action => "action",
Self::KnowledgeQuestion => "knowledge_question",
Self::Other => "other",
Self::Unknown => "unknown",
}
}
}
pub(crate) fn build_classifier_messages(user_text: &str) -> Vec<Value> {
let system = "You are an intent classifier. Read the user's message and \
return exactly one label (no prose, no explanation, no JSON \
wrapping). Valid labels:\n\
- schedule_one_shot: trigger a one-time future action\n\
- schedule_recurring: trigger a repeating action\n\
- memory_storage: store/remember a fact about the user\n\
- memory_recall: recall a fact already stored\n\
- action: run a tool, write code, search, browse, etc.\n\
- knowledge_question: answer a question from general knowledge\n\
- other: doesn't fit any category, or is a mix\n\
Respond with the label only.";
vec![
json!({"role": "system", "content": system}),
json!({"role": "user", "content": user_text}),
]
}
#[allow(dead_code)] pub async fn classify_intent(
provider: &dyn ModelProvider,
fast_model: &str,
user_text: &str,
state: Option<&Arc<dyn StateStore>>,
event_store: Option<Arc<EventStore>>,
) -> LlmIntentClass {
let trimmed = user_text.trim();
if trimmed.is_empty() {
return LlmIntentClass::Other;
}
let messages = build_classifier_messages(trimmed);
let options = ChatOptions {
max_tokens_override: Some(CLASSIFIER_MAX_OUTPUT_TOKENS),
..ChatOptions::default()
};
let call_start = std::time::Instant::now();
let call = provider.chat_with_options(fast_model, &messages, &[], &options);
let response = match tokio::time::timeout(CLASSIFIER_TIMEOUT, call).await {
Ok(Ok(r)) => r,
Ok(Err(err)) => {
debug!(?err, "intent classifier call failed; failing open");
return LlmIntentClass::Unknown;
}
Err(_) => {
debug!(
timeout_s = CLASSIFIER_TIMEOUT.as_secs(),
"intent classifier timeout"
);
return LlmIntentClass::Unknown;
}
};
if let (Some(state), Some(event_store)) = (state, event_store) {
crate::events::record_background_model_call_telemetry(
event_store,
state.as_ref(),
"background:intent_classifier",
"intent_classifier",
fast_model,
&response,
call_start.elapsed(),
)
.await;
}
match response.content.as_deref() {
Some(text) => LlmIntentClass::from_response_str(text),
None => LlmIntentClass::Unknown,
}
}
#[allow(dead_code)] pub fn log_intent_disagreement(user_text: &str, heuristic_label: &str, llm: LlmIntentClass) {
if llm == LlmIntentClass::Unknown {
return;
}
if heuristic_label == llm.as_label() {
return;
}
info!(
event = "intent_disagreement",
heuristic = heuristic_label,
llm = llm.as_label(),
user_text_preview = %crate::utils::truncate_str(user_text, 200),
"heuristic and LLM intent classifiers disagree"
);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::testing::MockProvider;
#[test]
fn parses_known_labels() {
assert_eq!(
LlmIntentClass::from_response_str("schedule_one_shot"),
LlmIntentClass::ScheduleOneShot
);
assert_eq!(
LlmIntentClass::from_response_str("MEMORY_STORAGE"),
LlmIntentClass::MemoryStorage
);
assert_eq!(
LlmIntentClass::from_response_str(" action "),
LlmIntentClass::Action
);
assert_eq!(
LlmIntentClass::from_response_str("\"knowledge_question\""),
LlmIntentClass::KnowledgeQuestion
);
}
#[test]
fn rejects_unknown_labels_as_unknown() {
assert_eq!(
LlmIntentClass::from_response_str("garbage"),
LlmIntentClass::Unknown
);
assert_eq!(
LlmIntentClass::from_response_str(""),
LlmIntentClass::Unknown
);
assert_eq!(
LlmIntentClass::from_response_str("{\"label\":\"action\"}"),
LlmIntentClass::Unknown
);
}
#[test]
fn label_roundtrip_is_stable() {
for class in [
LlmIntentClass::ScheduleOneShot,
LlmIntentClass::ScheduleRecurring,
LlmIntentClass::MemoryStorage,
LlmIntentClass::MemoryRecall,
LlmIntentClass::Action,
LlmIntentClass::KnowledgeQuestion,
LlmIntentClass::Other,
] {
assert_eq!(
LlmIntentClass::from_response_str(class.as_label()),
class,
"round-trip failed for {class:?}"
);
}
}
#[test]
fn message_shape_contains_user_text_and_label_vocabulary() {
let messages = build_classifier_messages("remind me at 5pm");
assert_eq!(messages.len(), 2);
let system = messages[0].get("content").and_then(|c| c.as_str()).unwrap();
for label in [
"schedule_one_shot",
"schedule_recurring",
"memory_storage",
"memory_recall",
"action",
"knowledge_question",
"other",
] {
assert!(system.contains(label), "prompt is missing label {label:?}");
}
let user = messages[1].get("content").and_then(|c| c.as_str()).unwrap();
assert_eq!(user, "remind me at 5pm");
}
#[tokio::test]
async fn classify_returns_parsed_label_on_success() {
let provider =
MockProvider::with_responses(vec![MockProvider::text_response("memory_storage")]);
let result =
classify_intent(&provider, "fast-model", "remember my birthday", None, None).await;
assert_eq!(result, LlmIntentClass::MemoryStorage);
}
#[tokio::test]
async fn classify_fails_open_on_unparseable_response() {
let provider =
MockProvider::with_responses(vec![MockProvider::text_response("not a label")]);
let result =
classify_intent(&provider, "fast-model", "do something weird", None, None).await;
assert_eq!(result, LlmIntentClass::Unknown);
}
#[tokio::test]
async fn classify_short_circuits_on_empty_input() {
let provider = MockProvider::with_responses(vec![]);
let result = classify_intent(&provider, "fast-model", " ", None, None).await;
assert_eq!(result, LlmIntentClass::Other);
assert_eq!(provider.call_count().await, 0);
}
#[tokio::test]
async fn classify_handles_provider_error_by_failing_open() {
let provider = MockProvider::new();
let result = classify_intent(&provider, "fast-model", "anything", None, None).await;
assert_eq!(result, LlmIntentClass::Unknown);
}
#[test]
fn disagreement_log_is_silent_on_match() {
log_intent_disagreement("do something", "action", LlmIntentClass::Action);
log_intent_disagreement("anything", "action", LlmIntentClass::Unknown);
}
}