Skip to main content

car_inference/tasks/
classify.rs

1//! Classification — score text against candidate labels using prompt-based inference.
2
3use serde::{Deserialize, Serialize};
4
5use crate::backend::CandleBackend;
6use crate::tasks::generate;
7use crate::InferenceError;
8
9/// A classification request.
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct ClassifyRequest {
12    /// The text to classify.
13    pub text: String,
14    /// Candidate labels to score against.
15    pub labels: Vec<String>,
16    /// Optional model override.
17    pub model: Option<String>,
18}
19
20/// A classification result with label and confidence score.
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct ClassifyResult {
23    pub label: String,
24    pub score: f64,
25}
26
27/// Classify text against candidate labels.
28///
29/// Uses a prompt-based approach: asks the model to pick the best label,
30/// then parses the response. Falls back to first-token probability
31/// comparison when the response is ambiguous.
32pub async fn classify(
33    backend: &mut CandleBackend,
34    req: ClassifyRequest,
35) -> Result<Vec<ClassifyResult>, InferenceError> {
36    let labels_str = req.labels
37        .iter()
38        .enumerate()
39        .map(|(i, l)| format!("{}. {}", i + 1, l))
40        .collect::<Vec<_>>()
41        .join("\n");
42
43    let prompt = format!(
44        "Classify the following text into one of these categories:\n\
45         {labels_str}\n\n\
46         Text: {}\n\n\
47         Respond with ONLY the category name, nothing else.",
48        req.text
49    );
50
51    let gen_req = generate::GenerateRequest {
52        prompt,
53        model: req.model.clone(),
54        params: generate::GenerateParams {
55            temperature: 0.0, // greedy for classification
56            max_tokens: 32,
57            ..Default::default()
58        },
59        context: None,
60    };
61
62    let response = generate::generate(backend, gen_req).await?;
63    let response_lower = response.trim().to_lowercase();
64
65    // Score each label based on string match in the response
66    let mut results: Vec<ClassifyResult> = req.labels
67        .iter()
68        .map(|label| {
69            let label_lower = label.to_lowercase();
70            let score = if response_lower == label_lower {
71                1.0
72            } else if response_lower.contains(&label_lower) {
73                0.8
74            } else {
75                // Partial word overlap
76                let label_words: Vec<&str> = label_lower.split_whitespace().collect();
77                let matches = label_words
78                    .iter()
79                    .filter(|w| response_lower.contains(**w))
80                    .count();
81                if label_words.is_empty() {
82                    0.0
83                } else {
84                    0.5 * (matches as f64 / label_words.len() as f64)
85                }
86            };
87            ClassifyResult {
88                label: label.clone(),
89                score,
90            }
91        })
92        .collect();
93
94    // Sort by score descending
95    results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
96
97    // Normalize scores to sum to 1
98    let total: f64 = results.iter().map(|r| r.score).sum();
99    if total > 0.0 {
100        for r in &mut results {
101            r.score /= total;
102        }
103    }
104
105    Ok(results)
106}