use std::pin::Pin;
use async_trait::async_trait;
use futures::Stream;
use crate::error::{AudioError, AudioResult};
use crate::frame::AudioFrame;
use crate::providers::stt::frame_to_wav_bytes;
use crate::traits::{SttOptions, SttProvider, Transcript};
const DEFAULT_MODEL: &str = "gemini-3-flash-preview";
pub struct GeminiStt {
api_key: String,
client: reqwest::Client,
model: String,
prompt: String,
}
impl GeminiStt {
pub fn from_env() -> AudioResult<Self> {
let api_key = std::env::var("GEMINI_API_KEY")
.or_else(|_| std::env::var("GOOGLE_API_KEY"))
.map_err(|_| AudioError::Stt {
provider: "gemini".into(),
message: "GEMINI_API_KEY or GOOGLE_API_KEY not set".into(),
})?;
Ok(Self {
api_key,
client: reqwest::Client::new(),
model: DEFAULT_MODEL.into(),
prompt: "Transcribe this audio accurately. Return only the transcription text, no commentary.".into(),
})
}
pub fn new(api_key: impl Into<String>) -> Self {
Self {
api_key: api_key.into(),
client: reqwest::Client::new(),
model: DEFAULT_MODEL.into(),
prompt: "Transcribe this audio accurately. Return only the transcription text, no commentary.".into(),
}
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
pub fn with_prompt(mut self, prompt: impl Into<String>) -> Self {
self.prompt = prompt.into();
self
}
fn url(&self) -> String {
format!(
"https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent",
self.model
)
}
}
#[async_trait]
impl SttProvider for GeminiStt {
async fn transcribe(&self, audio: &AudioFrame, opts: &SttOptions) -> AudioResult<Transcript> {
let wav_bytes = frame_to_wav_bytes(audio)?;
use base64::Engine;
let audio_b64 = base64::engine::general_purpose::STANDARD.encode(&wav_bytes);
let prompt = if let Some(ref lang) = opts.language {
format!("{} The audio is in {lang}.", self.prompt)
} else {
self.prompt.clone()
};
let body = serde_json::json!({
"contents": [{
"parts": [
{"text": prompt},
{
"inlineData": {
"mimeType": "audio/wav",
"data": audio_b64
}
}
]
}]
});
let resp = self
.client
.post(self.url())
.header("x-goog-api-key", &self.api_key)
.json(&body)
.send()
.await
.map_err(|e| AudioError::Stt { provider: "gemini".into(), message: e.to_string() })?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(AudioError::Stt {
provider: "gemini".into(),
message: format!("HTTP {status}: {body}"),
});
}
let json: serde_json::Value = resp
.json()
.await
.map_err(|e| AudioError::Stt { provider: "gemini".into(), message: e.to_string() })?;
let text = json["candidates"][0]["content"]["parts"][0]["text"]
.as_str()
.unwrap_or_default()
.trim()
.to_string();
Ok(Transcript {
text,
words: vec![],
speakers: vec![],
confidence: 1.0,
language_detected: opts.language.clone(),
})
}
async fn transcribe_stream(
&self,
_audio: Pin<Box<dyn Stream<Item = AudioFrame> + Send>>,
_opts: &SttOptions,
) -> AudioResult<Pin<Box<dyn Stream<Item = AudioResult<Transcript>> + Send>>> {
Ok(Box::pin(futures::stream::empty()))
}
}