1use std::time::Duration;
26
27use anyhow::{Context, Result, anyhow};
28use async_trait::async_trait;
29use serde::Deserialize;
30use serde_json::json;
31
32const RAW_TRUNCATE_BYTES: usize = 2000;
33const CLASSIFY_TIMEOUT: Duration = Duration::from_secs(15);
34const ANTHROPIC_MODEL: &str = "claude-haiku-4-5-20251001";
35const ANTHROPIC_VERSION: &str = "2023-06-01";
36const ANTHROPIC_URL: &str = "https://api.anthropic.com/v1/messages";
37
38const SYSTEM_PROMPT: &str = "You triage developer messages for a tool recommender. \
39Return ONLY a JSON object (no prose, no code fences) matching: \
40{ \"is_task\": bool, \"query\": string }. \
41is_task=false when the message is a greeting, ack, status check, or pure chit-chat. \
42is_task=true when the user wants code written, changed, explained, or a tool invoked. \
43query: a short imperative summary (<=120 chars) of what the user wants done; \
44on is_task=false, return the empty string.";
45
46#[derive(Debug, Clone, PartialEq, Eq)]
49pub struct ClassifiedTask {
50 pub is_task: bool,
51 pub query: String,
52}
53
54impl ClassifiedTask {
55 pub fn passthrough(raw: &str) -> Self {
56 Self {
57 is_task: true,
58 query: raw.to_string(),
59 }
60 }
61}
62
63#[async_trait]
64pub trait TaskClassifier: Send + Sync {
65 async fn classify(&self, raw: &str) -> ClassifiedTask;
66}
67
68#[derive(Default, Debug, Clone, Copy)]
71pub struct NoopClassifier;
72
73#[async_trait]
74impl TaskClassifier for NoopClassifier {
75 async fn classify(&self, raw: &str) -> ClassifiedTask {
76 ClassifiedTask::passthrough(raw)
77 }
78}
79
80#[derive(Debug, Clone)]
83pub enum ClaudeBackend {
84 Api { api_key: String, base_url: String },
85 Cli { binary: String },
86}
87
88pub struct HaikuClassifier {
89 backend: ClaudeBackend,
90 timeout: Duration,
91}
92
93impl HaikuClassifier {
94 pub fn new(backend: ClaudeBackend) -> Self {
95 Self {
96 backend,
97 timeout: CLASSIFY_TIMEOUT,
98 }
99 }
100
101 pub fn with_timeout(mut self, t: Duration) -> Self {
102 self.timeout = t;
103 self
104 }
105
106 pub fn detect() -> Option<Self> {
108 if let Ok(key) = std::env::var("ANTHROPIC_API_KEY")
109 && !key.trim().is_empty()
110 {
111 return Some(Self::new(ClaudeBackend::Api {
112 api_key: key,
113 base_url: ANTHROPIC_URL.to_string(),
114 }));
115 }
116 if let Some(bin) = which_claude() {
117 return Some(Self::new(ClaudeBackend::Cli { binary: bin }));
118 }
119 None
120 }
121
122 pub fn label(&self) -> &'static str {
123 match self.backend {
124 ClaudeBackend::Api { .. } => "haiku-api",
125 ClaudeBackend::Cli { .. } => "haiku-cli",
126 }
127 }
128
129 async fn try_classify(&self, raw: &str) -> Result<ClassifiedTask> {
130 let user_msg = format!(
131 "Developer message (truncated to {RAW_TRUNCATE_BYTES} chars):\n{body}",
132 body = truncate_chars(raw, RAW_TRUNCATE_BYTES)
133 );
134 let text = match &self.backend {
135 ClaudeBackend::Api { api_key, base_url } => tokio::time::timeout(
136 self.timeout,
137 call_anthropic_api(base_url, api_key, &user_msg),
138 )
139 .await
140 .map_err(|_| anyhow!("anthropic api timeout"))??,
141 ClaudeBackend::Cli { binary } => {
142 tokio::time::timeout(self.timeout, call_claude_cli(binary, &user_msg))
143 .await
144 .map_err(|_| anyhow!("claude cli timeout"))??
145 },
146 };
147 Ok(parse_classify_json(&text, raw))
148 }
149}
150
151#[async_trait]
152impl TaskClassifier for HaikuClassifier {
153 async fn classify(&self, raw: &str) -> ClassifiedTask {
154 match self.try_classify(raw).await {
155 Ok(c) => c,
156 Err(e) => {
157 tracing::warn!("haiku classify failed, passthrough: {e:#}");
158 ClassifiedTask::passthrough(raw)
159 },
160 }
161 }
162}
163
164fn truncate_chars(s: &str, max: usize) -> String {
165 if s.len() <= max {
166 return s.to_string();
167 }
168 let mut out = String::with_capacity(max);
169 for ch in s.chars() {
170 if out.len() + ch.len_utf8() > max {
171 break;
172 }
173 out.push(ch);
174 }
175 out
176}
177
178#[derive(Deserialize)]
179struct ApiResp {
180 content: Vec<ApiContentBlock>,
181}
182#[derive(Deserialize)]
183struct ApiContentBlock {
184 #[serde(default)]
185 r#type: String,
186 #[serde(default)]
187 text: String,
188}
189
190async fn call_anthropic_api(base_url: &str, api_key: &str, user_msg: &str) -> Result<String> {
191 let body = json!({
192 "model": ANTHROPIC_MODEL,
193 "max_tokens": 200,
194 "system": SYSTEM_PROMPT,
195 "messages": [{ "role": "user", "content": user_msg }],
196 });
197 let client = reqwest::Client::builder().build()?;
198 let resp = client
199 .post(base_url)
200 .header("x-api-key", api_key)
201 .header("anthropic-version", ANTHROPIC_VERSION)
202 .header("content-type", "application/json")
203 .json(&body)
204 .send()
205 .await
206 .context("post anthropic /v1/messages")?;
207 if !resp.status().is_success() {
208 let s = resp.status();
209 let t = resp.text().await.unwrap_or_default();
210 return Err(anyhow!("anthropic http {s}: {t}"));
211 }
212 let parsed: ApiResp = resp.json().await.context("decode anthropic response")?;
213 let text = parsed
214 .content
215 .into_iter()
216 .filter(|b| b.r#type == "text")
217 .map(|b| b.text)
218 .collect::<Vec<_>>()
219 .join("\n");
220 Ok(text)
221}
222
223#[derive(Deserialize)]
224struct CliResp {
225 #[serde(default)]
226 result: String,
227}
228
229async fn call_claude_cli(binary: &str, user_msg: &str) -> Result<String> {
230 use tokio::io::AsyncWriteExt;
231 use tokio::process::Command;
232
233 let prompt = format!("{SYSTEM_PROMPT}\n\n{user_msg}");
234 let mut child = Command::new(binary)
235 .args(["--print", "--output-format", "json"])
236 .stdin(std::process::Stdio::piped())
237 .stdout(std::process::Stdio::piped())
238 .stderr(std::process::Stdio::piped())
239 .spawn()
240 .context("spawn claude cli")?;
241 if let Some(mut stdin) = child.stdin.take() {
242 stdin.write_all(prompt.as_bytes()).await?;
243 stdin.shutdown().await?;
244 }
245 let out = child.wait_with_output().await.context("wait claude cli")?;
246 if !out.status.success() {
247 let stderr = String::from_utf8_lossy(&out.stderr);
248 return Err(anyhow!("claude cli exit {}: {stderr}", out.status));
249 }
250 let raw = String::from_utf8_lossy(&out.stdout).to_string();
251 let parsed: CliResp =
252 serde_json::from_str(&raw).context("parse claude --output-format json")?;
253 Ok(parsed.result)
254}
255
256#[derive(Deserialize)]
257struct LlmJson {
258 #[serde(default)]
259 is_task: Option<bool>,
260 #[serde(default)]
261 query: Option<String>,
262}
263
264pub fn parse_classify_json(text: &str, raw: &str) -> ClassifiedTask {
268 let trimmed = strip_code_fence(text.trim()).trim();
269 let start = trimmed.find('{');
270 let end = trimmed.rfind('}');
271 let payload = match (start, end) {
272 (Some(a), Some(b)) if b >= a => &trimmed[a..=b],
273 _ => return ClassifiedTask::passthrough(raw),
274 };
275 let parsed: LlmJson = match serde_json::from_str(payload) {
276 Ok(p) => p,
277 Err(e) => {
278 tracing::warn!("classify json parse failed: {e:#}");
279 return ClassifiedTask::passthrough(raw);
280 },
281 };
282 let is_task = parsed.is_task.unwrap_or(true);
283 let query_raw = parsed.query.unwrap_or_default();
284 let query_trim = query_raw.trim();
285 let query = if !is_task {
286 String::new()
287 } else if query_trim.is_empty() {
288 raw.to_string()
289 } else {
290 query_trim.to_string()
291 };
292 ClassifiedTask { is_task, query }
293}
294
295fn strip_code_fence(s: &str) -> &str {
296 if let Some(rest) = s.strip_prefix("```json").or_else(|| s.strip_prefix("```"))
297 && let Some(end) = rest.rfind("```")
298 {
299 return &rest[..end];
300 }
301 s
302}
303
304fn which_claude() -> Option<String> {
305 let path = std::env::var_os("PATH")?;
306 for dir in std::env::split_paths(&path) {
307 let candidate = dir.join("claude");
308 if candidate.is_file() {
309 return Some(candidate.to_string_lossy().into_owned());
310 }
311 }
312 None
313}
314
315#[cfg(test)]
316mod tests {
317 use super::*;
318
319 #[tokio::test]
320 async fn noop_classifier_returns_passthrough() {
321 let c = NoopClassifier;
322 let out = c.classify("write a tailwind config").await;
323 assert!(out.is_task);
324 assert_eq!(out.query, "write a tailwind config");
325 }
326
327 #[test]
328 fn parse_strips_fences_and_extracts_fields() {
329 let raw =
330 "Sure! Here:\n```json\n{\"is_task\":true,\"query\":\"extract design tokens\"}\n```\n";
331 let c = parse_classify_json(raw, "original");
332 assert!(c.is_task);
333 assert_eq!(c.query, "extract design tokens");
334 }
335
336 #[test]
337 fn parse_handles_is_task_false() {
338 let raw = "{\"is_task\":false,\"query\":\"\"}";
339 let c = parse_classify_json(raw, "thanks!");
340 assert!(!c.is_task);
341 assert_eq!(c.query, "");
342 }
343
344 #[test]
345 fn parse_falls_back_on_garbage() {
346 let c = parse_classify_json("not json at all", "real task");
347 assert!(c.is_task);
348 assert_eq!(c.query, "real task");
349 }
350
351 #[test]
352 fn parse_falls_back_on_unbalanced_braces() {
353 let c = parse_classify_json("{ broken json", "real task");
354 assert!(c.is_task);
355 assert_eq!(c.query, "real task");
356 }
357
358 #[test]
359 fn parse_empty_query_with_is_task_true_uses_raw() {
360 let raw = "{\"is_task\":true,\"query\":\"\"}";
363 let c = parse_classify_json(raw, "wire up auth middleware");
364 assert!(c.is_task);
365 assert_eq!(c.query, "wire up auth middleware");
366 }
367
368 #[test]
369 fn parse_missing_is_task_defaults_to_true() {
370 let raw = "{\"query\":\"do x\"}";
371 let c = parse_classify_json(raw, "do x literal");
372 assert!(c.is_task);
373 assert_eq!(c.query, "do x");
374 }
375
376 #[test]
377 fn truncate_chars_respects_utf8_boundary() {
378 let s = "héllo wörld";
379 let out = truncate_chars(s, 6);
381 assert!(s.starts_with(&out));
382 assert!(out.len() <= 6);
383 }
384}