car-inference 0.29.0

Local model inference for CAR — Candle backend with Qwen3 models
//! Local in-process inference backends — the local mirror of the remote
//! [`ProtocolHandler`](crate::protocol::ProtocolHandler) abstraction.
//!
//! Remote models dispatch cleanly through `ProtocolHandler` + `handler_for`.
//! Local (in-process) models historically did not: the engine hand-dispatched
//! a single hardcoded `Qwen3Model` via ad-hoc `cfg` + tag matching. This module
//! introduces the missing seam so each architecture is one isolated backend
//! impl instead of another arm in `generate_tracked_inner`.
//!
//! Two layers, on purpose:
//!
//! - [`TextDecoder`] — the token-level primitives (encode / forward / decode /
//!   eos / context / cache). The engine's shared decode loop (`drive_generation`)
//!   runs sampling, stop detection, TTFT timing, and the Metal panic-catch +
//!   cache-invalidation over `&mut dyn TextDecoder`, so that engine-state-coupled
//!   logic lives in ONE place and every backend reuses it.
//! - [`LocalInferenceBackend`] — the higher-level surface (capability contract,
//!   prompt rendering, tool-call parsing). A `LocalInferenceBackend` IS a
//!   `TextDecoder` (supertrait); the engine upcasts to drive the loop.
//!
//! Adding an architecture = implement both traits for a new backend struct and
//! add one arm to [`local_backend_for`]. No engine-dispatch surgery.

use crate::schema::ModelCapability;
use crate::tasks::generate::{parse_tool_calls, render_chat_prompt, GenerateRequest, ToolCall};
use crate::InferenceError;

/// Token-level primitives every in-process text backend exposes.
///
/// `forward` returns the next-token logits as `Vec<f32>` (the unifying shape the
/// shared MLX sampler `sample_from_logits` consumes). Backends whose native
/// forward returns something else (Candle returns a `Tensor`) adapt here.
pub trait TextDecoder: Send {
    /// Encode text to token ids (with the tokenizer's special tokens).
    fn encode(&self, text: &str) -> Result<Vec<u32>, InferenceError>;

    /// Decode token ids back to text.
    fn decode(&self, tokens: &[u32]) -> Result<String, InferenceError>;

    /// One prefill/decode pass; returns the final position's logits.
    fn forward(&mut self, tokens: &[u32], pos: usize) -> Result<Vec<f32>, InferenceError>;

    /// Every token id that terminates generation — the model's `eos` plus any
    /// chat-turn-ending control tokens. The shared loop stops on any of these,
    /// so it never needs to know a specific architecture's eos convention.
    fn eos_ids(&self) -> Vec<u32>;

    /// Context window in tokens (for the prompt-truncation guard).
    fn context_length(&self) -> usize;

    /// Reset the KV cache between independent generations.
    fn clear_kv_cache(&mut self);

    /// Prepare the KV cache to prefill `prompt_tokens`, reusing any already-cached
    /// matching prefix (prompt/prefix caching). Returns the offset to begin
    /// prefilling from: `prompt_tokens[..offset]` are already in the cache (their
    /// KV is identical because KV for a fixed token prefix at fixed positions is
    /// deterministic), so the caller only prefills `prompt_tokens[offset..]` at
    /// position `offset`.
    ///
    /// The default clears the cache and returns 0 — a full re-prefill, no reuse.
    /// Backends opt into cross-call reuse by overriding this (and tracking which
    /// tokens their cache represents). This is the big win for multi-turn agent
    /// loops, where each turn re-sends the whole growing conversation.
    fn begin_prompt(&mut self, prompt_tokens: &[u32]) -> usize {
        let _ = prompt_tokens;
        self.clear_kv_cache();
        0
    }
}

/// Result of one engine-driven decode pass.
pub struct LocalGeneration {
    pub text: String,
    pub ttft_ms: Option<u64>,
    /// `"stop"` (hit an eos id) or `"length"` (hit max_tokens).
    pub stop_reason: Option<String>,
    /// Prompt tokens fed to the model (post context-window truncation), so the
    /// in-process path reports real `TokenUsage` like the remote providers do —
    /// without it, decode-throughput (tokens/sec) is unmeasurable (it was always
    /// 0 for local models).
    pub prompt_tokens: usize,
    /// Tokens the model generated this pass.
    pub completion_tokens: usize,
}

/// Outcome of the shared decode loop. Distinguishes a normal failure (the
/// backend is still usable) from one where a panic crossed the compute/FFI
/// boundary and left the backend in an indeterminate state — the caller, which
/// owns the cache lock, must then evict it.
pub enum DriveError {
    /// Normal failure (encode/decode/sample). Backend remains usable.
    Recoverable(InferenceError),
    /// A panic was caught mid-forward. The caller MUST drop its guard and
    /// invalidate the backend from the cache before returning.
    BackendCorrupted(InferenceError),
}

impl DriveError {
    pub fn into_inner(self) -> InferenceError {
        match self {
            DriveError::Recoverable(e) | DriveError::BackendCorrupted(e) => e,
        }
    }
}

/// What a backend claims *without* loading weights — consulted by the registry
/// gate (does a backend exist for this `model_type`?) and dispatch. Replaces the
/// flat `KNOWN_LLM_TYPES` whitelist as the source of truth.
pub struct BackendDescriptor {
    pub backend_name: &'static str,
    /// `config.json` `model_type` strings this backend services.
    pub model_types: &'static [&'static str],
}

/// The local in-process mirror of [`ProtocolHandler`](crate::protocol::ProtocolHandler).
/// One impl per architecture family; loaded lazily and cached in the engine.
pub trait LocalInferenceBackend: TextDecoder {
    /// Stable id for tracing / unsupported-mode messages (e.g. `"native-mlx-qwen3"`).
    fn backend_name(&self) -> &'static str;

    /// The *execution* contract: what THIS loaded checkpoint can actually
    /// service, independent of the registry's routing claim. Mirrors today's
    /// `MlxBackend::supports_capability`.
    fn supports_capability(&self, cap: ModelCapability) -> bool;

    /// Render a request to the model's wire prompt string. The default is the
    /// hardcoded Qwen3 chat format; backends with their own template (gemma) or
    /// the data-driven `ChatTemplate` path override this.
    fn render_prompt(&self, req: &GenerateRequest) -> Result<String, InferenceError> {
        Ok(render_chat_prompt(req))
    }

    /// Extract tool calls from generated text. The default understands the Qwen
    /// Hermes `<tool_call>{json}</tool_call>` convention; architectures with a
    /// different convention (gemma's `<|tool_call>…<tool_call|>`) override.
    fn parse_tool_calls(&self, text: &str) -> (String, Vec<ToolCall>) {
        parse_tool_calls(text)
    }
}

// ── MLX (Apple Silicon) ──────────────────────────────────────────────────────

#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
mod mlx_impl {
    use super::*;
    use crate::backend::MlxBackend;

    impl TextDecoder for MlxBackend {
        fn encode(&self, text: &str) -> Result<Vec<u32>, InferenceError> {
            MlxBackend::encode(self, text)
        }

        fn decode(&self, tokens: &[u32]) -> Result<String, InferenceError> {
            MlxBackend::decode(self, tokens)
        }

        fn forward(&mut self, tokens: &[u32], pos: usize) -> Result<Vec<f32>, InferenceError> {
            MlxBackend::forward(self, tokens, pos)
        }

        fn eos_ids(&self) -> Vec<u32> {
            // Qwen3: the config eos plus the chat-turn-ender `<|im_end|>`.
            let mut ids = Vec::new();
            if let Some(e) = self.eos_token_id() {
                ids.push(e);
            }
            if let Some(e) = self.token_id("<|im_end|>") {
                if !ids.contains(&e) {
                    ids.push(e);
                }
            }
            ids
        }

        fn context_length(&self) -> usize {
            MlxBackend::context_length(self)
        }

        fn clear_kv_cache(&mut self) {
            MlxBackend::clear_kv_cache(self)
        }
    }

    impl LocalInferenceBackend for MlxBackend {
        fn backend_name(&self) -> &'static str {
            "native-mlx-qwen3"
        }

        fn supports_capability(&self, cap: ModelCapability) -> bool {
            MlxBackend::supports_capability(self, cap)
        }
        // render_prompt / parse_tool_calls: Qwen3 defaults are correct.
    }

    // ── Gemma 4 (gemma4_unified text) ────────────────────────────────────────

    use crate::backend::mlx_gemma4::{parse_gemma4_tool_calls, Gemma4Backend};

    impl TextDecoder for Gemma4Backend {
        fn encode(&self, text: &str) -> Result<Vec<u32>, InferenceError> {
            Gemma4Backend::encode(self, text)
        }

        fn decode(&self, tokens: &[u32]) -> Result<String, InferenceError> {
            Gemma4Backend::decode(self, tokens)
        }

        fn forward(&mut self, tokens: &[u32], pos: usize) -> Result<Vec<f32>, InferenceError> {
            Gemma4Backend::forward(self, tokens, pos)
        }

        fn eos_ids(&self) -> Vec<u32> {
            Gemma4Backend::eos_token_ids(self)
        }

        fn context_length(&self) -> usize {
            Gemma4Backend::context_length(self)
        }

        fn clear_kv_cache(&mut self) {
            Gemma4Backend::clear_kv_cache(self)
        }

        fn begin_prompt(&mut self, prompt_tokens: &[u32]) -> usize {
            Gemma4Backend::begin_prompt(self, prompt_tokens)
        }
    }

    impl LocalInferenceBackend for Gemma4Backend {
        fn backend_name(&self) -> &'static str {
            "native-mlx-gemma4"
        }

        fn supports_capability(&self, cap: ModelCapability) -> bool {
            use ModelCapability as C;
            match cap {
                // Text-only for now (vision/audio towers not loaded).
                C::Generate
                | C::ToolUse
                | C::MultiToolCall
                | C::Reasoning
                | C::Summarize
                | C::Code
                | C::Classify => true,
                C::Rerank
                | C::Embed
                | C::Grounding
                | C::Vision
                | C::VideoUnderstanding
                | C::AudioUnderstanding
                | C::SpeechToText
                | C::TextToSpeech
                | C::ImageGeneration
                | C::VideoGeneration => false,
            }
        }

        fn render_prompt(&self, req: &GenerateRequest) -> Result<String, InferenceError> {
            match self.chat_template() {
                Some(t) => t.render_request(req),
                None => Ok(render_chat_prompt(req)),
            }
        }

        fn parse_tool_calls(&self, text: &str) -> (String, Vec<ToolCall>) {
            parse_gemma4_tool_calls(text)
        }
    }
}

// Non-macOS (Candle) deliberately does NOT implement these traits: the Candle
// backend is already a single generic GGUF loader, so it has no multi-arch
// dispatch problem to solve. This abstraction targets the macOS MLX path, where
// each architecture (Qwen3, Gemma 4, …) is a distinct hand-written backend that
// would otherwise each need its own arm in the engine's dispatch.

// ── Dispatch (macOS MLX) ─────────────────────────────────────────────────────

/// Read a local model's `config.json` `model_type` (lowercased). This is the
/// authoritative architecture signal — the same field the registry gates on.
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
fn read_model_type(model_dir: &std::path::Path) -> Result<String, InferenceError> {
    let cfg_path = model_dir.join("config.json");
    let raw = std::fs::read_to_string(&cfg_path).map_err(|e| {
        InferenceError::InferenceFailed(format!("read {}: {e}", cfg_path.display()))
    })?;
    let cfg: serde_json::Value = serde_json::from_str(&raw).map_err(|e| {
        InferenceError::InferenceFailed(format!("parse {}: {e}", cfg_path.display()))
    })?;
    Ok(cfg
        .get("model_type")
        .and_then(|v| v.as_str())
        .unwrap_or("")
        .to_ascii_lowercase())
}

/// The local in-process mirror of [`handler_for`](crate::protocol::handler_for):
/// construct the loaded backend for a model directory, dispatched on its
/// `config.json` `model_type`.
///
/// Qwen3 is intentionally absent — it keeps the dedicated `MlxBackend` path on
/// the engine (which also backs streaming / tokenize / embeddings). This
/// dispatch is for the *additional* architectures that path doesn't serve.
/// Architecture arms are added as their backends land (Gemma 4: B4).
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
pub fn local_backend_for(
    model_dir: &std::path::Path,
) -> Result<Box<dyn LocalInferenceBackend>, InferenceError> {
    let model_type = read_model_type(model_dir)?;
    match model_type.as_str() {
        "gemma4_unified" | "gemma4_unified_text" => Ok(Box::new(
            crate::backend::mlx_gemma4::Gemma4Backend::load(model_dir)?,
        )),
        other => Err(InferenceError::InferenceFailed(format!(
            "no in-process MLX backend for model_type '{other}' ({}); Qwen3 uses the \
             dedicated native path and other architectures route to vLLM-MLX",
            model_dir.display()
        ))),
    }
}