car-inference 0.14.0

Local model inference for CAR — Candle backend with Qwen3 models
Documentation
//! Candle inference backend — loads GGUF models, runs on Metal/CUDA/CPU.
//!
//! Supports both standard Qwen3 and Qwen3-MoE (Mixture of Experts) architectures.
//! Auto-detects the architecture from GGUF metadata.

use std::path::Path;

use candle_core::quantized::gguf_file;
use candle_core::{Device as CandleDevice, Tensor};
use candle_transformers::models::quantized_qwen3 as qwen;
use tokenizers::Tokenizer;

use super::moe::Qwen3MoeModel;
use crate::{Device, InferenceError};

/// Loaded model — either standard Qwen3 or MoE variant.
enum QwenModel {
    Standard(qwen::ModelWeights),
    /// MoE model using our naive (non-fused) implementation that works on Metal/CPU.
    Moe(Qwen3MoeModel),
}

/// A loaded model ready for inference.
pub struct CandleBackend {
    model: QwenModel,
    pub tokenizer: Tokenizer,
    pub device: CandleDevice,
}

impl CandleBackend {
    /// Load a GGUF model + tokenizer from a model directory.
    ///
    /// Expects `model.gguf` and `tokenizer.json` in `model_dir`.
    /// Auto-detects standard vs MoE architecture from GGUF metadata.
    pub fn load(model_dir: &Path, device: Device) -> Result<Self, InferenceError> {
        let candle_device = to_candle_device(device)?;

        // Load GGUF weights
        let model_path = model_dir.join("model.gguf");
        let mut file = std::fs::File::open(&model_path)
            .map_err(|e| InferenceError::InferenceFailed(format!("open model: {e}")))?;
        let gguf = gguf_file::Content::read(&mut file)
            .map_err(|e| InferenceError::InferenceFailed(format!("read gguf: {e}")))?;

        // Detect architecture from GGUF metadata
        let arch = gguf
            .metadata
            .get("general.architecture")
            .and_then(|v| v.to_string().ok())
            .map(|s| s.to_string())
            .unwrap_or_default();

        let model = if arch == "qwen3moe" {
            let moe = Qwen3MoeModel::from_gguf(gguf, &mut file, &candle_device)
                .map_err(|e| InferenceError::InferenceFailed(format!("load moe weights: {e}")))?;
            QwenModel::Moe(moe)
        } else {
            let std = qwen::ModelWeights::from_gguf(gguf, &mut file, &candle_device)
                .map_err(|e| InferenceError::InferenceFailed(format!("load weights: {e}")))?;
            QwenModel::Standard(std)
        };

        // Load tokenizer
        let tokenizer_path = model_dir.join("tokenizer.json");
        let tokenizer = Tokenizer::from_file(&tokenizer_path)
            .map_err(|e| InferenceError::TokenizationError(format!("load tokenizer: {e}")))?;

        Ok(Self {
            model,
            tokenizer,
            device: candle_device,
        })
    }

    /// Clear the KV cache so the next forward pass starts fresh.
    pub fn clear_kv_cache(&mut self) {
        match &mut self.model {
            QwenModel::Standard(m) => m.clear_kv_cache(),
            QwenModel::Moe(m) => m.clear_kv_cache(),
        }
    }

    /// Run a forward pass for a sequence of token IDs. Returns logits.
    pub fn forward(&mut self, tokens: &[u32], pos: usize) -> Result<Tensor, InferenceError> {
        let input = Tensor::new(tokens, &self.device)
            .map_err(|e| InferenceError::InferenceFailed(format!("tensor: {e}")))?
            .unsqueeze(0)
            .map_err(|e| InferenceError::InferenceFailed(format!("unsqueeze: {e}")))?;

        match &mut self.model {
            QwenModel::Standard(m) => m
                .forward(&input, pos)
                .map_err(|e| InferenceError::InferenceFailed(format!("forward: {e}"))),
            QwenModel::Moe(m) => m
                .forward(&input, pos)
                .map_err(|e| InferenceError::InferenceFailed(format!("forward moe: {e}"))),
        }
    }

    /// Encode text to token IDs.
    pub fn encode(&self, text: &str) -> Result<Vec<u32>, InferenceError> {
        let encoding = self
            .tokenizer
            .encode(text, true)
            .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
        Ok(encoding.get_ids().to_vec())
    }

    /// Decode token IDs back to text.
    pub fn decode(&self, tokens: &[u32]) -> Result<String, InferenceError> {
        self.tokenizer
            .decode(tokens, true)
            .map_err(|e| InferenceError::TokenizationError(e.to_string()))
    }

    /// Encode text to token IDs *without* adding tokenizer special tokens
    /// (BOS, etc.). Pair with [`Self::detokenize_raw`] for the round-trip
    /// property `detokenize_raw(tokenize_raw(s)) == s` that downstream
    /// validation harnesses (e.g. tokhn) check.
    pub fn tokenize_raw(&self, text: &str) -> Result<Vec<u32>, InferenceError> {
        let encoding = self
            .tokenizer
            .encode(text, false)
            .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
        Ok(encoding.get_ids().to_vec())
    }

    /// Decode token IDs back to text *without* skipping special tokens, so
    /// the caller sees exactly what's in the token sequence (matching the
    /// raw-tokenize path).
    pub fn detokenize_raw(&self, tokens: &[u32]) -> Result<String, InferenceError> {
        self.tokenizer
            .decode(tokens, false)
            .map_err(|e| InferenceError::TokenizationError(e.to_string()))
    }

    /// Get the EOS token ID.
    pub fn eos_token_id(&self) -> Option<u32> {
        self.tokenizer
            .token_to_id("<|endoftext|>")
            .or_else(|| self.tokenizer.token_to_id("</s>"))
    }

    /// Look up any token's ID by string.
    pub fn token_id(&self, token: &str) -> Option<u32> {
        self.tokenizer.token_to_id(token)
    }

    /// Get the model's maximum context length (tokens).
    /// Returns None if not determinable from GGUF metadata.
    pub fn context_length(&self) -> Option<usize> {
        // Qwen3 models default to 32768 context
        Some(32768)
    }
}

/// Convert our Device enum to candle's Device. Public for use by vision backend.
pub fn to_candle_device_pub(device: Device) -> Result<CandleDevice, InferenceError> {
    to_candle_device(device)
}

fn to_candle_device(device: Device) -> Result<CandleDevice, InferenceError> {
    match device {
        Device::Cpu => Ok(CandleDevice::Cpu),
        Device::Metal => {
            #[cfg(feature = "metal")]
            {
                Ok(CandleDevice::new_metal(0)
                    .map_err(|e| InferenceError::DeviceError(format!("metal: {e}")))?)
            }
            #[cfg(not(feature = "metal"))]
            {
                Err(InferenceError::DeviceError(
                    "metal feature not enabled".to_string(),
                ))
            }
        }
        Device::Cuda(ordinal) => {
            #[cfg(feature = "cuda")]
            {
                Ok(CandleDevice::new_cuda(ordinal)
                    .map_err(|e| InferenceError::DeviceError(format!("cuda({ordinal}): {e}")))?)
            }
            #[cfg(not(feature = "cuda"))]
            {
                let _ = ordinal;
                Err(InferenceError::DeviceError(
                    "cuda feature not enabled".to_string(),
                ))
            }
        }
    }
}