1use serde::{Deserialize, Serialize};
31use wasm_bindgen::prelude::*;
32
33pub mod gpu_bridge;
34pub mod idb_cache;
35pub mod service_worker;
36pub mod simd_check;
37pub mod streaming_load;
38pub mod streaming_loader;
39pub mod webgpu;
40pub mod worker;
41
42pub use service_worker::{
43 get_service_worker_script, register_service_worker, ServiceWorkerOptions,
44};
45pub use simd_check::get_simd128_status;
46pub use streaming_loader::StreamingGgufLoader;
47pub use streaming_loader::StreamingLoadOptions;
48
49#[wasm_bindgen(start)]
56pub fn init() {
57 #[cfg(feature = "console_error_panic_hook")]
58 console_error_panic_hook::set_once();
59}
60
61#[wasm_bindgen(js_name = parseGgufHeader)]
72pub fn parse_gguf_header(data: &[u8]) -> Result<JsValue, JsValue> {
73 let gguf = oxillama_gguf::GgufFile::parse(data)
74 .map_err(|e| JsValue::from_str(&format!("GGUF parse error: {e}")))?;
75
76 let obj = js_sys::Object::new();
77 js_sys::Reflect::set(
78 &obj,
79 &JsValue::from_str("tensorCount"),
80 &JsValue::from_f64(gguf.tensors.len() as f64),
81 )
82 .map_err(|e| JsValue::from_str(&format!("Reflect.set error: {e:?}")))?;
83 js_sys::Reflect::set(
84 &obj,
85 &JsValue::from_str("metadataCount"),
86 &JsValue::from_f64(gguf.metadata.len() as f64),
87 )
88 .map_err(|e| JsValue::from_str(&format!("Reflect.set error: {e:?}")))?;
89 js_sys::Reflect::set(
90 &obj,
91 &JsValue::from_str("version"),
92 &JsValue::from_f64(gguf.header.version as f64),
93 )
94 .map_err(|e| JsValue::from_str(&format!("Reflect.set error: {e:?}")))?;
95
96 Ok(JsValue::from(obj))
97}
98
99#[wasm_bindgen(js_name = listTensorNames)]
103pub fn list_tensor_names(data: &[u8]) -> Result<Vec<JsValue>, JsValue> {
104 let gguf = oxillama_gguf::GgufFile::parse(data)
105 .map_err(|e| JsValue::from_str(&format!("GGUF parse error: {e}")))?;
106
107 Ok(gguf
108 .tensors
109 .names()
110 .map(|name| JsValue::from_str(name))
111 .collect())
112}
113
114#[wasm_bindgen(js_name = dequantQ4_0)]
126pub fn dequant_q4_0(data: &[u8]) -> Result<Vec<f32>, JsValue> {
127 use oxillama_quant::reference::Q4_0Ref;
128 use oxillama_quant::traits::QuantKernel;
129
130 const BLOCK_BYTES: usize = 18;
131 const BLOCK_SIZE: usize = 32;
132
133 if !data.len().is_multiple_of(BLOCK_BYTES) {
134 return Err(JsValue::from_str(&format!(
135 "Q4_0 data length {} is not a multiple of {} bytes per block",
136 data.len(),
137 BLOCK_BYTES,
138 )));
139 }
140
141 let n_blocks = data.len() / BLOCK_BYTES;
142 let n_weights = n_blocks * BLOCK_SIZE;
143 let mut out = vec![0.0f32; n_weights];
144 let kernel = Q4_0Ref;
145
146 for (blk_idx, block) in data.chunks_exact(BLOCK_BYTES).enumerate() {
147 let output_slice = &mut out[blk_idx * BLOCK_SIZE..(blk_idx + 1) * BLOCK_SIZE];
148 kernel.dequant_block(block, output_slice).map_err(|e| {
149 JsValue::from_str(&format!("dequant_block error at block {blk_idx}: {e}"))
150 })?;
151 }
152
153 Ok(out)
154}
155
156#[cfg(feature = "inference")]
183#[wasm_bindgen]
184pub fn generate(
185 model_bytes: &[u8],
186 tokenizer_json: &str,
187 prompt: &str,
188 max_tokens: usize,
189 on_token: Option<js_sys::Function>,
190) -> Result<String, JsValue> {
191 use oxillama_runtime::{EngineConfig, InferenceEngine};
192
193 let mut engine = InferenceEngine::new(EngineConfig::default());
200 engine
201 .load_model_from_bytes(model_bytes, tokenizer_json)
202 .map_err(|e| JsValue::from_str(&format!("model load error: {e}")))?;
203
204 let output = engine
211 .generate(prompt, max_tokens, |token_text| {
212 if let Some(ref cb) = on_token {
213 let this = JsValue::NULL;
214 let _ = cb.call1(&this, &JsValue::from_str(token_text));
215 }
216 })
217 .map_err(|e| JsValue::from_str(&format!("generation error: {e}")))?;
218
219 Ok(output)
220}
221
222#[wasm_bindgen(js_name = dequantQ4K)]
231pub fn dequant_q4_k(data: &[u8]) -> Result<Vec<f32>, JsValue> {
232 use oxillama_quant::reference::Q4KRef;
233 use oxillama_quant::traits::QuantKernel;
234
235 const BLOCK_BYTES: usize = 144;
236 const BLOCK_SIZE: usize = 256;
237
238 if !data.len().is_multiple_of(BLOCK_BYTES) {
239 return Err(JsValue::from_str(&format!(
240 "Q4_K data length {} is not a multiple of {} bytes per block",
241 data.len(),
242 BLOCK_BYTES,
243 )));
244 }
245
246 let n_blocks = data.len() / BLOCK_BYTES;
247 let n_weights = n_blocks * BLOCK_SIZE;
248 let mut out = vec![0.0f32; n_weights];
249 let kernel = Q4KRef;
250
251 for (blk_idx, block) in data.chunks_exact(BLOCK_BYTES).enumerate() {
252 let output_slice = &mut out[blk_idx * BLOCK_SIZE..(blk_idx + 1) * BLOCK_SIZE];
253 kernel.dequant_block(block, output_slice).map_err(|e| {
254 JsValue::from_str(&format!("dequant_block error at block {blk_idx}: {e}"))
255 })?;
256 }
257
258 Ok(out)
259}
260
261#[wasm_bindgen(js_name = dequantQ5K)]
268pub fn dequant_q5_k(data: &[u8]) -> Result<Vec<f32>, JsValue> {
269 use oxillama_quant::reference::Q5KRef;
270 use oxillama_quant::traits::QuantKernel;
271
272 const BLOCK_BYTES: usize = 176;
273 const BLOCK_SIZE: usize = 256;
274
275 if !data.len().is_multiple_of(BLOCK_BYTES) {
276 return Err(JsValue::from_str(&format!(
277 "Q5_K data length {} is not a multiple of {} bytes per block",
278 data.len(),
279 BLOCK_BYTES,
280 )));
281 }
282
283 let n_blocks = data.len() / BLOCK_BYTES;
284 let n_weights = n_blocks * BLOCK_SIZE;
285 let mut out = vec![0.0f32; n_weights];
286 let kernel = Q5KRef;
287
288 for (blk_idx, block) in data.chunks_exact(BLOCK_BYTES).enumerate() {
289 let output_slice = &mut out[blk_idx * BLOCK_SIZE..(blk_idx + 1) * BLOCK_SIZE];
290 kernel.dequant_block(block, output_slice).map_err(|e| {
291 JsValue::from_str(&format!("dequant_block error at block {blk_idx}: {e}"))
292 })?;
293 }
294
295 Ok(out)
296}
297
298#[wasm_bindgen(js_name = dequantQ6K)]
305pub fn dequant_q6_k(data: &[u8]) -> Result<Vec<f32>, JsValue> {
306 use oxillama_quant::reference::Q6KRef;
307 use oxillama_quant::traits::QuantKernel;
308
309 const BLOCK_BYTES: usize = 210;
310 const BLOCK_SIZE: usize = 256;
311
312 if !data.len().is_multiple_of(BLOCK_BYTES) {
313 return Err(JsValue::from_str(&format!(
314 "Q6_K data length {} is not a multiple of {} bytes per block",
315 data.len(),
316 BLOCK_BYTES,
317 )));
318 }
319
320 let n_blocks = data.len() / BLOCK_BYTES;
321 let n_weights = n_blocks * BLOCK_SIZE;
322 let mut out = vec![0.0f32; n_weights];
323 let kernel = Q6KRef;
324
325 for (blk_idx, block) in data.chunks_exact(BLOCK_BYTES).enumerate() {
326 let output_slice = &mut out[blk_idx * BLOCK_SIZE..(blk_idx + 1) * BLOCK_SIZE];
327 kernel.dequant_block(block, output_slice).map_err(|e| {
328 JsValue::from_str(&format!("dequant_block error at block {blk_idx}: {e}"))
329 })?;
330 }
331
332 Ok(out)
333}
334
335#[derive(Debug, Serialize, Deserialize)]
341pub struct GgufMetadataJs {
342 pub version: u32,
343 pub tensor_count: u64,
344 pub kv_count: u64,
345 pub arch: Option<String>,
346 pub context_length: Option<u64>,
347 pub embedding_length: Option<u64>,
348 pub feed_forward_length: Option<u64>,
349 pub attention_head_count: Option<u64>,
350 pub block_count: Option<u64>,
351 pub quantization_version: Option<u32>,
352 pub general_name: Option<String>,
353 pub general_author: Option<String>,
354 pub general_description: Option<String>,
355}
356
357#[cfg(feature = "inference")]
363fn load_model_core(
364 model_bytes: &[u8],
365 tokenizer_json: &str,
366 on_progress: Option<&js_sys::Function>,
367) -> Result<oxillama_runtime::InferenceEngine, JsValue> {
368 use oxillama_runtime::{EngineConfig, InferenceEngine};
369
370 let emit = |pct: u32| {
371 if let Some(cb) = on_progress {
372 let _ = cb.call1(&JsValue::UNDEFINED, &JsValue::from(pct));
373 }
374 };
375
376 emit(0);
377 let mut engine = InferenceEngine::new(EngineConfig::default());
378 emit(25);
379 engine
380 .load_model_from_bytes(model_bytes, tokenizer_json)
381 .map_err(|e| JsValue::from_str(&format!("model load error: {e}")))?;
382 emit(100);
383
384 Ok(engine)
385}
386
387#[cfg(feature = "inference")]
395#[wasm_bindgen(js_name = loadModelFromBytesWithProgress)]
396pub fn load_model_from_bytes_with_progress(
397 model_bytes: &[u8],
398 tokenizer_json: &str,
399 on_progress: Option<js_sys::Function>,
400) -> Result<WasmEngine, JsValue> {
401 let engine = load_model_core(model_bytes, tokenizer_json, on_progress.as_ref())?;
402 Ok(WasmEngine { inner: engine })
403}
404
405#[cfg(feature = "inference")]
407#[wasm_bindgen]
408pub struct WasmEngine {
409 inner: oxillama_runtime::InferenceEngine,
410}
411
412#[cfg(feature = "inference")]
413#[wasm_bindgen]
414impl WasmEngine {
415 pub fn generate(
420 &mut self,
421 prompt: &str,
422 max_tokens: usize,
423 on_token: Option<js_sys::Function>,
424 ) -> Result<String, JsValue> {
425 self.inner
426 .generate(prompt, max_tokens, |tok| {
427 if let Some(ref cb) = on_token {
428 let _ = cb.call1(&JsValue::NULL, &JsValue::from_str(tok));
429 }
430 })
431 .map_err(|e| JsValue::from_str(&format!("generation error: {e}")))
432 }
433}
434
435#[wasm_bindgen(js_name = parseGgufMetadata)]
444pub fn parse_gguf_metadata(data: &[u8]) -> Result<JsValue, JsValue> {
445 let gguf = oxillama_gguf::GgufFile::parse(data)
446 .map_err(|e| JsValue::from_str(&format!("GGUF parse error: {e}")))?;
447
448 let meta = &gguf.metadata;
449
450 let arch: Option<String> = meta
452 .get("general.architecture")
453 .and_then(|v| v.as_str())
454 .map(|s| s.to_owned());
455
456 let get_u64 = |suffix: &str| -> Option<u64> {
459 let prefixes: &[&str] = match arch.as_deref() {
460 Some(a) => {
461 &[a, "llama", "mistral", "qwen3", "gemma", "phi"][..]
463 }
464 None => &["llama", "mistral", "qwen3", "gemma", "phi"][..],
465 };
466 for prefix in prefixes {
467 let key = format!("{prefix}.{suffix}");
468 if let Some(val) = meta.get(&key).and_then(|v| v.as_u64()) {
469 return Some(val);
470 }
471 }
472 None
473 };
474
475 let metadata_js = GgufMetadataJs {
476 version: gguf.header.version,
477 tensor_count: gguf.header.tensor_count,
478 kv_count: gguf.header.metadata_kv_count,
479 context_length: get_u64("context_length"),
480 embedding_length: get_u64("embedding_length"),
481 feed_forward_length: get_u64("feed_forward_length"),
482 attention_head_count: get_u64("attention.head_count"),
483 block_count: get_u64("block_count"),
484 quantization_version: meta
485 .get("general.quantization_version")
486 .and_then(|v| v.as_u32()),
487 general_name: meta
488 .get("general.name")
489 .and_then(|v| v.as_str())
490 .map(|s| s.to_owned()),
491 general_author: meta
492 .get("general.author")
493 .and_then(|v| v.as_str())
494 .map(|s| s.to_owned()),
495 general_description: meta
496 .get("general.description")
497 .and_then(|v| v.as_str())
498 .map(|s| s.to_owned()),
499 arch,
500 };
501
502 serde_wasm_bindgen::to_value(&metadata_js).map_err(|e| JsValue::from_str(&e.to_string()))
503}
504
505#[cfg(test)]
508mod tests {
509 use oxillama_quant::reference::Q4_0Ref;
515 use oxillama_quant::traits::QuantKernel;
516
517 #[test]
518 fn test_parse_gguf_empty_fails() {
519 let result = oxillama_gguf::GgufFile::parse(&[]);
521 assert!(result.is_err(), "empty buffer should fail to parse");
522 }
523
524 #[test]
525 fn test_parse_gguf_bad_magic_fails() {
526 let bad = b"BAAD\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00";
528 let result = oxillama_gguf::GgufFile::parse(bad);
529 assert!(result.is_err(), "wrong magic should fail to parse");
530 }
531
532 #[test]
533 fn test_dequant_q4_0_wrong_length_fails() {
534 const BLOCK_BYTES: usize = 18;
536 let bad = vec![0u8; 17];
537 assert_ne!(
538 bad.len() % BLOCK_BYTES,
539 0,
540 "17 must not be a multiple of 18"
541 );
542 let kernel = Q4_0Ref;
544 let mut out = vec![0.0f32; 32];
545 let result = kernel.dequant_block(&bad, &mut out);
546 assert!(result.is_err(), "incomplete block should fail");
547 }
548
549 #[test]
550 fn test_dequant_q4_0_zero_block() {
551 let mut block = vec![0u8; 18];
553 block[0] = 0x00;
555 block[1] = 0x3C;
556 for b in block[2..].iter_mut() {
558 *b = 0x88;
559 }
560 let kernel = Q4_0Ref;
561 let mut out = vec![0.0f32; 32];
562 kernel
563 .dequant_block(&block, &mut out)
564 .expect("should not fail on valid block");
565 assert_eq!(out.len(), 32, "one block = 32 weights");
566 for (i, &v) in out.iter().enumerate() {
567 assert!(v.abs() < 1e-5, "weight[{i}] = {v}, expected ~0.0");
568 }
569 }
570
571 #[test]
572 fn test_dequant_q4_0_two_blocks_length() {
573 const BLOCK_BYTES: usize = 18;
575 const BLOCK_SIZE: usize = 32;
576 let data = [0u8; 2 * BLOCK_BYTES];
577 let kernel = Q4_0Ref;
578 let n_blocks = data.len() / BLOCK_BYTES;
579 let mut out = vec![0.0f32; n_blocks * BLOCK_SIZE];
580 for (blk_idx, block) in data.chunks_exact(BLOCK_BYTES).enumerate() {
581 let slice = &mut out[blk_idx * BLOCK_SIZE..(blk_idx + 1) * BLOCK_SIZE];
582 kernel
583 .dequant_block(block, slice)
584 .expect("dequant_block should succeed on zeroed data");
585 }
586 assert_eq!(out.len(), 64, "two blocks = 64 weights");
587 }
588
589 #[test]
592 fn test_dequant_q4_k_wrong_length_fails() {
593 use oxillama_quant::reference::Q4KRef;
594 const BLOCK_BYTES: usize = 144;
595 let bad = vec![0u8; 143];
596 assert_ne!(bad.len() % BLOCK_BYTES, 0);
597 let kernel = Q4KRef;
598 let mut out = vec![0.0f32; 256];
599 let result = kernel.dequant_block(&bad, &mut out);
600 assert!(result.is_err(), "incomplete Q4_K block should fail");
601 }
602
603 #[test]
604 fn test_dequant_q4_k_zero_block() {
605 use oxillama_quant::reference::Q4KRef;
606 const BLOCK_BYTES: usize = 144;
607 const BLOCK_SIZE: usize = 256;
608 let block = vec![0u8; BLOCK_BYTES];
610 let kernel = Q4KRef;
611 let mut out = vec![0.0f32; BLOCK_SIZE];
612 kernel
613 .dequant_block(&block, &mut out)
614 .expect("zero block should succeed");
615 for (i, &v) in out.iter().enumerate() {
616 assert!(v.abs() < 1e-5, "Q4_K weight[{i}] = {v}, expected ~0.0");
617 }
618 }
619
620 #[test]
623 fn test_dequant_q5_k_wrong_length_fails() {
624 use oxillama_quant::reference::Q5KRef;
625 const BLOCK_BYTES: usize = 176;
626 let bad = vec![0u8; 175];
627 assert_ne!(bad.len() % BLOCK_BYTES, 0);
628 let kernel = Q5KRef;
629 let mut out = vec![0.0f32; 256];
630 let result = kernel.dequant_block(&bad, &mut out);
631 assert!(result.is_err(), "incomplete Q5_K block should fail");
632 }
633
634 #[test]
635 fn test_dequant_q5_k_zero_block() {
636 use oxillama_quant::reference::Q5KRef;
637 const BLOCK_BYTES: usize = 176;
638 const BLOCK_SIZE: usize = 256;
639 let block = vec![0u8; BLOCK_BYTES];
640 let kernel = Q5KRef;
641 let mut out = vec![0.0f32; BLOCK_SIZE];
642 kernel
643 .dequant_block(&block, &mut out)
644 .expect("zero block should succeed");
645 for (i, &v) in out.iter().enumerate() {
646 assert!(v.abs() < 1e-5, "Q5_K weight[{i}] = {v}, expected ~0.0");
647 }
648 }
649
650 #[test]
653 fn test_dequant_q6_k_wrong_length_fails() {
654 use oxillama_quant::reference::Q6KRef;
655 const BLOCK_BYTES: usize = 210;
656 let bad = vec![0u8; 209];
657 assert_ne!(bad.len() % BLOCK_BYTES, 0);
658 let kernel = Q6KRef;
659 let mut out = vec![0.0f32; 256];
660 let result = kernel.dequant_block(&bad, &mut out);
661 assert!(result.is_err(), "incomplete Q6_K block should fail");
662 }
663
664 #[test]
665 fn test_dequant_q6_k_zero_block() {
666 use oxillama_quant::reference::Q6KRef;
667 const BLOCK_BYTES: usize = 210;
668 const BLOCK_SIZE: usize = 256;
669 let block = vec![0u8; BLOCK_BYTES];
672 let kernel = Q6KRef;
673 let mut out = vec![0.0f32; BLOCK_SIZE];
674 kernel
675 .dequant_block(&block, &mut out)
676 .expect("zero block should succeed");
677 for (i, &v) in out.iter().enumerate() {
678 assert!(v.abs() < 1e-5, "Q6_K weight[{i}] = {v}, expected ~0.0");
679 }
680 }
681
682 #[test]
685 fn test_load_model_with_progress_empty_fails() {
686 let result = oxillama_gguf::GgufFile::parse(&[]);
688 assert!(result.is_err(), "empty bytes must fail GGUF parse");
689 }
690
691 #[test]
692 fn test_parse_gguf_metadata_empty_fails() {
693 let result = oxillama_gguf::GgufFile::parse(&[]);
695 assert!(result.is_err(), "empty bytes must fail metadata extraction");
696 }
697}