aristech_stt_client/lib.rs
1//! # Aristech STT-Client
2//! The Aristech STT-Client is a client library for the Aristech STT-Server.
3
4#![warn(missing_docs)]
5
6/// The stt_service module contains types and functions generated from the Aristech STT proto file.
7pub mod stt_service {
8 #![allow(missing_docs)]
9 tonic::include_proto!("ari.stt.v1");
10}
11
12use std::error::Error;
13
14use stt_service::streaming_recognition_request::StreamingRequest;
15use tonic::codegen::InterceptedService;
16use tonic::service::Interceptor;
17use tonic::transport::{Certificate, Channel, ClientTlsConfig};
18
19use stt_service::stt_service_client::SttServiceClient;
20use stt_service::{
21 streaming_recognition_request, AccountInfoRequest, AccountInfoResponse, ModelsRequest,
22 ModelsResponse, NlpFunctionsRequest, NlpFunctionsResponse, NlpProcessRequest,
23 NlpProcessResponse, RecognitionConfig, RecognitionSpec, StreamingRecognitionRequest,
24 StreamingRecognitionResponse,
25};
26
27use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
28
29/// The Auth struct holds the token and secret needed to authenticate with the server.
30#[derive(Clone)]
31pub struct Auth {
32 /// The token to authenticate with the server
33 pub token: String,
34 /// The secret to authenticate with the server
35 pub secret: String,
36}
37
38impl Auth {
39 /// Creates a new Auth struct with the given token and secret.
40 pub fn new(token: &str, secret: &str) -> Self {
41 Self {
42 token: token.to_string(),
43 secret: secret.to_string(),
44 }
45 }
46}
47
48/// The AuthInterceptor struct is used to intercept requests to the server and add the authentication headers.
49pub struct AuthInterceptor {
50 /// The authentication data to add to the headers
51 auth: Option<Auth>,
52}
53
54impl AuthInterceptor {
55 /// Creates a new AuthInterceptor with the given authentication data.
56 fn new(auth: Option<Auth>) -> Self {
57 Self { auth }
58 }
59}
60impl Interceptor for AuthInterceptor {
61 /// Adds the authentication data to the headers of the request.
62 fn call(&mut self, request: tonic::Request<()>) -> Result<tonic::Request<()>, tonic::Status> {
63 if let Some(auth) = &self.auth {
64 let mut request = request;
65 request
66 .metadata_mut()
67 .insert("token", auth.token.parse().unwrap());
68 request
69 .metadata_mut()
70 .insert("secret", auth.secret.parse().unwrap());
71 Ok(request)
72 } else {
73 Ok(request)
74 }
75 }
76}
77
78/// The SttClient type is a type alias for the SttServiceClient with the AuthInterceptor.
79pub type SttClient = SttServiceClient<InterceptedService<Channel, AuthInterceptor>>;
80
81struct ApiKeyData {
82 token: String,
83 secret: String,
84 host: Option<String>,
85}
86
87/// Decodes the given api key into an Auth struct.
88fn decode_api_key(api_key: &str) -> Result<ApiKeyData, Box<dyn Error>> {
89 // The API key is base64 url encoded and has no padding and starts with "at-"
90 let api_key = api_key.trim_start_matches("at-");
91 let key_data = URL_SAFE_NO_PAD.decode(api_key)?;
92 let key_data = String::from_utf8(key_data)?;
93
94 let mut token = None;
95 let mut secret = None;
96 let mut host = None;
97 for line in key_data.lines() {
98 let mut parts = line.splitn(2, ":");
99 let key = match parts.next() {
100 Some(key) => key.trim(),
101 None => continue,
102 };
103 let value = match parts.next() {
104 Some(value) => value.trim(),
105 None => continue,
106 };
107 match key {
108 "token" => token = Some(value.to_string()),
109 "secret" => secret = Some(value.to_string()),
110 "host" => {
111 // If the host doesn't start with http:// or https://, add https://
112 let key_host = value.to_string();
113 host = match key_host.starts_with("http://") || key_host.starts_with("https://") {
114 true => Some(
115 key_host
116 .trim_end_matches('/')
117 .trim_end_matches('/')
118 .to_string(),
119 ),
120 false => Some(format!(
121 "https://{}",
122 key_host.trim_end_matches('/').trim_end_matches('/')
123 )),
124 };
125 }
126 _ => {}
127 }
128 }
129 match (token, secret) {
130 (Some(token), Some(secret)) => Ok(ApiKeyData {
131 token,
132 secret,
133 host,
134 }),
135 _ => Err("API key is missing token or secret".into()),
136 }
137}
138
139/// The SttClientBuilder struct is used to build a SttClient with the given host and tls options.
140#[derive(Default)]
141pub struct SttClientBuilder {
142 host: String,
143 tls: bool,
144 auth: Option<Auth>,
145 ca_certificate: Option<String>,
146}
147
148impl SttClientBuilder {
149 /// Creates a new SttClientBuilder and tries to parse the API key from the environment variable `ARISTECH_STT_API_KEY`.
150 /// If no API key is found or the API key is invalid, the builder will be created without authentication data.
151 /// Tls will be enabled if a valid API key was found or false otherwise.
152 /// To catch any errors with the API key, use the `.api_key` method.
153 ///
154 /// If a valid API key was found, a custom root certificate can be set with the environment variable `ARISTECH_STT_CA_CERTIFICATE` as well if the server uses a self-signed certificate for example.
155 /// The certificate should be the path to the certificate in PEM format. It is also possible to set the certificate with the `.ca_certificate` method.
156 ///
157 /// To create a client without automatically checking the environment variable, use the default constructor.
158 ///
159 /// # Example
160 ///
161 /// ```no_run
162 /// use aristech_stt_client::{SttClientBuilder};
163 ///
164 /// #[tokio::main]
165 /// async fn main() {
166 /// let client = SttClientBuilder::new()
167 /// .build()
168 /// .await
169 /// .unwrap();
170 /// // Use the client
171 /// // ...
172 /// }
173 pub fn new() -> Self {
174 // Try parsing the API key from the environment `ARISTECH_STT_API_KEY`
175 if let Ok(api_key) = std::env::var("ARISTECH_STT_API_KEY") {
176 if let Ok(api_key_data) = decode_api_key(&api_key) {
177 let ca_certificate = match std::env::var("ARISTECH_STT_CA_CERTIFICATE") {
178 Ok(ca_certificate) if !ca_certificate.is_empty() => {
179 // Try to read the certificate from the file
180 match std::fs::read_to_string(ca_certificate) {
181 Ok(ca_certificate) => Some(ca_certificate),
182 Err(_) => None,
183 }
184 }
185 _ => None,
186 };
187 let host = api_key_data.host.unwrap_or_default();
188 return Self {
189 tls: true,
190 host,
191 auth: Some(Auth::new(&api_key_data.token, &api_key_data.secret)),
192 ca_certificate,
193 };
194 }
195 }
196 Self {
197 tls: false,
198 ..Default::default()
199 }
200 }
201
202 /// Attempts to parse the given API key and set the authentication data for the SttClientBuilder.
203 /// When using the `new` method, the builder will automatically try to parse the API key from the environment variable `ARISTECH_STT_API_KEY` but won't fail if the API key is invalid or missing.
204 /// You can use this method to manually set the API key and catch any errors.
205 /// Note that the host from the API key will only be used if no host was set before.
206 ///
207 /// # Arguments
208 /// * `api_key` - The API key to use for the connection.
209 ///
210 /// # Example
211 ///
212 /// ```no_run
213 /// use aristech_stt_client::{SttClientBuilder};
214 ///
215 /// #[tokio::main]
216 /// async fn main() {
217 /// // Use the default constructor to create a new SttClientBuilder without
218 /// // automatically attempting to parse the environment variable `ARISTECH_STT_API_KEY`.
219 /// let client = SttClientBuilder::default()
220 /// .api_key("at-abc123...").unwrap()
221 /// .build()
222 /// .await
223 /// .unwrap();
224 /// // Use the client
225 /// // ...
226 /// }
227 /// ```
228 pub fn api_key(mut self, api_key: &str) -> Result<Self, Box<dyn Error>> {
229 let api_key_data = decode_api_key(api_key)?;
230 if let Some(host) = api_key_data.host {
231 if self.host.is_empty() {
232 self.host = host;
233 }
234 }
235 self.tls = true;
236 self.auth = Some(Auth::new(&api_key_data.token, &api_key_data.secret));
237 Ok(self)
238 }
239
240 /// Allows to set a custom root certificate to use for the connection and enables tls when a certificate is set to Some.
241 /// This is especially useful when the server uses a self-signed certificate.
242 /// The `ca_certificate` should be the content of the certificate in PEM format.
243 ///
244 /// # Arguments
245 /// * `ca_certificate` - The root certificate to use for the connection.
246 ///
247 /// # Example
248 ///
249 /// ```no_run
250 /// use aristech_stt_client::{SttClientBuilder};
251 ///
252 /// #[tokio::main]
253 /// async fn main() {
254 /// let client = SttClientBuilder::new()
255 /// .ca_certificate(Some(std::fs::read_to_string("path/to/certificate.pem").unwrap()))
256 /// .build()
257 /// .await
258 /// .unwrap();
259 /// // Use the client
260 /// // ...
261 /// }
262 /// ```
263 pub fn ca_certificate(mut self, ca_certificate: Option<String>) -> Self {
264 // If a ca_certificate is set, we need to use tls
265 match ca_certificate {
266 Some(_) => self.tls = true,
267 _ => {}
268 }
269 self.ca_certificate = ca_certificate;
270 self
271 }
272
273 /// Sets the auth options for the SttClientBuilder manually and enables tls when auth is set to Some.
274 /// **Note:** Calling `.api_key` after `.auth` will overwrite the auth data.
275 ///
276 /// # Arguments
277 /// * `auth` - The authentication data to use for the connection.
278 ///
279 /// # Example
280 ///
281 /// ```no_run
282 /// use aristech_stt_client::{SttClientBuilder, Auth};
283 ///
284 /// #[tokio::main]
285 /// async fn main() {
286 /// let client = SttClientBuilder::default()
287 /// .host("https://stt.example.com:9424").unwrap()
288 /// .auth(Some(Auth { token: "my-token".to_string(). secret: "my-secret".to_string() }))
289 /// .build()
290 /// .await
291 /// .unwrap();
292 /// // Use the client
293 /// // ...
294 /// }
295 pub fn auth(mut self, auth: Option<Auth>) -> Self {
296 // If auth is set, we need to use tls
297 match auth {
298 Some(_) => self.tls = true,
299 _ => {}
300 }
301 self.auth = auth;
302 self
303 }
304
305 /// Sets the host for the SttClientBuilder manually and enables tls depending on the protocol of the host.
306 /// **Note:** When the API key in the environment variable ARISTECH_STT_API_KEY contains a host or when you call `.api_key` before this call, this will automatically be set to the host from the API key but you can still overwrite it with this call.
307 ///
308 /// # Arguments
309 /// * `host` - The host to connect to (might include the port number e.g. "https://stt.example.com:9424"). Note that the protocol must be included.
310 ///
311 /// # Example
312 ///
313 /// ```no_run
314 /// use aristech_stt_client::{SttClientBuilder};
315 ///
316 /// #[tokio::main]
317 /// async fn main() {
318 /// let client = SttClientBuilder::default()
319 /// .host("https://stt.example.com:9423").unwrap()
320 /// .build()
321 /// .await
322 /// .unwrap();
323 /// // Use the client
324 /// // ...
325 /// }
326 pub fn host(mut self, host: &str) -> Result<Self, Box<dyn Error>> {
327 if host.is_empty() {
328 return Err("Host cannot be empty".into());
329 }
330 if !host.starts_with("http://") && host.starts_with("https://") {
331 return Err("Host must start with http:// or https://".into());
332 }
333 self.tls = host.starts_with("https://");
334 self.host = host.to_string();
335 Ok(self)
336 }
337
338 /// Manually enables or disables tls for the SttClientBuilder.
339 /// **Note:** The other methods will overwrite this setting depending on the given values if called after this method.
340 ///
341 /// # Arguments
342 /// * `tls` - Whether to use tls for the connection.
343 ///
344 /// # Example
345 ///
346 /// ```no_run
347 /// use aristech_stt_client::{SttClientBuilder};
348 ///
349 /// #[tokio::main]
350 /// async fn main() {
351 /// let client = SttClientBuilder::default()
352 /// .host("https://stt.example.com:9424").unwrap()
353 /// .tls(false) // <- This doesn't make much sense because the host obviously uses tls but it's just an example
354 /// .build()
355 /// .await
356 /// .unwrap();
357 ///
358 /// // Use the client
359 /// // ...
360 /// }
361 /// ``````
362 pub fn tls(mut self, tls: bool) -> Self {
363 self.tls = tls;
364 self
365 }
366
367 /// Atttempts to build the SttClient with the given options.
368 ///
369 /// # Example
370 ///
371 /// ```no_run
372 /// use aristech_stt_client::{SttClientBuilder};
373 /// use std::error::Error;
374 ///
375 /// #[tokio::main]
376 /// async fn main() -> Result<(), Box<dyn Error>> {
377 /// let client = SttClientBuilder::new()
378 /// .build()
379 /// .await?;
380 /// // Use the client
381 /// // ...
382 /// Ok(())
383 /// }
384 pub async fn build(self) -> Result<SttClient, Box<dyn Error>> {
385 let tls_options = match self.tls {
386 true => Some(TlsOptions::new(self.auth, self.ca_certificate)),
387 false => None,
388 };
389 get_client(self.host, tls_options).await
390 }
391}
392
393/// The TlsOptions struct holds the tls options needed to communicate with the server.
394#[derive(Clone, Default)]
395pub struct TlsOptions {
396 /// The authentication data to authenticate with the server
397 pub auth: Option<Auth>,
398 /// The root certificate to verify the server's certificate
399 /// This is usually only needed when the server uses a self-signed certificate
400 pub ca_certificate: Option<String>,
401}
402
403impl TlsOptions {
404 /// Creates a new TlsOptions struct with the given authentication data and root certificate.
405 pub fn new(auth: Option<Auth>, ca_certificate: Option<String>) -> Self {
406 Self {
407 auth,
408 ca_certificate,
409 }
410 }
411}
412
413/// Creates a new [SttClient] to communicate with the server.
414///
415/// # Arguments
416/// * `host` - The host to connect to (might include the port number e.g. "https://stt.example.com:9424"). Note that the protocol must be included in the host.
417/// * `tls_options` - The tls options to use when connecting to the server. If None is given, the connection will be unencrypted and unauthenticated (the server must also be configured to communicate without encryption in this case).
418pub async fn get_client(
419 host: String,
420 tls_options: Option<TlsOptions>,
421) -> Result<SttClient, Box<dyn Error>> {
422 // Check if a schema is included in the host
423 // otherwise add http if no tls options are given and https otherwise
424 let host = if host.starts_with("http://") || host.starts_with("https://") {
425 host
426 } else {
427 match tls_options {
428 Some(_) => format!("https://{}", host),
429 None => format!("http://{}", host),
430 }
431 };
432 match tls_options {
433 Some(tls_options) => {
434 let tls = match tls_options.ca_certificate {
435 Some(ca_certificate) => {
436 let ca_certificate = Certificate::from_pem(ca_certificate);
437 ClientTlsConfig::new().ca_certificate(ca_certificate)
438 }
439 None => ClientTlsConfig::with_native_roots(ClientTlsConfig::new()),
440 };
441 let channel = Channel::from_shared(host)?
442 .tls_config(tls)?
443 .connect()
444 .await?;
445 let client: SttServiceClient<InterceptedService<Channel, AuthInterceptor>> =
446 SttServiceClient::with_interceptor(channel, AuthInterceptor::new(tls_options.auth));
447 Ok(client)
448 }
449 None => {
450 let channel = Channel::from_shared(host)?.connect().await?;
451 let client: SttServiceClient<InterceptedService<Channel, AuthInterceptor>> =
452 SttServiceClient::with_interceptor(channel, AuthInterceptor::new(None));
453 Ok(client)
454 }
455 }
456}
457
458/// Gets the list of available models from the server.
459///
460/// # Arguments
461/// * `client` - The client to use to communicate with the server.
462/// * `request` - The request to send to the server. If None is given, the default request will be used.
463///
464/// # Example
465///
466/// ```no_run
467/// use aristech_stt_client::{get_client, TlsOptions, get_models};
468/// use std::error::Error;
469///
470/// #[tokio::main]
471/// async fn main() -> Result<(), Box<dyn Error>> {
472/// let mut client = get_client("https://tts.example.com".to_string(), Some(TlsOptions::default())).await?;
473/// let response = get_models(&mut client, None).await?;
474/// for model in response.model {
475/// println!("{:?}", model);
476/// }
477/// Ok(())
478/// }
479/// ```
480pub async fn get_models(
481 client: &mut SttClient,
482 request: Option<ModelsRequest>,
483) -> Result<ModelsResponse, Box<dyn Error>> {
484 let req = request.unwrap_or(ModelsRequest::default());
485 let request = tonic::Request::new(req);
486 let response = client.models(request).await?;
487 Ok(response.get_ref().to_owned())
488}
489
490/// Gets the account information from the server.
491///
492/// # Arguments
493/// * `client` - The client to use to communicate with the server.
494/// * `request` - The request to send to the server. If None is given, the default request will be used.
495///
496/// # Example
497///
498/// ```no_run
499/// use aristech_stt_client::{get_client, TlsOptions, get_account_info};
500/// use std::error::Error;
501///
502/// #[tokio::main]
503/// async fn main() -> Result<(), Box<dyn Error>> {
504/// let mut client = get_client("https://tts.example.com".to_string(), Some(TlsOptions::default())).await?;
505/// let response = get_account_info(&mut client, None).await?;
506/// println!("{:#?}", response);
507/// Ok(())
508/// }
509/// ```
510pub async fn get_account_info(
511 client: &mut SttClient,
512 request: Option<AccountInfoRequest>,
513) -> Result<AccountInfoResponse, Box<dyn Error>> {
514 let req = request.unwrap_or(AccountInfoRequest::default());
515 let request = tonic::Request::new(req);
516 let response = client.account_info(request).await?;
517 Ok(response.get_ref().to_owned())
518}
519
520/// Gets the list of available NLP functions for each configured NLP-Server.
521///
522/// # Arguments
523/// * `client` - The client to use to communicate with the server.
524/// * `request` - The request to send to the server. If None is given, the default request will be used.
525///
526/// # Example
527///
528/// ```no_run
529/// use aristech_stt_client::{get_client, TlsOptions, get_nlp_functions};
530/// use std::error::Error;
531///
532/// #[tokio::main]
533/// async fn main() -> Result<(), Box<dyn Error>> {
534/// let mut client = get_client("https://tts.example.com".to_string(), Some(TlsOptions::default())).await?;
535/// let response = get_nlp_functions(&mut client, None).await?;
536/// println!("{:#?}", response);
537/// Ok(())
538/// }
539/// ```
540pub async fn get_nlp_functions(
541 client: &mut SttClient,
542 request: Option<NlpFunctionsRequest>,
543) -> Result<NlpFunctionsResponse, Box<dyn Error>> {
544 let req = request.unwrap_or(NlpFunctionsRequest::default());
545 let request = tonic::Request::new(req);
546 let response = client.nlp_functions(request).await?;
547 Ok(response.get_ref().to_owned())
548}
549
550/// Processes the given text with a given NLP pipeline using the STT-Server as proxy.
551///
552/// # Arguments
553/// * `client` - The client to use to communicate with the server.
554/// * `request` - The request to send to the server.
555///
556/// # Example
557///
558/// ```no_run
559/// use aristech_stt_client::{
560/// get_client, TlsOptions, nlp_process,
561/// stt_service::{NlpFunctionSpec, NlpProcessRequest, NlpSpec},
562/// };
563/// use std::error::Error;
564///
565/// #[tokio::main]
566/// async fn main() -> Result<(), Box<dyn Error>> {
567/// let mut client = get_client("https://tts.example.com".to_string(), Some(TlsOptions::default())).await?;
568/// let response = nlp_process(&mut client, NlpProcessRequest {
569/// text: "hello world".to_string(),
570/// nlp: Some(NlpSpec {{
571/// server_config: "default".to_string(),
572/// functions: vec![NlpFunctionSpec {
573/// id: "spellcheck-de".to_string(),
574/// ..NlpFunctionSpec::default()
575/// }],
576/// ..NlpSpec::default()
577/// }}),
578/// }).await?;
579/// println!("{:#?}", response);
580/// Ok(())
581/// }
582/// ```
583pub async fn nlp_process(
584 client: &mut SttClient,
585 request: NlpProcessRequest,
586) -> Result<NlpProcessResponse, Box<dyn Error>> {
587 let request = tonic::Request::new(request);
588 let response = client.nlp_process(request).await?;
589 Ok(response.get_ref().to_owned())
590}
591
592/// Performs speech recognition on a wav file
593///
594/// # Arguments
595/// * `client` - The client to use to communicate with the server.
596/// * `file_path` - The path to the wav file to recognize.
597/// * `config` - The recognition configuration to use. If None is given, the default configuration with locale "en" and the sample rate from the wav file will be used.
598///
599/// # Example
600///
601/// ```no_run
602/// use aristech_stt_client::{get_client, TlsOptions, recognize_file};
603/// use std::error::Error;
604///
605/// #[tokio::main]
606/// async fn main() -> Result<(), Box<dyn Error>> {
607/// let mut client = get_client("https://tts.example.com".to_string(), Some(TlsOptions::default())).await?;
608/// let results = recognize_file(&mut client, "my-audio.wav", None).await?;
609/// for result in results {
610/// println!(
611/// "{}",
612/// result
613/// .chunks
614/// .get(0)
615/// .unwrap()
616/// .alternatives
617/// .get(0)
618/// .unwrap()
619/// .text
620/// );
621/// }
622/// Ok(())
623/// }
624/// ```
625pub async fn recognize_file(
626 client: &mut SttClient,
627 file_path: &str,
628 config: Option<RecognitionConfig>,
629) -> Result<Vec<StreamingRecognitionResponse>, Box<dyn Error>> {
630 let mut responses = Vec::new();
631 // Read the file with hound::WavReader
632 let wav_reader = hound::WavReader::open(file_path)?;
633 let sample_rate_hertz = wav_reader.spec().sample_rate as i64;
634 let spec = config.unwrap_or_default().specification.unwrap_or_default();
635 let initial_request = StreamingRecognitionRequest {
636 streaming_request: Some(StreamingRequest::Config(RecognitionConfig {
637 specification: Some(RecognitionSpec {
638 sample_rate_hertz, // Set sample_rate_hertz from the WAV file
639 partial_results: false, // We don't want partial results for files
640 ..spec
641 }),
642 // At the moment there is nothing besides the specification in the config so we can use the default
643 ..RecognitionConfig::default()
644 })),
645 };
646
647 let audio_content = std::fs::read(file_path)?;
648 // Remove the header of the wav file
649 let audio_content = &audio_content[44..];
650 let audio_request = StreamingRecognitionRequest {
651 streaming_request: Some(
652 streaming_recognition_request::StreamingRequest::AudioContent(audio_content.to_vec()),
653 ),
654 };
655 // Create an tokio_stream where the first item is the initial request
656 let input_stream = tokio_stream::iter(vec![initial_request, audio_request]);
657 let mut stream = client.streaming_recognize(input_stream).await?.into_inner();
658 while let Some(response) = stream.message().await? {
659 responses.push(response);
660 }
661 Ok(responses)
662}