prediction_guard/
client.rs

1//! Used to connect to the Prediction Guard API.
2use std::{env, fmt, sync::Arc, time::Duration};
3
4use crate::built_info;
5use crate::{
6    chat, completion, embedding, factuality,
7    injection, pii, rerank, toxicity, translate,
8    tokenize, models, Result
9};
10use dotenvy;
11use eventsource_client::Client as EventClient;
12use eventsource_client::SSE;
13use futures::TryStreamExt;
14use log::error;
15use reqwest::{
16    header::{HeaderMap, HeaderValue},
17    ClientBuilder, Response, StatusCode,
18};
19use serde::{Deserialize, Serialize};
20use tokio::sync::mpsc::Sender;
21
22const USER_AGENT: &str = "Prediction Guard Rust Client";
23
24/// The base error that is returned from the API calls.
25#[derive(Debug, Serialize, Deserialize, Clone, Default)]
26pub struct ApiError {
27    error: String,
28}
29
30impl fmt::Display for ApiError {
31    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32        f.write_fmt(format_args!("error {}", self.error))
33    }
34}
35
36impl std::error::Error for ApiError {}
37
38/// Prediction Guard Configuration
39pub struct PgEnvironment {
40    pub key: String,
41    pub host: String,
42}
43
44impl PgEnvironment {
45    /// Specify Prediction Guard API configuration manually
46    ///
47    /// ## Arguments:
48    ///
49    /// * `key` - the Prediction Guard API key
50    /// * `host` - the Prediction Guard URL
51    pub fn new(key: String, host: String) -> Self {
52        Self { key, host }
53    }
54
55    /// Loads Prediction Guard API configuration from either
56    /// a `.env` file or environment variables. Expects to find
57    /// the `PREDICTIONGUARD_API_KEY` and `PREDICTIONGUARD_URL` environment variables.
58    ///
59    /// Returns an error if the environment variables are not found.
60    pub fn from_env() -> Result<Self> {
61        let _ = dotenvy::dotenv(); // Ignoring error - it's ok to not have .env files
62        Ok(Self {
63            key: env::var("PREDICTIONGUARD_API_KEY")?,
64            host: env::var("PREDICTIONGUARD_URL")?,
65        })
66    }
67}
68
69/// Handles the connectivity to the Prediction Guard API. It is safe to be
70/// used across threads.
71#[derive(Debug, Clone)]
72pub struct Client {
73    inner: Arc<ClientInner>,
74}
75
76#[derive(Debug)]
77struct ClientInner {
78    server: String,
79    http_client: reqwest::Client,
80    headers: HeaderMap,
81    api_key: String,
82}
83
84impl Client {
85    /// Creates a new instance of client to be used. Assumes the Prediction Guard keys,
86    /// PREDICTIONGUARD_API_KEY and PREDICTIONGUARD_URL are set in the environment.
87    pub fn new() -> Result<Self> {
88        let pg_env = PgEnvironment::from_env().expect("env keys");
89
90        Self::from_environment(pg_env)
91    }
92
93    /// Creates a new instance of client to be used with a particular Prediction Guard environment.
94    ///
95    ///  ## Arguments:
96    ///
97    ///  * `pg_env` - the prediction guard environment to connect to.
98    pub fn from_environment(pg_env: PgEnvironment) -> Result<Self> {
99        let user_agent = format!("{} v{}", USER_AGENT, built_info::PKG_VERSION);
100
101        let http = ClientBuilder::new()
102            .connect_timeout(Duration::new(30, 0))
103            .read_timeout(Duration::new(30, 0))
104            .timeout(Duration::new(45, 0))
105            .user_agent(user_agent)
106            .build()?;
107
108        let header_key = match HeaderValue::from_str(&pg_env.key) {
109            Ok(x) => x,
110            Err(e) => {
111                return Err(Box::new(e));
112            }
113        };
114
115        let mut header_map = HeaderMap::new();
116        let _ = header_map
117            .insert("x-api-key", header_key)
118            .ok_or("invalid api key");
119
120        let inner = Arc::new(ClientInner {
121            server: pg_env.host.to_string(),
122            http_client: http,
123            headers: header_map,
124            api_key: pg_env.key,
125        });
126
127        Ok(Self { inner })
128    }
129
130    /// Calls the health endpoint.
131    ///
132    /// Returns the text response from the server. A 200 (Ok) status code is expected from
133    /// Prediction Guard api. Any other status code is considered an error.
134    pub async fn check_health(&self) -> Result<String> {
135        let result = self
136            .inner
137            .http_client
138            .get(&self.inner.server)
139            .headers(self.inner.headers.clone())
140            .send()
141            .await?;
142
143        if result.status() != StatusCode::OK {
144            return Err(retrieve_error(result).await);
145        }
146
147        let txt = result.text().await?;
148
149        Ok(txt)
150    }
151
152    /// Retrieves the list of models available for a set capability
153    ///
154    /// ## Arguments:
155    ///
156    /// * `capability` - The capability of models to sort by.
157    ///
158    /// Returns a vector of strings with the model names. A 200 (Ok) status code is expected from the Prediction Guard api.
159    /// Any other status code is considered an error.
160    pub async fn retrieve_model_list(&self, capability: String) -> Result<Vec<String>> {
161        let url = format!(
162            "{}{}/{}",
163            &self.inner.server,
164            models::PATH,
165            capability
166        );
167
168        let result = self
169            .inner
170            .http_client
171            .get(url)
172            .headers(self.inner.headers.clone())
173            .send()
174            .await?;
175
176        if result.status() != StatusCode::OK {
177            return Err(retrieve_error(result).await);
178        }
179
180        let response_body: models::Response = result.json().await?;
181
182        let retrieve_models_response: Vec<String> = response_body
183            .data
184            .into_iter()
185            .map(|model| model.id)
186            .collect();
187
188        Ok(retrieve_models_response)
189    }
190
191    /// Calls the embedding endpoint.
192    ///
193    /// ## Arguments:
194    ///
195    /// * `req` - An instance of [`embedding::Request`]
196    ///
197    /// Returns a [`embedding::Response`]. A 200 (Ok) status code is expected from the Prediction Guard api. Any other status code
198    /// is considered an error.
199    pub async fn embedding(&self, req: &embedding::Request) -> Result<embedding::Response> {
200        let url = format!("{}{}", &self.inner.server, embedding::PATH);
201
202        let result = self
203            .inner
204            .http_client
205            .post(url)
206            .headers(self.inner.headers.clone())
207            .json(req)
208            .send()
209            .await?;
210
211        if result.status() != StatusCode::OK {
212            return Err(retrieve_error(result).await);
213        }
214
215        let embed_response = result.json::<embedding::Response>().await?;
216
217        Ok(embed_response)
218    }
219
220    /// Calls the generate completion endpoint.
221    ///
222    /// ## Arguments:
223    ///
224    /// * `req` - An instance of [`completion::Request`]
225    ///
226    /// Returns a [`completion::Response`]. A 200 (Ok) status code is expected from the Prediction Guard api. Any other status code
227    /// is considered an error.
228    pub async fn generate_completion(
229        &self,
230        req: &completion::Request,
231    ) -> Result<completion::Response> {
232        let url = format!("{}{}", &self.inner.server, completion::PATH);
233
234        let result = self
235            .inner
236            .http_client
237            .post(url)
238            .headers(self.inner.headers.clone())
239            .json(req)
240            .send()
241            .await?;
242
243        if result.status() != StatusCode::OK {
244            return Err(retrieve_error(result).await);
245        }
246
247        let comp_response = result.json::<completion::Response>().await?;
248
249        Ok(comp_response)
250    }
251
252    /// Calls the generate chat completion endpoint.
253    ///
254    /// ## Arguments:
255    ///
256    /// * `req` - An instance of [`chat::Request::<Message>`]
257    ///
258    /// Returns an instance of [`chat::Response`]. A 200 (Ok) status code is expected from the Prediction Guard api. Any other status code
259    /// is considered an error.
260    pub async fn generate_chat_completion(
261        &self,
262        req: &chat::Request<chat::Message>,
263    ) -> Result<chat::Response> {
264        let url = format!("{}{}", &self.inner.server, chat::PATH);
265
266        let result = self
267            .inner
268            .http_client
269            .post(url)
270            .headers(self.inner.headers.clone())
271            .json(req)
272            .send()
273            .await?;
274
275        if result.status() != StatusCode::OK {
276            return Err(retrieve_error(result).await);
277        }
278
279        let chat_response = result.json::<chat::Response>().await?;
280
281        Ok(chat_response)
282    }
283
284    /// Calls the generate chat completion endpoint.
285    ///
286    /// ## Arguments:
287    ///
288    /// * `req` - An instance of [`chat::Request::<Message>`]
289    /// * `event_handler` - Event handler function that is called when a server side event is raised.
290    ///
291    /// Returns an instance of [`chat::Response`].
292    ///
293    /// The generated text is returned via events from the server. The event handler function gets called
294    /// every time the client receives an event response with data. Once the server terminates the events the call returns.
295    /// The entire [`chat::Response`] response is then returned to the caller.
296    ///
297    /// A 200 (Ok) status code is expected from the Prediction Guard api. Any other status code
298    /// is considered an error.
299    pub async fn generate_chat_completion_events<F>(
300        &self,
301        req: &mut chat::Request<chat::Message>,
302        event_handler: &mut F,
303    ) -> Result<Option<chat::ResponseEvents>>
304    where
305        F: FnMut(&String),
306    {
307        let url = format!("{}{}", &self.inner.server, chat::PATH);
308
309        req.stream = true;
310        req.output = None;
311
312        let body = serde_json::to_string(&req)?;
313
314        let user_agent = format!("{} v{}", USER_AGENT, built_info::PKG_VERSION);
315
316        let key = format!("Bearer {}", &self.inner.api_key);
317
318        let client = eventsource_client::ClientBuilder::for_url(&url)?
319            .header("User-Agent", user_agent.as_str())?
320            .header("Authorization", &key)?
321            .method("POST".to_string())
322            .body(body)
323            .build();
324
325        let mut stream = Box::pin(client.stream());
326
327        loop {
328            match stream.try_next().await {
329                Ok(Some(event)) => {
330                    match event {
331                        SSE::Comment(_) => continue,
332                        SSE::Event(evt) => {
333                            // Check for [DONE]
334                            if evt.data == "[DONE]" {
335                                return Ok(None);
336                            }
337
338                            // JSON Response
339                            let resp: chat::ResponseEvents = match serde_json::from_str(&evt.data) {
340                                Ok(v) => v,
341                                Err(e) => {
342                                    return Err(Box::from(ApiError {
343                                        error: format!("error parsing stream response: {}", e),
344                                    }));
345                                }
346                            };
347
348                            if resp.choices.is_empty() {
349                                // No data to stream or Done
350                                continue;
351                            }
352
353                            // Finish Reason == Stop That is the final Response.
354                            if resp.choices[0].finish_reason == Some("stop".to_string()) {
355                                return Ok(Some(resp));
356                            }
357
358                            let msg = resp.choices[0].delta.clone().content;
359                            event_handler(&msg);
360                        }
361                    }
362                }
363
364                Ok(None) => continue,
365                Err(e) => match e {
366                    eventsource_client::Error::StreamClosed => break,
367                    _ => return Err(stream_error_into_api_err(e).await),
368                },
369            }
370        }
371
372        Ok(None)
373    }
374
375    /// Calls the generate chat completion endpoint.
376    ///
377    /// ## Arguments:
378    ///
379    /// * `req` - An instance of [`chat::Request::<Message>`]
380    /// * `sender` - A sender instance for a channel where there is a receiver waiting for a message.
381    ///
382    /// Returns an instance of [`chat::Response`].
383    ///
384    /// The generated text is returned via events from the server. The sender gets called
385    /// every time the client receives an event response with data. Once the server terminates the events the call returns.
386    /// The receiver should handle the `stop` message which means there are no more messages to receive and exit.
387    /// The entire [`chat::Response`] response is then returned to the caller.
388    ///
389    /// A 200 (Ok) status code is expected from the Prediction Guard api. Any other status code
390    /// is considered an error.
391    pub async fn generate_chat_completion_events_async(
392        &self,
393        req: &mut chat::Request<chat::Message>,
394        sender: &Sender<String>,
395    ) -> Result<Option<chat::ResponseEvents>> {
396        let url = format!("{}{}", &self.inner.server, chat::PATH);
397
398        req.stream = true;
399        req.output = None;
400
401        let body = serde_json::to_string(&req)?;
402
403        let user_agent = format!("{} v{}", USER_AGENT, built_info::PKG_VERSION);
404
405        let key = format!("Bearer {}", &self.inner.api_key);
406
407        let client = eventsource_client::ClientBuilder::for_url(&url)?
408            .header("User-Agent", user_agent.as_str())?
409            .header("Authorization", &key)?
410            .method("POST".to_string())
411            .body(body)
412            .build();
413
414        let mut stream = Box::pin(client.stream());
415
416        loop {
417            match stream.try_next().await {
418                Ok(Some(event)) => {
419                    match event {
420                        SSE::Comment(_) => continue,
421                        SSE::Event(evt) => {
422                            // Check for [DONE]
423                            if evt.data.to_lowercase() == "[done]" {
424                                let _ = sender.send("stop".to_string()).await;
425                                return Ok(None);
426                            }
427
428                            // JSON Response
429                            let resp: chat::ResponseEvents = match serde_json::from_str(&evt.data) {
430                                Ok(v) => v,
431                                Err(e) => {
432                                    return Err(Box::from(ApiError {
433                                        error: format!("error parsing stream response: {}", e),
434                                    }));
435                                }
436                            };
437
438                            if resp.choices.is_empty() {
439                                // No data to stream or Done
440                                continue;
441                            }
442
443                            // Finish Reason == Stop That is the final Response.
444                            if resp.choices[0].finish_reason == Some("stop".to_string()) {
445                                let _ = sender.send("stop".to_string()).await;
446                                return Ok(Some(resp));
447                            }
448
449                            let msg = resp.choices[0].delta.clone().content;
450
451                            match sender.send(msg).await {
452                                Ok(_) => (),
453                                Err(e) => {
454                                    error!("generate_chat_completion_events_async - error sending on channel, {e}");
455                                }
456                            }
457                        }
458                    }
459                }
460
461                Ok(None) => continue,
462                Err(e) => match e {
463                    eventsource_client::Error::StreamClosed => break,
464                    _ => return Err(stream_error_into_api_err(e).await),
465                },
466            }
467        }
468
469        Ok(None)
470    }
471
472    /// Calls the generate chat completion endpoint for chat vision.
473    ///
474    /// ## Arguments:
475    ///
476    /// * `req` - An instance of [`chat::Request::<MessageVision>`]
477    ///
478    /// Returns an instance of [`chat::Response`]. A 200 (Ok) status code is expected from the Prediction Guard api. Any other status code
479    /// is considered an error.
480    pub async fn generate_chat_vision(
481        &self,
482        req: &chat::Request<chat::MessageVision>,
483    ) -> Result<chat::Response> {
484        let url = format!("{}{}", &self.inner.server, chat::PATH);
485
486        let result = self
487            .inner
488            .http_client
489            .post(url)
490            .headers(self.inner.headers.clone())
491            .json(req)
492            .send()
493            .await?;
494
495        if result.status() != StatusCode::OK {
496            return Err(retrieve_error(result).await);
497        }
498
499        let chat_response = result.json::<chat::Response>().await?;
500
501        Ok(chat_response)
502    }
503
504    /// Calls the rerank endpoint.
505    ///
506    /// ## Arguments:
507    ///
508    /// * `req` - An instance of [`rerank::Request`]
509    ///
510    /// Returns an instance of [`rerank::Response`]. A 200 (Ok) status code is expected from the Prediction Guard api. Any other status code
511    /// is considered an error.
512    pub async fn rerank(
513        &self,
514        req: &rerank::Request,
515    ) -> Result<rerank::Response> {
516        let url = format!("{}{}", &self.inner.server, rerank::PATH);
517
518        let result = self
519            .inner
520            .http_client
521            .post(url)
522            .headers(self.inner.headers.clone())
523            .json(req)
524            .send()
525            .await?;
526
527        if result.status() != StatusCode::OK {
528            return Err(retrieve_error(result).await);
529        }
530
531        let rerank_response = result.json::<rerank::Response>().await?;
532
533        Ok(rerank_response)
534    }
535
536    /// Calls the factuality check endpoint.
537    ///
538    /// ## Arguments:
539    ///
540    /// * `req` - An instance of [`factuality::Request`]
541    ///
542    /// Returns am instance of [`factuality::Response`]. A 200 (Ok) status code is expected from the Prediction Guard api. Any other status code
543    /// is considered an error.
544    pub async fn check_factuality(
545        &self,
546        req: &factuality::Request,
547    ) -> Result<factuality::Response> {
548        let url = format!("{}{}", &self.inner.server, factuality::PATH);
549
550        let result = self
551            .inner
552            .http_client
553            .post(url)
554            .headers(self.inner.headers.clone())
555            .json(req)
556            .send()
557            .await?;
558
559        if result.status() != StatusCode::OK {
560            return Err(retrieve_error(result).await);
561        }
562
563        let fact_response = result.json::<factuality::Response>().await?;
564
565        Ok(fact_response)
566    }
567
568    /// Calls the translate endpoint.
569    ///
570    /// ## Arguments:
571    ///
572    /// `req` - Instance of [`translate::Request`]
573    ///
574    /// Returns a [`translate::Response`]. A 200 (Ok) status code is expected from the Prediction Guard api. Any other status code
575    /// is considered an error.
576    pub async fn translate(&self, req: &translate::Request) -> Result<translate::Response> {
577        let url = format!("{}{}", &self.inner.server, translate::PATH);
578
579        let result = self
580            .inner
581            .http_client
582            .post(url)
583            .headers(self.inner.headers.clone())
584            .json(req)
585            .send()
586            .await?;
587
588        if result.status() != StatusCode::OK {
589            return Err(retrieve_error(result).await);
590        }
591
592        let translate_response = result.json::<translate::Response>().await?;
593
594        Ok(translate_response)
595    }
596
597    /// Calls the PII endpoint that is used to remove/detect PII information in the request.
598    ///
599    /// ## Arguments:
600    ///
601    /// `req` - An instance of [`pii::Request`]
602    ///
603    /// Returns an instance of [`pii::Response`]. A 200 (Ok) status code is expected from the Prediction Guard api.
604    /// Any other status code is considered an error.
605    pub async fn pii(&self, req: &pii::Request) -> Result<pii::Response> {
606        let url = format!("{}{}", &self.inner.server, pii::PATH);
607
608        let result = self
609            .inner
610            .http_client
611            .post(url)
612            .headers(self.inner.headers.clone())
613            .json(req)
614            .send()
615            .await?;
616
617        if result.status() != StatusCode::OK {
618            return Err(retrieve_error(result).await);
619        }
620
621        let pii_response = result.json::<pii::Response>().await?;
622
623        Ok(pii_response)
624    }
625
626    /// Calls the injection check endpoint.
627    ///
628    /// ## Arguments:
629    ///
630    /// `req` - Instance of [`injection::Request`]
631    ///
632    /// Returns an instance of [`injection::Response`]. A 200 (Ok) status code is expected from the Prediction Guard api. Any other status code
633    /// is considered an error.
634    pub async fn injection(&self, req: &injection::Request) -> Result<injection::Response> {
635        let url = format!("{}{}", &self.inner.server, injection::PATH);
636
637        let result = self
638            .inner
639            .http_client
640            .post(url)
641            .headers(self.inner.headers.clone())
642            .json(req)
643            .send()
644            .await?;
645
646        if result.status() != StatusCode::OK {
647            return Err(retrieve_error(result).await);
648        }
649
650        let injection_response = result.json::<injection::Response>().await?;
651
652        Ok(injection_response)
653    }
654
655    /// Calls the injection check endpoint.
656    ///
657    /// ## Arguments:
658    ///
659    /// `req` - An instance of [`toxicity::Request`]
660    ///
661    /// Returns an instance of [`toxicity::Response`]. A 200 (Ok) status code is expected from the Prediction Guard api. Any other status code
662    /// is considered an error.
663    pub async fn toxicity(&self, req: &toxicity::Request) -> Result<toxicity::Response> {
664        let url = format!("{}{}", &self.inner.server, toxicity::PATH);
665
666        let result = self
667            .inner
668            .http_client
669            .post(url)
670            .headers(self.inner.headers.clone())
671            .json(req)
672            .send()
673            .await?;
674
675        if result.status() != StatusCode::OK {
676            return Err(retrieve_error(result).await);
677        }
678
679        let toxicity_response = result.json::<toxicity::Response>().await?;
680
681        Ok(toxicity_response)
682    }
683
684    /// Calls the tokenize endpoint.
685    ///
686    /// ## Arguments:
687    ///
688    /// * `req` - An instance of [`tokenize::Request`]
689    ///
690    /// Returns an instance of [`tokenize::Response`]. A 200 (Ok) status code is expected from the Prediction Guard api. Any other status code
691    /// is considered an error.
692    pub async fn tokenize(
693        &self,
694        req: &tokenize::Request,
695    ) -> Result<tokenize::Response> {
696        let url = format!("{}{}", &self.inner.server, tokenize::PATH);
697
698        let result = self
699            .inner
700            .http_client
701            .post(url)
702            .headers(self.inner.headers.clone())
703            .json(req)
704            .send()
705            .await?;
706
707        if result.status() != StatusCode::OK {
708            return Err(retrieve_error(result).await);
709        }
710
711        let token_response = result.json::<tokenize::Response>().await?;
712
713        Ok(token_response)
714    }
715
716
717    /// Retrieves the list of models available.
718    ///
719    /// Returns a vector with the model metadata. A 200 (Ok) status code is expected from the Prediction Guard api. Any other status code
720    /// is considered an error.
721    pub async fn models(
722        &self,
723        req: Option<&models::Request>
724    ) -> Result<models::Response> {
725        let mut url = format!("{}{}", &self.inner.server, models::PATH);
726
727        // If `req` is Some, append it to the URL
728        if let Some(request) = req {
729            if let Some(capability) = &request.capability {
730                url.push('/');
731                url.push_str(capability);
732            }
733        }
734
735        let result = self
736            .inner
737            .http_client
738            .get(url)
739            .headers(self.inner.headers.clone())
740            .send()
741            .await?;
742
743        if result.status() != StatusCode::OK {
744            return Err(retrieve_error(result).await);
745        }
746
747        let model_response = result.json::<models::Response>().await?;
748
749        Ok(model_response)
750    }
751}
752
753async fn retrieve_error(resp: Response) -> Box<dyn std::error::Error> {
754    let err = match resp.json::<ApiError>().await {
755        Ok(x) => x,
756        Err(e) => return Box::from(format!("error parsing error response, {}", e)),
757    };
758
759    err.into()
760}
761
762async fn stream_error_into_api_err(err: eventsource_client::Error) -> Box<dyn std::error::Error> {
763    let msg = format!("{}", err);
764    Box::from(ApiError {
765        error: msg.to_string(),
766    })
767}