Skip to main content

adk_audio/providers/stt/
whisper_api.rs

1//! OpenAI Whisper API STT provider.
2
3use std::pin::Pin;
4
5use async_trait::async_trait;
6use futures::Stream;
7
8use crate::error::{AudioError, AudioResult};
9use crate::frame::AudioFrame;
10use crate::providers::stt::frame_to_wav_bytes;
11use crate::traits::{SttOptions, SttProvider, Transcript, Word};
12
13/// OpenAI Whisper API STT provider.
14///
15/// Uses the `/v1/audio/transcriptions` endpoint.
16/// Configure via `OPENAI_API_KEY` environment variable.
17pub struct WhisperApiStt {
18    api_key: String,
19    client: reqwest::Client,
20    base_url: String,
21    model: String,
22}
23
24impl WhisperApiStt {
25    /// Create from environment variable `OPENAI_API_KEY`.
26    pub fn from_env() -> AudioResult<Self> {
27        let api_key = std::env::var("OPENAI_API_KEY").map_err(|_| AudioError::Stt {
28            provider: "whisper".into(),
29            message: "OPENAI_API_KEY not set".into(),
30        })?;
31        Ok(Self {
32            api_key,
33            client: reqwest::Client::new(),
34            base_url: "https://api.openai.com".into(),
35            model: "whisper-1".into(),
36        })
37    }
38}
39
40#[async_trait]
41impl SttProvider for WhisperApiStt {
42    async fn transcribe(&self, audio: &AudioFrame, opts: &SttOptions) -> AudioResult<Transcript> {
43        let wav_bytes = frame_to_wav_bytes(audio)?;
44        let url = format!("{}/v1/audio/transcriptions", self.base_url);
45
46        let part = reqwest::multipart::Part::bytes(wav_bytes.to_vec())
47            .file_name("audio.wav")
48            .mime_str("audio/wav")
49            .map_err(|e| AudioError::Stt { provider: "whisper".into(), message: e.to_string() })?;
50
51        let mut form = reqwest::multipart::Form::new()
52            .text("model", self.model.clone())
53            .text("response_format", "verbose_json")
54            .part("file", part);
55
56        if let Some(ref lang) = opts.language {
57            form = form.text("language", lang.clone());
58        }
59        if opts.word_timestamps {
60            form = form.text("timestamp_granularities[]", "word");
61        }
62
63        let resp = self
64            .client
65            .post(&url)
66            .bearer_auth(&self.api_key)
67            .multipart(form)
68            .send()
69            .await
70            .map_err(|e| AudioError::Stt {
71            provider: "whisper".into(),
72            message: e.to_string(),
73        })?;
74
75        if !resp.status().is_success() {
76            return Err(AudioError::Stt {
77                provider: "whisper".into(),
78                message: format!("HTTP {}", resp.status()),
79            });
80        }
81
82        let json: serde_json::Value = resp
83            .json()
84            .await
85            .map_err(|e| AudioError::Stt { provider: "whisper".into(), message: e.to_string() })?;
86
87        let text = json["text"].as_str().unwrap_or_default().to_string();
88        let language_detected = json["language"].as_str().map(String::from);
89
90        let words = json["words"]
91            .as_array()
92            .map(|arr| {
93                arr.iter()
94                    .map(|w| Word {
95                        text: w["word"].as_str().unwrap_or_default().to_string(),
96                        start_ms: (w["start"].as_f64().unwrap_or(0.0) * 1000.0) as u32,
97                        end_ms: (w["end"].as_f64().unwrap_or(0.0) * 1000.0) as u32,
98                        confidence: 1.0,
99                        speaker: None,
100                    })
101                    .collect()
102            })
103            .unwrap_or_default();
104
105        Ok(Transcript { text, words, speakers: vec![], confidence: 1.0, language_detected })
106    }
107
108    async fn transcribe_stream(
109        &self,
110        _audio: Pin<Box<dyn Stream<Item = AudioFrame> + Send>>,
111        _opts: &SttOptions,
112    ) -> AudioResult<Pin<Box<dyn Stream<Item = AudioResult<Transcript>> + Send>>> {
113        // Whisper API doesn't support native streaming; use windowed fallback
114        Ok(Box::pin(futures::stream::empty()))
115    }
116}