use std::sync::Arc;
use serde::{Deserialize, Serialize};
use crate::backend::{Pipelines, WeightCache, WgpuCtx};
use crate::error::Result;
use crate::gguf::GgufReader;
use crate::model::config::Gemma4Config;
use crate::multimodal::{AudioConfig, GpuAudioForward, VisionConfig, VisionForward, decode_wav};
use crate::reference::Weights;
use crate::reference::forward_chained::Forward;
use crate::sampling::{Sampler, SamplingOptions};
use crate::template::gemma4_small;
use crate::tokenizer::BpeTokenizer;
#[cfg(target_arch = "wasm32")]
use wasm_bindgen::prelude::*;
#[cfg(target_arch = "wasm32")]
#[wasm_bindgen(js_name = computeSpike)]
pub async fn compute_spike_js(input: Vec<f32>) -> std::result::Result<Vec<f32>, JsError> {
crate::backend::compute_spike(&input)
.await
.map_err(|e| JsError::new(&format!("{e}")))
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
pub struct Model {
tokenizer: BpeTokenizer,
forward: Forward,
vision: Option<VisionForward>,
audio: Option<GpuAudioForward>,
sampler: Sampler,
}
impl Model {
async fn from_reader(reader: GgufReader) -> Result<Self> {
Self::from_reader_with_modes(reader, true, true,
crate::reference::forward_chained::MAX_CONTEXT).await
}
async fn from_reader_with_modes(
reader: GgufReader,
with_vision: bool,
with_audio: bool,
max_context: u32,
) -> Result<Self> {
let cfg = Gemma4Config::from_gguf(&reader)?;
let tokenizer = BpeTokenizer::from_gguf(&reader)?;
let d_text = cfg.d_model;
let r_arc = Arc::new(reader);
let weights = Weights::new(r_arc.clone());
let ctx = WgpuCtx::new().await?;
let pipes = Arc::new(Pipelines::new_with_features(&ctx.device, ctx.has_subgroups, ctx.has_f16));
let wcache = Arc::new(WeightCache::new(r_arc.clone(), ctx.device.clone(), ctx.queue.clone()));
let vision = if with_vision && r_arc.tensor("v.patch_embd.weight").is_ok() {
let vcfg = VisionConfig::from_gguf(&r_arc, d_text)?;
Some(VisionForward::new(vcfg, ctx.clone(), pipes.clone(), wcache.clone()).await?)
} else {
None
};
let audio = if with_audio && r_arc.tensor("a.conv1d.0.weight").is_ok() {
let acfg = AudioConfig::from_gguf(&r_arc, d_text)?;
Some(GpuAudioForward::new(acfg, ctx.clone(), pipes.clone(), wcache.clone()).await?)
} else {
None
};
let forward = Forward::new_with_max_context(cfg, ctx, pipes, weights, wcache, max_context).await?;
Ok(Self {
tokenizer,
forward,
vision,
audio,
sampler: Sampler::new(SamplingOptions::default()),
})
}
pub fn has_vision_native(&self) -> bool { self.vision.is_some() }
pub async fn encode_image_native(
&self, pixels: &[f32], h: usize, w: usize,
progress: Option<&dyn Fn(u32, u32)>,
) -> Result<Vec<f32>> {
let v = self.vision.as_ref().ok_or_else(|| {
crate::error::RullamaError::Inference(
"encode_image: this checkpoint has no vision tower".into()
)
})?;
v.encode(pixels, h, w, progress).await
}
pub fn image_soft_token_count_native(&self, h: usize, w: usize) -> Option<usize> {
let v = self.vision.as_ref()?;
let cfg = v.cfg();
let align = (cfg.patch_size * cfg.n_merge) as usize;
if h % align != 0 || w % align != 0 { return None; }
let pooled_h = h / align;
let pooled_w = w / align;
Some(pooled_h * pooled_w)
}
pub fn has_audio_native(&self) -> bool { self.audio.is_some() }
pub async fn encode_audio_native(&self, pcm: &[f32]) -> Result<Vec<f32>> {
let a = self.audio.as_ref().ok_or_else(|| {
crate::error::RullamaError::Inference(
"encode_audio: this checkpoint has no audio tower".into()
)
})?;
a.encode(pcm).await
}
pub fn decode_wav_native(bytes: &[u8]) -> Result<Vec<f32>> {
decode_wav(bytes)
}
pub fn audio_sentinel_ids_native(&self) -> Option<(u32, u32)> {
let begin = self.tokenizer.str_to_id("<|audio>")?;
let end = self.tokenizer.str_to_id("<audio|>")?;
Some((begin, end))
}
pub fn image_sentinel_ids_native(&self) -> Option<(u32, u32)> {
let begin = self.tokenizer.str_to_id("<|image>")?;
let end = self.tokenizer.str_to_id("<image|>")?;
Some((begin, end))
}
pub async fn load_native(bytes: Vec<u8>) -> Result<Self> {
let reader = GgufReader::new(bytes)?;
Self::from_reader(reader).await
}
pub async fn load_streaming(
fetcher: std::sync::Arc<dyn crate::gguf::TensorFetcher>,
) -> Result<Self> {
let reader = GgufReader::new_streaming(fetcher).await?;
Self::from_reader(reader).await
}
pub async fn load_streaming_text_only(
fetcher: std::sync::Arc<dyn crate::gguf::TensorFetcher>,
max_context: u32,
) -> Result<Self> {
let reader = GgufReader::new_streaming(fetcher).await?;
Self::from_reader_with_modes(reader, false, false, max_context).await
}
pub fn encode_tokens(&self, text: &str) -> Vec<u32> {
self.tokenizer.encode(text)
}
pub fn token_str_native(&self, id: u32) -> Option<String> {
self.tokenizer.id_to_str(id).map(|s| s.to_string())
}
pub fn vocab_size_native(&self) -> u32 { self.forward.cfg().vocab_size }
pub fn position_native(&self) -> u32 { self.forward.pos() }
pub fn is_eos_native(&self, id: u32) -> bool {
self.forward.cfg().eos_ids.iter().any(|&e| e == id)
}
pub fn forward_mut(&mut self) -> &mut Forward { &mut self.forward }
pub fn forward(&self) -> &Forward { &self.forward }
pub fn reset_native(&mut self) {
self.forward.reset();
self.sampler.clear_history();
}
pub fn set_sampling_native(&mut self, opts: SamplingOptions) {
self.sampler.set_options(opts);
}
pub async fn step_native(&mut self, token_id: u32) -> Result<u32> {
self.sampler.observe(token_id);
let logits = self.forward.step(token_id).await?;
let next = self.sampler.sample(&logits);
Ok(next)
}
pub async fn step_with_embedding_native(&mut self, embedding: &[f32]) -> Result<u32> {
let logits = self.forward.step_with_embedding(embedding).await?;
let next = self.sampler.sample(&logits);
Ok(next)
}
pub fn render_chat_native(&self, messages: &[ChatMessage], with_bos: bool) -> String {
gemma4_small::render_for_completion(messages, with_bos)
}
}
#[cfg(target_arch = "wasm32")]
#[wasm_bindgen]
impl Model {
#[wasm_bindgen(js_name = load)]
pub async fn load_js(bytes: Vec<u8>) -> std::result::Result<Model, JsError> {
Self::load_native(bytes).await.map_err(|e| JsError::new(&format!("{e}")))
}
#[wasm_bindgen(js_name = loadFromUrl)]
pub async fn load_from_url_js(url: String) -> std::result::Result<Model, JsError> {
let fetcher = crate::gguf::HttpRangeFetcher::new(url)
.await
.map_err(|e| JsError::new(&format!("{e}")))?;
let arc: std::sync::Arc<dyn crate::gguf::TensorFetcher> = std::sync::Arc::new(fetcher);
Self::load_streaming(arc).await.map_err(|e| JsError::new(&format!("{e}")))
}
#[wasm_bindgen(js_name = loadFromOpfs)]
pub async fn load_from_opfs_js(
read_fn: js_sys::Function,
total_bytes: f64,
) -> std::result::Result<Model, JsError> {
if !total_bytes.is_finite() || total_bytes < 0.0 {
return Err(JsError::new("loadFromOpfs: total_bytes must be a non-negative finite number"));
}
let total = total_bytes as u64;
let fetcher = crate::gguf::OpfsFetcher::new(read_fn, total);
let arc: std::sync::Arc<dyn crate::gguf::TensorFetcher> = std::sync::Arc::new(fetcher);
Self::load_streaming(arc).await.map_err(|e| JsError::new(&format!("{e}")))
}
#[wasm_bindgen(js_name = loadFromOpfsTextOnly)]
pub async fn load_from_opfs_text_only_js(
read_fn: js_sys::Function,
total_bytes: f64,
max_context: u32,
) -> std::result::Result<Model, JsError> {
if !total_bytes.is_finite() || total_bytes < 0.0 {
return Err(JsError::new("loadFromOpfsTextOnly: total_bytes must be a non-negative finite number"));
}
let total = total_bytes as u64;
let max_ctx = if max_context == 0 { 512 } else { max_context };
let fetcher = crate::gguf::OpfsFetcher::new(read_fn, total);
let arc: std::sync::Arc<dyn crate::gguf::TensorFetcher> = std::sync::Arc::new(fetcher);
Self::load_streaming_text_only(arc, max_ctx).await.map_err(|e| JsError::new(&format!("{e}")))
}
#[wasm_bindgen(js_name = encode)]
pub fn encode_js(&self, text: &str) -> Vec<u32> { self.encode_tokens(text) }
#[wasm_bindgen(js_name = tokenStr)]
pub fn token_str_js(&self, id: u32) -> Option<String> { self.token_str_native(id) }
#[wasm_bindgen(js_name = vocabSize, getter)]
pub fn vocab_size_js(&self) -> u32 { self.vocab_size_native() }
#[wasm_bindgen(js_name = position, getter)]
pub fn position_js(&self) -> u32 { self.position_native() }
#[wasm_bindgen(js_name = isEos)]
pub fn is_eos_js(&self, id: u32) -> bool { self.is_eos_native(id) }
#[wasm_bindgen(js_name = reset)]
pub fn reset_js(&mut self) { self.reset_native() }
#[wasm_bindgen(js_name = step)]
pub async fn step_js(&mut self, token_id: u32) -> std::result::Result<u32, JsError> {
self.step_native(token_id).await.map_err(|e| JsError::new(&format!("{e}")))
}
#[wasm_bindgen(js_name = stepWithEmbedding)]
pub async fn step_with_embedding_js(
&mut self, embedding: Vec<f32>,
) -> std::result::Result<u32, JsError> {
self.step_with_embedding_native(&embedding)
.await
.map_err(|e| JsError::new(&format!("{e}")))
}
#[wasm_bindgen(js_name = setSampling)]
pub fn set_sampling_js(&mut self, opts_json: JsValue) -> std::result::Result<(), JsError> {
let opts: SamplingOptions = serde_wasm_bindgen::from_value(opts_json)
.map_err(|e| JsError::new(&format!("invalid sampling options: {e}")))?;
self.sampler.set_options(opts);
Ok(())
}
#[wasm_bindgen(js_name = hasVision, getter)]
pub fn has_vision_js(&self) -> bool { self.has_vision_native() }
#[wasm_bindgen(js_name = encodeImage)]
pub async fn encode_image_js(
&self, pixels: Vec<f32>, h: u32, w: u32,
progress_cb: Option<js_sys::Function>,
) -> std::result::Result<Vec<f32>, JsError> {
let cb: Option<Box<dyn Fn(u32, u32)>> = progress_cb.map(|f| {
Box::new(move |layer: u32, total: u32| {
let _ = f.call2(
&JsValue::NULL,
&JsValue::from(layer),
&JsValue::from(total),
);
}) as Box<dyn Fn(u32, u32)>
});
self.encode_image_native(&pixels, h as usize, w as usize, cb.as_deref())
.await
.map_err(|e| JsError::new(&format!("{e}")))
}
#[wasm_bindgen(js_name = imageSoftTokenCount)]
pub fn image_soft_token_count_js(&self, h: u32, w: u32) -> Option<u32> {
self.image_soft_token_count_native(h as usize, w as usize).map(|n| n as u32)
}
#[wasm_bindgen(js_name = imageSentinelIds)]
pub fn image_sentinel_ids_js(&self) -> Option<Vec<u32>> {
let begin = self.tokenizer.str_to_id("<|image>")?;
let end = self.tokenizer.str_to_id("<image|>")?;
Some(vec![begin, end])
}
#[wasm_bindgen(js_name = hasAudio, getter)]
pub fn has_audio_js(&self) -> bool { self.has_audio_native() }
#[wasm_bindgen(js_name = encodeAudio)]
pub async fn encode_audio_js(
&self, pcm: Vec<f32>,
) -> std::result::Result<Vec<f32>, JsError> {
self.encode_audio_native(&pcm).await.map_err(|e| JsError::new(&format!("{e}")))
}
#[wasm_bindgen(js_name = decodeWav)]
pub fn decode_wav_js(bytes: Vec<u8>) -> std::result::Result<Vec<f32>, JsError> {
Self::decode_wav_native(&bytes).map_err(|e| JsError::new(&format!("{e}")))
}
#[wasm_bindgen(js_name = audioSentinelIds)]
pub fn audio_sentinel_ids_js(&self) -> Option<Vec<u32>> {
let begin = self.tokenizer.str_to_id("<|audio>")?;
let end = self.tokenizer.str_to_id("<audio|>")?;
Some(vec![begin, end])
}
#[wasm_bindgen(js_name = renderChat)]
pub fn render_chat_js(&self, messages_json: JsValue, with_bos: bool) -> std::result::Result<String, JsError> {
let msgs: Vec<ChatMessage> = serde_wasm_bindgen::from_value(messages_json)
.map_err(|e| JsError::new(&format!("invalid messages: {e}")))?;
Ok(self.render_chat_native(&msgs, with_bos))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: ChatRole,
pub content: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ChatRole {
System,
User,
Model,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GenerateOptions {
pub messages: Vec<ChatMessage>,
#[serde(default = "default_max_tokens")]
pub max_tokens: u32,
#[serde(default = "default_temperature")]
pub temperature: f32,
#[serde(default = "default_top_p")]
pub top_p: f32,
#[serde(default = "default_top_k")]
pub top_k: u32,
#[serde(default = "default_repetition_penalty")]
pub repetition_penalty: f32,
#[serde(default)]
pub stop: Vec<String>,
}
fn default_max_tokens() -> u32 { 256 }
fn default_temperature() -> f32 { 0.7 }
fn default_top_p() -> f32 { 0.95 }
fn default_top_k() -> u32 { 40 }
fn default_repetition_penalty() -> f32 { 1.0 }