use wasm_bindgen::prelude::*;
use crate::loader::ModelFile;
use crate::models::bert::{BertConfig, BertModel};
use crate::ops::matmul::matmul;
use crate::tensor::Tensor;
#[wasm_bindgen]
pub fn init_panic_hook() {
#[cfg(feature = "wasm-debug")]
console_error_panic_hook::set_once();
}
#[wasm_bindgen]
pub fn version() -> String {
env!("CARGO_PKG_VERSION").to_string()
}
#[wasm_bindgen]
pub fn matmul_bench(n: usize) -> f32 {
let a = Tensor::from_vec(vec![1.0; n * n], &[n, n]);
let b = Tensor::from_vec(vec![1.0; n * n], &[n, n]);
let c = matmul(&a, &b);
c.data()[0]
}
#[wasm_bindgen]
pub struct WasmBertModel {
inner: BertModel,
}
#[wasm_bindgen]
impl WasmBertModel {
#[wasm_bindgen(constructor)]
#[allow(clippy::too_many_arguments)]
pub fn new(
bytes: &[u8],
hidden_size: usize,
num_hidden_layers: usize,
num_attention_heads: usize,
intermediate_size: usize,
vocab_size: usize,
max_position_embeddings: usize,
type_vocab_size: usize,
prefix: &str,
) -> Result<WasmBertModel, JsError> {
let config = BertConfig {
hidden_size,
num_hidden_layers,
num_attention_heads,
intermediate_size,
vocab_size,
max_position_embeddings,
type_vocab_size,
layer_norm_eps: 1e-12,
};
let file = ModelFile::parse(bytes).map_err(|e| JsError::new(&e.to_string()))?;
let inner = BertModel::from_safetensors(&file, config, prefix)
.map_err(|e| JsError::new(&e.to_string()))?;
Ok(Self { inner })
}
pub fn forward(&self, input_ids: &[u32]) -> Box<[f32]> {
self.inner
.forward(input_ids, None)
.data()
.to_vec()
.into_boxed_slice()
}
pub fn embed(&self, input_ids: &[u32]) -> Box<[f32]> {
self.inner
.embed_sentence(input_ids, None, None)
.data()
.to_vec()
.into_boxed_slice()
}
}