openai_agents_rust/voice/
mod.rs

1use async_trait::async_trait;
2use reqwest::Client;
3use reqwest::multipart::{Form, Part};
4use serde::Deserialize;
5
6use crate::config::Config;
7use crate::error::AgentError;
8
9/// Trait representing a generic voice pipeline.
10#[async_trait]
11pub trait VoicePipeline: Send + Sync {
12    /// Process an input audio buffer and return a textual transcription.
13    async fn transcribe(&self, audio: &[u8]) -> Result<String, AgentError>;
14
15    /// Convert text to synthesized audio bytes.
16    async fn synthesize(&self, text: &str) -> Result<Vec<u8>, AgentError>;
17}
18
19/// Speech‑to‑text trait.
20#[async_trait]
21pub trait Stt: Send + Sync {
22    async fn stt(&self, audio: &[u8]) -> Result<String, AgentError>;
23}
24
25/// Text‑to‑speech trait.
26#[async_trait]
27pub trait Tts: Send + Sync {
28    async fn tts(&self, text: &str) -> Result<Vec<u8>, AgentError>;
29}
30
31/// OpenAI‑compatible STT implementation (POST /v1/audio/transcriptions).
32pub struct OpenAiStt {
33    client: Client,
34    base_url: String,
35    model: String,
36    auth_token: Option<String>,
37}
38
39impl OpenAiStt {
40    pub fn new(config: Config) -> Self {
41        let client = Client::builder()
42            .user_agent("openai-agents-rust")
43            .build()
44            .expect("Failed to build reqwest client");
45        let auth_token = if config.api_key.is_empty() {
46            None
47        } else {
48            Some(config.api_key.clone())
49        };
50        Self {
51            client,
52            base_url: config.base_url.clone(),
53            model: config.model.clone(),
54            auth_token,
55        }
56    }
57
58    fn url(&self) -> String {
59        format!(
60            "{}/audio/transcriptions",
61            self.base_url.trim_end_matches('/')
62        )
63    }
64}
65
66#[derive(Deserialize)]
67struct SttResponse {
68    text: String,
69}
70
71#[async_trait]
72impl Stt for OpenAiStt {
73    async fn stt(&self, audio: &[u8]) -> Result<String, AgentError> {
74        let part = Part::bytes(audio.to_vec())
75            .file_name("audio.wav")
76            .mime_str("audio/wav")
77            .map_err(|e| AgentError::Other(format!("invalid audio mime: {}", e)))?;
78        let form = Form::new()
79            .text("model", self.model.clone())
80            .part("file", part);
81        let mut req = self.client.post(self.url());
82        if let Some(token) = &self.auth_token {
83            req = req.bearer_auth(token);
84        }
85        let resp = req.multipart(form).send().await.map_err(AgentError::from)?;
86        let status = resp.status();
87        let body = resp.text().await.map_err(AgentError::from)?;
88        if !status.is_success() {
89            return Err(AgentError::Other(format!(
90                "stt failed (status: {}): {}",
91                status, body
92            )));
93        }
94        let parsed: SttResponse = serde_json::from_str(&body)
95            .map_err(|e| AgentError::Other(format!("stt parse error: {} body={}", e, body)))?;
96        Ok(parsed.text)
97    }
98}
99
100/// OpenAI‑compatible TTS implementation (POST /v1/audio/speech).
101pub struct OpenAiTts {
102    client: Client,
103    base_url: String,
104    model: String,
105    voice: Option<String>,
106    format: Option<String>,
107    auth_token: Option<String>,
108}
109
110impl OpenAiTts {
111    pub fn new(config: Config) -> Self {
112        let client = Client::builder()
113            .user_agent("openai-agents-rust")
114            .build()
115            .expect("Failed to build reqwest client");
116        let auth_token = if config.api_key.is_empty() {
117            None
118        } else {
119            Some(config.api_key.clone())
120        };
121        Self {
122            client,
123            base_url: config.base_url.clone(),
124            model: config.model.clone(),
125            voice: Some("alloy".into()),
126            format: Some("wav".into()),
127            auth_token,
128        }
129    }
130
131    pub fn with_voice(mut self, voice: impl Into<String>) -> Self {
132        self.voice = Some(voice.into());
133        self
134    }
135    pub fn with_format(mut self, fmt: impl Into<String>) -> Self {
136        self.format = Some(fmt.into());
137        self
138    }
139    fn url(&self) -> String {
140        format!("{}/audio/speech", self.base_url.trim_end_matches('/'))
141    }
142}
143
144#[async_trait]
145impl Tts for OpenAiTts {
146    async fn tts(&self, text: &str) -> Result<Vec<u8>, AgentError> {
147        let mut body = serde_json::json!({
148            "model": self.model,
149            "input": text,
150        });
151        if let Some(v) = &self.voice {
152            body["voice"] = serde_json::json!(v);
153        }
154        if let Some(f) = &self.format {
155            body["format"] = serde_json::json!(f);
156        }
157        let mut req = self.client.post(self.url());
158        if let Some(token) = &self.auth_token {
159            req = req.bearer_auth(token);
160        }
161        let resp = req.json(&body).send().await.map_err(AgentError::from)?;
162        let status = resp.status();
163        let bytes = resp.bytes().await.map_err(AgentError::from)?;
164        if !status.is_success() {
165            let body = String::from_utf8_lossy(&bytes).to_string();
166            return Err(AgentError::Other(format!(
167                "tts failed (status: {}): {}",
168                status, body
169            )));
170        }
171        Ok(bytes.to_vec())
172    }
173}
174
175/// Composed pipeline using STT and TTS.
176pub struct HttpVoicePipeline {
177    stt: Box<dyn Stt>,
178    tts: Box<dyn Tts>,
179}
180
181impl HttpVoicePipeline {
182    pub fn new(stt: Box<dyn Stt>, tts: Box<dyn Tts>) -> Self {
183        Self { stt, tts }
184    }
185}
186
187#[async_trait]
188impl VoicePipeline for HttpVoicePipeline {
189    async fn transcribe(&self, audio: &[u8]) -> Result<String, AgentError> {
190        self.stt.stt(audio).await
191    }
192    async fn synthesize(&self, text: &str) -> Result<Vec<u8>, AgentError> {
193        self.tts.tts(text).await
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200    use axum::http::StatusCode;
201    use axum::response::IntoResponse;
202    use axum::{Router, routing::post};
203
204    #[tokio::test]
205    async fn stt_tts_roundtrip_against_mock_server() {
206        // Simple mock endpoints that accept any payloads
207        let app = Router::new()
208            .route(
209                "/audio/transcriptions",
210                post(|| async move {
211                    let body = serde_json::json!({"text":"hello world"});
212                    (StatusCode::OK, axum::Json(body))
213                }),
214            )
215            .route(
216                "/audio/speech",
217                post(|axum::Json(_): axum::Json<serde_json::Value>| async move {
218                    let audio: Vec<u8> = vec![1, 2, 3, 4, 5];
219                    (StatusCode::OK, audio).into_response()
220                }),
221            );
222        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
223        let addr = listener.local_addr().unwrap();
224        tokio::spawn(async move {
225            axum::serve(listener, app.into_make_service())
226                .await
227                .unwrap();
228        });
229
230    let _ = dotenvy::dotenv();
231    let mut cfg = crate::config::load_from_env();
232    cfg.api_key = String::new();
233    cfg.model = if cfg.model.is_empty() { "whisper-1".into() } else { cfg.model };
234    cfg.base_url = format!("http://{}:{}", addr.ip(), addr.port());
235        let stt = OpenAiStt::new(cfg.clone());
236        let tts = OpenAiTts::new(cfg.clone());
237        let pipe = HttpVoicePipeline::new(Box::new(stt), Box::new(tts));
238
239        let transcript = pipe.transcribe(b"ignored").await.unwrap();
240        assert_eq!(transcript, "hello world");
241        let audio = pipe.synthesize("Hi").await.unwrap();
242        assert_eq!(audio, vec![1, 2, 3, 4, 5]);
243    }
244}