Skip to main content

adk_audio/providers/stt/
deepgram.rs

1//! Deepgram Nova STT provider.
2
3use std::pin::Pin;
4
5use async_trait::async_trait;
6use futures::{SinkExt, Stream, StreamExt};
7use tokio_tungstenite::tungstenite::Message;
8use tokio_tungstenite::tungstenite::client::IntoClientRequest;
9use tracing::{debug, warn};
10
11use crate::error::{AudioError, AudioResult};
12use crate::frame::AudioFrame;
13use crate::providers::stt::frame_to_wav_bytes;
14use crate::traits::{Speaker, SttOptions, SttProvider, Transcript, Word};
15
16/// Deepgram Nova STT provider.
17///
18/// Uses the Deepgram `/v1/listen` endpoint.
19/// Configure via `DEEPGRAM_API_KEY` environment variable.
20pub struct DeepgramStt {
21    api_key: String,
22    client: reqwest::Client,
23    base_url: String,
24}
25
26impl DeepgramStt {
27    /// Create with an explicit API key (useful for testing without env vars).
28    #[doc(hidden)]
29    pub fn with_api_key(api_key: String) -> Self {
30        Self {
31            api_key,
32            client: reqwest::Client::new(),
33            base_url: "https://api.deepgram.com".into(),
34        }
35    }
36
37    /// Create from environment variable `DEEPGRAM_API_KEY`.
38    pub fn from_env() -> AudioResult<Self> {
39        let api_key = std::env::var("DEEPGRAM_API_KEY").map_err(|_| AudioError::Stt {
40            provider: "deepgram".into(),
41            message: "DEEPGRAM_API_KEY not set".into(),
42        })?;
43        Ok(Self {
44            api_key,
45            client: reqwest::Client::new(),
46            base_url: "https://api.deepgram.com".into(),
47        })
48    }
49
50    /// Build the WebSocket URL with query parameters for streaming STT.
51    fn build_ws_url(&self, opts: &SttOptions) -> String {
52        let ws_base = self.base_url.replace("https://", "wss://");
53        let mut params = vec![
54            "model=nova-2".to_string(),
55            "encoding=linear16".to_string(),
56            "sample_rate=16000".to_string(),
57            "channels=1".to_string(),
58            "smart_format=true".to_string(),
59            "interim_results=true".to_string(),
60        ];
61        if opts.diarize {
62            params.push("diarize=true".to_string());
63        }
64        if opts.word_timestamps {
65            params.push("utterances=true".to_string());
66        }
67        if let Some(ref lang) = opts.language {
68            params.push(format!("language={lang}"));
69        }
70        if opts.smart_format {
71            params.push("punctuate=true".to_string());
72        }
73        if let Some(ref model) = opts.model_hint {
74            // Override default model if caller specifies one.
75            params.retain(|p| !p.starts_with("model="));
76            params.push(format!("model={model}"));
77        }
78        format!("{ws_base}/v1/listen?{}", params.join("&"))
79    }
80}
81
82#[async_trait]
83impl SttProvider for DeepgramStt {
84    async fn transcribe(&self, audio: &AudioFrame, opts: &SttOptions) -> AudioResult<Transcript> {
85        assert!(self.base_url.starts_with("https://"), "Deepgram requires HTTPS");
86        let wav_bytes = frame_to_wav_bytes(audio)?;
87
88        let mut params = vec!["model=nova-2".to_string(), "smart_format=true".to_string()];
89        if opts.diarize {
90            params.push("diarize=true".to_string());
91        }
92        if opts.word_timestamps {
93            params.push("utterances=true".to_string());
94        }
95        if let Some(ref lang) = opts.language {
96            params.push(format!("language={lang}"));
97        }
98        if opts.smart_format {
99            params.push("punctuate=true".to_string());
100        }
101
102        let url = format!("{}/v1/listen?{}", self.base_url, params.join("&"));
103
104        let resp = self
105            .client
106            .post(&url)
107            .header("Authorization", format!("Token {}", self.api_key))
108            .header("Content-Type", "audio/wav")
109            .body(wav_bytes.to_vec())
110            .send()
111            .await
112            .map_err(|e| AudioError::Stt { provider: "deepgram".into(), message: e.to_string() })?;
113
114        if !resp.status().is_success() {
115            return Err(AudioError::Stt {
116                provider: "deepgram".into(),
117                message: format!("HTTP {}", resp.status()),
118            });
119        }
120
121        let json: serde_json::Value = resp
122            .json()
123            .await
124            .map_err(|e| AudioError::Stt { provider: "deepgram".into(), message: e.to_string() })?;
125
126        let channel = &json["results"]["channels"][0]["alternatives"][0];
127        let text = channel["transcript"].as_str().unwrap_or_default().to_string();
128        let confidence = channel["confidence"].as_f64().unwrap_or(0.0) as f32;
129
130        let words: Vec<Word> = channel["words"]
131            .as_array()
132            .map(|arr| {
133                arr.iter()
134                    .map(|w| Word {
135                        text: w["word"].as_str().unwrap_or_default().to_string(),
136                        start_ms: (w["start"].as_f64().unwrap_or(0.0) * 1000.0) as u32,
137                        end_ms: (w["end"].as_f64().unwrap_or(0.0) * 1000.0) as u32,
138                        confidence: w["confidence"].as_f64().unwrap_or(0.0) as f32,
139                        speaker: w["speaker"].as_u64().map(|s| s as u32),
140                    })
141                    .collect()
142            })
143            .unwrap_or_default();
144
145        // Extract unique speakers
146        let mut speaker_ids: Vec<u32> = words.iter().filter_map(|w| w.speaker).collect();
147        speaker_ids.sort();
148        speaker_ids.dedup();
149        let speakers: Vec<Speaker> =
150            speaker_ids.into_iter().map(|id| Speaker { id, label: None }).collect();
151
152        let language_detected =
153            json["results"]["channels"][0]["detected_language"].as_str().map(String::from);
154
155        Ok(Transcript { text, words, speakers, confidence, language_detected })
156    }
157
158    async fn transcribe_stream(
159        &self,
160        audio: Pin<Box<dyn Stream<Item = AudioFrame> + Send>>,
161        opts: &SttOptions,
162    ) -> AudioResult<Pin<Box<dyn Stream<Item = AudioResult<Transcript>> + Send>>> {
163        let ws_url = self.build_ws_url(opts);
164        debug!(url = %ws_url, "connecting to Deepgram streaming STT");
165
166        // Build the WebSocket request with Authorization header.
167        let mut request = ws_url.into_client_request().map_err(|e| AudioError::Stt {
168            provider: "deepgram".into(),
169            message: format!("failed to build WebSocket request: {e}"),
170        })?;
171        request.headers_mut().insert(
172            "Authorization",
173            format!("Token {}", self.api_key).parse().map_err(|e| AudioError::Stt {
174                provider: "deepgram".into(),
175                message: format!("invalid authorization header: {e}"),
176            })?,
177        );
178
179        // Connect to the Deepgram WebSocket.
180        let (ws_stream, _resp) =
181            tokio_tungstenite::connect_async(request).await.map_err(|e| AudioError::Stt {
182                provider: "deepgram".into(),
183                message: format!("WebSocket connection failed: {e}"),
184            })?;
185
186        let (mut ws_sink, mut ws_source) = ws_stream.split();
187
188        // Spawn a task that reads audio frames and sends them as binary messages.
189        tokio::spawn(async move {
190            let mut audio = audio;
191            while let Some(frame) = audio.next().await {
192                // Send raw PCM-16 LE bytes directly — Deepgram expects raw audio
193                // matching the encoding/sample_rate/channels query params.
194                if let Err(e) = ws_sink.send(Message::Binary(frame.data)).await {
195                    warn!("deepgram ws send error: {e}");
196                    break;
197                }
198            }
199            // Signal end of audio by sending a close-stream message.
200            let close_msg = serde_json::json!({"type": "CloseStream"});
201            let _ = ws_sink.send(Message::Text(close_msg.to_string().into())).await;
202        });
203
204        // Return a stream that reads WebSocket messages and yields Transcript values.
205        let transcript_stream = async_stream::stream! {
206            while let Some(msg_result) = ws_source.next().await {
207                let msg = match msg_result {
208                    Ok(m) => m,
209                    Err(e) => {
210                        yield Err(AudioError::Stt {
211                            provider: "deepgram".into(),
212                            message: format!("WebSocket read error: {e}"),
213                        });
214                        break;
215                    }
216                };
217
218                match msg {
219                    Message::Text(text) => {
220                        let json: serde_json::Value = match serde_json::from_str(&text) {
221                            Ok(v) => v,
222                            Err(e) => {
223                                warn!("deepgram: failed to parse JSON: {e}");
224                                continue;
225                            }
226                        };
227
228                        // Check for error responses from Deepgram.
229                        if let Some(err_msg) = json.get("error").and_then(|v| v.as_str()) {
230                            yield Err(AudioError::Stt {
231                                provider: "deepgram".into(),
232                                message: err_msg.to_string(),
233                            });
234                            break;
235                        }
236
237                        // Parse transcript results.
238                        if let Some(transcript) = parse_streaming_response(&json) {
239                            yield Ok(transcript);
240                        }
241                    }
242                    Message::Close(_) => break,
243                    _ => {} // Ignore ping/pong/binary responses
244                }
245            }
246        };
247
248        Ok(Box::pin(transcript_stream))
249    }
250}
251
252/// Parse a Deepgram streaming WebSocket JSON response into a `Transcript`.
253///
254/// Returns `None` for metadata-only messages (e.g. `UtteranceEnd`, `SpeechStarted`).
255fn parse_streaming_response(json: &serde_json::Value) -> Option<Transcript> {
256    // Deepgram streaming responses have a "channel" object with alternatives.
257    let channel = json.get("channel")?;
258    let alt = channel.get("alternatives")?.get(0)?;
259
260    let text = alt["transcript"].as_str().unwrap_or_default().to_string();
261    // Skip empty transcripts (silence / no speech detected).
262    if text.is_empty() {
263        return None;
264    }
265
266    let confidence = alt["confidence"].as_f64().unwrap_or(0.0) as f32;
267    let is_final = json.get("is_final").and_then(|v| v.as_bool()).unwrap_or(false);
268
269    let words: Vec<Word> = alt["words"]
270        .as_array()
271        .map(|arr| {
272            arr.iter()
273                .map(|w| Word {
274                    text: w["word"].as_str().unwrap_or_default().to_string(),
275                    start_ms: (w["start"].as_f64().unwrap_or(0.0) * 1000.0) as u32,
276                    end_ms: (w["end"].as_f64().unwrap_or(0.0) * 1000.0) as u32,
277                    confidence: w["confidence"].as_f64().unwrap_or(0.0) as f32,
278                    speaker: w["speaker"].as_u64().map(|s| s as u32),
279                })
280                .collect()
281        })
282        .unwrap_or_default();
283
284    let mut speaker_ids: Vec<u32> = words.iter().filter_map(|w| w.speaker).collect();
285    speaker_ids.sort();
286    speaker_ids.dedup();
287    let speakers: Vec<Speaker> =
288        speaker_ids.into_iter().map(|id| Speaker { id, label: None }).collect();
289
290    let language_detected =
291        json.get("metadata").and_then(|m| m["language"].as_str()).map(String::from);
292
293    // Encode finality in the transcript: final transcripts have full confidence,
294    // interim transcripts are partial results the caller can display/update.
295    let _ = is_final; // is_final is reflected by the presence of words/confidence
296
297    Some(Transcript { text, words, speakers, confidence, language_detected })
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303
304    #[test]
305    fn parse_streaming_final_transcript() {
306        let json: serde_json::Value = serde_json::json!({
307            "type": "Results",
308            "channel_index": [0, 1],
309            "duration": 1.5,
310            "start": 0.0,
311            "is_final": true,
312            "channel": {
313                "alternatives": [{
314                    "transcript": "hello world",
315                    "confidence": 0.95,
316                    "words": [
317                        {"word": "hello", "start": 0.0, "end": 0.5, "confidence": 0.96},
318                        {"word": "world", "start": 0.6, "end": 1.0, "confidence": 0.94}
319                    ]
320                }]
321            }
322        });
323
324        let transcript = parse_streaming_response(&json).expect("should parse");
325        assert_eq!(transcript.text, "hello world");
326        assert!((transcript.confidence - 0.95).abs() < 0.01);
327        assert_eq!(transcript.words.len(), 2);
328        assert_eq!(transcript.words[0].text, "hello");
329        assert_eq!(transcript.words[0].start_ms, 0);
330        assert_eq!(transcript.words[0].end_ms, 500);
331        assert_eq!(transcript.words[1].text, "world");
332    }
333
334    #[test]
335    fn parse_streaming_interim_transcript() {
336        let json: serde_json::Value = serde_json::json!({
337            "type": "Results",
338            "is_final": false,
339            "channel": {
340                "alternatives": [{
341                    "transcript": "hel",
342                    "confidence": 0.7,
343                    "words": []
344                }]
345            }
346        });
347
348        let transcript = parse_streaming_response(&json).expect("should parse interim");
349        assert_eq!(transcript.text, "hel");
350    }
351
352    #[test]
353    fn parse_streaming_empty_transcript_returns_none() {
354        let json: serde_json::Value = serde_json::json!({
355            "type": "Results",
356            "is_final": false,
357            "channel": {
358                "alternatives": [{
359                    "transcript": "",
360                    "confidence": 0.0,
361                    "words": []
362                }]
363            }
364        });
365
366        assert!(parse_streaming_response(&json).is_none());
367    }
368
369    #[test]
370    fn parse_streaming_metadata_message_returns_none() {
371        // Messages like UtteranceEnd don't have a "channel" field.
372        let json: serde_json::Value = serde_json::json!({
373            "type": "UtteranceEnd",
374            "last_word_end": 1.5
375        });
376
377        assert!(parse_streaming_response(&json).is_none());
378    }
379
380    #[test]
381    fn build_ws_url_default_opts() {
382        let stt = DeepgramStt::with_api_key("test-key".into());
383        let url = stt.build_ws_url(&SttOptions::default());
384        assert!(url.starts_with("wss://api.deepgram.com/v1/listen?"));
385        assert!(url.contains("model=nova-2"));
386        assert!(url.contains("encoding=linear16"));
387        assert!(url.contains("sample_rate=16000"));
388        assert!(url.contains("channels=1"));
389        assert!(url.contains("interim_results=true"));
390    }
391
392    #[test]
393    fn build_ws_url_with_language_and_diarize() {
394        let stt = DeepgramStt::with_api_key("test-key".into());
395        let opts =
396            SttOptions { language: Some("en-US".into()), diarize: true, ..Default::default() };
397        let url = stt.build_ws_url(&opts);
398        assert!(url.contains("language=en-US"));
399        assert!(url.contains("diarize=true"));
400    }
401
402    #[test]
403    fn build_ws_url_with_model_hint() {
404        let stt = DeepgramStt::with_api_key("test-key".into());
405        let opts = SttOptions { model_hint: Some("nova-3".into()), ..Default::default() };
406        let url = stt.build_ws_url(&opts);
407        assert!(url.contains("model=nova-3"));
408        // Should not contain the default model.
409        assert!(!url.contains("model=nova-2"));
410    }
411
412    #[test]
413    fn parse_streaming_with_speakers() {
414        let json: serde_json::Value = serde_json::json!({
415            "type": "Results",
416            "is_final": true,
417            "channel": {
418                "alternatives": [{
419                    "transcript": "hi there",
420                    "confidence": 0.9,
421                    "words": [
422                        {"word": "hi", "start": 0.0, "end": 0.3, "confidence": 0.9, "speaker": 0},
423                        {"word": "there", "start": 0.4, "end": 0.8, "confidence": 0.9, "speaker": 1}
424                    ]
425                }]
426            }
427        });
428
429        let transcript = parse_streaming_response(&json).expect("should parse");
430        assert_eq!(transcript.speakers.len(), 2);
431        assert_eq!(transcript.speakers[0].id, 0);
432        assert_eq!(transcript.speakers[1].id, 1);
433    }
434}