car_inference/tasks/
classify.rs1use serde::{Deserialize, Serialize};
4
5use crate::backend::CandleBackend;
6use crate::tasks::generate;
7use crate::InferenceError;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct ClassifyRequest {
12 pub text: String,
14 pub labels: Vec<String>,
16 pub model: Option<String>,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct ClassifyResult {
23 pub label: String,
24 pub score: f64,
25}
26
27pub async fn classify(
33 backend: &mut CandleBackend,
34 req: ClassifyRequest,
35) -> Result<Vec<ClassifyResult>, InferenceError> {
36 let labels_str = req.labels
37 .iter()
38 .enumerate()
39 .map(|(i, l)| format!("{}. {}", i + 1, l))
40 .collect::<Vec<_>>()
41 .join("\n");
42
43 let prompt = format!(
44 "Classify the following text into one of these categories:\n\
45 {labels_str}\n\n\
46 Text: {}\n\n\
47 Respond with ONLY the category name, nothing else.",
48 req.text
49 );
50
51 let gen_req = generate::GenerateRequest {
52 prompt,
53 model: req.model.clone(),
54 params: generate::GenerateParams {
55 temperature: 0.0, max_tokens: 32,
57 ..Default::default()
58 },
59 context: None,
60 };
61
62 let response = generate::generate(backend, gen_req).await?;
63 let response_lower = response.trim().to_lowercase();
64
65 let mut results: Vec<ClassifyResult> = req.labels
67 .iter()
68 .map(|label| {
69 let label_lower = label.to_lowercase();
70 let score = if response_lower == label_lower {
71 1.0
72 } else if response_lower.contains(&label_lower) {
73 0.8
74 } else {
75 let label_words: Vec<&str> = label_lower.split_whitespace().collect();
77 let matches = label_words
78 .iter()
79 .filter(|w| response_lower.contains(**w))
80 .count();
81 if label_words.is_empty() {
82 0.0
83 } else {
84 0.5 * (matches as f64 / label_words.len() as f64)
85 }
86 };
87 ClassifyResult {
88 label: label.clone(),
89 score,
90 }
91 })
92 .collect();
93
94 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
96
97 let total: f64 = results.iter().map(|r| r.score).sum();
99 if total > 0.0 {
100 for r in &mut results {
101 r.score /= total;
102 }
103 }
104
105 Ok(results)
106}