Skip to main content

unified/
unified.rs

1/*
2Unified model: offline + buffered streaming transcription
3
4Offline:
5cargo run --release --example unified 6_speakers.wav
6
7Streaming:
8cargo run --release --example unified 6_speakers.wav streaming
9
10---
11
12Download model from: https://huggingface.co/bobNight/parakeet-unified-en-0.6b-onnx
13Files: encoder.onnx, encoder.onnx.data, decoder_joint.onnx, tokenizer.model
14Place in: ./unified/
15*/
16
17use parakeet_rs::{ParakeetUnified, TimestampMode, Transcriber};
18use std::env;
19use std::io::Write;
20use std::time::Instant;
21
22fn main() -> Result<(), Box<dyn std::error::Error>> {
23    let start_time = Instant::now();
24    let args: Vec<String> = env::args().collect();
25
26    let audio_path = if args.len() > 1 {
27        &args[1]
28    } else {
29        "6_speakers.wav"
30    };
31
32    let use_streaming = args.len() > 2 && args[2] == "streaming";
33
34    // Load audio
35    let mut reader = hound::WavReader::open(audio_path)?;
36    let spec = reader.spec();
37
38    if spec.sample_rate != 16000 {
39        return Err(format!("Expected 16kHz, got {}Hz", spec.sample_rate).into());
40    }
41
42    let mut audio: Vec<f32> = match spec.sample_format {
43        hound::SampleFormat::Float => reader.samples::<f32>().collect::<Result<Vec<_>, _>>()?,
44        hound::SampleFormat::Int => reader
45            .samples::<i16>()
46            .map(|s| s.map(|s| s as f32 / 32768.0))
47            .collect::<Result<Vec<_>, _>>()?,
48    };
49
50    if spec.channels > 1 {
51        audio = audio
52            .chunks(spec.channels as usize)
53            .map(|c| c.iter().sum::<f32>() / spec.channels as f32)
54            .collect();
55    }
56
57    let duration = audio.len() as f32 / 16000.0;
58    println!("Audio: {:.1}s, {}Hz, {} ch", duration, spec.sample_rate, spec.channels);
59
60    let mut model = ParakeetUnified::from_pretrained("./unified", None)?;
61    let load_time = start_time.elapsed();
62    println!("Model loaded in {:.2}s", load_time.as_secs_f32());
63
64    if use_streaming {
65        let config = model.streaming_config();
66        let chunk_size = config.chunk_samples();
67        println!("Streaming mode: {:.0}ms chunks", config.chunk_secs * 1000.0);
68
69        let transcribe_start = Instant::now();
70        print!("Streaming: ");
71
72        for chunk in audio.chunks(chunk_size) {
73            let text = model.transcribe_chunk(chunk)?;
74            if !text.is_empty() {
75                print!("{}", text);
76                std::io::stdout().flush()?;
77            }
78        }
79
80        let remaining = model.flush()?;
81        if !remaining.is_empty() {
82            print!("{}", remaining);
83        }
84
85        let result = model.get_timed_transcript(TimestampMode::Sentences);
86        println!("\n\nFinal: {}", result.text);
87
88        println!("\nSentences:");
89        for segment in &result.tokens {
90            println!("[{:.2}s - {:.2}s]: {}", segment.start, segment.end, segment.text);
91        }
92
93        let elapsed = transcribe_start.elapsed();
94        println!(
95            "Transcribed in {:.2}s (audio: {:.1}s, RTF: {:.2}x)",
96            elapsed.as_secs_f32(),
97            duration,
98            duration / elapsed.as_secs_f32()
99        );
100    } else {
101        println!("Offline mode (with word timestamps)");
102
103        let transcribe_start = Instant::now();
104        let result = model.transcribe_samples(
105            audio,
106            spec.sample_rate,
107            spec.channels,
108            Some(TimestampMode::Words),
109        )?;
110        let elapsed = transcribe_start.elapsed();
111
112        println!("Result: {}", result.text);
113
114        println!("\nWords (first 20):");
115        for word in result.tokens.iter().take(20) {
116            println!("[{:.2}s - {:.2}s]: {}", word.start, word.end, word.text);
117        }
118
119        println!(
120            "Transcribed in {:.2}s (audio: {:.1}s, RTF: {:.2}x)",
121            elapsed.as_secs_f32(),
122            duration,
123            duration / elapsed.as_secs_f32()
124        );
125    }
126
127    Ok(())
128}