use wasm_bindgen::prelude::*;
use crate::loader::ModelFile;
use crate::models::bert::{BertConfig, BertModel};
use crate::ops::matmul::matmul;
use crate::tensor::Tensor;
use crate::tokenizer::{EncodedInput, WordPieceOptions, WordPieceTokenizer};
#[wasm_bindgen]
pub fn init_panic_hook() {
#[cfg(feature = "wasm-debug")]
console_error_panic_hook::set_once();
}
#[wasm_bindgen]
pub fn version() -> String {
env!("CARGO_PKG_VERSION").to_string()
}
#[wasm_bindgen]
pub fn matmul_bench(n: usize) -> f32 {
let a = Tensor::from_vec(vec![1.0; n * n], &[n, n]);
let b = Tensor::from_vec(vec![1.0; n * n], &[n, n]);
let c = matmul(&a, &b);
c.data()[0]
}
#[wasm_bindgen]
pub struct WasmWordPieceTokenizer {
inner: WordPieceTokenizer,
}
#[wasm_bindgen]
impl WasmWordPieceTokenizer {
#[wasm_bindgen(constructor)]
pub fn new(vocab: &[u8], lowercase: bool) -> Result<WasmWordPieceTokenizer, JsError> {
let options = WordPieceOptions {
lowercase,
..WordPieceOptions::default()
};
let inner = WordPieceTokenizer::from_vocab_bytes_with_options(vocab, options)
.map_err(|e| JsError::new(&e.to_string()))?;
Ok(Self { inner })
}
pub fn encode(&self, text: &str, max_len: usize) -> Result<WasmEncodedInput, JsError> {
self.inner
.encode(text, max_len)
.map(WasmEncodedInput::new)
.map_err(|e| JsError::new(&e.to_string()))
}
pub fn encode_padded(&self, text: &str, max_len: usize) -> Result<WasmEncodedInput, JsError> {
self.inner
.encode_padded(text, max_len)
.map(WasmEncodedInput::new)
.map_err(|e| JsError::new(&e.to_string()))
}
pub fn token_id(&self, token: &str) -> Option<u32> {
self.inner.token_id(token)
}
}
#[wasm_bindgen]
pub struct WasmEncodedInput {
inner: EncodedInput,
}
impl WasmEncodedInput {
fn new(inner: EncodedInput) -> Self {
Self { inner }
}
}
#[wasm_bindgen]
impl WasmEncodedInput {
pub fn input_ids(&self) -> Box<[u32]> {
self.inner.input_ids.clone().into_boxed_slice()
}
pub fn token_type_ids(&self) -> Box<[u32]> {
self.inner.token_type_ids.clone().into_boxed_slice()
}
pub fn attention_mask(&self) -> Box<[u32]> {
self.inner.attention_mask.clone().into_boxed_slice()
}
}
#[wasm_bindgen]
pub struct WasmBertModel {
inner: BertModel,
}
#[wasm_bindgen]
impl WasmBertModel {
#[wasm_bindgen(constructor)]
#[allow(clippy::too_many_arguments)]
pub fn new(
bytes: &[u8],
hidden_size: usize,
num_hidden_layers: usize,
num_attention_heads: usize,
intermediate_size: usize,
vocab_size: usize,
max_position_embeddings: usize,
type_vocab_size: usize,
prefix: &str,
) -> Result<WasmBertModel, JsError> {
let config = BertConfig {
hidden_size,
num_hidden_layers,
num_attention_heads,
intermediate_size,
vocab_size,
max_position_embeddings,
type_vocab_size,
layer_norm_eps: 1e-12,
};
let file = ModelFile::parse(bytes).map_err(|e| JsError::new(&e.to_string()))?;
let inner = BertModel::from_safetensors(&file, config, prefix)
.map_err(|e| JsError::new(&e.to_string()))?;
Ok(Self { inner })
}
pub fn forward(&self, input_ids: &[u32]) -> Box<[f32]> {
self.inner
.forward(input_ids, None)
.data()
.to_vec()
.into_boxed_slice()
}
pub fn embed(&self, input_ids: &[u32]) -> Box<[f32]> {
self.inner
.embed_sentence(input_ids, None, None)
.data()
.to_vec()
.into_boxed_slice()
}
pub fn embed_with_mask(
&self,
input_ids: &[u32],
attention_mask: &[u32],
) -> Result<Box<[f32]>, JsError> {
if input_ids.len() != attention_mask.len() {
return Err(JsError::new(
"input_ids and attention_mask must have the same length",
));
}
Ok(self
.inner
.embed_sentence(input_ids, None, Some(attention_mask))
.data()
.to_vec()
.into_boxed_slice())
}
pub fn embed_encoded(&self, encoded: &WasmEncodedInput) -> Box<[f32]> {
self.inner
.embed_sentence(
&encoded.inner.input_ids,
Some(&encoded.inner.token_type_ids),
Some(&encoded.inner.attention_mask),
)
.data()
.to_vec()
.into_boxed_slice()
}
pub fn embed_text(
&self,
tokenizer: &WasmWordPieceTokenizer,
text: &str,
max_len: usize,
) -> Result<Box<[f32]>, JsError> {
self.inner
.embed_text(&tokenizer.inner, text, max_len)
.map(|t| t.data().to_vec().into_boxed_slice())
.map_err(|e| JsError::new(&e.to_string()))
}
}