1use async_trait::async_trait;
2use crate::{Input, JsonRequest, Usage};
3use std::collections::HashMap;
4use serde::{Serialize,Deserialize};
5
6#[derive(Clone, Serialize, Deserialize, Debug)]
19pub struct CompletionRequest {
20 model: String,
21 prompt: Input,
22 #[serde(skip_serializing_if = "Option::is_none")]
23 suffix: Option<String>,
24 #[serde(skip_serializing_if = "Option::is_none")]
25 max_tokens: Option<u32>,
26 #[serde(skip_serializing_if = "Option::is_none")]
27 temperature: Option<f32>,
28 #[serde(skip_serializing_if = "Option::is_none")]
29 top_p: Option<f32>,
30 #[serde(skip_serializing_if = "Option::is_none")]
31 n: Option<u16>,
32 #[serde(skip_serializing_if = "Option::is_none")]
33 stream: Option<bool>,
34 #[serde(skip_serializing_if = "Option::is_none")]
35 logprobs: Option<u32>,
36 #[serde(skip_serializing_if = "Option::is_none")]
37 echo: Option<bool>,
38 #[serde(skip_serializing_if = "Option::is_none")]
39 stop: Option<Vec<String>>,
40 #[serde(skip_serializing_if = "Option::is_none")]
41 presence_penalty: Option<f32>,
42 #[serde(skip_serializing_if = "Option::is_none")]
43 frequency_penalty: Option<f32>,
44 #[serde(skip_serializing_if = "Option::is_none")]
45 best_of: Option<u16>,
46 #[serde(skip_serializing_if = "Option::is_none")]
47 logit_bias: Option<HashMap<String, f32>>,
48 #[serde(skip_serializing_if = "Option::is_none")]
49 user: Option<String>,
50}
51
52
53#[async_trait(?Send)]
54impl JsonRequest<CompletionSuccess> for CompletionRequest {
55 const ENDPOINT: &'static str = "/completions";
56}
57
58impl CompletionRequest {
59 pub fn new(prompt: Input) -> CompletionRequest {
60 CompletionRequest {
61 model: "text-davinci-003".to_string(),
62 prompt,
63 suffix: None,
64 max_tokens: None,
65 temperature: None,
66 top_p: None,
67 n: None,
68 stream: None,
69 logprobs: None,
70 echo: None,
71 stop: None,
72 presence_penalty: None,
73 frequency_penalty: None,
74 best_of: None,
75 logit_bias: None,
76 user: None,
77 }
78 }
79 pub fn with_model(model: &str, prompt: Input) -> CompletionRequest {
80 CompletionRequest {
81 model: model.to_string(),
82 prompt,
83 suffix: None,
84 max_tokens: None,
85 temperature: None,
86 top_p: None,
87 n: None,
88 stream: None,
89 logprobs: None,
90 echo: None,
91 stop: None,
92 presence_penalty: None,
93 frequency_penalty: None,
94 best_of: None,
95 logit_bias: None,
96 user: None,
97 }
98 }
99
100 pub fn set_suffix(&mut self, suffix: &str) -> &mut Self {
101 self.suffix = Some(suffix.to_string());
102 self
103 }
104
105 pub fn set_max_tokens(&mut self, max_tokens: u32) -> &mut Self {
106 self.max_tokens = Some(max_tokens);
107 self
108 }
109
110 pub fn set_temperature(&mut self, temperature: f32) -> &mut Self {
111 self.temperature = Some(temperature);
112 self
113 }
114
115 pub fn set_top_p(&mut self, top_p: f32) -> &mut Self {
116 self.top_p = Some(top_p);
117 self
118 }
119
120 pub fn set_n(&mut self, n: u16) -> &mut Self {
121 self.n = Some(n);
122 self
123 }
124
125 pub fn set_stream(&mut self, stream: bool) -> &mut Self {
126 self.stream = Some(stream);
127 self
128 }
129
130 pub fn set_logprobs(&mut self, logprobs: u32) -> &mut Self {
131 self.logprobs = Some(logprobs);
132 self
133 }
134
135 pub fn set_echo(&mut self, echo: bool) -> &mut Self {
136 self.echo = Some(echo);
137 self
138 }
139
140 pub fn set_stop(mut self, stop: impl Into<Vec<String>>) -> Self {
141 self.stop = Some(stop.into());
142 self
143 }
144
145 pub fn set_presence_penalty(mut self, presence_penalty: f32) -> Self {
146 self.presence_penalty = Some(presence_penalty);
147 self
148 }
149
150 pub fn set_frequency_penalty(mut self, frequency_penalty: f32) -> Self {
151 self.frequency_penalty = Some(frequency_penalty);
152 self
153 }
154
155 pub fn set_best_of(mut self, best_of: u16) -> Self {
156 self.best_of = Some(best_of);
157 self
158 }
159
160 pub fn set_logit_bias(mut self, logit_bias: HashMap<String, f32>) -> Self {
161 self.logit_bias = Some(logit_bias);
162 self
163 }
164
165 pub fn set_user(mut self, user: String) -> Self {
166 self.user = Some(user);
167 self
168 }
169}
170
171#[derive(Clone, Serialize, Deserialize, Debug)]
172pub struct CompletionChoice {
173 pub text: String,
174 pub index: i64,
175 pub logprobs: Option<u32>,
176 pub finish_reason: String,
177}
178
179#[derive(Clone, Serialize, Deserialize, Debug)]
180pub struct CompletionSuccess {
181 pub id: String,
182 pub object: String,
183 pub created: i64,
184 pub model: String,
185 pub choices: Vec<CompletionChoice>,
186 pub usage: Usage,
187}