aristech_nlp_client/
lib.rs

1//! # Aristech NLP-Client
2//! This is a client library for the Aristech NLP-Server.
3
4#![warn(missing_docs)]
5
6/// The nlp_server module contains types and functions generated from the Aristech NLP proto file.
7pub mod nlp_service {
8    #![allow(missing_docs)]
9    tonic::include_proto!("aristech.nlp");
10}
11
12use std::error::Error;
13
14use nlp_service::nlp_server_client::NlpServerClient;
15use nlp_service::{
16    Function, FunctionRequest, GetContentRequest, GetContentResponse, GetIntentsRequest,
17    GetIntentsResponse, GetProjectsRequest, GetScoreLimitsRequest, GetScoreLimitsResponse, Project,
18    RemoveContentRequest, RemoveContentResponse, RunFunctionsRequest, RunFunctionsResponse,
19    UpdateContentRequest, UpdateContentResponse,
20};
21use tonic::codegen::InterceptedService;
22use tonic::service::Interceptor;
23use tonic::transport::{Certificate, Channel, ClientTlsConfig};
24
25/// The Auth struct holds the token and secret needed to authenticate with the server.
26#[derive(Clone, Debug)]
27pub struct Auth {
28    /// The token to authenticate with the server
29    pub token: String,
30    /// The secret to authenticate with the server
31    pub secret: String,
32}
33
34impl Auth {
35    /// Creates a new Auth struct with the given token and secret.
36    pub fn new(token: &str, secret: &str) -> Self {
37        Self {
38            token: token.to_string(),
39            secret: secret.to_string(),
40        }
41    }
42}
43
44/// The AuthInterceptor struct is used to intercept requests to the server and add the authentication headers.
45pub struct AuthInterceptor {
46    /// The authentication data to add to the headers
47    auth: Option<Auth>,
48}
49
50impl AuthInterceptor {
51    /// Creates a new AuthInterceptor with the given authentication data.
52    fn new(auth: Option<Auth>) -> Self {
53        Self { auth }
54    }
55}
56impl Interceptor for AuthInterceptor {
57    /// Adds the authentication data to the headers of the request.
58    fn call(&mut self, request: tonic::Request<()>) -> Result<tonic::Request<()>, tonic::Status> {
59        if let Some(auth) = &self.auth {
60            let mut request = request;
61            request
62                .metadata_mut()
63                .insert("token", auth.token.parse().unwrap());
64            request
65                .metadata_mut()
66                .insert("secret", auth.secret.parse().unwrap());
67            Ok(request)
68        } else {
69            Ok(request)
70        }
71    }
72}
73
74/// The NlpClient type is a type alias for the NlpServiceClient with the AuthInterceptor.
75pub type NlpClient = NlpServerClient<InterceptedService<Channel, AuthInterceptor>>;
76
77/// The TlsOptions struct holds the tls options needed to communicate with the server.
78#[derive(Clone, Default, Debug)]
79pub struct TlsOptions {
80    /// The authentication data to authenticate with the server
81    pub auth: Option<Auth>,
82    /// The root certificate to verify the server's certificate
83    /// This is usually only needed when the server uses a self-signed certificate
84    pub ca_certificate: Option<String>,
85}
86
87impl TlsOptions {
88    /// Creates a new TlsOptions struct with the given authentication data and root certificate.
89    pub fn new(auth: Option<Auth>, ca_certificate: Option<String>) -> Self {
90        Self {
91            auth,
92            ca_certificate,
93        }
94    }
95}
96
97/// Creates a new [NlpClient] to communicate with the server.
98///
99/// # Arguments
100/// * `host` - The host to connect to (might include the port number e.g. "https://nlp.example.com:8524"). Note that the protocol must be included in the host.
101/// * `tls_options` - The tls options to use when connecting to the server. If None is given, the connection will be unencrypted and unauthenticated (the server must also be configured to communicate without encryption in this case).
102pub async fn get_client(
103    host: String,
104    tls_options: Option<TlsOptions>,
105) -> Result<NlpClient, Box<dyn Error>> {
106    // Check if a schema is included in the host
107    // otherwise add http if no tls options are given and https otherwise
108    let host = if host.starts_with("http://") || host.starts_with("https://") {
109        host
110    } else {
111        match tls_options {
112            Some(_) => format!("https://{}", host),
113            None => format!("http://{}", host),
114        }
115    };
116    match tls_options {
117        Some(tls_options) => {
118            let tls = match tls_options.ca_certificate {
119                Some(ca_certificate) => {
120                    let ca_certificate = Certificate::from_pem(ca_certificate);
121                    ClientTlsConfig::new().ca_certificate(ca_certificate)
122                }
123                None => ClientTlsConfig::with_native_roots(ClientTlsConfig::new()),
124            };
125            let channel = Channel::from_shared(host)?
126                .tls_config(tls)?
127                .connect()
128                .await?;
129            let client: NlpServerClient<InterceptedService<Channel, AuthInterceptor>> =
130                NlpServerClient::with_interceptor(channel, AuthInterceptor::new(tls_options.auth));
131            Ok(client)
132        }
133        None => {
134            let channel = Channel::from_shared(host)?.connect().await?;
135            let client: NlpServerClient<InterceptedService<Channel, AuthInterceptor>> =
136                NlpServerClient::with_interceptor(channel, AuthInterceptor::new(None));
137            Ok(client)
138        }
139    }
140}
141
142/// Gets the functions available on the server.
143///
144/// # Arguments
145/// * `client` - The client to use to communicate with the server.  
146/// * `request` - The request to send to the server. If None is given, the default request will be used.
147///
148/// # Example
149///
150/// ```no_run
151/// use aristech_nlp_client::{get_client, list_functions, TlsOptions};
152/// use std::error::Error;
153///
154/// #[tokio::main]
155/// async fn main() -> Result<(), Box<dyn Error>> {
156///     let mut client = get_client("https://nlp.example.com".to_string(), Some(TlsOptions::default())).await?;
157///     let functions = list_functions(&mut client, None).await?;
158///     for function in functions {
159///         println!("{:?}", function);
160///     }
161///     Ok(())
162/// }
163/// ```
164pub async fn list_functions(
165    client: &mut NlpClient,
166    request: Option<FunctionRequest>,
167) -> Result<Vec<Function>, Box<dyn Error>> {
168    let req = request.unwrap_or_default();
169    let response = client.get_functions(req).await?;
170    let mut stream = response.into_inner();
171    let mut functions = vec![];
172    while let Some(function) = stream.message().await? {
173        functions.push(function);
174    }
175    Ok(functions)
176}
177
178/// Processes the given text with the specified function pipeline.
179///
180/// # Arguments
181/// * `client` - The client to use to communicate with the server.
182/// * `request` - The request to send to the server.
183///
184/// # Example
185///
186/// ```no_run
187/// use aristech_nlp_client::{get_client, process, TlsOptions};
188/// use std::error::Error;
189///
190/// #[tokio::main]
191/// async fn main() -> Result<(), Box<dyn Error>> {
192///     let mut client = get_client("https://nlp.example.com".to_string(), Some(TlsOptions::default())).await?;
193///     let request = ProcessRawRequest {
194///         input: "Hello, world!".to_string(),
195///         functions: vec![
196///             Function { id: "spellcheck-de".to_string(), ..Function::default() },
197///         ],
198///        ..ProcessRawRequest::default()
199///     };
200///     let response = run_functions(&mut client, request).await?;
201///     println!("{}", response.output);
202///     Ok(())
203/// }
204pub async fn run_functions(
205    client: &mut NlpClient,
206    request: RunFunctionsRequest,
207) -> Result<RunFunctionsResponse, Box<dyn Error>> {
208    let response = client.run_functions(request).await?;
209    Ok(response.into_inner())
210}
211
212/// Gets the projects available on the server.
213///
214/// # Arguments
215/// * `client` - The client to use to communicate with the server.
216/// * `request` - The request to send to the server. If None is given, the default request will be used.
217///
218/// # Example
219///
220/// ```no_run
221/// use aristech_nlp_client::{get_client, list_projects, TlsOptions};
222/// use std::error::Error;
223///
224/// #[tokio::main]
225/// async fn main() -> Result<(), Box<dyn Error>> {
226///     let mut client = get_client("https://nlp.example.com".to_string(), Some(TlsOptions::default())).await?;
227///     let projects = list_projects(&mut client, None).await?;
228///     for project in projects {
229///         println!("{:?}", project);
230///     }
231///     Ok(())
232/// }
233/// ```
234pub async fn list_projects(
235    client: &mut NlpClient,
236    request: Option<GetProjectsRequest>,
237) -> Result<Vec<Project>, Box<dyn Error>> {
238    let req = request.unwrap_or_default();
239    let response = client.get_projects(req).await?;
240    let mut stream = response.into_inner();
241    let mut projects = vec![];
242    while let Some(project) = stream.message().await? {
243        projects.push(project);
244    }
245    Ok(projects)
246}
247
248/// Gets the intents available on the server.
249///
250/// # Arguments
251/// * `client` - The client to use to communicate with the server.
252/// * `request` - The request to send to the server. If None is given, the default request will be used.
253///
254/// # Example
255///
256/// ```no_run
257/// use aristech_nlp_client::{get_client, get_intents, TlsOptions};
258/// use std::error::Error;
259///
260/// #[tokio::main]
261/// async fn main() -> Result<(), Box<dyn Error>> {
262///     let mut client = get_client("https://nlp.example.com".to_string(), Some(TlsOptions::default())).await?;
263///     let intents = get_intents(&mut client, None).await?;
264///     for intent in intents {
265///         println!("{:?}", intent);
266///     }
267///     Ok(())
268/// }
269/// ```
270pub async fn get_intents(
271    client: &mut NlpClient,
272    request: GetIntentsRequest,
273) -> Result<Vec<GetIntentsResponse>, Box<dyn Error>> {
274    let response = client.get_intents(request).await?;
275    let mut stream = response.into_inner();
276    let mut intents = vec![];
277    while let Some(intent) = stream.message().await? {
278        intents.push(intent);
279    }
280    Ok(intents)
281}
282
283/// Gets the score limits for the given input.
284///
285/// # Arguments
286/// * `client` - The client to use to communicate with the server.
287/// * `request` - The request to send to the server.
288///
289/// # Example
290///
291/// ```no_run
292/// use aristech_nlp_client::{get_client, get_score_limits, TlsOptions};
293/// use std::error::Error;
294///
295/// #[tokio::main]
296/// async fn main() -> Result<(), Box<dyn Error>> {
297///     let mut client = get_client("https://nlp.example.com".to_string(), Some(TlsOptions::default())).await?;
298///     let request = GetScoreLimitsRequest {
299///         input: "Hello, world!".to_string(),
300///         ..GetScoreLimitsRequest::default()
301///     };
302///     let response = get_score_limits(&mut client, request).await?;
303///     println!("{:?}", response);
304///     Ok(())
305/// }
306/// ```
307pub async fn get_score_limits(
308    client: &mut NlpClient,
309    request: GetScoreLimitsRequest,
310) -> Result<GetScoreLimitsResponse, Box<dyn Error>> {
311    let response = client.get_score_limits(request).await?;
312    Ok(response.into_inner())
313}
314
315/// Gets the content for the given request.
316///
317/// # Arguments
318/// * `client` - The client to use to communicate with the server.
319/// * `request` - The request to send to the server.
320///
321/// # Example
322///
323/// ```no_run
324/// use aristech_nlp_client::{get_client, get_content, TlsOptions};
325/// use std::error::Error;
326///
327/// #[tokio::main]
328/// async fn main() -> Result<(), Box<dyn Error>> {
329///     let mut client = get_client("https://nlp.example.com".to_string(), Some(TlsOptions::default())).await?;
330///     let request = GetContentRequest {
331///         prompt: "What are the lottery numbers?".to_string(),
332///         threshold: 0.5,
333///         return_payload: true,
334///         num_results: 3,
335///         metadata: Some(ContentMetaData {
336///             project_id: "3f6959e6-cfb5-4eed-8195-033d47b73263".to_string(),
337///             exclude_output_from_search: true,
338///             ..ContentMetaData::default()
339///         }),
340///         ..GetContentRequest::default()
341///     };
342///     let responses = get_content(&mut client, request).await?;
343///     for response in responses {
344///         println!("{:?}", response.items);
345///     }
346///     Ok(())
347/// }
348/// ```
349pub async fn get_content(
350    client: &mut NlpClient,
351    request: GetContentRequest,
352) -> Result<Vec<GetContentResponse>, Box<dyn Error>> {
353    let response = client.get_content(request).await?;
354    let mut stream = response.into_inner();
355    let mut content = vec![];
356    while let Some(c) = stream.message().await? {
357        content.push(c);
358    }
359    Ok(content)
360}
361
362/// Updates the content for the given request.
363///
364/// # Arguments
365/// * `client` - The client to use to communicate with the server.
366/// * `request` - The request to send to the server.
367///
368/// # Example
369///
370/// ```no_run
371/// use aristech_nlp_client::{get_client, update_content, TlsOptions, UpdateContentRequest, ContentMetaData, ContentData, DescriptionMapping, Output, OutputType};
372/// use std::error::Error;
373///
374/// #[tokio::main]
375/// async fn main() -> Result<(), Box<dyn Error>> {
376///     let mut client = get_client("https://nlp.example.com".to_string(), Some(TlsOptions::default())).await?;
377///     let request = UpdateContentRequest {
378///         id: "123".to_string(),
379///         metadata: Some(ContentMetaData {
380///             project_id: "456".to_string(),
381///             ..ContentMetaData::default()
382///         }),
383///         content: Some(ContentData {
384///             topic: "lottery".to_string(),
385///             description_mappings: vec![DescriptionMapping {
386///                 uuid: "82a2133e-038e-4c83-aa26-de47ed386c55".to_string(),
387///                 description: "What are the lottery numbers?".to_string(),
388///                 ..DescriptionMapping::default()
389///             }],
390///             output: vec![Output {
391///                 r#type: OutputType::Chat.into(),
392///                 data: "The latest lottery numbers are {{lotto_numbers}}".into(),
393///             }],
394///             ..ContentData::default()
395///         }),
396///     };
397///     let response = update_content(&mut client, request).await?;
398///     println!("{:?}", response);
399///     Ok(())
400/// }
401/// ```
402pub async fn update_content(
403    client: &mut NlpClient,
404    request: UpdateContentRequest,
405) -> Result<UpdateContentResponse, Box<dyn Error>> {
406    let response = client.update_content(request).await?;
407    Ok(response.into_inner())
408}
409
410/// Removes the content for the given request.
411///
412/// # Arguments
413/// * `client` - The client to use to communicate with the server.
414/// * `request` - The request to send to the server.
415///
416/// # Example
417///
418/// ```no_run
419/// use aristech_nlp_client::{get_client, remove_content, TlsOptions, RemoveContentRequest, ContentMetaData};
420/// use std::error::Error;
421///
422/// #[tokio::main]
423/// async fn main() -> Result<(), Box<dyn Error>> {
424///     let mut client = get_client("https://nlp.example.com".to_string(), Some(TlsOptions::default())).await?;
425///     let request = RemoveContentRequest {
426///         id: "123".to_string(),
427///         metadata: Some(ContentMetaData {
428///             project_id: "456".to_string(),
429///             ..ContentMetaData::default()
430///         }),
431///         ..RemoveContentRequest::default()
432///     };
433///     let response = remove_content(&mut client, request).await?;
434///     println!("{:?}", response);
435///     Ok(())
436/// }
437/// ```
438pub async fn remove_content(
439    client: &mut NlpClient,
440    request: RemoveContentRequest,
441) -> Result<RemoveContentResponse, Box<dyn Error>> {
442    let response = client.remove_content(request).await?;
443    Ok(response.into_inner())
444}