Skip to main content

quiver_agent/
classify.rs

1//! Optional Haiku-backed task classifier for the agent loop.
2//!
3//! Each `UserText` event is fed through a [`TaskClassifier`] before the
4//! recommender. The classifier returns:
5//!
6//!   * `is_task=false` ⇒ engine drops the message (no hint, no
7//!     `agent_suggestions` row) — filters greetings, acks, status checks.
8//!   * `is_task=true` + `query` ⇒ engine embeds `query` (a focused rewrite),
9//!     not the raw text, and continues.
10//!
11//! Two impls ship here:
12//!
13//!   * [`NoopClassifier`] — passthrough, `is_task=true`, `query=raw`. Used
14//!     when the agent is started without `--classify` / `QUIVER_TASK_CLASSIFIER`.
15//!   * [`HaikuClassifier`] — calls Anthropic API (`ANTHROPIC_API_KEY`) or local
16//!     `claude` CLI. Mirrors the backend selection in
17//!     [`quiver_ingestion::llm_extract::ClaudeExtractor`]; same model, same
18//!     timeout pattern.
19//!
20//! The trait method never returns an error: every failure path inside
21//! `HaikuClassifier` (timeout, non-2xx, malformed JSON, missing CLI binary)
22//! logs a warning and falls back to passthrough so the agent never silently
23//! drops a real task because of a transient LLM glitch.
24
25use std::time::Duration;
26
27use anyhow::{Context, Result, anyhow};
28use async_trait::async_trait;
29use serde::Deserialize;
30use serde_json::json;
31
32const RAW_TRUNCATE_BYTES: usize = 2000;
33const CLASSIFY_TIMEOUT: Duration = Duration::from_secs(15);
34const ANTHROPIC_MODEL: &str = "claude-haiku-4-5-20251001";
35const ANTHROPIC_VERSION: &str = "2023-06-01";
36const ANTHROPIC_URL: &str = "https://api.anthropic.com/v1/messages";
37
38const SYSTEM_PROMPT: &str = "You triage developer messages for a tool recommender. \
39Return ONLY a JSON object (no prose, no code fences) matching: \
40{ \"is_task\": bool, \"query\": string }. \
41is_task=false when the message is a greeting, ack, status check, or pure chit-chat. \
42is_task=true when the user wants code written, changed, explained, or a tool invoked. \
43query: a short imperative summary (<=120 chars) of what the user wants done; \
44on is_task=false, return the empty string.";
45
46/// Output of one classifier call. `query` equals `raw` for the noop and on
47/// any LLM failure, so callers can use it unconditionally.
48#[derive(Debug, Clone, PartialEq, Eq)]
49pub struct ClassifiedTask {
50    pub is_task: bool,
51    pub query: String,
52}
53
54impl ClassifiedTask {
55    pub fn passthrough(raw: &str) -> Self {
56        Self {
57            is_task: true,
58            query: raw.to_string(),
59        }
60    }
61}
62
63#[async_trait]
64pub trait TaskClassifier: Send + Sync {
65    async fn classify(&self, raw: &str) -> ClassifiedTask;
66}
67
68// ── NoopClassifier ──────────────────────────────────────────────────────────
69
70#[derive(Default, Debug, Clone, Copy)]
71pub struct NoopClassifier;
72
73#[async_trait]
74impl TaskClassifier for NoopClassifier {
75    async fn classify(&self, raw: &str) -> ClassifiedTask {
76        ClassifiedTask::passthrough(raw)
77    }
78}
79
80// ── HaikuClassifier ─────────────────────────────────────────────────────────
81
82#[derive(Debug, Clone)]
83pub enum ClaudeBackend {
84    Api { api_key: String, base_url: String },
85    Cli { binary: String },
86}
87
88pub struct HaikuClassifier {
89    backend: ClaudeBackend,
90    timeout: Duration,
91}
92
93impl HaikuClassifier {
94    pub fn new(backend: ClaudeBackend) -> Self {
95        Self {
96            backend,
97            timeout: CLASSIFY_TIMEOUT,
98        }
99    }
100
101    pub fn with_timeout(mut self, t: Duration) -> Self {
102        self.timeout = t;
103        self
104    }
105
106    /// API > CLI. Returns `None` when neither is available.
107    pub fn detect() -> Option<Self> {
108        if let Ok(key) = std::env::var("ANTHROPIC_API_KEY")
109            && !key.trim().is_empty()
110        {
111            return Some(Self::new(ClaudeBackend::Api {
112                api_key: key,
113                base_url: ANTHROPIC_URL.to_string(),
114            }));
115        }
116        if let Some(bin) = which_claude() {
117            return Some(Self::new(ClaudeBackend::Cli { binary: bin }));
118        }
119        None
120    }
121
122    pub fn label(&self) -> &'static str {
123        match self.backend {
124            ClaudeBackend::Api { .. } => "haiku-api",
125            ClaudeBackend::Cli { .. } => "haiku-cli",
126        }
127    }
128
129    async fn try_classify(&self, raw: &str) -> Result<ClassifiedTask> {
130        let user_msg = format!(
131            "Developer message (truncated to {RAW_TRUNCATE_BYTES} chars):\n{body}",
132            body = truncate_chars(raw, RAW_TRUNCATE_BYTES)
133        );
134        let text = match &self.backend {
135            ClaudeBackend::Api { api_key, base_url } => tokio::time::timeout(
136                self.timeout,
137                call_anthropic_api(base_url, api_key, &user_msg),
138            )
139            .await
140            .map_err(|_| anyhow!("anthropic api timeout"))??,
141            ClaudeBackend::Cli { binary } => {
142                tokio::time::timeout(self.timeout, call_claude_cli(binary, &user_msg))
143                    .await
144                    .map_err(|_| anyhow!("claude cli timeout"))??
145            },
146        };
147        Ok(parse_classify_json(&text, raw))
148    }
149}
150
151#[async_trait]
152impl TaskClassifier for HaikuClassifier {
153    async fn classify(&self, raw: &str) -> ClassifiedTask {
154        match self.try_classify(raw).await {
155            Ok(c) => c,
156            Err(e) => {
157                tracing::warn!("haiku classify failed, passthrough: {e:#}");
158                ClassifiedTask::passthrough(raw)
159            },
160        }
161    }
162}
163
164fn truncate_chars(s: &str, max: usize) -> String {
165    if s.len() <= max {
166        return s.to_string();
167    }
168    let mut out = String::with_capacity(max);
169    for ch in s.chars() {
170        if out.len() + ch.len_utf8() > max {
171            break;
172        }
173        out.push(ch);
174    }
175    out
176}
177
178#[derive(Deserialize)]
179struct ApiResp {
180    content: Vec<ApiContentBlock>,
181}
182#[derive(Deserialize)]
183struct ApiContentBlock {
184    #[serde(default)]
185    r#type: String,
186    #[serde(default)]
187    text: String,
188}
189
190async fn call_anthropic_api(base_url: &str, api_key: &str, user_msg: &str) -> Result<String> {
191    let body = json!({
192        "model": ANTHROPIC_MODEL,
193        "max_tokens": 200,
194        "system": SYSTEM_PROMPT,
195        "messages": [{ "role": "user", "content": user_msg }],
196    });
197    let client = reqwest::Client::builder().build()?;
198    let resp = client
199        .post(base_url)
200        .header("x-api-key", api_key)
201        .header("anthropic-version", ANTHROPIC_VERSION)
202        .header("content-type", "application/json")
203        .json(&body)
204        .send()
205        .await
206        .context("post anthropic /v1/messages")?;
207    if !resp.status().is_success() {
208        let s = resp.status();
209        let t = resp.text().await.unwrap_or_default();
210        return Err(anyhow!("anthropic http {s}: {t}"));
211    }
212    let parsed: ApiResp = resp.json().await.context("decode anthropic response")?;
213    let text = parsed
214        .content
215        .into_iter()
216        .filter(|b| b.r#type == "text")
217        .map(|b| b.text)
218        .collect::<Vec<_>>()
219        .join("\n");
220    Ok(text)
221}
222
223#[derive(Deserialize)]
224struct CliResp {
225    #[serde(default)]
226    result: String,
227}
228
229async fn call_claude_cli(binary: &str, user_msg: &str) -> Result<String> {
230    use tokio::io::AsyncWriteExt;
231    use tokio::process::Command;
232
233    let prompt = format!("{SYSTEM_PROMPT}\n\n{user_msg}");
234    let mut child = Command::new(binary)
235        .args(["--print", "--output-format", "json"])
236        .stdin(std::process::Stdio::piped())
237        .stdout(std::process::Stdio::piped())
238        .stderr(std::process::Stdio::piped())
239        .spawn()
240        .context("spawn claude cli")?;
241    if let Some(mut stdin) = child.stdin.take() {
242        stdin.write_all(prompt.as_bytes()).await?;
243        stdin.shutdown().await?;
244    }
245    let out = child.wait_with_output().await.context("wait claude cli")?;
246    if !out.status.success() {
247        let stderr = String::from_utf8_lossy(&out.stderr);
248        return Err(anyhow!("claude cli exit {}: {stderr}", out.status));
249    }
250    let raw = String::from_utf8_lossy(&out.stdout).to_string();
251    let parsed: CliResp =
252        serde_json::from_str(&raw).context("parse claude --output-format json")?;
253    Ok(parsed.result)
254}
255
256#[derive(Deserialize)]
257struct LlmJson {
258    #[serde(default)]
259    is_task: Option<bool>,
260    #[serde(default)]
261    query: Option<String>,
262}
263
264/// Parse the LLM response into a `ClassifiedTask`. Tolerant of stray prose
265/// and code fences. On garbage / missing fields, returns
266/// `ClassifiedTask::passthrough(raw)` — never silently drops a message.
267pub fn parse_classify_json(text: &str, raw: &str) -> ClassifiedTask {
268    let trimmed = strip_code_fence(text.trim()).trim();
269    let start = trimmed.find('{');
270    let end = trimmed.rfind('}');
271    let payload = match (start, end) {
272        (Some(a), Some(b)) if b >= a => &trimmed[a..=b],
273        _ => return ClassifiedTask::passthrough(raw),
274    };
275    let parsed: LlmJson = match serde_json::from_str(payload) {
276        Ok(p) => p,
277        Err(e) => {
278            tracing::warn!("classify json parse failed: {e:#}");
279            return ClassifiedTask::passthrough(raw);
280        },
281    };
282    let is_task = parsed.is_task.unwrap_or(true);
283    let query_raw = parsed.query.unwrap_or_default();
284    let query_trim = query_raw.trim();
285    let query = if !is_task {
286        String::new()
287    } else if query_trim.is_empty() {
288        raw.to_string()
289    } else {
290        query_trim.to_string()
291    };
292    ClassifiedTask { is_task, query }
293}
294
295fn strip_code_fence(s: &str) -> &str {
296    if let Some(rest) = s.strip_prefix("```json").or_else(|| s.strip_prefix("```"))
297        && let Some(end) = rest.rfind("```")
298    {
299        return &rest[..end];
300    }
301    s
302}
303
304fn which_claude() -> Option<String> {
305    let path = std::env::var_os("PATH")?;
306    for dir in std::env::split_paths(&path) {
307        let candidate = dir.join("claude");
308        if candidate.is_file() {
309            return Some(candidate.to_string_lossy().into_owned());
310        }
311    }
312    None
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318
319    #[tokio::test]
320    async fn noop_classifier_returns_passthrough() {
321        let c = NoopClassifier;
322        let out = c.classify("write a tailwind config").await;
323        assert!(out.is_task);
324        assert_eq!(out.query, "write a tailwind config");
325    }
326
327    #[test]
328    fn parse_strips_fences_and_extracts_fields() {
329        let raw =
330            "Sure! Here:\n```json\n{\"is_task\":true,\"query\":\"extract design tokens\"}\n```\n";
331        let c = parse_classify_json(raw, "original");
332        assert!(c.is_task);
333        assert_eq!(c.query, "extract design tokens");
334    }
335
336    #[test]
337    fn parse_handles_is_task_false() {
338        let raw = "{\"is_task\":false,\"query\":\"\"}";
339        let c = parse_classify_json(raw, "thanks!");
340        assert!(!c.is_task);
341        assert_eq!(c.query, "");
342    }
343
344    #[test]
345    fn parse_falls_back_on_garbage() {
346        let c = parse_classify_json("not json at all", "real task");
347        assert!(c.is_task);
348        assert_eq!(c.query, "real task");
349    }
350
351    #[test]
352    fn parse_falls_back_on_unbalanced_braces() {
353        let c = parse_classify_json("{ broken json", "real task");
354        assert!(c.is_task);
355        assert_eq!(c.query, "real task");
356    }
357
358    #[test]
359    fn parse_empty_query_with_is_task_true_uses_raw() {
360        // Defensive: model says is_task=true but forgot the query — don't lose
361        // the original text.
362        let raw = "{\"is_task\":true,\"query\":\"\"}";
363        let c = parse_classify_json(raw, "wire up auth middleware");
364        assert!(c.is_task);
365        assert_eq!(c.query, "wire up auth middleware");
366    }
367
368    #[test]
369    fn parse_missing_is_task_defaults_to_true() {
370        let raw = "{\"query\":\"do x\"}";
371        let c = parse_classify_json(raw, "do x literal");
372        assert!(c.is_task);
373        assert_eq!(c.query, "do x");
374    }
375
376    #[test]
377    fn truncate_chars_respects_utf8_boundary() {
378        let s = "héllo wörld";
379        // 11 chars, 13 bytes — truncating to 6 bytes should land on a boundary.
380        let out = truncate_chars(s, 6);
381        assert!(s.starts_with(&out));
382        assert!(out.len() <= 6);
383    }
384}