citrinet_rs/
lib.rs

1use std::{
2    fs::File,
3    io::{BufRead, BufReader},
4    path::Path,
5};
6
7use hound::{SampleFormat, WavReader};
8use kaldi_native_fbank::{FbankComputer, FbankOptions, OnlineFeature, online::FeatureComputer};
9use ort::{
10    session::{Session, builder::GraphOptimizationLevel},
11    value::Tensor,
12};
13
14pub const PCM_SCALE: f32 = 32_768.0;
15pub const EXPECTED_SAMPLE_RATE: u32 = 16_000;
16pub const DEFAULT_BLANK_ID: usize = 1_024;
17pub const FBANK_BINS: usize = 80;
18
19#[derive(Debug)]
20pub enum CitrinetError {
21    Io(std::io::Error),
22    Audio(hound::Error),
23    Ort(ort::Error),
24    InvalidSampleRate { expected: u32, got: u32 },
25    InvalidChannels(u16),
26    EmptyAudio,
27    Feature(String),
28    ModelOutput(String),
29    Tokens(String),
30}
31
32impl std::fmt::Display for CitrinetError {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        match self {
35            Self::Io(err) => write!(f, "I/O error: {}", err),
36            Self::Audio(err) => write!(f, "WAV error: {}", err),
37            Self::Ort(err) => write!(f, "ONNX Runtime error: {}", err),
38            Self::InvalidSampleRate { expected, got } => {
39                write!(f, "expected {} Hz audio, got {} Hz", expected, got)
40            }
41            Self::InvalidChannels(ch) => write!(f, "expected mono audio, got {} channels", ch),
42            Self::EmptyAudio => write!(f, "no samples found in audio"),
43            Self::Feature(msg) => write!(f, "feature extraction failed: {}", msg),
44            Self::ModelOutput(msg) => write!(f, "model output error: {}", msg),
45            Self::Tokens(msg) => write!(f, "token table error: {}", msg),
46        }
47    }
48}
49
50impl std::error::Error for CitrinetError {}
51
52impl From<std::io::Error> for CitrinetError {
53    fn from(value: std::io::Error) -> Self {
54        Self::Io(value)
55    }
56}
57
58impl From<hound::Error> for CitrinetError {
59    fn from(value: hound::Error) -> Self {
60        Self::Audio(value)
61    }
62}
63
64impl From<ort::Error> for CitrinetError {
65    fn from(value: ort::Error) -> Self {
66        Self::Ort(value)
67    }
68}
69
70pub struct Citrinet {
71    session: Session,
72    tokens: Vec<String>,
73    blank_id: usize,
74}
75
76pub struct CitrinetResult {
77    pub text: String,
78    pub token_ids: Vec<usize>,
79    pub log_probs: Vec<f32>,
80    pub log_prob_shape: [usize; 3],
81}
82
83impl Citrinet {
84    pub fn from_files(
85        model_path: impl AsRef<Path>,
86        tokens_path: impl AsRef<Path>,
87    ) -> Result<Self, CitrinetError> {
88        let session = Session::builder()?
89            .with_optimization_level(GraphOptimizationLevel::Level3)?
90            .commit_from_file(model_path)?;
91        let tokens = load_tokens(tokens_path)?;
92        Ok(Self {
93            session,
94            tokens,
95            blank_id: DEFAULT_BLANK_ID,
96        })
97    }
98
99    pub fn with_blank_id(mut self, blank_id: usize) -> Self {
100        self.blank_id = blank_id;
101        self
102    }
103
104    pub fn infer_file(
105        &mut self,
106        wav_path: impl AsRef<Path>,
107    ) -> Result<CitrinetResult, CitrinetError> {
108        let (samples, sample_rate) = read_wav(wav_path)?;
109        self.infer_samples(&samples, sample_rate)
110    }
111
112    pub fn infer_samples(
113        &mut self,
114        samples: &[i16],
115        sample_rate: u32,
116    ) -> Result<CitrinetResult, CitrinetError> {
117        if sample_rate != EXPECTED_SAMPLE_RATE {
118            return Err(CitrinetError::InvalidSampleRate {
119                expected: EXPECTED_SAMPLE_RATE,
120                got: sample_rate,
121            });
122        }
123
124        let scaled: Vec<f32> = samples.iter().map(|s| *s as f32 * PCM_SCALE).collect();
125        let (features, frames, dim) = compute_fbank(&scaled, sample_rate)?;
126        let nct = to_nct(&features, frames, dim);
127        let (log_probs, shape) = self.run_model(&nct, frames, dim)?;
128        let time = shape[1];
129        let vocab = shape[2];
130        let token_ids = greedy_decode(&log_probs, time, vocab, self.blank_id);
131        let text = token_ids
132            .iter()
133            .filter(|&&id| id != self.blank_id)
134            .filter_map(|&id| self.tokens.get(id))
135            .fold(String::new(), |mut acc, sym| {
136                acc.push_str(sym);
137                acc
138            });
139
140        Ok(CitrinetResult {
141            text,
142            token_ids,
143            log_probs,
144            log_prob_shape: shape,
145        })
146    }
147
148    fn run_model(
149        &mut self,
150        nct_features: &[f32],
151        frames: usize,
152        dim: usize,
153    ) -> Result<(Vec<f32>, [usize; 3]), CitrinetError> {
154        let feature_tensor = Tensor::from_array(([1usize, dim, frames], nct_features.to_vec()))?;
155        let length_tensor = Tensor::from_array(([1usize], vec![frames as i64]))?;
156        let outputs = self
157            .session
158            .run(ort::inputs![feature_tensor, length_tensor])?;
159        if outputs.len() == 0 {
160            return Err(CitrinetError::ModelOutput(
161                "model produced no outputs".to_string(),
162            ));
163        }
164        let output = &outputs[0];
165        let (shape, data) = output
166            .try_extract_tensor::<f32>()
167            .map_err(|e| CitrinetError::ModelOutput(e.to_string()))?;
168        if shape.len() != 3 {
169            return Err(CitrinetError::ModelOutput(format!(
170                "expected 3D logits, got shape {:?}",
171                shape
172            )));
173        }
174        let dims: Vec<usize> = shape.iter().map(|d| *d as usize).collect();
175        let [batch, time, vocab]: [usize; 3] = dims
176            .clone()
177            .try_into()
178            .map_err(|_| CitrinetError::ModelOutput("invalid logits rank".to_string()))?;
179        if batch != 1 {
180            return Err(CitrinetError::ModelOutput(format!(
181                "only batch size 1 supported, got {}",
182                batch
183            )));
184        }
185        if time == 0 || vocab == 0 {
186            return Err(CitrinetError::ModelOutput(
187                "empty logits returned by model".to_string(),
188            ));
189        }
190        let expected = time
191            .checked_mul(vocab)
192            .and_then(|v| v.checked_mul(batch))
193            .ok_or_else(|| CitrinetError::ModelOutput("logit shape overflow".to_string()))?;
194        if data.len() != expected {
195            return Err(CitrinetError::ModelOutput(format!(
196                "logit data length {} does not match shape {:?}",
197                data.len(),
198                dims
199            )));
200        }
201        Ok((data.to_vec(), [batch, time, vocab]))
202    }
203}
204
205fn load_tokens(path: impl AsRef<Path>) -> Result<Vec<String>, CitrinetError> {
206    let file = File::open(path).map_err(CitrinetError::Io)?;
207    let reader = BufReader::new(file);
208    let mut table = Vec::new();
209    for line in reader.lines() {
210        let line = line?;
211        let mut parts = line.split_whitespace();
212        let Some(symbol) = parts.next() else { continue };
213        let Some(idx) = parts.next().and_then(|p| p.parse::<usize>().ok()) else {
214            continue;
215        };
216        if table.len() <= idx {
217            table.resize(idx + 1, String::new());
218        }
219        table[idx] = symbol.to_string();
220    }
221    if table.is_empty() {
222        return Err(CitrinetError::Tokens(
223            "token file contained no entries".to_string(),
224        ));
225    }
226    Ok(table)
227}
228
229fn read_wav(path: impl AsRef<Path>) -> Result<(Vec<i16>, u32), CitrinetError> {
230    let mut reader = WavReader::open(path)?;
231    let spec = reader.spec();
232    if spec.channels != 1 {
233        return Err(CitrinetError::InvalidChannels(spec.channels));
234    }
235    if spec.sample_rate != EXPECTED_SAMPLE_RATE {
236        return Err(CitrinetError::InvalidSampleRate {
237            expected: EXPECTED_SAMPLE_RATE,
238            got: spec.sample_rate,
239        });
240    }
241    let samples: Vec<i16> = match spec.sample_format {
242        SampleFormat::Int => reader.samples::<i16>().collect::<Result<_, _>>()?,
243        SampleFormat::Float => {
244            let raw: Vec<f32> = reader.samples::<f32>().collect::<Result<_, _>>()?;
245            raw.into_iter()
246                .map(|s| (s * i16::MAX as f32) as i16)
247                .collect()
248        }
249    };
250    if samples.is_empty() {
251        return Err(CitrinetError::EmptyAudio);
252    }
253    Ok((samples, spec.sample_rate))
254}
255
256fn compute_fbank(
257    samples: &[f32],
258    sample_rate: u32,
259) -> Result<(Vec<f32>, usize, usize), CitrinetError> {
260    if samples.is_empty() {
261        return Err(CitrinetError::EmptyAudio);
262    }
263
264    let mut opts = FbankOptions::default();
265    opts.frame_opts.dither = 0.0;
266    opts.frame_opts.snip_edges = false;
267    opts.frame_opts.samp_freq = sample_rate as f32;
268    opts.mel_opts.num_bins = FBANK_BINS;
269    opts.use_energy = false;
270    opts.raw_energy = false;
271
272    let computer =
273        FeatureComputer::Fbank(FbankComputer::new(opts).map_err(CitrinetError::Feature)?);
274    let mut online = OnlineFeature::new(computer);
275    online.accept_waveform(sample_rate as f32, samples);
276    online.input_finished();
277
278    let frames = online.num_frames_ready();
279    let Some(dim) = online.features.first().map(|f| f.len()) else {
280        return Err(CitrinetError::Feature(
281            "no feature frames produced".to_string(),
282        ));
283    };
284    let mut matrix = Vec::with_capacity(frames * dim);
285    for frame in &online.features {
286        matrix.extend_from_slice(frame);
287    }
288    normalize_features(&mut matrix, frames, dim);
289    Ok((matrix, frames, dim))
290}
291
292fn normalize_features(features: &mut [f32], frames: usize, dim: usize) {
293    const EPS: f32 = 1e-5;
294    for c in 0..dim {
295        let mut sum = 0.0;
296        for t in 0..frames {
297            sum += features[t * dim + c];
298        }
299        let mean = sum / frames as f32;
300
301        let mut sq = 0.0;
302        for t in 0..frames {
303            let v = features[t * dim + c] - mean;
304            sq += v * v;
305        }
306        let variance = sq / frames as f32;
307        let inv_std = 1.0 / (variance.sqrt() + EPS);
308
309        for t in 0..frames {
310            let idx = t * dim + c;
311            features[idx] = (features[idx] - mean) * inv_std;
312        }
313    }
314}
315
316fn to_nct(features: &[f32], frames: usize, dim: usize) -> Vec<f32> {
317    let mut dst = vec![0.0; frames * dim];
318    for c in 0..dim {
319        for t in 0..frames {
320            dst[c * frames + t] = features[t * dim + c];
321        }
322    }
323    dst
324}
325
326fn greedy_decode(log_probs: &[f32], time: usize, vocab: usize, blank_id: usize) -> Vec<usize> {
327    let mut argmax = Vec::with_capacity(time);
328    for t in 0..time {
329        let row = &log_probs[t * vocab..(t + 1) * vocab];
330        let (idx, _) = row
331            .iter()
332            .enumerate()
333            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
334            .unwrap();
335        argmax.push(idx);
336    }
337
338    let mut collapsed = Vec::with_capacity(time);
339    let mut prev: Option<usize> = None;
340    for id in argmax {
341        if prev == Some(id) {
342            continue;
343        }
344        prev = Some(id);
345        collapsed.push(id);
346    }
347
348    collapsed.into_iter().filter(|&id| id != blank_id).collect()
349}