#![cfg(feature = "parakeet")]
use crate::stt::SttProvider;
use crate::{Result, VoiceError};
use async_trait::async_trait;
use ndarray::{Array, Array1, Array2, Array3, ArrayD, ArrayViewD, IxDyn};
use once_cell::sync::Lazy;
use ort::execution_providers::CPUExecutionProvider;
use ort::inputs;
use ort::session::builder::GraphOptimizationLevel;
use ort::session::Session;
use ort::value::TensorRef;
use regex::Regex;
use std::fs;
use std::path::{Path, PathBuf};
use std::sync::Mutex;
const SUBSAMPLING_FACTOR: usize = 8;
const WINDOW_SIZE_S: f32 = 0.01;
const MAX_TOKENS_PER_STEP: usize = 10;
const MODEL_NAME: &str = "parakeet-tdt-0.6b-v2-int8";
const HF_BASE_URL: &str = "https://huggingface.co/istupakov/parakeet-tdt-0.6b-v2-onnx/resolve/main";
const MODEL_FILES: &[&str] = &[
"encoder-model.int8.onnx",
"decoder_joint-model.int8.onnx",
"nemo128.onnx",
"vocab.txt",
];
static DECODE_SPACE_RE: Lazy<std::result::Result<Regex, regex::Error>> =
Lazy::new(|| Regex::new(r"\A\s|\s\B|(\s)\b"));
pub type DecoderState = (Array3<f32>, Array3<f32>);
#[derive(Debug, Clone)]
pub struct TimestampedResult {
pub text: String,
pub timestamps: Vec<f32>,
pub tokens: Vec<String>,
}
#[derive(thiserror::Error, Debug)]
pub enum ParakeetError {
#[error("ORT error: {0}")]
Ort(#[from] ort::Error),
#[error("io error: {0}")]
Io(#[from] std::io::Error),
#[error("ndarray shape error: {0}")]
Shape(#[from] ndarray::ShapeError),
#[error("model input not found: {0}")]
InputNotFound(String),
#[error("model output not found: {0}")]
OutputNotFound(String),
#[error("failed to read tensor shape for input: {0}")]
TensorShape(String),
#[error("model download failed: {0}")]
Download(String),
#[error("vocab missing required <blk> token")]
NoBlankToken,
}
impl From<ParakeetError> for VoiceError {
fn from(e: ParakeetError) -> Self {
VoiceError::Stt(e.to_string())
}
}
pub struct ParakeetModel {
encoder: Session,
decoder_joint: Session,
preprocessor: Session,
vocab: Vec<String>,
blank_idx: i32,
vocab_size: usize,
}
impl ParakeetModel {
pub fn new<P: AsRef<Path>>(
model_dir: P,
quantized: bool,
) -> std::result::Result<Self, ParakeetError> {
let encoder = Self::init_session(&model_dir, "encoder-model", quantized)?;
let decoder_joint = Self::init_session(&model_dir, "decoder_joint-model", quantized)?;
let preprocessor = Self::init_session(&model_dir, "nemo128", false)?;
let (vocab, blank_idx) = Self::load_vocab(&model_dir)?;
let vocab_size = vocab.len();
tracing::info!(
"[parakeet] loaded {} ({} vocab tokens, blank_idx={})",
model_dir.as_ref().display(),
vocab_size,
blank_idx
);
Ok(Self {
encoder,
decoder_joint,
preprocessor,
vocab,
blank_idx,
vocab_size,
})
}
fn init_session<P: AsRef<Path>>(
model_dir: P,
model_name: &str,
try_quantized: bool,
) -> std::result::Result<Session, ParakeetError> {
let providers = vec![CPUExecutionProvider::default().build()];
let model_filename = if try_quantized {
let q = format!("{}.int8.onnx", model_name);
if model_dir.as_ref().join(&q).exists() {
q
} else {
format!("{}.onnx", model_name)
}
} else {
format!("{}.onnx", model_name)
};
let session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.with_execution_providers(providers)?
.with_parallel_execution(true)?
.commit_from_file(model_dir.as_ref().join(&model_filename))?;
tracing::debug!(
"[parakeet] session opened: {} ({} inputs)",
model_filename,
session.inputs.len()
);
Ok(session)
}
fn load_vocab<P: AsRef<Path>>(
model_dir: P,
) -> std::result::Result<(Vec<String>, i32), ParakeetError> {
let path = model_dir.as_ref().join("vocab.txt");
let content = fs::read_to_string(path)?;
let mut max_id = 0usize;
let mut entries: Vec<(String, usize)> = Vec::new();
let mut blank_idx: Option<usize> = None;
for line in content.lines() {
let parts: Vec<&str> = line.trim_end().split(' ').collect();
if parts.len() < 2 {
continue;
}
let token = parts[0].to_string();
let Ok(id) = parts[1].parse::<usize>() else {
continue;
};
if token == "<blk>" {
blank_idx = Some(id);
}
entries.push((token, id));
max_id = max_id.max(id);
}
let mut vocab = vec![String::new(); max_id + 1];
for (token, id) in entries {
vocab[id] = token.replace('\u{2581}', " ");
}
let blank_idx = blank_idx.ok_or(ParakeetError::NoBlankToken)? as i32;
Ok((vocab, blank_idx))
}
pub fn preprocess(
&mut self,
waveforms: &ArrayViewD<f32>,
waveforms_lens: &ArrayViewD<i64>,
) -> std::result::Result<(ArrayD<f32>, ArrayD<i64>), ParakeetError> {
let outputs = self.preprocessor.run(inputs![
"waveforms" => TensorRef::from_array_view(waveforms.view())?,
"waveforms_lens" => TensorRef::from_array_view(waveforms_lens.view())?,
])?;
let features = outputs
.get("features")
.ok_or_else(|| ParakeetError::OutputNotFound("features".into()))?
.try_extract_array()?;
let features_lens = outputs
.get("features_lens")
.ok_or_else(|| ParakeetError::OutputNotFound("features_lens".into()))?
.try_extract_array()?;
Ok((features.to_owned(), features_lens.to_owned()))
}
pub fn encode(
&mut self,
audio_signal: &ArrayViewD<f32>,
length: &ArrayViewD<i64>,
) -> std::result::Result<(ArrayD<f32>, ArrayD<i64>), ParakeetError> {
let outputs = self.encoder.run(inputs![
"audio_signal" => TensorRef::from_array_view(audio_signal.view())?,
"length" => TensorRef::from_array_view(length.view())?,
])?;
let encoder_output = outputs
.get("outputs")
.ok_or_else(|| ParakeetError::OutputNotFound("outputs".into()))?
.try_extract_array()?;
let encoded_lengths = outputs
.get("encoded_lengths")
.ok_or_else(|| ParakeetError::OutputNotFound("encoded_lengths".into()))?
.try_extract_array()?;
let encoder_output = encoder_output.permuted_axes(IxDyn(&[0, 2, 1]));
Ok((encoder_output.to_owned(), encoded_lengths.to_owned()))
}
pub fn create_decoder_state(&self) -> std::result::Result<DecoderState, ParakeetError> {
let inputs = &self.decoder_joint.inputs;
let s1_shape = inputs
.iter()
.find(|i| i.name == "input_states_1")
.ok_or_else(|| ParakeetError::InputNotFound("input_states_1".into()))?
.input_type
.tensor_shape()
.ok_or_else(|| ParakeetError::TensorShape("input_states_1".into()))?;
let s2_shape = inputs
.iter()
.find(|i| i.name == "input_states_2")
.ok_or_else(|| ParakeetError::InputNotFound("input_states_2".into()))?
.input_type
.tensor_shape()
.ok_or_else(|| ParakeetError::TensorShape("input_states_2".into()))?;
let state1 = Array::zeros((s1_shape[0] as usize, 1, s1_shape[2] as usize));
let state2 = Array::zeros((s2_shape[0] as usize, 1, s2_shape[2] as usize));
Ok((state1, state2))
}
pub fn decode_step(
&mut self,
prev_tokens: &[i32],
prev_state: &DecoderState,
encoder_out: &ArrayViewD<f32>,
) -> std::result::Result<(ArrayD<f32>, DecoderState), ParakeetError> {
let target_token = prev_tokens.last().copied().unwrap_or(self.blank_idx);
let encoder_outputs = encoder_out
.to_owned()
.insert_axis(ndarray::Axis(0))
.insert_axis(ndarray::Axis(2));
let targets = Array2::from_shape_vec((1, 1), vec![target_token])?;
let target_length = Array1::from_vec(vec![1i32]);
let outputs = self.decoder_joint.run(inputs![
"encoder_outputs" => TensorRef::from_array_view(encoder_outputs.view())?,
"targets" => TensorRef::from_array_view(targets.view())?,
"target_length" => TensorRef::from_array_view(target_length.view())?,
"input_states_1" => TensorRef::from_array_view(prev_state.0.view())?,
"input_states_2" => TensorRef::from_array_view(prev_state.1.view())?,
])?;
let logits = outputs
.get("outputs")
.ok_or_else(|| ParakeetError::OutputNotFound("outputs".into()))?
.try_extract_array()?;
let state1 = outputs
.get("output_states_1")
.ok_or_else(|| ParakeetError::OutputNotFound("output_states_1".into()))?
.try_extract_array()?;
let state2 = outputs
.get("output_states_2")
.ok_or_else(|| ParakeetError::OutputNotFound("output_states_2".into()))?
.try_extract_array()?;
let logits = logits.remove_axis(ndarray::Axis(0));
let state1_3d = state1.to_owned().into_dimensionality::<ndarray::Ix3>()?;
let state2_3d = state2.to_owned().into_dimensionality::<ndarray::Ix3>()?;
Ok((logits.to_owned(), (state1_3d, state2_3d)))
}
pub fn recognize_batch(
&mut self,
waveforms: &ArrayViewD<f32>,
waveforms_len: &ArrayViewD<i64>,
) -> std::result::Result<Vec<TimestampedResult>, ParakeetError> {
let (features, features_lens) = self.preprocess(waveforms, waveforms_len)?;
let (encoder_out, encoder_out_lens) =
self.encode(&features.view(), &features_lens.view())?;
let mut results = Vec::new();
for (encodings, &enc_len) in encoder_out.outer_iter().zip(encoder_out_lens.iter()) {
let (tokens, timestamps) = self.decode_sequence(&encodings.view(), enc_len as usize)?;
results.push(self.decode_tokens(tokens, timestamps));
}
Ok(results)
}
fn decode_sequence(
&mut self,
encodings: &ArrayViewD<f32>,
encodings_len: usize,
) -> std::result::Result<(Vec<i32>, Vec<usize>), ParakeetError> {
let mut prev_state = self.create_decoder_state()?;
let mut tokens: Vec<i32> = Vec::new();
let mut timestamps: Vec<usize> = Vec::new();
let mut t = 0usize;
let mut emitted_at_step = 0usize;
while t < encodings_len {
let encoder_step = encodings.slice(ndarray::s![t, ..]);
let encoder_step_dyn = encoder_step.to_owned().into_dyn();
let (probs, new_state) =
self.decode_step(&tokens, &prev_state, &encoder_step_dyn.view())?;
let probs_slice = probs.as_slice().ok_or_else(|| {
ParakeetError::Shape(ndarray::ShapeError::from_kind(
ndarray::ErrorKind::IncompatibleShape,
))
})?;
let vocab_logits = if probs.len() > self.vocab_size {
&probs_slice[..self.vocab_size]
} else {
probs_slice
};
let token = vocab_logits
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i as i32)
.unwrap_or(self.blank_idx);
if token != self.blank_idx {
prev_state = new_state;
tokens.push(token);
timestamps.push(t);
emitted_at_step += 1;
}
if token == self.blank_idx || emitted_at_step == MAX_TOKENS_PER_STEP {
t += 1;
emitted_at_step = 0;
}
}
if tokens.is_empty() {
tracing::debug!(
"[parakeet] zero tokens decoded across {} encoder steps — likely silence or short audio",
encodings_len
);
}
Ok((tokens, timestamps))
}
fn decode_tokens(&self, ids: Vec<i32>, timestamps: Vec<usize>) -> TimestampedResult {
tokens_to_text(&self.vocab, &ids, ×tamps)
}
pub fn transcribe_samples(
&mut self,
samples: Vec<f32>,
) -> std::result::Result<TimestampedResult, ParakeetError> {
let samples_len = samples.len();
let waveforms = Array2::from_shape_vec((1, samples_len), samples)?.into_dyn();
let waveforms_lens = Array1::from_vec(vec![samples_len as i64]).into_dyn();
let mut results = self.recognize_batch(&waveforms.view(), &waveforms_lens.view())?;
results
.pop()
.ok_or_else(|| ParakeetError::Io(std::io::Error::other("empty result")))
}
}
fn tokens_to_text(vocab: &[String], ids: &[i32], timestamps: &[usize]) -> TimestampedResult {
let tokens: Vec<String> = ids
.iter()
.filter_map(|&id| {
let idx = id as usize;
(idx < vocab.len()).then(|| vocab[idx].clone())
})
.collect();
let text = match &*DECODE_SPACE_RE {
Ok(re) => re
.replace_all(&tokens.join(""), |caps: ®ex::Captures| {
if caps.get(1).is_some() {
" "
} else {
""
}
})
.to_string(),
Err(_) => tokens.join(""),
};
let float_timestamps: Vec<f32> = timestamps
.iter()
.map(|&t| WINDOW_SIZE_S * SUBSAMPLING_FACTOR as f32 * t as f32)
.collect();
TimestampedResult {
text,
timestamps: float_timestamps,
tokens,
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct ParakeetPartial {
pub text: String,
pub timestamps: Vec<f32>,
pub tokens: Vec<String>,
pub is_final: bool,
}
impl ParakeetModel {
fn decode_sequence_streaming<F>(
&mut self,
encodings: &ArrayViewD<f32>,
encodings_len: usize,
mut on_token: F,
) -> std::result::Result<(Vec<i32>, Vec<usize>), ParakeetError>
where
F: FnMut(&[i32], &[usize]),
{
let mut prev_state = self.create_decoder_state()?;
let mut tokens: Vec<i32> = Vec::new();
let mut timestamps: Vec<usize> = Vec::new();
let mut t = 0usize;
let mut emitted_at_step = 0usize;
while t < encodings_len {
let encoder_step = encodings.slice(ndarray::s![t, ..]);
let encoder_step_dyn = encoder_step.to_owned().into_dyn();
let (probs, new_state) =
self.decode_step(&tokens, &prev_state, &encoder_step_dyn.view())?;
let probs_slice = probs.as_slice().ok_or_else(|| {
ParakeetError::Shape(ndarray::ShapeError::from_kind(
ndarray::ErrorKind::IncompatibleShape,
))
})?;
let vocab_logits = if probs.len() > self.vocab_size {
&probs_slice[..self.vocab_size]
} else {
probs_slice
};
let token = vocab_logits
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i as i32)
.unwrap_or(self.blank_idx);
if token != self.blank_idx {
prev_state = new_state;
tokens.push(token);
timestamps.push(t);
emitted_at_step += 1;
on_token(&tokens, ×tamps);
}
if token == self.blank_idx || emitted_at_step == MAX_TOKENS_PER_STEP {
t += 1;
emitted_at_step = 0;
}
}
Ok((tokens, timestamps))
}
pub fn transcribe_samples_streaming<F>(
&mut self,
samples: Vec<f32>,
mut on_partial: F,
) -> std::result::Result<TimestampedResult, ParakeetError>
where
F: FnMut(ParakeetPartial),
{
let samples_len = samples.len();
let waveforms = Array2::from_shape_vec((1, samples_len), samples)?.into_dyn();
let waveforms_lens = Array1::from_vec(vec![samples_len as i64]).into_dyn();
let (features, features_lens) =
self.preprocess(&waveforms.view(), &waveforms_lens.view())?;
let (encoder_out, encoder_out_lens) =
self.encode(&features.view(), &features_lens.view())?;
let encodings = encoder_out
.outer_iter()
.next()
.ok_or_else(|| ParakeetError::Io(std::io::Error::other("empty encoder output")))?;
let enc_len = *encoder_out_lens
.iter()
.next()
.ok_or_else(|| ParakeetError::Io(std::io::Error::other("empty encoder lens")))?
as usize;
let vocab_clone = self.vocab.clone();
let (tokens, timestamps) = self.decode_sequence_streaming(
&encodings.view(),
enc_len,
|acc_tokens, acc_timestamps| {
let snapshot = tokens_to_text(&vocab_clone, acc_tokens, acc_timestamps);
on_partial(ParakeetPartial {
text: snapshot.text,
timestamps: snapshot.timestamps,
tokens: snapshot.tokens,
is_final: false,
});
},
)?;
let final_result = self.decode_tokens(tokens, timestamps);
on_partial(ParakeetPartial {
text: final_result.text.clone(),
timestamps: final_result.timestamps.clone(),
tokens: final_result.tokens.clone(),
is_final: true,
});
Ok(final_result)
}
}
pub struct ParakeetSttProvider {
model_dir: PathBuf,
quantized: bool,
model: Mutex<Option<ParakeetModel>>,
}
impl ParakeetSttProvider {
pub fn new(model_dir: impl Into<PathBuf>) -> Self {
Self {
model_dir: model_dir.into(),
quantized: true,
model: Mutex::new(None),
}
}
pub fn default_model_dir() -> std::result::Result<PathBuf, ParakeetError> {
let home = std::env::var_os("HOME").map(PathBuf::from).ok_or_else(|| {
ParakeetError::Io(std::io::Error::new(
std::io::ErrorKind::NotFound,
"HOME not set",
))
})?;
Ok(home.join(".car").join("models").join(MODEL_NAME))
}
fn ensure_model_files(&self) -> std::result::Result<(), ParakeetError> {
fs::create_dir_all(&self.model_dir)?;
for filename in MODEL_FILES {
let local = self.model_dir.join(filename);
if local.exists() {
continue;
}
let url = format!("{}/{}", HF_BASE_URL, filename);
tracing::info!(
"[parakeet] downloading {} → {} (this can take a few minutes for the encoder)",
url,
local.display()
);
let bytes = reqwest::blocking::get(&url)
.and_then(|r| r.error_for_status())
.and_then(|r| r.bytes())
.map_err(|e| ParakeetError::Download(format!("{}: {}", filename, e)))?;
fs::write(&local, &bytes)?;
tracing::info!(
"[parakeet] wrote {} ({} MB)",
filename,
bytes.len() / (1024 * 1024)
);
}
Ok(())
}
pub fn prepare(&self) -> std::result::Result<(), ParakeetError> {
self.ensure_loaded()
}
fn ensure_loaded(&self) -> std::result::Result<(), ParakeetError> {
let mut guard = self
.model
.lock()
.map_err(|e| ParakeetError::Io(std::io::Error::other(format!("lock: {}", e))))?;
if guard.is_some() {
return Ok(());
}
self.ensure_model_files()?;
let model = ParakeetModel::new(&self.model_dir, self.quantized)?;
*guard = Some(model);
Ok(())
}
}
impl ParakeetSttProvider {
pub async fn transcribe_streaming<F>(
&self,
samples: &[f32],
sample_rate: u32,
on_partial: F,
) -> Result<String>
where
F: FnMut(ParakeetPartial) + Send,
{
if sample_rate != 16_000 {
return Err(VoiceError::Stt(format!(
"Parakeet expects 16 kHz, got {} Hz",
sample_rate
)));
}
let samples_owned: Vec<f32> = samples.to_vec();
let provider_ref: &ParakeetSttProvider = self;
let mut on_partial = on_partial;
let result: TimestampedResult = tokio::task::block_in_place(|| {
provider_ref.ensure_loaded()?;
let mut guard = provider_ref
.model
.lock()
.map_err(|e| ParakeetError::Io(std::io::Error::other(format!("lock: {}", e))))?;
let model = guard.as_mut().expect("ensure_loaded set Some");
model.transcribe_samples_streaming(samples_owned, &mut on_partial)
})?;
Ok(result.text)
}
}
#[async_trait]
impl SttProvider for ParakeetSttProvider {
async fn transcribe(&self, samples: &[f32], sample_rate: u32) -> Result<String> {
if sample_rate != 16_000 {
return Err(VoiceError::Stt(format!(
"Parakeet expects 16 kHz, got {} Hz",
sample_rate
)));
}
let samples_owned: Vec<f32> = samples.to_vec();
let provider_ref: &ParakeetSttProvider = self;
let result: TimestampedResult = tokio::task::block_in_place(|| {
provider_ref.ensure_loaded()?;
let mut guard = provider_ref
.model
.lock()
.map_err(|e| ParakeetError::Io(std::io::Error::other(format!("lock: {}", e))))?;
let model = guard.as_mut().expect("ensure_loaded set Some");
model.transcribe_samples(samples_owned)
})?;
Ok(result.text)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parses_vocab_lines_and_finds_blank() {
let tmp = tempfile::tempdir().unwrap();
let vocab = "\u{2581}hello 0\nworld 1\n<blk> 2\n";
std::fs::write(tmp.path().join("vocab.txt"), vocab).unwrap();
let (tokens, blank) = ParakeetModel::load_vocab(tmp.path()).unwrap();
assert_eq!(tokens.len(), 3);
assert_eq!(tokens[0], " hello"); assert_eq!(tokens[1], "world");
assert_eq!(tokens[2], "<blk>");
assert_eq!(blank, 2);
}
#[test]
fn vocab_without_blank_token_errors() {
let tmp = tempfile::tempdir().unwrap();
std::fs::write(tmp.path().join("vocab.txt"), "hello 0\nworld 1\n").unwrap();
let err = ParakeetModel::load_vocab(tmp.path()).unwrap_err();
matches!(err, ParakeetError::NoBlankToken);
}
#[test]
fn vocab_skips_malformed_lines() {
let tmp = tempfile::tempdir().unwrap();
let vocab = "no_id_here\nhello 0\nbad_id abc\nworld 1\n<blk> 2\n";
std::fs::write(tmp.path().join("vocab.txt"), vocab).unwrap();
let (tokens, blank) = ParakeetModel::load_vocab(tmp.path()).unwrap();
assert_eq!(tokens.len(), 3);
assert_eq!(tokens[0], "hello");
assert_eq!(tokens[1], "world");
assert_eq!(blank, 2);
}
#[test]
fn space_regex_substitutes_correctly() {
let regex = Regex::new(r"\A\s|\s\B|(\s)\b").unwrap();
let result = regex
.replace_all(
" hello world",
|caps: ®ex::Captures| {
if caps.get(1).is_some() {
" "
} else {
""
}
},
)
.to_string();
assert_eq!(result, "hello world");
}
#[test]
fn tokens_to_text_joins_and_normalizes_spaces() {
let vocab = vec![
" hello".to_string(),
"world".to_string(),
" how".to_string(),
];
let result = tokens_to_text(&vocab, &[0, 1, 2], &[0, 1, 5]);
assert_eq!(result.text, "helloworld how");
assert_eq!(result.tokens, vec![" hello", "world", " how"]);
assert_eq!(result.timestamps.len(), 3);
assert!((result.timestamps[0] - 0.0).abs() < 1e-6);
assert!((result.timestamps[1] - 0.08).abs() < 1e-6);
assert!((result.timestamps[2] - 0.40).abs() < 1e-6);
}
#[test]
fn tokens_to_text_skips_out_of_range_ids() {
let vocab = vec!["hi".to_string(), "there".to_string()];
let result = tokens_to_text(&vocab, &[0, 5, 1], &[0, 1, 2]);
assert_eq!(result.tokens, vec!["hi", "there"]);
}
#[test]
fn parakeet_partial_extends_monotonically() {
let vocab = vec![
" hello".to_string(),
" world".to_string(),
" how".to_string(),
];
let p1 = ParakeetPartial {
text: tokens_to_text(&vocab, &[0], &[0]).text,
timestamps: tokens_to_text(&vocab, &[0], &[0]).timestamps,
tokens: vec![" hello".into()],
is_final: false,
};
let p2 = ParakeetPartial {
text: tokens_to_text(&vocab, &[0, 1], &[0, 1]).text,
timestamps: tokens_to_text(&vocab, &[0, 1], &[0, 1]).timestamps,
tokens: vec![" hello".into(), " world".into()],
is_final: false,
};
let p3 = ParakeetPartial {
text: tokens_to_text(&vocab, &[0, 1, 2], &[0, 1, 2]).text,
timestamps: tokens_to_text(&vocab, &[0, 1, 2], &[0, 1, 2]).timestamps,
tokens: vec![" hello".into(), " world".into(), " how".into()],
is_final: true,
};
assert!(p2.text.starts_with(&p1.text));
assert!(p3.text.starts_with(&p2.text));
assert!(p3.is_final);
assert!(!p2.is_final);
assert!(p3.timestamps.len() > p2.timestamps.len());
}
#[test]
fn timestamp_conversion_uses_subsampling_constant() {
assert!((WINDOW_SIZE_S * SUBSAMPLING_FACTOR as f32 - 0.08).abs() < 1e-6);
}
#[test]
fn default_model_dir_is_under_dotcar() {
std::env::set_var("HOME", "/tmp/test-home");
let dir = ParakeetSttProvider::default_model_dir().unwrap();
assert!(dir.ends_with(".car/models/parakeet-tdt-0.6b-v2-int8"));
}
}