adk_audio/providers/stt/
gemini.rs1use std::pin::Pin;
11
12use async_trait::async_trait;
13use futures::Stream;
14
15use crate::error::{AudioError, AudioResult};
16use crate::frame::AudioFrame;
17use crate::providers::stt::frame_to_wav_bytes;
18use crate::traits::{SttOptions, SttProvider, Transcript};
19
20const DEFAULT_MODEL: &str = "gemini-3-flash-preview";
22
23pub struct GeminiStt {
39 api_key: String,
40 client: reqwest::Client,
41 model: String,
42 prompt: String,
44}
45
46impl GeminiStt {
47 pub fn from_env() -> AudioResult<Self> {
49 let api_key = std::env::var("GEMINI_API_KEY")
50 .or_else(|_| std::env::var("GOOGLE_API_KEY"))
51 .map_err(|_| AudioError::Stt {
52 provider: "gemini".into(),
53 message: "GEMINI_API_KEY or GOOGLE_API_KEY not set".into(),
54 })?;
55 Ok(Self {
56 api_key,
57 client: reqwest::Client::new(),
58 model: DEFAULT_MODEL.into(),
59 prompt: "Transcribe this audio accurately. Return only the transcription text, no commentary.".into(),
60 })
61 }
62
63 pub fn new(api_key: impl Into<String>) -> Self {
65 Self {
66 api_key: api_key.into(),
67 client: reqwest::Client::new(),
68 model: DEFAULT_MODEL.into(),
69 prompt: "Transcribe this audio accurately. Return only the transcription text, no commentary.".into(),
70 }
71 }
72
73 pub fn with_model(mut self, model: impl Into<String>) -> Self {
75 self.model = model.into();
76 self
77 }
78
79 pub fn with_prompt(mut self, prompt: impl Into<String>) -> Self {
84 self.prompt = prompt.into();
85 self
86 }
87
88 fn url(&self) -> String {
89 format!(
90 "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent",
91 self.model
92 )
93 }
94}
95
96#[async_trait]
97impl SttProvider for GeminiStt {
98 async fn transcribe(&self, audio: &AudioFrame, opts: &SttOptions) -> AudioResult<Transcript> {
99 let wav_bytes = frame_to_wav_bytes(audio)?;
100
101 use base64::Engine;
102 let audio_b64 = base64::engine::general_purpose::STANDARD.encode(&wav_bytes);
103
104 let prompt = if let Some(ref lang) = opts.language {
106 format!("{} The audio is in {lang}.", self.prompt)
107 } else {
108 self.prompt.clone()
109 };
110
111 let body = serde_json::json!({
112 "contents": [{
113 "parts": [
114 {"text": prompt},
115 {
116 "inlineData": {
117 "mimeType": "audio/wav",
118 "data": audio_b64
119 }
120 }
121 ]
122 }]
123 });
124
125 let resp = self
126 .client
127 .post(self.url())
128 .header("x-goog-api-key", &self.api_key)
129 .json(&body)
130 .send()
131 .await
132 .map_err(|e| AudioError::Stt { provider: "gemini".into(), message: e.to_string() })?;
133
134 if !resp.status().is_success() {
135 let status = resp.status();
136 let body = resp.text().await.unwrap_or_default();
137 return Err(AudioError::Stt {
138 provider: "gemini".into(),
139 message: format!("HTTP {status}: {body}"),
140 });
141 }
142
143 let json: serde_json::Value = resp
144 .json()
145 .await
146 .map_err(|e| AudioError::Stt { provider: "gemini".into(), message: e.to_string() })?;
147
148 let text = json["candidates"][0]["content"]["parts"][0]["text"]
149 .as_str()
150 .unwrap_or_default()
151 .trim()
152 .to_string();
153
154 Ok(Transcript {
155 text,
156 words: vec![],
157 speakers: vec![],
158 confidence: 1.0,
159 language_detected: opts.language.clone(),
160 })
161 }
162
163 async fn transcribe_stream(
164 &self,
165 _audio: Pin<Box<dyn Stream<Item = AudioFrame> + Send>>,
166 _opts: &SttOptions,
167 ) -> AudioResult<Pin<Box<dyn Stream<Item = AudioResult<Transcript>> + Send>>> {
168 Ok(Box::pin(futures::stream::empty()))
171 }
172}