use ndarray::Array2;
use ort::execution_providers::CPUExecutionProvider;
use ort::inputs;
use ort::session::builder::GraphOptimizationLevel;
use ort::session::Session;
use ort::value::TensorRef;
use rustfft::num_complex::Complex;
use rustfft::FftPlanner;
use std::f32::consts::PI;
use std::path::Path;
use std::sync::Arc;
const N_FFT: usize = 320;
const HOP_LENGTH: usize = 160;
const WIN_LENGTH: usize = 320;
const N_MELS: usize = 64;
const F_MIN: f32 = 0.0;
const F_MAX: f32 = 8000.0;
const VOCAB: &[&str] = &[
"<unk>", "▁", ".", "е", "а", "с", "и", ",", "о", "т", "н", "м", "у", "й", "л", "я", "в", "д", "з", "к", "но", "▁с", "ы", "г", "▁в", "б", "р", "п", "то", "ть", "ра", "▁по", "ка", "ш", "ни", "ли", "на", "го", "х", "ро", "ва", "▁на", "ю", "ко", "ль", "те", "?", "ч", "ж", "во", "ла", "ре", "да", "▁и", "ло", "ст", "-", "ё", "▁не", "ле", "ри", "де", "та", "ны", "▁В", "▁С", "ь", "ки", "ер", "▁о", "ви", "ти", "ма", "▁за", "▁А", "▁Т", "▁у", "же", "э", "▁М", "ц", "ди", "не", "ру", "че", "ф", "ве", "▁Д", "бо", "▁К", "щ", "▁О", "ми", "▁что", "▁«", "»", "ся", "▁По", "▁про", "e", "a", "ку", "ну", "▁это", "мо", "жи", "▁ко", "▁П", "▁И", "ча", "му", "0", "ты", "ста", "сь", "▁как", "o", "▁мо", "i", "до", "ля", "▁до", "▁от", "У", "Б", "ры", "чи", "ци", "▁бы", "▁Включи", "па", "ключ", "по", "ду", "▁при", "\u{2014}", "Л", "n", "Р", "сто", "r", "▁так", "сти", "Г", "▁На", "Н", "▁об", "▁мне", "l", "Я", "t", "1", "▁За", "s", "Э", "Ч", "Е", "▁есть", "ень", "▁Ну", "2", "▁Сбер", "вер", "▁вот", "ение", "смотр", "В", "▁раз", "Ф", "▁пере", "ешь", "▁тебя", "u", "3", "5", "d", "y", "Х", "4", "З", "S", "С", "h", "c", "m", "9", ":", "8", "6", "7", "M", "B", "П", "D", "T", "!", "k", "g", "О", "C", "Ш", "М", "A", "p", "Ю", "P", "Т", "К", "А", "L", "b", "Д", "ъ", "H", "%", "F", "v", "V", "R", "O", "I", "И", "N", "Ж", "\"", "K", "G", "Ц", "f", "w", "E", "₽", "W", "J", "x", "z", "'", "U", "Y", "&", "Z", "X", "+", "/", "Щ", ";", "j", "Й", "q", "Q", "°", "Ё", "Ы", "€", "$", "«", ];
const BLANK_ID: usize = 256;
#[derive(thiserror::Error, Debug)]
pub enum GigaAMError {
#[error("ORT error: {0}")]
Ort(#[from] ort::Error),
#[error("ndarray shape error: {0}")]
Shape(#[from] ndarray::ShapeError),
#[error("Model file not found: {0}")]
ModelNotFound(String),
#[error("Model not loaded")]
ModelNotLoaded,
}
pub struct GigaAMModel {
session: Session,
mel_filterbank: Array2<f32>,
hann_window: Vec<f32>,
fft: Arc<dyn rustfft::Fft<f32>>,
}
impl Drop for GigaAMModel {
fn drop(&mut self) {
log::debug!("Dropping GigaAMModel");
}
}
impl GigaAMModel {
pub fn new(model_path: &Path) -> Result<Self, GigaAMError> {
if !model_path.exists() {
return Err(GigaAMError::ModelNotFound(model_path.display().to_string()));
}
log::info!("Loading GigaAM model from {:?}...", model_path);
let session = Self::init_session(model_path)?;
let window: Vec<f32> = (0..WIN_LENGTH)
.map(|i| 0.5 * (1.0 - (2.0 * PI * i as f32 / WIN_LENGTH as f32).cos()))
.collect();
let mut planner = FftPlanner::<f32>::new();
let fft = planner.plan_fft_forward(N_FFT);
Ok(Self {
session,
mel_filterbank: compute_mel_filterbank(N_MELS, N_FFT, 16000, F_MIN, F_MAX),
hann_window: window,
fft,
})
}
fn init_session(path: &Path) -> Result<Session, GigaAMError> {
let providers = vec![CPUExecutionProvider::default().build()];
let session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.with_execution_providers(providers)?
.with_parallel_execution(true)?
.commit_from_file(path)?;
for input in &session.inputs {
log::info!(
"Model input: name={}, type={:?}",
input.name,
input.input_type
);
}
for output in &session.outputs {
log::info!(
"Model output: name={}, type={:?}",
output.name,
output.output_type
);
}
Ok(session)
}
pub fn transcribe(&mut self, samples: &[f32]) -> Result<String, GigaAMError> {
if samples.len() < N_FFT {
return Ok(String::new());
}
let mel = self.compute_mel_spectrogram(samples);
let time_steps = mel.shape()[1];
log::debug!(
"Mel spectrogram shape: [{}, {}]",
mel.shape()[0],
mel.shape()[1]
);
let features = mel.insert_axis(ndarray::Axis(0)); let features_dyn = features.into_dyn();
let feature_lengths = ndarray::arr1(&[time_steps as i64]).into_dyn();
let inputs = inputs! {
"features" => TensorRef::from_array_view(features_dyn.view())?,
"feature_lengths" => TensorRef::from_array_view(feature_lengths.view())?,
};
let outputs = self.session.run(inputs)?;
let log_probs = outputs[0].try_extract_array::<f32>()?;
let log_probs = log_probs.to_owned().into_dimensionality::<ndarray::Ix3>()?;
log::debug!("Log probs shape: {:?}", log_probs.shape());
let text = ctc_greedy_decode(&log_probs);
Ok(text)
}
fn compute_mel_spectrogram(&self, audio: &[f32]) -> Array2<f32> {
let n_frames = (audio.len() - N_FFT) / HOP_LENGTH + 1;
let freq_bins = N_FFT / 2 + 1;
let mut power_spec = Array2::<f32>::zeros((freq_bins, n_frames));
for frame_idx in 0..n_frames {
let start = frame_idx * HOP_LENGTH;
let mut fft_buf: Vec<Complex<f32>> = (0..N_FFT)
.map(|i| Complex::new(audio[start + i] * self.hann_window[i], 0.0))
.collect();
self.fft.process(&mut fft_buf);
for (bin, val) in fft_buf.iter().enumerate().take(freq_bins) {
power_spec[[bin, frame_idx]] = val.norm_sqr();
}
}
let mel = self.mel_filterbank.dot(&power_spec);
mel.mapv(|v| v.clamp(1e-9, 1e9).ln())
}
}
fn ctc_greedy_decode(log_probs: &ndarray::Array3<f32>) -> String {
let time_steps = log_probs.shape()[1];
let vocab_size = log_probs.shape()[2];
let mut token_ids: Vec<usize> = Vec::with_capacity(time_steps);
for t in 0..time_steps {
let mut best_id = 0;
let mut best_val = f32::NEG_INFINITY;
for v in 0..vocab_size {
let val = log_probs[[0, t, v]];
if val > best_val {
best_val = val;
best_id = v;
}
}
token_ids.push(best_id);
}
let mut result = String::new();
let mut prev_id: Option<usize> = None;
for &id in &token_ids {
if Some(id) == prev_id {
continue;
}
prev_id = Some(id);
if id == BLANK_ID || id >= VOCAB.len() {
continue;
}
let token = VOCAB[id];
if token == "<unk>" {
continue;
}
if let Some(stripped) = token.strip_prefix('▁') {
if !result.is_empty() {
result.push(' ');
}
result.push_str(stripped);
} else {
result.push_str(token);
}
}
result.trim().to_string()
}
fn compute_mel_filterbank(
n_mels: usize,
n_fft: usize,
sample_rate: u32,
f_min: f32,
f_max: f32,
) -> Array2<f32> {
let n_freqs = n_fft / 2 + 1;
let hz_to_mel = |f: f32| -> f32 { 2595.0 * (1.0 + f / 700.0).log10() };
let mel_to_hz = |m: f32| -> f32 { 700.0 * (10.0f32.powf(m / 2595.0) - 1.0) };
let mel_min = hz_to_mel(f_min);
let mel_max = hz_to_mel(f_max);
let mel_points: Vec<f32> = (0..=n_mels + 1)
.map(|i| mel_min + (mel_max - mel_min) * i as f32 / (n_mels + 1) as f32)
.collect();
let hz_points: Vec<f32> = mel_points.iter().map(|&m| mel_to_hz(m)).collect();
let bin_points: Vec<f32> = hz_points
.iter()
.map(|&f| f * n_fft as f32 / sample_rate as f32)
.collect();
let mut filterbank = Array2::<f32>::zeros((n_mels, n_freqs));
for m in 0..n_mels {
let f_left = bin_points[m];
let f_center = bin_points[m + 1];
let f_right = bin_points[m + 2];
for k in 0..n_freqs {
let freq = k as f32;
if freq >= f_left && freq <= f_center {
let denom = f_center - f_left;
if denom > 0.0 {
filterbank[[m, k]] = (freq - f_left) / denom;
}
} else if freq > f_center && freq <= f_right {
let denom = f_right - f_center;
if denom > 0.0 {
filterbank[[m, k]] = (f_right - freq) / denom;
}
}
}
}
filterbank
}