gemini-tts-cli 0.1.1

Agent-friendly Gemini text-to-speech CLI for expressive scripts, voices, tags, and audio files
use base64::Engine;
use reqwest::blocking::Client;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::time::Duration;

use crate::error::AppError;
use crate::prompt::SpeakerVoice;

#[derive(Debug, Clone)]
pub struct GenerateRequest {
    pub model: String,
    pub prompt: String,
    pub voice: String,
    pub speakers: Vec<SpeakerVoice>,
    pub timeout_seconds: u64,
}

#[derive(Debug, Clone, Serialize)]
pub struct GenerateResponse {
    pub pcm: Vec<u8>,
    pub mime_type: String,
    pub model: String,
    pub prompt_chars: usize,
}

pub fn generate(api_key: &str, request: &GenerateRequest) -> Result<GenerateResponse, AppError> {
    let client = Client::builder()
        .timeout(Duration::from_secs(request.timeout_seconds.max(10)))
        .build()?;

    let mut attempt = 0;
    loop {
        match generate_once(&client, api_key, request) {
            Ok(response) => return Ok(response),
            Err(error) if should_retry(&error) && attempt < 2 => {
                attempt += 1;
                std::thread::sleep(retry_delay(&error, attempt));
            }
            Err(error) => return Err(error),
        }
    }
}

fn generate_once(
    client: &Client,
    api_key: &str,
    request: &GenerateRequest,
) -> Result<GenerateResponse, AppError> {
    let url = format!(
        "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent",
        request.model
    );
    let payload = request_payload(request);
    let response = client
        .post(url)
        .header("x-goog-api-key", api_key)
        .header("Content-Type", "application/json")
        .json(&payload)
        .send()?;

    let status = response.status();
    let text = response.text()?;
    if !status.is_success() {
        let message = extract_error_message(&text).unwrap_or_else(|| text.trim().to_string());
        if status.as_u16() == 429 {
            return Err(AppError::RateLimited(message));
        }
        if matches!(status.as_u16(), 400 | 401 | 403) {
            return Err(AppError::Config(format!(
                "Gemini API rejected the request ({status}): {message}"
            )));
        }
        return Err(AppError::Transient(format!(
            "Gemini API returned {status}: {message}"
        )));
    }

    let json: Value = serde_json::from_str(&text)
        .map_err(|e| AppError::Transient(format!("Gemini returned invalid JSON: {e}")))?;
    let (data, mime_type) = extract_audio(&json)?;
    let pcm = base64::engine::general_purpose::STANDARD
        .decode(data)
        .map_err(|e| AppError::Transient(format!("Gemini returned invalid base64 audio: {e}")))?;

    Ok(GenerateResponse {
        pcm,
        mime_type: mime_type.unwrap_or_else(|| "audio/l16; rate=24000; channels=1".into()),
        model: request.model.clone(),
        prompt_chars: request.prompt.chars().count(),
    })
}

fn should_retry(error: &AppError) -> bool {
    matches!(
        error,
        AppError::Transient(_) | AppError::RateLimited(_) | AppError::Http(_)
    )
}

fn retry_delay(error: &AppError, attempt: u32) -> Duration {
    if let AppError::RateLimited(message) = error {
        if let Some(delay) = parse_retry_delay(message) {
            return delay;
        }
        return Duration::from_secs(2 * attempt as u64);
    }
    Duration::from_millis(400 * attempt as u64)
}

fn parse_retry_delay(message: &str) -> Option<Duration> {
    let rest = message.split("Please retry in ").nth(1)?;
    let number = rest
        .chars()
        .take_while(|ch| ch.is_ascii_digit() || *ch == '.')
        .collect::<String>();
    let seconds = number.parse::<f64>().ok()?;
    if !seconds.is_finite() || seconds < 0.0 {
        return None;
    }
    Some(Duration::from_millis(
        (seconds * 1000.0).ceil() as u64 + 500,
    ))
}

pub fn request_payload(request: &GenerateRequest) -> Value {
    let speech_config = if request.speakers.is_empty() {
        serde_json::json!({
            "voiceConfig": {
                "prebuiltVoiceConfig": {
                    "voiceName": request.voice
                }
            }
        })
    } else {
        let speaker_voice_configs: Vec<Value> = request
            .speakers
            .iter()
            .map(|speaker| {
                serde_json::json!({
                    "speaker": speaker.speaker,
                    "voiceConfig": {
                        "prebuiltVoiceConfig": {
                            "voiceName": speaker.voice
                        }
                    }
                })
            })
            .collect();
        serde_json::json!({
            "multiSpeakerVoiceConfig": {
                "speakerVoiceConfigs": speaker_voice_configs
            }
        })
    };

    serde_json::json!({
        "contents": [{
            "parts": [{
                "text": request.prompt
            }]
        }],
        "generationConfig": {
            "responseModalities": ["AUDIO"],
            "speechConfig": speech_config
        }
    })
}

fn extract_audio(json: &Value) -> Result<(&str, Option<String>), AppError> {
    let candidates = json
        .get("candidates")
        .and_then(Value::as_array)
        .ok_or_else(|| AppError::Transient("Gemini response had no candidates array".into()))?;

    for candidate in candidates {
        let Some(parts) = candidate
            .pointer("/content/parts")
            .and_then(Value::as_array)
        else {
            continue;
        };
        for part in parts {
            let inline = part.get("inlineData").or_else(|| part.get("inline_data"));
            let Some(inline) = inline else {
                continue;
            };
            let data = inline
                .get("data")
                .and_then(Value::as_str)
                .ok_or_else(|| AppError::Transient("Gemini inline audio missing data".into()))?;
            let mime = inline
                .get("mimeType")
                .or_else(|| inline.get("mime_type"))
                .and_then(Value::as_str)
                .map(str::to_string);
            return Ok((data, mime));
        }
    }

    Err(AppError::Transient(
        "Gemini response did not contain inline audio data".into(),
    ))
}

fn extract_error_message(text: &str) -> Option<String> {
    let json: Value = serde_json::from_str(text).ok()?;
    json.pointer("/error/message")
        .and_then(Value::as_str)
        .map(str::to_string)
}

#[derive(Debug, Clone, Deserialize)]
struct ModelResponse {
    name: Option<String>,
}

pub fn check_model(api_key: &str, model: &str, timeout_seconds: u64) -> Result<String, AppError> {
    let client = Client::builder()
        .timeout(Duration::from_secs(timeout_seconds.max(10)))
        .build()?;
    let url = format!("https://generativelanguage.googleapis.com/v1beta/models/{model}");
    let response = client.get(url).header("x-goog-api-key", api_key).send()?;
    let status = response.status();
    let text = response.text()?;
    if !status.is_success() {
        let message = extract_error_message(&text).unwrap_or_else(|| text.trim().to_string());
        if status.as_u16() == 429 {
            return Err(AppError::RateLimited(message));
        }
        return Err(AppError::Config(format!(
            "model check failed ({status}): {message}"
        )));
    }
    let parsed: ModelResponse = serde_json::from_str(&text)
        .map_err(|e| AppError::Transient(format!("model check returned invalid JSON: {e}")))?;
    Ok(parsed.name.unwrap_or_else(|| model.into()))
}

#[cfg(test)]
mod tests {
    use super::parse_retry_delay;

    #[test]
    fn parses_google_retry_hint() {
        let delay = parse_retry_delay("Quota exceeded. Please retry in 2.830310597s.").unwrap();
        assert!(delay.as_millis() >= 3_330);
    }

    #[test]
    fn missing_retry_hint_returns_none() {
        assert!(parse_retry_delay("Quota exceeded").is_none());
    }
}