use std::sync::Arc;
use crate::backend::{Pipelines, WeightCache, WgpuCtx};
use crate::error::Result;
use crate::gguf::{GgufReader, TensorFetcher};
use crate::reference::embed::EmbedModel;
use crate::tokenizer::SpmTokenizer;
#[cfg(target_arch = "wasm32")]
use wasm_bindgen::prelude::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
pub struct EmbeddingModel {
model: EmbedModel,
tok: SpmTokenizer,
ctx: WgpuCtx,
pipes: Arc<Pipelines>,
wcache: WeightCache,
bos: u32,
eos: u32,
add_bos: bool,
add_eos: bool,
}
impl EmbeddingModel {
async fn from_reader(reader: GgufReader) -> Result<Self> {
let r_arc = Arc::new(reader);
let tok = SpmTokenizer::from_gguf(&r_arc)?;
let bos = meta_u32(&r_arc, "tokenizer.ggml.bos_token_id", 2);
let eos = meta_u32(&r_arc, "tokenizer.ggml.eos_token_id", 1);
let add_bos = meta_bool(&r_arc, "tokenizer.ggml.add_bos_token", true);
let add_eos = meta_bool(&r_arc, "tokenizer.ggml.add_eos_token", true);
let model = EmbedModel::new(r_arc.clone())?;
let ctx = WgpuCtx::new().await?;
let pipes = Arc::new(Pipelines::new(&ctx.device));
let wcache = WeightCache::new(
r_arc,
ctx.device.clone(),
ctx.queue.clone(),
Arc::clone(&ctx.bind_cache),
);
Ok(Self {
model,
tok,
ctx,
pipes,
wcache,
bos,
eos,
add_bos,
add_eos,
})
}
pub async fn load_native(bytes: Vec<u8>) -> Result<Self> {
Self::from_reader(GgufReader::new(bytes)?).await
}
pub async fn load_streaming_native(fetcher: Arc<dyn TensorFetcher>) -> Result<Self> {
Self::from_reader(GgufReader::new_streaming(fetcher).await?).await
}
pub fn dim_native(&self) -> u32 {
self.model.cfg.embed_dim
}
fn ids_for(&self, text: &str) -> Vec<u32> {
let mut ids = Vec::new();
if self.add_bos {
ids.push(self.bos);
}
ids.extend(self.tok.encode(text));
if self.add_eos {
ids.push(self.eos);
}
ids
}
pub async fn embed_native(&self, text: &str, target_dim: usize) -> Result<Vec<f32>> {
let ids = self.ids_for(text);
self.model
.embed_ids_gpu(&self.ctx, &self.pipes, &self.wcache, &ids, target_dim)
.await
}
pub async fn embed_batch_native(
&self,
texts: &[String],
target_dim: usize,
) -> Result<(Vec<f32>, usize)> {
let mut out = Vec::new();
let mut dim = 0usize;
for t in texts {
let v = self.embed_native(t, target_dim).await?;
dim = v.len();
out.extend_from_slice(&v);
}
Ok((out, dim))
}
}
fn meta_u32(r: &GgufReader, key: &str, default: u32) -> u32 {
r.get(key)
.ok()
.and_then(|v| v.as_u32().ok())
.unwrap_or(default)
}
fn meta_bool(r: &GgufReader, key: &str, default: bool) -> bool {
r.get(key)
.ok()
.and_then(|v| v.as_bool().ok())
.unwrap_or(default)
}
#[cfg(target_arch = "wasm32")]
#[wasm_bindgen]
impl EmbeddingModel {
#[wasm_bindgen(js_name = load)]
pub async fn load_js(bytes: Vec<u8>) -> std::result::Result<EmbeddingModel, JsError> {
Self::load_native(bytes)
.await
.map_err(|e| JsError::new(&format!("{e:?}")))
}
#[wasm_bindgen(js_name = loadFromOpfs)]
pub async fn load_from_opfs_js(
read_fn: js_sys::Function,
total_bytes: f64,
) -> std::result::Result<EmbeddingModel, JsError> {
if !total_bytes.is_finite() || total_bytes < 0.0 {
return Err(JsError::new(
"loadFromOpfs: total_bytes must be a non-negative finite number",
));
}
let fetcher = crate::gguf::OpfsFetcher::new(read_fn, total_bytes as u64);
let arc: Arc<dyn TensorFetcher> = Arc::new(fetcher);
Self::load_streaming_native(arc)
.await
.map_err(|e| JsError::new(&format!("{e:?}")))
}
#[wasm_bindgen(js_name = dim, getter)]
pub fn dim_js(&self) -> u32 {
self.dim_native()
}
#[wasm_bindgen(js_name = embed)]
pub async fn embed_js(
&self,
text: String,
target_dim: u32,
) -> std::result::Result<Vec<f32>, JsError> {
self.embed_native(&text, target_dim as usize)
.await
.map_err(|e| JsError::new(&format!("{e:?}")))
}
#[wasm_bindgen(js_name = embedBatch)]
pub async fn embed_batch_js(
&self,
texts: Vec<String>,
target_dim: u32,
) -> std::result::Result<Vec<f32>, JsError> {
self.embed_batch_native(&texts, target_dim as usize)
.await
.map(|(v, _dim)| v)
.map_err(|e| JsError::new(&format!("{e:?}")))
}
}