Skip to main content

car_inference/tasks/
classify.rs

1//! Classification — score text against candidate labels using prompt-based inference.
2
3use serde::{Deserialize, Serialize};
4
5#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
6use crate::backend::CandleBackend;
7#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
8use crate::tasks::generate;
9use crate::InferenceError;
10
11/// A classification request.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct ClassifyRequest {
14    /// The text to classify.
15    pub text: String,
16    /// Candidate labels to score against.
17    pub labels: Vec<String>,
18    /// Optional model override.
19    pub model: Option<String>,
20}
21
22/// A classification result with label and confidence score.
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct ClassifyResult {
25    pub label: String,
26    pub score: f64,
27}
28
29#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
30/// Classify text against candidate labels.
31///
32/// Uses a prompt-based approach: asks the model to pick the best label,
33/// then parses the response. Falls back to first-token probability
34/// comparison when the response is ambiguous.
35pub async fn classify(
36    backend: &mut CandleBackend,
37    req: ClassifyRequest,
38) -> Result<Vec<ClassifyResult>, InferenceError> {
39    let labels_str = req
40        .labels
41        .iter()
42        .enumerate()
43        .map(|(i, l)| format!("{}. {}", i + 1, l))
44        .collect::<Vec<_>>()
45        .join("\n");
46
47    let prompt = format!(
48        "Classify the following text into one of these categories:\n\
49         {labels_str}\n\n\
50         Text: {}\n\n\
51         Respond with ONLY the category name, nothing else.",
52        req.text
53    );
54
55    let gen_req = generate::GenerateRequest {
56        prompt,
57        model: req.model.clone(),
58        params: generate::GenerateParams {
59            temperature: 0.0, // greedy for classification
60            max_tokens: 32,
61            ..Default::default()
62        },
63        context: None,
64        tools: None,
65        images: None,
66        messages: None,
67        cache_control: false,
68        response_format: None,
69        intent: None,
70    };
71
72    let (response, _ttft_ms) = generate::generate(backend, gen_req).await?;
73    let response_lower = response.trim().to_lowercase();
74
75    // Score each label based on string match in the response
76    let mut results: Vec<ClassifyResult> = req
77        .labels
78        .iter()
79        .map(|label| {
80            let label_lower = label.to_lowercase();
81            let score = if response_lower == label_lower {
82                1.0
83            } else if response_lower.contains(&label_lower) {
84                0.8
85            } else {
86                // Partial word overlap
87                let label_words: Vec<&str> = label_lower.split_whitespace().collect();
88                let matches = label_words
89                    .iter()
90                    .filter(|w| response_lower.contains(**w))
91                    .count();
92                if label_words.is_empty() {
93                    0.0
94                } else {
95                    0.5 * (matches as f64 / label_words.len() as f64)
96                }
97            };
98            ClassifyResult {
99                label: label.clone(),
100                score,
101            }
102        })
103        .collect();
104
105    // Sort by score descending
106    results.sort_by(|a, b| {
107        b.score
108            .partial_cmp(&a.score)
109            .unwrap_or(std::cmp::Ordering::Equal)
110    });
111
112    // Normalize scores to sum to 1
113    let total: f64 = results.iter().map(|r| r.score).sum();
114    if total > 0.0 {
115        for r in &mut results {
116            r.score /= total;
117        }
118    }
119
120    Ok(results)
121}