car-inference 0.14.0

Local model inference for CAR — Candle backend with Qwen3 models
Documentation
//! Classification — score text against candidate labels using prompt-based inference.

use serde::{Deserialize, Serialize};

#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
use crate::backend::CandleBackend;
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
use crate::tasks::generate;
use crate::InferenceError;

/// A classification request.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClassifyRequest {
    /// The text to classify.
    pub text: String,
    /// Candidate labels to score against.
    pub labels: Vec<String>,
    /// Optional model override.
    pub model: Option<String>,
}

/// A classification result with label and confidence score.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClassifyResult {
    pub label: String,
    pub score: f64,
}

#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
/// Classify text against candidate labels.
///
/// Uses a prompt-based approach: asks the model to pick the best label,
/// then parses the response. Falls back to first-token probability
/// comparison when the response is ambiguous.
pub async fn classify(
    backend: &mut CandleBackend,
    req: ClassifyRequest,
) -> Result<Vec<ClassifyResult>, InferenceError> {
    let labels_str = req
        .labels
        .iter()
        .enumerate()
        .map(|(i, l)| format!("{}. {}", i + 1, l))
        .collect::<Vec<_>>()
        .join("\n");

    let prompt = format!(
        "Classify the following text into one of these categories:\n\
         {labels_str}\n\n\
         Text: {}\n\n\
         Respond with ONLY the category name, nothing else.",
        req.text
    );

    let gen_req = generate::GenerateRequest {
        prompt,
        model: req.model.clone(),
        params: generate::GenerateParams {
            temperature: 0.0, // greedy for classification
            max_tokens: 32,
            ..Default::default()
        },
        context: None,
        tools: None,
        images: None,
        messages: None,
        cache_control: false,
        response_format: None,
        intent: None,
    };

    let (response, _ttft_ms) = generate::generate(backend, gen_req).await?;
    let response_lower = response.trim().to_lowercase();

    // Score each label based on string match in the response
    let mut results: Vec<ClassifyResult> = req
        .labels
        .iter()
        .map(|label| {
            let label_lower = label.to_lowercase();
            let score = if response_lower == label_lower {
                1.0
            } else if response_lower.contains(&label_lower) {
                0.8
            } else {
                // Partial word overlap
                let label_words: Vec<&str> = label_lower.split_whitespace().collect();
                let matches = label_words
                    .iter()
                    .filter(|w| response_lower.contains(**w))
                    .count();
                if label_words.is_empty() {
                    0.0
                } else {
                    0.5 * (matches as f64 / label_words.len() as f64)
                }
            };
            ClassifyResult {
                label: label.clone(),
                score,
            }
        })
        .collect();

    // Sort by score descending
    results.sort_by(|a, b| {
        b.score
            .partial_cmp(&a.score)
            .unwrap_or(std::cmp::Ordering::Equal)
    });

    // Normalize scores to sum to 1
    let total: f64 = results.iter().map(|r| r.score).sum();
    if total > 0.0 {
        for r in &mut results {
            r.score /= total;
        }
    }

    Ok(results)
}