use std::sync::Arc;
use async_trait::async_trait;
use crate::error::{Result, RullamaError};
#[async_trait(?Send)]
pub trait TensorFetcher {
fn total_len(&self) -> u64;
async fn fetch(&self, offset: u64, len: u64) -> Result<Vec<u8>>;
}
pub struct InMemoryFetcher {
bytes: Arc<[u8]>,
}
impl InMemoryFetcher {
pub fn new(bytes: Vec<u8>) -> Self {
Self {
bytes: bytes.into(),
}
}
pub fn from_arc(bytes: Arc<[u8]>) -> Self {
Self { bytes }
}
pub fn as_slice(&self) -> &[u8] {
&self.bytes
}
}
#[async_trait(?Send)]
impl TensorFetcher for InMemoryFetcher {
fn total_len(&self) -> u64 {
self.bytes.len() as u64
}
async fn fetch(&self, offset: u64, len: u64) -> Result<Vec<u8>> {
let start = offset as usize;
let end = start.checked_add(len as usize).ok_or_else(|| {
RullamaError::Gguf(format!("InMemoryFetcher: range overflow {offset}+{len}"))
})?;
if end > self.bytes.len() {
return Err(RullamaError::Gguf(format!(
"InMemoryFetcher: range {start}..{end} extends past buffer end ({})",
self.bytes.len()
)));
}
Ok(self.bytes[start..end].to_vec())
}
}
#[cfg(target_arch = "wasm32")]
pub struct HttpRangeFetcher {
url: String,
total: u64,
}
#[cfg(target_arch = "wasm32")]
impl HttpRangeFetcher {
pub async fn new(url: String) -> Result<Self> {
use wasm_bindgen::JsCast;
use wasm_bindgen_futures::JsFuture;
let req_init = web_sys::RequestInit::new();
req_init.set_method("GET");
let headers = web_sys::Headers::new()
.map_err(|e| RullamaError::Gguf(format!("Headers::new: {e:?}")))?;
headers
.set("Range", "bytes=0-0")
.map_err(|e| RullamaError::Gguf(format!("set Range: {e:?}")))?;
req_init.set_headers(&headers);
let request = web_sys::Request::new_with_str_and_init(&url, &req_init)
.map_err(|e| RullamaError::Gguf(format!("Request::new: {e:?}")))?;
let resp_value = JsFuture::from(global_fetch(&request)?)
.await
.map_err(|e| RullamaError::Gguf(format!("fetch failed: {e:?}")))?;
let resp: web_sys::Response = resp_value
.dyn_into()
.map_err(|e| RullamaError::Gguf(format!("response cast: {e:?}")))?;
if !resp.ok() && resp.status() != 206 {
return Err(RullamaError::Gguf(format!(
"HTTP {} from {url}",
resp.status()
)));
}
let total = if let Some(cr) = resp.headers().get("Content-Range").ok().flatten() {
cr.rsplit('/')
.next()
.and_then(|s| s.parse::<u64>().ok())
.ok_or_else(|| RullamaError::Gguf(format!("bad Content-Range: {cr}")))?
} else if let Some(xs) = resp.headers().get("X-Total-Size").ok().flatten() {
xs.parse::<u64>()
.map_err(|e| RullamaError::Gguf(format!("bad X-Total-Size: {e}")))?
} else {
return Err(RullamaError::Gguf(
"server returned no Content-Range or X-Total-Size; cannot determine GGUF length"
.into(),
));
};
Ok(Self { url, total })
}
}
#[cfg(target_arch = "wasm32")]
fn global_fetch(request: &web_sys::Request) -> Result<js_sys::Promise> {
use wasm_bindgen::JsCast;
let global = js_sys::global();
if let Some(window) = global.dyn_ref::<web_sys::Window>() {
return Ok(window.fetch_with_request(request));
}
if let Some(scope) = global.dyn_ref::<web_sys::WorkerGlobalScope>() {
return Ok(scope.fetch_with_request(request));
}
Err(RullamaError::Gguf(
"no Window or WorkerGlobalScope for fetch()".into(),
))
}
#[cfg(target_arch = "wasm32")]
#[async_trait(?Send)]
impl TensorFetcher for HttpRangeFetcher {
fn total_len(&self) -> u64 {
self.total
}
async fn fetch(&self, offset: u64, len: u64) -> Result<Vec<u8>> {
use wasm_bindgen::JsCast;
use wasm_bindgen_futures::JsFuture;
if len == 0 {
return Ok(Vec::new());
}
let end = offset.checked_add(len - 1).ok_or_else(|| {
RullamaError::Gguf(format!("HttpRangeFetcher: range overflow {offset}+{len}"))
})?;
if end >= self.total {
return Err(RullamaError::Gguf(format!(
"HttpRangeFetcher: range {offset}..={end} extends past file end ({})",
self.total
)));
}
let req_init = web_sys::RequestInit::new();
req_init.set_method("GET");
let headers = web_sys::Headers::new()
.map_err(|e| RullamaError::Gguf(format!("Headers::new: {e:?}")))?;
headers
.set("Range", &format!("bytes={offset}-{end}"))
.map_err(|e| RullamaError::Gguf(format!("set Range: {e:?}")))?;
req_init.set_headers(&headers);
let request = web_sys::Request::new_with_str_and_init(&self.url, &req_init)
.map_err(|e| RullamaError::Gguf(format!("Request::new: {e:?}")))?;
let resp_value = JsFuture::from(global_fetch(&request)?)
.await
.map_err(|e| RullamaError::Gguf(format!("fetch failed: {e:?}")))?;
let resp: web_sys::Response = resp_value
.dyn_into()
.map_err(|e| RullamaError::Gguf(format!("response cast: {e:?}")))?;
if !resp.ok() && resp.status() != 206 {
return Err(RullamaError::Gguf(format!(
"HTTP {} fetching range {offset}-{end}",
resp.status()
)));
}
let buf_promise = resp
.array_buffer()
.map_err(|e| RullamaError::Gguf(format!("array_buffer: {e:?}")))?;
let array_buffer = JsFuture::from(buf_promise)
.await
.map_err(|e| RullamaError::Gguf(format!("await array_buffer: {e:?}")))?;
let bytes = js_sys::Uint8Array::new(&array_buffer).to_vec();
if bytes.len() as u64 != len {
return Err(RullamaError::Gguf(format!(
"HttpRangeFetcher: server returned {} bytes, expected {len}",
bytes.len()
)));
}
Ok(bytes)
}
}
#[cfg(target_arch = "wasm32")]
pub struct OpfsFetcher {
read_fn: js_sys::Function,
total: u64,
}
#[cfg(target_arch = "wasm32")]
impl OpfsFetcher {
pub fn new(read_fn: js_sys::Function, total: u64) -> Self {
Self { read_fn, total }
}
}
#[cfg(target_arch = "wasm32")]
#[async_trait(?Send)]
impl TensorFetcher for OpfsFetcher {
fn total_len(&self) -> u64 {
self.total
}
async fn fetch(&self, offset: u64, len: u64) -> Result<Vec<u8>> {
use wasm_bindgen::{JsCast, JsValue};
use wasm_bindgen_futures::JsFuture;
if len == 0 {
return Ok(Vec::new());
}
let end = offset.checked_add(len).ok_or_else(|| {
RullamaError::Gguf(format!("OpfsFetcher: range overflow {offset}+{len}"))
})?;
if end > self.total {
return Err(RullamaError::Gguf(format!(
"OpfsFetcher: range {offset}..{end} extends past file end ({})",
self.total
)));
}
let result = self
.read_fn
.call2(
&JsValue::NULL,
&JsValue::from_f64(offset as f64),
&JsValue::from_f64(len as f64),
)
.map_err(|e| RullamaError::Gguf(format!("OPFS read_fn call failed: {e:?}")))?;
let value = if let Ok(promise) = result.clone().dyn_into::<js_sys::Promise>() {
JsFuture::from(promise)
.await
.map_err(|e| RullamaError::Gguf(format!("OPFS read_fn promise rejected: {e:?}")))?
} else {
result
};
let array = js_sys::Uint8Array::new(&value);
let bytes = array.to_vec();
if bytes.len() as u64 != len {
return Err(RullamaError::Gguf(format!(
"OpfsFetcher: read_fn returned {} bytes, expected {len}",
bytes.len()
)));
}
Ok(bytes)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn block<F: core::future::Future>(f: F) -> F::Output {
pollster::block_on(f)
}
#[test]
fn in_memory_fetcher_returns_correct_slice() {
let bytes: Vec<u8> = (0..=255u8).collect();
let f = InMemoryFetcher::new(bytes);
assert_eq!(f.total_len(), 256);
let chunk = block(f.fetch(10, 8)).unwrap();
assert_eq!(chunk, vec![10, 11, 12, 13, 14, 15, 16, 17]);
}
#[test]
fn in_memory_fetcher_rejects_out_of_range() {
let f = InMemoryFetcher::new(vec![0u8; 16]);
assert!(block(f.fetch(0, 17)).is_err());
assert!(block(f.fetch(20, 1)).is_err());
}
#[test]
fn in_memory_fetcher_zero_length() {
let f = InMemoryFetcher::new(vec![1, 2, 3, 4]);
let chunk = block(f.fetch(2, 0)).unwrap();
assert_eq!(chunk, Vec::<u8>::new());
}
}