nornir 0.4.41

Companion to cargo: dependency tracking, release gating, deploy, benchmarks, and documentation assembly. Project-agnostic.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
//! `candle` generative backend (`gen-candle`) — pure-Rust generation via
//! [`candle-core`] + [`candle-transformers`].
//!
//! Loads a **quantized GGUF** Qwen2-family model (weights + tokenizer fetched
//! once from the HF hub into the candle cache) and runs a real prefill → sample
//! → decode loop with a [`LogitsProcessor`]. No C dependency: candle is pure
//! Rust on the CPU back-end.
//!
//! ## Model spec
//! The factory spec `candle:<model>` selects a built-in [`Preset`]; an empty
//! model (`candle:`) uses the default ([`Preset::Qwen2_0_5b`], the smallest). A
//! preset names the HF repo + GGUF file + tokenizer repo so `new` is enough to
//! know what to fetch; the heavy fetch+load happens lazily on the first
//! [`complete`](crate::warehouse::generator::Generator::complete) so constructing
//! the generator (and probing [`available`]) is cheap.
//!
//! ## `available()`
//! Reports `true` when the candle cache already holds this preset's GGUF +
//! tokenizer (an offline probe — no network, no model load). A fresh machine
//! reports `false` until the first online `complete` populates the cache.

use std::path::PathBuf;
use std::sync::Mutex;

use anyhow::{anyhow, Context, Result};
use candle_core::quantized::gguf_file;
use candle_core::{Device, Tensor};
use candle_transformers::generation::LogitsProcessor;
use candle_transformers::models::quantized_qwen2::ModelWeights;
use tokenizers::Tokenizer;

use super::{Backend, GenAnswer, GenRequest, Generator};

/// A built-in quantized model preset: where to fetch the GGUF + tokenizer.
#[derive(Debug, Clone, Copy)]
struct Preset {
    /// The factory id (`candle:<id>`).
    id: &'static str,
    /// HF repo holding the GGUF.
    gguf_repo: &'static str,
    /// GGUF filename inside the repo.
    gguf_file: &'static str,
    /// HF repo holding `tokenizer.json` (the unquantized base).
    tokenizer_repo: &'static str,
}

/// The supported presets. Small first so a CI/dev fetch is cheap. All are Qwen2
/// family, which `quantized_qwen2::ModelWeights` loads directly.
const PRESETS: &[Preset] = &[
    Preset {
        id: "qwen2-0.5b",
        gguf_repo: "Qwen/Qwen2-0.5B-Instruct-GGUF",
        gguf_file: "qwen2-0_5b-instruct-q4_0.gguf",
        tokenizer_repo: "Qwen/Qwen2-0.5B-Instruct",
    },
    Preset {
        id: "qwen2-1.5b",
        gguf_repo: "Qwen/Qwen2-1.5B-Instruct-GGUF",
        gguf_file: "qwen2-1_5b-instruct-q4_0.gguf",
        tokenizer_repo: "Qwen/Qwen2-1.5B-Instruct",
    },
];

/// The default preset when the spec is `candle:` (smallest).
const DEFAULT_PRESET: &Preset = &PRESETS[0];

fn resolve_preset(model: &str) -> Result<&'static Preset> {
    if model.is_empty() {
        return Ok(DEFAULT_PRESET);
    }
    PRESETS
        .iter()
        .find(|p| p.id == model)
        .ok_or_else(|| {
            let ids: Vec<&str> = PRESETS.iter().map(|p| p.id).collect();
            anyhow!("candle: unknown model `{model}` — known presets: {}", ids.join(", "))
        })
}

/// A loaded model + tokenizer (the expensive state), built lazily.
struct Loaded {
    model: ModelWeights,
    tokenizer: Tokenizer,
    device: Device,
}

/// The candle generator. Holds its preset + a lazily-loaded model behind a mutex
/// (generation mutates KV-cache state, so it needs `&mut`; the mutex makes the
/// generator `Sync` for the bake-off/server).
pub struct CandleGenerator {
    id: String,
    preset: &'static Preset,
    loaded: Mutex<Option<Loaded>>,
}

impl CandleGenerator {
    /// Build the generator for `model` (a preset id, or empty for the default).
    /// Does NOT fetch or load weights — that happens on first `complete`.
    pub fn new(model: &str) -> Result<Self> {
        let preset = resolve_preset(model)?;
        Ok(Self {
            id: format!("candle:{}", preset.id),
            preset,
            loaded: Mutex::new(None),
        })
    }

    /// The candle/HF cache file for this preset's GGUF, if already on disk.
    fn cached_gguf(&self) -> Option<PathBuf> {
        cached_hub_file(self.preset.gguf_repo, self.preset.gguf_file)
    }

    /// The cache file for this preset's tokenizer, if already on disk.
    fn cached_tokenizer(&self) -> Option<PathBuf> {
        cached_hub_file(self.preset.tokenizer_repo, "tokenizer.json")
    }

    /// Fetch (if needed) + load the model & tokenizer into [`Loaded`].
    ///
    /// Path resolution decouples the load PATH from the HF download: an explicit
    /// `NORNIR_CANDLE_GGUF` / `NORNIR_CANDLE_TOKENIZER` (both set) loads from disk
    /// with NO network, so the real GGUF/tokenizer load path can be exercised
    /// against a local fixture (a synthetic GGUF the test writes, or a
    /// pre-downloaded model) without ever reaching the hub. Absent the override,
    /// the preset's GGUF + tokenizer are fetched from the HF hub as before.
    fn load(&self) -> Result<Loaded> {
        let (gguf_path, tok_path) = match Self::override_paths() {
            Some(paths) => paths,
            None => {
                let api = hf_hub::api::sync::Api::new().context("candle: hf-hub api init")?;
                let gguf_path = api
                    .model(self.preset.gguf_repo.to_string())
                    .get(self.preset.gguf_file)
                    .context("candle: fetch GGUF weights")?;
                let tok_path = api
                    .model(self.preset.tokenizer_repo.to_string())
                    .get("tokenizer.json")
                    .context("candle: fetch tokenizer.json")?;
                (gguf_path, tok_path)
            }
        };
        Self::load_from(&gguf_path, &tok_path)
    }

    /// Explicit on-disk override for the GGUF + tokenizer files. Returns `Some`
    /// only when BOTH `NORNIR_CANDLE_GGUF` and `NORNIR_CANDLE_TOKENIZER` are set
    /// (a half-set override would silently fall back to a download, hiding the
    /// operator's intent — so we require both or neither).
    fn override_paths() -> Option<(PathBuf, PathBuf)> {
        let gguf = std::env::var_os("NORNIR_CANDLE_GGUF")?;
        let tok = std::env::var_os("NORNIR_CANDLE_TOKENIZER")?;
        Some((PathBuf::from(gguf), PathBuf::from(tok)))
    }

    /// Load a GGUF model + tokenizer from explicit on-disk paths. This is the REAL
    /// load path — `gguf_file::Content::read` then `ModelWeights::from_gguf` —
    /// shared by both the hub-download arm and any local/synthetic fixture.
    fn load_from(gguf_path: &std::path::Path, tok_path: &std::path::Path) -> Result<Loaded> {
        let device = Device::Cpu;
        let mut file = std::fs::File::open(gguf_path)
            .with_context(|| format!("candle: open {}", gguf_path.display()))?;
        let content = gguf_file::Content::read(&mut file)
            .map_err(|e| anyhow!("candle: read GGUF: {e}"))?;
        let model = ModelWeights::from_gguf(content, &mut file, &device)
            .map_err(|e| anyhow!("candle: build model from GGUF: {e}"))?;
        let tokenizer =
            Tokenizer::from_file(tok_path).map_err(|e| anyhow!("candle: load tokenizer: {e}"))?;
        Ok(Loaded { model, tokenizer, device })
    }

    /// Build the chat-formatted prompt for Qwen2 (system + user turns).
    fn format_prompt(req: &GenRequest) -> String {
        let mut s = String::new();
        if let Some(sys) = &req.system {
            s.push_str(&format!("<|im_start|>system\n{sys}<|im_end|>\n"));
        }
        s.push_str(&format!(
            "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
            req.prompt
        ));
        s
    }
}

impl Backend for CandleGenerator {
    fn id(&self) -> &str {
        &self.id
    }

    /// Available when the GGUF + tokenizer are already cached locally (offline
    /// probe). A fresh machine reports `false` until the first online `complete`.
    fn available(&self) -> bool {
        self.cached_gguf().is_some() && self.cached_tokenizer().is_some()
    }
}

impl Generator for CandleGenerator {
    fn complete(&self, req: &GenRequest) -> Result<GenAnswer> {
        let started = std::time::Instant::now();
        let mut guard = self.loaded.lock().expect("candle loaded mutex");
        if guard.is_none() {
            *guard = Some(self.load()?);
        }
        let Loaded { model, tokenizer, device } = guard.as_mut().expect("loaded set above");

        let prompt = Self::format_prompt(req);
        let encoding =
            tokenizer.encode(prompt, true).map_err(|e| anyhow!("candle: encode prompt: {e}"))?;
        let prompt_tokens: Vec<u32> = encoding.get_ids().to_vec();
        let tokens_in = prompt_tokens.len() as i64;

        let mut logits_processor = LogitsProcessor::new(
            42,
            if req.temperature > 0.0 { Some(req.temperature as f64) } else { None },
            None,
        );

        // Prefill the prompt, then sample one token at a time.
        let mut all_tokens: Vec<u32> = Vec::new();
        let input = Tensor::new(prompt_tokens.as_slice(), device)
            .map_err(|e| anyhow!("candle: prompt tensor: {e}"))?
            .unsqueeze(0)
            .map_err(|e| anyhow!("candle: unsqueeze: {e}"))?;
        let mut logits = model
            .forward(&input, 0)
            .map_err(|e| anyhow!("candle: prefill forward: {e}"))?;
        logits = logits
            .squeeze(0)
            .map_err(|e| anyhow!("candle: squeeze logits: {e}"))?;
        let eos = tokenizer.token_to_id("<|im_end|>").unwrap_or(u32::MAX);

        let mut next = logits_processor
            .sample(&logits)
            .map_err(|e| anyhow!("candle: sample: {e}"))?;
        // `index` is the KV-cache position fed to `forward`, not a plain loop
        // counter — it starts past the prompt and advances per decoded token.
        let mut index = prompt_tokens.len();
        #[allow(clippy::explicit_counter_loop)]
        for _ in 0..req.max_tokens {
            if next == eos {
                break;
            }
            all_tokens.push(next);
            let input = Tensor::new(&[next], device)
                .map_err(|e| anyhow!("candle: step tensor: {e}"))?
                .unsqueeze(0)
                .map_err(|e| anyhow!("candle: step unsqueeze: {e}"))?;
            let l = model
                .forward(&input, index)
                .map_err(|e| anyhow!("candle: decode forward: {e}"))?
                .squeeze(0)
                .map_err(|e| anyhow!("candle: decode squeeze: {e}"))?;
            next = logits_processor.sample(&l).map_err(|e| anyhow!("candle: sample: {e}"))?;
            index += 1;

            // Honor stop sequences against the running decode.
            if !req.stop.is_empty() {
                let so_far = tokenizer
                    .decode(&all_tokens, true)
                    .map_err(|e| anyhow!("candle: decode: {e}"))?;
                if req.stop.iter().any(|s| so_far.contains(s)) {
                    break;
                }
            }
        }

        let text = tokenizer
            .decode(&all_tokens, true)
            .map_err(|e| anyhow!("candle: final decode: {e}"))?;
        let latency_ms = started.elapsed().as_secs_f64() * 1000.0;
        let tokens_out = all_tokens.len() as i64;
        let tokens_per_s = if latency_ms > 0.0 {
            tokens_out as f64 / (latency_ms / 1000.0)
        } else {
            0.0
        };
        Ok(GenAnswer { text, tokens_in, tokens_out, tokens_per_s, latency_ms })
    }
}

/// Best-effort offline lookup of an HF-hub-cached file (no network). Returns the
/// path only if it already exists on disk, so `available()` never reaches out.
fn cached_hub_file(repo: &str, file: &str) -> Option<PathBuf> {
    let api = hf_hub::api::sync::Api::new().ok()?;
    let cached = api.model(repo.to_string()).get(file);
    // `get` for a cached file returns the local path without downloading only
    // when offline mode is set; to stay strictly offline we instead check the
    // path the cache would use. hf-hub exposes the cache via `Cache`.
    match cached {
        Ok(p) if p.exists() => Some(p),
        _ => {
            // Fall back to the cache's own path resolution (no download).
            let cache = hf_hub::Cache::default();
            cache.model(repo.to_string()).get(file).filter(|p| p.exists())
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn unknown_preset_errors_with_known_list() {
        let err = match CandleGenerator::new("no-such-model") {
            Ok(_) => panic!("unknown preset must error"),
            Err(e) => e.to_string(),
        };
        assert!(err.contains("unknown model"), "{err}");
        assert!(err.contains("qwen2-0.5b"), "lists presets: {err}");
    }

    #[test]
    fn default_preset_when_empty() {
        let gen = CandleGenerator::new("").unwrap();
        assert_eq!(gen.id(), "candle:qwen2-0.5b");
    }

    #[test]
    fn constructs_and_reports_availability_without_loading() {
        // Constructing must not fetch/load — only `available()` probes the cache.
        let gen = CandleGenerator::new("qwen2-0.5b").unwrap();
        assert_eq!(gen.id(), "candle:qwen2-0.5b");
        // available() is a pure offline probe: it returns a bool either way and
        // must not panic or block. (false on a machine without the model cached.)
        let _ = gen.available();
    }

    /// SYNTHETIC tier — ALWAYS runs (offline, no download, no `--heavy`).
    ///
    /// Writes a tiny but VALID GGUF fixture (real magic, real metadata, a real
    /// quantized tensor) with candle's own `gguf_file::write`, then drives the
    /// EXACT reader the production `load()` uses — `gguf_file::Content::read` —
    /// and asserts the injected metadata + tensor round-trip byte-for-byte. This
    /// exercises the real GGUF parse path the model load depends on against a
    /// synthesized input, instead of silently skipping when no multi-GB model is
    /// cached. The full transformer `from_gguf` + generate stays in the heavy arm
    /// (it needs the model's full tensor set, which we don't fabricate).
    #[test]
    fn synthetic_gguf_round_trips_through_the_real_reader() {
        use candle_core::quantized::{gguf_file::Value, GgmlDType, QTensor};
        use std::io::{Seek, SeekFrom};

        let device = Device::Cpu;

        // INJECT: a known F32 tensor, quantized to a real GGML dtype, plus a known
        // metadata key/value the reader must hand back unchanged.
        let src = Tensor::from_vec(
            (0..32).map(|i| i as f32).collect::<Vec<f32>>(),
            (1, 32),
            &device,
        )
        .expect("build source tensor");
        let qtensor = QTensor::quantize(&src, GgmlDType::Q4_0).expect("quantize tensor");

        let arch = Value::String("qwen2-synthetic".to_string());
        let ctx_len = Value::U32(2048);
        let metadata: Vec<(&str, &Value)> = vec![
            ("general.architecture", &arch),
            ("qwen2.context_length", &ctx_len),
        ];
        let tensors: Vec<(&str, &QTensor)> = vec![("token_embd.weight", &qtensor)];

        // Write the synthetic GGUF to a temp file, then read it back via the
        // production reader path.
        let mut tmp = tempfile::NamedTempFile::new().expect("tempfile");
        gguf_file::write(tmp.as_file_mut(), &metadata, &tensors).expect("write synthetic GGUF");
        tmp.as_file_mut().seek(SeekFrom::Start(0)).expect("rewind");

        let mut f = std::fs::File::open(tmp.path()).expect("reopen synthetic GGUF");
        let content = gguf_file::Content::read(&mut f).expect("read synthetic GGUF (real path)");

        // ASSERT real round-tripped values — not "didn't panic".
        match content.metadata.get("general.architecture") {
            Some(Value::String(s)) => assert_eq!(s, "qwen2-synthetic", "architecture round-trips"),
            other => panic!("architecture metadata missing/wrong: {other:?}"),
        }
        match content.metadata.get("qwen2.context_length") {
            Some(Value::U32(n)) => assert_eq!(*n, 2048, "context_length round-trips"),
            other => panic!("context_length metadata missing/wrong: {other:?}"),
        }
        assert!(
            content.tensor_infos.contains_key("token_embd.weight"),
            "tensor info round-trips; got {:?}",
            content.tensor_infos.keys().collect::<Vec<_>>()
        );
        // And the tensor data itself is readable back through the real accessor.
        let read_back = content
            .tensor(&mut f, "token_embd.weight", &device)
            .expect("read tensor back");
        assert_eq!(read_back.shape().dims(), &[1, 32], "tensor shape round-trips");
    }

    /// SYNTHETIC tier — ALWAYS runs. Proves the explicit-path load override is
    /// wired: with neither env set there is no override; with both set the
    /// override resolves to those paths (so `load()` skips the network). We don't
    /// need a real model to assert the *resolution*, only the precedence rule.
    #[test]
    fn override_paths_require_both_env_vars() {
        // Serialize against any parallel env mutation by saving/restoring.
        let prev_g = std::env::var_os("NORNIR_CANDLE_GGUF");
        let prev_t = std::env::var_os("NORNIR_CANDLE_TOKENIZER");
        std::env::remove_var("NORNIR_CANDLE_GGUF");
        std::env::remove_var("NORNIR_CANDLE_TOKENIZER");
        assert!(CandleGenerator::override_paths().is_none(), "no override when unset");

        std::env::set_var("NORNIR_CANDLE_GGUF", "/tmp/synthetic.gguf");
        assert!(
            CandleGenerator::override_paths().is_none(),
            "half-set override must not engage (would hide intent)"
        );
        std::env::set_var("NORNIR_CANDLE_TOKENIZER", "/tmp/tok.json");
        let resolved = CandleGenerator::override_paths().expect("both set → override");
        assert_eq!(resolved.0, PathBuf::from("/tmp/synthetic.gguf"));
        assert_eq!(resolved.1, PathBuf::from("/tmp/tok.json"));

        // Restore.
        match prev_g {
            Some(v) => std::env::set_var("NORNIR_CANDLE_GGUF", v),
            None => std::env::remove_var("NORNIR_CANDLE_GGUF"),
        }
        match prev_t {
            Some(v) => std::env::set_var("NORNIR_CANDLE_TOKENIZER", v),
            None => std::env::remove_var("NORNIR_CANDLE_TOKENIZER"),
        }
    }

    /// Heavy: fetches + loads the real model and generates. Network + multi-GB.
    /// Gated `#[ignore]` so the default `cargo test` stays offline + fast. Run it
    /// on the REAL-DATA opt-in (`--include-ignored` / `--heavy`), or point it at a
    /// pre-downloaded model with `NORNIR_CANDLE_GGUF` + `NORNIR_CANDLE_TOKENIZER`.
    #[test]
    #[ignore = "downloads a multi-GB GGUF model (real-data arm)"]
    fn real_generation_round_trips() {
        let gen = CandleGenerator::new("qwen2-0.5b").unwrap();
        let req = GenRequest::new("Reply with the single word: pong").with_max_tokens(8);
        let ans = gen.complete(&req).unwrap();
        assert!(ans.tokens_in > 0);
        assert!(ans.tokens_out > 0);
        assert!(!ans.text.is_empty());
    }
}