car_inference/tasks/
classify.rs1use serde::{Deserialize, Serialize};
4
5#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
6use crate::backend::CandleBackend;
7#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
8use crate::tasks::generate;
9use crate::InferenceError;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct ClassifyRequest {
14 pub text: String,
16 pub labels: Vec<String>,
18 pub model: Option<String>,
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct ClassifyResult {
25 pub label: String,
26 pub score: f64,
27}
28
29#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
30pub async fn classify(
36 backend: &mut CandleBackend,
37 req: ClassifyRequest,
38) -> Result<Vec<ClassifyResult>, InferenceError> {
39 let labels_str = req
40 .labels
41 .iter()
42 .enumerate()
43 .map(|(i, l)| format!("{}. {}", i + 1, l))
44 .collect::<Vec<_>>()
45 .join("\n");
46
47 let prompt = format!(
48 "Classify the following text into one of these categories:\n\
49 {labels_str}\n\n\
50 Text: {}\n\n\
51 Respond with ONLY the category name, nothing else.",
52 req.text
53 );
54
55 let gen_req = generate::GenerateRequest {
56 prompt,
57 model: req.model.clone(),
58 params: generate::GenerateParams {
59 temperature: 0.0, max_tokens: 32,
61 ..Default::default()
62 },
63 context: None,
64 tools: None,
65 images: None,
66 messages: None,
67 cache_control: false,
68 response_format: None,
69 intent: None,
70 };
71
72 let (response, _ttft_ms) = generate::generate(backend, gen_req).await?;
73 let response_lower = response.trim().to_lowercase();
74
75 let mut results: Vec<ClassifyResult> = req
77 .labels
78 .iter()
79 .map(|label| {
80 let label_lower = label.to_lowercase();
81 let score = if response_lower == label_lower {
82 1.0
83 } else if response_lower.contains(&label_lower) {
84 0.8
85 } else {
86 let label_words: Vec<&str> = label_lower.split_whitespace().collect();
88 let matches = label_words
89 .iter()
90 .filter(|w| response_lower.contains(**w))
91 .count();
92 if label_words.is_empty() {
93 0.0
94 } else {
95 0.5 * (matches as f64 / label_words.len() as f64)
96 }
97 };
98 ClassifyResult {
99 label: label.clone(),
100 score,
101 }
102 })
103 .collect();
104
105 results.sort_by(|a, b| {
107 b.score
108 .partial_cmp(&a.score)
109 .unwrap_or(std::cmp::Ordering::Equal)
110 });
111
112 let total: f64 = results.iter().map(|r| r.score).sum();
114 if total > 0.0 {
115 for r in &mut results {
116 r.score /= total;
117 }
118 }
119
120 Ok(results)
121}