1use std::pin::Pin;
4
5use async_trait::async_trait;
6use futures::{SinkExt, Stream, StreamExt};
7use tokio_tungstenite::tungstenite::Message;
8use tokio_tungstenite::tungstenite::client::IntoClientRequest;
9use tracing::{debug, warn};
10
11use crate::error::{AudioError, AudioResult};
12use crate::frame::AudioFrame;
13use crate::providers::stt::frame_to_wav_bytes;
14use crate::traits::{Speaker, SttOptions, SttProvider, Transcript, Word};
15
16pub struct DeepgramStt {
21 api_key: String,
22 client: reqwest::Client,
23 base_url: String,
24}
25
26impl DeepgramStt {
27 #[doc(hidden)]
29 pub fn with_api_key(api_key: String) -> Self {
30 Self {
31 api_key,
32 client: reqwest::Client::new(),
33 base_url: "https://api.deepgram.com".into(),
34 }
35 }
36
37 pub fn from_env() -> AudioResult<Self> {
39 let api_key = std::env::var("DEEPGRAM_API_KEY").map_err(|_| AudioError::Stt {
40 provider: "deepgram".into(),
41 message: "DEEPGRAM_API_KEY not set".into(),
42 })?;
43 Ok(Self {
44 api_key,
45 client: reqwest::Client::new(),
46 base_url: "https://api.deepgram.com".into(),
47 })
48 }
49
50 fn build_ws_url(&self, opts: &SttOptions) -> String {
52 let ws_base = self.base_url.replace("https://", "wss://");
53 let mut params = vec![
54 "model=nova-2".to_string(),
55 "encoding=linear16".to_string(),
56 "sample_rate=16000".to_string(),
57 "channels=1".to_string(),
58 "smart_format=true".to_string(),
59 "interim_results=true".to_string(),
60 ];
61 if opts.diarize {
62 params.push("diarize=true".to_string());
63 }
64 if opts.word_timestamps {
65 params.push("utterances=true".to_string());
66 }
67 if let Some(ref lang) = opts.language {
68 params.push(format!("language={lang}"));
69 }
70 if opts.smart_format {
71 params.push("punctuate=true".to_string());
72 }
73 if let Some(ref model) = opts.model_hint {
74 params.retain(|p| !p.starts_with("model="));
76 params.push(format!("model={model}"));
77 }
78 format!("{ws_base}/v1/listen?{}", params.join("&"))
79 }
80}
81
82#[async_trait]
83impl SttProvider for DeepgramStt {
84 async fn transcribe(&self, audio: &AudioFrame, opts: &SttOptions) -> AudioResult<Transcript> {
85 assert!(self.base_url.starts_with("https://"), "Deepgram requires HTTPS");
86 let wav_bytes = frame_to_wav_bytes(audio)?;
87
88 let mut params = vec!["model=nova-2".to_string(), "smart_format=true".to_string()];
89 if opts.diarize {
90 params.push("diarize=true".to_string());
91 }
92 if opts.word_timestamps {
93 params.push("utterances=true".to_string());
94 }
95 if let Some(ref lang) = opts.language {
96 params.push(format!("language={lang}"));
97 }
98 if opts.smart_format {
99 params.push("punctuate=true".to_string());
100 }
101
102 let url = format!("{}/v1/listen?{}", self.base_url, params.join("&"));
103
104 let resp = self
105 .client
106 .post(&url)
107 .header("Authorization", format!("Token {}", self.api_key))
108 .header("Content-Type", "audio/wav")
109 .body(wav_bytes.to_vec())
110 .send()
111 .await
112 .map_err(|e| AudioError::Stt { provider: "deepgram".into(), message: e.to_string() })?;
113
114 if !resp.status().is_success() {
115 return Err(AudioError::Stt {
116 provider: "deepgram".into(),
117 message: format!("HTTP {}", resp.status()),
118 });
119 }
120
121 let json: serde_json::Value = resp
122 .json()
123 .await
124 .map_err(|e| AudioError::Stt { provider: "deepgram".into(), message: e.to_string() })?;
125
126 let channel = &json["results"]["channels"][0]["alternatives"][0];
127 let text = channel["transcript"].as_str().unwrap_or_default().to_string();
128 let confidence = channel["confidence"].as_f64().unwrap_or(0.0) as f32;
129
130 let words: Vec<Word> = channel["words"]
131 .as_array()
132 .map(|arr| {
133 arr.iter()
134 .map(|w| Word {
135 text: w["word"].as_str().unwrap_or_default().to_string(),
136 start_ms: (w["start"].as_f64().unwrap_or(0.0) * 1000.0) as u32,
137 end_ms: (w["end"].as_f64().unwrap_or(0.0) * 1000.0) as u32,
138 confidence: w["confidence"].as_f64().unwrap_or(0.0) as f32,
139 speaker: w["speaker"].as_u64().map(|s| s as u32),
140 })
141 .collect()
142 })
143 .unwrap_or_default();
144
145 let mut speaker_ids: Vec<u32> = words.iter().filter_map(|w| w.speaker).collect();
147 speaker_ids.sort();
148 speaker_ids.dedup();
149 let speakers: Vec<Speaker> =
150 speaker_ids.into_iter().map(|id| Speaker { id, label: None }).collect();
151
152 let language_detected =
153 json["results"]["channels"][0]["detected_language"].as_str().map(String::from);
154
155 Ok(Transcript { text, words, speakers, confidence, language_detected })
156 }
157
158 async fn transcribe_stream(
159 &self,
160 audio: Pin<Box<dyn Stream<Item = AudioFrame> + Send>>,
161 opts: &SttOptions,
162 ) -> AudioResult<Pin<Box<dyn Stream<Item = AudioResult<Transcript>> + Send>>> {
163 let ws_url = self.build_ws_url(opts);
164 debug!(url = %ws_url, "connecting to Deepgram streaming STT");
165
166 let mut request = ws_url.into_client_request().map_err(|e| AudioError::Stt {
168 provider: "deepgram".into(),
169 message: format!("failed to build WebSocket request: {e}"),
170 })?;
171 request.headers_mut().insert(
172 "Authorization",
173 format!("Token {}", self.api_key).parse().map_err(|e| AudioError::Stt {
174 provider: "deepgram".into(),
175 message: format!("invalid authorization header: {e}"),
176 })?,
177 );
178
179 let (ws_stream, _resp) =
181 tokio_tungstenite::connect_async(request).await.map_err(|e| AudioError::Stt {
182 provider: "deepgram".into(),
183 message: format!("WebSocket connection failed: {e}"),
184 })?;
185
186 let (mut ws_sink, mut ws_source) = ws_stream.split();
187
188 tokio::spawn(async move {
190 let mut audio = audio;
191 while let Some(frame) = audio.next().await {
192 if let Err(e) = ws_sink.send(Message::Binary(frame.data)).await {
195 warn!("deepgram ws send error: {e}");
196 break;
197 }
198 }
199 let close_msg = serde_json::json!({"type": "CloseStream"});
201 let _ = ws_sink.send(Message::Text(close_msg.to_string().into())).await;
202 });
203
204 let transcript_stream = async_stream::stream! {
206 while let Some(msg_result) = ws_source.next().await {
207 let msg = match msg_result {
208 Ok(m) => m,
209 Err(e) => {
210 yield Err(AudioError::Stt {
211 provider: "deepgram".into(),
212 message: format!("WebSocket read error: {e}"),
213 });
214 break;
215 }
216 };
217
218 match msg {
219 Message::Text(text) => {
220 let json: serde_json::Value = match serde_json::from_str(&text) {
221 Ok(v) => v,
222 Err(e) => {
223 warn!("deepgram: failed to parse JSON: {e}");
224 continue;
225 }
226 };
227
228 if let Some(err_msg) = json.get("error").and_then(|v| v.as_str()) {
230 yield Err(AudioError::Stt {
231 provider: "deepgram".into(),
232 message: err_msg.to_string(),
233 });
234 break;
235 }
236
237 if let Some(transcript) = parse_streaming_response(&json) {
239 yield Ok(transcript);
240 }
241 }
242 Message::Close(_) => break,
243 _ => {} }
245 }
246 };
247
248 Ok(Box::pin(transcript_stream))
249 }
250}
251
252fn parse_streaming_response(json: &serde_json::Value) -> Option<Transcript> {
256 let channel = json.get("channel")?;
258 let alt = channel.get("alternatives")?.get(0)?;
259
260 let text = alt["transcript"].as_str().unwrap_or_default().to_string();
261 if text.is_empty() {
263 return None;
264 }
265
266 let confidence = alt["confidence"].as_f64().unwrap_or(0.0) as f32;
267 let is_final = json.get("is_final").and_then(|v| v.as_bool()).unwrap_or(false);
268
269 let words: Vec<Word> = alt["words"]
270 .as_array()
271 .map(|arr| {
272 arr.iter()
273 .map(|w| Word {
274 text: w["word"].as_str().unwrap_or_default().to_string(),
275 start_ms: (w["start"].as_f64().unwrap_or(0.0) * 1000.0) as u32,
276 end_ms: (w["end"].as_f64().unwrap_or(0.0) * 1000.0) as u32,
277 confidence: w["confidence"].as_f64().unwrap_or(0.0) as f32,
278 speaker: w["speaker"].as_u64().map(|s| s as u32),
279 })
280 .collect()
281 })
282 .unwrap_or_default();
283
284 let mut speaker_ids: Vec<u32> = words.iter().filter_map(|w| w.speaker).collect();
285 speaker_ids.sort();
286 speaker_ids.dedup();
287 let speakers: Vec<Speaker> =
288 speaker_ids.into_iter().map(|id| Speaker { id, label: None }).collect();
289
290 let language_detected =
291 json.get("metadata").and_then(|m| m["language"].as_str()).map(String::from);
292
293 let _ = is_final; Some(Transcript { text, words, speakers, confidence, language_detected })
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303
304 #[test]
305 fn parse_streaming_final_transcript() {
306 let json: serde_json::Value = serde_json::json!({
307 "type": "Results",
308 "channel_index": [0, 1],
309 "duration": 1.5,
310 "start": 0.0,
311 "is_final": true,
312 "channel": {
313 "alternatives": [{
314 "transcript": "hello world",
315 "confidence": 0.95,
316 "words": [
317 {"word": "hello", "start": 0.0, "end": 0.5, "confidence": 0.96},
318 {"word": "world", "start": 0.6, "end": 1.0, "confidence": 0.94}
319 ]
320 }]
321 }
322 });
323
324 let transcript = parse_streaming_response(&json).expect("should parse");
325 assert_eq!(transcript.text, "hello world");
326 assert!((transcript.confidence - 0.95).abs() < 0.01);
327 assert_eq!(transcript.words.len(), 2);
328 assert_eq!(transcript.words[0].text, "hello");
329 assert_eq!(transcript.words[0].start_ms, 0);
330 assert_eq!(transcript.words[0].end_ms, 500);
331 assert_eq!(transcript.words[1].text, "world");
332 }
333
334 #[test]
335 fn parse_streaming_interim_transcript() {
336 let json: serde_json::Value = serde_json::json!({
337 "type": "Results",
338 "is_final": false,
339 "channel": {
340 "alternatives": [{
341 "transcript": "hel",
342 "confidence": 0.7,
343 "words": []
344 }]
345 }
346 });
347
348 let transcript = parse_streaming_response(&json).expect("should parse interim");
349 assert_eq!(transcript.text, "hel");
350 }
351
352 #[test]
353 fn parse_streaming_empty_transcript_returns_none() {
354 let json: serde_json::Value = serde_json::json!({
355 "type": "Results",
356 "is_final": false,
357 "channel": {
358 "alternatives": [{
359 "transcript": "",
360 "confidence": 0.0,
361 "words": []
362 }]
363 }
364 });
365
366 assert!(parse_streaming_response(&json).is_none());
367 }
368
369 #[test]
370 fn parse_streaming_metadata_message_returns_none() {
371 let json: serde_json::Value = serde_json::json!({
373 "type": "UtteranceEnd",
374 "last_word_end": 1.5
375 });
376
377 assert!(parse_streaming_response(&json).is_none());
378 }
379
380 #[test]
381 fn build_ws_url_default_opts() {
382 let stt = DeepgramStt::with_api_key("test-key".into());
383 let url = stt.build_ws_url(&SttOptions::default());
384 assert!(url.starts_with("wss://api.deepgram.com/v1/listen?"));
385 assert!(url.contains("model=nova-2"));
386 assert!(url.contains("encoding=linear16"));
387 assert!(url.contains("sample_rate=16000"));
388 assert!(url.contains("channels=1"));
389 assert!(url.contains("interim_results=true"));
390 }
391
392 #[test]
393 fn build_ws_url_with_language_and_diarize() {
394 let stt = DeepgramStt::with_api_key("test-key".into());
395 let opts =
396 SttOptions { language: Some("en-US".into()), diarize: true, ..Default::default() };
397 let url = stt.build_ws_url(&opts);
398 assert!(url.contains("language=en-US"));
399 assert!(url.contains("diarize=true"));
400 }
401
402 #[test]
403 fn build_ws_url_with_model_hint() {
404 let stt = DeepgramStt::with_api_key("test-key".into());
405 let opts = SttOptions { model_hint: Some("nova-3".into()), ..Default::default() };
406 let url = stt.build_ws_url(&opts);
407 assert!(url.contains("model=nova-3"));
408 assert!(!url.contains("model=nova-2"));
410 }
411
412 #[test]
413 fn parse_streaming_with_speakers() {
414 let json: serde_json::Value = serde_json::json!({
415 "type": "Results",
416 "is_final": true,
417 "channel": {
418 "alternatives": [{
419 "transcript": "hi there",
420 "confidence": 0.9,
421 "words": [
422 {"word": "hi", "start": 0.0, "end": 0.3, "confidence": 0.9, "speaker": 0},
423 {"word": "there", "start": 0.4, "end": 0.8, "confidence": 0.9, "speaker": 1}
424 ]
425 }]
426 }
427 });
428
429 let transcript = parse_streaming_response(&json).expect("should parse");
430 assert_eq!(transcript.speakers.len(), 2);
431 assert_eq!(transcript.speakers[0].id, 0);
432 assert_eq!(transcript.speakers[1].id, 1);
433 }
434}