use serde::{Deserialize, Serialize};
use wasm_bindgen::prelude::*;
pub mod gpu_bridge;
pub mod idb_cache;
pub mod service_worker;
pub mod simd_check;
pub mod streaming_load;
pub mod streaming_loader;
pub mod webgpu;
pub mod worker;
pub use service_worker::{
get_service_worker_script, register_service_worker, ServiceWorkerOptions,
};
pub use simd_check::get_simd128_status;
pub use streaming_loader::StreamingGgufLoader;
pub use streaming_loader::StreamingLoadOptions;
#[wasm_bindgen(start)]
pub fn init() {
#[cfg(feature = "console_error_panic_hook")]
console_error_panic_hook::set_once();
}
#[wasm_bindgen(js_name = parseGgufHeader)]
pub fn parse_gguf_header(data: &[u8]) -> Result<JsValue, JsValue> {
let gguf = oxillama_gguf::GgufFile::parse(data)
.map_err(|e| JsValue::from_str(&format!("GGUF parse error: {e}")))?;
let obj = js_sys::Object::new();
js_sys::Reflect::set(
&obj,
&JsValue::from_str("tensorCount"),
&JsValue::from_f64(gguf.tensors.len() as f64),
)
.map_err(|e| JsValue::from_str(&format!("Reflect.set error: {e:?}")))?;
js_sys::Reflect::set(
&obj,
&JsValue::from_str("metadataCount"),
&JsValue::from_f64(gguf.metadata.len() as f64),
)
.map_err(|e| JsValue::from_str(&format!("Reflect.set error: {e:?}")))?;
js_sys::Reflect::set(
&obj,
&JsValue::from_str("version"),
&JsValue::from_f64(gguf.header.version as f64),
)
.map_err(|e| JsValue::from_str(&format!("Reflect.set error: {e:?}")))?;
Ok(JsValue::from(obj))
}
#[wasm_bindgen(js_name = listTensorNames)]
pub fn list_tensor_names(data: &[u8]) -> Result<Vec<JsValue>, JsValue> {
let gguf = oxillama_gguf::GgufFile::parse(data)
.map_err(|e| JsValue::from_str(&format!("GGUF parse error: {e}")))?;
Ok(gguf
.tensors
.names()
.map(|name| JsValue::from_str(name))
.collect())
}
#[wasm_bindgen(js_name = dequantQ4_0)]
pub fn dequant_q4_0(data: &[u8]) -> Result<Vec<f32>, JsValue> {
use oxillama_quant::reference::Q4_0Ref;
use oxillama_quant::traits::QuantKernel;
const BLOCK_BYTES: usize = 18;
const BLOCK_SIZE: usize = 32;
if !data.len().is_multiple_of(BLOCK_BYTES) {
return Err(JsValue::from_str(&format!(
"Q4_0 data length {} is not a multiple of {} bytes per block",
data.len(),
BLOCK_BYTES,
)));
}
let n_blocks = data.len() / BLOCK_BYTES;
let n_weights = n_blocks * BLOCK_SIZE;
let mut out = vec![0.0f32; n_weights];
let kernel = Q4_0Ref;
for (blk_idx, block) in data.chunks_exact(BLOCK_BYTES).enumerate() {
let output_slice = &mut out[blk_idx * BLOCK_SIZE..(blk_idx + 1) * BLOCK_SIZE];
kernel.dequant_block(block, output_slice).map_err(|e| {
JsValue::from_str(&format!("dequant_block error at block {blk_idx}: {e}"))
})?;
}
Ok(out)
}
#[cfg(feature = "inference")]
#[wasm_bindgen]
pub fn generate(
model_bytes: &[u8],
tokenizer_json: &str,
prompt: &str,
max_tokens: usize,
on_token: Option<js_sys::Function>,
) -> Result<String, JsValue> {
use oxillama_runtime::{EngineConfig, InferenceEngine};
let mut engine = InferenceEngine::new(EngineConfig::default());
engine
.load_model_from_bytes(model_bytes, tokenizer_json)
.map_err(|e| JsValue::from_str(&format!("model load error: {e}")))?;
let output = engine
.generate(prompt, max_tokens, |token_text| {
if let Some(ref cb) = on_token {
let this = JsValue::NULL;
let _ = cb.call1(&this, &JsValue::from_str(token_text));
}
})
.map_err(|e| JsValue::from_str(&format!("generation error: {e}")))?;
Ok(output)
}
#[wasm_bindgen(js_name = dequantQ4K)]
pub fn dequant_q4_k(data: &[u8]) -> Result<Vec<f32>, JsValue> {
use oxillama_quant::reference::Q4KRef;
use oxillama_quant::traits::QuantKernel;
const BLOCK_BYTES: usize = 144;
const BLOCK_SIZE: usize = 256;
if !data.len().is_multiple_of(BLOCK_BYTES) {
return Err(JsValue::from_str(&format!(
"Q4_K data length {} is not a multiple of {} bytes per block",
data.len(),
BLOCK_BYTES,
)));
}
let n_blocks = data.len() / BLOCK_BYTES;
let n_weights = n_blocks * BLOCK_SIZE;
let mut out = vec![0.0f32; n_weights];
let kernel = Q4KRef;
for (blk_idx, block) in data.chunks_exact(BLOCK_BYTES).enumerate() {
let output_slice = &mut out[blk_idx * BLOCK_SIZE..(blk_idx + 1) * BLOCK_SIZE];
kernel.dequant_block(block, output_slice).map_err(|e| {
JsValue::from_str(&format!("dequant_block error at block {blk_idx}: {e}"))
})?;
}
Ok(out)
}
#[wasm_bindgen(js_name = dequantQ5K)]
pub fn dequant_q5_k(data: &[u8]) -> Result<Vec<f32>, JsValue> {
use oxillama_quant::reference::Q5KRef;
use oxillama_quant::traits::QuantKernel;
const BLOCK_BYTES: usize = 176;
const BLOCK_SIZE: usize = 256;
if !data.len().is_multiple_of(BLOCK_BYTES) {
return Err(JsValue::from_str(&format!(
"Q5_K data length {} is not a multiple of {} bytes per block",
data.len(),
BLOCK_BYTES,
)));
}
let n_blocks = data.len() / BLOCK_BYTES;
let n_weights = n_blocks * BLOCK_SIZE;
let mut out = vec![0.0f32; n_weights];
let kernel = Q5KRef;
for (blk_idx, block) in data.chunks_exact(BLOCK_BYTES).enumerate() {
let output_slice = &mut out[blk_idx * BLOCK_SIZE..(blk_idx + 1) * BLOCK_SIZE];
kernel.dequant_block(block, output_slice).map_err(|e| {
JsValue::from_str(&format!("dequant_block error at block {blk_idx}: {e}"))
})?;
}
Ok(out)
}
#[wasm_bindgen(js_name = dequantQ6K)]
pub fn dequant_q6_k(data: &[u8]) -> Result<Vec<f32>, JsValue> {
use oxillama_quant::reference::Q6KRef;
use oxillama_quant::traits::QuantKernel;
const BLOCK_BYTES: usize = 210;
const BLOCK_SIZE: usize = 256;
if !data.len().is_multiple_of(BLOCK_BYTES) {
return Err(JsValue::from_str(&format!(
"Q6_K data length {} is not a multiple of {} bytes per block",
data.len(),
BLOCK_BYTES,
)));
}
let n_blocks = data.len() / BLOCK_BYTES;
let n_weights = n_blocks * BLOCK_SIZE;
let mut out = vec![0.0f32; n_weights];
let kernel = Q6KRef;
for (blk_idx, block) in data.chunks_exact(BLOCK_BYTES).enumerate() {
let output_slice = &mut out[blk_idx * BLOCK_SIZE..(blk_idx + 1) * BLOCK_SIZE];
kernel.dequant_block(block, output_slice).map_err(|e| {
JsValue::from_str(&format!("dequant_block error at block {blk_idx}: {e}"))
})?;
}
Ok(out)
}
#[derive(Debug, Serialize, Deserialize)]
pub struct GgufMetadataJs {
pub version: u32,
pub tensor_count: u64,
pub kv_count: u64,
pub arch: Option<String>,
pub context_length: Option<u64>,
pub embedding_length: Option<u64>,
pub feed_forward_length: Option<u64>,
pub attention_head_count: Option<u64>,
pub block_count: Option<u64>,
pub quantization_version: Option<u32>,
pub general_name: Option<String>,
pub general_author: Option<String>,
pub general_description: Option<String>,
}
#[cfg(feature = "inference")]
fn load_model_core(
model_bytes: &[u8],
tokenizer_json: &str,
on_progress: Option<&js_sys::Function>,
) -> Result<oxillama_runtime::InferenceEngine, JsValue> {
use oxillama_runtime::{EngineConfig, InferenceEngine};
let emit = |pct: u32| {
if let Some(cb) = on_progress {
let _ = cb.call1(&JsValue::UNDEFINED, &JsValue::from(pct));
}
};
emit(0);
let mut engine = InferenceEngine::new(EngineConfig::default());
emit(25);
engine
.load_model_from_bytes(model_bytes, tokenizer_json)
.map_err(|e| JsValue::from_str(&format!("model load error: {e}")))?;
emit(100);
Ok(engine)
}
#[cfg(feature = "inference")]
#[wasm_bindgen(js_name = loadModelFromBytesWithProgress)]
pub fn load_model_from_bytes_with_progress(
model_bytes: &[u8],
tokenizer_json: &str,
on_progress: Option<js_sys::Function>,
) -> Result<WasmEngine, JsValue> {
let engine = load_model_core(model_bytes, tokenizer_json, on_progress.as_ref())?;
Ok(WasmEngine { inner: engine })
}
#[cfg(feature = "inference")]
#[wasm_bindgen]
pub struct WasmEngine {
inner: oxillama_runtime::InferenceEngine,
}
#[cfg(feature = "inference")]
#[wasm_bindgen]
impl WasmEngine {
pub fn generate(
&mut self,
prompt: &str,
max_tokens: usize,
on_token: Option<js_sys::Function>,
) -> Result<String, JsValue> {
self.inner
.generate(prompt, max_tokens, |tok| {
if let Some(ref cb) = on_token {
let _ = cb.call1(&JsValue::NULL, &JsValue::from_str(tok));
}
})
.map_err(|e| JsValue::from_str(&format!("generation error: {e}")))
}
}
#[wasm_bindgen(js_name = parseGgufMetadata)]
pub fn parse_gguf_metadata(data: &[u8]) -> Result<JsValue, JsValue> {
let gguf = oxillama_gguf::GgufFile::parse(data)
.map_err(|e| JsValue::from_str(&format!("GGUF parse error: {e}")))?;
let meta = &gguf.metadata;
let arch: Option<String> = meta
.get("general.architecture")
.and_then(|v| v.as_str())
.map(|s| s.to_owned());
let get_u64 = |suffix: &str| -> Option<u64> {
let prefixes: &[&str] = match arch.as_deref() {
Some(a) => {
&[a, "llama", "mistral", "qwen3", "gemma", "phi"][..]
}
None => &["llama", "mistral", "qwen3", "gemma", "phi"][..],
};
for prefix in prefixes {
let key = format!("{prefix}.{suffix}");
if let Some(val) = meta.get(&key).and_then(|v| v.as_u64()) {
return Some(val);
}
}
None
};
let metadata_js = GgufMetadataJs {
version: gguf.header.version,
tensor_count: gguf.header.tensor_count,
kv_count: gguf.header.metadata_kv_count,
context_length: get_u64("context_length"),
embedding_length: get_u64("embedding_length"),
feed_forward_length: get_u64("feed_forward_length"),
attention_head_count: get_u64("attention.head_count"),
block_count: get_u64("block_count"),
quantization_version: meta
.get("general.quantization_version")
.and_then(|v| v.as_u32()),
general_name: meta
.get("general.name")
.and_then(|v| v.as_str())
.map(|s| s.to_owned()),
general_author: meta
.get("general.author")
.and_then(|v| v.as_str())
.map(|s| s.to_owned()),
general_description: meta
.get("general.description")
.and_then(|v| v.as_str())
.map(|s| s.to_owned()),
arch,
};
serde_wasm_bindgen::to_value(&metadata_js).map_err(|e| JsValue::from_str(&e.to_string()))
}
#[cfg(test)]
mod tests {
use oxillama_quant::reference::Q4_0Ref;
use oxillama_quant::traits::QuantKernel;
#[test]
fn test_parse_gguf_empty_fails() {
let result = oxillama_gguf::GgufFile::parse(&[]);
assert!(result.is_err(), "empty buffer should fail to parse");
}
#[test]
fn test_parse_gguf_bad_magic_fails() {
let bad = b"BAAD\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00";
let result = oxillama_gguf::GgufFile::parse(bad);
assert!(result.is_err(), "wrong magic should fail to parse");
}
#[test]
fn test_dequant_q4_0_wrong_length_fails() {
const BLOCK_BYTES: usize = 18;
let bad = vec![0u8; 17];
assert_ne!(
bad.len() % BLOCK_BYTES,
0,
"17 must not be a multiple of 18"
);
let kernel = Q4_0Ref;
let mut out = vec![0.0f32; 32];
let result = kernel.dequant_block(&bad, &mut out);
assert!(result.is_err(), "incomplete block should fail");
}
#[test]
fn test_dequant_q4_0_zero_block() {
let mut block = vec![0u8; 18];
block[0] = 0x00;
block[1] = 0x3C;
for b in block[2..].iter_mut() {
*b = 0x88;
}
let kernel = Q4_0Ref;
let mut out = vec![0.0f32; 32];
kernel
.dequant_block(&block, &mut out)
.expect("should not fail on valid block");
assert_eq!(out.len(), 32, "one block = 32 weights");
for (i, &v) in out.iter().enumerate() {
assert!(v.abs() < 1e-5, "weight[{i}] = {v}, expected ~0.0");
}
}
#[test]
fn test_dequant_q4_0_two_blocks_length() {
const BLOCK_BYTES: usize = 18;
const BLOCK_SIZE: usize = 32;
let data = [0u8; 2 * BLOCK_BYTES];
let kernel = Q4_0Ref;
let n_blocks = data.len() / BLOCK_BYTES;
let mut out = vec![0.0f32; n_blocks * BLOCK_SIZE];
for (blk_idx, block) in data.chunks_exact(BLOCK_BYTES).enumerate() {
let slice = &mut out[blk_idx * BLOCK_SIZE..(blk_idx + 1) * BLOCK_SIZE];
kernel
.dequant_block(block, slice)
.expect("dequant_block should succeed on zeroed data");
}
assert_eq!(out.len(), 64, "two blocks = 64 weights");
}
#[test]
fn test_dequant_q4_k_wrong_length_fails() {
use oxillama_quant::reference::Q4KRef;
const BLOCK_BYTES: usize = 144;
let bad = vec![0u8; 143];
assert_ne!(bad.len() % BLOCK_BYTES, 0);
let kernel = Q4KRef;
let mut out = vec![0.0f32; 256];
let result = kernel.dequant_block(&bad, &mut out);
assert!(result.is_err(), "incomplete Q4_K block should fail");
}
#[test]
fn test_dequant_q4_k_zero_block() {
use oxillama_quant::reference::Q4KRef;
const BLOCK_BYTES: usize = 144;
const BLOCK_SIZE: usize = 256;
let block = vec![0u8; BLOCK_BYTES];
let kernel = Q4KRef;
let mut out = vec![0.0f32; BLOCK_SIZE];
kernel
.dequant_block(&block, &mut out)
.expect("zero block should succeed");
for (i, &v) in out.iter().enumerate() {
assert!(v.abs() < 1e-5, "Q4_K weight[{i}] = {v}, expected ~0.0");
}
}
#[test]
fn test_dequant_q5_k_wrong_length_fails() {
use oxillama_quant::reference::Q5KRef;
const BLOCK_BYTES: usize = 176;
let bad = vec![0u8; 175];
assert_ne!(bad.len() % BLOCK_BYTES, 0);
let kernel = Q5KRef;
let mut out = vec![0.0f32; 256];
let result = kernel.dequant_block(&bad, &mut out);
assert!(result.is_err(), "incomplete Q5_K block should fail");
}
#[test]
fn test_dequant_q5_k_zero_block() {
use oxillama_quant::reference::Q5KRef;
const BLOCK_BYTES: usize = 176;
const BLOCK_SIZE: usize = 256;
let block = vec![0u8; BLOCK_BYTES];
let kernel = Q5KRef;
let mut out = vec![0.0f32; BLOCK_SIZE];
kernel
.dequant_block(&block, &mut out)
.expect("zero block should succeed");
for (i, &v) in out.iter().enumerate() {
assert!(v.abs() < 1e-5, "Q5_K weight[{i}] = {v}, expected ~0.0");
}
}
#[test]
fn test_dequant_q6_k_wrong_length_fails() {
use oxillama_quant::reference::Q6KRef;
const BLOCK_BYTES: usize = 210;
let bad = vec![0u8; 209];
assert_ne!(bad.len() % BLOCK_BYTES, 0);
let kernel = Q6KRef;
let mut out = vec![0.0f32; 256];
let result = kernel.dequant_block(&bad, &mut out);
assert!(result.is_err(), "incomplete Q6_K block should fail");
}
#[test]
fn test_dequant_q6_k_zero_block() {
use oxillama_quant::reference::Q6KRef;
const BLOCK_BYTES: usize = 210;
const BLOCK_SIZE: usize = 256;
let block = vec![0u8; BLOCK_BYTES];
let kernel = Q6KRef;
let mut out = vec![0.0f32; BLOCK_SIZE];
kernel
.dequant_block(&block, &mut out)
.expect("zero block should succeed");
for (i, &v) in out.iter().enumerate() {
assert!(v.abs() < 1e-5, "Q6_K weight[{i}] = {v}, expected ~0.0");
}
}
#[test]
fn test_load_model_with_progress_empty_fails() {
let result = oxillama_gguf::GgufFile::parse(&[]);
assert!(result.is_err(), "empty bytes must fail GGUF parse");
}
#[test]
fn test_parse_gguf_metadata_empty_fails() {
let result = oxillama_gguf::GgufFile::parse(&[]);
assert!(result.is_err(), "empty bytes must fail metadata extraction");
}
}