Skip to main content

oxillama_wasm/
lib.rs

1//! WebAssembly bindings for OxiLLaMa.
2//!
3//! Exposes GGUF header parsing, Q4_0 dequantization, **and full text
4//! generation** to JavaScript/TypeScript via wasm-bindgen.
5//!
6//! ## Feature flags
7//!
8//! | Feature     | Default | Description                                        |
9//! |-------------|---------|---------------------------------------------------|
10//! | `inference` | yes     | Enables `generate()` via `oxillama-runtime` with  |
11//! |             |         | the pure-Rust `unstable_wasm` tokenizer backend.  |
12//!
13//! ## Usage (generate)
14//!
15//! ```js
16//! import init, { generate } from './oxillama_wasm.js';
17//! await init();
18//!
19//! const modelResp = await fetch('model.gguf');
20//! const modelBytes = new Uint8Array(await modelResp.arrayBuffer());
21//! const tokenizerResp = await fetch('tokenizer.json');
22//! const tokenizerJson = await tokenizerResp.text();
23//!
24//! // Streaming: pass a callback to receive each token as it is generated.
25//! const text = generate(modelBytes, tokenizerJson, "Hello, world!", 128,
26//!     (token) => process.stdout.write(token));
27//! console.log(text);
28//! ```
29
30use serde::{Deserialize, Serialize};
31use wasm_bindgen::prelude::*;
32
33pub mod gpu_bridge;
34pub mod idb_cache;
35pub mod service_worker;
36pub mod simd_check;
37pub mod streaming_load;
38pub mod streaming_loader;
39pub mod webgpu;
40pub mod worker;
41
42pub use service_worker::{
43    get_service_worker_script, register_service_worker, ServiceWorkerOptions,
44};
45pub use simd_check::get_simd128_status;
46pub use streaming_loader::StreamingGgufLoader;
47pub use streaming_loader::StreamingLoadOptions;
48
49// ── Panic hook (default feature) ─────────────────────────────────────────────
50
51/// Initialize the WASM module (sets up panic hook for better error messages).
52///
53/// Called automatically by the generated JS glue code when the WASM module
54/// is first instantiated.
55#[wasm_bindgen(start)]
56pub fn init() {
57    #[cfg(feature = "console_error_panic_hook")]
58    console_error_panic_hook::set_once();
59}
60
61// ── GGUF header parsing ───────────────────────────────────────────────────────
62
63/// Parse a GGUF file header from raw bytes and return key metadata as a JS object.
64///
65/// The returned object has the following numeric fields:
66/// - `tensorCount`   — number of tensors in the file
67/// - `metadataCount` — number of metadata KV pairs
68/// - `version`       — GGUF file version (2 or 3)
69///
70/// Throws a JavaScript error string if the bytes are not a valid GGUF file.
71#[wasm_bindgen(js_name = parseGgufHeader)]
72pub fn parse_gguf_header(data: &[u8]) -> Result<JsValue, JsValue> {
73    let gguf = oxillama_gguf::GgufFile::parse(data)
74        .map_err(|e| JsValue::from_str(&format!("GGUF parse error: {e}")))?;
75
76    let obj = js_sys::Object::new();
77    js_sys::Reflect::set(
78        &obj,
79        &JsValue::from_str("tensorCount"),
80        &JsValue::from_f64(gguf.tensors.len() as f64),
81    )
82    .map_err(|e| JsValue::from_str(&format!("Reflect.set error: {e:?}")))?;
83    js_sys::Reflect::set(
84        &obj,
85        &JsValue::from_str("metadataCount"),
86        &JsValue::from_f64(gguf.metadata.len() as f64),
87    )
88    .map_err(|e| JsValue::from_str(&format!("Reflect.set error: {e:?}")))?;
89    js_sys::Reflect::set(
90        &obj,
91        &JsValue::from_str("version"),
92        &JsValue::from_f64(gguf.header.version as f64),
93    )
94    .map_err(|e| JsValue::from_str(&format!("Reflect.set error: {e:?}")))?;
95
96    Ok(JsValue::from(obj))
97}
98
99/// Return all tensor names stored in a GGUF file as a JS array of strings.
100///
101/// Throws a JavaScript error string if parsing fails.
102#[wasm_bindgen(js_name = listTensorNames)]
103pub fn list_tensor_names(data: &[u8]) -> Result<Vec<JsValue>, JsValue> {
104    let gguf = oxillama_gguf::GgufFile::parse(data)
105        .map_err(|e| JsValue::from_str(&format!("GGUF parse error: {e}")))?;
106
107    Ok(gguf
108        .tensors
109        .names()
110        .map(|name| JsValue::from_str(name))
111        .collect())
112}
113
114// ── Q4_0 dequantization ───────────────────────────────────────────────────────
115
116/// Dequantize a buffer of Q4_0 blocks to an array of f32 values.
117///
118/// The Q4_0 block layout is 18 bytes per 32 weights:
119/// - 2 bytes: FP16 scale factor
120/// - 16 bytes: 32 × 4-bit nibbles packed two per byte
121///
122/// `data` must be a multiple of 18 bytes.  Returns a `Vec<f32>` of length
123/// `(data.len() / 18) * 32`.  Throws a JavaScript error string on any
124/// malformed input.
125#[wasm_bindgen(js_name = dequantQ4_0)]
126pub fn dequant_q4_0(data: &[u8]) -> Result<Vec<f32>, JsValue> {
127    use oxillama_quant::reference::Q4_0Ref;
128    use oxillama_quant::traits::QuantKernel;
129
130    const BLOCK_BYTES: usize = 18;
131    const BLOCK_SIZE: usize = 32;
132
133    if !data.len().is_multiple_of(BLOCK_BYTES) {
134        return Err(JsValue::from_str(&format!(
135            "Q4_0 data length {} is not a multiple of {} bytes per block",
136            data.len(),
137            BLOCK_BYTES,
138        )));
139    }
140
141    let n_blocks = data.len() / BLOCK_BYTES;
142    let n_weights = n_blocks * BLOCK_SIZE;
143    let mut out = vec![0.0f32; n_weights];
144    let kernel = Q4_0Ref;
145
146    for (blk_idx, block) in data.chunks_exact(BLOCK_BYTES).enumerate() {
147        let output_slice = &mut out[blk_idx * BLOCK_SIZE..(blk_idx + 1) * BLOCK_SIZE];
148        kernel.dequant_block(block, output_slice).map_err(|e| {
149            JsValue::from_str(&format!("dequant_block error at block {blk_idx}: {e}"))
150        })?;
151    }
152
153    Ok(out)
154}
155
156// ── Text generation ───────────────────────────────────────────────────────────
157
158/// Run full text generation from an in-memory GGUF model.
159///
160/// # Arguments
161///
162/// - `model_bytes`    — raw bytes of the `.gguf` model file (copied in JS
163///                      via `new Uint8Array(buffer)`)
164/// - `tokenizer_json` — contents of the HuggingFace `tokenizer.json` file
165///                      for this model
166/// - `prompt`         — input text prompt
167/// - `max_tokens`     — maximum number of tokens to generate (generation
168///                      stops earlier if the model produces an EOS token)
169/// - `on_token`       — optional JS callback invoked with each generated token
170///                      text as a string, enabling streaming output
171///
172/// Returns the generated text as a JS string, or throws a JS error.
173///
174/// # Notes
175///
176/// This function requires the `inference` feature (enabled by default).
177/// It uses `oxillama-runtime` with the pure-Rust `tokenizer-wasm` backend
178/// (`fancy-regex`) so it is safe on `wasm32-unknown-unknown`.
179///
180/// Model loading is done via `InferenceEngine::load_model_from_bytes` which
181/// accepts raw GGUF bytes — no filesystem access is needed inside the WASM sandbox.
182#[cfg(feature = "inference")]
183#[wasm_bindgen]
184pub fn generate(
185    model_bytes: &[u8],
186    tokenizer_json: &str,
187    prompt: &str,
188    max_tokens: usize,
189    on_token: Option<js_sys::Function>,
190) -> Result<String, JsValue> {
191    use oxillama_runtime::{EngineConfig, InferenceEngine};
192
193    // ── 1. Create engine and load model from raw bytes ────────────────────────
194    //
195    // `load_model_from_bytes` is the filesystem-free entry point added to
196    // InferenceEngine specifically for WASM (and any other no-fs environment).
197    // It accepts the GGUF bytes directly and loads the tokenizer from the
198    // supplied JSON string rather than a file path.
199    let mut engine = InferenceEngine::new(EngineConfig::default());
200    engine
201        .load_model_from_bytes(model_bytes, tokenizer_json)
202        .map_err(|e| JsValue::from_str(&format!("model load error: {e}")))?;
203
204    // ── 2. Run the generation pipeline ───────────────────────────────────────
205    //
206    // `InferenceEngine::generate` handles tokenisation, prefill, and the
207    // autoregressive decode loop — including EOS detection and context-length
208    // capping.  The callback receives each token's text as it is decoded;
209    // if a JS callback was supplied we forward each token to it.
210    let output = engine
211        .generate(prompt, max_tokens, |token_text| {
212            if let Some(ref cb) = on_token {
213                let this = JsValue::NULL;
214                let _ = cb.call1(&this, &JsValue::from_str(token_text));
215            }
216        })
217        .map_err(|e| JsValue::from_str(&format!("generation error: {e}")))?;
218
219    Ok(output)
220}
221
222// ── K-quant dequantization ─────────────────────────────────────────────────────
223
224/// Dequantize a buffer of Q4_K blocks to an array of f32 values.
225///
226/// The Q4_K block layout is 144 bytes per 256 weights.
227/// `data` must be a multiple of 144 bytes.  Returns a `Vec<f32>` of length
228/// `(data.len() / 144) * 256`.  Throws a JavaScript error string on any
229/// malformed input.
230#[wasm_bindgen(js_name = dequantQ4K)]
231pub fn dequant_q4_k(data: &[u8]) -> Result<Vec<f32>, JsValue> {
232    use oxillama_quant::reference::Q4KRef;
233    use oxillama_quant::traits::QuantKernel;
234
235    const BLOCK_BYTES: usize = 144;
236    const BLOCK_SIZE: usize = 256;
237
238    if !data.len().is_multiple_of(BLOCK_BYTES) {
239        return Err(JsValue::from_str(&format!(
240            "Q4_K data length {} is not a multiple of {} bytes per block",
241            data.len(),
242            BLOCK_BYTES,
243        )));
244    }
245
246    let n_blocks = data.len() / BLOCK_BYTES;
247    let n_weights = n_blocks * BLOCK_SIZE;
248    let mut out = vec![0.0f32; n_weights];
249    let kernel = Q4KRef;
250
251    for (blk_idx, block) in data.chunks_exact(BLOCK_BYTES).enumerate() {
252        let output_slice = &mut out[blk_idx * BLOCK_SIZE..(blk_idx + 1) * BLOCK_SIZE];
253        kernel.dequant_block(block, output_slice).map_err(|e| {
254            JsValue::from_str(&format!("dequant_block error at block {blk_idx}: {e}"))
255        })?;
256    }
257
258    Ok(out)
259}
260
261/// Dequantize a buffer of Q5_K blocks to an array of f32 values.
262///
263/// The Q5_K block layout is 176 bytes per 256 weights.
264/// `data` must be a multiple of 176 bytes.  Returns a `Vec<f32>` of length
265/// `(data.len() / 176) * 256`.  Throws a JavaScript error string on any
266/// malformed input.
267#[wasm_bindgen(js_name = dequantQ5K)]
268pub fn dequant_q5_k(data: &[u8]) -> Result<Vec<f32>, JsValue> {
269    use oxillama_quant::reference::Q5KRef;
270    use oxillama_quant::traits::QuantKernel;
271
272    const BLOCK_BYTES: usize = 176;
273    const BLOCK_SIZE: usize = 256;
274
275    if !data.len().is_multiple_of(BLOCK_BYTES) {
276        return Err(JsValue::from_str(&format!(
277            "Q5_K data length {} is not a multiple of {} bytes per block",
278            data.len(),
279            BLOCK_BYTES,
280        )));
281    }
282
283    let n_blocks = data.len() / BLOCK_BYTES;
284    let n_weights = n_blocks * BLOCK_SIZE;
285    let mut out = vec![0.0f32; n_weights];
286    let kernel = Q5KRef;
287
288    for (blk_idx, block) in data.chunks_exact(BLOCK_BYTES).enumerate() {
289        let output_slice = &mut out[blk_idx * BLOCK_SIZE..(blk_idx + 1) * BLOCK_SIZE];
290        kernel.dequant_block(block, output_slice).map_err(|e| {
291            JsValue::from_str(&format!("dequant_block error at block {blk_idx}: {e}"))
292        })?;
293    }
294
295    Ok(out)
296}
297
298/// Dequantize a buffer of Q6_K blocks to an array of f32 values.
299///
300/// The Q6_K block layout is 210 bytes per 256 weights.
301/// `data` must be a multiple of 210 bytes.  Returns a `Vec<f32>` of length
302/// `(data.len() / 210) * 256`.  Throws a JavaScript error string on any
303/// malformed input.
304#[wasm_bindgen(js_name = dequantQ6K)]
305pub fn dequant_q6_k(data: &[u8]) -> Result<Vec<f32>, JsValue> {
306    use oxillama_quant::reference::Q6KRef;
307    use oxillama_quant::traits::QuantKernel;
308
309    const BLOCK_BYTES: usize = 210;
310    const BLOCK_SIZE: usize = 256;
311
312    if !data.len().is_multiple_of(BLOCK_BYTES) {
313        return Err(JsValue::from_str(&format!(
314            "Q6_K data length {} is not a multiple of {} bytes per block",
315            data.len(),
316            BLOCK_BYTES,
317        )));
318    }
319
320    let n_blocks = data.len() / BLOCK_BYTES;
321    let n_weights = n_blocks * BLOCK_SIZE;
322    let mut out = vec![0.0f32; n_weights];
323    let kernel = Q6KRef;
324
325    for (blk_idx, block) in data.chunks_exact(BLOCK_BYTES).enumerate() {
326        let output_slice = &mut out[blk_idx * BLOCK_SIZE..(blk_idx + 1) * BLOCK_SIZE];
327        kernel.dequant_block(block, output_slice).map_err(|e| {
328            JsValue::from_str(&format!("dequant_block error at block {blk_idx}: {e}"))
329        })?;
330    }
331
332    Ok(out)
333}
334
335// ── Load model with progress callback ────────────────────────────────────────
336
337/// Typed GGUF metadata returned by [`parse_gguf_metadata`].
338///
339/// Optional fields are `None` when the corresponding key is absent in the file.
340#[derive(Debug, Serialize, Deserialize)]
341pub struct GgufMetadataJs {
342    pub version: u32,
343    pub tensor_count: u64,
344    pub kv_count: u64,
345    pub arch: Option<String>,
346    pub context_length: Option<u64>,
347    pub embedding_length: Option<u64>,
348    pub feed_forward_length: Option<u64>,
349    pub attention_head_count: Option<u64>,
350    pub block_count: Option<u64>,
351    pub quantization_version: Option<u32>,
352    pub general_name: Option<String>,
353    pub general_author: Option<String>,
354    pub general_description: Option<String>,
355}
356
357/// Core model-loading logic shared by `load_model_from_bytes_with_progress`
358/// and the inference `generate` function.
359///
360/// Emits progress percentages (0, 25, 100) to `on_progress` if provided.
361/// Returns an error as a `JsValue` string if loading fails.
362#[cfg(feature = "inference")]
363fn load_model_core(
364    model_bytes: &[u8],
365    tokenizer_json: &str,
366    on_progress: Option<&js_sys::Function>,
367) -> Result<oxillama_runtime::InferenceEngine, JsValue> {
368    use oxillama_runtime::{EngineConfig, InferenceEngine};
369
370    let emit = |pct: u32| {
371        if let Some(cb) = on_progress {
372            let _ = cb.call1(&JsValue::UNDEFINED, &JsValue::from(pct));
373        }
374    };
375
376    emit(0);
377    let mut engine = InferenceEngine::new(EngineConfig::default());
378    emit(25);
379    engine
380        .load_model_from_bytes(model_bytes, tokenizer_json)
381        .map_err(|e| JsValue::from_str(&format!("model load error: {e}")))?;
382    emit(100);
383
384    Ok(engine)
385}
386
387/// Load a GGUF model from raw bytes, reporting progress via an optional JS callback.
388///
389/// The callback is invoked with a percentage value (`0`, `25`, `100`) at key
390/// milestones during loading.  Pass `undefined` / `null` to skip progress
391/// reporting.
392///
393/// Returns an opaque `WasmEngine` handle on success, or throws a JS error.
394#[cfg(feature = "inference")]
395#[wasm_bindgen(js_name = loadModelFromBytesWithProgress)]
396pub fn load_model_from_bytes_with_progress(
397    model_bytes: &[u8],
398    tokenizer_json: &str,
399    on_progress: Option<js_sys::Function>,
400) -> Result<WasmEngine, JsValue> {
401    let engine = load_model_core(model_bytes, tokenizer_json, on_progress.as_ref())?;
402    Ok(WasmEngine { inner: engine })
403}
404
405/// Opaque handle wrapping a loaded `InferenceEngine` for use from JS.
406#[cfg(feature = "inference")]
407#[wasm_bindgen]
408pub struct WasmEngine {
409    inner: oxillama_runtime::InferenceEngine,
410}
411
412#[cfg(feature = "inference")]
413#[wasm_bindgen]
414impl WasmEngine {
415    /// Run text generation on this engine.
416    ///
417    /// Equivalent to the top-level [`generate`] function but reuses an already
418    /// loaded model, avoiding the expensive load step on subsequent calls.
419    pub fn generate(
420        &mut self,
421        prompt: &str,
422        max_tokens: usize,
423        on_token: Option<js_sys::Function>,
424    ) -> Result<String, JsValue> {
425        self.inner
426            .generate(prompt, max_tokens, |tok| {
427                if let Some(ref cb) = on_token {
428                    let _ = cb.call1(&JsValue::NULL, &JsValue::from_str(tok));
429                }
430            })
431            .map_err(|e| JsValue::from_str(&format!("generation error: {e}")))
432    }
433}
434
435// ── Typed GGUF metadata export ────────────────────────────────────────────────
436
437/// Parse a GGUF file and return typed metadata as a JS object.
438///
439/// The returned object conforms to the [`GgufMetadataJs`] schema.  Optional
440/// fields are `null` when the corresponding metadata key is absent.
441///
442/// Throws a JavaScript error string if parsing fails.
443#[wasm_bindgen(js_name = parseGgufMetadata)]
444pub fn parse_gguf_metadata(data: &[u8]) -> Result<JsValue, JsValue> {
445    let gguf = oxillama_gguf::GgufFile::parse(data)
446        .map_err(|e| JsValue::from_str(&format!("GGUF parse error: {e}")))?;
447
448    let meta = &gguf.metadata;
449
450    // Detect architecture first — used as prefix for architecture-specific keys.
451    let arch: Option<String> = meta
452        .get("general.architecture")
453        .and_then(|v| v.as_str())
454        .map(|s| s.to_owned());
455
456    // Helper: look up an integer key, trying the arch-prefixed form first then
457    // falling back to common prefixes.
458    let get_u64 = |suffix: &str| -> Option<u64> {
459        let prefixes: &[&str] = match arch.as_deref() {
460            Some(a) => {
461                // Use the detected arch plus a handful of common fallbacks.
462                &[a, "llama", "mistral", "qwen3", "gemma", "phi"][..]
463            }
464            None => &["llama", "mistral", "qwen3", "gemma", "phi"][..],
465        };
466        for prefix in prefixes {
467            let key = format!("{prefix}.{suffix}");
468            if let Some(val) = meta.get(&key).and_then(|v| v.as_u64()) {
469                return Some(val);
470            }
471        }
472        None
473    };
474
475    let metadata_js = GgufMetadataJs {
476        version: gguf.header.version,
477        tensor_count: gguf.header.tensor_count,
478        kv_count: gguf.header.metadata_kv_count,
479        context_length: get_u64("context_length"),
480        embedding_length: get_u64("embedding_length"),
481        feed_forward_length: get_u64("feed_forward_length"),
482        attention_head_count: get_u64("attention.head_count"),
483        block_count: get_u64("block_count"),
484        quantization_version: meta
485            .get("general.quantization_version")
486            .and_then(|v| v.as_u32()),
487        general_name: meta
488            .get("general.name")
489            .and_then(|v| v.as_str())
490            .map(|s| s.to_owned()),
491        general_author: meta
492            .get("general.author")
493            .and_then(|v| v.as_str())
494            .map(|s| s.to_owned()),
495        general_description: meta
496            .get("general.description")
497            .and_then(|v| v.as_str())
498            .map(|s| s.to_owned()),
499        arch,
500    };
501
502    serde_wasm_bindgen::to_value(&metadata_js).map_err(|e| JsValue::from_str(&e.to_string()))
503}
504
505// ── Tests ─────────────────────────────────────────────────────────────────────
506
507#[cfg(test)]
508mod tests {
509    // Tests operate at the underlying library level to avoid wasm-bindgen
510    // JsValue/Reflect machinery that only works correctly inside a WASM runtime.
511    // The wasm-bindgen glue wrappers are tested via wasm-bindgen-test on a real
512    // WASM target; here we verify the underlying logic independently.
513
514    use oxillama_quant::reference::Q4_0Ref;
515    use oxillama_quant::traits::QuantKernel;
516
517    #[test]
518    fn test_parse_gguf_empty_fails() {
519        // An empty buffer must return a descriptive error, not panic.
520        let result = oxillama_gguf::GgufFile::parse(&[]);
521        assert!(result.is_err(), "empty buffer should fail to parse");
522    }
523
524    #[test]
525    fn test_parse_gguf_bad_magic_fails() {
526        // Wrong magic bytes must produce an error, not a panic.
527        let bad = b"BAAD\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00";
528        let result = oxillama_gguf::GgufFile::parse(bad);
529        assert!(result.is_err(), "wrong magic should fail to parse");
530    }
531
532    #[test]
533    fn test_dequant_q4_0_wrong_length_fails() {
534        // 17 bytes is not a multiple of 18 — the length check should catch it.
535        const BLOCK_BYTES: usize = 18;
536        let bad = vec![0u8; 17];
537        assert_ne!(
538            bad.len() % BLOCK_BYTES,
539            0,
540            "17 must not be a multiple of 18"
541        );
542        // Verify that feeding incomplete block data to dequant_block gives an error.
543        let kernel = Q4_0Ref;
544        let mut out = vec![0.0f32; 32];
545        let result = kernel.dequant_block(&bad, &mut out);
546        assert!(result.is_err(), "incomplete block should fail");
547    }
548
549    #[test]
550    fn test_dequant_q4_0_zero_block() {
551        // A single Q4_0 block: scale = 1.0 (FP16), all nibbles = 0x88 (encodes 0).
552        let mut block = vec![0u8; 18];
553        // FP16 1.0 = 0x3C00
554        block[0] = 0x00;
555        block[1] = 0x3C;
556        // Nibbles 0x88 => lo = 8 - 8 = 0, hi = 8 - 8 = 0 for each byte
557        for b in block[2..].iter_mut() {
558            *b = 0x88;
559        }
560        let kernel = Q4_0Ref;
561        let mut out = vec![0.0f32; 32];
562        kernel
563            .dequant_block(&block, &mut out)
564            .expect("should not fail on valid block");
565        assert_eq!(out.len(), 32, "one block = 32 weights");
566        for (i, &v) in out.iter().enumerate() {
567            assert!(v.abs() < 1e-5, "weight[{i}] = {v}, expected ~0.0");
568        }
569    }
570
571    #[test]
572    fn test_dequant_q4_0_two_blocks_length() {
573        // Two blocks: verify output vector has 64 elements.
574        const BLOCK_BYTES: usize = 18;
575        const BLOCK_SIZE: usize = 32;
576        let data = [0u8; 2 * BLOCK_BYTES];
577        let kernel = Q4_0Ref;
578        let n_blocks = data.len() / BLOCK_BYTES;
579        let mut out = vec![0.0f32; n_blocks * BLOCK_SIZE];
580        for (blk_idx, block) in data.chunks_exact(BLOCK_BYTES).enumerate() {
581            let slice = &mut out[blk_idx * BLOCK_SIZE..(blk_idx + 1) * BLOCK_SIZE];
582            kernel
583                .dequant_block(block, slice)
584                .expect("dequant_block should succeed on zeroed data");
585        }
586        assert_eq!(out.len(), 64, "two blocks = 64 weights");
587    }
588
589    // ── Q4_K tests ────────────────────────────────────────────────────────────
590
591    #[test]
592    fn test_dequant_q4_k_wrong_length_fails() {
593        use oxillama_quant::reference::Q4KRef;
594        const BLOCK_BYTES: usize = 144;
595        let bad = vec![0u8; 143];
596        assert_ne!(bad.len() % BLOCK_BYTES, 0);
597        let kernel = Q4KRef;
598        let mut out = vec![0.0f32; 256];
599        let result = kernel.dequant_block(&bad, &mut out);
600        assert!(result.is_err(), "incomplete Q4_K block should fail");
601    }
602
603    #[test]
604    fn test_dequant_q4_k_zero_block() {
605        use oxillama_quant::reference::Q4KRef;
606        const BLOCK_BYTES: usize = 144;
607        const BLOCK_SIZE: usize = 256;
608        // All-zero block: d=0, dmin=0, all nibbles=0 → all weights should be 0.
609        let block = vec![0u8; BLOCK_BYTES];
610        let kernel = Q4KRef;
611        let mut out = vec![0.0f32; BLOCK_SIZE];
612        kernel
613            .dequant_block(&block, &mut out)
614            .expect("zero block should succeed");
615        for (i, &v) in out.iter().enumerate() {
616            assert!(v.abs() < 1e-5, "Q4_K weight[{i}] = {v}, expected ~0.0");
617        }
618    }
619
620    // ── Q5_K tests ────────────────────────────────────────────────────────────
621
622    #[test]
623    fn test_dequant_q5_k_wrong_length_fails() {
624        use oxillama_quant::reference::Q5KRef;
625        const BLOCK_BYTES: usize = 176;
626        let bad = vec![0u8; 175];
627        assert_ne!(bad.len() % BLOCK_BYTES, 0);
628        let kernel = Q5KRef;
629        let mut out = vec![0.0f32; 256];
630        let result = kernel.dequant_block(&bad, &mut out);
631        assert!(result.is_err(), "incomplete Q5_K block should fail");
632    }
633
634    #[test]
635    fn test_dequant_q5_k_zero_block() {
636        use oxillama_quant::reference::Q5KRef;
637        const BLOCK_BYTES: usize = 176;
638        const BLOCK_SIZE: usize = 256;
639        let block = vec![0u8; BLOCK_BYTES];
640        let kernel = Q5KRef;
641        let mut out = vec![0.0f32; BLOCK_SIZE];
642        kernel
643            .dequant_block(&block, &mut out)
644            .expect("zero block should succeed");
645        for (i, &v) in out.iter().enumerate() {
646            assert!(v.abs() < 1e-5, "Q5_K weight[{i}] = {v}, expected ~0.0");
647        }
648    }
649
650    // ── Q6_K tests ────────────────────────────────────────────────────────────
651
652    #[test]
653    fn test_dequant_q6_k_wrong_length_fails() {
654        use oxillama_quant::reference::Q6KRef;
655        const BLOCK_BYTES: usize = 210;
656        let bad = vec![0u8; 209];
657        assert_ne!(bad.len() % BLOCK_BYTES, 0);
658        let kernel = Q6KRef;
659        let mut out = vec![0.0f32; 256];
660        let result = kernel.dequant_block(&bad, &mut out);
661        assert!(result.is_err(), "incomplete Q6_K block should fail");
662    }
663
664    #[test]
665    fn test_dequant_q6_k_zero_block() {
666        use oxillama_quant::reference::Q6KRef;
667        const BLOCK_BYTES: usize = 210;
668        const BLOCK_SIZE: usize = 256;
669        // Q6_K zero block: d=0 → all weights = 0 regardless of quant values
670        // (the 6-bit quants get -32 offset but d=0 zeroes everything).
671        let block = vec![0u8; BLOCK_BYTES];
672        let kernel = Q6KRef;
673        let mut out = vec![0.0f32; BLOCK_SIZE];
674        kernel
675            .dequant_block(&block, &mut out)
676            .expect("zero block should succeed");
677        for (i, &v) in out.iter().enumerate() {
678            assert!(v.abs() < 1e-5, "Q6_K weight[{i}] = {v}, expected ~0.0");
679        }
680    }
681
682    // ── Progress callback / metadata tests ───────────────────────────────────
683
684    #[test]
685    fn test_load_model_with_progress_empty_fails() {
686        // Empty bytes must return an error (not panic) when no progress cb given.
687        let result = oxillama_gguf::GgufFile::parse(&[]);
688        assert!(result.is_err(), "empty bytes must fail GGUF parse");
689    }
690
691    #[test]
692    fn test_parse_gguf_metadata_empty_fails() {
693        // parse_gguf_metadata on empty input must propagate the parse error.
694        let result = oxillama_gguf::GgufFile::parse(&[]);
695        assert!(result.is_err(), "empty bytes must fail metadata extraction");
696    }
697}