1use crate::audio::response::TranscriptionResponse;
31use crate::common::auth::AuthProvider;
32use crate::common::client::create_http_client;
33use crate::common::errors::{ErrorResponse, OpenAIToolError, Result};
34use request::multipart::{Form, Part};
35use serde::{Deserialize, Serialize};
36use std::path::Path;
37use std::time::Duration;
38
39const AUDIO_PATH: &str = "audio";
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
44pub enum TtsModel {
45 #[serde(rename = "tts-1")]
47 #[default]
48 Tts1,
49 #[serde(rename = "tts-1-hd")]
51 Tts1Hd,
52 #[serde(rename = "gpt-4o-mini-tts")]
54 Gpt4oMiniTts,
55}
56
57impl TtsModel {
58 pub fn as_str(&self) -> &'static str {
60 match self {
61 Self::Tts1 => "tts-1",
62 Self::Tts1Hd => "tts-1-hd",
63 Self::Gpt4oMiniTts => "gpt-4o-mini-tts",
64 }
65 }
66
67 pub fn supports_instructions(&self) -> bool {
82 matches!(self, Self::Gpt4oMiniTts)
83 }
84}
85
86impl std::fmt::Display for TtsModel {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 write!(f, "{}", self.as_str())
89 }
90}
91
92#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
94#[serde(rename_all = "lowercase")]
95pub enum Voice {
96 #[default]
98 Alloy,
99 Ash,
101 Ballad,
103 Cedar,
105 Coral,
107 Echo,
109 Fable,
111 Marin,
113 Nova,
115 Onyx,
117 Sage,
119 Shimmer,
121 Verse,
123}
124
125impl Voice {
126 pub fn as_str(&self) -> &'static str {
128 match self {
129 Self::Alloy => "alloy",
130 Self::Ash => "ash",
131 Self::Ballad => "ballad",
132 Self::Cedar => "cedar",
133 Self::Coral => "coral",
134 Self::Echo => "echo",
135 Self::Fable => "fable",
136 Self::Marin => "marin",
137 Self::Nova => "nova",
138 Self::Onyx => "onyx",
139 Self::Sage => "sage",
140 Self::Shimmer => "shimmer",
141 Self::Verse => "verse",
142 }
143 }
144}
145
146impl std::fmt::Display for Voice {
147 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
148 write!(f, "{}", self.as_str())
149 }
150}
151
152#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
154#[serde(rename_all = "lowercase")]
155pub enum AudioFormat {
156 #[default]
158 Mp3,
159 Opus,
161 Aac,
163 Flac,
165 Wav,
167 Pcm,
169}
170
171impl AudioFormat {
172 pub fn as_str(&self) -> &'static str {
174 match self {
175 Self::Mp3 => "mp3",
176 Self::Opus => "opus",
177 Self::Aac => "aac",
178 Self::Flac => "flac",
179 Self::Wav => "wav",
180 Self::Pcm => "pcm",
181 }
182 }
183
184 pub fn file_extension(&self) -> &'static str {
186 self.as_str()
187 }
188}
189
190impl std::fmt::Display for AudioFormat {
191 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
192 write!(f, "{}", self.as_str())
193 }
194}
195
196#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
198pub enum SttModel {
199 #[serde(rename = "whisper-1")]
201 #[default]
202 Whisper1,
203 #[serde(rename = "gpt-4o-transcribe")]
205 Gpt4oTranscribe,
206}
207
208impl SttModel {
209 pub fn as_str(&self) -> &'static str {
211 match self {
212 Self::Whisper1 => "whisper-1",
213 Self::Gpt4oTranscribe => "gpt-4o-transcribe",
214 }
215 }
216}
217
218impl std::fmt::Display for SttModel {
219 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
220 write!(f, "{}", self.as_str())
221 }
222}
223
224#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
226#[serde(rename_all = "snake_case")]
227pub enum TranscriptionFormat {
228 #[default]
230 Json,
231 Text,
233 Srt,
235 VerboseJson,
237 Vtt,
239}
240
241impl TranscriptionFormat {
242 pub fn as_str(&self) -> &'static str {
244 match self {
245 Self::Json => "json",
246 Self::Text => "text",
247 Self::Srt => "srt",
248 Self::VerboseJson => "verbose_json",
249 Self::Vtt => "vtt",
250 }
251 }
252}
253
254#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
256#[serde(rename_all = "lowercase")]
257pub enum TimestampGranularity {
258 Word,
260 Segment,
262}
263
264impl TimestampGranularity {
265 pub fn as_str(&self) -> &'static str {
267 match self {
268 Self::Word => "word",
269 Self::Segment => "segment",
270 }
271 }
272}
273
274#[derive(Debug, Clone, Default)]
276pub struct TtsOptions {
277 pub model: TtsModel,
279 pub voice: Voice,
281 pub response_format: AudioFormat,
283 pub speed: Option<f32>,
285 pub instructions: Option<String>,
299}
300
301#[derive(Debug, Clone, Default)]
303pub struct TranscribeOptions {
304 pub model: Option<SttModel>,
306 pub language: Option<String>,
308 pub prompt: Option<String>,
310 pub response_format: Option<TranscriptionFormat>,
312 pub temperature: Option<f32>,
314 pub timestamp_granularities: Option<Vec<TimestampGranularity>>,
316}
317
318#[derive(Debug, Clone, Default)]
320pub struct TranslateOptions {
321 pub model: Option<SttModel>,
323 pub prompt: Option<String>,
325 pub response_format: Option<TranscriptionFormat>,
327 pub temperature: Option<f32>,
329}
330
331#[derive(Debug, Clone, Serialize)]
333struct TtsRequest {
334 model: String,
335 input: String,
336 voice: String,
337 #[serde(skip_serializing_if = "Option::is_none")]
338 response_format: Option<String>,
339 #[serde(skip_serializing_if = "Option::is_none")]
340 speed: Option<f32>,
341 #[serde(skip_serializing_if = "Option::is_none")]
343 instructions: Option<String>,
344}
345
346pub struct Audio {
373 auth: AuthProvider,
375 timeout: Option<Duration>,
377}
378
379impl Audio {
380 pub fn new() -> Result<Self> {
399 let auth = AuthProvider::openai_from_env()?;
400 Ok(Self { auth, timeout: None })
401 }
402
403 pub fn with_auth(auth: AuthProvider) -> Self {
405 Self { auth, timeout: None }
406 }
407
408 pub fn azure() -> Result<Self> {
410 let auth = AuthProvider::azure_from_env()?;
411 Ok(Self { auth, timeout: None })
412 }
413
414 pub fn detect_provider() -> Result<Self> {
416 let auth = AuthProvider::from_env()?;
417 Ok(Self { auth, timeout: None })
418 }
419
420 pub fn with_url<S: Into<String>>(base_url: S, api_key: S) -> Self {
422 let auth = AuthProvider::from_url_with_key(base_url, api_key);
423 Self { auth, timeout: None }
424 }
425
426 pub fn from_url<S: Into<String>>(url: S) -> Result<Self> {
428 let auth = AuthProvider::from_url(url)?;
429 Ok(Self { auth, timeout: None })
430 }
431
432 pub fn auth(&self) -> &AuthProvider {
434 &self.auth
435 }
436
437 pub fn timeout(&mut self, timeout: Duration) -> &mut Self {
447 self.timeout = Some(timeout);
448 self
449 }
450
451 fn create_client(&self) -> Result<(request::Client, request::header::HeaderMap)> {
453 let client = create_http_client(self.timeout)?;
454 let mut headers = request::header::HeaderMap::new();
455 self.auth.apply_headers(&mut headers)?;
456 headers.insert("User-Agent", request::header::HeaderValue::from_static("openai-tools-rust"));
457 Ok((client, headers))
458 }
459
460 pub async fn text_to_speech(&self, text: &str, options: TtsOptions) -> Result<Vec<u8>> {
497 let (client, mut headers) = self.create_client()?;
498 headers.insert("Content-Type", request::header::HeaderValue::from_static("application/json"));
499
500 let instructions = if options.instructions.is_some() {
502 if options.model.supports_instructions() {
503 options.instructions
504 } else {
505 tracing::warn!("Model '{}' does not support instructions parameter. Ignoring instructions.", options.model);
506 None
507 }
508 } else {
509 None
510 };
511
512 let request_body = TtsRequest {
513 model: options.model.as_str().to_string(),
514 input: text.to_string(),
515 voice: options.voice.as_str().to_string(),
516 response_format: Some(options.response_format.as_str().to_string()),
517 speed: options.speed,
518 instructions,
519 };
520
521 let body = serde_json::to_string(&request_body).map_err(OpenAIToolError::SerdeJsonError)?;
522
523 let url = format!("{}/speech", self.auth.endpoint(AUDIO_PATH));
524
525 let response = client.post(&url).headers(headers).body(body).send().await.map_err(OpenAIToolError::RequestError)?;
526
527 let bytes = response.bytes().await.map_err(OpenAIToolError::RequestError)?;
528
529 Ok(bytes.to_vec())
530 }
531
532 pub async fn transcribe(&self, audio_path: &str, options: TranscribeOptions) -> Result<TranscriptionResponse> {
565 let audio_content = tokio::fs::read(audio_path).await.map_err(|e| OpenAIToolError::Error(format!("Failed to read audio file: {}", e)))?;
566
567 let filename = Path::new(audio_path).file_name().and_then(|n| n.to_str()).unwrap_or("audio.mp3").to_string();
568
569 self.transcribe_bytes(&audio_content, &filename, options).await
570 }
571
572 pub async fn transcribe_bytes(&self, audio_data: &[u8], filename: &str, options: TranscribeOptions) -> Result<TranscriptionResponse> {
607 let (client, headers) = self.create_client()?;
608
609 let audio_part = Part::bytes(audio_data.to_vec())
610 .file_name(filename.to_string())
611 .mime_str("audio/mpeg")
612 .map_err(|e| OpenAIToolError::Error(format!("Failed to set MIME type: {}", e)))?;
613
614 let mut form = Form::new().part("file", audio_part);
615
616 let model = options.model.unwrap_or_default();
618 form = form.text("model", model.as_str().to_string());
619
620 if let Some(language) = options.language {
622 form = form.text("language", language);
623 }
624 if let Some(prompt) = options.prompt {
625 form = form.text("prompt", prompt);
626 }
627 if let Some(response_format) = options.response_format {
628 form = form.text("response_format", response_format.as_str().to_string());
629 }
630 if let Some(temperature) = options.temperature {
631 form = form.text("temperature", temperature.to_string());
632 }
633 if let Some(granularities) = options.timestamp_granularities {
634 for g in granularities {
635 form = form.text("timestamp_granularities[]", g.as_str().to_string());
636 }
637 }
638
639 let url = format!("{}/transcriptions", self.auth.endpoint(AUDIO_PATH));
640
641 let response = client.post(&url).headers(headers).multipart(form).send().await.map_err(OpenAIToolError::RequestError)?;
642
643 let status = response.status();
644 let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
645
646 if cfg!(test) {
647 tracing::info!("Response content: {}", content);
648 }
649
650 if !status.is_success() {
651 if let Ok(error_resp) = serde_json::from_str::<ErrorResponse>(&content) {
652 return Err(OpenAIToolError::Error(error_resp.error.message.unwrap_or_default()));
653 }
654 return Err(OpenAIToolError::Error(format!("API error ({}): {}", status, content)));
655 }
656
657 serde_json::from_str::<TranscriptionResponse>(&content).map_err(OpenAIToolError::SerdeJsonError)
658 }
659
660 pub async fn translate(&self, audio_path: &str, options: TranslateOptions) -> Result<TranscriptionResponse> {
691 let audio_content = tokio::fs::read(audio_path).await.map_err(|e| OpenAIToolError::Error(format!("Failed to read audio file: {}", e)))?;
692
693 let filename = Path::new(audio_path).file_name().and_then(|n| n.to_str()).unwrap_or("audio.mp3").to_string();
694
695 self.translate_bytes(&audio_content, &filename, options).await
696 }
697
698 pub async fn translate_bytes(&self, audio_data: &[u8], filename: &str, options: TranslateOptions) -> Result<TranscriptionResponse> {
711 let (client, headers) = self.create_client()?;
712
713 let audio_part = Part::bytes(audio_data.to_vec())
714 .file_name(filename.to_string())
715 .mime_str("audio/mpeg")
716 .map_err(|e| OpenAIToolError::Error(format!("Failed to set MIME type: {}", e)))?;
717
718 let mut form = Form::new().part("file", audio_part);
719
720 let model = options.model.unwrap_or(SttModel::Whisper1);
722 form = form.text("model", model.as_str().to_string());
723
724 if let Some(prompt) = options.prompt {
726 form = form.text("prompt", prompt);
727 }
728 if let Some(response_format) = options.response_format {
729 form = form.text("response_format", response_format.as_str().to_string());
730 }
731 if let Some(temperature) = options.temperature {
732 form = form.text("temperature", temperature.to_string());
733 }
734
735 let url = format!("{}/translations", self.auth.endpoint(AUDIO_PATH));
736
737 let response = client.post(&url).headers(headers).multipart(form).send().await.map_err(OpenAIToolError::RequestError)?;
738
739 let status = response.status();
740 let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
741
742 if cfg!(test) {
743 tracing::info!("Response content: {}", content);
744 }
745
746 if !status.is_success() {
747 if let Ok(error_resp) = serde_json::from_str::<ErrorResponse>(&content) {
748 return Err(OpenAIToolError::Error(error_resp.error.message.unwrap_or_default()));
749 }
750 return Err(OpenAIToolError::Error(format!("API error ({}): {}", status, content)));
751 }
752
753 serde_json::from_str::<TranscriptionResponse>(&content).map_err(OpenAIToolError::SerdeJsonError)
754 }
755}
756
757#[cfg(test)]
758mod tests {
759 use super::*;
760
761 #[test]
766 fn test_tts_model_as_str() {
767 assert_eq!(TtsModel::Tts1.as_str(), "tts-1");
768 assert_eq!(TtsModel::Tts1Hd.as_str(), "tts-1-hd");
769 assert_eq!(TtsModel::Gpt4oMiniTts.as_str(), "gpt-4o-mini-tts");
770 }
771
772 #[test]
773 fn test_tts_model_supports_instructions() {
774 assert!(TtsModel::Gpt4oMiniTts.supports_instructions());
776 assert!(!TtsModel::Tts1.supports_instructions());
777 assert!(!TtsModel::Tts1Hd.supports_instructions());
778 }
779
780 #[test]
781 fn test_tts_model_default() {
782 let model = TtsModel::default();
783 assert_eq!(model, TtsModel::Tts1);
784 }
785
786 #[test]
787 fn test_tts_model_display() {
788 assert_eq!(format!("{}", TtsModel::Gpt4oMiniTts), "gpt-4o-mini-tts");
789 }
790
791 #[test]
796 fn test_voice_as_str_all_voices() {
797 assert_eq!(Voice::Alloy.as_str(), "alloy");
798 assert_eq!(Voice::Ash.as_str(), "ash");
799 assert_eq!(Voice::Ballad.as_str(), "ballad");
800 assert_eq!(Voice::Cedar.as_str(), "cedar");
801 assert_eq!(Voice::Coral.as_str(), "coral");
802 assert_eq!(Voice::Echo.as_str(), "echo");
803 assert_eq!(Voice::Fable.as_str(), "fable");
804 assert_eq!(Voice::Marin.as_str(), "marin");
805 assert_eq!(Voice::Nova.as_str(), "nova");
806 assert_eq!(Voice::Onyx.as_str(), "onyx");
807 assert_eq!(Voice::Sage.as_str(), "sage");
808 assert_eq!(Voice::Shimmer.as_str(), "shimmer");
809 assert_eq!(Voice::Verse.as_str(), "verse");
810 }
811
812 #[test]
813 fn test_voice_new_voices() {
814 assert_eq!(Voice::Ballad.as_str(), "ballad");
816 assert_eq!(Voice::Cedar.as_str(), "cedar");
817 assert_eq!(Voice::Marin.as_str(), "marin");
818 assert_eq!(Voice::Verse.as_str(), "verse");
819 }
820
821 #[test]
822 fn test_voice_default() {
823 let voice = Voice::default();
824 assert_eq!(voice, Voice::Alloy);
825 }
826
827 #[test]
828 fn test_voice_serialization() {
829 let voice = Voice::Coral;
830 let json = serde_json::to_string(&voice).unwrap();
831 assert_eq!(json, "\"coral\"");
832
833 let ballad = Voice::Ballad;
835 let json = serde_json::to_string(&ballad).unwrap();
836 assert_eq!(json, "\"ballad\"");
837 }
838
839 #[test]
840 fn test_voice_deserialization() {
841 let voice: Voice = serde_json::from_str("\"coral\"").unwrap();
842 assert_eq!(voice, Voice::Coral);
843
844 let cedar: Voice = serde_json::from_str("\"cedar\"").unwrap();
846 assert_eq!(cedar, Voice::Cedar);
847
848 let marin: Voice = serde_json::from_str("\"marin\"").unwrap();
849 assert_eq!(marin, Voice::Marin);
850 }
851
852 #[test]
857 fn test_tts_options_default() {
858 let options = TtsOptions::default();
859 assert_eq!(options.model, TtsModel::Tts1);
860 assert_eq!(options.voice, Voice::Alloy);
861 assert_eq!(options.response_format, AudioFormat::Mp3);
862 assert!(options.speed.is_none());
863 assert!(options.instructions.is_none());
864 }
865
866 #[test]
867 fn test_tts_options_with_instructions() {
868 let options = TtsOptions {
869 model: TtsModel::Gpt4oMiniTts,
870 voice: Voice::Coral,
871 instructions: Some("Speak in a cheerful tone.".to_string()),
872 ..Default::default()
873 };
874 assert_eq!(options.model, TtsModel::Gpt4oMiniTts);
875 assert_eq!(options.instructions, Some("Speak in a cheerful tone.".to_string()));
876 }
877
878 #[test]
883 fn test_tts_request_serialization_with_instructions() {
884 let request = TtsRequest {
885 model: "gpt-4o-mini-tts".to_string(),
886 input: "Hello, world!".to_string(),
887 voice: "coral".to_string(),
888 response_format: Some("mp3".to_string()),
889 speed: None,
890 instructions: Some("Speak cheerfully.".to_string()),
891 };
892 let json = serde_json::to_value(&request).unwrap();
893
894 assert_eq!(json["model"], "gpt-4o-mini-tts");
895 assert_eq!(json["input"], "Hello, world!");
896 assert_eq!(json["voice"], "coral");
897 assert_eq!(json["response_format"], "mp3");
898 assert_eq!(json["instructions"], "Speak cheerfully.");
899 assert!(json.get("speed").is_none());
900 }
901
902 #[test]
903 fn test_tts_request_serialization_without_instructions() {
904 let request = TtsRequest {
905 model: "tts-1".to_string(),
906 input: "Hello".to_string(),
907 voice: "alloy".to_string(),
908 response_format: Some("mp3".to_string()),
909 speed: Some(1.0),
910 instructions: None,
911 };
912 let json = serde_json::to_value(&request).unwrap();
913
914 assert_eq!(json["model"], "tts-1");
915 assert_eq!(json["speed"], 1.0);
916 assert!(json.get("instructions").is_none());
918 }
919
920 #[test]
921 fn test_tts_request_skip_serializing_none_fields() {
922 let request = TtsRequest {
923 model: "tts-1".to_string(),
924 input: "Test".to_string(),
925 voice: "echo".to_string(),
926 response_format: None,
927 speed: None,
928 instructions: None,
929 };
930 let json = serde_json::to_value(&request).unwrap();
931
932 assert!(json.get("model").is_some());
934 assert!(json.get("input").is_some());
935 assert!(json.get("voice").is_some());
936
937 assert!(json.get("response_format").is_none());
939 assert!(json.get("speed").is_none());
940 assert!(json.get("instructions").is_none());
941 }
942
943 #[test]
948 fn test_audio_format_as_str() {
949 assert_eq!(AudioFormat::Mp3.as_str(), "mp3");
950 assert_eq!(AudioFormat::Opus.as_str(), "opus");
951 assert_eq!(AudioFormat::Aac.as_str(), "aac");
952 assert_eq!(AudioFormat::Flac.as_str(), "flac");
953 assert_eq!(AudioFormat::Wav.as_str(), "wav");
954 assert_eq!(AudioFormat::Pcm.as_str(), "pcm");
955 }
956
957 #[test]
958 fn test_audio_format_file_extension() {
959 assert_eq!(AudioFormat::Mp3.file_extension(), "mp3");
960 assert_eq!(AudioFormat::Wav.file_extension(), "wav");
961 }
962
963 #[test]
968 fn test_stt_model_as_str() {
969 assert_eq!(SttModel::Whisper1.as_str(), "whisper-1");
970 assert_eq!(SttModel::Gpt4oTranscribe.as_str(), "gpt-4o-transcribe");
971 }
972
973 #[test]
978 fn test_transcription_format_as_str() {
979 assert_eq!(TranscriptionFormat::Json.as_str(), "json");
980 assert_eq!(TranscriptionFormat::Text.as_str(), "text");
981 assert_eq!(TranscriptionFormat::Srt.as_str(), "srt");
982 assert_eq!(TranscriptionFormat::VerboseJson.as_str(), "verbose_json");
983 assert_eq!(TranscriptionFormat::Vtt.as_str(), "vtt");
984 }
985
986 #[test]
991 fn test_timestamp_granularity_as_str() {
992 assert_eq!(TimestampGranularity::Word.as_str(), "word");
993 assert_eq!(TimestampGranularity::Segment.as_str(), "segment");
994 }
995}