aristech_tts_client/
lib.rs

1//! # Aristech TTS-Client
2//! The Aristech TTS-Client is a client library for the Aristech TTS-Server.
3
4#![warn(missing_docs)]
5
6/// The tts_services module contains types and functions generated from the Aristech TTS proto files.
7pub mod tts_services {
8    #![allow(missing_docs)]
9    tonic::include_proto!("aristech.tts");
10}
11
12use tts_services::speech_service_client::SpeechServiceClient as TtsServiceClient;
13use tts_services::{
14    PhonesetRequest, PhonesetResponse, SpeechRequest, TranscriptionRequest, TranscriptionResponse,
15    Voice, VoiceListRequest,
16};
17
18use std::error::Error;
19
20use tonic::codegen::InterceptedService;
21use tonic::service::Interceptor;
22use tonic::transport::{Certificate, Channel, ClientTlsConfig};
23
24/// The Auth struct holds the token and secret needed to authenticate with the server.
25#[derive(Clone)]
26pub struct Auth {
27    /// The token to authenticate with the server
28    pub token: String,
29    /// The secret to authenticate with the server
30    pub secret: String,
31}
32
33impl Auth {
34    /// Creates a new Auth struct with the given token and secret.
35    pub fn new(token: &str, secret: &str) -> Self {
36        Self {
37            token: token.to_string(),
38            secret: secret.to_string(),
39        }
40    }
41}
42
43/// The AuthInterceptor struct is used to intercept requests to the server and add the authentication headers.
44pub struct AuthInterceptor {
45    /// The authentication data to add to the headers
46    auth: Option<Auth>,
47}
48
49impl AuthInterceptor {
50    /// Creates a new AuthInterceptor with the given authentication data.
51    fn new(auth: Option<Auth>) -> Self {
52        Self { auth }
53    }
54}
55
56impl Interceptor for AuthInterceptor {
57    /// Adds the authentication data to the headers of the request.
58    fn call(&mut self, request: tonic::Request<()>) -> Result<tonic::Request<()>, tonic::Status> {
59        if let Some(auth) = &self.auth {
60            let mut request = request;
61            request
62                .metadata_mut()
63                .insert("token", auth.token.parse().unwrap());
64            request
65                .metadata_mut()
66                .insert("secret", auth.secret.parse().unwrap());
67            Ok(request)
68        } else {
69            Ok(request)
70        }
71    }
72}
73
74/// The TtsClient type is a type alias for the TtsServiceClient with the AuthInterceptor.
75pub type TtsClient = TtsServiceClient<InterceptedService<Channel, AuthInterceptor>>;
76
77/// The TlsOptions struct holds the tls options needed to communicate with the server.
78#[derive(Clone, Default)]
79pub struct TlsOptions {
80    /// The authentication data to authenticate with the server
81    pub auth: Option<Auth>,
82    /// The root certificate to verify the server's certificate
83    /// This is usually only needed when the server uses a self-signed certificate
84    pub ca_certificate: Option<String>,
85}
86
87impl TlsOptions {
88    /// Creates a new TlsOptions struct with the given authentication data and root certificate.
89    pub fn new(auth: Option<Auth>, ca_certificate: Option<String>) -> Self {
90        Self {
91            auth,
92            ca_certificate,
93        }
94    }
95}
96
97/// Creates a new [TtsClient] to communicate with the server.
98///
99/// # Arguments
100/// * `host` - The host to connect to (might include the port number e.g. "https://tts.example.com:8424"). Note that the protocol must be included in the host.
101/// * `tls_options` - The tls options to use when connecting to the server. If None is given, the connection will be unencrypted and unauthenticated (the server must also be configured to communicate without encryption in this case).
102pub async fn get_client(
103    host: String,
104    tls_options: Option<TlsOptions>,
105) -> Result<TtsClient, Box<dyn Error>> {
106    // Check if a schema is included in the host
107    // otherwise add http if no tls options are given and https otherwise
108    let host = if host.starts_with("http://") || host.starts_with("https://") {
109        host
110    } else {
111        match tls_options {
112            Some(_) => format!("https://{}", host),
113            None => format!("http://{}", host),
114        }
115    };
116    match tls_options {
117        Some(tls_options) => {
118            let tls = match tls_options.ca_certificate {
119                Some(ca_certificate) => {
120                    let ca_certificate = Certificate::from_pem(ca_certificate);
121                    ClientTlsConfig::new().ca_certificate(ca_certificate)
122                }
123                None => ClientTlsConfig::with_native_roots(ClientTlsConfig::new()),
124            };
125            let channel = Channel::from_shared(host)?
126                .tls_config(tls)?
127                .connect()
128                .await?;
129            let client: TtsServiceClient<InterceptedService<Channel, AuthInterceptor>> =
130                TtsServiceClient::with_interceptor(channel, AuthInterceptor::new(tls_options.auth));
131            Ok(client)
132        }
133        None => {
134            let channel = Channel::from_shared(host)?.connect().await?;
135            let client: TtsServiceClient<InterceptedService<Channel, AuthInterceptor>> =
136                TtsServiceClient::with_interceptor(channel, AuthInterceptor::new(None));
137            Ok(client)
138        }
139    }
140}
141
142/// Gets the list of available voices from the server.
143///
144/// # Arguments
145/// * `client` - The client to use to communicate with the server.
146/// * `request` - The request to send to the server. If None is given, the default request will be used.
147///
148/// # Example
149///
150/// ```no_run
151/// use aristech_tts_client::{get_client, TlsOptions, get_voices};
152/// use std::error::Error;
153///
154/// #[tokio::main]
155/// async fn main() -> Result<(), Box<dyn Error>> {
156///     let mut client = get_client("https://tts.example.com".to_string(), Some(TlsOptions::default())).await?;
157///     let voices = get_voices(&mut client, None).await?;
158///     for voice in voices {
159///         println!("Voice: {:?}", voice);
160///     }
161///    Ok(())
162/// }
163/// ```
164pub async fn get_voices(
165    client: &mut TtsClient,
166    request: Option<VoiceListRequest>,
167) -> Result<Vec<Voice>, Box<dyn Error>> {
168    let req = request.unwrap_or(VoiceListRequest::default());
169    let request = tonic::Request::new(req);
170    let response = client.get_voice_list(request).await?;
171    let mut stream = response.into_inner();
172    let mut voices = vec![];
173    while let Ok(Some(voice)) = stream.message().await {
174        voices.push(voice);
175    }
176    Ok(voices)
177}
178
179/// Gets the phoneset for the given voice from the server.
180///
181/// # Arguments
182/// * `client` - The client to use to communicate with the server.
183/// * `request` - The request to send to the server.
184///
185/// # Example
186///
187/// ```no_run
188/// use aristech_tts_client::{
189///     get_client, TlsOptions, get_phoneset,
190///     tts_services::{PhonesetRequest, Voice},
191/// };
192/// use std::error::Error;
193///
194/// #[tokio::main]
195/// async fn main() -> Result<(), Box<dyn Error>> {
196///     let mut client = get_client("https://tts.example.com".to_string(), Some(TlsOptions::default())).await?;
197///     let response = get_phoneset(&mut client, PhonesetRequest {
198///         voice: Some(Voice {
199///             voice_id: "anne_en_GB".to_string(),
200///             ..Voice::default()
201///         }),
202///         ..PhonesetRequest::default()
203///     }).await?;
204///     println!("{:?}", response);
205///     Ok(())
206/// }
207pub async fn get_phoneset(
208    client: &mut TtsClient,
209    request: PhonesetRequest,
210) -> Result<PhonesetResponse, Box<dyn Error>> {
211    let request = tonic::Request::new(request);
212    let response = client.get_phoneset(request).await?;
213    Ok(response.into_inner())
214}
215
216/// Gets the transcription for a word for a specific voice from the server.
217///
218/// # Arguments
219/// * `client` - The client to use to communicate with the server.
220/// * `request` - The request to send to the server.
221///
222/// # Example
223///
224/// ```no_run
225/// use aristech_tts_client::{
226///    get_client, TlsOptions, get_transcription,
227///   tts_services::{TranscriptionRequest, Voice},
228/// };
229/// use std::error::Error;
230///
231/// #[tokio::main]
232/// async fn main() -> Result<(), Box<dyn Error>> {
233///     let mut client = get_client("https://tts.example.com".to_string(), Some(TlsOptions::default())).await?;
234///     let response = get_transcription(&mut client, TranscriptionRequest {
235///         voice: Some(Voice {
236///             voice_id: "anne_en_GB".to_string(),
237///             ..Voice::default()
238///         }),
239///         word: "hello".to_string(),
240///         ..TranscriptionRequest::default()
241///     }).await?;
242///     println!("{:?}", response);
243///     Ok(())
244/// }
245/// ```
246///
247pub async fn get_transcription(
248    client: &mut TtsClient,
249    request: TranscriptionRequest,
250) -> Result<TranscriptionResponse, Box<dyn Error>> {
251    let request = tonic::Request::new(request);
252    let response = client.get_transcription(request).await?;
253    Ok(response.into_inner())
254}
255
256/// Synthesizes the given text with the given options and returns the audio data.
257/// Currently only raw and wav audio formats are supported.
258/// If the audio format is set to wav (default), the audio data will be prepended with a wave header.
259///
260/// # Arguments
261/// * `client` - The client to use to communicate with the server.
262/// * `request` - The request to send to the server.
263///
264/// # Example
265///
266/// ```no_run
267/// use aristech_tts_client::{
268///     get_client, TlsOptions, synthesize,
269///     tts_services::{SpeechRequest, SpeechRequestOption},
270/// };
271/// use std::error::Error;
272///
273/// #[tokio::main]
274/// async fn main() -> Result<(), Box<dyn Error>> {
275///     let mut client = get_client("https://tts.example.com".to_string(), Some(TlsOptions::default())).await?;
276///     let request = SpeechRequest {
277///         text: "Hello world.".to_string(),
278///         options: Some(SpeechRequestOption {
279///             voice_id: "anne_en_GB".to_string(),
280///             ..SpeechRequestOption::default()
281///         }),
282///         ..SpeechRequest::default()
283///     };
284///     let data = synthesize(&mut client, request).await?;
285///     std::fs::write("output.wav", data)?;
286///     Ok(())
287/// }
288pub async fn synthesize(
289    client: &mut TtsClient,
290    request: SpeechRequest,
291) -> Result<Vec<u8>, Box<dyn Error>> {
292    // Get the voice list to check if the voice exists
293    // and be able to create a header with the voices audio format
294    let voices = get_voices(client, None).await?;
295    // let voice = voices
296    voices
297        .iter()
298        .find(|voice| voice.voice_id == request.clone().options.unwrap_or_default().voice_id)
299        .ok_or("Voice not found")?;
300    // Start the audio stream
301    let request = tonic::Request::new(request);
302    let response = client.get_speech(request).await?;
303    let mut stream = response.into_inner();
304    let mut audio = vec![];
305    while let Ok(Some(chunk)) = stream.message().await {
306        // audio.extend_from_slice(&chunk.data);
307        audio.extend_from_slice(&chunk.data);
308    }
309
310    Ok(audio)
311}
312
313/// This is just an alias for [synthesize]
314pub async fn get_audio(
315    client: &mut TtsClient,
316    request: SpeechRequest,
317) -> Result<Vec<u8>, Box<dyn Error>> {
318    synthesize(client, request).await
319}
320
321/// Clears the cache of the server, removing all cached audio data.
322/// # Arguments
323/// * `client` - The client to use to communicate with the server.
324/// * `request` - The request to send to the server. If None is given, the default request will be used.
325/// # Example
326/// ```no_run
327/// use aristech_tts_client::{get_client, TlsOptions, clear_cache};
328/// use std::error::Error;
329/// #[tokio::main]
330/// async fn main() -> Result<(), Box<dyn Error>> {
331///     let mut client = get_client("https://tts.example.com".to_string(), Some(TlsOptions::default())).await?;
332///     let response = clear_cache(&mut client, None).await?;
333///     println!("Cache cleared: {:?}", response);
334///     Ok(())
335/// }
336/// ```
337pub async fn clear_cache(
338    client: &mut TtsClient,
339    request: Option<tts_services::ClearCacheRequest>,
340) -> Result<tts_services::ClearCacheResponse, Box<dyn Error>> {
341    let req = request.unwrap_or(tts_services::ClearCacheRequest::default());
342    let request = tonic::Request::new(req);
343    let response = client.clear_cache(request).await?;
344    Ok(response.into_inner())
345}