prediction_guard/completion.rs
1//! Data types that are used for the completion endpoints.
2use serde::{self, Deserialize, Serialize};
3
4use crate::pii;
5
6/// Path to the completions endpoint.
7pub const PATH: &str = "/completions";
8
9/// Allows to request PII check and Injection check on the inputs in the chat request.
10#[derive(Debug, Default, Deserialize, Serialize)]
11pub struct RequestInput {
12 pub(crate) block_prompt_injection: bool,
13 pub(crate) pii: Option<pii::InputMethod>,
14 pub(crate) pii_replace_method: Option<pii::ReplaceMethod>,
15}
16
17/// Allows for checking the output of the request for factuality and toxicity.
18#[derive(Debug, Deserialize, Serialize)]
19pub struct RequestOutput {
20 pub factuality: bool,
21 pub toxicity: bool,
22}
23
24/// Completion request for the base completion endpoint.
25#[derive(Debug, Deserialize, Default, Serialize)]
26pub struct Request {
27 pub(crate) model: String,
28 pub(crate) prompt: String,
29 pub(crate) max_tokens: Option<i64>,
30 pub(crate) temperature: Option<f64>,
31 pub(crate) top_p: Option<f64>,
32 pub(crate) top_k: Option<i64>,
33 pub(crate) input: Option<RequestInput>,
34 pub(crate) output: Option<RequestOutput>,
35}
36
37impl Request {
38 /// Creates a new request for completion.
39 ///
40 /// ## Arguments
41 ///
42 /// * `model` - The model to be used for the request.
43 /// * `prompt` - The prompt to be used for the completion request.
44 pub fn new(model: String, prompt: String) -> Self {
45 Self {
46 model,
47 prompt,
48 ..Default::default()
49 }
50 }
51
52 /// Sets the max tokens for the request.
53 ///
54 /// ## Arguments
55 ///
56 /// * `max` - The maximum number of tokens to be returned in the response.
57 pub fn max_tokens(mut self, max: i64) -> Request {
58 self.max_tokens = Some(max);
59 self
60 }
61
62 /// Sets the temperature for the request.
63 ///
64 /// ## Arguments
65 ///
66 /// * `temp` - The temperature setting for the request. Used to control randomness.
67 pub fn temperature(mut self, temp: f64) -> Request {
68 self.temperature = Some(temp);
69 self
70 }
71
72 /// Sets the Top p for the request.
73 ///
74 /// ## Arguments
75 ///
76 /// * `top` - The Top p setting for the request. Used to control randomness.
77 pub fn top_p(mut self, top: f64) -> Request {
78 self.top_p = Some(top);
79 self
80 }
81
82 /// Sets the Top k for the request.
83 ///
84 /// ## Arguments
85 ///
86 /// * `top_k` - The Top k setting for the request. Used to control randomness.
87 pub fn top_k(mut self, top_k: i64) -> Request {
88 self.top_k = Some(top_k);
89 self
90 }
91
92 /// Sets the input parameters for the request, to check for prompt injection and PII.
93 ///
94 /// ## Arguments
95 ///
96 /// * `block_prompt_injection` - Determines whether to check for prompt injection in
97 /// the request.
98 /// * `pii` - Sets the `pii::InputMethod` and the `pii::ReplacementMethod`.
99 pub fn input(
100 mut self,
101 block_prompt_injection: bool,
102 pii: Option<(pii::InputMethod, pii::ReplaceMethod)>,
103 ) -> Request {
104 match self.input {
105 Some(ref mut x) => {
106 // set values on request input
107 x.block_prompt_injection = block_prompt_injection;
108 if let Some(p) = pii {
109 x.pii = Some(p.0);
110 x.pii_replace_method = Some(p.1);
111 }
112 }
113 None => {
114 // create request input
115 let mut input = RequestInput {
116 block_prompt_injection,
117 ..Default::default()
118 };
119
120 if let Some(p) = pii {
121 input.pii = Some(p.0);
122 input.pii_replace_method = Some(p.1);
123 }
124 self.input = Some(input);
125 }
126 }
127 self
128 }
129
130 /// Sets the output parameters for the request, to check for factuality and toxicity.
131 ///
132 /// ## Arguments
133 ///
134 /// * `check_factuality` - Determines whether to check for factuality in the response.
135 /// * `check_toxicity` - Determines whether to check for toxicity in the response.
136 pub fn output(mut self, check_factuality: bool, check_toxicity: bool) -> Request {
137 match self.output {
138 Some(ref mut x) => {
139 x.factuality = check_factuality;
140 x.toxicity = check_toxicity;
141 }
142 None => {
143 self.output = Some(RequestOutput {
144 toxicity: check_toxicity,
145 factuality: check_factuality,
146 })
147 }
148 };
149 self
150 }
151}
152
153/// Represents a choice in the base completion response.
154#[derive(Debug, Default, Deserialize, Serialize)]
155#[serde(default)]
156pub struct Choice {
157 pub text: String,
158 pub index: i64,
159}
160
161/// Completion response for the base completetion endpoint.
162#[derive(Debug, Default, Deserialize, Serialize)]
163#[serde(default)]
164pub struct Response {
165 pub id: String,
166 pub object: String,
167 pub model: String,
168 pub created: i64,
169 pub choices: Vec<Choice>,
170}