use anyhow::{Context, Result};
use rlx_core::embedded_safetensors::EmbeddedSafetensors;
use std::path::Path;
use std::sync::OnceLock;
const SAFETENSORS: &[u8] = include_bytes!("../../weights/silero_vad_16k.safetensors");
static PARSED: OnceLock<super::weights::SileroWeights> = OnceLock::new();
pub fn embedded() -> &'static super::weights::SileroWeights {
PARSED.get_or_init(|| parse_bytes(SAFETENSORS).expect("embedded silero safetensors"))
}
pub fn load_file(path: &Path) -> Result<super::weights::SileroWeights> {
let bytes = std::fs::read(path).with_context(|| format!("read {}", path.display()))?;
parse_bytes(&bytes)
}
fn parse_bytes(bytes: &[u8]) -> Result<super::weights::SileroWeights> {
let st = EmbeddedSafetensors::parse(bytes)?;
let final_w = st.tensor_f32("final_conv.weight")?;
Ok(super::weights::SileroWeights {
stft_conv: st.tensor_f32("stft_conv.weight")?,
conv1_w: st.tensor_f32("conv1.weight")?,
conv1_b: st.tensor_f32("conv1.bias")?,
conv2_w: st.tensor_f32("conv2.weight")?,
conv2_b: st.tensor_f32("conv2.bias")?,
conv3_w: st.tensor_f32("conv3.weight")?,
conv3_b: st.tensor_f32("conv3.bias")?,
conv4_w: st.tensor_f32("conv4.weight")?,
conv4_b: st.tensor_f32("conv4.bias")?,
lstm_w_ih: st.tensor_f32("lstm_cell.weight_ih")?,
lstm_w_hh: st.tensor_f32("lstm_cell.weight_hh")?,
lstm_b_ih: st.tensor_f32("lstm_cell.bias_ih")?,
lstm_b_hh: st.tensor_f32("lstm_cell.bias_hh")?,
final_w,
final_b: st.tensor_f32("final_conv.bias")?,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn embedded_safetensors_parses() {
let w = embedded();
assert_eq!(w.stft_conv.len(), 130 * 128);
assert_eq!(w.conv1_w.len(), 128 * 65 * 3);
assert_eq!(w.lstm_w_ih.len(), 512 * 128);
assert_eq!(w.final_w.len(), 128);
}
}