use std::collections::{HashMap, VecDeque};
use std::sync::{Arc, Mutex};
use js_sys::Uint8Array;
use wasm_bindgen::prelude::*;
use wasm_bindgen_futures::JsFuture;
#[derive(Clone, Debug)]
pub struct TensorMeta {
pub file_offset: u64,
pub size_bytes: u64,
pub dtype: String,
pub shape: Vec<u64>,
}
pub struct LruTensorCache {
data: HashMap<String, Arc<Vec<u8>>>,
order: VecDeque<String>,
capacity: usize,
}
impl LruTensorCache {
pub fn new(capacity: usize) -> Self {
Self {
data: HashMap::new(),
order: VecDeque::new(),
capacity,
}
}
pub fn get(&mut self, name: &str) -> Option<Arc<Vec<u8>>> {
if !self.data.contains_key(name) {
return None;
}
self.order.retain(|n| n != name);
self.order.push_back(name.to_owned());
self.data.get(name).cloned()
}
pub fn put(&mut self, name: String, bytes: Vec<u8>) {
if self.data.contains_key(&name) {
self.order.retain(|n| n != &name);
self.data.remove(&name);
}
while self.capacity > 0 && self.data.len() >= self.capacity {
if let Some(oldest) = self.order.pop_front() {
self.data.remove(&oldest);
} else {
break;
}
}
self.order.push_back(name.clone());
self.data.insert(name, Arc::new(bytes));
}
pub fn len(&self) -> usize {
self.data.len()
}
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum LoadPhase {
WaitingForHeader,
HeaderParsed,
}
pub struct LoaderInner {
bytes_buffer: Vec<u8>,
tensor_index: HashMap<String, TensorMeta>,
cache: LruTensorCache,
phase: LoadPhase,
}
impl LoaderInner {
fn new(cache_capacity: usize) -> Self {
Self {
bytes_buffer: Vec::new(),
tensor_index: HashMap::new(),
cache: LruTensorCache::new(cache_capacity),
phase: LoadPhase::WaitingForHeader,
}
}
fn try_parse_header(&mut self) -> Result<bool, String> {
if self.phase == LoadPhase::HeaderParsed {
return Ok(true);
}
if self.bytes_buffer.len() < 24 {
return Ok(false);
}
match oxillama_gguf::StreamingGgufParser::new(&self.bytes_buffer) {
Ok(parser) => {
let data_offset = parser.tensor_infos().data_offset();
for result in parser.tensor_infos() {
match result {
Ok(info) => {
let file_offset = data_offset + info.offset;
let size_bytes = info.data_size();
let dtype = format!("{:?}", info.tensor_type);
let shape = info.dimensions.clone();
self.tensor_index.insert(
info.name,
TensorMeta {
file_offset,
size_bytes,
dtype,
shape,
},
);
}
Err(e) => {
return Err(format!("tensor info parse error: {e}"));
}
}
}
self.phase = LoadPhase::HeaderParsed;
Ok(true)
}
Err(oxillama_gguf::GgufError::UnexpectedEof { .. }) => {
Ok(false)
}
Err(other) => {
Err(format!("GGUF parse error: {other}"))
}
}
}
}
#[wasm_bindgen]
pub struct StreamingGgufLoader {
inner: Arc<Mutex<LoaderInner>>,
}
#[wasm_bindgen]
impl StreamingGgufLoader {
#[wasm_bindgen(constructor)]
pub fn new(cache_capacity: Option<usize>) -> Self {
Self {
inner: Arc::new(Mutex::new(LoaderInner::new(cache_capacity.unwrap_or(8)))),
}
}
pub fn push_chunk(&mut self, chunk: &[u8]) -> Result<bool, JsValue> {
let mut guard = self
.inner
.lock()
.map_err(|e| JsValue::from_str(&format!("lock poisoned: {e}")))?;
guard.bytes_buffer.extend_from_slice(chunk);
guard.try_parse_header().map_err(|e| JsValue::from_str(&e))
}
pub fn is_header_ready(&self) -> bool {
self.inner
.lock()
.map(|g| g.phase == LoadPhase::HeaderParsed)
.unwrap_or(false)
}
pub fn bytes_buffered(&self) -> u32 {
self.inner
.lock()
.map(|g| g.bytes_buffer.len() as u32)
.unwrap_or(0)
}
pub fn tensor_names(&self) -> Result<js_sys::Array, JsValue> {
let guard = self
.inner
.lock()
.map_err(|e| JsValue::from_str(&format!("lock poisoned: {e}")))?;
let arr = js_sys::Array::new();
for name in guard.tensor_index.keys() {
arr.push(&JsValue::from_str(name));
}
Ok(arr)
}
pub fn tensor_meta_json(&self, name: &str) -> Option<String> {
let guard = self.inner.lock().ok()?;
let meta = guard.tensor_index.get(name)?;
let json = serde_json::json!({
"file_offset": meta.file_offset,
"size_bytes": meta.size_bytes,
"dtype": meta.dtype,
"shape": meta.shape,
});
serde_json::to_string(&json).ok()
}
pub async fn read_tensor(
&self,
name: &str,
fetcher: &js_sys::Function,
) -> Result<Uint8Array, JsValue> {
let maybe_cached = {
let mut guard = self
.inner
.lock()
.map_err(|e| JsValue::from_str(&format!("lock poisoned: {e}")))?;
if guard.phase != LoadPhase::HeaderParsed {
return Err(JsValue::from_str(
"header not yet parsed — call push_chunk until is_header_ready() returns true",
));
}
guard.cache.get(name)
};
if let Some(cached) = maybe_cached {
return Ok(Uint8Array::from(cached.as_slice()));
}
let (file_offset, size_bytes) = {
let guard = self
.inner
.lock()
.map_err(|e| JsValue::from_str(&format!("lock poisoned: {e}")))?;
let meta = guard
.tensor_index
.get(name)
.ok_or_else(|| JsValue::from_str(&format!("tensor '{name}' not found in index")))?;
(meta.file_offset, meta.size_bytes)
};
let bytes = call_byte_range_fetcher(fetcher, file_offset, size_bytes).await?;
{
let mut guard = self
.inner
.lock()
.map_err(|e| JsValue::from_str(&format!("lock poisoned: {e}")))?;
guard.cache.put(name.to_owned(), bytes.clone());
}
Ok(Uint8Array::from(bytes.as_slice()))
}
pub fn progress(&self) -> Result<js_sys::Object, JsValue> {
let guard = self
.inner
.lock()
.map_err(|e| JsValue::from_str(&format!("lock poisoned: {e}")))?;
let phase_str = match guard.phase {
LoadPhase::WaitingForHeader => "waiting_for_header",
LoadPhase::HeaderParsed => "header_parsed",
};
let obj = js_sys::Object::new();
set_js_prop(
&obj,
"bytes_buffered",
&JsValue::from(guard.bytes_buffer.len() as f64),
)?;
set_js_prop(&obj, "phase", &JsValue::from_str(phase_str))?;
set_js_prop(
&obj,
"tensor_count",
&JsValue::from(guard.tensor_index.len() as f64),
)?;
set_js_prop(&obj, "cache_size", &JsValue::from(guard.cache.len() as f64))?;
Ok(obj)
}
}
#[wasm_bindgen]
pub struct StreamingLoadOptions {
pub progress_enabled: bool,
pub cache_capacity: usize,
}
#[wasm_bindgen]
impl StreamingLoadOptions {
#[wasm_bindgen(constructor)]
pub fn new() -> Self {
Self {
progress_enabled: true,
cache_capacity: 8,
}
}
}
impl Default for StreamingLoadOptions {
fn default() -> Self {
Self::new()
}
}
fn set_js_prop(obj: &js_sys::Object, key: &str, value: &JsValue) -> Result<(), JsValue> {
js_sys::Reflect::set(obj, &JsValue::from_str(key), value)
.map_err(|e| JsValue::from_str(&format!("Reflect.set({key}) failed: {e:?}")))?;
Ok(())
}
async fn call_byte_range_fetcher(
fetcher: &js_sys::Function,
offset: u64,
size: u64,
) -> Result<Vec<u8>, JsValue> {
let js_offset = JsValue::from_f64(offset as f64);
let js_size = JsValue::from_f64(size as f64);
let promise_val = fetcher.call2(&JsValue::NULL, &js_offset, &js_size)?;
let promise = js_sys::Promise::from(promise_val);
let resolved = JsFuture::from(promise).await?;
let array = Uint8Array::new(&resolved);
Ok(array.to_vec())
}
#[cfg(test)]
mod tests {
use super::*;
use oxillama_gguf::types::{GgufTensorType, GgufValueType, GGUF_MAGIC};
fn write_string_v3(buf: &mut Vec<u8>, s: &str) {
buf.extend_from_slice(&(s.len() as u64).to_le_bytes());
buf.extend_from_slice(s.as_bytes());
}
fn align_up(value: usize, alignment: usize) -> usize {
if alignment == 0 {
return value;
}
let rem = value % alignment;
if rem == 0 {
value
} else {
value + alignment - rem
}
}
fn make_empty_gguf() -> Vec<u8> {
let mut buf = Vec::new();
buf.extend_from_slice(&GGUF_MAGIC.to_le_bytes()); buf.extend_from_slice(&3u32.to_le_bytes()); buf.extend_from_slice(&0u64.to_le_bytes()); buf.extend_from_slice(&0u64.to_le_bytes()); buf
}
fn make_two_tensor_gguf() -> Vec<u8> {
let mut buf = Vec::new();
buf.extend_from_slice(&GGUF_MAGIC.to_le_bytes());
buf.extend_from_slice(&3u32.to_le_bytes()); buf.extend_from_slice(&2u64.to_le_bytes()); buf.extend_from_slice(&1u64.to_le_bytes());
write_string_v3(&mut buf, "general.architecture");
buf.extend_from_slice(&(GgufValueType::String as u32).to_le_bytes());
write_string_v3(&mut buf, "test_arch");
write_string_v3(&mut buf, "blk.0.attn_q.weight");
buf.extend_from_slice(&2u32.to_le_bytes()); buf.extend_from_slice(&4u64.to_le_bytes()); buf.extend_from_slice(&4u64.to_le_bytes()); buf.extend_from_slice(&(GgufTensorType::F32 as u32).to_le_bytes());
buf.extend_from_slice(&0u64.to_le_bytes());
write_string_v3(&mut buf, "output.weight");
buf.extend_from_slice(&1u32.to_le_bytes()); buf.extend_from_slice(&8u64.to_le_bytes()); buf.extend_from_slice(&(GgufTensorType::F16 as u32).to_le_bytes());
buf.extend_from_slice(&64u64.to_le_bytes());
let aligned = align_up(buf.len(), 32);
buf.resize(aligned, 0u8);
buf.resize(aligned + 256, 0xAB_u8);
buf
}
#[test]
fn lru_cache_evicts_oldest() {
let mut cache = LruTensorCache::new(2);
cache.put("a".into(), vec![1, 2, 3]);
cache.put("b".into(), vec![4, 5, 6]);
cache.put("c".into(), vec![7, 8, 9]);
assert!(
cache.get("a").is_none(),
"entry 'a' should have been evicted"
);
assert!(cache.get("b").is_some(), "entry 'b' should still be cached");
assert!(cache.get("c").is_some(), "entry 'c' should be in the cache");
}
#[test]
fn lru_cache_get_refreshes_order() {
let mut cache = LruTensorCache::new(2);
cache.put("a".into(), vec![1]);
cache.put("b".into(), vec![2]);
let _ = cache.get("a");
cache.put("c".into(), vec![3]);
assert!(
cache.get("a").is_some(),
"entry 'a' was recently used; must survive"
);
assert!(
cache.get("b").is_none(),
"entry 'b' is now LRU; must be evicted"
);
assert!(cache.get("c").is_some(), "entry 'c' must be present");
}
#[test]
fn lru_cache_duplicate_put_no_corruption() {
let mut cache = LruTensorCache::new(2);
cache.put("a".into(), vec![1]);
cache.put("b".into(), vec![2]);
cache.put("a".into(), vec![99]);
cache.put("c".into(), vec![3]);
assert!(cache.get("b").is_none(), "stale 'b' should be evicted");
assert!(cache.get("a").is_some(), "refreshed 'a' must survive");
assert_eq!(
cache.get("a").as_deref().map(|v| v.as_slice()),
Some([99_u8].as_slice()),
"re-inserted value must be the new bytes"
);
}
#[test]
fn lru_cache_zero_capacity_evicts_immediately() {
let mut cache = LruTensorCache::new(0);
cache.put("a".into(), vec![1]);
assert_eq!(cache.len(), 1);
}
#[test]
fn push_chunk_transitions_to_header_parsed_empty_gguf() {
let header_bytes = make_empty_gguf();
let mut inner = LoaderInner::new(8);
inner.bytes_buffer.extend_from_slice(&header_bytes);
let result = inner
.try_parse_header()
.expect("try_parse_header should succeed");
assert!(result, "should return true when header is parsed");
assert_eq!(inner.phase, LoadPhase::HeaderParsed);
}
#[test]
fn push_chunk_partial_data_returns_false() {
let header_bytes = make_empty_gguf();
let mut inner = LoaderInner::new(8);
inner.bytes_buffer.extend_from_slice(&header_bytes[..10]);
let result = inner
.try_parse_header()
.expect("should not error on partial data");
assert!(!result, "incomplete header should return false, not error");
assert_eq!(inner.phase, LoadPhase::WaitingForHeader);
}
#[test]
fn push_chunk_invalid_magic_returns_error() {
let mut inner = LoaderInner::new(8);
inner.bytes_buffer.extend_from_slice(&[0u8; 32]);
let result = inner.try_parse_header();
assert!(result.is_err(), "invalid magic must produce an error");
let msg = result.expect_err("expected error string");
assert!(
msg.contains("GGUF parse error"),
"error message should mention GGUF parse error, got: {msg}"
);
}
#[test]
fn tensor_index_populated_after_header() {
let gguf_bytes = make_two_tensor_gguf();
let mut inner = LoaderInner::new(8);
inner.bytes_buffer.extend_from_slice(&gguf_bytes);
let ready = inner.try_parse_header().expect("parse should succeed");
assert!(ready, "header should be ready");
assert_eq!(inner.phase, LoadPhase::HeaderParsed);
assert!(
inner.tensor_index.contains_key("blk.0.attn_q.weight"),
"blk.0.attn_q.weight must be indexed"
);
assert!(
inner.tensor_index.contains_key("output.weight"),
"output.weight must be indexed"
);
let meta = inner
.tensor_index
.get("blk.0.attn_q.weight")
.expect("meta must be present");
assert_eq!(meta.shape, vec![4, 4], "F32 [4,4] shape must match");
assert_eq!(meta.dtype, "F32", "dtype must be F32");
assert_eq!(meta.size_bytes, 64, "F32 [4,4] data size must be 64 bytes");
}
#[test]
fn tensor_file_offset_is_absolute() {
let gguf_bytes = make_two_tensor_gguf();
let parser = oxillama_gguf::StreamingGgufParser::new(&gguf_bytes)
.expect("streaming parser should succeed");
let data_section_offset = parser.tensor_infos().data_offset();
let mut inner = LoaderInner::new(8);
inner.bytes_buffer.extend_from_slice(&gguf_bytes);
inner.try_parse_header().expect("header parse must succeed");
let meta_q = inner
.tensor_index
.get("blk.0.attn_q.weight")
.expect("tensor must be indexed");
assert_eq!(
meta_q.file_offset, data_section_offset,
"file_offset for offset-0 tensor must equal data_section_offset"
);
let meta_out = inner
.tensor_index
.get("output.weight")
.expect("output.weight must be indexed");
assert_eq!(
meta_out.file_offset,
data_section_offset + 64,
"file_offset for output.weight must be data_section_offset + 64"
);
}
#[test]
fn try_parse_header_is_idempotent() {
let gguf_bytes = make_empty_gguf();
let mut inner = LoaderInner::new(8);
inner.bytes_buffer.extend_from_slice(&gguf_bytes);
let first = inner.try_parse_header().expect("first call must succeed");
assert!(first);
let second = inner.try_parse_header().expect("second call must succeed");
assert!(second);
assert_eq!(inner.phase, LoadPhase::HeaderParsed);
}
}