aristech_stt_client/
lib.rs

1//! # Aristech STT-Client
2//! The Aristech STT-Client is a client library for the Aristech STT-Server.
3
4#![warn(missing_docs)]
5
6/// The stt_service module contains types and functions generated from the Aristech STT proto file.
7pub mod stt_service {
8    #![allow(missing_docs)]
9    tonic::include_proto!("ari.stt.v1");
10}
11
12use std::error::Error;
13
14use stt_service::streaming_recognition_request::StreamingRequest;
15use tonic::codegen::InterceptedService;
16use tonic::service::Interceptor;
17use tonic::transport::{Certificate, Channel, ClientTlsConfig};
18
19use stt_service::stt_service_client::SttServiceClient;
20use stt_service::{
21    streaming_recognition_request, AccountInfoRequest, AccountInfoResponse, ModelsRequest,
22    ModelsResponse, NlpFunctionsRequest, NlpFunctionsResponse, NlpProcessRequest,
23    NlpProcessResponse, RecognitionConfig, RecognitionSpec, StreamingRecognitionRequest,
24    StreamingRecognitionResponse,
25};
26
27use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
28
29/// The Auth struct holds the token and secret needed to authenticate with the server.
30#[derive(Clone)]
31pub struct Auth {
32    /// The token to authenticate with the server
33    pub token: String,
34    /// The secret to authenticate with the server
35    pub secret: String,
36}
37
38impl Auth {
39    /// Creates a new Auth struct with the given token and secret.
40    pub fn new(token: &str, secret: &str) -> Self {
41        Self {
42            token: token.to_string(),
43            secret: secret.to_string(),
44        }
45    }
46}
47
48/// The AuthInterceptor struct is used to intercept requests to the server and add the authentication headers.
49pub struct AuthInterceptor {
50    /// The authentication data to add to the headers
51    auth: Option<Auth>,
52}
53
54impl AuthInterceptor {
55    /// Creates a new AuthInterceptor with the given authentication data.
56    fn new(auth: Option<Auth>) -> Self {
57        Self { auth }
58    }
59}
60impl Interceptor for AuthInterceptor {
61    /// Adds the authentication data to the headers of the request.
62    fn call(&mut self, request: tonic::Request<()>) -> Result<tonic::Request<()>, tonic::Status> {
63        if let Some(auth) = &self.auth {
64            let mut request = request;
65            request
66                .metadata_mut()
67                .insert("token", auth.token.parse().unwrap());
68            request
69                .metadata_mut()
70                .insert("secret", auth.secret.parse().unwrap());
71            Ok(request)
72        } else {
73            Ok(request)
74        }
75    }
76}
77
78/// The SttClient type is a type alias for the SttServiceClient with the AuthInterceptor.
79pub type SttClient = SttServiceClient<InterceptedService<Channel, AuthInterceptor>>;
80
81struct ApiKeyData {
82    token: String,
83    secret: String,
84    host: Option<String>,
85}
86
87/// Decodes the given api key into an Auth struct.
88fn decode_api_key(api_key: &str) -> Result<ApiKeyData, Box<dyn Error>> {
89    // The API key is base64 url encoded and has no padding and starts with "at-"
90    let api_key = api_key.trim_start_matches("at-");
91    let key_data = URL_SAFE_NO_PAD.decode(api_key)?;
92    let key_data = String::from_utf8(key_data)?;
93
94    let mut token = None;
95    let mut secret = None;
96    let mut host = None;
97    for line in key_data.lines() {
98        let mut parts = line.splitn(2, ":");
99        let key = match parts.next() {
100            Some(key) => key.trim(),
101            None => continue,
102        };
103        let value = match parts.next() {
104            Some(value) => value.trim(),
105            None => continue,
106        };
107        match key {
108            "token" => token = Some(value.to_string()),
109            "secret" => secret = Some(value.to_string()),
110            "host" => {
111                // If the host doesn't start with http:// or https://, add https://
112                let key_host = value.to_string();
113                host = match key_host.starts_with("http://") || key_host.starts_with("https://") {
114                    true => Some(
115                        key_host
116                            .trim_end_matches('/')
117                            .trim_end_matches('/')
118                            .to_string(),
119                    ),
120                    false => Some(format!(
121                        "https://{}",
122                        key_host.trim_end_matches('/').trim_end_matches('/')
123                    )),
124                };
125            }
126            _ => {}
127        }
128    }
129    match (token, secret) {
130        (Some(token), Some(secret)) => Ok(ApiKeyData {
131            token,
132            secret,
133            host,
134        }),
135        _ => Err("API key is missing token or secret".into()),
136    }
137}
138
139/// The SttClientBuilder struct is used to build a SttClient with the given host and tls options.
140#[derive(Default)]
141pub struct SttClientBuilder {
142    host: String,
143    tls: bool,
144    auth: Option<Auth>,
145    ca_certificate: Option<String>,
146}
147
148impl SttClientBuilder {
149    /// Creates a new SttClientBuilder and tries to parse the API key from the environment variable `ARISTECH_STT_API_KEY`.
150    /// If no API key is found or the API key is invalid, the builder will be created without authentication data.
151    /// Tls will be enabled if a valid API key was found or false otherwise.
152    /// To catch any errors with the API key, use the `.api_key` method.
153    ///
154    /// If a valid API key was found, a custom root certificate can be set with the environment variable `ARISTECH_STT_CA_CERTIFICATE` as well if the server uses a self-signed certificate for example.
155    /// The certificate should be the path to the certificate in PEM format. It is also possible to set the certificate with the `.ca_certificate` method.
156    ///
157    /// To create a client without automatically checking the environment variable, use the default constructor.
158    ///
159    /// # Example
160    ///
161    /// ```no_run
162    /// use aristech_stt_client::{SttClientBuilder};
163    ///
164    /// #[tokio::main]
165    /// async fn main() {
166    ///     let client = SttClientBuilder::new()
167    ///       .build()
168    ///       .await
169    ///       .unwrap();
170    ///     // Use the client
171    ///     // ...
172    /// }
173    pub fn new() -> Self {
174        // Try parsing the API key from the environment `ARISTECH_STT_API_KEY`
175        if let Ok(api_key) = std::env::var("ARISTECH_STT_API_KEY") {
176            if let Ok(api_key_data) = decode_api_key(&api_key) {
177                let ca_certificate = match std::env::var("ARISTECH_STT_CA_CERTIFICATE") {
178                    Ok(ca_certificate) if !ca_certificate.is_empty() => {
179                        // Try to read the certificate from the file
180                        match std::fs::read_to_string(ca_certificate) {
181                            Ok(ca_certificate) => Some(ca_certificate),
182                            Err(_) => None,
183                        }
184                    }
185                    _ => None,
186                };
187                let host = api_key_data.host.unwrap_or_default();
188                return Self {
189                    tls: true,
190                    host,
191                    auth: Some(Auth::new(&api_key_data.token, &api_key_data.secret)),
192                    ca_certificate,
193                };
194            }
195        }
196        Self {
197            tls: false,
198            ..Default::default()
199        }
200    }
201
202    /// Attempts to parse the given API key and set the authentication data for the SttClientBuilder.
203    /// When using the `new` method, the builder will automatically try to parse the API key from the environment variable `ARISTECH_STT_API_KEY` but won't fail if the API key is invalid or missing.
204    /// You can use this method to manually set the API key and catch any errors.
205    /// Note that the host from the API key will only be used if no host was set before.
206    ///
207    /// # Arguments
208    /// * `api_key` - The API key to use for the connection.
209    ///
210    /// # Example
211    ///
212    /// ```no_run
213    /// use aristech_stt_client::{SttClientBuilder};
214    ///
215    /// #[tokio::main]
216    /// async fn main() {
217    ///     // Use the default constructor to create a new SttClientBuilder without
218    ///     // automatically attempting to parse the environment variable `ARISTECH_STT_API_KEY`.
219    ///     let client = SttClientBuilder::default()
220    ///       .api_key("at-abc123...").unwrap()
221    ///       .build()
222    ///       .await
223    ///       .unwrap();
224    ///     // Use the client
225    ///     // ...
226    /// }
227    /// ```
228    pub fn api_key(mut self, api_key: &str) -> Result<Self, Box<dyn Error>> {
229        let api_key_data = decode_api_key(api_key)?;
230        if let Some(host) = api_key_data.host {
231            if self.host.is_empty() {
232                self.host = host;
233            }
234        }
235        self.tls = true;
236        self.auth = Some(Auth::new(&api_key_data.token, &api_key_data.secret));
237        Ok(self)
238    }
239
240    /// Allows to set a custom root certificate to use for the connection and enables tls when a certificate is set to Some.
241    /// This is especially useful when the server uses a self-signed certificate.
242    /// The `ca_certificate` should be the content of the certificate in PEM format.
243    ///
244    /// # Arguments
245    /// * `ca_certificate` - The root certificate to use for the connection.
246    ///
247    /// # Example
248    ///
249    /// ```no_run
250    /// use aristech_stt_client::{SttClientBuilder};
251    ///
252    /// #[tokio::main]
253    /// async fn main() {
254    ///     let client = SttClientBuilder::new()
255    ///       .ca_certificate(Some(std::fs::read_to_string("path/to/certificate.pem").unwrap()))
256    ///       .build()
257    ///       .await
258    ///       .unwrap();
259    ///      // Use the client
260    ///      // ...
261    /// }
262    /// ```
263    pub fn ca_certificate(mut self, ca_certificate: Option<String>) -> Self {
264        // If a ca_certificate is set, we need to use tls
265        match ca_certificate {
266            Some(_) => self.tls = true,
267            _ => {}
268        }
269        self.ca_certificate = ca_certificate;
270        self
271    }
272
273    /// Sets the auth options for the SttClientBuilder manually and enables tls when auth is set to Some.  
274    /// **Note:** Calling `.api_key` after `.auth` will overwrite the auth data.
275    ///
276    /// # Arguments
277    /// * `auth` - The authentication data to use for the connection.
278    ///
279    /// # Example
280    ///
281    /// ```no_run
282    /// use aristech_stt_client::{SttClientBuilder, Auth};
283    ///
284    /// #[tokio::main]
285    /// async fn main() {
286    ///     let client = SttClientBuilder::default()
287    ///       .host("https://stt.example.com:9424").unwrap()
288    ///       .auth(Some(Auth { token: "my-token".to_string(). secret: "my-secret".to_string() }))
289    ///       .build()
290    ///       .await
291    ///       .unwrap();
292    ///       // Use the client
293    ///       // ...
294    /// }
295    pub fn auth(mut self, auth: Option<Auth>) -> Self {
296        // If auth is set, we need to use tls
297        match auth {
298            Some(_) => self.tls = true,
299            _ => {}
300        }
301        self.auth = auth;
302        self
303    }
304
305    /// Sets the host for the SttClientBuilder manually and enables tls depending on the protocol of the host.  
306    /// **Note:** When the API key in the environment variable ARISTECH_STT_API_KEY contains a host or when you call `.api_key` before this call, this will automatically be set to the host from the API key but you can still overwrite it with this call.
307    ///
308    /// # Arguments
309    /// * `host` - The host to connect to (might include the port number e.g. "https://stt.example.com:9424"). Note that the protocol must be included.
310    ///
311    /// # Example
312    ///
313    /// ```no_run
314    /// use aristech_stt_client::{SttClientBuilder};
315    ///
316    /// #[tokio::main]
317    /// async fn main() {
318    ///     let client = SttClientBuilder::default()
319    ///       .host("https://stt.example.com:9423").unwrap()
320    ///       .build()
321    ///       .await
322    ///       .unwrap();
323    ///       // Use the client
324    ///       // ...
325    /// }
326    pub fn host(mut self, host: &str) -> Result<Self, Box<dyn Error>> {
327        if host.is_empty() {
328            return Err("Host cannot be empty".into());
329        }
330        if !host.starts_with("http://") && host.starts_with("https://") {
331            return Err("Host must start with http:// or https://".into());
332        }
333        self.tls = host.starts_with("https://");
334        self.host = host.to_string();
335        Ok(self)
336    }
337
338    /// Manually enables or disables tls for the SttClientBuilder.  
339    /// **Note:** The other methods will overwrite this setting depending on the given values if called after this method.
340    ///
341    /// # Arguments
342    /// * `tls` - Whether to use tls for the connection.
343    ///
344    /// # Example
345    ///
346    /// ```no_run
347    /// use aristech_stt_client::{SttClientBuilder};
348    ///
349    /// #[tokio::main]
350    /// async fn main() {
351    ///     let client = SttClientBuilder::default()
352    ///       .host("https://stt.example.com:9424").unwrap()
353    ///       .tls(false) // <- This doesn't make much sense because the host obviously uses tls but it's just an example
354    ///       .build()
355    ///       .await
356    ///       .unwrap();
357    ///
358    ///     // Use the client
359    ///     // ...
360    /// }
361    /// ``````
362    pub fn tls(mut self, tls: bool) -> Self {
363        self.tls = tls;
364        self
365    }
366
367    /// Atttempts to build the SttClient with the given options.
368    ///
369    /// # Example
370    ///
371    /// ```no_run
372    /// use aristech_stt_client::{SttClientBuilder};
373    /// use std::error::Error;
374    ///
375    /// #[tokio::main]
376    /// async fn main() -> Result<(), Box<dyn Error>> {
377    ///    let client = SttClientBuilder::new()
378    ///      .build()
379    ///      .await?;
380    ///      // Use the client
381    ///      // ...
382    ///      Ok(())
383    /// }
384    pub async fn build(self) -> Result<SttClient, Box<dyn Error>> {
385        let tls_options = match self.tls {
386            true => Some(TlsOptions::new(self.auth, self.ca_certificate)),
387            false => None,
388        };
389        get_client(self.host, tls_options).await
390    }
391}
392
393/// The TlsOptions struct holds the tls options needed to communicate with the server.
394#[derive(Clone, Default)]
395pub struct TlsOptions {
396    /// The authentication data to authenticate with the server
397    pub auth: Option<Auth>,
398    /// The root certificate to verify the server's certificate
399    /// This is usually only needed when the server uses a self-signed certificate
400    pub ca_certificate: Option<String>,
401}
402
403impl TlsOptions {
404    /// Creates a new TlsOptions struct with the given authentication data and root certificate.
405    pub fn new(auth: Option<Auth>, ca_certificate: Option<String>) -> Self {
406        Self {
407            auth,
408            ca_certificate,
409        }
410    }
411}
412
413/// Creates a new [SttClient] to communicate with the server.
414///
415/// # Arguments
416/// * `host` - The host to connect to (might include the port number e.g. "https://stt.example.com:9424"). Note that the protocol must be included in the host.
417/// * `tls_options` - The tls options to use when connecting to the server. If None is given, the connection will be unencrypted and unauthenticated (the server must also be configured to communicate without encryption in this case).
418pub async fn get_client(
419    host: String,
420    tls_options: Option<TlsOptions>,
421) -> Result<SttClient, Box<dyn Error>> {
422    // Check if a schema is included in the host
423    // otherwise add http if no tls options are given and https otherwise
424    let host = if host.starts_with("http://") || host.starts_with("https://") {
425        host
426    } else {
427        match tls_options {
428            Some(_) => format!("https://{}", host),
429            None => format!("http://{}", host),
430        }
431    };
432    match tls_options {
433        Some(tls_options) => {
434            let tls = match tls_options.ca_certificate {
435                Some(ca_certificate) => {
436                    let ca_certificate = Certificate::from_pem(ca_certificate);
437                    ClientTlsConfig::new().ca_certificate(ca_certificate)
438                }
439                None => ClientTlsConfig::with_native_roots(ClientTlsConfig::new()),
440            };
441            let channel = Channel::from_shared(host)?
442                .tls_config(tls)?
443                .connect()
444                .await?;
445            let client: SttServiceClient<InterceptedService<Channel, AuthInterceptor>> =
446                SttServiceClient::with_interceptor(channel, AuthInterceptor::new(tls_options.auth));
447            Ok(client)
448        }
449        None => {
450            let channel = Channel::from_shared(host)?.connect().await?;
451            let client: SttServiceClient<InterceptedService<Channel, AuthInterceptor>> =
452                SttServiceClient::with_interceptor(channel, AuthInterceptor::new(None));
453            Ok(client)
454        }
455    }
456}
457
458/// Gets the list of available models from the server.
459///
460/// # Arguments
461/// * `client` - The client to use to communicate with the server.
462/// * `request` - The request to send to the server. If None is given, the default request will be used.
463///
464/// # Example
465///
466/// ```no_run
467/// use aristech_stt_client::{get_client, TlsOptions, get_models};
468/// use std::error::Error;
469///
470/// #[tokio::main]
471/// async fn main() -> Result<(), Box<dyn Error>> {
472///     let mut client = get_client("https://tts.example.com".to_string(), Some(TlsOptions::default())).await?;
473///     let response = get_models(&mut client, None).await?;
474///     for model in response.model {
475///         println!("{:?}", model);
476///     }
477///     Ok(())
478/// }
479/// ```
480pub async fn get_models(
481    client: &mut SttClient,
482    request: Option<ModelsRequest>,
483) -> Result<ModelsResponse, Box<dyn Error>> {
484    let req = request.unwrap_or(ModelsRequest::default());
485    let request = tonic::Request::new(req);
486    let response = client.models(request).await?;
487    Ok(response.get_ref().to_owned())
488}
489
490/// Gets the account information from the server.
491///
492/// # Arguments
493/// * `client` - The client to use to communicate with the server.
494/// * `request` - The request to send to the server. If None is given, the default request will be used.
495///
496/// # Example
497///
498/// ```no_run
499/// use aristech_stt_client::{get_client, TlsOptions, get_account_info};
500/// use std::error::Error;
501///
502/// #[tokio::main]
503/// async fn main() -> Result<(), Box<dyn Error>> {
504///     let mut client = get_client("https://tts.example.com".to_string(), Some(TlsOptions::default())).await?;
505///     let response = get_account_info(&mut client, None).await?;
506///     println!("{:#?}", response);
507///     Ok(())
508/// }
509/// ```
510pub async fn get_account_info(
511    client: &mut SttClient,
512    request: Option<AccountInfoRequest>,
513) -> Result<AccountInfoResponse, Box<dyn Error>> {
514    let req = request.unwrap_or(AccountInfoRequest::default());
515    let request = tonic::Request::new(req);
516    let response = client.account_info(request).await?;
517    Ok(response.get_ref().to_owned())
518}
519
520/// Gets the list of available NLP functions for each configured NLP-Server.
521///
522/// # Arguments
523/// * `client` - The client to use to communicate with the server.
524/// * `request` - The request to send to the server. If None is given, the default request will be used.
525///
526/// # Example
527///
528/// ```no_run
529/// use aristech_stt_client::{get_client, TlsOptions, get_nlp_functions};
530/// use std::error::Error;
531///
532/// #[tokio::main]
533/// async fn main() -> Result<(), Box<dyn Error>> {
534///     let mut client = get_client("https://tts.example.com".to_string(), Some(TlsOptions::default())).await?;
535///     let response = get_nlp_functions(&mut client, None).await?;
536///     println!("{:#?}", response);
537///     Ok(())
538/// }
539/// ```
540pub async fn get_nlp_functions(
541    client: &mut SttClient,
542    request: Option<NlpFunctionsRequest>,
543) -> Result<NlpFunctionsResponse, Box<dyn Error>> {
544    let req = request.unwrap_or(NlpFunctionsRequest::default());
545    let request = tonic::Request::new(req);
546    let response = client.nlp_functions(request).await?;
547    Ok(response.get_ref().to_owned())
548}
549
550/// Processes the given text with a given NLP pipeline using the STT-Server as proxy.
551///
552/// # Arguments
553/// * `client` - The client to use to communicate with the server.
554/// * `request` - The request to send to the server.
555///
556/// # Example
557///
558/// ```no_run
559/// use aristech_stt_client::{
560///     get_client, TlsOptions, nlp_process,
561///     stt_service::{NlpFunctionSpec, NlpProcessRequest, NlpSpec},
562/// };
563/// use std::error::Error;
564///
565/// #[tokio::main]
566/// async fn main() -> Result<(), Box<dyn Error>> {
567///     let mut client = get_client("https://tts.example.com".to_string(), Some(TlsOptions::default())).await?;
568///     let response = nlp_process(&mut client, NlpProcessRequest {
569///         text: "hello world".to_string(),
570///         nlp: Some(NlpSpec {{
571///             server_config: "default".to_string(),
572///             functions: vec![NlpFunctionSpec {
573///                 id: "spellcheck-de".to_string(),
574///                 ..NlpFunctionSpec::default()
575///             }],
576///            ..NlpSpec::default()
577///         }}),
578///     }).await?;
579///     println!("{:#?}", response);
580///     Ok(())
581/// }
582/// ```
583pub async fn nlp_process(
584    client: &mut SttClient,
585    request: NlpProcessRequest,
586) -> Result<NlpProcessResponse, Box<dyn Error>> {
587    let request = tonic::Request::new(request);
588    let response = client.nlp_process(request).await?;
589    Ok(response.get_ref().to_owned())
590}
591
592/// Performs speech recognition on a wav file
593///
594/// # Arguments
595/// * `client` - The client to use to communicate with the server.
596/// * `file_path` - The path to the wav file to recognize.
597/// * `config` - The recognition configuration to use. If None is given, the default configuration with locale "en" and the sample rate from the wav file will be used.
598///
599/// # Example
600///
601/// ```no_run
602/// use aristech_stt_client::{get_client, TlsOptions, recognize_file};
603/// use std::error::Error;
604///
605/// #[tokio::main]
606/// async fn main() -> Result<(), Box<dyn Error>> {
607///     let mut client = get_client("https://tts.example.com".to_string(), Some(TlsOptions::default())).await?;
608///     let results = recognize_file(&mut client, "my-audio.wav", None).await?;
609///     for result in results {
610///         println!(
611///             "{}",
612///             result
613///                 .chunks
614///                 .get(0)
615///                 .unwrap()
616///                 .alternatives
617///                 .get(0)
618///                 .unwrap()
619///                 .text
620///         );
621///     }
622///     Ok(())
623/// }
624/// ```
625pub async fn recognize_file(
626    client: &mut SttClient,
627    file_path: &str,
628    config: Option<RecognitionConfig>,
629) -> Result<Vec<StreamingRecognitionResponse>, Box<dyn Error>> {
630    let mut responses = Vec::new();
631    // Read the file with hound::WavReader
632    let wav_reader = hound::WavReader::open(file_path)?;
633    let sample_rate_hertz = wav_reader.spec().sample_rate as i64;
634    let spec = config.unwrap_or_default().specification.unwrap_or_default();
635    let initial_request = StreamingRecognitionRequest {
636        streaming_request: Some(StreamingRequest::Config(RecognitionConfig {
637            specification: Some(RecognitionSpec {
638                sample_rate_hertz,      // Set sample_rate_hertz from the WAV file
639                partial_results: false, // We don't want partial results for files
640                ..spec
641            }),
642            // At the moment there is nothing besides the specification in the config so we can use the default
643            ..RecognitionConfig::default()
644        })),
645    };
646
647    let audio_content = std::fs::read(file_path)?;
648    // Remove the header of the wav file
649    let audio_content = &audio_content[44..];
650    let audio_request = StreamingRecognitionRequest {
651        streaming_request: Some(
652            streaming_recognition_request::StreamingRequest::AudioContent(audio_content.to_vec()),
653        ),
654    };
655    // Create an tokio_stream where the first item is the initial request
656    let input_stream = tokio_stream::iter(vec![initial_request, audio_request]);
657    let mut stream = client.streaming_recognize(input_stream).await?.into_inner();
658    while let Some(response) = stream.message().await? {
659        responses.push(response);
660    }
661    Ok(responses)
662}