use std::collections::HashMap;
use std::path::Path;
use mlx_rs::nn;
use mlx_rs::ops;
use mlx_rs::ops::indexing::IndexOp;
use mlx_rs::Array;
use tracing::{info, warn};
use super::mlx::load_all_tensors;
use crate::InferenceError;
const SAMPLE_RATE: usize = 16000;
const N_MEL: usize = 128;
const N_FFT: usize = 512;
const WIN_LEN_SAMPLES: usize = 400; const HOP_LEN_SAMPLES: usize = 160;
const D_MODEL: usize = 1024;
const NUM_HEADS: usize = 8;
const HEAD_DIM: usize = D_MODEL / NUM_HEADS; const NUM_ENCODER_LAYERS: usize = 24;
const PRED_HIDDEN: usize = 640;
const PRED_LAYERS: usize = 2;
const BLANK_ID: u32 = 0;
fn mel_filterbank() -> Result<Array, mlx_rs::error::Exception> {
let n_freqs = N_FFT / 2 + 1; let f_max = SAMPLE_RATE as f64 / 2.0;
let hz_to_mel = |f: f64| -> f64 { 2595.0 * (1.0 + f / 700.0).log10() };
let mel_to_hz = |m: f64| -> f64 { 700.0 * (10.0_f64.powf(m / 2595.0) - 1.0) };
let mel_min = hz_to_mel(0.0);
let mel_max = hz_to_mel(f_max);
let n_points = N_MEL + 2;
let mel_points: Vec<f64> = (0..n_points)
.map(|i| mel_min + (mel_max - mel_min) * i as f64 / (n_points - 1) as f64)
.collect();
let hz_points: Vec<f64> = mel_points.iter().map(|&m| mel_to_hz(m)).collect();
let bin_freqs: Vec<f64> = (0..n_freqs)
.map(|i| i as f64 * SAMPLE_RATE as f64 / N_FFT as f64)
.collect();
let mut fb = vec![0.0f32; N_MEL * n_freqs];
for m in 0..N_MEL {
let f_left = hz_points[m];
let f_center = hz_points[m + 1];
let f_right = hz_points[m + 2];
for k in 0..n_freqs {
let f = bin_freqs[k];
if f >= f_left && f <= f_center && f_center > f_left {
fb[m * n_freqs + k] = ((f - f_left) / (f_center - f_left)) as f32;
} else if f > f_center && f <= f_right && f_right > f_center {
fb[m * n_freqs + k] = ((f_right - f) / (f_right - f_center)) as f32;
}
}
}
Ok(Array::from_slice(&fb, &[N_MEL as i32, n_freqs as i32]))
}
fn compute_log_mel(samples: &[f32]) -> Result<Array, InferenceError> {
let map_err = |e: mlx_rs::error::Exception| InferenceError::InferenceFailed(e.to_string());
let pad_len = if samples.len() < WIN_LEN_SAMPLES {
WIN_LEN_SAMPLES - samples.len()
} else {
let remainder = (samples.len() - WIN_LEN_SAMPLES) % HOP_LEN_SAMPLES;
if remainder == 0 {
0
} else {
HOP_LEN_SAMPLES - remainder
}
};
let mut padded = samples.to_vec();
padded.extend(std::iter::repeat(0.0f32).take(pad_len));
let num_frames = (padded.len() - WIN_LEN_SAMPLES) / HOP_LEN_SAMPLES + 1;
let hann: Vec<f32> = (0..WIN_LEN_SAMPLES)
.map(|i| {
let w = (std::f32::consts::PI * i as f32 / (WIN_LEN_SAMPLES - 1) as f32).sin();
w * w
})
.collect();
let mut framed = vec![0.0f32; num_frames * N_FFT];
for f in 0..num_frames {
let start = f * HOP_LEN_SAMPLES;
for s in 0..WIN_LEN_SAMPLES {
framed[f * N_FFT + s] = padded[start + s] * hann[s];
}
}
let framed_arr = Array::from_slice(&framed, &[num_frames as i32, N_FFT as i32]);
let spectrum = mlx_rs::fft::rfft(&framed_arr, N_FFT as i32, -1).map_err(map_err)?;
let mag = ops::abs(&spectrum).map_err(map_err)?;
let power = ops::square(&mag).map_err(map_err)?;
let power = power.as_dtype(mlx_rs::Dtype::Float32).map_err(map_err)?;
let fb = mel_filterbank().map_err(map_err)?;
let fb_t = ops::transpose_axes(&fb, &[1, 0]).map_err(map_err)?;
let mel = ops::matmul(&power, &fb_t).map_err(map_err)?;
let floor = Array::from_f32(1e-10);
let mel_clamped = ops::maximum(&mel, &floor).map_err(map_err)?;
let log_mel = ops::log(&mel_clamped).map_err(map_err)?;
ops::reshape(&log_mel, &[1, num_frames as i32, N_MEL as i32]).map_err(map_err)
}
fn load_wav(path: &Path) -> Result<Vec<f32>, InferenceError> {
let data = std::fs::read(path)
.map_err(|e| InferenceError::InferenceFailed(format!("read wav: {e}")))?;
if data.len() < 44 {
return Err(InferenceError::InferenceFailed("WAV file too short".into()));
}
if &data[0..4] != b"RIFF" || &data[8..12] != b"WAVE" {
return Err(InferenceError::InferenceFailed(
"not a valid WAV file".into(),
));
}
let mut pos = 12;
let mut sample_rate = 0u32;
let mut bits_per_sample = 0u16;
let mut num_channels = 0u16;
let mut audio_format = 0u16;
let mut data_start = 0usize;
let mut data_len = 0usize;
while pos + 8 <= data.len() {
let chunk_id = &data[pos..pos + 4];
let chunk_size =
u32::from_le_bytes([data[pos + 4], data[pos + 5], data[pos + 6], data[pos + 7]])
as usize;
if chunk_id == b"fmt " && chunk_size >= 16 {
audio_format = u16::from_le_bytes([data[pos + 8], data[pos + 9]]);
num_channels = u16::from_le_bytes([data[pos + 10], data[pos + 11]]);
sample_rate = u32::from_le_bytes([
data[pos + 12],
data[pos + 13],
data[pos + 14],
data[pos + 15],
]);
bits_per_sample = u16::from_le_bytes([data[pos + 22], data[pos + 23]]);
} else if chunk_id == b"data" {
data_start = pos + 8;
data_len = chunk_size;
}
pos += 8 + chunk_size;
if chunk_size % 2 != 0 {
pos += 1;
}
}
if audio_format != 1 {
return Err(InferenceError::InferenceFailed(format!(
"unsupported WAV format {audio_format}, only PCM (1) supported"
)));
}
if sample_rate != SAMPLE_RATE as u32 {
return Err(InferenceError::InferenceFailed(format!(
"expected {SAMPLE_RATE}Hz, got {sample_rate}Hz"
)));
}
if data_start == 0 || data_len == 0 {
return Err(InferenceError::InferenceFailed(
"no data chunk found in WAV".into(),
));
}
let end = (data_start + data_len).min(data.len());
let raw = &data[data_start..end];
let samples: Vec<f32> = match bits_per_sample {
16 => raw
.chunks_exact(2 * num_channels as usize)
.map(|frame| {
let s = i16::from_le_bytes([frame[0], frame[1]]);
s as f32 / 32768.0
})
.collect(),
32 => raw
.chunks_exact(4 * num_channels as usize)
.map(|frame| f32::from_le_bytes([frame[0], frame[1], frame[2], frame[3]]))
.collect(),
_ => {
return Err(InferenceError::InferenceFailed(format!(
"unsupported bits_per_sample: {bits_per_sample}"
)))
}
};
Ok(samples)
}
struct LayerNorm {
weight: Array,
bias: Array,
eps: f32,
}
impl LayerNorm {
fn forward(&self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
let mean = x.mean_axes(&[-1], true)?;
let centered = ops::subtract(x, &mean)?;
let var = centered.square()?.mean_axes(&[-1], true)?;
let eps_arr = Array::from_f32(self.eps);
let norm = ops::rsqrt(&ops::add(&var, &eps_arr)?)?;
let normed = ops::multiply(¢ered, &norm)?;
let scaled = ops::multiply(&normed, &self.weight)?;
ops::add(&scaled, &self.bias)
}
fn all_arrays(&self) -> Vec<&Array> {
vec![&self.weight, &self.bias]
}
}
fn load_layer_norm(
tensors: &HashMap<String, Array>,
prefix: &str,
eps: f32,
) -> Result<LayerNorm, InferenceError> {
let weight = get_tensor(tensors, &format!("{prefix}.weight"))?;
let bias = get_tensor(tensors, &format!("{prefix}.bias"))?;
Ok(LayerNorm { weight, bias, eps })
}
struct Linear {
weight: Array,
bias: Option<Array>,
}
impl Linear {
fn forward(&self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
let w_t = ops::transpose_axes(&self.weight, &[1, 0])?;
let out = ops::matmul(x, &w_t)?;
if let Some(ref b) = self.bias {
ops::add(&out, b)
} else {
Ok(out)
}
}
fn all_arrays(&self) -> Vec<&Array> {
let mut v = vec![&self.weight];
if let Some(ref b) = self.bias {
v.push(b);
}
v
}
}
fn load_linear(tensors: &HashMap<String, Array>, prefix: &str) -> Result<Linear, InferenceError> {
let weight = get_tensor(tensors, &format!("{prefix}.weight"))?;
let bias = tensors.get(&format!("{prefix}.bias")).cloned();
Ok(Linear { weight, bias })
}
fn load_linear_with_bias(
tensors: &HashMap<String, Array>,
prefix: &str,
) -> Result<Linear, InferenceError> {
let weight = get_tensor(tensors, &format!("{prefix}.weight"))?;
let bias = get_tensor(tensors, &format!("{prefix}.bias"))?;
Ok(Linear {
weight,
bias: Some(bias),
})
}
fn get_tensor(tensors: &HashMap<String, Array>, key: &str) -> Result<Array, InferenceError> {
tensors
.get(key)
.cloned()
.ok_or_else(|| InferenceError::InferenceFailed(format!("missing tensor: {key}")))
}
struct DepthwiseSubsampling {
conv1_weight: Array,
conv1_bias: Array,
conv2_weight: Array,
conv2_bias: Array,
conv3_weight: Array,
conv3_bias: Array,
norm: LayerNorm,
proj: Linear,
}
impl DepthwiseSubsampling {
fn forward(&self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
let h = ops::conv1d(x, &self.conv1_weight, 2, 1, 1, 1)?;
let h = ops::add(&h, &self.conv1_bias)?;
let h = nn::relu(&h)?;
let h = ops::conv1d(&h, &self.conv2_weight, 2, 1, 1, 1)?;
let h = ops::add(&h, &self.conv2_bias)?;
let h = nn::relu(&h)?;
let h = ops::conv1d(&h, &self.conv3_weight, 2, 1, 1, 1)?;
let h = ops::add(&h, &self.conv3_bias)?;
let h = nn::relu(&h)?;
let h = self.proj.forward(&h)?;
self.norm.forward(&h)
}
fn all_arrays(&self) -> Vec<&Array> {
let mut v = vec![
&self.conv1_weight,
&self.conv1_bias,
&self.conv2_weight,
&self.conv2_bias,
&self.conv3_weight,
&self.conv3_bias,
];
v.extend(self.norm.all_arrays());
v.extend(self.proj.all_arrays());
v
}
}
fn load_subsampling(
tensors: &HashMap<String, Array>,
prefix: &str,
) -> Result<DepthwiseSubsampling, InferenceError> {
Ok(DepthwiseSubsampling {
conv1_weight: get_tensor(tensors, &format!("{prefix}.conv.0.weight"))?,
conv1_bias: get_tensor(tensors, &format!("{prefix}.conv.0.bias"))?,
conv2_weight: get_tensor(tensors, &format!("{prefix}.conv.2.weight"))?,
conv2_bias: get_tensor(tensors, &format!("{prefix}.conv.2.bias"))?,
conv3_weight: get_tensor(tensors, &format!("{prefix}.conv.4.weight"))?,
conv3_bias: get_tensor(tensors, &format!("{prefix}.conv.4.bias"))?,
norm: load_layer_norm(tensors, &format!("{prefix}.norm"), 1e-5)?,
proj: load_linear_with_bias(tensors, &format!("{prefix}.proj"))?,
})
}
struct FeedForwardModule {
norm: LayerNorm,
linear1: Linear,
linear2: Linear,
}
impl FeedForwardModule {
fn forward(&self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
let h = self.norm.forward(x)?;
let h = self.linear1.forward(&h)?;
let h = nn::silu(&h)?;
self.linear2.forward(&h)
}
fn all_arrays(&self) -> Vec<&Array> {
let mut v = self.norm.all_arrays();
v.extend(self.linear1.all_arrays());
v.extend(self.linear2.all_arrays());
v
}
}
fn load_ff_module(
tensors: &HashMap<String, Array>,
prefix: &str,
) -> Result<FeedForwardModule, InferenceError> {
Ok(FeedForwardModule {
norm: load_layer_norm(tensors, &format!("{prefix}.norm"), 1e-5)?,
linear1: load_linear_with_bias(tensors, &format!("{prefix}.linear1"))?,
linear2: load_linear_with_bias(tensors, &format!("{prefix}.linear2"))?,
})
}
struct MultiHeadSelfAttention {
norm: LayerNorm,
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
pos_bias: Linear,
}
impl MultiHeadSelfAttention {
fn forward(&self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
let shape = x.shape().to_vec();
let batch = shape[0] as usize;
let seq_len = shape[1] as usize;
let h = self.norm.forward(x)?;
let q = self.q_proj.forward(&h)?;
let k = self.k_proj.forward(&h)?;
let v = self.v_proj.forward(&h)?;
let q = ops::transpose_axes(
&ops::reshape(
&q,
&[
batch as i32,
seq_len as i32,
NUM_HEADS as i32,
HEAD_DIM as i32,
],
)?,
&[0, 2, 1, 3],
)?;
let k = ops::transpose_axes(
&ops::reshape(
&k,
&[
batch as i32,
seq_len as i32,
NUM_HEADS as i32,
HEAD_DIM as i32,
],
)?,
&[0, 2, 1, 3],
)?;
let v = ops::transpose_axes(
&ops::reshape(
&v,
&[
batch as i32,
seq_len as i32,
NUM_HEADS as i32,
HEAD_DIM as i32,
],
)?,
&[0, 2, 1, 3],
)?;
let scale = Array::from_f32(1.0 / (HEAD_DIM as f32).sqrt());
let scores = ops::multiply(
&ops::matmul(&q, &ops::transpose_axes(&k, &[0, 1, 3, 2])?)?,
&scale,
)?;
let positions = self.compute_relative_position_bias(seq_len)?;
let scores = ops::add(&scores, &positions)?;
let attn = ops::softmax_axis(&scores, -1, None)?;
let out = ops::matmul(&attn, &v)?;
let out = ops::transpose_axes(&out, &[0, 2, 1, 3])?;
let out = ops::reshape(&out, &[batch as i32, seq_len as i32, D_MODEL as i32])?;
self.o_proj.forward(&out)
}
fn compute_relative_position_bias(
&self,
seq_len: usize,
) -> Result<Array, mlx_rs::error::Exception> {
let mut rel_pos = vec![0.0f32; seq_len * seq_len];
for i in 0..seq_len {
for j in 0..seq_len {
rel_pos[i * seq_len + j] = (i as f32) - (j as f32);
}
}
let rel_pos_arr = Array::from_slice(&rel_pos, &[seq_len as i32, seq_len as i32]);
let rel_pos_arr = ops::reshape(&rel_pos_arr, &[seq_len as i32 * seq_len as i32, 1])?;
let bias = self.pos_bias.forward(&rel_pos_arr)?;
let bias = ops::reshape(&bias, &[seq_len as i32, seq_len as i32, NUM_HEADS as i32])?;
let bias = ops::transpose_axes(&bias, &[2, 0, 1])?;
ops::reshape(
&bias,
&[1, NUM_HEADS as i32, seq_len as i32, seq_len as i32],
)
}
fn all_arrays(&self) -> Vec<&Array> {
let mut v = self.norm.all_arrays();
v.extend(self.q_proj.all_arrays());
v.extend(self.k_proj.all_arrays());
v.extend(self.v_proj.all_arrays());
v.extend(self.o_proj.all_arrays());
v.extend(self.pos_bias.all_arrays());
v
}
}
fn load_mhsa(
tensors: &HashMap<String, Array>,
prefix: &str,
) -> Result<MultiHeadSelfAttention, InferenceError> {
Ok(MultiHeadSelfAttention {
norm: load_layer_norm(tensors, &format!("{prefix}.norm"), 1e-5)?,
q_proj: load_linear_with_bias(tensors, &format!("{prefix}.q_proj"))?,
k_proj: load_linear_with_bias(tensors, &format!("{prefix}.k_proj"))?,
v_proj: load_linear_with_bias(tensors, &format!("{prefix}.v_proj"))?,
o_proj: load_linear_with_bias(tensors, &format!("{prefix}.o_proj"))?,
pos_bias: load_linear_with_bias(tensors, &format!("{prefix}.pos_bias"))?,
})
}
struct ConvModule {
norm: LayerNorm,
pointwise1_weight: Array,
pointwise1_bias: Array,
depthwise_weight: Array,
depthwise_bias: Array,
batch_norm_weight: Array,
batch_norm_bias: Array,
batch_norm_mean: Array,
batch_norm_var: Array,
pointwise2_weight: Array,
pointwise2_bias: Array,
}
impl ConvModule {
fn forward(&self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
let h = self.norm.forward(x)?;
let h = ops::conv1d(&h, &self.pointwise1_weight, 1, 0, 1, 1)?;
let h = ops::add(&h, &self.pointwise1_bias)?;
let ch = h.shape()[2] as usize;
let half = (ch / 2) as i32;
let gate_input = h.index((.., .., ..half));
let gate = h.index((.., .., half..));
let h = ops::multiply(&gate_input, &nn::sigmoid(&gate)?)?;
let kernel_size = self.depthwise_weight.shape()[1] as i32;
let pad = (kernel_size - 1) / 2;
let groups = self.depthwise_weight.shape()[2] as i32;
let h = ops::conv1d(&h, &self.depthwise_weight, 1, pad, 1, groups)?;
let h = ops::add(&h, &self.depthwise_bias)?;
let eps = Array::from_f32(1e-5);
let bn_std = ops::rsqrt(&ops::add(&self.batch_norm_var, &eps)?)?;
let h = ops::multiply(&ops::subtract(&h, &self.batch_norm_mean)?, &bn_std)?;
let h = ops::multiply(&h, &self.batch_norm_weight)?;
let h = ops::add(&h, &self.batch_norm_bias)?;
let h = nn::silu(&h)?;
let h = ops::conv1d(&h, &self.pointwise2_weight, 1, 0, 1, 1)?;
ops::add(&h, &self.pointwise2_bias)
}
fn all_arrays(&self) -> Vec<&Array> {
let mut v = self.norm.all_arrays();
v.extend([
&self.pointwise1_weight,
&self.pointwise1_bias,
&self.depthwise_weight,
&self.depthwise_bias,
&self.batch_norm_weight,
&self.batch_norm_bias,
&self.batch_norm_mean,
&self.batch_norm_var,
&self.pointwise2_weight,
&self.pointwise2_bias,
]);
v
}
}
fn load_conv_module(
tensors: &HashMap<String, Array>,
prefix: &str,
) -> Result<ConvModule, InferenceError> {
Ok(ConvModule {
norm: load_layer_norm(tensors, &format!("{prefix}.norm"), 1e-5)?,
pointwise1_weight: get_tensor(tensors, &format!("{prefix}.pointwise_conv1.weight"))?,
pointwise1_bias: get_tensor(tensors, &format!("{prefix}.pointwise_conv1.bias"))?,
depthwise_weight: get_tensor(tensors, &format!("{prefix}.depthwise_conv.weight"))?,
depthwise_bias: get_tensor(tensors, &format!("{prefix}.depthwise_conv.bias"))?,
batch_norm_weight: get_tensor(tensors, &format!("{prefix}.batch_norm.weight"))?,
batch_norm_bias: get_tensor(tensors, &format!("{prefix}.batch_norm.bias"))?,
batch_norm_mean: get_tensor(tensors, &format!("{prefix}.batch_norm.running_mean"))?,
batch_norm_var: get_tensor(tensors, &format!("{prefix}.batch_norm.running_var"))?,
pointwise2_weight: get_tensor(tensors, &format!("{prefix}.pointwise_conv2.weight"))?,
pointwise2_bias: get_tensor(tensors, &format!("{prefix}.pointwise_conv2.bias"))?,
})
}
struct ConformerLayer {
ff1: FeedForwardModule,
mhsa: MultiHeadSelfAttention,
conv: ConvModule,
ff2: FeedForwardModule,
final_norm: LayerNorm,
}
impl ConformerLayer {
fn forward(&self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
let half = Array::from_f32(0.5);
let ff1_out = self.ff1.forward(x)?;
let x = ops::add(x, &ops::multiply(&half, &ff1_out)?)?;
let mhsa_out = self.mhsa.forward(&x)?;
let x = ops::add(&x, &mhsa_out)?;
let conv_out = self.conv.forward(&x)?;
let x = ops::add(&x, &conv_out)?;
let ff2_out = self.ff2.forward(&x)?;
let x = ops::add(&x, &ops::multiply(&half, &ff2_out)?)?;
self.final_norm.forward(&x)
}
fn all_arrays(&self) -> Vec<&Array> {
let mut v = self.ff1.all_arrays();
v.extend(self.mhsa.all_arrays());
v.extend(self.conv.all_arrays());
v.extend(self.ff2.all_arrays());
v.extend(self.final_norm.all_arrays());
v
}
}
fn load_conformer_layer(
tensors: &HashMap<String, Array>,
prefix: &str,
) -> Result<ConformerLayer, InferenceError> {
Ok(ConformerLayer {
ff1: load_ff_module(tensors, &format!("{prefix}.ff1"))?,
mhsa: load_mhsa(tensors, &format!("{prefix}.mhsa"))?,
conv: load_conv_module(tensors, &format!("{prefix}.conv"))?,
ff2: load_ff_module(tensors, &format!("{prefix}.ff2"))?,
final_norm: load_layer_norm(tensors, &format!("{prefix}.final_norm"), 1e-5)?,
})
}
struct ConformerEncoder {
subsampling: DepthwiseSubsampling,
layers: Vec<ConformerLayer>,
final_norm: LayerNorm,
}
impl ConformerEncoder {
fn forward(&self, mel: &Array) -> Result<Array, mlx_rs::error::Exception> {
let mut h = self.subsampling.forward(mel)?;
for layer in &self.layers {
h = layer.forward(&h)?;
}
self.final_norm.forward(&h)
}
fn all_arrays(&self) -> Vec<&Array> {
let mut v = self.subsampling.all_arrays();
for layer in &self.layers {
v.extend(layer.all_arrays());
}
v.extend(self.final_norm.all_arrays());
v
}
}
fn load_encoder(tensors: &HashMap<String, Array>) -> Result<ConformerEncoder, InferenceError> {
let mut layers = Vec::with_capacity(NUM_ENCODER_LAYERS);
for i in 0..NUM_ENCODER_LAYERS {
layers.push(load_conformer_layer(
tensors,
&format!("encoder.layers.{i}"),
)?);
}
Ok(ConformerEncoder {
subsampling: load_subsampling(tensors, "encoder.subsampling")?,
layers,
final_norm: load_layer_norm(tensors, "encoder.final_norm", 1e-5)?,
})
}
struct PredictionNetwork {
embedding: Array, lstm_layers: Vec<LstmLayer>,
proj: Linear,
}
struct LstmLayer {
wx: Array, wh: Array, bias: Option<Array>,
}
impl LstmLayer {
fn step(
&self,
x: &Array,
h: &Array,
c: &Array,
) -> Result<(Array, Array), mlx_rs::error::Exception> {
let wx_t = ops::transpose_axes(&self.wx, &[1, 0])?;
let wh_t = ops::transpose_axes(&self.wh, &[1, 0])?;
let mut gates = ops::add(&ops::matmul(x, &wx_t)?, &ops::matmul(h, &wh_t)?)?;
if let Some(ref bias) = self.bias {
gates = ops::add(&gates, bias)?;
}
let hidden = h.shape()[h.shape().len() - 1] as usize;
let i_gate = nn::sigmoid(&gates.index((.., ..(hidden as i32))))?;
let f_gate = nn::sigmoid(&gates.index((.., (hidden as i32)..(2 * hidden as i32))))?;
let g_gate = ops::tanh(&gates.index((.., (2 * hidden as i32)..(3 * hidden as i32))))?;
let o_gate = nn::sigmoid(&gates.index((.., (3 * hidden as i32)..)))?;
let new_c = ops::add(
&ops::multiply(&f_gate, c)?,
&ops::multiply(&i_gate, &g_gate)?,
)?;
let new_h = ops::multiply(&o_gate, &ops::tanh(&new_c)?)?;
Ok((new_h, new_c))
}
fn all_arrays(&self) -> Vec<&Array> {
let mut v = vec![&self.wx, &self.wh];
if let Some(ref b) = self.bias {
v.push(b);
}
v
}
}
impl PredictionNetwork {
fn step(
&self,
token: u32,
states: &[(Array, Array)],
) -> Result<(Array, Vec<(Array, Array)>), mlx_rs::error::Exception> {
let tok_arr = Array::from_slice(&[token as i32], &[1]);
let embed_row = self.embedding.index((tok_arr, ..));
let mut x = embed_row;
let mut new_states = Vec::with_capacity(self.lstm_layers.len());
for (i, lstm) in self.lstm_layers.iter().enumerate() {
let (h, c) = &states[i];
let (new_h, new_c) = lstm.step(&x, h, c)?;
x = new_h.clone();
new_states.push((new_h, new_c));
}
let pred_out = self.proj.forward(&x)?;
Ok((pred_out, new_states))
}
fn initial_states(&self) -> Result<Vec<(Array, Array)>, mlx_rs::error::Exception> {
let mut states = Vec::with_capacity(self.lstm_layers.len());
for _ in &self.lstm_layers {
let h = Array::zeros::<f32>(&[1, PRED_HIDDEN as i32])?;
let c = Array::zeros::<f32>(&[1, PRED_HIDDEN as i32])?;
states.push((h, c));
}
Ok(states)
}
fn all_arrays(&self) -> Vec<&Array> {
let mut v = vec![&self.embedding];
for lstm in &self.lstm_layers {
v.extend(lstm.all_arrays());
}
v.extend(self.proj.all_arrays());
v
}
}
fn load_prediction_network(
tensors: &HashMap<String, Array>,
) -> Result<PredictionNetwork, InferenceError> {
let embedding = get_tensor(tensors, "prediction.embedding.weight")?;
let mut lstm_layers = Vec::with_capacity(PRED_LAYERS);
for i in 0..PRED_LAYERS {
let pfx = format!("prediction.lstm.layers.{i}");
let wx = get_tensor(tensors, &format!("{pfx}.wx"))?;
let wh = get_tensor(tensors, &format!("{pfx}.wh"))?;
let bias = tensors.get(&format!("{pfx}.bias")).cloned();
lstm_layers.push(LstmLayer { wx, wh, bias });
}
let proj = load_linear(tensors, "prediction.proj")?;
Ok(PredictionNetwork {
embedding,
lstm_layers,
proj,
})
}
struct JointNetwork {
encoder_proj: Linear,
pred_proj: Linear,
joint_proj: Linear, duration_proj: Linear, }
impl JointNetwork {
fn forward(
&self,
encoder_frame: &Array,
pred_output: &Array,
) -> Result<(Array, Array), mlx_rs::error::Exception> {
let enc = self.encoder_proj.forward(encoder_frame)?;
let pred = self.pred_proj.forward(pred_output)?;
let joint = nn::relu(&ops::add(&enc, &pred)?)?;
let token_logits = self.joint_proj.forward(&joint)?;
let dur_logits = self.duration_proj.forward(&joint)?;
Ok((token_logits, dur_logits))
}
fn all_arrays(&self) -> Vec<&Array> {
let mut v = self.encoder_proj.all_arrays();
v.extend(self.pred_proj.all_arrays());
v.extend(self.joint_proj.all_arrays());
v.extend(self.duration_proj.all_arrays());
v
}
}
fn load_joint_network(tensors: &HashMap<String, Array>) -> Result<JointNetwork, InferenceError> {
Ok(JointNetwork {
encoder_proj: load_linear(tensors, "joint.encoder_proj")?,
pred_proj: load_linear(tensors, "joint.pred_proj")?,
joint_proj: load_linear(tensors, "joint.joint_proj")?,
duration_proj: load_linear(tensors, "joint.duration_proj")?,
})
}
fn load_vocabulary(model_dir: &Path) -> Result<Vec<String>, InferenceError> {
let vocab_path = model_dir.join("tokenizer.vocab");
if !vocab_path.exists() {
let tokenizer_json_path = model_dir.join("tokenizer.json");
if tokenizer_json_path.exists() {
return load_vocabulary_from_json(&tokenizer_json_path);
}
return Err(InferenceError::InferenceFailed(
"neither tokenizer.vocab nor tokenizer.json found".into(),
));
}
let content = std::fs::read_to_string(&vocab_path)
.map_err(|e| InferenceError::InferenceFailed(format!("read tokenizer.vocab: {e}")))?;
let vocab: Vec<String> = content.lines().map(|line| line.to_string()).collect();
if vocab.is_empty() {
return Err(InferenceError::InferenceFailed(
"empty vocabulary file".into(),
));
}
Ok(vocab)
}
fn load_vocabulary_from_json(path: &Path) -> Result<Vec<String>, InferenceError> {
let content = std::fs::read_to_string(path)
.map_err(|e| InferenceError::InferenceFailed(format!("read tokenizer.json: {e}")))?;
let json: serde_json::Value = serde_json::from_str(&content)
.map_err(|e| InferenceError::InferenceFailed(format!("parse tokenizer.json: {e}")))?;
if let Some(model) = json.get("model") {
if let Some(vocab_arr) = model.get("vocab").and_then(|v| v.as_array()) {
let vocab: Vec<String> = vocab_arr
.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect();
if !vocab.is_empty() {
return Ok(vocab);
}
}
}
if let Some(model) = json.get("model") {
if let Some(vocab_obj) = model.get("vocab").and_then(|v| v.as_object()) {
let mut pairs: Vec<(String, u64)> = vocab_obj
.iter()
.filter_map(|(k, v)| v.as_u64().map(|id| (k.clone(), id)))
.collect();
pairs.sort_by_key(|(_, id)| *id);
let vocab: Vec<String> = pairs.into_iter().map(|(k, _)| k).collect();
if !vocab.is_empty() {
return Ok(vocab);
}
}
}
Err(InferenceError::InferenceFailed(
"could not extract vocabulary from tokenizer.json".into(),
))
}
#[derive(Debug, Clone)]
pub(crate) struct TokenEmission {
pub token_id: u32,
pub start_frame: usize,
pub duration_frames: usize,
}
pub(crate) const FRAME_DURATION_S: f64 = 0.08;
fn greedy_tdt_decode(
encoder_output: &Array,
prediction: &PredictionNetwork,
joint: &JointNetwork,
vocab: &[String],
) -> Result<(String, Vec<TokenEmission>), InferenceError> {
let map_err = |e: mlx_rs::error::Exception| InferenceError::InferenceFailed(e.to_string());
let enc_shape = encoder_output.shape();
let num_frames = enc_shape[1] as usize;
let mut pred_states = prediction.initial_states().map_err(map_err)?;
let mut last_token = BLANK_ID;
let mut output_tokens: Vec<u32> = Vec::new();
let mut emissions: Vec<TokenEmission> = Vec::new();
let mut t = 0usize;
let max_steps = num_frames * 10; let mut step_count = 0usize;
while t < num_frames && step_count < max_steps {
step_count += 1;
let enc_frame = encoder_output.index((.., t as i32..t as i32 + 1, ..));
let enc_frame = ops::reshape(&enc_frame, &[1, D_MODEL as i32]).map_err(map_err)?;
let (pred_out, new_states) = prediction.step(last_token, &pred_states).map_err(map_err)?;
let (token_logits, dur_logits) = joint.forward(&enc_frame, &pred_out).map_err(map_err)?;
token_logits.eval().map_err(map_err)?;
dur_logits.eval().map_err(map_err)?;
let token_logits_flat = ops::reshape(&token_logits, &[-1]).map_err(map_err)?;
token_logits_flat.eval().map_err(map_err)?;
let token_data: &[f32] = token_logits_flat.as_slice();
let token_id = argmax_f32(token_data) as u32;
let dur_logits_flat = ops::reshape(&dur_logits, &[-1]).map_err(map_err)?;
dur_logits_flat.eval().map_err(map_err)?;
let dur_data: &[f32] = dur_logits_flat.as_slice();
let duration = argmax_f32(dur_data);
if token_id == BLANK_ID {
t += 1.max(duration);
} else {
output_tokens.push(token_id);
emissions.push(TokenEmission {
token_id,
start_frame: t,
duration_frames: duration,
});
last_token = token_id;
pred_states = new_states;
t += 1.max(duration);
}
}
if step_count >= max_steps {
warn!(
num_frames,
max_steps,
emitted_tokens = output_tokens.len(),
"parakeet TDT decode hit step cap — transcription may be truncated"
);
}
let mut oov_count: usize = 0;
let text: String = output_tokens
.iter()
.filter_map(|&id| {
let piece = vocab.get(id as usize);
if piece.is_none() {
oov_count += 1;
}
piece
})
.map(|token| token.replace('\u{2581}', " "))
.collect::<String>()
.trim()
.to_string();
if oov_count > 0 {
warn!(
oov_count,
vocab_size = vocab.len(),
"parakeet: emitted token id outside vocab range — model/vocab mismatch"
);
}
Ok((text, emissions))
}
pub(crate) fn emissions_to_words(
emissions: &[TokenEmission],
vocab: &[String],
) -> Vec<crate::tasks::transcribe::TranscribedWord> {
use crate::tasks::transcribe::TranscribedWord;
let mut words: Vec<TranscribedWord> = Vec::new();
let mut cur_text = String::new();
let mut cur_start_frame: Option<usize> = None;
let mut cur_end_frame: usize = 0;
let flush =
|words: &mut Vec<TranscribedWord>, start_frame: usize, end_frame: usize, text: String| {
if text.is_empty() {
return;
}
let start = (start_frame as f64 * FRAME_DURATION_S) as f32;
let end = (end_frame as f64 * FRAME_DURATION_S) as f32;
words.push(TranscribedWord { start, end, text });
};
for emission in emissions {
let Some(piece) = vocab.get(emission.token_id as usize).map(String::as_str) else {
continue;
};
let has_marker = piece.starts_with('\u{2581}');
let clean = piece.trim_start_matches('\u{2581}');
if clean.is_empty() && has_marker {
continue;
}
let starts_new_word = has_marker || cur_start_frame.is_none();
if starts_new_word {
if let Some(start_frame) = cur_start_frame.take() {
flush(
&mut words,
start_frame,
cur_end_frame,
std::mem::take(&mut cur_text),
);
}
cur_end_frame = 0;
}
if cur_start_frame.is_none() {
cur_start_frame = Some(emission.start_frame);
}
cur_text.push_str(clean);
let tok_end = emission
.start_frame
.saturating_add(emission.duration_frames);
cur_end_frame = cur_end_frame.max(tok_end);
}
if let Some(start_frame) = cur_start_frame {
flush(&mut words, start_frame, cur_end_frame, cur_text);
}
words
}
fn argmax_f32(data: &[f32]) -> usize {
let mut best_idx = 0;
let mut best_val = f32::NEG_INFINITY;
for (i, &v) in data.iter().enumerate() {
if v > best_val {
best_val = v;
best_idx = i;
}
}
best_idx
}
pub struct ParakeetBackend {
encoder: ConformerEncoder,
prediction: PredictionNetwork,
joint: JointNetwork,
vocab: Vec<String>,
}
unsafe impl Send for ParakeetBackend {}
unsafe impl Sync for ParakeetBackend {}
impl ParakeetBackend {
pub fn load(model_dir: &Path) -> Result<Self, InferenceError> {
info!(model_dir = %model_dir.display(), "loading Parakeet-TDT model via MLX");
#[cfg(feature = "mlx-metal")]
let default_device = mlx_rs::Device::gpu();
#[cfg(not(feature = "mlx-metal"))]
let default_device = mlx_rs::Device::cpu();
match std::env::var("CAR_MLX_DEVICE").ok().as_deref() {
Some("cpu") => mlx_rs::Device::set_default(&mlx_rs::Device::cpu()),
#[cfg(feature = "mlx-metal")]
Some("gpu") => mlx_rs::Device::set_default(&mlx_rs::Device::gpu()),
_ => mlx_rs::Device::set_default(&default_device),
}
let vocab = load_vocabulary(model_dir)?;
info!(vocab_size = vocab.len(), "vocabulary loaded");
info!("loading safetensors weights");
let tensors = load_all_tensors(model_dir)?;
info!(tensors = tensors.len(), "tensors loaded");
let encoder = load_encoder(&tensors)?;
info!(layers = NUM_ENCODER_LAYERS, "conformer encoder loaded");
let prediction = load_prediction_network(&tensors)?;
info!(
lstm_layers = PRED_LAYERS,
hidden = PRED_HIDDEN,
"prediction network loaded"
);
let joint = load_joint_network(&tensors)?;
info!("joint network loaded");
let mut all_params = encoder.all_arrays();
all_params.extend(prediction.all_arrays());
all_params.extend(joint.all_arrays());
mlx_rs::transforms::eval(all_params)
.map_err(|e| InferenceError::InferenceFailed(format!("eval weights: {e}")))?;
info!("Parakeet-TDT model loaded successfully");
Ok(Self {
encoder,
prediction,
joint,
vocab,
})
}
pub fn transcribe(&self, audio_path: &Path) -> Result<String, InferenceError> {
let (text, _words) = self.transcribe_detailed(audio_path)?;
Ok(text)
}
pub fn transcribe_detailed(
&self,
audio_path: &Path,
) -> Result<(String, Vec<crate::tasks::transcribe::TranscribedWord>), InferenceError> {
info!(path = %audio_path.display(), "transcribing audio (detailed)");
let samples = load_wav(audio_path)?;
if samples.is_empty() {
return Ok((String::new(), Vec::new()));
}
info!(
samples = samples.len(),
duration_secs = samples.len() as f32 / SAMPLE_RATE as f32,
"audio loaded"
);
let mel = compute_log_mel(&samples)?;
let mel_frames = mel.shape()[1] as usize;
info!(mel_frames = mel_frames, "mel spectrogram computed");
let map_err = |e: mlx_rs::error::Exception| InferenceError::InferenceFailed(e.to_string());
let encoder_output = self.encoder.forward(&mel).map_err(map_err)?;
encoder_output.eval().map_err(map_err)?;
let enc_frames = encoder_output.shape()[1] as usize;
info!(encoder_frames = enc_frames, "encoder forward complete");
let (text, emissions) =
greedy_tdt_decode(&encoder_output, &self.prediction, &self.joint, &self.vocab)?;
let words = emissions_to_words(&emissions, &self.vocab);
info!(
text_len = text.len(),
word_count = words.len(),
"transcription complete"
);
Ok((text, words))
}
pub fn transcribe_samples(&self, samples: &[f32]) -> Result<String, InferenceError> {
let (text, _words) = self.transcribe_samples_detailed(samples)?;
Ok(text)
}
pub fn transcribe_samples_detailed(
&self,
samples: &[f32],
) -> Result<(String, Vec<crate::tasks::transcribe::TranscribedWord>), InferenceError> {
if samples.is_empty() {
return Ok((String::new(), Vec::new()));
}
let mel = compute_log_mel(samples)?;
let map_err = |e: mlx_rs::error::Exception| InferenceError::InferenceFailed(e.to_string());
let encoder_output = self.encoder.forward(&mel).map_err(map_err)?;
encoder_output.eval().map_err(map_err)?;
let (text, emissions) =
greedy_tdt_decode(&encoder_output, &self.prediction, &self.joint, &self.vocab)?;
let words = emissions_to_words(&emissions, &self.vocab);
Ok((text, words))
}
pub fn vocab_size(&self) -> usize {
self.vocab.len()
}
pub fn sample_rate(&self) -> usize {
SAMPLE_RATE
}
}
#[cfg(test)]
mod tests {
use super::*;
fn vocab() -> Vec<String> {
vec![
"hello".into(),
"\u{2581}world".into(),
"\u{2581}".into(),
"!".into(),
"\u{2581}foo".into(),
"bar".into(),
]
}
fn emit(token_id: u32, start: usize, dur: usize) -> TokenEmission {
TokenEmission {
token_id,
start_frame: start,
duration_frames: dur,
}
}
#[test]
fn first_token_without_marker_starts_a_word() {
let v = vocab();
let em = vec![emit(0, 0, 5), emit(1, 5, 10)]; let words = emissions_to_words(&em, &v);
assert_eq!(words.len(), 2);
assert_eq!(words[0].text, "hello");
assert!((words[0].start - 0.0).abs() < 1e-6);
assert!((words[0].end - (5.0 * 0.08)).abs() < 1e-5);
assert_eq!(words[1].text, "world");
}
#[test]
fn pure_marker_token_is_skipped() {
let v = vocab();
let em = vec![emit(4, 0, 4), emit(2, 4, 1), emit(5, 5, 3)];
let words = emissions_to_words(&em, &v);
assert_eq!(words.len(), 1, "pure ▁ should not split or emit a word");
assert_eq!(words[0].text, "foobar");
}
#[test]
fn punctuation_attaches_to_previous_word() {
let v = vocab();
let em = vec![emit(1, 0, 4), emit(3, 4, 1)];
let words = emissions_to_words(&em, &v);
assert_eq!(words.len(), 1);
assert_eq!(words[0].text, "world!");
}
#[test]
fn zero_duration_does_not_regress_end() {
let v = vocab();
let em = vec![emit(0, 0, 5), emit(3, 5, 0)];
let words = emissions_to_words(&em, &v);
assert_eq!(words.len(), 1);
assert!(words[0].end >= words[0].start);
assert!((words[0].end - (5.0 * 0.08)).abs() < 1e-5);
}
#[test]
fn empty_emissions_yields_empty() {
let v = vocab();
let words = emissions_to_words(&[], &v);
assert!(words.is_empty());
}
#[test]
fn out_of_vocab_token_is_silently_dropped() {
let v = vocab();
let em = vec![emit(0, 0, 5), emit(99, 5, 3), emit(1, 8, 4)];
let words = emissions_to_words(&em, &v);
assert_eq!(words.len(), 2);
assert_eq!(words[0].text, "hello");
assert_eq!(words[1].text, "world");
}
}