#[cfg(feature = "async")]
use crate::{
context::WhisperContext, error::Result, params::FullParams,
state::WhisperState, stream::{WhisperStreamConfig, WhisperStream}, Segment, TranscriptionResult,
};
#[cfg(feature = "async")]
use std::sync::Arc;
#[cfg(feature = "async")]
use tokio::sync::{mpsc, oneshot, Mutex};
#[cfg(feature = "async")]
use tokio::task;
#[cfg(feature = "async")]
impl WhisperContext {
pub async fn transcribe_async(&self, audio: Vec<f32>) -> Result<String> {
let ctx = self.clone();
task::spawn_blocking(move || ctx.transcribe(&audio))
.await
.map_err(|e| crate::WhisperError::TranscriptionError(e.to_string()))?
}
pub async fn transcribe_with_params_async(
&self,
audio: Vec<f32>,
params: crate::TranscriptionParams,
) -> Result<TranscriptionResult> {
let ctx = self.clone();
task::spawn_blocking(move || ctx.transcribe_with_params(&audio, params))
.await
.map_err(|e| crate::WhisperError::TranscriptionError(e.to_string()))?
}
pub async fn create_state_async(&self) -> Result<WhisperState> {
let ctx = self.clone();
task::spawn_blocking(move || ctx.create_state())
.await
.map_err(|e| crate::WhisperError::TranscriptionError(e.to_string()))?
}
}
#[cfg(feature = "async")]
pub struct AsyncWhisperStream {
audio_tx: mpsc::Sender<AudioCommand>,
segment_rx: mpsc::Receiver<Vec<Segment>>,
handle: task::JoinHandle<Result<()>>,
}
#[cfg(feature = "async")]
enum AudioCommand {
Feed(Vec<f32>),
Flush(oneshot::Sender<Vec<Segment>>),
Stop,
}
#[cfg(feature = "async")]
impl AsyncWhisperStream {
pub fn new(
context: WhisperContext,
params: FullParams,
) -> Result<Self> {
Self::with_config(context, params, WhisperStreamConfig::default())
}
pub fn with_config(
context: WhisperContext,
params: FullParams,
config: WhisperStreamConfig,
) -> Result<Self> {
let (audio_tx, mut audio_rx) = mpsc::channel::<AudioCommand>(100);
let (segment_tx, segment_rx) = mpsc::channel::<Vec<Segment>>(100);
let handle = task::spawn_blocking(move || {
let mut stream = WhisperStream::with_config(&context, params, config)?;
while let Some(cmd) = audio_rx.blocking_recv() {
match cmd {
AudioCommand::Feed(audio) => {
stream.feed_audio(&audio);
while let Some(segments) = stream.process_step()? {
if !segments.is_empty() {
let _ = segment_tx.blocking_send(segments);
}
}
}
AudioCommand::Flush(response) => {
let segments = stream.flush()?;
let _ = response.send(segments);
}
AudioCommand::Stop => break,
}
}
Ok(())
});
Ok(Self {
audio_tx,
segment_rx,
handle,
})
}
pub async fn feed_audio(&self, audio: Vec<f32>) -> Result<()> {
self.audio_tx
.send(AudioCommand::Feed(audio))
.await
.map_err(|_| crate::WhisperError::TranscriptionError("Stream closed".into()))
}
pub async fn recv_segments(&mut self) -> Option<Vec<Segment>> {
self.segment_rx.recv().await
}
pub fn try_recv_segments(&mut self) -> Option<Vec<Segment>> {
self.segment_rx.try_recv().ok()
}
pub async fn flush(&self) -> Result<Vec<Segment>> {
let (tx, rx) = oneshot::channel();
self.audio_tx
.send(AudioCommand::Flush(tx))
.await
.map_err(|_| crate::WhisperError::TranscriptionError("Stream closed".into()))?;
rx.await
.map_err(|_| crate::WhisperError::TranscriptionError("Failed to flush".into()))
}
pub async fn stop(self) -> Result<()> {
let _ = self.audio_tx.send(AudioCommand::Stop).await;
self.handle
.await
.map_err(|e| crate::WhisperError::TranscriptionError(e.to_string()))?
}
}
#[cfg(feature = "async")]
pub struct SharedAsyncStream {
inner: Arc<Mutex<AsyncStreamInner>>,
}
#[cfg(feature = "async")]
struct AsyncStreamInner {
stream: WhisperStream,
pending_segments: Vec<Segment>,
}
#[cfg(feature = "async")]
impl SharedAsyncStream {
pub async fn new(
context: &WhisperContext,
params: FullParams,
config: WhisperStreamConfig,
) -> Result<Self> {
let stream = WhisperStream::with_config(context, params, config)?;
Ok(Self {
inner: Arc::new(Mutex::new(AsyncStreamInner {
stream,
pending_segments: Vec::new(),
})),
})
}
pub async fn feed_and_process(&self, audio: Vec<f32>) -> Result<Vec<Segment>> {
let mut inner = self.inner.lock().await;
inner.stream.feed_audio(&audio);
let mut segments = Vec::new();
while let Some(segs) = inner.stream.process_step()? {
segments.extend(segs);
}
inner.pending_segments.extend(segments.clone());
Ok(segments)
}
pub async fn drain_segments(&self) -> Vec<Segment> {
let mut inner = self.inner.lock().await;
std::mem::take(&mut inner.pending_segments)
}
pub async fn flush(&self) -> Result<Vec<Segment>> {
let mut inner = self.inner.lock().await;
let segments = inner.stream.flush()?;
inner.pending_segments.extend(segments.clone());
Ok(segments)
}
}
#[cfg(all(test, feature = "async"))]
mod tests {
use super::*;
use crate::SamplingStrategy;
use std::path::Path;
#[tokio::test]
async fn test_async_transcribe() {
let model_path = "tests/models/ggml-tiny.en.bin";
if Path::new(model_path).exists() {
let ctx = WhisperContext::new(model_path).unwrap();
let audio = vec![0.0f32; 16000];
let result = ctx.transcribe_async(audio).await;
assert!(result.is_ok());
}
}
#[tokio::test]
async fn test_async_stream() {
let model_path = "tests/models/ggml-tiny.en.bin";
if Path::new(model_path).exists() {
let ctx = WhisperContext::new(model_path).unwrap();
let params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
let stream = AsyncWhisperStream::new(ctx, params);
assert!(stream.is_ok());
let stream = stream.unwrap();
let audio = vec![0.0f32; 16000];
let result = stream.feed_audio(audio).await;
assert!(result.is_ok());
let result = stream.stop().await;
assert!(result.is_ok());
}
}
#[tokio::test]
async fn test_shared_stream() {
let model_path = "tests/models/ggml-tiny.en.bin";
if Path::new(model_path).exists() {
let ctx = WhisperContext::new(model_path).unwrap();
let params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
let stream = SharedAsyncStream::new(&ctx, params, WhisperStreamConfig::default()).await;
assert!(stream.is_ok());
let stream = stream.unwrap();
let stream1 = stream.clone();
let handle1 = tokio::spawn(async move {
let audio = vec![0.0f32; 16000];
stream1.feed_and_process(audio).await
});
let stream2 = stream.clone();
let handle2 = tokio::spawn(async move {
let audio = vec![0.0f32; 16000];
stream2.feed_and_process(audio).await
});
let result1 = handle1.await.unwrap();
let result2 = handle2.await.unwrap();
assert!(result1.is_ok());
assert!(result2.is_ok());
}
}
}
#[cfg(feature = "async")]
impl Clone for SharedAsyncStream {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
}
}
}