mod decoder;
mod vocab;
use std::path::Path;
use std::time::Instant;
use ort::session::Session;
use ort::value::Tensor;
use self::decoder::decode_autoregressive;
use self::vocab::Vocab;
use crate::{
ModelCapabilities, SpeechModel, TranscribeError, TranscribeOptions, TranscriptionResult,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CanaryVariant {
Flash,
V2,
}
const FLASH_LANGUAGES: &[&str] = &["en", "de", "es", "fr"];
const V2_LANGUAGES: &[&str] = &[
"bg", "hr", "cs", "da", "nl", "en", "et", "fi", "fr", "de", "el", "hu", "it", "lv", "lt", "mt",
"pl", "pt", "ro", "sk", "sl", "es", "sv", "ru", "uk",
];
impl CanaryVariant {
fn detect(vocab_size: usize) -> Self {
if vocab_size < 10_000 {
CanaryVariant::Flash
} else {
CanaryVariant::V2
}
}
fn name(self) -> &'static str {
match self {
CanaryVariant::Flash => "Canary Flash",
CanaryVariant::V2 => "Canary 1B v2",
}
}
fn languages(self) -> &'static [&'static str] {
match self {
CanaryVariant::Flash => FLASH_LANGUAGES,
CanaryVariant::V2 => V2_LANGUAGES,
}
}
}
#[derive(Debug, Clone)]
pub struct CanaryParams {
pub language: Option<String>,
pub target_language: Option<String>,
pub use_pnc: bool,
pub use_itn: bool,
pub max_sequence_length: usize,
}
impl Default for CanaryParams {
fn default() -> Self {
Self {
language: None,
target_language: None,
use_pnc: true,
use_itn: true,
max_sequence_length: 1024,
}
}
}
pub struct CanaryModel {
preprocessor: Session,
encoder: Session,
decoder: Session,
vocab: Vocab,
variant: CanaryVariant,
}
impl CanaryModel {
pub fn load(
model_dir: &Path,
quantization: &super::Quantization,
) -> Result<Self, TranscribeError> {
if !model_dir.exists() {
return Err(TranscribeError::ModelNotFound(model_dir.to_path_buf()));
}
let load_start = Instant::now();
let preprocessor_path = model_dir.join("nemo128.onnx");
log::info!(
"Loading Canary preprocessor from {:?}...",
preprocessor_path
);
let preprocessor = super::session::create_session(&preprocessor_path)?;
let encoder_path =
super::session::resolve_model_path(model_dir, "encoder-model", quantization);
log::info!("Loading Canary encoder from {:?}...", encoder_path);
let encoder = super::session::create_session(&encoder_path)?;
let decoder_path =
super::session::resolve_model_path(model_dir, "decoder-model", quantization);
log::info!("Loading Canary decoder from {:?}...", decoder_path);
let decoder = super::session::create_session(&decoder_path)?;
let vocab_path = model_dir.join("vocab.txt");
let vocab = Vocab::load(&vocab_path)?;
let variant = CanaryVariant::detect(vocab.size());
log::info!(
"Canary model loaded in {:.2?} (variant: {:?}, vocab: {} tokens)",
load_start.elapsed(),
variant,
vocab.size()
);
Ok(Self {
preprocessor,
encoder,
decoder,
vocab,
variant,
})
}
pub fn transcribe_with(
&mut self,
samples: &[f32],
params: &CanaryParams,
) -> Result<TranscriptionResult, TranscribeError> {
let src_lang = params.language.as_deref().unwrap_or("en");
let tgt_lang = params.target_language.as_deref().unwrap_or(src_lang);
let use_itn = params.use_itn && self.variant != CanaryVariant::Flash;
let total_start = Instant::now();
let preprocess_start = Instant::now();
let num_samples = samples.len();
log::debug!("Preprocessor input: waveforms shape [1, {}]", num_samples);
let waveforms = Tensor::from_array((
vec![1i64, num_samples as i64],
samples.to_vec().into_boxed_slice(),
))?;
let waveforms_lens =
Tensor::from_array((vec![1i64], vec![num_samples as i64].into_boxed_slice()))?;
let mut preprocess_out = self.preprocessor.run(ort::inputs![
"waveforms" => waveforms,
"waveforms_lens" => waveforms_lens
])?;
log::debug!(
"Preprocessor output: features shape {:?} ({:.2?})",
preprocess_out["features"].shape(),
preprocess_start.elapsed()
);
let features = preprocess_out
.remove("features")
.ok_or_else(|| TranscribeError::Inference("Missing features output".to_string()))?;
let features_lens = preprocess_out.remove("features_lens").ok_or_else(|| {
TranscribeError::Inference("Missing features_lens output".to_string())
})?;
let encode_start = Instant::now();
let mut encoder_out = self.encoder.run(ort::inputs![
"audio_signal" => features,
"length" => features_lens
])?;
log::debug!(
"Encoder output: embeddings shape {:?}, mask shape {:?} ({:.2?})",
encoder_out["encoder_embeddings"].shape(),
encoder_out["encoder_mask"].shape(),
encode_start.elapsed()
);
let encoder_embeddings = encoder_out.remove("encoder_embeddings").ok_or_else(|| {
TranscribeError::Inference("Missing encoder_embeddings output".to_string())
})?;
let encoder_mask = encoder_out
.remove("encoder_mask")
.ok_or_else(|| TranscribeError::Inference("Missing encoder_mask output".to_string()))?;
let prompt_tokens = self
.vocab
.build_prompt(src_lang, tgt_lang, params.use_pnc, use_itn)?;
log::debug!(
"Prompt tokens ({}): {:?}",
prompt_tokens.len(),
prompt_tokens
);
let decode_start = Instant::now();
let text = decode_autoregressive(
&mut self.decoder,
&encoder_embeddings,
&encoder_mask,
prompt_tokens,
&self.vocab,
params.max_sequence_length,
)?;
log::debug!("Decoding completed in {:.2?}", decode_start.elapsed());
log::info!(
"Transcription completed in {:.2?}: \"{}\"",
total_start.elapsed(),
text
);
Ok(TranscriptionResult {
text,
segments: None,
})
}
}
impl SpeechModel for CanaryModel {
fn capabilities(&self) -> ModelCapabilities {
ModelCapabilities {
name: self.variant.name(),
engine_id: "canary",
sample_rate: 16000,
languages: self.variant.languages(),
supports_timestamps: false,
supports_translation: true,
supports_streaming: false,
}
}
fn transcribe_raw(
&mut self,
samples: &[f32],
options: &TranscribeOptions,
) -> Result<TranscriptionResult, TranscribeError> {
let src_lang = options.language.as_deref().unwrap_or("en");
let tgt_lang = if options.translate { "en" } else { src_lang };
let params = CanaryParams {
language: Some(src_lang.to_string()),
target_language: Some(tgt_lang.to_string()),
..Default::default()
};
self.transcribe_with(samples, ¶ms)
}
}