Skip to main content

adk_audio/providers/stt/
gemini.rs

1//! Gemini audio understanding STT provider.
2//!
3//! Uses the `generateContent` API with audio input to transcribe speech.
4//! Gemini models can process audio inline and return text transcriptions.
5//!
6//! This is a batch-mode provider — audio is sent as a single request and
7//! the full transcript is returned. For real-time streaming transcription,
8//! use the Gemini Live API via `adk-realtime`.
9
10use std::pin::Pin;
11
12use async_trait::async_trait;
13use futures::Stream;
14
15use crate::error::{AudioError, AudioResult};
16use crate::frame::AudioFrame;
17use crate::providers::stt::frame_to_wav_bytes;
18use crate::traits::{SttOptions, SttProvider, Transcript};
19
20/// Default model for Gemini STT (audio understanding).
21const DEFAULT_MODEL: &str = "gemini-3-flash-preview";
22
23/// Gemini STT provider using `generateContent` with audio input.
24///
25/// Sends audio as inline data to the Gemini API and receives a text
26/// transcription. Supports language detection and optional prompting
27/// for specialized transcription tasks.
28///
29/// # Example
30///
31/// ```rust,ignore
32/// use adk_audio::{GeminiStt, SttProvider, SttOptions, AudioFrame};
33///
34/// let stt = GeminiStt::from_env()?;
35/// let transcript = stt.transcribe(&audio_frame, &SttOptions::default()).await?;
36/// println!("Transcribed: {}", transcript.text);
37/// ```
38pub struct GeminiStt {
39    api_key: String,
40    client: reqwest::Client,
41    model: String,
42    /// Optional custom prompt for transcription.
43    prompt: String,
44}
45
46impl GeminiStt {
47    /// Create from environment variable `GEMINI_API_KEY` or `GOOGLE_API_KEY`.
48    pub fn from_env() -> AudioResult<Self> {
49        let api_key = std::env::var("GEMINI_API_KEY")
50            .or_else(|_| std::env::var("GOOGLE_API_KEY"))
51            .map_err(|_| AudioError::Stt {
52                provider: "gemini".into(),
53                message: "GEMINI_API_KEY or GOOGLE_API_KEY not set".into(),
54            })?;
55        Ok(Self {
56            api_key,
57            client: reqwest::Client::new(),
58            model: DEFAULT_MODEL.into(),
59            prompt: "Transcribe this audio accurately. Return only the transcription text, no commentary.".into(),
60        })
61    }
62
63    /// Create with an explicit API key.
64    pub fn new(api_key: impl Into<String>) -> Self {
65        Self {
66            api_key: api_key.into(),
67            client: reqwest::Client::new(),
68            model: DEFAULT_MODEL.into(),
69            prompt: "Transcribe this audio accurately. Return only the transcription text, no commentary.".into(),
70        }
71    }
72
73    /// Set the model to use for transcription.
74    pub fn with_model(mut self, model: impl Into<String>) -> Self {
75        self.model = model.into();
76        self
77    }
78
79    /// Set a custom transcription prompt.
80    ///
81    /// The prompt is sent alongside the audio to guide the model's output.
82    /// For example: "Transcribe this audio in English with punctuation."
83    pub fn with_prompt(mut self, prompt: impl Into<String>) -> Self {
84        self.prompt = prompt.into();
85        self
86    }
87
88    fn url(&self) -> String {
89        format!(
90            "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent",
91            self.model
92        )
93    }
94}
95
96#[async_trait]
97impl SttProvider for GeminiStt {
98    async fn transcribe(&self, audio: &AudioFrame, opts: &SttOptions) -> AudioResult<Transcript> {
99        let wav_bytes = frame_to_wav_bytes(audio)?;
100
101        use base64::Engine;
102        let audio_b64 = base64::engine::general_purpose::STANDARD.encode(&wav_bytes);
103
104        // Build the prompt with optional language hint
105        let prompt = if let Some(ref lang) = opts.language {
106            format!("{} The audio is in {lang}.", self.prompt)
107        } else {
108            self.prompt.clone()
109        };
110
111        let body = serde_json::json!({
112            "contents": [{
113                "parts": [
114                    {"text": prompt},
115                    {
116                        "inlineData": {
117                            "mimeType": "audio/wav",
118                            "data": audio_b64
119                        }
120                    }
121                ]
122            }]
123        });
124
125        let resp = self
126            .client
127            .post(self.url())
128            .header("x-goog-api-key", &self.api_key)
129            .json(&body)
130            .send()
131            .await
132            .map_err(|e| AudioError::Stt { provider: "gemini".into(), message: e.to_string() })?;
133
134        if !resp.status().is_success() {
135            let status = resp.status();
136            let body = resp.text().await.unwrap_or_default();
137            return Err(AudioError::Stt {
138                provider: "gemini".into(),
139                message: format!("HTTP {status}: {body}"),
140            });
141        }
142
143        let json: serde_json::Value = resp
144            .json()
145            .await
146            .map_err(|e| AudioError::Stt { provider: "gemini".into(), message: e.to_string() })?;
147
148        let text = json["candidates"][0]["content"]["parts"][0]["text"]
149            .as_str()
150            .unwrap_or_default()
151            .trim()
152            .to_string();
153
154        Ok(Transcript {
155            text,
156            words: vec![],
157            speakers: vec![],
158            confidence: 1.0,
159            language_detected: opts.language.clone(),
160        })
161    }
162
163    async fn transcribe_stream(
164        &self,
165        _audio: Pin<Box<dyn Stream<Item = AudioFrame> + Send>>,
166        _opts: &SttOptions,
167    ) -> AudioResult<Pin<Box<dyn Stream<Item = AudioResult<Transcript>> + Send>>> {
168        // Gemini generateContent doesn't support streaming STT.
169        // For real-time streaming, use adk-realtime with Gemini Live.
170        Ok(Box::pin(futures::stream::empty()))
171    }
172}