aristech_nlp_client/lib.rs
1//! # Aristech NLP-Client
2//! This is a client library for the Aristech NLP-Server.
3
4#![warn(missing_docs)]
5
6/// The nlp_server module contains types and functions generated from the Aristech NLP proto file.
7pub mod nlp_service {
8 #![allow(missing_docs)]
9 tonic::include_proto!("aristech.nlp");
10}
11
12use std::error::Error;
13
14use nlp_service::nlp_server_client::NlpServerClient;
15use nlp_service::{
16 Function, FunctionRequest, GetContentRequest, GetContentResponse, GetIntentsRequest,
17 GetIntentsResponse, GetProjectsRequest, GetScoreLimitsRequest, GetScoreLimitsResponse, Project,
18 RemoveContentRequest, RemoveContentResponse, RunFunctionsRequest, RunFunctionsResponse,
19 UpdateContentRequest, UpdateContentResponse,
20};
21use tonic::codegen::InterceptedService;
22use tonic::service::Interceptor;
23use tonic::transport::{Certificate, Channel, ClientTlsConfig};
24
25/// The Auth struct holds the token and secret needed to authenticate with the server.
26#[derive(Clone, Debug)]
27pub struct Auth {
28 /// The token to authenticate with the server
29 pub token: String,
30 /// The secret to authenticate with the server
31 pub secret: String,
32}
33
34impl Auth {
35 /// Creates a new Auth struct with the given token and secret.
36 pub fn new(token: &str, secret: &str) -> Self {
37 Self {
38 token: token.to_string(),
39 secret: secret.to_string(),
40 }
41 }
42}
43
44/// The AuthInterceptor struct is used to intercept requests to the server and add the authentication headers.
45pub struct AuthInterceptor {
46 /// The authentication data to add to the headers
47 auth: Option<Auth>,
48}
49
50impl AuthInterceptor {
51 /// Creates a new AuthInterceptor with the given authentication data.
52 fn new(auth: Option<Auth>) -> Self {
53 Self { auth }
54 }
55}
56impl Interceptor for AuthInterceptor {
57 /// Adds the authentication data to the headers of the request.
58 fn call(&mut self, request: tonic::Request<()>) -> Result<tonic::Request<()>, tonic::Status> {
59 if let Some(auth) = &self.auth {
60 let mut request = request;
61 request
62 .metadata_mut()
63 .insert("token", auth.token.parse().unwrap());
64 request
65 .metadata_mut()
66 .insert("secret", auth.secret.parse().unwrap());
67 Ok(request)
68 } else {
69 Ok(request)
70 }
71 }
72}
73
74/// The NlpClient type is a type alias for the NlpServiceClient with the AuthInterceptor.
75pub type NlpClient = NlpServerClient<InterceptedService<Channel, AuthInterceptor>>;
76
77/// The TlsOptions struct holds the tls options needed to communicate with the server.
78#[derive(Clone, Default, Debug)]
79pub struct TlsOptions {
80 /// The authentication data to authenticate with the server
81 pub auth: Option<Auth>,
82 /// The root certificate to verify the server's certificate
83 /// This is usually only needed when the server uses a self-signed certificate
84 pub ca_certificate: Option<String>,
85}
86
87impl TlsOptions {
88 /// Creates a new TlsOptions struct with the given authentication data and root certificate.
89 pub fn new(auth: Option<Auth>, ca_certificate: Option<String>) -> Self {
90 Self {
91 auth,
92 ca_certificate,
93 }
94 }
95}
96
97/// Creates a new [NlpClient] to communicate with the server.
98///
99/// # Arguments
100/// * `host` - The host to connect to (might include the port number e.g. "https://nlp.example.com:8524"). Note that the protocol must be included in the host.
101/// * `tls_options` - The tls options to use when connecting to the server. If None is given, the connection will be unencrypted and unauthenticated (the server must also be configured to communicate without encryption in this case).
102pub async fn get_client(
103 host: String,
104 tls_options: Option<TlsOptions>,
105) -> Result<NlpClient, Box<dyn Error>> {
106 // Check if a schema is included in the host
107 // otherwise add http if no tls options are given and https otherwise
108 let host = if host.starts_with("http://") || host.starts_with("https://") {
109 host
110 } else {
111 match tls_options {
112 Some(_) => format!("https://{}", host),
113 None => format!("http://{}", host),
114 }
115 };
116 match tls_options {
117 Some(tls_options) => {
118 let tls = match tls_options.ca_certificate {
119 Some(ca_certificate) => {
120 let ca_certificate = Certificate::from_pem(ca_certificate);
121 ClientTlsConfig::new().ca_certificate(ca_certificate)
122 }
123 None => ClientTlsConfig::with_native_roots(ClientTlsConfig::new()),
124 };
125 let channel = Channel::from_shared(host)?
126 .tls_config(tls)?
127 .connect()
128 .await?;
129 let client: NlpServerClient<InterceptedService<Channel, AuthInterceptor>> =
130 NlpServerClient::with_interceptor(channel, AuthInterceptor::new(tls_options.auth));
131 Ok(client)
132 }
133 None => {
134 let channel = Channel::from_shared(host)?.connect().await?;
135 let client: NlpServerClient<InterceptedService<Channel, AuthInterceptor>> =
136 NlpServerClient::with_interceptor(channel, AuthInterceptor::new(None));
137 Ok(client)
138 }
139 }
140}
141
142/// Gets the functions available on the server.
143///
144/// # Arguments
145/// * `client` - The client to use to communicate with the server.
146/// * `request` - The request to send to the server. If None is given, the default request will be used.
147///
148/// # Example
149///
150/// ```no_run
151/// use aristech_nlp_client::{get_client, list_functions, TlsOptions};
152/// use std::error::Error;
153///
154/// #[tokio::main]
155/// async fn main() -> Result<(), Box<dyn Error>> {
156/// let mut client = get_client("https://nlp.example.com".to_string(), Some(TlsOptions::default())).await?;
157/// let functions = list_functions(&mut client, None).await?;
158/// for function in functions {
159/// println!("{:?}", function);
160/// }
161/// Ok(())
162/// }
163/// ```
164pub async fn list_functions(
165 client: &mut NlpClient,
166 request: Option<FunctionRequest>,
167) -> Result<Vec<Function>, Box<dyn Error>> {
168 let req = request.unwrap_or_default();
169 let response = client.get_functions(req).await?;
170 let mut stream = response.into_inner();
171 let mut functions = vec![];
172 while let Some(function) = stream.message().await? {
173 functions.push(function);
174 }
175 Ok(functions)
176}
177
178/// Processes the given text with the specified function pipeline.
179///
180/// # Arguments
181/// * `client` - The client to use to communicate with the server.
182/// * `request` - The request to send to the server.
183///
184/// # Example
185///
186/// ```no_run
187/// use aristech_nlp_client::{get_client, process, TlsOptions};
188/// use std::error::Error;
189///
190/// #[tokio::main]
191/// async fn main() -> Result<(), Box<dyn Error>> {
192/// let mut client = get_client("https://nlp.example.com".to_string(), Some(TlsOptions::default())).await?;
193/// let request = ProcessRawRequest {
194/// input: "Hello, world!".to_string(),
195/// functions: vec![
196/// Function { id: "spellcheck-de".to_string(), ..Function::default() },
197/// ],
198/// ..ProcessRawRequest::default()
199/// };
200/// let response = run_functions(&mut client, request).await?;
201/// println!("{}", response.output);
202/// Ok(())
203/// }
204pub async fn run_functions(
205 client: &mut NlpClient,
206 request: RunFunctionsRequest,
207) -> Result<RunFunctionsResponse, Box<dyn Error>> {
208 let response = client.run_functions(request).await?;
209 Ok(response.into_inner())
210}
211
212/// Gets the projects available on the server.
213///
214/// # Arguments
215/// * `client` - The client to use to communicate with the server.
216/// * `request` - The request to send to the server. If None is given, the default request will be used.
217///
218/// # Example
219///
220/// ```no_run
221/// use aristech_nlp_client::{get_client, list_projects, TlsOptions};
222/// use std::error::Error;
223///
224/// #[tokio::main]
225/// async fn main() -> Result<(), Box<dyn Error>> {
226/// let mut client = get_client("https://nlp.example.com".to_string(), Some(TlsOptions::default())).await?;
227/// let projects = list_projects(&mut client, None).await?;
228/// for project in projects {
229/// println!("{:?}", project);
230/// }
231/// Ok(())
232/// }
233/// ```
234pub async fn list_projects(
235 client: &mut NlpClient,
236 request: Option<GetProjectsRequest>,
237) -> Result<Vec<Project>, Box<dyn Error>> {
238 let req = request.unwrap_or_default();
239 let response = client.get_projects(req).await?;
240 let mut stream = response.into_inner();
241 let mut projects = vec![];
242 while let Some(project) = stream.message().await? {
243 projects.push(project);
244 }
245 Ok(projects)
246}
247
248/// Gets the intents available on the server.
249///
250/// # Arguments
251/// * `client` - The client to use to communicate with the server.
252/// * `request` - The request to send to the server. If None is given, the default request will be used.
253///
254/// # Example
255///
256/// ```no_run
257/// use aristech_nlp_client::{get_client, get_intents, TlsOptions};
258/// use std::error::Error;
259///
260/// #[tokio::main]
261/// async fn main() -> Result<(), Box<dyn Error>> {
262/// let mut client = get_client("https://nlp.example.com".to_string(), Some(TlsOptions::default())).await?;
263/// let intents = get_intents(&mut client, None).await?;
264/// for intent in intents {
265/// println!("{:?}", intent);
266/// }
267/// Ok(())
268/// }
269/// ```
270pub async fn get_intents(
271 client: &mut NlpClient,
272 request: GetIntentsRequest,
273) -> Result<Vec<GetIntentsResponse>, Box<dyn Error>> {
274 let response = client.get_intents(request).await?;
275 let mut stream = response.into_inner();
276 let mut intents = vec![];
277 while let Some(intent) = stream.message().await? {
278 intents.push(intent);
279 }
280 Ok(intents)
281}
282
283/// Gets the score limits for the given input.
284///
285/// # Arguments
286/// * `client` - The client to use to communicate with the server.
287/// * `request` - The request to send to the server.
288///
289/// # Example
290///
291/// ```no_run
292/// use aristech_nlp_client::{get_client, get_score_limits, TlsOptions};
293/// use std::error::Error;
294///
295/// #[tokio::main]
296/// async fn main() -> Result<(), Box<dyn Error>> {
297/// let mut client = get_client("https://nlp.example.com".to_string(), Some(TlsOptions::default())).await?;
298/// let request = GetScoreLimitsRequest {
299/// input: "Hello, world!".to_string(),
300/// ..GetScoreLimitsRequest::default()
301/// };
302/// let response = get_score_limits(&mut client, request).await?;
303/// println!("{:?}", response);
304/// Ok(())
305/// }
306/// ```
307pub async fn get_score_limits(
308 client: &mut NlpClient,
309 request: GetScoreLimitsRequest,
310) -> Result<GetScoreLimitsResponse, Box<dyn Error>> {
311 let response = client.get_score_limits(request).await?;
312 Ok(response.into_inner())
313}
314
315/// Gets the content for the given request.
316///
317/// # Arguments
318/// * `client` - The client to use to communicate with the server.
319/// * `request` - The request to send to the server.
320///
321/// # Example
322///
323/// ```no_run
324/// use aristech_nlp_client::{get_client, get_content, TlsOptions};
325/// use std::error::Error;
326///
327/// #[tokio::main]
328/// async fn main() -> Result<(), Box<dyn Error>> {
329/// let mut client = get_client("https://nlp.example.com".to_string(), Some(TlsOptions::default())).await?;
330/// let request = GetContentRequest {
331/// prompt: "What are the lottery numbers?".to_string(),
332/// threshold: 0.5,
333/// return_payload: true,
334/// num_results: 3,
335/// metadata: Some(ContentMetaData {
336/// project_id: "3f6959e6-cfb5-4eed-8195-033d47b73263".to_string(),
337/// exclude_output_from_search: true,
338/// ..ContentMetaData::default()
339/// }),
340/// ..GetContentRequest::default()
341/// };
342/// let responses = get_content(&mut client, request).await?;
343/// for response in responses {
344/// println!("{:?}", response.items);
345/// }
346/// Ok(())
347/// }
348/// ```
349pub async fn get_content(
350 client: &mut NlpClient,
351 request: GetContentRequest,
352) -> Result<Vec<GetContentResponse>, Box<dyn Error>> {
353 let response = client.get_content(request).await?;
354 let mut stream = response.into_inner();
355 let mut content = vec![];
356 while let Some(c) = stream.message().await? {
357 content.push(c);
358 }
359 Ok(content)
360}
361
362/// Updates the content for the given request.
363///
364/// # Arguments
365/// * `client` - The client to use to communicate with the server.
366/// * `request` - The request to send to the server.
367///
368/// # Example
369///
370/// ```no_run
371/// use aristech_nlp_client::{get_client, update_content, TlsOptions, UpdateContentRequest, ContentMetaData, ContentData, DescriptionMapping, Output, OutputType};
372/// use std::error::Error;
373///
374/// #[tokio::main]
375/// async fn main() -> Result<(), Box<dyn Error>> {
376/// let mut client = get_client("https://nlp.example.com".to_string(), Some(TlsOptions::default())).await?;
377/// let request = UpdateContentRequest {
378/// id: "123".to_string(),
379/// metadata: Some(ContentMetaData {
380/// project_id: "456".to_string(),
381/// ..ContentMetaData::default()
382/// }),
383/// content: Some(ContentData {
384/// topic: "lottery".to_string(),
385/// description_mappings: vec![DescriptionMapping {
386/// uuid: "82a2133e-038e-4c83-aa26-de47ed386c55".to_string(),
387/// description: "What are the lottery numbers?".to_string(),
388/// ..DescriptionMapping::default()
389/// }],
390/// output: vec![Output {
391/// r#type: OutputType::Chat.into(),
392/// data: "The latest lottery numbers are {{lotto_numbers}}".into(),
393/// }],
394/// ..ContentData::default()
395/// }),
396/// };
397/// let response = update_content(&mut client, request).await?;
398/// println!("{:?}", response);
399/// Ok(())
400/// }
401/// ```
402pub async fn update_content(
403 client: &mut NlpClient,
404 request: UpdateContentRequest,
405) -> Result<UpdateContentResponse, Box<dyn Error>> {
406 let response = client.update_content(request).await?;
407 Ok(response.into_inner())
408}
409
410/// Removes the content for the given request.
411///
412/// # Arguments
413/// * `client` - The client to use to communicate with the server.
414/// * `request` - The request to send to the server.
415///
416/// # Example
417///
418/// ```no_run
419/// use aristech_nlp_client::{get_client, remove_content, TlsOptions, RemoveContentRequest, ContentMetaData};
420/// use std::error::Error;
421///
422/// #[tokio::main]
423/// async fn main() -> Result<(), Box<dyn Error>> {
424/// let mut client = get_client("https://nlp.example.com".to_string(), Some(TlsOptions::default())).await?;
425/// let request = RemoveContentRequest {
426/// id: "123".to_string(),
427/// metadata: Some(ContentMetaData {
428/// project_id: "456".to_string(),
429/// ..ContentMetaData::default()
430/// }),
431/// ..RemoveContentRequest::default()
432/// };
433/// let response = remove_content(&mut client, request).await?;
434/// println!("{:?}", response);
435/// Ok(())
436/// }
437/// ```
438pub async fn remove_content(
439 client: &mut NlpClient,
440 request: RemoveContentRequest,
441) -> Result<RemoveContentResponse, Box<dyn Error>> {
442 let response = client.remove_content(request).await?;
443 Ok(response.into_inner())
444}