#![cfg(target_arch = "wasm32")]
use crate::tts_model::TTSModel;
use candle_core::Tensor;
use js_sys::{Date, Float32Array, Object, Reflect};
use wasm_bindgen::prelude::*;
type StreamIter = Box<dyn Iterator<Item = std::result::Result<Tensor, anyhow::Error>>>;
#[wasm_bindgen(start)]
pub fn init() {
console_error_panic_hook::set_once();
web_sys::console::log_1(&"Pocket TTS WASM initialized".into());
}
#[wasm_bindgen]
pub struct WasmTTSModel {
model: Option<TTSModel>,
voice_state: Option<crate::ModelState>,
sample_rate: u32,
}
#[wasm_bindgen]
pub struct WasmTTSStream {
iter: Option<StreamIter>,
last_samples: u32,
last_compute_ms: f64,
last_chunks_merged: u32,
}
#[wasm_bindgen]
impl WasmTTSModel {
#[wasm_bindgen(constructor)]
pub fn new() -> WasmTTSModel {
Self {
model: None,
voice_state: None,
sample_rate: 24000,
}
}
#[wasm_bindgen]
pub fn load_from_buffer(
&mut self,
config_yaml: &[u8],
weights_data: &[u8],
tokenizer_bytes: &[u8],
) -> Result<(), JsValue> {
let tok_bytes = if tokenizer_bytes.is_empty() {
include_bytes!("../assets/tokenizer.json")
} else {
tokenizer_bytes
};
let model = TTSModel::load_from_bytes(config_yaml, weights_data, tok_bytes)
.map_err(|e| JsValue::from_str(&format!("Model loading failed: {:?}", e)))?;
self.sample_rate = model.sample_rate as u32;
self.model = Some(model);
web_sys::console::log_1(
&"Model loaded successfully (using embedded or provided tokenizer)".into(),
);
Ok(())
}
#[wasm_bindgen]
pub fn is_ready(&self) -> bool {
self.model.is_some()
}
#[wasm_bindgen]
pub fn load_voice_from_buffer(&mut self, wav_bytes: &[u8]) -> Result<(), JsValue> {
let model = self
.model
.as_ref()
.ok_or_else(|| JsValue::from_str("Model not loaded. Call load_from_buffer first."))?;
let voice_state = model
.get_voice_state_from_bytes(wav_bytes)
.map_err(|e| JsValue::from_str(&format!("Voice cloning failed: {:?}", e)))?;
self.voice_state = Some(voice_state);
web_sys::console::log_1(&"Voice loaded successfully (from audio)".into());
Ok(())
}
#[wasm_bindgen]
pub fn load_voice_from_safetensors(&mut self, bytes: &[u8]) -> Result<(), JsValue> {
let model = self
.model
.as_ref()
.ok_or_else(|| JsValue::from_str("Model not loaded. Call load_from_buffer first."))?;
let voice_state = model
.get_voice_state_from_prompt_bytes(bytes)
.map_err(|e| JsValue::from_str(&format!("Voice loading failed: {:?}", e)))?;
self.voice_state = Some(voice_state);
web_sys::console::log_1(&"Voice loaded successfully (from embedding)".into());
Ok(())
}
#[wasm_bindgen(getter)]
pub fn sample_rate(&self) -> u32 {
self.sample_rate
}
#[wasm_bindgen]
pub fn generate(&self, text: &str) -> Result<Float32Array, JsValue> {
let model = self
.model
.as_ref()
.ok_or_else(|| JsValue::from_str("Model not loaded. Call load_from_buffer first."))?;
let voice_state = self
.voice_state
.clone()
.unwrap_or_else(|| crate::voice_state::init_states(1, 0));
let audio_tensor = model
.generate(text, &voice_state)
.map_err(|e| JsValue::from_str(&format!("Generation failed: {:?}", e)))?;
let samples = audio_tensor
.to_vec2::<f32>()
.map_err(|e| JsValue::from_str(&format!("Failed to extract samples: {:?}", e)))?[0]
.clone();
let array = Float32Array::new_with_length(samples.len() as u32);
array.copy_from(&samples);
Ok(array)
}
#[wasm_bindgen]
pub fn start_stream(&self, text: &str) -> Result<WasmTTSStream, JsValue> {
let model = self
.model
.as_ref()
.ok_or_else(|| JsValue::from_str("Model not loaded. Call load_from_buffer first."))?;
let voice_state = self
.voice_state
.clone()
.unwrap_or_else(|| crate::voice_state::init_states(1, 0));
let iter = model.generate_stream_owned(text, &voice_state);
Ok(WasmTTSStream {
iter: Some(iter),
last_samples: 0,
last_compute_ms: 0.0,
last_chunks_merged: 0,
})
}
#[wasm_bindgen]
pub fn generate_wav_base64(&self, text: &str) -> Result<String, JsValue> {
let samples = self.generate(text)?;
let mut sample_vec = vec![0.0f32; samples.length() as usize];
samples.copy_to(&mut sample_vec);
let mut buffer = std::io::Cursor::new(Vec::new());
{
write_wav_header(&mut buffer, self.sample_rate, sample_vec.len() as u32)
.map_err(|e| JsValue::from_str(&format!("WAV header error: {:?}", e)))?;
let pcm_bytes = crate::audio::pcm_i16_le_bytes_mono(&sample_vec);
buffer.get_mut().extend_from_slice(&pcm_bytes);
}
let base64 = base64_encode(buffer.get_ref());
Ok(format!("data:audio/wav;base64,{}", base64))
}
}
#[wasm_bindgen]
impl WasmTTSStream {
#[wasm_bindgen]
pub fn next_chunk(&mut self) -> Result<Option<Float32Array>, JsValue> {
self.next_chunk_min_samples(1)
}
#[wasm_bindgen]
pub fn next_chunk_min_samples(
&mut self,
min_samples: u32,
) -> Result<Option<Float32Array>, JsValue> {
let start_ms = Date::now();
let iter = match self.iter.as_mut() {
Some(iter) => iter,
None => return Ok(None),
};
let target_samples = min_samples.max(1) as usize;
let mut merged = Vec::<f32>::new();
let mut merged_chunks = 0u32;
while merged.len() < target_samples {
match iter.next() {
Some(Ok(tensor)) => {
let samples = tensor_to_mono_vec(&tensor)?;
if !samples.is_empty() {
merged.extend_from_slice(&samples);
}
merged_chunks += 1;
}
Some(Err(e)) => {
self.last_samples = 0;
self.last_chunks_merged = 0;
self.last_compute_ms = Date::now() - start_ms;
return Err(JsValue::from_str(&format!("Generation failed: {:?}", e)));
}
None => {
self.iter = None;
break;
}
}
}
self.last_compute_ms = Date::now() - start_ms;
self.last_chunks_merged = merged_chunks;
if merged.is_empty() {
self.last_samples = 0;
return Ok(None);
}
self.last_samples = merged.len() as u32;
let array = Float32Array::new_with_length(self.last_samples);
array.copy_from(&merged);
Ok(Some(array))
}
#[wasm_bindgen]
pub fn last_chunk_stats(&self) -> JsValue {
let stats = Object::new();
let _ = Reflect::set(
&stats,
&JsValue::from_str("samples"),
&JsValue::from_f64(self.last_samples as f64),
);
let _ = Reflect::set(
&stats,
&JsValue::from_str("compute_ms"),
&JsValue::from_f64(self.last_compute_ms),
);
let _ = Reflect::set(
&stats,
&JsValue::from_str("chunks_merged"),
&JsValue::from_f64(self.last_chunks_merged as f64),
);
JsValue::from(stats)
}
}
fn write_wav_header(
w: &mut dyn std::io::Write,
sample_rate: u32,
num_samples: u32,
) -> std::io::Result<()> {
let subchunk2_size = num_samples * 2; let chunk_size = 36 + subchunk2_size;
w.write_all(b"RIFF")?;
w.write_all(&chunk_size.to_le_bytes())?;
w.write_all(b"WAVE")?;
w.write_all(b"fmt ")?;
w.write_all(&16u32.to_le_bytes())?; w.write_all(&1u16.to_le_bytes())?; w.write_all(&1u16.to_le_bytes())?; w.write_all(&sample_rate.to_le_bytes())?;
w.write_all(&(sample_rate * 2).to_le_bytes())?; w.write_all(&2u16.to_le_bytes())?; w.write_all(&16u16.to_le_bytes())?; w.write_all(b"data")?;
w.write_all(&subchunk2_size.to_le_bytes())?;
Ok(())
}
fn tensor_to_mono_vec(tensor: &Tensor) -> Result<Vec<f32>, JsValue> {
let dims = tensor.dims();
match dims.len() {
3 => {
let data = tensor
.to_vec3::<f32>()
.map_err(|e| JsValue::from_str(&format!("Failed to extract samples: {:?}", e)))?;
Ok(data[0][0].clone())
}
2 => {
let data = tensor
.to_vec2::<f32>()
.map_err(|e| JsValue::from_str(&format!("Failed to extract samples: {:?}", e)))?;
Ok(data[0].clone())
}
1 => tensor
.to_vec1::<f32>()
.map_err(|e| JsValue::from_str(&format!("Failed to extract samples: {:?}", e))),
_ => Err(JsValue::from_str("Unexpected audio tensor shape")),
}
}
fn base64_encode(input: &[u8]) -> String {
const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
let mut output = String::with_capacity((input.len() + 2) / 3 * 4);
for chunk in input.chunks(3) {
let b = match chunk.len() {
3 => (chunk[0] as u32) << 16 | (chunk[1] as u32) << 8 | (chunk[2] as u32),
2 => (chunk[0] as u32) << 16 | (chunk[1] as u32) << 8,
1 => (chunk[0] as u32) << 16,
_ => unreachable!(),
};
output.push(CHARSET[(b >> 18 & 0x3F) as usize] as char);
output.push(CHARSET[(b >> 12 & 0x3F) as usize] as char);
if chunk.len() > 1 {
output.push(CHARSET[(b >> 6 & 0x3F) as usize] as char);
} else {
output.push('=');
}
if chunk.len() > 2 {
output.push(CHARSET[(b & 0x3F) as usize] as char);
} else {
output.push('=');
}
}
output
}
impl Default for WasmTTSModel {
fn default() -> Self {
Self::new()
}
}