adk_audio/providers/stt/
assemblyai.rs1use std::pin::Pin;
4
5use async_trait::async_trait;
6use futures::Stream;
7
8use crate::error::{AudioError, AudioResult};
9use crate::frame::AudioFrame;
10use crate::providers::stt::frame_to_wav_bytes;
11use crate::traits::{Speaker, SttOptions, SttProvider, Transcript, Word};
12
13pub struct AssemblyAiStt {
18 api_key: String,
19 client: reqwest::Client,
20 base_url: String,
21}
22
23impl AssemblyAiStt {
24 #[doc(hidden)]
26 pub fn with_api_key(api_key: String) -> Self {
27 Self {
28 api_key,
29 client: reqwest::Client::new(),
30 base_url: "https://api.assemblyai.com".into(),
31 }
32 }
33
34 pub fn from_env() -> AudioResult<Self> {
36 let api_key = std::env::var("ASSEMBLYAI_API_KEY").map_err(|_| AudioError::Stt {
37 provider: "assemblyai".into(),
38 message: "ASSEMBLYAI_API_KEY not set".into(),
39 })?;
40 Ok(Self {
41 api_key,
42 client: reqwest::Client::new(),
43 base_url: "https://api.assemblyai.com".into(),
44 })
45 }
46}
47
48#[async_trait]
49impl SttProvider for AssemblyAiStt {
50 async fn transcribe(&self, audio: &AudioFrame, opts: &SttOptions) -> AudioResult<Transcript> {
51 let wav_bytes = frame_to_wav_bytes(audio)?;
52
53 assert!(self.base_url.starts_with("https://"), "AssemblyAI requires HTTPS");
55 let upload_url = format!("{}/v2/upload", self.base_url);
56 let upload_resp = self
57 .client
58 .post(&upload_url)
59 .header("authorization", &self.api_key)
60 .header("content-type", "application/octet-stream")
61 .body(wav_bytes.to_vec())
62 .send()
63 .await
64 .map_err(|e| AudioError::Stt {
65 provider: "assemblyai".into(),
66 message: e.to_string(),
67 })?;
68
69 if !upload_resp.status().is_success() {
70 return Err(AudioError::Stt {
71 provider: "assemblyai".into(),
72 message: format!("upload HTTP {}", upload_resp.status()),
73 });
74 }
75
76 let upload_json: serde_json::Value = upload_resp.json().await.map_err(|e| {
77 AudioError::Stt { provider: "assemblyai".into(), message: e.to_string() }
78 })?;
79 let audio_url = upload_json["upload_url"].as_str().ok_or_else(|| AudioError::Stt {
80 provider: "assemblyai".into(),
81 message: "no upload_url in response".into(),
82 })?;
83
84 let create_url = format!("{}/v2/transcript", self.base_url);
86 let mut body = serde_json::json!({
87 "audio_url": audio_url,
88 "language_detection": true,
89 });
90 if opts.diarize {
91 body["speaker_labels"] = serde_json::json!(true);
92 }
93 if let Some(ref lang) = opts.language {
94 body["language_code"] = serde_json::json!(lang);
95 body["language_detection"] = serde_json::json!(false);
96 }
97
98 let create_resp = self
99 .client
100 .post(&create_url)
101 .header("authorization", &self.api_key)
102 .json(&body)
103 .send()
104 .await
105 .map_err(|e| AudioError::Stt {
106 provider: "assemblyai".into(),
107 message: e.to_string(),
108 })?;
109
110 if !create_resp.status().is_success() {
111 return Err(AudioError::Stt {
112 provider: "assemblyai".into(),
113 message: format!("create HTTP {}", create_resp.status()),
114 });
115 }
116
117 let create_json: serde_json::Value = create_resp.json().await.map_err(|e| {
118 AudioError::Stt { provider: "assemblyai".into(), message: e.to_string() }
119 })?;
120 let transcript_id = create_json["id"].as_str().ok_or_else(|| AudioError::Stt {
121 provider: "assemblyai".into(),
122 message: "no id in response".into(),
123 })?;
124
125 let poll_url = format!("{}/v2/transcript/{transcript_id}", self.base_url);
127 loop {
128 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
129
130 let poll_resp = self
131 .client
132 .get(&poll_url)
133 .header("authorization", &self.api_key)
134 .send()
135 .await
136 .map_err(|e| AudioError::Stt {
137 provider: "assemblyai".into(),
138 message: e.to_string(),
139 })?;
140
141 let poll_json: serde_json::Value = poll_resp.json().await.map_err(|e| {
142 AudioError::Stt { provider: "assemblyai".into(), message: e.to_string() }
143 })?;
144
145 let status = poll_json["status"].as_str().unwrap_or("unknown");
146 match status {
147 "completed" => {
148 return parse_assemblyai_response(&poll_json);
149 }
150 "error" => {
151 let error_msg = poll_json["error"].as_str().unwrap_or("unknown error");
152 return Err(AudioError::Stt {
153 provider: "assemblyai".into(),
154 message: error_msg.to_string(),
155 });
156 }
157 _ => continue,
158 }
159 }
160 }
161
162 async fn transcribe_stream(
163 &self,
164 _audio: Pin<Box<dyn Stream<Item = AudioFrame> + Send>>,
165 _opts: &SttOptions,
166 ) -> AudioResult<Pin<Box<dyn Stream<Item = AudioResult<Transcript>> + Send>>> {
167 Err(AudioError::Stt {
168 provider: "assemblyai".into(),
169 message: "streaming transcription not yet implemented".into(),
170 })
171 }
172}
173
174fn parse_assemblyai_response(json: &serde_json::Value) -> AudioResult<Transcript> {
175 let text = json["text"].as_str().unwrap_or_default().to_string();
176 let confidence = json["confidence"].as_f64().unwrap_or(0.0) as f32;
177 let language_detected = json["language_code"].as_str().map(String::from);
178
179 let words: Vec<Word> = json["words"]
180 .as_array()
181 .map(|arr| {
182 arr.iter()
183 .map(|w| Word {
184 text: w["text"].as_str().unwrap_or_default().to_string(),
185 start_ms: w["start"].as_u64().unwrap_or(0) as u32,
186 end_ms: w["end"].as_u64().unwrap_or(0) as u32,
187 confidence: w["confidence"].as_f64().unwrap_or(0.0) as f32,
188 speaker: w["speaker"]
189 .as_str()
190 .and_then(|s| s.strip_prefix("speaker_").and_then(|n| n.parse().ok())),
191 })
192 .collect()
193 })
194 .unwrap_or_default();
195
196 let mut speaker_ids: Vec<u32> = words.iter().filter_map(|w| w.speaker).collect();
197 speaker_ids.sort();
198 speaker_ids.dedup();
199 let speakers: Vec<Speaker> =
200 speaker_ids.into_iter().map(|id| Speaker { id, label: None }).collect();
201
202 Ok(Transcript { text, words, speakers, confidence, language_detected })
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208
209 #[tokio::test]
210 async fn transcribe_stream_returns_explicit_unimplemented_error() {
211 let provider = AssemblyAiStt {
212 api_key: "test-key".to_string(),
213 client: reqwest::Client::new(),
214 base_url: "https://api.assemblyai.com".to_string(),
215 };
216
217 let result = provider
218 .transcribe_stream(Box::pin(futures::stream::empty()), &SttOptions::default())
219 .await;
220
221 match result {
222 Err(AudioError::Stt { provider, message }) => {
223 assert_eq!(provider, "assemblyai");
224 assert!(message.contains("not yet implemented"));
225 }
226 Err(err) => panic!("unexpected audio error: {err}"),
227 Ok(_) => panic!("expected explicit STT error"),
228 }
229 }
230}