aristech_tts_client/lib.rs
1//! # Aristech TTS-Client
2//! The Aristech TTS-Client is a client library for the Aristech TTS-Server.
3
4#![warn(missing_docs)]
5
6/// The tts_services module contains types and functions generated from the Aristech TTS proto files.
7pub mod tts_services {
8 #![allow(missing_docs)]
9 tonic::include_proto!("aristech.tts");
10}
11
12use tts_services::speech_service_client::SpeechServiceClient as TtsServiceClient;
13use tts_services::{
14 PhonesetRequest, PhonesetResponse, SpeechRequest, TranscriptionRequest, TranscriptionResponse,
15 Voice, VoiceListRequest,
16};
17
18use std::error::Error;
19
20use tonic::codegen::InterceptedService;
21use tonic::service::Interceptor;
22use tonic::transport::{Certificate, Channel, ClientTlsConfig};
23
24/// The Auth struct holds the token and secret needed to authenticate with the server.
25#[derive(Clone)]
26pub struct Auth {
27 /// The token to authenticate with the server
28 pub token: String,
29 /// The secret to authenticate with the server
30 pub secret: String,
31}
32
33impl Auth {
34 /// Creates a new Auth struct with the given token and secret.
35 pub fn new(token: &str, secret: &str) -> Self {
36 Self {
37 token: token.to_string(),
38 secret: secret.to_string(),
39 }
40 }
41}
42
43/// The AuthInterceptor struct is used to intercept requests to the server and add the authentication headers.
44pub struct AuthInterceptor {
45 /// The authentication data to add to the headers
46 auth: Option<Auth>,
47}
48
49impl AuthInterceptor {
50 /// Creates a new AuthInterceptor with the given authentication data.
51 fn new(auth: Option<Auth>) -> Self {
52 Self { auth }
53 }
54}
55
56impl Interceptor for AuthInterceptor {
57 /// Adds the authentication data to the headers of the request.
58 fn call(&mut self, request: tonic::Request<()>) -> Result<tonic::Request<()>, tonic::Status> {
59 if let Some(auth) = &self.auth {
60 let mut request = request;
61 request
62 .metadata_mut()
63 .insert("token", auth.token.parse().unwrap());
64 request
65 .metadata_mut()
66 .insert("secret", auth.secret.parse().unwrap());
67 Ok(request)
68 } else {
69 Ok(request)
70 }
71 }
72}
73
74/// The TtsClient type is a type alias for the TtsServiceClient with the AuthInterceptor.
75pub type TtsClient = TtsServiceClient<InterceptedService<Channel, AuthInterceptor>>;
76
77/// The TlsOptions struct holds the tls options needed to communicate with the server.
78#[derive(Clone, Default)]
79pub struct TlsOptions {
80 /// The authentication data to authenticate with the server
81 pub auth: Option<Auth>,
82 /// The root certificate to verify the server's certificate
83 /// This is usually only needed when the server uses a self-signed certificate
84 pub ca_certificate: Option<String>,
85}
86
87impl TlsOptions {
88 /// Creates a new TlsOptions struct with the given authentication data and root certificate.
89 pub fn new(auth: Option<Auth>, ca_certificate: Option<String>) -> Self {
90 Self {
91 auth,
92 ca_certificate,
93 }
94 }
95}
96
97/// Creates a new [TtsClient] to communicate with the server.
98///
99/// # Arguments
100/// * `host` - The host to connect to (might include the port number e.g. "https://tts.example.com:8424"). Note that the protocol must be included in the host.
101/// * `tls_options` - The tls options to use when connecting to the server. If None is given, the connection will be unencrypted and unauthenticated (the server must also be configured to communicate without encryption in this case).
102pub async fn get_client(
103 host: String,
104 tls_options: Option<TlsOptions>,
105) -> Result<TtsClient, Box<dyn Error>> {
106 // Check if a schema is included in the host
107 // otherwise add http if no tls options are given and https otherwise
108 let host = if host.starts_with("http://") || host.starts_with("https://") {
109 host
110 } else {
111 match tls_options {
112 Some(_) => format!("https://{}", host),
113 None => format!("http://{}", host),
114 }
115 };
116 match tls_options {
117 Some(tls_options) => {
118 let tls = match tls_options.ca_certificate {
119 Some(ca_certificate) => {
120 let ca_certificate = Certificate::from_pem(ca_certificate);
121 ClientTlsConfig::new().ca_certificate(ca_certificate)
122 }
123 None => ClientTlsConfig::with_native_roots(ClientTlsConfig::new()),
124 };
125 let channel = Channel::from_shared(host)?
126 .tls_config(tls)?
127 .connect()
128 .await?;
129 let client: TtsServiceClient<InterceptedService<Channel, AuthInterceptor>> =
130 TtsServiceClient::with_interceptor(channel, AuthInterceptor::new(tls_options.auth));
131 Ok(client)
132 }
133 None => {
134 let channel = Channel::from_shared(host)?.connect().await?;
135 let client: TtsServiceClient<InterceptedService<Channel, AuthInterceptor>> =
136 TtsServiceClient::with_interceptor(channel, AuthInterceptor::new(None));
137 Ok(client)
138 }
139 }
140}
141
142/// Gets the list of available voices from the server.
143///
144/// # Arguments
145/// * `client` - The client to use to communicate with the server.
146/// * `request` - The request to send to the server. If None is given, the default request will be used.
147///
148/// # Example
149///
150/// ```no_run
151/// use aristech_tts_client::{get_client, TlsOptions, get_voices};
152/// use std::error::Error;
153///
154/// #[tokio::main]
155/// async fn main() -> Result<(), Box<dyn Error>> {
156/// let mut client = get_client("https://tts.example.com".to_string(), Some(TlsOptions::default())).await?;
157/// let voices = get_voices(&mut client, None).await?;
158/// for voice in voices {
159/// println!("Voice: {:?}", voice);
160/// }
161/// Ok(())
162/// }
163/// ```
164pub async fn get_voices(
165 client: &mut TtsClient,
166 request: Option<VoiceListRequest>,
167) -> Result<Vec<Voice>, Box<dyn Error>> {
168 let req = request.unwrap_or(VoiceListRequest::default());
169 let request = tonic::Request::new(req);
170 let response = client.get_voice_list(request).await?;
171 let mut stream = response.into_inner();
172 let mut voices = vec![];
173 while let Ok(Some(voice)) = stream.message().await {
174 voices.push(voice);
175 }
176 Ok(voices)
177}
178
179/// Gets the phoneset for the given voice from the server.
180///
181/// # Arguments
182/// * `client` - The client to use to communicate with the server.
183/// * `request` - The request to send to the server.
184///
185/// # Example
186///
187/// ```no_run
188/// use aristech_tts_client::{
189/// get_client, TlsOptions, get_phoneset,
190/// tts_services::{PhonesetRequest, Voice},
191/// };
192/// use std::error::Error;
193///
194/// #[tokio::main]
195/// async fn main() -> Result<(), Box<dyn Error>> {
196/// let mut client = get_client("https://tts.example.com".to_string(), Some(TlsOptions::default())).await?;
197/// let response = get_phoneset(&mut client, PhonesetRequest {
198/// voice: Some(Voice {
199/// voice_id: "anne_en_GB".to_string(),
200/// ..Voice::default()
201/// }),
202/// ..PhonesetRequest::default()
203/// }).await?;
204/// println!("{:?}", response);
205/// Ok(())
206/// }
207pub async fn get_phoneset(
208 client: &mut TtsClient,
209 request: PhonesetRequest,
210) -> Result<PhonesetResponse, Box<dyn Error>> {
211 let request = tonic::Request::new(request);
212 let response = client.get_phoneset(request).await?;
213 Ok(response.into_inner())
214}
215
216/// Gets the transcription for a word for a specific voice from the server.
217///
218/// # Arguments
219/// * `client` - The client to use to communicate with the server.
220/// * `request` - The request to send to the server.
221///
222/// # Example
223///
224/// ```no_run
225/// use aristech_tts_client::{
226/// get_client, TlsOptions, get_transcription,
227/// tts_services::{TranscriptionRequest, Voice},
228/// };
229/// use std::error::Error;
230///
231/// #[tokio::main]
232/// async fn main() -> Result<(), Box<dyn Error>> {
233/// let mut client = get_client("https://tts.example.com".to_string(), Some(TlsOptions::default())).await?;
234/// let response = get_transcription(&mut client, TranscriptionRequest {
235/// voice: Some(Voice {
236/// voice_id: "anne_en_GB".to_string(),
237/// ..Voice::default()
238/// }),
239/// word: "hello".to_string(),
240/// ..TranscriptionRequest::default()
241/// }).await?;
242/// println!("{:?}", response);
243/// Ok(())
244/// }
245/// ```
246///
247pub async fn get_transcription(
248 client: &mut TtsClient,
249 request: TranscriptionRequest,
250) -> Result<TranscriptionResponse, Box<dyn Error>> {
251 let request = tonic::Request::new(request);
252 let response = client.get_transcription(request).await?;
253 Ok(response.into_inner())
254}
255
256/// Synthesizes the given text with the given options and returns the audio data.
257/// Currently only raw and wav audio formats are supported.
258/// If the audio format is set to wav (default), the audio data will be prepended with a wave header.
259///
260/// # Arguments
261/// * `client` - The client to use to communicate with the server.
262/// * `request` - The request to send to the server.
263///
264/// # Example
265///
266/// ```no_run
267/// use aristech_tts_client::{
268/// get_client, TlsOptions, synthesize,
269/// tts_services::{SpeechRequest, SpeechRequestOption},
270/// };
271/// use std::error::Error;
272///
273/// #[tokio::main]
274/// async fn main() -> Result<(), Box<dyn Error>> {
275/// let mut client = get_client("https://tts.example.com".to_string(), Some(TlsOptions::default())).await?;
276/// let request = SpeechRequest {
277/// text: "Hello world.".to_string(),
278/// options: Some(SpeechRequestOption {
279/// voice_id: "anne_en_GB".to_string(),
280/// ..SpeechRequestOption::default()
281/// }),
282/// ..SpeechRequest::default()
283/// };
284/// let data = synthesize(&mut client, request).await?;
285/// std::fs::write("output.wav", data)?;
286/// Ok(())
287/// }
288pub async fn synthesize(
289 client: &mut TtsClient,
290 request: SpeechRequest,
291) -> Result<Vec<u8>, Box<dyn Error>> {
292 // Get the voice list to check if the voice exists
293 // and be able to create a header with the voices audio format
294 let voices = get_voices(client, None).await?;
295 // let voice = voices
296 voices
297 .iter()
298 .find(|voice| voice.voice_id == request.clone().options.unwrap_or_default().voice_id)
299 .ok_or("Voice not found")?;
300 // Start the audio stream
301 let request = tonic::Request::new(request);
302 let response = client.get_speech(request).await?;
303 let mut stream = response.into_inner();
304 let mut audio = vec![];
305 while let Ok(Some(chunk)) = stream.message().await {
306 // audio.extend_from_slice(&chunk.data);
307 audio.extend_from_slice(&chunk.data);
308 }
309
310 Ok(audio)
311}
312
313/// This is just an alias for [synthesize]
314pub async fn get_audio(
315 client: &mut TtsClient,
316 request: SpeechRequest,
317) -> Result<Vec<u8>, Box<dyn Error>> {
318 synthesize(client, request).await
319}
320
321/// Clears the cache of the server, removing all cached audio data.
322/// # Arguments
323/// * `client` - The client to use to communicate with the server.
324/// * `request` - The request to send to the server. If None is given, the default request will be used.
325/// # Example
326/// ```no_run
327/// use aristech_tts_client::{get_client, TlsOptions, clear_cache};
328/// use std::error::Error;
329/// #[tokio::main]
330/// async fn main() -> Result<(), Box<dyn Error>> {
331/// let mut client = get_client("https://tts.example.com".to_string(), Some(TlsOptions::default())).await?;
332/// let response = clear_cache(&mut client, None).await?;
333/// println!("Cache cleared: {:?}", response);
334/// Ok(())
335/// }
336/// ```
337pub async fn clear_cache(
338 client: &mut TtsClient,
339 request: Option<tts_services::ClearCacheRequest>,
340) -> Result<tts_services::ClearCacheResponse, Box<dyn Error>> {
341 let req = request.unwrap_or(tts_services::ClearCacheRequest::default());
342 let request = tonic::Request::new(req);
343 let response = client.clear_cache(request).await?;
344 Ok(response.into_inner())
345}