Skip to main content

adk_audio/providers/stt/
assemblyai.rs

1//! AssemblyAI Universal STT provider.
2
3use std::pin::Pin;
4
5use async_trait::async_trait;
6use futures::Stream;
7
8use crate::error::{AudioError, AudioResult};
9use crate::frame::AudioFrame;
10use crate::providers::stt::frame_to_wav_bytes;
11use crate::traits::{Speaker, SttOptions, SttProvider, Transcript, Word};
12
13/// AssemblyAI Universal STT provider.
14///
15/// Uses the AssemblyAI async transcription API (upload → create → poll).
16/// Configure via `ASSEMBLYAI_API_KEY` environment variable.
17pub struct AssemblyAiStt {
18    api_key: String,
19    client: reqwest::Client,
20    base_url: String,
21}
22
23impl AssemblyAiStt {
24    /// Create with an explicit API key (useful for testing without env vars).
25    #[doc(hidden)]
26    pub fn with_api_key(api_key: String) -> Self {
27        Self {
28            api_key,
29            client: reqwest::Client::new(),
30            base_url: "https://api.assemblyai.com".into(),
31        }
32    }
33
34    /// Create from environment variable `ASSEMBLYAI_API_KEY`.
35    pub fn from_env() -> AudioResult<Self> {
36        let api_key = std::env::var("ASSEMBLYAI_API_KEY").map_err(|_| AudioError::Stt {
37            provider: "assemblyai".into(),
38            message: "ASSEMBLYAI_API_KEY not set".into(),
39        })?;
40        Ok(Self {
41            api_key,
42            client: reqwest::Client::new(),
43            base_url: "https://api.assemblyai.com".into(),
44        })
45    }
46}
47
48#[async_trait]
49impl SttProvider for AssemblyAiStt {
50    async fn transcribe(&self, audio: &AudioFrame, opts: &SttOptions) -> AudioResult<Transcript> {
51        let wav_bytes = frame_to_wav_bytes(audio)?;
52
53        // Step 1: Upload audio (base_url is always HTTPS — enforced at construction)
54        assert!(self.base_url.starts_with("https://"), "AssemblyAI requires HTTPS");
55        let upload_url = format!("{}/v2/upload", self.base_url);
56        let upload_resp = self
57            .client
58            .post(&upload_url)
59            .header("authorization", &self.api_key)
60            .header("content-type", "application/octet-stream")
61            .body(wav_bytes.to_vec())
62            .send()
63            .await
64            .map_err(|e| AudioError::Stt {
65                provider: "assemblyai".into(),
66                message: e.to_string(),
67            })?;
68
69        if !upload_resp.status().is_success() {
70            return Err(AudioError::Stt {
71                provider: "assemblyai".into(),
72                message: format!("upload HTTP {}", upload_resp.status()),
73            });
74        }
75
76        let upload_json: serde_json::Value = upload_resp.json().await.map_err(|e| {
77            AudioError::Stt { provider: "assemblyai".into(), message: e.to_string() }
78        })?;
79        let audio_url = upload_json["upload_url"].as_str().ok_or_else(|| AudioError::Stt {
80            provider: "assemblyai".into(),
81            message: "no upload_url in response".into(),
82        })?;
83
84        // Step 2: Create transcription job
85        let create_url = format!("{}/v2/transcript", self.base_url);
86        let mut body = serde_json::json!({
87            "audio_url": audio_url,
88            "language_detection": true,
89        });
90        if opts.diarize {
91            body["speaker_labels"] = serde_json::json!(true);
92        }
93        if let Some(ref lang) = opts.language {
94            body["language_code"] = serde_json::json!(lang);
95            body["language_detection"] = serde_json::json!(false);
96        }
97
98        let create_resp = self
99            .client
100            .post(&create_url)
101            .header("authorization", &self.api_key)
102            .json(&body)
103            .send()
104            .await
105            .map_err(|e| AudioError::Stt {
106                provider: "assemblyai".into(),
107                message: e.to_string(),
108            })?;
109
110        if !create_resp.status().is_success() {
111            return Err(AudioError::Stt {
112                provider: "assemblyai".into(),
113                message: format!("create HTTP {}", create_resp.status()),
114            });
115        }
116
117        let create_json: serde_json::Value = create_resp.json().await.map_err(|e| {
118            AudioError::Stt { provider: "assemblyai".into(), message: e.to_string() }
119        })?;
120        let transcript_id = create_json["id"].as_str().ok_or_else(|| AudioError::Stt {
121            provider: "assemblyai".into(),
122            message: "no id in response".into(),
123        })?;
124
125        // Step 3: Poll for completion
126        let poll_url = format!("{}/v2/transcript/{transcript_id}", self.base_url);
127        loop {
128            tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
129
130            let poll_resp = self
131                .client
132                .get(&poll_url)
133                .header("authorization", &self.api_key)
134                .send()
135                .await
136                .map_err(|e| AudioError::Stt {
137                    provider: "assemblyai".into(),
138                    message: e.to_string(),
139                })?;
140
141            let poll_json: serde_json::Value = poll_resp.json().await.map_err(|e| {
142                AudioError::Stt { provider: "assemblyai".into(), message: e.to_string() }
143            })?;
144
145            let status = poll_json["status"].as_str().unwrap_or("unknown");
146            match status {
147                "completed" => {
148                    return parse_assemblyai_response(&poll_json);
149                }
150                "error" => {
151                    let error_msg = poll_json["error"].as_str().unwrap_or("unknown error");
152                    return Err(AudioError::Stt {
153                        provider: "assemblyai".into(),
154                        message: error_msg.to_string(),
155                    });
156                }
157                _ => continue,
158            }
159        }
160    }
161
162    async fn transcribe_stream(
163        &self,
164        _audio: Pin<Box<dyn Stream<Item = AudioFrame> + Send>>,
165        _opts: &SttOptions,
166    ) -> AudioResult<Pin<Box<dyn Stream<Item = AudioResult<Transcript>> + Send>>> {
167        Err(AudioError::Stt {
168            provider: "assemblyai".into(),
169            message: "streaming transcription not yet implemented".into(),
170        })
171    }
172}
173
174fn parse_assemblyai_response(json: &serde_json::Value) -> AudioResult<Transcript> {
175    let text = json["text"].as_str().unwrap_or_default().to_string();
176    let confidence = json["confidence"].as_f64().unwrap_or(0.0) as f32;
177    let language_detected = json["language_code"].as_str().map(String::from);
178
179    let words: Vec<Word> = json["words"]
180        .as_array()
181        .map(|arr| {
182            arr.iter()
183                .map(|w| Word {
184                    text: w["text"].as_str().unwrap_or_default().to_string(),
185                    start_ms: w["start"].as_u64().unwrap_or(0) as u32,
186                    end_ms: w["end"].as_u64().unwrap_or(0) as u32,
187                    confidence: w["confidence"].as_f64().unwrap_or(0.0) as f32,
188                    speaker: w["speaker"]
189                        .as_str()
190                        .and_then(|s| s.strip_prefix("speaker_").and_then(|n| n.parse().ok())),
191                })
192                .collect()
193        })
194        .unwrap_or_default();
195
196    let mut speaker_ids: Vec<u32> = words.iter().filter_map(|w| w.speaker).collect();
197    speaker_ids.sort();
198    speaker_ids.dedup();
199    let speakers: Vec<Speaker> =
200        speaker_ids.into_iter().map(|id| Speaker { id, label: None }).collect();
201
202    Ok(Transcript { text, words, speakers, confidence, language_detected })
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208
209    #[tokio::test]
210    async fn transcribe_stream_returns_explicit_unimplemented_error() {
211        let provider = AssemblyAiStt {
212            api_key: "test-key".to_string(),
213            client: reqwest::Client::new(),
214            base_url: "https://api.assemblyai.com".to_string(),
215        };
216
217        let result = provider
218            .transcribe_stream(Box::pin(futures::stream::empty()), &SttOptions::default())
219            .await;
220
221        match result {
222            Err(AudioError::Stt { provider, message }) => {
223                assert_eq!(provider, "assemblyai");
224                assert!(message.contains("not yet implemented"));
225            }
226            Err(err) => panic!("unexpected audio error: {err}"),
227            Ok(_) => panic!("expected explicit STT error"),
228        }
229    }
230}