1use std::{env, fmt, sync::Arc, time::Duration};
3
4use crate::built_info;
5use crate::{
6 chat, completion, embedding, factuality,
7 injection, pii, rerank, toxicity, translate,
8 tokenize, models, Result
9};
10use dotenvy;
11use eventsource_client::Client as EventClient;
12use eventsource_client::SSE;
13use futures::TryStreamExt;
14use log::error;
15use reqwest::{
16 header::{HeaderMap, HeaderValue},
17 ClientBuilder, Response, StatusCode,
18};
19use serde::{Deserialize, Serialize};
20use tokio::sync::mpsc::Sender;
21
22const USER_AGENT: &str = "Prediction Guard Rust Client";
23
24#[derive(Debug, Serialize, Deserialize, Clone, Default)]
26pub struct ApiError {
27 error: String,
28}
29
30impl fmt::Display for ApiError {
31 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32 f.write_fmt(format_args!("error {}", self.error))
33 }
34}
35
36impl std::error::Error for ApiError {}
37
38pub struct PgEnvironment {
40 pub key: String,
41 pub host: String,
42}
43
44impl PgEnvironment {
45 pub fn new(key: String, host: String) -> Self {
52 Self { key, host }
53 }
54
55 pub fn from_env() -> Result<Self> {
61 let _ = dotenvy::dotenv(); Ok(Self {
63 key: env::var("PREDICTIONGUARD_API_KEY")?,
64 host: env::var("PREDICTIONGUARD_URL")?,
65 })
66 }
67}
68
69#[derive(Debug, Clone)]
72pub struct Client {
73 inner: Arc<ClientInner>,
74}
75
76#[derive(Debug)]
77struct ClientInner {
78 server: String,
79 http_client: reqwest::Client,
80 headers: HeaderMap,
81 api_key: String,
82}
83
84impl Client {
85 pub fn new() -> Result<Self> {
88 let pg_env = PgEnvironment::from_env().expect("env keys");
89
90 Self::from_environment(pg_env)
91 }
92
93 pub fn from_environment(pg_env: PgEnvironment) -> Result<Self> {
99 let user_agent = format!("{} v{}", USER_AGENT, built_info::PKG_VERSION);
100
101 let http = ClientBuilder::new()
102 .connect_timeout(Duration::new(30, 0))
103 .read_timeout(Duration::new(30, 0))
104 .timeout(Duration::new(45, 0))
105 .user_agent(user_agent)
106 .build()?;
107
108 let header_key = match HeaderValue::from_str(&pg_env.key) {
109 Ok(x) => x,
110 Err(e) => {
111 return Err(Box::new(e));
112 }
113 };
114
115 let mut header_map = HeaderMap::new();
116 let _ = header_map
117 .insert("x-api-key", header_key)
118 .ok_or("invalid api key");
119
120 let inner = Arc::new(ClientInner {
121 server: pg_env.host.to_string(),
122 http_client: http,
123 headers: header_map,
124 api_key: pg_env.key,
125 });
126
127 Ok(Self { inner })
128 }
129
130 pub async fn check_health(&self) -> Result<String> {
135 let result = self
136 .inner
137 .http_client
138 .get(&self.inner.server)
139 .headers(self.inner.headers.clone())
140 .send()
141 .await?;
142
143 if result.status() != StatusCode::OK {
144 return Err(retrieve_error(result).await);
145 }
146
147 let txt = result.text().await?;
148
149 Ok(txt)
150 }
151
152 pub async fn retrieve_model_list(&self, capability: String) -> Result<Vec<String>> {
161 let url = format!(
162 "{}{}/{}",
163 &self.inner.server,
164 models::PATH,
165 capability
166 );
167
168 let result = self
169 .inner
170 .http_client
171 .get(url)
172 .headers(self.inner.headers.clone())
173 .send()
174 .await?;
175
176 if result.status() != StatusCode::OK {
177 return Err(retrieve_error(result).await);
178 }
179
180 let response_body: models::Response = result.json().await?;
181
182 let retrieve_models_response: Vec<String> = response_body
183 .data
184 .into_iter()
185 .map(|model| model.id)
186 .collect();
187
188 Ok(retrieve_models_response)
189 }
190
191 pub async fn embedding(&self, req: &embedding::Request) -> Result<embedding::Response> {
200 let url = format!("{}{}", &self.inner.server, embedding::PATH);
201
202 let result = self
203 .inner
204 .http_client
205 .post(url)
206 .headers(self.inner.headers.clone())
207 .json(req)
208 .send()
209 .await?;
210
211 if result.status() != StatusCode::OK {
212 return Err(retrieve_error(result).await);
213 }
214
215 let embed_response = result.json::<embedding::Response>().await?;
216
217 Ok(embed_response)
218 }
219
220 pub async fn generate_completion(
229 &self,
230 req: &completion::Request,
231 ) -> Result<completion::Response> {
232 let url = format!("{}{}", &self.inner.server, completion::PATH);
233
234 let result = self
235 .inner
236 .http_client
237 .post(url)
238 .headers(self.inner.headers.clone())
239 .json(req)
240 .send()
241 .await?;
242
243 if result.status() != StatusCode::OK {
244 return Err(retrieve_error(result).await);
245 }
246
247 let comp_response = result.json::<completion::Response>().await?;
248
249 Ok(comp_response)
250 }
251
252 pub async fn generate_chat_completion(
261 &self,
262 req: &chat::Request<chat::Message>,
263 ) -> Result<chat::Response> {
264 let url = format!("{}{}", &self.inner.server, chat::PATH);
265
266 let result = self
267 .inner
268 .http_client
269 .post(url)
270 .headers(self.inner.headers.clone())
271 .json(req)
272 .send()
273 .await?;
274
275 if result.status() != StatusCode::OK {
276 return Err(retrieve_error(result).await);
277 }
278
279 let chat_response = result.json::<chat::Response>().await?;
280
281 Ok(chat_response)
282 }
283
284 pub async fn generate_chat_completion_events<F>(
300 &self,
301 req: &mut chat::Request<chat::Message>,
302 event_handler: &mut F,
303 ) -> Result<Option<chat::ResponseEvents>>
304 where
305 F: FnMut(&String),
306 {
307 let url = format!("{}{}", &self.inner.server, chat::PATH);
308
309 req.stream = true;
310 req.output = None;
311
312 let body = serde_json::to_string(&req)?;
313
314 let user_agent = format!("{} v{}", USER_AGENT, built_info::PKG_VERSION);
315
316 let key = format!("Bearer {}", &self.inner.api_key);
317
318 let client = eventsource_client::ClientBuilder::for_url(&url)?
319 .header("User-Agent", user_agent.as_str())?
320 .header("Authorization", &key)?
321 .method("POST".to_string())
322 .body(body)
323 .build();
324
325 let mut stream = Box::pin(client.stream());
326
327 loop {
328 match stream.try_next().await {
329 Ok(Some(event)) => {
330 match event {
331 SSE::Comment(_) => continue,
332 SSE::Event(evt) => {
333 if evt.data == "[DONE]" {
335 return Ok(None);
336 }
337
338 let resp: chat::ResponseEvents = match serde_json::from_str(&evt.data) {
340 Ok(v) => v,
341 Err(e) => {
342 return Err(Box::from(ApiError {
343 error: format!("error parsing stream response: {}", e),
344 }));
345 }
346 };
347
348 if resp.choices.is_empty() {
349 continue;
351 }
352
353 if resp.choices[0].finish_reason == Some("stop".to_string()) {
355 return Ok(Some(resp));
356 }
357
358 let msg = resp.choices[0].delta.clone().content;
359 event_handler(&msg);
360 }
361 }
362 }
363
364 Ok(None) => continue,
365 Err(e) => match e {
366 eventsource_client::Error::StreamClosed => break,
367 _ => return Err(stream_error_into_api_err(e).await),
368 },
369 }
370 }
371
372 Ok(None)
373 }
374
375 pub async fn generate_chat_completion_events_async(
392 &self,
393 req: &mut chat::Request<chat::Message>,
394 sender: &Sender<String>,
395 ) -> Result<Option<chat::ResponseEvents>> {
396 let url = format!("{}{}", &self.inner.server, chat::PATH);
397
398 req.stream = true;
399 req.output = None;
400
401 let body = serde_json::to_string(&req)?;
402
403 let user_agent = format!("{} v{}", USER_AGENT, built_info::PKG_VERSION);
404
405 let key = format!("Bearer {}", &self.inner.api_key);
406
407 let client = eventsource_client::ClientBuilder::for_url(&url)?
408 .header("User-Agent", user_agent.as_str())?
409 .header("Authorization", &key)?
410 .method("POST".to_string())
411 .body(body)
412 .build();
413
414 let mut stream = Box::pin(client.stream());
415
416 loop {
417 match stream.try_next().await {
418 Ok(Some(event)) => {
419 match event {
420 SSE::Comment(_) => continue,
421 SSE::Event(evt) => {
422 if evt.data.to_lowercase() == "[done]" {
424 let _ = sender.send("stop".to_string()).await;
425 return Ok(None);
426 }
427
428 let resp: chat::ResponseEvents = match serde_json::from_str(&evt.data) {
430 Ok(v) => v,
431 Err(e) => {
432 return Err(Box::from(ApiError {
433 error: format!("error parsing stream response: {}", e),
434 }));
435 }
436 };
437
438 if resp.choices.is_empty() {
439 continue;
441 }
442
443 if resp.choices[0].finish_reason == Some("stop".to_string()) {
445 let _ = sender.send("stop".to_string()).await;
446 return Ok(Some(resp));
447 }
448
449 let msg = resp.choices[0].delta.clone().content;
450
451 match sender.send(msg).await {
452 Ok(_) => (),
453 Err(e) => {
454 error!("generate_chat_completion_events_async - error sending on channel, {e}");
455 }
456 }
457 }
458 }
459 }
460
461 Ok(None) => continue,
462 Err(e) => match e {
463 eventsource_client::Error::StreamClosed => break,
464 _ => return Err(stream_error_into_api_err(e).await),
465 },
466 }
467 }
468
469 Ok(None)
470 }
471
472 pub async fn generate_chat_vision(
481 &self,
482 req: &chat::Request<chat::MessageVision>,
483 ) -> Result<chat::Response> {
484 let url = format!("{}{}", &self.inner.server, chat::PATH);
485
486 let result = self
487 .inner
488 .http_client
489 .post(url)
490 .headers(self.inner.headers.clone())
491 .json(req)
492 .send()
493 .await?;
494
495 if result.status() != StatusCode::OK {
496 return Err(retrieve_error(result).await);
497 }
498
499 let chat_response = result.json::<chat::Response>().await?;
500
501 Ok(chat_response)
502 }
503
504 pub async fn rerank(
513 &self,
514 req: &rerank::Request,
515 ) -> Result<rerank::Response> {
516 let url = format!("{}{}", &self.inner.server, rerank::PATH);
517
518 let result = self
519 .inner
520 .http_client
521 .post(url)
522 .headers(self.inner.headers.clone())
523 .json(req)
524 .send()
525 .await?;
526
527 if result.status() != StatusCode::OK {
528 return Err(retrieve_error(result).await);
529 }
530
531 let rerank_response = result.json::<rerank::Response>().await?;
532
533 Ok(rerank_response)
534 }
535
536 pub async fn check_factuality(
545 &self,
546 req: &factuality::Request,
547 ) -> Result<factuality::Response> {
548 let url = format!("{}{}", &self.inner.server, factuality::PATH);
549
550 let result = self
551 .inner
552 .http_client
553 .post(url)
554 .headers(self.inner.headers.clone())
555 .json(req)
556 .send()
557 .await?;
558
559 if result.status() != StatusCode::OK {
560 return Err(retrieve_error(result).await);
561 }
562
563 let fact_response = result.json::<factuality::Response>().await?;
564
565 Ok(fact_response)
566 }
567
568 pub async fn translate(&self, req: &translate::Request) -> Result<translate::Response> {
577 let url = format!("{}{}", &self.inner.server, translate::PATH);
578
579 let result = self
580 .inner
581 .http_client
582 .post(url)
583 .headers(self.inner.headers.clone())
584 .json(req)
585 .send()
586 .await?;
587
588 if result.status() != StatusCode::OK {
589 return Err(retrieve_error(result).await);
590 }
591
592 let translate_response = result.json::<translate::Response>().await?;
593
594 Ok(translate_response)
595 }
596
597 pub async fn pii(&self, req: &pii::Request) -> Result<pii::Response> {
606 let url = format!("{}{}", &self.inner.server, pii::PATH);
607
608 let result = self
609 .inner
610 .http_client
611 .post(url)
612 .headers(self.inner.headers.clone())
613 .json(req)
614 .send()
615 .await?;
616
617 if result.status() != StatusCode::OK {
618 return Err(retrieve_error(result).await);
619 }
620
621 let pii_response = result.json::<pii::Response>().await?;
622
623 Ok(pii_response)
624 }
625
626 pub async fn injection(&self, req: &injection::Request) -> Result<injection::Response> {
635 let url = format!("{}{}", &self.inner.server, injection::PATH);
636
637 let result = self
638 .inner
639 .http_client
640 .post(url)
641 .headers(self.inner.headers.clone())
642 .json(req)
643 .send()
644 .await?;
645
646 if result.status() != StatusCode::OK {
647 return Err(retrieve_error(result).await);
648 }
649
650 let injection_response = result.json::<injection::Response>().await?;
651
652 Ok(injection_response)
653 }
654
655 pub async fn toxicity(&self, req: &toxicity::Request) -> Result<toxicity::Response> {
664 let url = format!("{}{}", &self.inner.server, toxicity::PATH);
665
666 let result = self
667 .inner
668 .http_client
669 .post(url)
670 .headers(self.inner.headers.clone())
671 .json(req)
672 .send()
673 .await?;
674
675 if result.status() != StatusCode::OK {
676 return Err(retrieve_error(result).await);
677 }
678
679 let toxicity_response = result.json::<toxicity::Response>().await?;
680
681 Ok(toxicity_response)
682 }
683
684 pub async fn tokenize(
693 &self,
694 req: &tokenize::Request,
695 ) -> Result<tokenize::Response> {
696 let url = format!("{}{}", &self.inner.server, tokenize::PATH);
697
698 let result = self
699 .inner
700 .http_client
701 .post(url)
702 .headers(self.inner.headers.clone())
703 .json(req)
704 .send()
705 .await?;
706
707 if result.status() != StatusCode::OK {
708 return Err(retrieve_error(result).await);
709 }
710
711 let token_response = result.json::<tokenize::Response>().await?;
712
713 Ok(token_response)
714 }
715
716
717 pub async fn models(
722 &self,
723 req: Option<&models::Request>
724 ) -> Result<models::Response> {
725 let mut url = format!("{}{}", &self.inner.server, models::PATH);
726
727 if let Some(request) = req {
729 if let Some(capability) = &request.capability {
730 url.push('/');
731 url.push_str(capability);
732 }
733 }
734
735 let result = self
736 .inner
737 .http_client
738 .get(url)
739 .headers(self.inner.headers.clone())
740 .send()
741 .await?;
742
743 if result.status() != StatusCode::OK {
744 return Err(retrieve_error(result).await);
745 }
746
747 let model_response = result.json::<models::Response>().await?;
748
749 Ok(model_response)
750 }
751}
752
753async fn retrieve_error(resp: Response) -> Box<dyn std::error::Error> {
754 let err = match resp.json::<ApiError>().await {
755 Ok(x) => x,
756 Err(e) => return Box::from(format!("error parsing error response, {}", e)),
757 };
758
759 err.into()
760}
761
762async fn stream_error_into_api_err(err: eventsource_client::Error) -> Box<dyn std::error::Error> {
763 let msg = format!("{}", err);
764 Box::from(ApiError {
765 error: msg.to_string(),
766 })
767}