use serde::{Deserialize, Serialize};
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
use crate::backend::CandleBackend;
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
use crate::tasks::generate;
use crate::InferenceError;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClassifyRequest {
pub text: String,
pub labels: Vec<String>,
pub model: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClassifyResult {
pub label: String,
pub score: f64,
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
pub async fn classify(
backend: &mut CandleBackend,
req: ClassifyRequest,
) -> Result<Vec<ClassifyResult>, InferenceError> {
let labels_str = req
.labels
.iter()
.enumerate()
.map(|(i, l)| format!("{}. {}", i + 1, l))
.collect::<Vec<_>>()
.join("\n");
let prompt = format!(
"Classify the following text into one of these categories:\n\
{labels_str}\n\n\
Text: {}\n\n\
Respond with ONLY the category name, nothing else.",
req.text
);
let gen_req = generate::GenerateRequest {
prompt,
model: req.model.clone(),
params: generate::GenerateParams {
temperature: 0.0, max_tokens: 32,
..Default::default()
},
context: None,
tools: None,
images: None,
messages: None,
cache_control: false,
response_format: None,
intent: None,
};
let (response, _ttft_ms) = generate::generate(backend, gen_req).await?;
let response_lower = response.trim().to_lowercase();
let mut results: Vec<ClassifyResult> = req
.labels
.iter()
.map(|label| {
let label_lower = label.to_lowercase();
let score = if response_lower == label_lower {
1.0
} else if response_lower.contains(&label_lower) {
0.8
} else {
let label_words: Vec<&str> = label_lower.split_whitespace().collect();
let matches = label_words
.iter()
.filter(|w| response_lower.contains(**w))
.count();
if label_words.is_empty() {
0.0
} else {
0.5 * (matches as f64 / label_words.len() as f64)
}
};
ClassifyResult {
label: label.clone(),
score,
}
})
.collect();
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
let total: f64 = results.iter().map(|r| r.score).sum();
if total > 0.0 {
for r in &mut results {
r.score /= total;
}
}
Ok(results)
}