mod legacy;
mod tokens;
use {
super::{OperationError, get_session_builder},
async_stream::stream,
futures_util::{Stream, stream},
ndarray::{Array1, Array2, Array3, ArrayD, IxDyn},
ort::{
inputs,
session::{RunOptions, Session, SessionInputValue},
value::TensorRef,
},
std::{collections::HashMap, path::Path, pin::Pin},
tokens::TOKENS,
};
fn decode_sp_token(token: &str) -> String {
if let Some(rest) = token.strip_prefix('▁') {
if rest.chars().next().is_some_and(is_cjk) {
rest.to_owned()
} else {
format!(" {}", rest)
}
} else {
token.to_owned()
}
}
fn is_cjk(c: char) -> bool {
matches!(c,
'\u{3000}'..='\u{303F}' | '\u{3400}'..='\u{4DBF}' | '\u{4E00}'..='\u{9FFF}' | '\u{F900}'..='\u{FAFF}' | '\u{FF00}'..='\u{FFEF}' | '\u{20000}'..='\u{2A6DF}' )
}
pub use legacy::AutomaticSpeechRecognizerLegacy;
enum CacheValue {
F32(ArrayD<f32>),
I64(Array1<i64>),
}
impl Clone for CacheValue {
fn clone(&self) -> Self {
match self {
CacheValue::F32(arr) => CacheValue::F32(arr.clone()),
CacheValue::I64(arr) => CacheValue::I64(arr.clone()),
}
}
}
fn parse_encoder_metadata(session: &Session) -> Result<HashMap<String, String>, OperationError> {
let meta = session.metadata()?;
let mut map = HashMap::new();
for key in [
"num_encoder_layers",
"encoder_dims",
"query_head_dims",
"value_head_dims",
"num_heads",
"cnn_module_kernels",
"left_context_len",
"T",
"decode_chunk_len",
] {
if let Some(val) = meta.custom(key) {
map.insert(key.to_string(), val);
}
}
Ok(map)
}
fn parse_int_list(s: &str) -> Vec<i32> {
s.split(',')
.filter_map(|x| x.trim().parse::<i32>().ok())
.collect()
}
fn build_cache_shapes(meta: &HashMap<String, String>) -> Vec<(String, Vec<i64>)> {
let layers_per_stack = parse_int_list(&meta["num_encoder_layers"]);
let encoder_dims = parse_int_list(&meta["encoder_dims"]);
let query_head_dims = parse_int_list(&meta["query_head_dims"]);
let value_head_dims = parse_int_list(&meta["value_head_dims"]);
let num_heads = parse_int_list(&meta["num_heads"]);
let cnn_kernels = parse_int_list(&meta["cnn_module_kernels"]);
let left_ctx = parse_int_list(&meta["left_context_len"]);
let mut entries = Vec::new();
let mut layer_idx: i32 = 0;
for (stack, &n_layers) in layers_per_stack.iter().enumerate() {
let enc_dim = encoder_dims[stack];
let key_dim = query_head_dims[stack] * num_heads[stack];
let val_dim = value_head_dims[stack] * num_heads[stack];
let ctx = left_ctx[stack];
let cnn_half = cnn_kernels[stack] / 2;
let nonlin_dim = 3 * enc_dim / 4;
for _ in 0..n_layers {
entries.push((
format!("cached_key_{layer_idx}"),
vec![ctx as i64, 1, key_dim as i64],
));
entries.push((
format!("cached_nonlin_attn_{layer_idx}"),
vec![1, 1, ctx as i64, nonlin_dim as i64],
));
entries.push((
format!("cached_val1_{layer_idx}"),
vec![ctx as i64, 1, val_dim as i64],
));
entries.push((
format!("cached_val2_{layer_idx}"),
vec![ctx as i64, 1, val_dim as i64],
));
entries.push((
format!("cached_conv1_{layer_idx}"),
vec![1, enc_dim as i64, cnn_half as i64],
));
entries.push((
format!("cached_conv2_{layer_idx}"),
vec![1, enc_dim as i64, cnn_half as i64],
));
layer_idx += 1;
}
}
entries
}
fn init_caches(
cache_shapes: &[(String, Vec<i64>)],
feature_dim: i32,
) -> HashMap<String, CacheValue> {
let mut caches = HashMap::new();
for (name, shape) in cache_shapes {
let shape_usize: Vec<usize> = shape.iter().map(|&d| d as usize).collect();
let total: usize = shape_usize.iter().product();
let arr = ArrayD::from_shape_vec(IxDyn(&shape_usize), vec![0.0f32; total]).unwrap();
caches.insert(name.clone(), CacheValue::F32(arr));
}
let embed_dim = ((feature_dim - 1) / 2 - 1) / 2;
let embed_shape: Vec<usize> = vec![1, 128, 3, embed_dim as usize];
let embed_total: usize = embed_shape.iter().product();
let embed_arr = ArrayD::from_shape_vec(IxDyn(&embed_shape), vec![0.0f32; embed_total]).unwrap();
caches.insert("embed_states".to_string(), CacheValue::F32(embed_arr));
let lens_arr = Array1::from_vec(vec![0i64]);
caches.insert("processed_lens".to_string(), CacheValue::I64(lens_arr));
caches
}
async fn build_and_run_encoder(
encoder_session: &mut Session,
run_options: &RunOptions,
features: &Array3<f32>,
caches: &HashMap<String, CacheValue>,
) -> Result<(Array3<f32>, HashMap<String, CacheValue>), OperationError> {
let mut inputs: Vec<(&str, SessionInputValue)> = Vec::new();
let x_tensor: TensorRef<f32> = TensorRef::from_array_view(features)?;
inputs.push(("x", SessionInputValue::View(x_tensor.into_dyn())));
let input_names: Vec<String> = encoder_session
.inputs()
.iter()
.skip(1)
.map(|i| i.name().to_string())
.collect();
for name in &input_names {
if let Some(cache) = caches.get(name) {
match cache {
CacheValue::F32(arr) => {
let tensor = TensorRef::from_array_view(arr)?;
inputs.push((name.as_str(), SessionInputValue::View(tensor.into_dyn())));
}
CacheValue::I64(arr) => {
let tensor = TensorRef::from_array_view(arr)?;
inputs.push((name.as_str(), SessionInputValue::View(tensor.into_dyn())));
}
}
}
}
let outputs = encoder_session.run_async(inputs, run_options)?.await?;
let (shape, data) = outputs["encoder_out"].try_extract_tensor::<f32>()?;
let shape_vec: Vec<usize> = shape.iter().map(|&d| d as usize).collect();
let enc_out =
Array3::from_shape_vec((shape_vec[0], shape_vec[1], shape_vec[2]), data.to_vec())?;
let mut new_caches = HashMap::new();
for name in caches.keys() {
let new_name = format!("new_{name}");
if let Some(val) = outputs.get(&new_name) {
let new_value = match &caches[name] {
CacheValue::F32(old_arr) => {
let (s, d) = val.try_extract_tensor::<f32>()?;
let sv: Vec<usize> = s.iter().map(|&d| d as usize).collect();
CacheValue::F32(
ArrayD::from_shape_vec(IxDyn(&sv), d.to_vec())
.unwrap_or_else(|_| old_arr.clone()),
)
}
CacheValue::I64(old_arr) => {
if let Ok((s, d)) = val.try_extract_tensor::<i64>() {
let sv: Vec<usize> = s.iter().map(|&d| d as usize).collect();
let total: usize = sv.iter().product();
CacheValue::I64(Array1::from_vec(d[..total].to_vec()))
} else {
CacheValue::I64(old_arr.clone())
}
}
};
new_caches.insert(name.clone(), new_value);
} else {
new_caches.insert(name.clone(), caches[name].clone());
}
}
Ok((enc_out, new_caches))
}
pub struct AutomaticSpeechRecognizer {
encoder_session: Session,
decoder_session: Session,
joiner_session: Session,
run_options: RunOptions,
context_size: i64,
t_frames: i64,
decode_chunk_len: i64,
cache_shapes: Vec<(String, Vec<i64>)>,
}
impl AutomaticSpeechRecognizer {
pub const NUM_BINS: i32 = AutomaticSpeechRecognizerLegacy::NUM_BINS;
pub fn new<P>(encoder_path: P, decoder_path: P, joiner_path: P) -> Result<Self, OperationError>
where
P: AsRef<Path>,
{
Self::with_config(encoder_path, decoder_path, joiner_path)
}
pub fn with_config<P>(
encoder_path: P,
decoder_path: P,
joiner_path: P,
) -> Result<Self, OperationError>
where
P: AsRef<Path>,
{
let encoder_session = Self::build_session(&encoder_path)?;
let decoder_session = Self::build_session(&decoder_path)?;
let joiner_session = Self::build_session(&joiner_path)?;
let run_options = RunOptions::new()?;
let meta = parse_encoder_metadata(&encoder_session)?;
let t_frames = meta.get("T").and_then(|s| s.parse().ok()).unwrap_or(34);
let decode_chunk_len = meta
.get("decode_chunk_len")
.and_then(|s| s.parse().ok())
.unwrap_or(8);
let context_size = {
let dec_meta = decoder_session.metadata()?;
dec_meta
.custom("context_size")
.and_then(|s| s.parse().ok())
.unwrap_or(2)
};
let cache_shapes = build_cache_shapes(&meta);
Ok(Self {
encoder_session,
decoder_session,
joiner_session,
run_options,
context_size,
t_frames,
decode_chunk_len,
cache_shapes,
})
}
pub fn new_legacy<P: AsRef<Path>>(
model_path: P,
) -> Result<AutomaticSpeechRecognizerLegacy, OperationError> {
AutomaticSpeechRecognizerLegacy::new(model_path)
}
fn build_session<P: AsRef<Path>>(path: P) -> Result<Session, OperationError> {
Ok(get_session_builder()?.commit_from_file(path)?)
}
async fn decode_frame(
&mut self,
enc_frame: &Array2<f32>,
decoder_out: &mut Option<Array2<f32>>,
token_ids: &mut Vec<i64>,
blank_id: i64,
unk_id: i64,
) -> Result<Option<(String, i64)>, OperationError> {
if decoder_out.is_none() {
let ids = token_ids.clone();
*decoder_out = Some(self.run_decoder(&ids).await?);
}
let dec_out = decoder_out.clone().unwrap();
let logits = self.run_joiner(enc_frame, &dec_out).await?;
let pred_id = logits
.row(0)
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.map(|(i, _)| i as i64)
.unwrap_or(0);
if pred_id != blank_id && pred_id != unk_id {
let text = decode_sp_token(TOKENS[pred_id as usize]);
token_ids.push(pred_id);
let ids = token_ids.clone();
*decoder_out = Some(self.run_decoder(&ids).await?);
Ok(Some((text, pred_id)))
} else {
Ok(None)
}
}
async fn run_decoder(&mut self, token_ids: &[i64]) -> Result<Array2<f32>, OperationError> {
let ctx: Vec<i64> = if token_ids.len() >= self.context_size as usize {
token_ids[token_ids.len() - self.context_size as usize..].to_vec()
} else {
let mut pad = vec![0i64; (self.context_size as usize) - token_ids.len()];
pad.extend_from_slice(token_ids);
pad
};
let ctx_arr = Array2::from_shape_vec((1, ctx.len()), ctx)?;
let outputs = self
.decoder_session
.run_async(
inputs![
"y" => TensorRef::from_array_view(&ctx_arr)?,
],
&self.run_options,
)?
.await?;
let out = outputs
.values()
.next()
.ok_or_else(|| OperationError::Ort("decoder returned no outputs".to_string()))?;
let (shape, data) = out.try_extract_tensor::<f32>()?;
let shape_vec: Vec<usize> = shape.iter().map(|&d| d as usize).collect();
let out_arr = Array2::from_shape_vec((shape_vec[0], shape_vec[1]), data.to_vec())?;
Ok(out_arr)
}
async fn run_joiner(
&mut self,
enc_frame: &Array2<f32>,
dec_out: &Array2<f32>,
) -> Result<Array2<f32>, OperationError> {
let outputs = self
.joiner_session
.run_async(
inputs![
"encoder_out" => TensorRef::from_array_view(enc_frame)?,
"decoder_out" => TensorRef::from_array_view(dec_out)?,
],
&self.run_options,
)?
.await?;
let out = outputs
.values()
.next()
.ok_or_else(|| OperationError::Ort("joiner returned no outputs".to_string()))?;
let (shape, data) = out.try_extract_tensor::<f32>()?;
let shape_vec: Vec<usize> = shape.iter().map(|&d| d as usize).collect();
let out_arr = Array2::from_shape_vec((shape_vec[0], shape_vec[1]), data.to_vec())?;
Ok(out_arr)
}
pub fn recognize<'a>(
&'a mut self,
features: &'a [f32],
) -> Pin<Box<dyn Stream<Item = Result<String, OperationError>> + Send + '_>> {
let n_frames = features.len() / Self::NUM_BINS as usize;
if n_frames == 0 {
return Box::pin(stream::empty());
}
Box::pin(stream! {
let t_frames = self.t_frames as usize;
let decode_chunk_len = self.decode_chunk_len;
let blank_id = TOKENS
.iter()
.enumerate()
.find(|i| i.1 == &"<blk>")
.map_or(0, |i| i.0 as i64);
let unk_id = TOKENS
.iter()
.enumerate()
.find(|i| i.1 == &"<unk>")
.map_or(0, |i| i.0 as i64);
let mut num_processed: i64 = 0;
let n_frames_i64 = n_frames as i64;
let mut encoder_caches = init_caches(&self.cache_shapes, Self::NUM_BINS);
let mut token_ids = Vec::new();
let mut decoder_out = None;
while num_processed + self.t_frames <= n_frames_i64 {
let start = num_processed as usize;
let mut chunk_features = Vec::with_capacity(t_frames * Self::NUM_BINS as usize);
for i in 0..t_frames {
let offset = (start + i) * Self::NUM_BINS as usize;
chunk_features.extend_from_slice(&features[offset..offset + Self::NUM_BINS as usize]);
}
let chunk_arr =
Array3::from_shape_vec((1, t_frames, Self::NUM_BINS as _), chunk_features)?;
num_processed += decode_chunk_len;
let (encoder_out, new_caches) =
build_and_run_encoder(&mut self.encoder_session, &self.run_options, &chunk_arr, &encoder_caches).await?;
encoder_caches = new_caches;
let (_, num_frames, encoder_dim) = encoder_out.dim();
for t in 0..num_frames {
let enc_frame = encoder_out
.slice(ndarray::s![.., t, ..])
.to_owned()
.into_shape_with_order((1, encoder_dim))?;
if let Some((text, _)) = self.decode_frame(
&enc_frame,
&mut decoder_out,
&mut token_ids,
blank_id,
unk_id,
).await? {
yield Ok(text);
}
}
}
if num_processed < n_frames_i64 {
let avail = (n_frames_i64 - num_processed) as usize;
let start = num_processed as usize;
let mut padded = vec![0.0f32; t_frames * Self::NUM_BINS as usize];
for i in 0..avail {
let src_offset = (start + i) * Self::NUM_BINS as usize;
let dst_offset = i * Self::NUM_BINS as usize;
padded[dst_offset..dst_offset + Self::NUM_BINS as usize]
.copy_from_slice(&features[src_offset..src_offset + Self::NUM_BINS as usize]);
}
let chunk_arr = Array3::from_shape_vec((1, t_frames, Self::NUM_BINS as _), padded)?;
let (encoder_out, _) =
build_and_run_encoder(&mut self.encoder_session, &self.run_options, &chunk_arr, &encoder_caches).await?;
let (_, num_frames, encoder_dim) = encoder_out.dim();
for t in 0..num_frames {
let enc_frame = encoder_out
.slice(ndarray::s![.., t, ..])
.to_owned()
.into_shape_with_order((1, encoder_dim))?;
if let Some((text, _)) = self.decode_frame(
&enc_frame,
&mut decoder_out,
&mut token_ids,
blank_id,
unk_id,
).await? {
yield Ok(text);
}
}
}
})
}
}