shared_model/
shared_model.rs1use parakeet_rs::{
21 Nemotron, NemotronHandle, ParakeetEOU, ParakeetEOUHandle, ParakeetUnified,
22 ParakeetUnifiedHandle,
23};
24
25fn load_wav(path: &str) -> Result<Vec<f32>, Box<dyn std::error::Error>> {
26 let mut reader = hound::WavReader::open(path)?;
27 let spec = reader.spec();
28 let mut audio: Vec<f32> = match spec.sample_format {
29 hound::SampleFormat::Float => reader.samples::<f32>().collect::<Result<_, _>>()?,
30 hound::SampleFormat::Int => reader
31 .samples::<i16>()
32 .map(|s| s.map(|v| v as f32 / 32768.0))
33 .collect::<Result<_, _>>()?,
34 };
35 if spec.channels > 1 {
36 audio = audio
37 .chunks(spec.channels as usize)
38 .map(|c| c.iter().sum::<f32>() / spec.channels as f32)
39 .collect();
40 }
41 Ok(audio)
42}
43
44fn run_nemotron(model_dir: &str, audio: &[f32]) -> Result<(), Box<dyn std::error::Error>> {
45 let handle = NemotronHandle::load(model_dir, None)?;
46 let mut a = Nemotron::from_shared(&handle);
47 let mut b = Nemotron::from_shared(&handle);
48
49 let chunk_size = 8960; for chunk_data in audio.chunks(chunk_size) {
51 let mut chunk = chunk_data.to_vec();
52 chunk.resize(chunk_size, 0.0);
53 a.transcribe_chunk(&chunk)?;
54 b.transcribe_chunk(&chunk)?;
55 }
56
57 println!("A: {}", a.get_transcript());
58 println!("B: {}", b.get_transcript());
59 assert_eq!(
60 a.get_transcript(),
61 b.get_transcript(),
62 "shared model must be deterministic"
63 );
64 println!("same");
65 Ok(())
66}
67
68fn run_eou(model_dir: &str, audio: &[f32]) -> Result<(), Box<dyn std::error::Error>> {
69 let handle = ParakeetEOUHandle::load(model_dir, None)?;
70 let mut a = ParakeetEOU::from_shared(&handle);
71 let mut b = ParakeetEOU::from_shared(&handle);
72
73 let chunk_size = 2560;
74 let mut a_text = String::new();
75 let mut b_text = String::new();
76 for chunk_data in audio.chunks(chunk_size) {
77 let chunk: Vec<f32> = if chunk_data.len() < chunk_size {
78 let mut p = chunk_data.to_vec();
79 p.resize(chunk_size, 0.0);
80 p
81 } else {
82 chunk_data.to_vec()
83 };
84 a_text.push_str(&a.transcribe(&chunk, false)?);
85 b_text.push_str(&b.transcribe(&chunk, false)?);
86 }
87
88 println!("A: {}", a_text.trim());
89 println!("B: {}", b_text.trim());
90 assert_eq!(a_text, b_text, "shared model must be deterministic");
91 println!("same");
92 Ok(())
93}
94
95fn run_unified(model_dir: &str, audio: &[f32]) -> Result<(), Box<dyn std::error::Error>> {
96 let handle = ParakeetUnifiedHandle::load(model_dir, None)?;
97 let mut a = ParakeetUnified::from_shared(&handle);
98 let mut b = ParakeetUnified::from_shared(&handle);
99
100 let chunk_size = a.streaming_config().chunk_samples();
101 for chunk_data in audio.chunks(chunk_size) {
102 a.transcribe_chunk(chunk_data)?;
103 b.transcribe_chunk(chunk_data)?;
104 }
105 a.flush()?;
106 b.flush()?;
107
108 println!("A: {}", a.get_transcript());
109 println!("B: {}", b.get_transcript());
110 assert_eq!(
111 a.get_transcript(),
112 b.get_transcript(),
113 "shared model must be deterministic"
114 );
115 println!("same");
116 Ok(())
117}
118
119fn main() -> Result<(), Box<dyn std::error::Error>> {
120 let args: Vec<String> = std::env::args().collect();
121 if args.len() < 3 {
122 eprintln!("Usage: shared_model <model_dir> <audio.wav> [eou|unified]");
123 std::process::exit(1);
124 }
125 let model_dir = &args[1];
126 let audio_path = &args[2];
127 let variant = args.get(3).map(String::as_str).unwrap_or("nemotron");
128
129 let audio = load_wav(audio_path)?;
130
131 match variant {
132 "eou" => run_eou(model_dir, &audio),
133 "unified" => run_unified(model_dir, &audio),
134 _ => run_nemotron(model_dir, &audio),
135 }
136}