Skip to main content

shared_model/
shared_model.rs

1/*
2Shared-model demo — load one ONNX session, drive two concurrent streams
3with independent decoder state for streaming models:
4Nemotron (default):
5cargo run --release --example shared_model ./nemotron audio.wav
6
7EOU:
8cargo run --release --example shared_model ./fullstr audio.wav eou
9
10Unified:
11cargo run --release --example shared_model ./unified audio.wav unified
12
13---
14
15Nemotron (600M): https://huggingface.co/altunenes/parakeet-rs/tree/main/nemotron-speech-streaming-en-0.6b
16EOU (120M): https://huggingface.co/altunenes/parakeet-rs/tree/main/realtime_eou_120m-v1-onnx
17Unified: https://huggingface.co/bobNight/parakeet-unified-en-0.6b-onnx/tree/main
18*/
19
20use parakeet_rs::{
21    Nemotron, NemotronHandle, ParakeetEOU, ParakeetEOUHandle, ParakeetUnified,
22    ParakeetUnifiedHandle,
23};
24
25fn load_wav(path: &str) -> Result<Vec<f32>, Box<dyn std::error::Error>> {
26    let mut reader = hound::WavReader::open(path)?;
27    let spec = reader.spec();
28    let mut audio: Vec<f32> = match spec.sample_format {
29        hound::SampleFormat::Float => reader.samples::<f32>().collect::<Result<_, _>>()?,
30        hound::SampleFormat::Int => reader
31            .samples::<i16>()
32            .map(|s| s.map(|v| v as f32 / 32768.0))
33            .collect::<Result<_, _>>()?,
34    };
35    if spec.channels > 1 {
36        audio = audio
37            .chunks(spec.channels as usize)
38            .map(|c| c.iter().sum::<f32>() / spec.channels as f32)
39            .collect();
40    }
41    Ok(audio)
42}
43
44fn run_nemotron(model_dir: &str, audio: &[f32]) -> Result<(), Box<dyn std::error::Error>> {
45    let handle = NemotronHandle::load(model_dir, None)?;
46    let mut a = Nemotron::from_shared(&handle);
47    let mut b = Nemotron::from_shared(&handle);
48
49    let chunk_size = 8960; // 560 ms at 16 kHz
50    for chunk_data in audio.chunks(chunk_size) {
51        let mut chunk = chunk_data.to_vec();
52        chunk.resize(chunk_size, 0.0);
53        a.transcribe_chunk(&chunk)?;
54        b.transcribe_chunk(&chunk)?;
55    }
56
57    println!("A: {}", a.get_transcript());
58    println!("B: {}", b.get_transcript());
59    assert_eq!(
60        a.get_transcript(),
61        b.get_transcript(),
62        "shared model must be deterministic"
63    );
64    println!("same");
65    Ok(())
66}
67
68fn run_eou(model_dir: &str, audio: &[f32]) -> Result<(), Box<dyn std::error::Error>> {
69    let handle = ParakeetEOUHandle::load(model_dir, None)?;
70    let mut a = ParakeetEOU::from_shared(&handle);
71    let mut b = ParakeetEOU::from_shared(&handle);
72
73    let chunk_size = 2560;
74    let mut a_text = String::new();
75    let mut b_text = String::new();
76    for chunk_data in audio.chunks(chunk_size) {
77        let chunk: Vec<f32> = if chunk_data.len() < chunk_size {
78            let mut p = chunk_data.to_vec();
79            p.resize(chunk_size, 0.0);
80            p
81        } else {
82            chunk_data.to_vec()
83        };
84        a_text.push_str(&a.transcribe(&chunk, false)?);
85        b_text.push_str(&b.transcribe(&chunk, false)?);
86    }
87
88    println!("A: {}", a_text.trim());
89    println!("B: {}", b_text.trim());
90    assert_eq!(a_text, b_text, "shared model must be deterministic");
91    println!("same");
92    Ok(())
93}
94
95fn run_unified(model_dir: &str, audio: &[f32]) -> Result<(), Box<dyn std::error::Error>> {
96    let handle = ParakeetUnifiedHandle::load(model_dir, None)?;
97    let mut a = ParakeetUnified::from_shared(&handle);
98    let mut b = ParakeetUnified::from_shared(&handle);
99
100    let chunk_size = a.streaming_config().chunk_samples();
101    for chunk_data in audio.chunks(chunk_size) {
102        a.transcribe_chunk(chunk_data)?;
103        b.transcribe_chunk(chunk_data)?;
104    }
105    a.flush()?;
106    b.flush()?;
107
108    println!("A: {}", a.get_transcript());
109    println!("B: {}", b.get_transcript());
110    assert_eq!(
111        a.get_transcript(),
112        b.get_transcript(),
113        "shared model must be deterministic"
114    );
115    println!("same");
116    Ok(())
117}
118
119fn main() -> Result<(), Box<dyn std::error::Error>> {
120    let args: Vec<String> = std::env::args().collect();
121    if args.len() < 3 {
122        eprintln!("Usage: shared_model <model_dir> <audio.wav> [eou|unified]");
123        std::process::exit(1);
124    }
125    let model_dir = &args[1];
126    let audio_path = &args[2];
127    let variant = args.get(3).map(String::as_str).unwrap_or("nemotron");
128
129    let audio = load_wav(audio_path)?;
130
131    match variant {
132        "eou" => run_eou(model_dir, &audio),
133        "unified" => run_unified(model_dir, &audio),
134        _ => run_nemotron(model_dir, &audio),
135    }
136}