prediction_guard/
completion.rs

1//! Data types that are used for the completion endpoints.
2use serde::{self, Deserialize, Serialize};
3
4use crate::pii;
5
6/// Path to the completions endpoint.
7pub const PATH: &str = "/completions";
8
9/// Allows to request PII check and Injection check on the inputs in the chat request.
10#[derive(Debug, Default, Deserialize, Serialize)]
11pub struct RequestInput {
12    pub(crate) block_prompt_injection: bool,
13    pub(crate) pii: Option<pii::InputMethod>,
14    pub(crate) pii_replace_method: Option<pii::ReplaceMethod>,
15}
16
17/// Allows for checking the output of the request for factuality and toxicity.
18#[derive(Debug, Deserialize, Serialize)]
19pub struct RequestOutput {
20    pub factuality: bool,
21    pub toxicity: bool,
22}
23
24/// Completion request for the base completion endpoint.
25#[derive(Debug, Deserialize, Default, Serialize)]
26pub struct Request {
27    pub(crate) model: String,
28    pub(crate) prompt: String,
29    pub(crate) max_tokens: Option<i64>,
30    pub(crate) temperature: Option<f64>,
31    pub(crate) top_p: Option<f64>,
32    pub(crate) top_k: Option<i64>,
33    pub(crate) input: Option<RequestInput>,
34    pub(crate) output: Option<RequestOutput>,
35}
36
37impl Request {
38    /// Creates a new request for completion.
39    ///
40    /// ## Arguments
41    ///
42    /// * `model` - The model to be used for the request.
43    /// * `prompt` - The prompt to be used for the completion request.
44    pub fn new(model: String, prompt: String) -> Self {
45        Self {
46            model,
47            prompt,
48            ..Default::default()
49        }
50    }
51
52    /// Sets the max tokens for the request.
53    ///
54    /// ## Arguments
55    ///
56    /// * `max` - The maximum number of tokens to be returned in the response.
57    pub fn max_tokens(mut self, max: i64) -> Request {
58        self.max_tokens = Some(max);
59        self
60    }
61
62    /// Sets the temperature for the request.
63    ///
64    /// ## Arguments
65    ///
66    /// * `temp` - The temperature setting for the request. Used to control randomness.
67    pub fn temperature(mut self, temp: f64) -> Request {
68        self.temperature = Some(temp);
69        self
70    }
71
72    /// Sets the Top p for the request.
73    ///
74    /// ## Arguments
75    ///
76    /// * `top` - The Top p setting for the request. Used to control randomness.
77    pub fn top_p(mut self, top: f64) -> Request {
78        self.top_p = Some(top);
79        self
80    }
81
82    /// Sets the Top k for the request.
83    ///
84    /// ## Arguments
85    ///
86    /// * `top_k` - The Top k setting for the request. Used to control randomness.
87    pub fn top_k(mut self, top_k: i64) -> Request {
88        self.top_k = Some(top_k);
89        self
90    }
91
92    /// Sets the input parameters for the request, to check for prompt injection and PII.
93    ///
94    /// ## Arguments
95    ///
96    /// * `block_prompt_injection` - Determines whether to check for prompt injection in
97    ///   the request.
98    /// * `pii` - Sets the `pii::InputMethod` and the `pii::ReplacementMethod`.
99    pub fn input(
100        mut self,
101        block_prompt_injection: bool,
102        pii: Option<(pii::InputMethod, pii::ReplaceMethod)>,
103    ) -> Request {
104        match self.input {
105            Some(ref mut x) => {
106                // set values on request input
107                x.block_prompt_injection = block_prompt_injection;
108                if let Some(p) = pii {
109                    x.pii = Some(p.0);
110                    x.pii_replace_method = Some(p.1);
111                }
112            }
113            None => {
114                // create request input
115                let mut input = RequestInput {
116                    block_prompt_injection,
117                    ..Default::default()
118                };
119
120                if let Some(p) = pii {
121                    input.pii = Some(p.0);
122                    input.pii_replace_method = Some(p.1);
123                }
124                self.input = Some(input);
125            }
126        }
127        self
128    }
129
130    /// Sets the output parameters for the request, to check for factuality and toxicity.
131    ///
132    /// ## Arguments
133    ///
134    /// * `check_factuality` - Determines whether to check for factuality in the response.
135    /// * `check_toxicity` - Determines whether to check for toxicity in the response.
136    pub fn output(mut self, check_factuality: bool, check_toxicity: bool) -> Request {
137        match self.output {
138            Some(ref mut x) => {
139                x.factuality = check_factuality;
140                x.toxicity = check_toxicity;
141            }
142            None => {
143                self.output = Some(RequestOutput {
144                    toxicity: check_toxicity,
145                    factuality: check_factuality,
146                })
147            }
148        };
149        self
150    }
151}
152
153/// Represents a choice in the base completion response.
154#[derive(Debug, Default, Deserialize, Serialize)]
155#[serde(default)]
156pub struct Choice {
157    pub text: String,
158    pub index: i64,
159}
160
161/// Completion response for the base completetion endpoint.
162#[derive(Debug, Default, Deserialize, Serialize)]
163#[serde(default)]
164pub struct Response {
165    pub id: String,
166    pub object: String,
167    pub model: String,
168    pub created: i64,
169    pub choices: Vec<Choice>,
170}