aristech_nlp_client/lib.rs
1//! # Aristech NLP-Client
2//! This is a client library for the Aristech NLP-Server.
3
4#![warn(missing_docs)]
5
6/// The nlp_server module contains types and functions generated from the Aristech NLP proto file.
7pub mod nlp_service {
8 #![allow(missing_docs)]
9 tonic::include_proto!("aristech.nlp");
10}
11
12use std::error::Error;
13
14use nlp_service::nlp_server_client::NlpServerClient;
15use nlp_service::{
16 Function, FunctionRequest, GetContentRequest, GetContentResponse, GetIntentsRequest,
17 GetProjectsRequest, GetScoreLimitsRequest, GetScoreLimitsResponse, Project,
18 RemoveContentRequest, RemoveContentResponse, RunFunctionsRequest, RunFunctionsResponse,
19 UpdateContentRequest, UpdateContentResponse,
20};
21use tonic::codegen::InterceptedService;
22use tonic::service::Interceptor;
23use tonic::transport::{Certificate, Channel, ClientTlsConfig};
24
25use crate::nlp_service::Intent;
26
27/// The Auth struct holds the token and secret needed to authenticate with the server.
28#[derive(Clone, Debug)]
29pub struct Auth {
30 /// The token to authenticate with the server
31 pub token: String,
32 /// The secret to authenticate with the server
33 pub secret: String,
34}
35
36impl Auth {
37 /// Creates a new Auth struct with the given token and secret.
38 pub fn new(token: &str, secret: &str) -> Self {
39 Self {
40 token: token.to_string(),
41 secret: secret.to_string(),
42 }
43 }
44}
45
46/// The AuthInterceptor struct is used to intercept requests to the server and add the authentication headers.
47pub struct AuthInterceptor {
48 /// The authentication data to add to the headers
49 auth: Option<Auth>,
50}
51
52impl AuthInterceptor {
53 /// Creates a new AuthInterceptor with the given authentication data.
54 fn new(auth: Option<Auth>) -> Self {
55 Self { auth }
56 }
57}
58impl Interceptor for AuthInterceptor {
59 /// Adds the authentication data to the headers of the request.
60 fn call(&mut self, request: tonic::Request<()>) -> Result<tonic::Request<()>, tonic::Status> {
61 if let Some(auth) = &self.auth {
62 let mut request = request;
63 request
64 .metadata_mut()
65 .insert("token", auth.token.parse().unwrap());
66 request
67 .metadata_mut()
68 .insert("secret", auth.secret.parse().unwrap());
69 Ok(request)
70 } else {
71 Ok(request)
72 }
73 }
74}
75
76/// The NlpClient type is a type alias for the NlpServiceClient with the AuthInterceptor.
77pub type NlpClient = NlpServerClient<InterceptedService<Channel, AuthInterceptor>>;
78
79/// The TlsOptions struct holds the tls options needed to communicate with the server.
80#[derive(Clone, Default, Debug)]
81pub struct TlsOptions {
82 /// The authentication data to authenticate with the server
83 pub auth: Option<Auth>,
84 /// The root certificate to verify the server's certificate
85 /// This is usually only needed when the server uses a self-signed certificate
86 pub ca_certificate: Option<String>,
87}
88
89impl TlsOptions {
90 /// Creates a new TlsOptions struct with the given authentication data and root certificate.
91 pub fn new(auth: Option<Auth>, ca_certificate: Option<String>) -> Self {
92 Self {
93 auth,
94 ca_certificate,
95 }
96 }
97}
98
99/// Creates a new [NlpClient] to communicate with the server.
100///
101/// # Arguments
102/// * `host` - The host to connect to (might include the port number e.g. "https://nlp.example.com:8524"). Note that the protocol must be included in the host.
103/// * `tls_options` - The tls options to use when connecting to the server. If None is given, the connection will be unencrypted and unauthenticated (the server must also be configured to communicate without encryption in this case).
104pub async fn get_client(
105 host: String,
106 tls_options: Option<TlsOptions>,
107) -> Result<NlpClient, Box<dyn Error>> {
108 // Check if a schema is included in the host
109 // otherwise add http if no tls options are given and https otherwise
110 let host = if host.starts_with("http://") || host.starts_with("https://") {
111 host
112 } else {
113 match tls_options {
114 Some(_) => format!("https://{}", host),
115 None => format!("http://{}", host),
116 }
117 };
118 match tls_options {
119 Some(tls_options) => {
120 let tls = match tls_options.ca_certificate {
121 Some(ca_certificate) => {
122 let ca_certificate = Certificate::from_pem(ca_certificate);
123 ClientTlsConfig::new().ca_certificate(ca_certificate)
124 }
125 None => ClientTlsConfig::with_native_roots(ClientTlsConfig::new()),
126 };
127 let channel = Channel::from_shared(host)?
128 .tls_config(tls)?
129 .connect()
130 .await?;
131 let client: NlpServerClient<InterceptedService<Channel, AuthInterceptor>> =
132 NlpServerClient::with_interceptor(channel, AuthInterceptor::new(tls_options.auth));
133 Ok(client)
134 }
135 None => {
136 let channel = Channel::from_shared(host)?.connect().await?;
137 let client: NlpServerClient<InterceptedService<Channel, AuthInterceptor>> =
138 NlpServerClient::with_interceptor(channel, AuthInterceptor::new(None));
139 Ok(client)
140 }
141 }
142}
143
144/// Gets the functions available on the server.
145///
146/// # Arguments
147/// * `client` - The client to use to communicate with the server.
148/// * `request` - The request to send to the server. If None is given, the default request will be used.
149///
150/// # Example
151///
152/// ```no_run
153/// use aristech_nlp_client::{get_client, list_functions, TlsOptions};
154/// use std::error::Error;
155///
156/// #[tokio::main]
157/// async fn main() -> Result<(), Box<dyn Error>> {
158/// let mut client = get_client("https://nlp.example.com".to_string(), Some(TlsOptions::default())).await?;
159/// let functions = list_functions(&mut client, None).await?;
160/// for function in functions {
161/// println!("{:?}", function);
162/// }
163/// Ok(())
164/// }
165/// ```
166pub async fn list_functions(
167 client: &mut NlpClient,
168 request: Option<FunctionRequest>,
169) -> Result<Vec<Function>, Box<dyn Error>> {
170 let req = request.unwrap_or_default();
171 let response = client.get_functions(req).await?;
172 let mut stream = response.into_inner();
173 let mut functions = vec![];
174 while let Some(function) = stream.message().await? {
175 functions.push(function);
176 }
177 Ok(functions)
178}
179
180/// Processes the given text with the specified function pipeline.
181///
182/// # Arguments
183/// * `client` - The client to use to communicate with the server.
184/// * `request` - The request to send to the server.
185///
186/// # Example
187///
188/// ```no_run
189/// use aristech_nlp_client::{get_client, process, TlsOptions};
190/// use std::error::Error;
191///
192/// #[tokio::main]
193/// async fn main() -> Result<(), Box<dyn Error>> {
194/// let mut client = get_client("https://nlp.example.com".to_string(), Some(TlsOptions::default())).await?;
195/// let request = ProcessRawRequest {
196/// input: "Hello, world!".to_string(),
197/// functions: vec![
198/// Function { id: "spellcheck-de".to_string(), ..Function::default() },
199/// ],
200/// ..ProcessRawRequest::default()
201/// };
202/// let response = run_functions(&mut client, request).await?;
203/// println!("{}", response.output);
204/// Ok(())
205/// }
206pub async fn run_functions(
207 client: &mut NlpClient,
208 request: RunFunctionsRequest,
209) -> Result<RunFunctionsResponse, Box<dyn Error>> {
210 let response = client.run_functions(request).await?;
211 Ok(response.into_inner())
212}
213
214/// Gets the projects available on the server.
215///
216/// # Arguments
217/// * `client` - The client to use to communicate with the server.
218/// * `request` - The request to send to the server. If None is given, the default request will be used.
219///
220/// # Example
221///
222/// ```no_run
223/// use aristech_nlp_client::{get_client, list_projects, TlsOptions};
224/// use std::error::Error;
225///
226/// #[tokio::main]
227/// async fn main() -> Result<(), Box<dyn Error>> {
228/// let mut client = get_client("https://nlp.example.com".to_string(), Some(TlsOptions::default())).await?;
229/// let projects = list_projects(&mut client, None).await?;
230/// for project in projects {
231/// println!("{:?}", project);
232/// }
233/// Ok(())
234/// }
235/// ```
236pub async fn list_projects(
237 client: &mut NlpClient,
238 request: Option<GetProjectsRequest>,
239) -> Result<Vec<Project>, Box<dyn Error>> {
240 let req = request.unwrap_or_default();
241 let response = client.get_projects(req).await?;
242 let mut stream = response.into_inner();
243 let mut projects = vec![];
244 while let Some(project) = stream.message().await? {
245 projects.push(project);
246 }
247 Ok(projects)
248}
249
250/// Gets the intents available on the server.
251///
252/// # Arguments
253/// * `client` - The client to use to communicate with the server.
254/// * `request` - The request to send to the server. If None is given, the default request will be used.
255///
256/// # Example
257///
258/// ```no_run
259/// use aristech_nlp_client::{get_client, get_intents, TlsOptions};
260/// use std::error::Error;
261///
262/// #[tokio::main]
263/// async fn main() -> Result<(), Box<dyn Error>> {
264/// let mut client = get_client("https://nlp.example.com".to_string(), Some(TlsOptions::default())).await?;
265/// let intents = get_intents(&mut client, None).await?;
266/// for intent in intents {
267/// println!("{:?}", intent);
268/// }
269/// Ok(())
270/// }
271/// ```
272pub async fn get_intents(
273 client: &mut NlpClient,
274 request: GetIntentsRequest,
275) -> Result<Vec<Intent>, Box<dyn Error>> {
276 let response = client.get_intents(request).await?;
277 let mut stream = response.into_inner();
278 let mut intents = vec![];
279 while let Some(intent) = stream.message().await? {
280 intents.push(intent);
281 }
282 Ok(intents)
283}
284
285/// Gets the score limits for the given input.
286///
287/// # Arguments
288/// * `client` - The client to use to communicate with the server.
289/// * `request` - The request to send to the server.
290///
291/// # Example
292///
293/// ```no_run
294/// use aristech_nlp_client::{get_client, get_score_limits, TlsOptions};
295/// use std::error::Error;
296///
297/// #[tokio::main]
298/// async fn main() -> Result<(), Box<dyn Error>> {
299/// let mut client = get_client("https://nlp.example.com".to_string(), Some(TlsOptions::default())).await?;
300/// let request = GetScoreLimitsRequest {
301/// input: "Hello, world!".to_string(),
302/// ..GetScoreLimitsRequest::default()
303/// };
304/// let response = get_score_limits(&mut client, request).await?;
305/// println!("{:?}", response);
306/// Ok(())
307/// }
308/// ```
309pub async fn get_score_limits(
310 client: &mut NlpClient,
311 request: GetScoreLimitsRequest,
312) -> Result<GetScoreLimitsResponse, Box<dyn Error>> {
313 let response = client.get_score_limits(request).await?;
314 Ok(response.into_inner())
315}
316
317/// Gets the content for the given request.
318///
319/// # Arguments
320/// * `client` - The client to use to communicate with the server.
321/// * `request` - The request to send to the server.
322///
323/// # Example
324///
325/// ```no_run
326/// use aristech_nlp_client::{get_client, get_content, TlsOptions};
327/// use std::error::Error;
328///
329/// #[tokio::main]
330/// async fn main() -> Result<(), Box<dyn Error>> {
331/// let mut client = get_client("https://nlp.example.com".to_string(), Some(TlsOptions::default())).await?;
332/// let request = GetContentRequest {
333/// prompt: "What are the lottery numbers?".to_string(),
334/// threshold: 0.5,
335/// return_payload: true,
336/// num_results: 3,
337/// metadata: Some(ContentMetaData {
338/// project_id: "3f6959e6-cfb5-4eed-8195-033d47b73263".to_string(),
339/// exclude_output_from_search: true,
340/// ..ContentMetaData::default()
341/// }),
342/// ..GetContentRequest::default()
343/// };
344/// let responses = get_content(&mut client, request).await?;
345/// for response in responses {
346/// println!("{:?}", response.items);
347/// }
348/// Ok(())
349/// }
350/// ```
351pub async fn get_content(
352 client: &mut NlpClient,
353 request: GetContentRequest,
354) -> Result<Vec<GetContentResponse>, Box<dyn Error>> {
355 let response = client.get_content(request).await?;
356 let mut stream = response.into_inner();
357 let mut content = vec![];
358 while let Some(c) = stream.message().await? {
359 content.push(c);
360 }
361 Ok(content)
362}
363
364/// Updates the content for the given request.
365///
366/// # Arguments
367/// * `client` - The client to use to communicate with the server.
368/// * `request` - The request to send to the server.
369///
370/// # Example
371///
372/// ```no_run
373/// use aristech_nlp_client::{get_client, update_content, TlsOptions, UpdateContentRequest, ContentMetaData, ContentData, DescriptionMapping, Output, OutputType};
374/// use std::error::Error;
375///
376/// #[tokio::main]
377/// async fn main() -> Result<(), Box<dyn Error>> {
378/// let mut client = get_client("https://nlp.example.com".to_string(), Some(TlsOptions::default())).await?;
379/// let request = UpdateContentRequest {
380/// id: "123".to_string(),
381/// metadata: Some(ContentMetaData {
382/// project_id: "456".to_string(),
383/// ..ContentMetaData::default()
384/// }),
385/// content: Some(ContentData {
386/// topic: "lottery".to_string(),
387/// description_mappings: vec![DescriptionMapping {
388/// uuid: "82a2133e-038e-4c83-aa26-de47ed386c55".to_string(),
389/// description: "What are the lottery numbers?".to_string(),
390/// ..DescriptionMapping::default()
391/// }],
392/// output: vec![Output {
393/// r#type: OutputType::Chat.into(),
394/// data: "The latest lottery numbers are {{lotto_numbers}}".into(),
395/// }],
396/// ..ContentData::default()
397/// }),
398/// };
399/// let response = update_content(&mut client, request).await?;
400/// println!("{:?}", response);
401/// Ok(())
402/// }
403/// ```
404pub async fn update_content(
405 client: &mut NlpClient,
406 request: UpdateContentRequest,
407) -> Result<UpdateContentResponse, Box<dyn Error>> {
408 let response = client.update_content(request).await?;
409 Ok(response.into_inner())
410}
411
412/// Removes the content for the given request.
413///
414/// # Arguments
415/// * `client` - The client to use to communicate with the server.
416/// * `request` - The request to send to the server.
417///
418/// # Example
419///
420/// ```no_run
421/// use aristech_nlp_client::{get_client, remove_content, TlsOptions, RemoveContentRequest, ContentMetaData};
422/// use std::error::Error;
423///
424/// #[tokio::main]
425/// async fn main() -> Result<(), Box<dyn Error>> {
426/// let mut client = get_client("https://nlp.example.com".to_string(), Some(TlsOptions::default())).await?;
427/// let request = RemoveContentRequest {
428/// id: "123".to_string(),
429/// metadata: Some(ContentMetaData {
430/// project_id: "456".to_string(),
431/// ..ContentMetaData::default()
432/// }),
433/// ..RemoveContentRequest::default()
434/// };
435/// let response = remove_content(&mut client, request).await?;
436/// println!("{:?}", response);
437/// Ok(())
438/// }
439/// ```
440pub async fn remove_content(
441 client: &mut NlpClient,
442 request: RemoveContentRequest,
443) -> Result<RemoveContentResponse, Box<dyn Error>> {
444 let response = client.remove_content(request).await?;
445 Ok(response.into_inner())
446}