langchain_rust/tools/text2speech/openai/
client.rs1use std::{error::Error, sync::Arc};
2
3use async_openai::types::CreateSpeechRequestArgs;
4use async_openai::Client;
5pub use async_openai::{
6 config::{Config, OpenAIConfig},
7 types::{SpeechModel, SpeechResponseFormat, Voice},
8};
9use async_trait::async_trait;
10use serde_json::Value;
11
12use crate::tools::{SpeechStorage, Tool};
13
14#[derive(Clone)]
15pub struct Text2SpeechOpenAI<C: Config> {
16 config: C,
17 model: SpeechModel,
18 voice: Voice,
19 storage: Option<Arc<dyn SpeechStorage>>,
20 response_format: SpeechResponseFormat,
21 path: String,
22}
23
24impl<C: Config> Text2SpeechOpenAI<C> {
25 pub fn new(config: C) -> Self {
26 Self {
27 config,
28 model: SpeechModel::Tts1,
29 voice: Voice::Alloy,
30 storage: None,
31 response_format: SpeechResponseFormat::Mp3,
32 path: "./data/audio.mp3".to_string(),
33 }
34 }
35
36 pub fn with_model(mut self, model: SpeechModel) -> Self {
37 self.model = model;
38 self
39 }
40
41 pub fn with_voice(mut self, voice: Voice) -> Self {
42 self.voice = voice;
43 self
44 }
45
46 pub fn with_storage<SS: SpeechStorage + 'static>(mut self, storage: SS) -> Self {
47 self.storage = Some(Arc::new(storage));
48 self
49 }
50
51 pub fn with_response_format(mut self, response_format: SpeechResponseFormat) -> Self {
52 self.response_format = response_format;
53 self
54 }
55
56 pub fn with_path<S: Into<String>>(mut self, path: S) -> Self {
57 self.path = path.into();
58 self
59 }
60
61 pub fn with_config(mut self, config: C) -> Self {
62 self.config = config;
63 self
64 }
65}
66
67impl Default for Text2SpeechOpenAI<OpenAIConfig> {
68 fn default() -> Self {
69 Self::new(OpenAIConfig::default())
70 }
71}
72
73#[async_trait]
74impl<C: Config + Send + Sync> Tool for Text2SpeechOpenAI<C> {
75 fn name(&self) -> String {
76 "Text2SpeechOpenAI".to_string()
77 }
78
79 fn description(&self) -> String {
80 r#"A wrapper around OpenAI Text2Speech. "
81 "Useful for when you need to convert text to speech. "
82 "It supports multiple languages, including English, German, Polish, "
83 "Spanish, Italian, French, Portuguese""#
84 .to_string()
85 }
86
87 async fn run(&self, input: Value) -> Result<String, Box<dyn Error>> {
88 let input = input.as_str().ok_or("Invalid input")?;
89 let client = Client::new();
90 let response_format: SpeechResponseFormat = self.response_format;
91
92 let request = CreateSpeechRequestArgs::default()
93 .input(input)
94 .voice(self.voice.clone())
95 .response_format(response_format)
96 .model(self.model.clone())
97 .build()?;
98
99 let response = client.audio().speech(request).await?;
100
101 if self.storage.is_some() {
102 let storage = self.storage.as_ref().unwrap(); let data = response.bytes;
104 return storage.save(&self.path, &data).await;
105 } else {
106 response.save(&self.path).await?;
107 }
108
109 Ok(self.path.clone())
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use crate::tools::{Text2SpeechOpenAI, Tool};
116
117 #[tokio::test]
118 #[ignore]
119 async fn openai_speech2text_tool() {
120 let openai = Text2SpeechOpenAI::default();
121 let s = openai.call("Hola como estas").await.unwrap();
122 println!("{}", s);
123 }
124}