1use std::{
2 fs::File,
3 io::{BufRead, BufReader},
4 path::Path,
5};
6
7use hound::{SampleFormat, WavReader};
8use kaldi_native_fbank::{FbankComputer, FbankOptions, OnlineFeature, online::FeatureComputer};
9use ort::{
10 session::{Session, builder::GraphOptimizationLevel},
11 value::Tensor,
12};
13
14pub const PCM_SCALE: f32 = 32_768.0;
15pub const EXPECTED_SAMPLE_RATE: u32 = 16_000;
16pub const DEFAULT_BLANK_ID: usize = 1_024;
17pub const FBANK_BINS: usize = 80;
18
19#[derive(Debug)]
20pub enum CitrinetError {
21 Io(std::io::Error),
22 Audio(hound::Error),
23 Ort(ort::Error),
24 InvalidSampleRate { expected: u32, got: u32 },
25 InvalidChannels(u16),
26 EmptyAudio,
27 Feature(String),
28 ModelOutput(String),
29 Tokens(String),
30}
31
32impl std::fmt::Display for CitrinetError {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 match self {
35 Self::Io(err) => write!(f, "I/O error: {}", err),
36 Self::Audio(err) => write!(f, "WAV error: {}", err),
37 Self::Ort(err) => write!(f, "ONNX Runtime error: {}", err),
38 Self::InvalidSampleRate { expected, got } => {
39 write!(f, "expected {} Hz audio, got {} Hz", expected, got)
40 }
41 Self::InvalidChannels(ch) => write!(f, "expected mono audio, got {} channels", ch),
42 Self::EmptyAudio => write!(f, "no samples found in audio"),
43 Self::Feature(msg) => write!(f, "feature extraction failed: {}", msg),
44 Self::ModelOutput(msg) => write!(f, "model output error: {}", msg),
45 Self::Tokens(msg) => write!(f, "token table error: {}", msg),
46 }
47 }
48}
49
50impl std::error::Error for CitrinetError {}
51
52impl From<std::io::Error> for CitrinetError {
53 fn from(value: std::io::Error) -> Self {
54 Self::Io(value)
55 }
56}
57
58impl From<hound::Error> for CitrinetError {
59 fn from(value: hound::Error) -> Self {
60 Self::Audio(value)
61 }
62}
63
64impl From<ort::Error> for CitrinetError {
65 fn from(value: ort::Error) -> Self {
66 Self::Ort(value)
67 }
68}
69
70pub struct Citrinet {
71 session: Session,
72 tokens: Vec<String>,
73 blank_id: usize,
74}
75
76pub struct CitrinetResult {
77 pub text: String,
78 pub token_ids: Vec<usize>,
79 pub log_probs: Vec<f32>,
80 pub log_prob_shape: [usize; 3],
81}
82
83impl Citrinet {
84 pub fn from_files(
85 model_path: impl AsRef<Path>,
86 tokens_path: impl AsRef<Path>,
87 ) -> Result<Self, CitrinetError> {
88 let session = Session::builder()?
89 .with_optimization_level(GraphOptimizationLevel::Level3)?
90 .commit_from_file(model_path)?;
91 let tokens = load_tokens(tokens_path)?;
92 Ok(Self {
93 session,
94 tokens,
95 blank_id: DEFAULT_BLANK_ID,
96 })
97 }
98
99 pub fn with_blank_id(mut self, blank_id: usize) -> Self {
100 self.blank_id = blank_id;
101 self
102 }
103
104 pub fn infer_file(
105 &mut self,
106 wav_path: impl AsRef<Path>,
107 ) -> Result<CitrinetResult, CitrinetError> {
108 let (samples, sample_rate) = read_wav(wav_path)?;
109 self.infer_samples(&samples, sample_rate)
110 }
111
112 pub fn infer_samples(
113 &mut self,
114 samples: &[i16],
115 sample_rate: u32,
116 ) -> Result<CitrinetResult, CitrinetError> {
117 if sample_rate != EXPECTED_SAMPLE_RATE {
118 return Err(CitrinetError::InvalidSampleRate {
119 expected: EXPECTED_SAMPLE_RATE,
120 got: sample_rate,
121 });
122 }
123
124 let scaled: Vec<f32> = samples.iter().map(|s| *s as f32 * PCM_SCALE).collect();
125 let (features, frames, dim) = compute_fbank(&scaled, sample_rate)?;
126 let nct = to_nct(&features, frames, dim);
127 let (log_probs, shape) = self.run_model(&nct, frames, dim)?;
128 let time = shape[1];
129 let vocab = shape[2];
130 let token_ids = greedy_decode(&log_probs, time, vocab, self.blank_id);
131 let text = token_ids
132 .iter()
133 .filter(|&&id| id != self.blank_id)
134 .filter_map(|&id| self.tokens.get(id))
135 .fold(String::new(), |mut acc, sym| {
136 acc.push_str(sym);
137 acc
138 });
139
140 Ok(CitrinetResult {
141 text,
142 token_ids,
143 log_probs,
144 log_prob_shape: shape,
145 })
146 }
147
148 fn run_model(
149 &mut self,
150 nct_features: &[f32],
151 frames: usize,
152 dim: usize,
153 ) -> Result<(Vec<f32>, [usize; 3]), CitrinetError> {
154 let feature_tensor = Tensor::from_array(([1usize, dim, frames], nct_features.to_vec()))?;
155 let length_tensor = Tensor::from_array(([1usize], vec![frames as i64]))?;
156 let outputs = self
157 .session
158 .run(ort::inputs![feature_tensor, length_tensor])?;
159 if outputs.len() == 0 {
160 return Err(CitrinetError::ModelOutput(
161 "model produced no outputs".to_string(),
162 ));
163 }
164 let output = &outputs[0];
165 let (shape, data) = output
166 .try_extract_tensor::<f32>()
167 .map_err(|e| CitrinetError::ModelOutput(e.to_string()))?;
168 if shape.len() != 3 {
169 return Err(CitrinetError::ModelOutput(format!(
170 "expected 3D logits, got shape {:?}",
171 shape
172 )));
173 }
174 let dims: Vec<usize> = shape.iter().map(|d| *d as usize).collect();
175 let [batch, time, vocab]: [usize; 3] = dims
176 .clone()
177 .try_into()
178 .map_err(|_| CitrinetError::ModelOutput("invalid logits rank".to_string()))?;
179 if batch != 1 {
180 return Err(CitrinetError::ModelOutput(format!(
181 "only batch size 1 supported, got {}",
182 batch
183 )));
184 }
185 if time == 0 || vocab == 0 {
186 return Err(CitrinetError::ModelOutput(
187 "empty logits returned by model".to_string(),
188 ));
189 }
190 let expected = time
191 .checked_mul(vocab)
192 .and_then(|v| v.checked_mul(batch))
193 .ok_or_else(|| CitrinetError::ModelOutput("logit shape overflow".to_string()))?;
194 if data.len() != expected {
195 return Err(CitrinetError::ModelOutput(format!(
196 "logit data length {} does not match shape {:?}",
197 data.len(),
198 dims
199 )));
200 }
201 Ok((data.to_vec(), [batch, time, vocab]))
202 }
203}
204
205fn load_tokens(path: impl AsRef<Path>) -> Result<Vec<String>, CitrinetError> {
206 let file = File::open(path).map_err(CitrinetError::Io)?;
207 let reader = BufReader::new(file);
208 let mut table = Vec::new();
209 for line in reader.lines() {
210 let line = line?;
211 let mut parts = line.split_whitespace();
212 let Some(symbol) = parts.next() else { continue };
213 let Some(idx) = parts.next().and_then(|p| p.parse::<usize>().ok()) else {
214 continue;
215 };
216 if table.len() <= idx {
217 table.resize(idx + 1, String::new());
218 }
219 table[idx] = symbol.to_string();
220 }
221 if table.is_empty() {
222 return Err(CitrinetError::Tokens(
223 "token file contained no entries".to_string(),
224 ));
225 }
226 Ok(table)
227}
228
229fn read_wav(path: impl AsRef<Path>) -> Result<(Vec<i16>, u32), CitrinetError> {
230 let mut reader = WavReader::open(path)?;
231 let spec = reader.spec();
232 if spec.channels != 1 {
233 return Err(CitrinetError::InvalidChannels(spec.channels));
234 }
235 if spec.sample_rate != EXPECTED_SAMPLE_RATE {
236 return Err(CitrinetError::InvalidSampleRate {
237 expected: EXPECTED_SAMPLE_RATE,
238 got: spec.sample_rate,
239 });
240 }
241 let samples: Vec<i16> = match spec.sample_format {
242 SampleFormat::Int => reader.samples::<i16>().collect::<Result<_, _>>()?,
243 SampleFormat::Float => {
244 let raw: Vec<f32> = reader.samples::<f32>().collect::<Result<_, _>>()?;
245 raw.into_iter()
246 .map(|s| (s * i16::MAX as f32) as i16)
247 .collect()
248 }
249 };
250 if samples.is_empty() {
251 return Err(CitrinetError::EmptyAudio);
252 }
253 Ok((samples, spec.sample_rate))
254}
255
256fn compute_fbank(
257 samples: &[f32],
258 sample_rate: u32,
259) -> Result<(Vec<f32>, usize, usize), CitrinetError> {
260 if samples.is_empty() {
261 return Err(CitrinetError::EmptyAudio);
262 }
263
264 let mut opts = FbankOptions::default();
265 opts.frame_opts.dither = 0.0;
266 opts.frame_opts.snip_edges = false;
267 opts.frame_opts.samp_freq = sample_rate as f32;
268 opts.mel_opts.num_bins = FBANK_BINS;
269 opts.use_energy = false;
270 opts.raw_energy = false;
271
272 let computer =
273 FeatureComputer::Fbank(FbankComputer::new(opts).map_err(CitrinetError::Feature)?);
274 let mut online = OnlineFeature::new(computer);
275 online.accept_waveform(sample_rate as f32, samples);
276 online.input_finished();
277
278 let frames = online.num_frames_ready();
279 let Some(dim) = online.features.first().map(|f| f.len()) else {
280 return Err(CitrinetError::Feature(
281 "no feature frames produced".to_string(),
282 ));
283 };
284 let mut matrix = Vec::with_capacity(frames * dim);
285 for frame in &online.features {
286 matrix.extend_from_slice(frame);
287 }
288 normalize_features(&mut matrix, frames, dim);
289 Ok((matrix, frames, dim))
290}
291
292fn normalize_features(features: &mut [f32], frames: usize, dim: usize) {
293 const EPS: f32 = 1e-5;
294 for c in 0..dim {
295 let mut sum = 0.0;
296 for t in 0..frames {
297 sum += features[t * dim + c];
298 }
299 let mean = sum / frames as f32;
300
301 let mut sq = 0.0;
302 for t in 0..frames {
303 let v = features[t * dim + c] - mean;
304 sq += v * v;
305 }
306 let variance = sq / frames as f32;
307 let inv_std = 1.0 / (variance.sqrt() + EPS);
308
309 for t in 0..frames {
310 let idx = t * dim + c;
311 features[idx] = (features[idx] - mean) * inv_std;
312 }
313 }
314}
315
316fn to_nct(features: &[f32], frames: usize, dim: usize) -> Vec<f32> {
317 let mut dst = vec![0.0; frames * dim];
318 for c in 0..dim {
319 for t in 0..frames {
320 dst[c * frames + t] = features[t * dim + c];
321 }
322 }
323 dst
324}
325
326fn greedy_decode(log_probs: &[f32], time: usize, vocab: usize, blank_id: usize) -> Vec<usize> {
327 let mut argmax = Vec::with_capacity(time);
328 for t in 0..time {
329 let row = &log_probs[t * vocab..(t + 1) * vocab];
330 let (idx, _) = row
331 .iter()
332 .enumerate()
333 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
334 .unwrap();
335 argmax.push(idx);
336 }
337
338 let mut collapsed = Vec::with_capacity(time);
339 let mut prev: Option<usize> = None;
340 for id in argmax {
341 if prev == Some(id) {
342 continue;
343 }
344 prev = Some(id);
345 collapsed.push(id);
346 }
347
348 collapsed.into_iter().filter(|&id| id != blank_id).collect()
349}