#[cfg(feature = "wasm")]
use wasm_bindgen::prelude::*;
#[cfg(feature = "wasm")]
#[wasm_bindgen]
pub fn init_panic_hook() {
console_error_panic_hook::set_once();
}
#[cfg(feature = "wasm")]
#[wasm_bindgen]
pub fn wasm_log(message: &str) {
web_sys::console::log_1(&JsValue::from_str(message));
}
#[cfg(feature = "wasm")]
#[wasm_bindgen]
pub fn wasm_error(message: &str) {
web_sys::console::error_1(&JsValue::from_str(message));
}
#[cfg(feature = "wasm")]
#[wasm_bindgen]
pub struct WasmConfig {
hidden_dim: usize,
num_layers: usize,
vocab_size: usize,
max_seq_len: usize,
}
#[cfg(feature = "wasm")]
#[wasm_bindgen]
impl WasmConfig {
#[wasm_bindgen(constructor)]
pub fn new() -> Self {
Self {
hidden_dim: 256,
num_layers: 4,
vocab_size: 32000,
max_seq_len: 512,
}
}
pub fn with_hidden_dim(mut self, dim: usize) -> Self {
self.hidden_dim = dim;
self
}
pub fn with_num_layers(mut self, layers: usize) -> Self {
self.num_layers = layers;
self
}
pub fn with_vocab_size(mut self, vocab: usize) -> Self {
self.vocab_size = vocab;
self
}
pub fn with_max_seq_len(mut self, len: usize) -> Self {
self.max_seq_len = len;
self
}
#[wasm_bindgen(getter)]
pub fn hidden_dim(&self) -> usize {
self.hidden_dim
}
#[wasm_bindgen(getter)]
pub fn num_layers(&self) -> usize {
self.num_layers
}
}
#[cfg(feature = "wasm")]
#[wasm_bindgen]
pub struct WasmBitLlama {
config: WasmConfig,
initialized: bool,
current_pos: usize,
}
#[cfg(feature = "wasm")]
#[wasm_bindgen]
impl WasmBitLlama {
#[wasm_bindgen(constructor)]
pub fn new() -> Result<WasmBitLlama, JsValue> {
console_error_panic_hook::set_once();
Ok(Self {
config: WasmConfig::new(),
initialized: false,
current_pos: 0,
})
}
pub fn with_config(config: WasmConfig) -> Result<WasmBitLlama, JsValue> {
console_error_panic_hook::set_once();
Ok(Self {
config,
initialized: false,
current_pos: 0,
})
}
pub fn load_weights(&mut self, _weights_data: &[u8]) -> Result<(), JsValue> {
wasm_log("WasmBitLlama: load_weights called (stub implementation)");
self.initialized = true;
Ok(())
}
pub fn generate(&mut self, prompt: &str, max_tokens: u32) -> Result<String, JsValue> {
if !self.initialized {
return Err(JsValue::from_str(
"Model not initialized. Call load_weights first.",
));
}
wasm_log(&format!(
"WasmBitLlama: generate called with prompt='{}', max_tokens={}",
prompt, max_tokens
));
Ok(format!(
"[STUB] Generated response for: '{}' (max_tokens: {}). \
Full implementation pending.",
prompt, max_tokens
))
}
pub fn generate_stream(
&mut self,
prompt: &str,
max_tokens: u32,
callback: &js_sys::Function,
) -> Result<String, JsValue> {
if !self.initialized {
return Err(JsValue::from_str(
"Model not initialized. Call load_weights first.",
));
}
wasm_log(&format!(
"WasmBitLlama: generate_stream called with prompt='{}'",
prompt
));
let stub_tokens = ["[STUB]", " response", " for:", " '", prompt, "'"];
let mut result = String::new();
for (i, token) in stub_tokens.iter().enumerate() {
if i as u32 >= max_tokens {
break;
}
result.push_str(token);
let this = JsValue::NULL;
let token_js = JsValue::from_str(token);
if let Err(e) = callback.call1(&this, &token_js) {
wasm_error(&format!("Callback error: {:?}", e));
}
}
Ok(result)
}
pub fn reset(&mut self) {
self.current_pos = 0;
wasm_log("WasmBitLlama: state reset");
}
#[wasm_bindgen(getter)]
pub fn is_initialized(&self) -> bool {
self.initialized
}
#[wasm_bindgen(getter)]
pub fn position(&self) -> usize {
self.current_pos
}
pub fn get_info(&self) -> String {
serde_json::json!({
"hidden_dim": self.config.hidden_dim,
"num_layers": self.config.num_layers,
"vocab_size": self.config.vocab_size,
"max_seq_len": self.config.max_seq_len,
"initialized": self.initialized,
"current_pos": self.current_pos
})
.to_string()
}
}
#[cfg(feature = "wasm")]
use js_sys;
#[cfg(feature = "wasm")]
#[wasm_bindgen]
pub fn wasm_infer(input: &str) -> Result<String, JsValue> {
wasm_log(&format!("wasm_infer called with: {}", input));
Ok(format!(
"[Legacy API] Input received: '{}'. Use WasmBitLlama for full functionality.",
input
))
}
#[cfg(feature = "wasm")]
#[wasm_bindgen]
pub fn wasm_version() -> String {
env!("CARGO_PKG_VERSION").to_string()
}
#[cfg(all(test, feature = "wasm"))]
mod tests {
use super::*;
#[test]
fn test_wasm_config_creation() {
let config = WasmConfig::new();
assert_eq!(config.hidden_dim, 256);
assert_eq!(config.num_layers, 4);
}
}