nornir 0.4.31

Companion to cargo: dependency tracking, release gating, deploy, benchmarks, and documentation assembly. Project-agnostic.
//! `candle` generative backend (`gen-candle`) — pure-Rust generation via
//! [`candle-core`] + [`candle-transformers`].
//!
//! Loads a **quantized GGUF** Qwen2-family model (weights + tokenizer fetched
//! once from the HF hub into the candle cache) and runs a real prefill → sample
//! → decode loop with a [`LogitsProcessor`]. No C dependency: candle is pure
//! Rust on the CPU back-end.
//!
//! ## Model spec
//! The factory spec `candle:<model>` selects a built-in [`Preset`]; an empty
//! model (`candle:`) uses the default ([`Preset::Qwen2_0_5b`], the smallest). A
//! preset names the HF repo + GGUF file + tokenizer repo so `new` is enough to
//! know what to fetch; the heavy fetch+load happens lazily on the first
//! [`complete`](crate::warehouse::generator::Generator::complete) so constructing
//! the generator (and probing [`available`]) is cheap.
//!
//! ## `available()`
//! Reports `true` when the candle cache already holds this preset's GGUF +
//! tokenizer (an offline probe — no network, no model load). A fresh machine
//! reports `false` until the first online `complete` populates the cache.

use std::path::PathBuf;
use std::sync::Mutex;

use anyhow::{anyhow, Context, Result};
use candle_core::quantized::gguf_file;
use candle_core::{Device, Tensor};
use candle_transformers::generation::LogitsProcessor;
use candle_transformers::models::quantized_qwen2::ModelWeights;
use tokenizers::Tokenizer;

use super::{Backend, GenAnswer, GenRequest, Generator};

/// A built-in quantized model preset: where to fetch the GGUF + tokenizer.
#[derive(Debug, Clone, Copy)]
struct Preset {
    /// The factory id (`candle:<id>`).
    id: &'static str,
    /// HF repo holding the GGUF.
    gguf_repo: &'static str,
    /// GGUF filename inside the repo.
    gguf_file: &'static str,
    /// HF repo holding `tokenizer.json` (the unquantized base).
    tokenizer_repo: &'static str,
}

/// The supported presets. Small first so a CI/dev fetch is cheap. All are Qwen2
/// family, which `quantized_qwen2::ModelWeights` loads directly.
const PRESETS: &[Preset] = &[
    Preset {
        id: "qwen2-0.5b",
        gguf_repo: "Qwen/Qwen2-0.5B-Instruct-GGUF",
        gguf_file: "qwen2-0_5b-instruct-q4_0.gguf",
        tokenizer_repo: "Qwen/Qwen2-0.5B-Instruct",
    },
    Preset {
        id: "qwen2-1.5b",
        gguf_repo: "Qwen/Qwen2-1.5B-Instruct-GGUF",
        gguf_file: "qwen2-1_5b-instruct-q4_0.gguf",
        tokenizer_repo: "Qwen/Qwen2-1.5B-Instruct",
    },
];

/// The default preset when the spec is `candle:` (smallest).
const DEFAULT_PRESET: &Preset = &PRESETS[0];

fn resolve_preset(model: &str) -> Result<&'static Preset> {
    if model.is_empty() {
        return Ok(DEFAULT_PRESET);
    }
    PRESETS
        .iter()
        .find(|p| p.id == model)
        .ok_or_else(|| {
            let ids: Vec<&str> = PRESETS.iter().map(|p| p.id).collect();
            anyhow!("candle: unknown model `{model}` — known presets: {}", ids.join(", "))
        })
}

/// A loaded model + tokenizer (the expensive state), built lazily.
struct Loaded {
    model: ModelWeights,
    tokenizer: Tokenizer,
    device: Device,
}

/// The candle generator. Holds its preset + a lazily-loaded model behind a mutex
/// (generation mutates KV-cache state, so it needs `&mut`; the mutex makes the
/// generator `Sync` for the bake-off/server).
pub struct CandleGenerator {
    id: String,
    preset: &'static Preset,
    loaded: Mutex<Option<Loaded>>,
}

impl CandleGenerator {
    /// Build the generator for `model` (a preset id, or empty for the default).
    /// Does NOT fetch or load weights — that happens on first `complete`.
    pub fn new(model: &str) -> Result<Self> {
        let preset = resolve_preset(model)?;
        Ok(Self {
            id: format!("candle:{}", preset.id),
            preset,
            loaded: Mutex::new(None),
        })
    }

    /// The candle/HF cache file for this preset's GGUF, if already on disk.
    fn cached_gguf(&self) -> Option<PathBuf> {
        cached_hub_file(self.preset.gguf_repo, self.preset.gguf_file)
    }

    /// The cache file for this preset's tokenizer, if already on disk.
    fn cached_tokenizer(&self) -> Option<PathBuf> {
        cached_hub_file(self.preset.tokenizer_repo, "tokenizer.json")
    }

    /// Fetch (if needed) + load the model & tokenizer into [`Loaded`].
    fn load(&self) -> Result<Loaded> {
        let api = hf_hub::api::sync::Api::new().context("candle: hf-hub api init")?;
        let gguf_path = api
            .model(self.preset.gguf_repo.to_string())
            .get(self.preset.gguf_file)
            .context("candle: fetch GGUF weights")?;
        let tok_path = api
            .model(self.preset.tokenizer_repo.to_string())
            .get("tokenizer.json")
            .context("candle: fetch tokenizer.json")?;

        let device = Device::Cpu;
        let mut file = std::fs::File::open(&gguf_path)
            .with_context(|| format!("candle: open {}", gguf_path.display()))?;
        let content = gguf_file::Content::read(&mut file)
            .map_err(|e| anyhow!("candle: read GGUF: {e}"))?;
        let model = ModelWeights::from_gguf(content, &mut file, &device)
            .map_err(|e| anyhow!("candle: build model from GGUF: {e}"))?;
        let tokenizer =
            Tokenizer::from_file(&tok_path).map_err(|e| anyhow!("candle: load tokenizer: {e}"))?;
        Ok(Loaded { model, tokenizer, device })
    }

    /// Build the chat-formatted prompt for Qwen2 (system + user turns).
    fn format_prompt(req: &GenRequest) -> String {
        let mut s = String::new();
        if let Some(sys) = &req.system {
            s.push_str(&format!("<|im_start|>system\n{sys}<|im_end|>\n"));
        }
        s.push_str(&format!(
            "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
            req.prompt
        ));
        s
    }
}

impl Backend for CandleGenerator {
    fn id(&self) -> &str {
        &self.id
    }

    /// Available when the GGUF + tokenizer are already cached locally (offline
    /// probe). A fresh machine reports `false` until the first online `complete`.
    fn available(&self) -> bool {
        self.cached_gguf().is_some() && self.cached_tokenizer().is_some()
    }
}

impl Generator for CandleGenerator {
    fn complete(&self, req: &GenRequest) -> Result<GenAnswer> {
        let started = std::time::Instant::now();
        let mut guard = self.loaded.lock().expect("candle loaded mutex");
        if guard.is_none() {
            *guard = Some(self.load()?);
        }
        let Loaded { model, tokenizer, device } = guard.as_mut().expect("loaded set above");

        let prompt = Self::format_prompt(req);
        let encoding =
            tokenizer.encode(prompt, true).map_err(|e| anyhow!("candle: encode prompt: {e}"))?;
        let prompt_tokens: Vec<u32> = encoding.get_ids().to_vec();
        let tokens_in = prompt_tokens.len() as i64;

        let mut logits_processor = LogitsProcessor::new(
            42,
            if req.temperature > 0.0 { Some(req.temperature as f64) } else { None },
            None,
        );

        // Prefill the prompt, then sample one token at a time.
        let mut all_tokens: Vec<u32> = Vec::new();
        let input = Tensor::new(prompt_tokens.as_slice(), device)
            .map_err(|e| anyhow!("candle: prompt tensor: {e}"))?
            .unsqueeze(0)
            .map_err(|e| anyhow!("candle: unsqueeze: {e}"))?;
        let mut logits = model
            .forward(&input, 0)
            .map_err(|e| anyhow!("candle: prefill forward: {e}"))?;
        logits = logits
            .squeeze(0)
            .map_err(|e| anyhow!("candle: squeeze logits: {e}"))?;
        let eos = tokenizer.token_to_id("<|im_end|>").unwrap_or(u32::MAX);

        let mut next = logits_processor
            .sample(&logits)
            .map_err(|e| anyhow!("candle: sample: {e}"))?;
        // `index` is the KV-cache position fed to `forward`, not a plain loop
        // counter — it starts past the prompt and advances per decoded token.
        let mut index = prompt_tokens.len();
        #[allow(clippy::explicit_counter_loop)]
        for _ in 0..req.max_tokens {
            if next == eos {
                break;
            }
            all_tokens.push(next);
            let input = Tensor::new(&[next], device)
                .map_err(|e| anyhow!("candle: step tensor: {e}"))?
                .unsqueeze(0)
                .map_err(|e| anyhow!("candle: step unsqueeze: {e}"))?;
            let l = model
                .forward(&input, index)
                .map_err(|e| anyhow!("candle: decode forward: {e}"))?
                .squeeze(0)
                .map_err(|e| anyhow!("candle: decode squeeze: {e}"))?;
            next = logits_processor.sample(&l).map_err(|e| anyhow!("candle: sample: {e}"))?;
            index += 1;

            // Honor stop sequences against the running decode.
            if !req.stop.is_empty() {
                let so_far = tokenizer
                    .decode(&all_tokens, true)
                    .map_err(|e| anyhow!("candle: decode: {e}"))?;
                if req.stop.iter().any(|s| so_far.contains(s)) {
                    break;
                }
            }
        }

        let text = tokenizer
            .decode(&all_tokens, true)
            .map_err(|e| anyhow!("candle: final decode: {e}"))?;
        let latency_ms = started.elapsed().as_secs_f64() * 1000.0;
        let tokens_out = all_tokens.len() as i64;
        let tokens_per_s = if latency_ms > 0.0 {
            tokens_out as f64 / (latency_ms / 1000.0)
        } else {
            0.0
        };
        Ok(GenAnswer { text, tokens_in, tokens_out, tokens_per_s, latency_ms })
    }
}

/// Best-effort offline lookup of an HF-hub-cached file (no network). Returns the
/// path only if it already exists on disk, so `available()` never reaches out.
fn cached_hub_file(repo: &str, file: &str) -> Option<PathBuf> {
    let api = hf_hub::api::sync::Api::new().ok()?;
    let cached = api.model(repo.to_string()).get(file);
    // `get` for a cached file returns the local path without downloading only
    // when offline mode is set; to stay strictly offline we instead check the
    // path the cache would use. hf-hub exposes the cache via `Cache`.
    match cached {
        Ok(p) if p.exists() => Some(p),
        _ => {
            // Fall back to the cache's own path resolution (no download).
            let cache = hf_hub::Cache::default();
            cache.model(repo.to_string()).get(file).filter(|p| p.exists())
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn unknown_preset_errors_with_known_list() {
        let err = match CandleGenerator::new("no-such-model") {
            Ok(_) => panic!("unknown preset must error"),
            Err(e) => e.to_string(),
        };
        assert!(err.contains("unknown model"), "{err}");
        assert!(err.contains("qwen2-0.5b"), "lists presets: {err}");
    }

    #[test]
    fn default_preset_when_empty() {
        let gen = CandleGenerator::new("").unwrap();
        assert_eq!(gen.id(), "candle:qwen2-0.5b");
    }

    #[test]
    fn constructs_and_reports_availability_without_loading() {
        // Constructing must not fetch/load — only `available()` probes the cache.
        let gen = CandleGenerator::new("qwen2-0.5b").unwrap();
        assert_eq!(gen.id(), "candle:qwen2-0.5b");
        // available() is a pure offline probe: it returns a bool either way and
        // must not panic or block. (false on a machine without the model cached.)
        let _ = gen.available();
    }

    /// Heavy: fetches + loads the real model and generates. Network + multi-GB.
    /// Gated `#[ignore]` so the default `cargo test` stays offline + fast.
    #[test]
    #[ignore = "downloads a multi-GB GGUF model"]
    fn real_generation_round_trips() {
        let gen = CandleGenerator::new("qwen2-0.5b").unwrap();
        let req = GenRequest::new("Reply with the single word: pong").with_max_tokens(8);
        let ans = gen.complete(&req).unwrap();
        assert!(ans.tokens_in > 0);
        assert!(ans.tokens_out > 0);
        assert!(!ans.text.is_empty());
    }
}