aristech_nlp_client/
lib.rs

1//! # Aristech NLP-Client
2//! This is a client library for the Aristech NLP-Server.
3
4#![warn(missing_docs)]
5
6/// The nlp_server module contains types and functions generated from the Aristech NLP proto file.
7pub mod nlp_service {
8    #![allow(missing_docs)]
9    tonic::include_proto!("aristech.nlp");
10}
11
12use std::error::Error;
13
14use nlp_service::nlp_server_client::NlpServerClient;
15use nlp_service::{
16    Function, FunctionRequest, GetContentRequest, GetContentResponse, GetIntentsRequest,
17    GetProjectsRequest, GetScoreLimitsRequest, GetScoreLimitsResponse, Project,
18    RemoveContentRequest, RemoveContentResponse, RunFunctionsRequest, RunFunctionsResponse,
19    UpdateContentRequest, UpdateContentResponse,
20};
21use tonic::codegen::InterceptedService;
22use tonic::service::Interceptor;
23use tonic::transport::{Certificate, Channel, ClientTlsConfig};
24
25use crate::nlp_service::Intent;
26
27/// The Auth struct holds the token and secret needed to authenticate with the server.
28#[derive(Clone, Debug)]
29pub struct Auth {
30    /// The token to authenticate with the server
31    pub token: String,
32    /// The secret to authenticate with the server
33    pub secret: String,
34}
35
36impl Auth {
37    /// Creates a new Auth struct with the given token and secret.
38    pub fn new(token: &str, secret: &str) -> Self {
39        Self {
40            token: token.to_string(),
41            secret: secret.to_string(),
42        }
43    }
44}
45
46/// The AuthInterceptor struct is used to intercept requests to the server and add the authentication headers.
47pub struct AuthInterceptor {
48    /// The authentication data to add to the headers
49    auth: Option<Auth>,
50}
51
52impl AuthInterceptor {
53    /// Creates a new AuthInterceptor with the given authentication data.
54    fn new(auth: Option<Auth>) -> Self {
55        Self { auth }
56    }
57}
58impl Interceptor for AuthInterceptor {
59    /// Adds the authentication data to the headers of the request.
60    fn call(&mut self, request: tonic::Request<()>) -> Result<tonic::Request<()>, tonic::Status> {
61        if let Some(auth) = &self.auth {
62            let mut request = request;
63            request
64                .metadata_mut()
65                .insert("token", auth.token.parse().unwrap());
66            request
67                .metadata_mut()
68                .insert("secret", auth.secret.parse().unwrap());
69            Ok(request)
70        } else {
71            Ok(request)
72        }
73    }
74}
75
76/// The NlpClient type is a type alias for the NlpServiceClient with the AuthInterceptor.
77pub type NlpClient = NlpServerClient<InterceptedService<Channel, AuthInterceptor>>;
78
79/// The TlsOptions struct holds the tls options needed to communicate with the server.
80#[derive(Clone, Default, Debug)]
81pub struct TlsOptions {
82    /// The authentication data to authenticate with the server
83    pub auth: Option<Auth>,
84    /// The root certificate to verify the server's certificate
85    /// This is usually only needed when the server uses a self-signed certificate
86    pub ca_certificate: Option<String>,
87}
88
89impl TlsOptions {
90    /// Creates a new TlsOptions struct with the given authentication data and root certificate.
91    pub fn new(auth: Option<Auth>, ca_certificate: Option<String>) -> Self {
92        Self {
93            auth,
94            ca_certificate,
95        }
96    }
97}
98
99/// Creates a new [NlpClient] to communicate with the server.
100///
101/// # Arguments
102/// * `host` - The host to connect to (might include the port number e.g. "https://nlp.example.com:8524"). Note that the protocol must be included in the host.
103/// * `tls_options` - The tls options to use when connecting to the server. If None is given, the connection will be unencrypted and unauthenticated (the server must also be configured to communicate without encryption in this case).
104pub async fn get_client(
105    host: String,
106    tls_options: Option<TlsOptions>,
107) -> Result<NlpClient, Box<dyn Error>> {
108    // Check if a schema is included in the host
109    // otherwise add http if no tls options are given and https otherwise
110    let host = if host.starts_with("http://") || host.starts_with("https://") {
111        host
112    } else {
113        match tls_options {
114            Some(_) => format!("https://{}", host),
115            None => format!("http://{}", host),
116        }
117    };
118    match tls_options {
119        Some(tls_options) => {
120            let tls = match tls_options.ca_certificate {
121                Some(ca_certificate) => {
122                    let ca_certificate = Certificate::from_pem(ca_certificate);
123                    ClientTlsConfig::new().ca_certificate(ca_certificate)
124                }
125                None => ClientTlsConfig::with_native_roots(ClientTlsConfig::new()),
126            };
127            let channel = Channel::from_shared(host)?
128                .tls_config(tls)?
129                .connect()
130                .await?;
131            let client: NlpServerClient<InterceptedService<Channel, AuthInterceptor>> =
132                NlpServerClient::with_interceptor(channel, AuthInterceptor::new(tls_options.auth));
133            Ok(client)
134        }
135        None => {
136            let channel = Channel::from_shared(host)?.connect().await?;
137            let client: NlpServerClient<InterceptedService<Channel, AuthInterceptor>> =
138                NlpServerClient::with_interceptor(channel, AuthInterceptor::new(None));
139            Ok(client)
140        }
141    }
142}
143
144/// Gets the functions available on the server.
145///
146/// # Arguments
147/// * `client` - The client to use to communicate with the server.  
148/// * `request` - The request to send to the server. If None is given, the default request will be used.
149///
150/// # Example
151///
152/// ```no_run
153/// use aristech_nlp_client::{get_client, list_functions, TlsOptions};
154/// use std::error::Error;
155///
156/// #[tokio::main]
157/// async fn main() -> Result<(), Box<dyn Error>> {
158///     let mut client = get_client("https://nlp.example.com".to_string(), Some(TlsOptions::default())).await?;
159///     let functions = list_functions(&mut client, None).await?;
160///     for function in functions {
161///         println!("{:?}", function);
162///     }
163///     Ok(())
164/// }
165/// ```
166pub async fn list_functions(
167    client: &mut NlpClient,
168    request: Option<FunctionRequest>,
169) -> Result<Vec<Function>, Box<dyn Error>> {
170    let req = request.unwrap_or_default();
171    let response = client.get_functions(req).await?;
172    let mut stream = response.into_inner();
173    let mut functions = vec![];
174    while let Some(function) = stream.message().await? {
175        functions.push(function);
176    }
177    Ok(functions)
178}
179
180/// Processes the given text with the specified function pipeline.
181///
182/// # Arguments
183/// * `client` - The client to use to communicate with the server.
184/// * `request` - The request to send to the server.
185///
186/// # Example
187///
188/// ```no_run
189/// use aristech_nlp_client::{get_client, process, TlsOptions};
190/// use std::error::Error;
191///
192/// #[tokio::main]
193/// async fn main() -> Result<(), Box<dyn Error>> {
194///     let mut client = get_client("https://nlp.example.com".to_string(), Some(TlsOptions::default())).await?;
195///     let request = ProcessRawRequest {
196///         input: "Hello, world!".to_string(),
197///         functions: vec![
198///             Function { id: "spellcheck-de".to_string(), ..Function::default() },
199///         ],
200///        ..ProcessRawRequest::default()
201///     };
202///     let response = run_functions(&mut client, request).await?;
203///     println!("{}", response.output);
204///     Ok(())
205/// }
206pub async fn run_functions(
207    client: &mut NlpClient,
208    request: RunFunctionsRequest,
209) -> Result<RunFunctionsResponse, Box<dyn Error>> {
210    let response = client.run_functions(request).await?;
211    Ok(response.into_inner())
212}
213
214/// Gets the projects available on the server.
215///
216/// # Arguments
217/// * `client` - The client to use to communicate with the server.
218/// * `request` - The request to send to the server. If None is given, the default request will be used.
219///
220/// # Example
221///
222/// ```no_run
223/// use aristech_nlp_client::{get_client, list_projects, TlsOptions};
224/// use std::error::Error;
225///
226/// #[tokio::main]
227/// async fn main() -> Result<(), Box<dyn Error>> {
228///     let mut client = get_client("https://nlp.example.com".to_string(), Some(TlsOptions::default())).await?;
229///     let projects = list_projects(&mut client, None).await?;
230///     for project in projects {
231///         println!("{:?}", project);
232///     }
233///     Ok(())
234/// }
235/// ```
236pub async fn list_projects(
237    client: &mut NlpClient,
238    request: Option<GetProjectsRequest>,
239) -> Result<Vec<Project>, Box<dyn Error>> {
240    let req = request.unwrap_or_default();
241    let response = client.get_projects(req).await?;
242    let mut stream = response.into_inner();
243    let mut projects = vec![];
244    while let Some(project) = stream.message().await? {
245        projects.push(project);
246    }
247    Ok(projects)
248}
249
250/// Gets the intents available on the server.
251///
252/// # Arguments
253/// * `client` - The client to use to communicate with the server.
254/// * `request` - The request to send to the server. If None is given, the default request will be used.
255///
256/// # Example
257///
258/// ```no_run
259/// use aristech_nlp_client::{get_client, get_intents, TlsOptions};
260/// use std::error::Error;
261///
262/// #[tokio::main]
263/// async fn main() -> Result<(), Box<dyn Error>> {
264///     let mut client = get_client("https://nlp.example.com".to_string(), Some(TlsOptions::default())).await?;
265///     let intents = get_intents(&mut client, None).await?;
266///     for intent in intents {
267///         println!("{:?}", intent);
268///     }
269///     Ok(())
270/// }
271/// ```
272pub async fn get_intents(
273    client: &mut NlpClient,
274    request: GetIntentsRequest,
275) -> Result<Vec<Intent>, Box<dyn Error>> {
276    let response = client.get_intents(request).await?;
277    let mut stream = response.into_inner();
278    let mut intents = vec![];
279    while let Some(intent) = stream.message().await? {
280        intents.push(intent);
281    }
282    Ok(intents)
283}
284
285/// Gets the score limits for the given input.
286///
287/// # Arguments
288/// * `client` - The client to use to communicate with the server.
289/// * `request` - The request to send to the server.
290///
291/// # Example
292///
293/// ```no_run
294/// use aristech_nlp_client::{get_client, get_score_limits, TlsOptions};
295/// use std::error::Error;
296///
297/// #[tokio::main]
298/// async fn main() -> Result<(), Box<dyn Error>> {
299///     let mut client = get_client("https://nlp.example.com".to_string(), Some(TlsOptions::default())).await?;
300///     let request = GetScoreLimitsRequest {
301///         input: "Hello, world!".to_string(),
302///         ..GetScoreLimitsRequest::default()
303///     };
304///     let response = get_score_limits(&mut client, request).await?;
305///     println!("{:?}", response);
306///     Ok(())
307/// }
308/// ```
309pub async fn get_score_limits(
310    client: &mut NlpClient,
311    request: GetScoreLimitsRequest,
312) -> Result<GetScoreLimitsResponse, Box<dyn Error>> {
313    let response = client.get_score_limits(request).await?;
314    Ok(response.into_inner())
315}
316
317/// Gets the content for the given request.
318///
319/// # Arguments
320/// * `client` - The client to use to communicate with the server.
321/// * `request` - The request to send to the server.
322///
323/// # Example
324///
325/// ```no_run
326/// use aristech_nlp_client::{get_client, get_content, TlsOptions};
327/// use std::error::Error;
328///
329/// #[tokio::main]
330/// async fn main() -> Result<(), Box<dyn Error>> {
331///     let mut client = get_client("https://nlp.example.com".to_string(), Some(TlsOptions::default())).await?;
332///     let request = GetContentRequest {
333///         prompt: "What are the lottery numbers?".to_string(),
334///         threshold: 0.5,
335///         return_payload: true,
336///         num_results: 3,
337///         metadata: Some(ContentMetaData {
338///             project_id: "3f6959e6-cfb5-4eed-8195-033d47b73263".to_string(),
339///             exclude_output_from_search: true,
340///             ..ContentMetaData::default()
341///         }),
342///         ..GetContentRequest::default()
343///     };
344///     let responses = get_content(&mut client, request).await?;
345///     for response in responses {
346///         println!("{:?}", response.items);
347///     }
348///     Ok(())
349/// }
350/// ```
351pub async fn get_content(
352    client: &mut NlpClient,
353    request: GetContentRequest,
354) -> Result<Vec<GetContentResponse>, Box<dyn Error>> {
355    let response = client.get_content(request).await?;
356    let mut stream = response.into_inner();
357    let mut content = vec![];
358    while let Some(c) = stream.message().await? {
359        content.push(c);
360    }
361    Ok(content)
362}
363
364/// Updates the content for the given request.
365///
366/// # Arguments
367/// * `client` - The client to use to communicate with the server.
368/// * `request` - The request to send to the server.
369///
370/// # Example
371///
372/// ```no_run
373/// use aristech_nlp_client::{get_client, update_content, TlsOptions, UpdateContentRequest, ContentMetaData, ContentData, DescriptionMapping, Output, OutputType};
374/// use std::error::Error;
375///
376/// #[tokio::main]
377/// async fn main() -> Result<(), Box<dyn Error>> {
378///     let mut client = get_client("https://nlp.example.com".to_string(), Some(TlsOptions::default())).await?;
379///     let request = UpdateContentRequest {
380///         id: "123".to_string(),
381///         metadata: Some(ContentMetaData {
382///             project_id: "456".to_string(),
383///             ..ContentMetaData::default()
384///         }),
385///         content: Some(ContentData {
386///             topic: "lottery".to_string(),
387///             description_mappings: vec![DescriptionMapping {
388///                 uuid: "82a2133e-038e-4c83-aa26-de47ed386c55".to_string(),
389///                 description: "What are the lottery numbers?".to_string(),
390///                 ..DescriptionMapping::default()
391///             }],
392///             output: vec![Output {
393///                 r#type: OutputType::Chat.into(),
394///                 data: "The latest lottery numbers are {{lotto_numbers}}".into(),
395///             }],
396///             ..ContentData::default()
397///         }),
398///     };
399///     let response = update_content(&mut client, request).await?;
400///     println!("{:?}", response);
401///     Ok(())
402/// }
403/// ```
404pub async fn update_content(
405    client: &mut NlpClient,
406    request: UpdateContentRequest,
407) -> Result<UpdateContentResponse, Box<dyn Error>> {
408    let response = client.update_content(request).await?;
409    Ok(response.into_inner())
410}
411
412/// Removes the content for the given request.
413///
414/// # Arguments
415/// * `client` - The client to use to communicate with the server.
416/// * `request` - The request to send to the server.
417///
418/// # Example
419///
420/// ```no_run
421/// use aristech_nlp_client::{get_client, remove_content, TlsOptions, RemoveContentRequest, ContentMetaData};
422/// use std::error::Error;
423///
424/// #[tokio::main]
425/// async fn main() -> Result<(), Box<dyn Error>> {
426///     let mut client = get_client("https://nlp.example.com".to_string(), Some(TlsOptions::default())).await?;
427///     let request = RemoveContentRequest {
428///         id: "123".to_string(),
429///         metadata: Some(ContentMetaData {
430///             project_id: "456".to_string(),
431///             ..ContentMetaData::default()
432///         }),
433///         ..RemoveContentRequest::default()
434///     };
435///     let response = remove_content(&mut client, request).await?;
436///     println!("{:?}", response);
437///     Ok(())
438/// }
439/// ```
440pub async fn remove_content(
441    client: &mut NlpClient,
442    request: RemoveContentRequest,
443) -> Result<RemoveContentResponse, Box<dyn Error>> {
444    let response = client.remove_content(request).await?;
445    Ok(response.into_inner())
446}