adk_audio/providers/tts/
openai.rs1use std::pin::Pin;
4
5use async_trait::async_trait;
6use futures::Stream;
7
8use crate::error::{AudioError, AudioResult};
9use crate::frame::AudioFrame;
10use crate::providers::tts::CloudTtsConfig;
11use crate::traits::{TtsProvider, TtsRequest, Voice};
12
13pub struct OpenAiTts {
15 config: CloudTtsConfig,
16 client: reqwest::Client,
17 model: String,
18 voices: Vec<Voice>,
19}
20
21impl OpenAiTts {
22 pub fn from_env() -> AudioResult<Self> {
24 let api_key = std::env::var("OPENAI_API_KEY").map_err(|_| AudioError::Tts {
25 provider: "openai".into(),
26 message: "OPENAI_API_KEY not set".into(),
27 })?;
28 Ok(Self::new(CloudTtsConfig::new(api_key)))
29 }
30
31 pub fn new(config: CloudTtsConfig) -> Self {
33 let voices = vec![
34 Voice { id: "alloy".into(), name: "Alloy".into(), language: "en".into(), gender: None },
35 Voice {
36 id: "echo".into(),
37 name: "Echo".into(),
38 language: "en".into(),
39 gender: Some("male".into()),
40 },
41 Voice { id: "fable".into(), name: "Fable".into(), language: "en".into(), gender: None },
42 Voice {
43 id: "onyx".into(),
44 name: "Onyx".into(),
45 language: "en".into(),
46 gender: Some("male".into()),
47 },
48 Voice {
49 id: "nova".into(),
50 name: "Nova".into(),
51 language: "en".into(),
52 gender: Some("female".into()),
53 },
54 Voice {
55 id: "shimmer".into(),
56 name: "Shimmer".into(),
57 language: "en".into(),
58 gender: Some("female".into()),
59 },
60 ];
61 Self { config, client: reqwest::Client::new(), model: "tts-1".into(), voices }
62 }
63
64 pub fn hd(mut self) -> Self {
66 self.model = "tts-1-hd".into();
67 self
68 }
69
70 fn base_url(&self) -> &str {
71 self.config.base_url.as_deref().unwrap_or("https://api.openai.com")
72 }
73}
74
75#[async_trait]
76impl TtsProvider for OpenAiTts {
77 async fn synthesize(&self, request: &TtsRequest) -> AudioResult<AudioFrame> {
78 let voice = if request.voice.is_empty() { "alloy" } else { &request.voice };
79 let url = format!("{}/v1/audio/speech", self.base_url());
80
81 let body = serde_json::json!({
82 "model": self.model,
83 "input": request.text,
84 "voice": voice,
85 "response_format": "pcm",
86 "speed": request.speed,
87 });
88
89 let resp = self
90 .client
91 .post(&url)
92 .bearer_auth(&self.config.api_key)
93 .json(&body)
94 .send()
95 .await
96 .map_err(|e| AudioError::Tts { provider: "openai".into(), message: e.to_string() })?;
97
98 if !resp.status().is_success() {
99 return Err(AudioError::Tts {
100 provider: "openai".into(),
101 message: format!("HTTP {}", resp.status()),
102 });
103 }
104
105 let pcm = resp
106 .bytes()
107 .await
108 .map_err(|e| AudioError::Tts { provider: "openai".into(), message: e.to_string() })?;
109
110 Ok(AudioFrame::new(pcm, 24000, 1))
111 }
112
113 async fn synthesize_stream(
114 &self,
115 request: &TtsRequest,
116 ) -> AudioResult<Pin<Box<dyn Stream<Item = AudioResult<AudioFrame>> + Send>>> {
117 let frame = self.synthesize(request).await?;
120 Ok(Box::pin(futures::stream::once(async { Ok(frame) })))
121 }
122
123 fn voice_catalog(&self) -> &[Voice] {
124 &self.voices
125 }
126}