Skip to main content

adk_audio/providers/tts/
openai.rs

1//! OpenAI TTS provider.
2
3use std::pin::Pin;
4
5use async_trait::async_trait;
6use futures::Stream;
7
8use crate::error::{AudioError, AudioResult};
9use crate::frame::AudioFrame;
10use crate::providers::tts::CloudTtsConfig;
11use crate::traits::{TtsProvider, TtsRequest, Voice};
12
13/// OpenAI TTS provider using the `/v1/audio/speech` endpoint.
14pub struct OpenAiTts {
15    config: CloudTtsConfig,
16    client: reqwest::Client,
17    model: String,
18    voices: Vec<Voice>,
19}
20
21impl OpenAiTts {
22    /// Create from environment variable `OPENAI_API_KEY`.
23    pub fn from_env() -> AudioResult<Self> {
24        let api_key = std::env::var("OPENAI_API_KEY").map_err(|_| AudioError::Tts {
25            provider: "openai".into(),
26            message: "OPENAI_API_KEY not set".into(),
27        })?;
28        Ok(Self::new(CloudTtsConfig::new(api_key)))
29    }
30
31    /// Create with explicit config.
32    pub fn new(config: CloudTtsConfig) -> Self {
33        let voices = vec![
34            Voice { id: "alloy".into(), name: "Alloy".into(), language: "en".into(), gender: None },
35            Voice {
36                id: "echo".into(),
37                name: "Echo".into(),
38                language: "en".into(),
39                gender: Some("male".into()),
40            },
41            Voice { id: "fable".into(), name: "Fable".into(), language: "en".into(), gender: None },
42            Voice {
43                id: "onyx".into(),
44                name: "Onyx".into(),
45                language: "en".into(),
46                gender: Some("male".into()),
47            },
48            Voice {
49                id: "nova".into(),
50                name: "Nova".into(),
51                language: "en".into(),
52                gender: Some("female".into()),
53            },
54            Voice {
55                id: "shimmer".into(),
56                name: "Shimmer".into(),
57                language: "en".into(),
58                gender: Some("female".into()),
59            },
60        ];
61        Self { config, client: reqwest::Client::new(), model: "tts-1".into(), voices }
62    }
63
64    /// Use the HD model (tts-1-hd) for higher quality.
65    pub fn hd(mut self) -> Self {
66        self.model = "tts-1-hd".into();
67        self
68    }
69
70    fn base_url(&self) -> &str {
71        self.config.base_url.as_deref().unwrap_or("https://api.openai.com")
72    }
73}
74
75#[async_trait]
76impl TtsProvider for OpenAiTts {
77    async fn synthesize(&self, request: &TtsRequest) -> AudioResult<AudioFrame> {
78        let voice = if request.voice.is_empty() { "alloy" } else { &request.voice };
79        let url = format!("{}/v1/audio/speech", self.base_url());
80
81        let body = serde_json::json!({
82            "model": self.model,
83            "input": request.text,
84            "voice": voice,
85            "response_format": "pcm",
86            "speed": request.speed,
87        });
88
89        let resp = self
90            .client
91            .post(&url)
92            .bearer_auth(&self.config.api_key)
93            .json(&body)
94            .send()
95            .await
96            .map_err(|e| AudioError::Tts { provider: "openai".into(), message: e.to_string() })?;
97
98        if !resp.status().is_success() {
99            return Err(AudioError::Tts {
100                provider: "openai".into(),
101                message: format!("HTTP {}", resp.status()),
102            });
103        }
104
105        let pcm = resp
106            .bytes()
107            .await
108            .map_err(|e| AudioError::Tts { provider: "openai".into(), message: e.to_string() })?;
109
110        Ok(AudioFrame::new(pcm, 24000, 1))
111    }
112
113    async fn synthesize_stream(
114        &self,
115        request: &TtsRequest,
116    ) -> AudioResult<Pin<Box<dyn Stream<Item = AudioResult<AudioFrame>> + Send>>> {
117        // OpenAI TTS doesn't have a native streaming endpoint for PCM,
118        // so we fetch the full response and yield it as a single frame.
119        let frame = self.synthesize(request).await?;
120        Ok(Box::pin(futures::stream::once(async { Ok(frame) })))
121    }
122
123    fn voice_catalog(&self) -> &[Voice] {
124        &self.voices
125    }
126}