openai_agents_rust/voice/
mod.rs1use async_trait::async_trait;
2use reqwest::Client;
3use reqwest::multipart::{Form, Part};
4use serde::Deserialize;
5
6use crate::config::Config;
7use crate::error::AgentError;
8
9#[async_trait]
11pub trait VoicePipeline: Send + Sync {
12 async fn transcribe(&self, audio: &[u8]) -> Result<String, AgentError>;
14
15 async fn synthesize(&self, text: &str) -> Result<Vec<u8>, AgentError>;
17}
18
19#[async_trait]
21pub trait Stt: Send + Sync {
22 async fn stt(&self, audio: &[u8]) -> Result<String, AgentError>;
23}
24
25#[async_trait]
27pub trait Tts: Send + Sync {
28 async fn tts(&self, text: &str) -> Result<Vec<u8>, AgentError>;
29}
30
31pub struct OpenAiStt {
33 client: Client,
34 base_url: String,
35 model: String,
36 auth_token: Option<String>,
37}
38
39impl OpenAiStt {
40 pub fn new(config: Config) -> Self {
41 let client = Client::builder()
42 .user_agent("openai-agents-rust")
43 .build()
44 .expect("Failed to build reqwest client");
45 let auth_token = if config.api_key.is_empty() {
46 None
47 } else {
48 Some(config.api_key.clone())
49 };
50 Self {
51 client,
52 base_url: config.base_url.clone(),
53 model: config.model.clone(),
54 auth_token,
55 }
56 }
57
58 fn url(&self) -> String {
59 format!(
60 "{}/audio/transcriptions",
61 self.base_url.trim_end_matches('/')
62 )
63 }
64}
65
66#[derive(Deserialize)]
67struct SttResponse {
68 text: String,
69}
70
71#[async_trait]
72impl Stt for OpenAiStt {
73 async fn stt(&self, audio: &[u8]) -> Result<String, AgentError> {
74 let part = Part::bytes(audio.to_vec())
75 .file_name("audio.wav")
76 .mime_str("audio/wav")
77 .map_err(|e| AgentError::Other(format!("invalid audio mime: {}", e)))?;
78 let form = Form::new()
79 .text("model", self.model.clone())
80 .part("file", part);
81 let mut req = self.client.post(self.url());
82 if let Some(token) = &self.auth_token {
83 req = req.bearer_auth(token);
84 }
85 let resp = req.multipart(form).send().await.map_err(AgentError::from)?;
86 let status = resp.status();
87 let body = resp.text().await.map_err(AgentError::from)?;
88 if !status.is_success() {
89 return Err(AgentError::Other(format!(
90 "stt failed (status: {}): {}",
91 status, body
92 )));
93 }
94 let parsed: SttResponse = serde_json::from_str(&body)
95 .map_err(|e| AgentError::Other(format!("stt parse error: {} body={}", e, body)))?;
96 Ok(parsed.text)
97 }
98}
99
100pub struct OpenAiTts {
102 client: Client,
103 base_url: String,
104 model: String,
105 voice: Option<String>,
106 format: Option<String>,
107 auth_token: Option<String>,
108}
109
110impl OpenAiTts {
111 pub fn new(config: Config) -> Self {
112 let client = Client::builder()
113 .user_agent("openai-agents-rust")
114 .build()
115 .expect("Failed to build reqwest client");
116 let auth_token = if config.api_key.is_empty() {
117 None
118 } else {
119 Some(config.api_key.clone())
120 };
121 Self {
122 client,
123 base_url: config.base_url.clone(),
124 model: config.model.clone(),
125 voice: Some("alloy".into()),
126 format: Some("wav".into()),
127 auth_token,
128 }
129 }
130
131 pub fn with_voice(mut self, voice: impl Into<String>) -> Self {
132 self.voice = Some(voice.into());
133 self
134 }
135 pub fn with_format(mut self, fmt: impl Into<String>) -> Self {
136 self.format = Some(fmt.into());
137 self
138 }
139 fn url(&self) -> String {
140 format!("{}/audio/speech", self.base_url.trim_end_matches('/'))
141 }
142}
143
144#[async_trait]
145impl Tts for OpenAiTts {
146 async fn tts(&self, text: &str) -> Result<Vec<u8>, AgentError> {
147 let mut body = serde_json::json!({
148 "model": self.model,
149 "input": text,
150 });
151 if let Some(v) = &self.voice {
152 body["voice"] = serde_json::json!(v);
153 }
154 if let Some(f) = &self.format {
155 body["format"] = serde_json::json!(f);
156 }
157 let mut req = self.client.post(self.url());
158 if let Some(token) = &self.auth_token {
159 req = req.bearer_auth(token);
160 }
161 let resp = req.json(&body).send().await.map_err(AgentError::from)?;
162 let status = resp.status();
163 let bytes = resp.bytes().await.map_err(AgentError::from)?;
164 if !status.is_success() {
165 let body = String::from_utf8_lossy(&bytes).to_string();
166 return Err(AgentError::Other(format!(
167 "tts failed (status: {}): {}",
168 status, body
169 )));
170 }
171 Ok(bytes.to_vec())
172 }
173}
174
175pub struct HttpVoicePipeline {
177 stt: Box<dyn Stt>,
178 tts: Box<dyn Tts>,
179}
180
181impl HttpVoicePipeline {
182 pub fn new(stt: Box<dyn Stt>, tts: Box<dyn Tts>) -> Self {
183 Self { stt, tts }
184 }
185}
186
187#[async_trait]
188impl VoicePipeline for HttpVoicePipeline {
189 async fn transcribe(&self, audio: &[u8]) -> Result<String, AgentError> {
190 self.stt.stt(audio).await
191 }
192 async fn synthesize(&self, text: &str) -> Result<Vec<u8>, AgentError> {
193 self.tts.tts(text).await
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200 use axum::http::StatusCode;
201 use axum::response::IntoResponse;
202 use axum::{Router, routing::post};
203
204 #[tokio::test]
205 async fn stt_tts_roundtrip_against_mock_server() {
206 let app = Router::new()
208 .route(
209 "/audio/transcriptions",
210 post(|| async move {
211 let body = serde_json::json!({"text":"hello world"});
212 (StatusCode::OK, axum::Json(body))
213 }),
214 )
215 .route(
216 "/audio/speech",
217 post(|axum::Json(_): axum::Json<serde_json::Value>| async move {
218 let audio: Vec<u8> = vec![1, 2, 3, 4, 5];
219 (StatusCode::OK, audio).into_response()
220 }),
221 );
222 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
223 let addr = listener.local_addr().unwrap();
224 tokio::spawn(async move {
225 axum::serve(listener, app.into_make_service())
226 .await
227 .unwrap();
228 });
229
230 let _ = dotenvy::dotenv();
231 let mut cfg = crate::config::load_from_env();
232 cfg.api_key = String::new();
233 cfg.model = if cfg.model.is_empty() { "whisper-1".into() } else { cfg.model };
234 cfg.base_url = format!("http://{}:{}", addr.ip(), addr.port());
235 let stt = OpenAiStt::new(cfg.clone());
236 let tts = OpenAiTts::new(cfg.clone());
237 let pipe = HttpVoicePipeline::new(Box::new(stt), Box::new(tts));
238
239 let transcript = pipe.transcribe(b"ignored").await.unwrap();
240 assert_eq!(transcript, "hello world");
241 let audio = pipe.synthesize("Hi").await.unwrap();
242 assert_eq!(audio, vec![1, 2, 3, 4, 5]);
243 }
244}