use std::fmt::Debug;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Route {
Inject,
Skip,
}
impl Route {
pub fn as_str(self) -> &'static str {
match self {
Route::Inject => "inject",
Route::Skip => "skip",
}
}
}
pub fn classify_naive(query: &str) -> Route {
let q = query.to_ascii_lowercase();
for phrase in COUNTING_TRIGGERS {
if contains_whole_phrase(&q, phrase) {
return Route::Inject;
}
}
Route::Skip
}
fn contains_whole_phrase(haystack: &str, phrase: &str) -> bool {
let mut start = 0;
while let Some(idx) = haystack[start..].find(phrase) {
let abs = start + idx;
let before_ok = abs == 0 || !haystack.as_bytes()[abs - 1].is_ascii_alphanumeric();
let after_pos = abs + phrase.len();
let after_ok =
after_pos >= haystack.len() || !haystack.as_bytes()[after_pos].is_ascii_alphanumeric();
if before_ok && after_ok {
return true;
}
start = abs + 1;
}
false
}
const COUNTING_TRIGGERS: &[&str] = &[
"how many",
"how often",
"how much",
"list every",
"list all",
"count",
"total number",
"in total",
"altogether",
"over the course",
"across sessions",
"across all",
"across the",
"so far",
"sum of",
"aggregate",
];
#[cfg(feature = "observation-extraction")]
mod haiku {
use super::Route;
use serde::{Deserialize, Serialize};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
const DEFAULT_MODEL: &str = "claude-haiku-4-5-20251001";
const DEFAULT_MAX_TOKENS: u32 = 16;
const DEFAULT_TIMEOUT_SECS: u64 = 10;
const ANTHROPIC_VERSION: &str = "2023-06-01";
pub const CLASSIFIER_PROMPT_V1: &str = "You are a query router. Decide \
whether to inject pre-extracted structured facts from past conversations \
into the reader's prompt. Reply `inject` when the question asks about \
*either* of the following:\n\
\n\
1. COUNTING or ENUMERATION across conversations — \"how many\", \"list \
every\", \"total X\", \"how often\", \"sum of\", \"in total\".\n\
2. TEMPORAL reasoning or CHRONOLOGY — ordering events in time, asking \
when something happened, tracking how things changed over time, or \
comparing items mentioned in different sessions (e.g., \"what was the \
last X\", \"when did I start Y\", \"which came first\", \"what did the \
assistant recommend before suggesting Z\", \"what was I doing around \
the time we discussed Y\").\n\
\n\
Reply `skip` for everything else, including: current-state preference \
questions (\"what's my favorite…?\"), requests for advice or action \
(\"should I…?\", \"remind me…\"), single-session factual lookups \
(\"what did I tell you about X?\"), and assistant-output recall \
(\"what did you recommend?\") unless the answer requires comparing \
across sessions.\n\
\n\
When in doubt between a temporal/chronology question and a single-shot \
lookup, prefer `inject`. Respond with exactly one word (`inject` or \
`skip`), no punctuation, no explanation.";
#[derive(Debug, thiserror::Error)]
pub enum ClassifierError {
#[error("classifier config error: {0}")]
Config(String),
#[error("classifier transport error: {0}")]
Transport(String),
#[error("classifier response parse error: {0}")]
Parse(String),
}
pub type ClassifierResult<T> = Result<T, ClassifierError>;
#[derive(Debug)]
struct ClassifierCache {
capacity: usize,
ttl: Duration,
entries: Vec<(String, Route, Instant)>,
}
impl ClassifierCache {
fn new(capacity: usize, ttl: Duration) -> Self {
Self {
capacity,
ttl,
entries: Vec::with_capacity(capacity.min(1024)),
}
}
fn get(&mut self, key: &str) -> Option<Route> {
let now = Instant::now();
self.entries
.retain(|(_, _, ts)| now.duration_since(*ts) < self.ttl);
self.entries
.iter()
.find(|(k, _, _)| k == key)
.map(|(_, r, _)| *r)
}
fn put(&mut self, key: String, route: Route) {
if self.entries.len() >= self.capacity {
self.entries.remove(0);
}
self.entries.push((key, route, Instant::now()));
}
}
#[derive(Debug, Clone)]
pub struct HaikuQueryClassifier {
client: reqwest::Client,
api_key: String,
model: String,
base_url: String,
cache: Arc<Mutex<ClassifierCache>>,
}
impl HaikuQueryClassifier {
pub fn new(api_key: impl Into<String>) -> ClassifierResult<Self> {
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS))
.build()
.map_err(|e| ClassifierError::Config(format!("http client: {e}")))?;
Ok(Self {
client,
api_key: api_key.into(),
model: DEFAULT_MODEL.into(),
base_url: "https://api.anthropic.com".into(),
cache: Arc::new(Mutex::new(ClassifierCache::new(
1024,
Duration::from_secs(3600),
))),
})
}
pub fn from_env() -> ClassifierResult<Self> {
let api_key = std::env::var("ANTHROPIC_API_KEY")
.map_err(|_| ClassifierError::Config("ANTHROPIC_API_KEY env var not set".into()))?;
Self::new(api_key)
}
#[must_use]
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
self.base_url = base_url.into();
self
}
pub async fn classify(&self, query: &str) -> Route {
if super::classify_naive(query) == Route::Inject {
return Route::Inject;
}
self.classify_haiku_only(query).await
}
pub async fn classify_haiku_only(&self, query: &str) -> Route {
let key = normalize_query(query);
match self.cache.lock() {
Ok(mut cache) => {
if let Some(hit) = cache.get(&key) {
return hit;
}
}
Err(e) => {
tracing::warn!(
target: "pensyve::classifier",
error = %e,
"classifier cache lock poisoned on get; bypassing cache"
);
}
}
let route = match self.call_api(&key).await {
Ok(r) => r,
Err(e) => {
tracing::warn!(
target: "pensyve::classifier",
error = %e,
"Haiku classifier failed; falling back to naive regex"
);
super::classify_naive(query)
}
};
match self.cache.lock() {
Ok(mut cache) => cache.put(key, route),
Err(e) => {
tracing::warn!(
target: "pensyve::classifier",
error = %e,
"classifier cache lock poisoned on put; result not cached"
);
}
}
route
}
async fn call_api(&self, query: &str) -> ClassifierResult<Route> {
let req = AnthropicRequest {
model: &self.model,
max_tokens: DEFAULT_MAX_TOKENS,
temperature: 0.0,
system: CLASSIFIER_PROMPT_V1,
messages: vec![AnthropicMessage {
role: "user",
content: query,
}],
};
let url = format!("{}/v1/messages", self.base_url.trim_end_matches('/'));
let response = self
.client
.post(&url)
.header("x-api-key", &self.api_key)
.header("anthropic-version", ANTHROPIC_VERSION)
.header("content-type", "application/json")
.json(&req)
.send()
.await
.map_err(|e| ClassifierError::Transport(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(ClassifierError::Transport(format!("HTTP {status}: {body}")));
}
let parsed: AnthropicResponse = response
.json()
.await
.map_err(|e| ClassifierError::Parse(e.to_string()))?;
let text = parsed
.content
.into_iter()
.find(|b| b.block_type == "text")
.map(|b| b.text)
.unwrap_or_default();
parse_route(&text)
}
}
fn normalize_query(q: &str) -> String {
q.trim()
.to_ascii_lowercase()
.split_whitespace()
.collect::<Vec<_>>()
.join(" ")
}
fn parse_route(text: &str) -> ClassifierResult<Route> {
let trimmed = text.trim().to_ascii_lowercase();
if trimmed.starts_with("inject") {
Ok(Route::Inject)
} else if trimmed.starts_with("skip") {
Ok(Route::Skip)
} else {
Err(ClassifierError::Parse(format!(
"classifier returned unexpected output: {text:?}"
)))
}
}
#[derive(Debug, Serialize)]
struct AnthropicRequest<'a> {
model: &'a str,
max_tokens: u32,
temperature: f32,
system: &'a str,
messages: Vec<AnthropicMessage<'a>>,
}
#[derive(Debug, Serialize)]
struct AnthropicMessage<'a> {
role: &'a str,
content: &'a str,
}
#[derive(Debug, Deserialize)]
struct AnthropicResponse {
content: Vec<AnthropicContentBlock>,
}
#[derive(Debug, Deserialize)]
struct AnthropicContentBlock {
#[serde(rename = "type")]
block_type: String,
#[serde(default)]
text: String,
}
#[cfg(test)]
mod tests {
use super::*;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
fn anthropic_response(text: &str) -> serde_json::Value {
serde_json::json!({
"id": "msg_test",
"type": "message",
"role": "assistant",
"model": "claude-haiku-4-5-20251001",
"content": [{"type": "text", "text": text}],
"stop_reason": "end_turn",
"usage": {"input_tokens": 0, "output_tokens": 0},
})
}
#[tokio::test]
async fn classifier_returns_inject_on_inject_reply() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/messages"))
.respond_with(
ResponseTemplate::new(200).set_body_json(anthropic_response("inject")),
)
.mount(&server)
.await;
let c = HaikuQueryClassifier::new("test-key")
.unwrap()
.with_base_url(server.uri());
assert_eq!(
c.classify("how many books did I read?").await,
Route::Inject
);
}
#[tokio::test]
async fn classifier_returns_skip_on_skip_reply() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/messages"))
.respond_with(ResponseTemplate::new(200).set_body_json(anthropic_response("skip")))
.mount(&server)
.await;
let c = HaikuQueryClassifier::new("test-key")
.unwrap()
.with_base_url(server.uri());
assert_eq!(c.classify("what's my favorite color?").await, Route::Skip);
}
#[tokio::test]
async fn classifier_handles_trailing_punctuation() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/messages"))
.respond_with(
ResponseTemplate::new(200).set_body_json(anthropic_response("inject.\n")),
)
.mount(&server)
.await;
let c = HaikuQueryClassifier::new("k")
.unwrap()
.with_base_url(server.uri());
assert_eq!(c.classify("how many games?").await, Route::Inject);
}
#[tokio::test]
async fn classifier_falls_back_to_naive_on_http_error() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/messages"))
.respond_with(ResponseTemplate::new(500).set_body_string("broken"))
.mount(&server)
.await;
let c = HaikuQueryClassifier::new("k")
.unwrap()
.with_base_url(server.uri());
assert_eq!(c.classify("how many books?").await, Route::Inject);
assert_eq!(c.classify("what did I eat?").await, Route::Skip);
}
#[tokio::test]
async fn classifier_caches_repeat_queries() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/messages"))
.respond_with(
ResponseTemplate::new(200).set_body_json(anthropic_response("inject")),
)
.expect(1)
.mount(&server)
.await;
let c = HaikuQueryClassifier::new("k")
.unwrap()
.with_base_url(server.uri());
assert_eq!(
c.classify_haiku_only("what was the last book I read?")
.await,
Route::Inject
);
assert_eq!(
c.classify_haiku_only("what was the last book I read?")
.await,
Route::Inject
);
assert_eq!(
c.classify_haiku_only(" What Was The Last Book I Read? ")
.await,
Route::Inject
);
}
#[tokio::test]
async fn hybrid_classify_short_circuits_on_naive_inject() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/messages"))
.respond_with(ResponseTemplate::new(500).set_body_string("should not be hit"))
.expect(0)
.mount(&server)
.await;
let c = HaikuQueryClassifier::new("k")
.unwrap()
.with_base_url(server.uri());
assert_eq!(
c.classify("How many books did I read?").await,
Route::Inject
);
}
#[tokio::test]
async fn hybrid_classify_falls_through_to_haiku_when_naive_skips() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/messages"))
.respond_with(
ResponseTemplate::new(200).set_body_json(anthropic_response("inject")),
)
.expect(1)
.mount(&server)
.await;
let c = HaikuQueryClassifier::new("k")
.unwrap()
.with_base_url(server.uri());
assert_eq!(
c.classify("When did I first start using Rust?").await,
Route::Inject
);
}
#[test]
fn parse_route_rejects_garbage() {
assert!(parse_route("").is_err());
assert!(parse_route("I think maybe").is_err());
}
#[test]
fn normalize_query_lowercases_and_collapses_whitespace() {
assert_eq!(normalize_query(" How Many "), "how many");
assert_eq!(normalize_query("how\tmany\n\nbooks"), "how many books");
}
#[test]
fn cache_expires_entries() {
let mut cache = ClassifierCache::new(4, Duration::from_millis(50));
cache.put("q".into(), Route::Inject);
assert_eq!(cache.get("q"), Some(Route::Inject));
std::thread::sleep(Duration::from_millis(60));
assert_eq!(cache.get("q"), None);
}
#[test]
fn cache_evicts_oldest_when_full() {
let mut cache = ClassifierCache::new(2, Duration::from_secs(3600));
cache.put("a".into(), Route::Inject);
cache.put("b".into(), Route::Inject);
cache.put("c".into(), Route::Skip);
assert_eq!(cache.get("a"), None); assert_eq!(cache.get("b"), Some(Route::Inject));
assert_eq!(cache.get("c"), Some(Route::Skip));
}
}
}
#[cfg(feature = "observation-extraction")]
pub use haiku::{CLASSIFIER_PROMPT_V1, ClassifierError, ClassifierResult, HaikuQueryClassifier};
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn naive_classifier_catches_how_many() {
assert_eq!(classify_naive("how many games did I play?"), Route::Inject);
assert_eq!(classify_naive("How many books?"), Route::Inject);
assert_eq!(classify_naive("HOW MANY??"), Route::Inject);
}
#[test]
fn naive_classifier_catches_list_every() {
assert_eq!(
classify_naive("list every place I've visited"),
Route::Inject
);
assert_eq!(classify_naive("List all of the games"), Route::Inject);
}
#[test]
fn naive_classifier_catches_count() {
assert_eq!(classify_naive("count the total items"), Route::Inject);
}
#[test]
fn naive_classifier_catches_total() {
assert_eq!(
classify_naive("what's the total number of hours?"),
Route::Inject
);
assert_eq!(classify_naive("spent in total 40 hours"), Route::Inject);
}
#[test]
fn naive_classifier_catches_aggregation_phrases() {
assert_eq!(classify_naive("across all my sessions"), Route::Inject);
assert_eq!(classify_naive("over the course of a year"), Route::Inject);
assert_eq!(classify_naive("so far this year"), Route::Inject);
}
#[test]
fn naive_classifier_skips_non_counting_questions() {
assert_eq!(classify_naive("what is my favorite color?"), Route::Skip);
assert_eq!(classify_naive("who is my boss?"), Route::Skip);
assert_eq!(
classify_naive("remember to pick up milk tomorrow"),
Route::Skip
);
}
#[test]
fn naive_classifier_avoids_partial_word_matches() {
assert_eq!(classify_naive("my favorite counter"), Route::Skip);
assert_eq!(classify_naive("a discounted meal"), Route::Skip);
assert_eq!(classify_naive("the count was off"), Route::Inject);
}
#[test]
fn naive_classifier_handles_empty_input() {
assert_eq!(classify_naive(""), Route::Skip);
assert_eq!(classify_naive(" "), Route::Skip);
}
#[test]
fn route_as_str_returns_stable_strings() {
assert_eq!(Route::Inject.as_str(), "inject");
assert_eq!(Route::Skip.as_str(), "skip");
}
}