1use thiserror::Error;
2use awc::Client;
3use serde::Deserialize;
4use serde::Serialize;
5use std::collections::HashMap;
6use std::time::Duration;
7
8
9pub async fn completions(
10 prompt: &str,
11 params: &CompletionsParams,
12 api_key: &str,
13) -> Result<CompletionsResponse, ClientError> {
14 let client = Client::default();
15
16 let request = Request {
17 model: params.model.clone(),
18 prompt: prompt.to_string(),
19 temperature: params.temperature,
20 max_tokens: params.max_tokens,
21 top_p: params.top_p,
22 frequency_penalty: params.frequency_penalty,
23 presence_penalty: params.presence_penalty,
24 stop: params.stop.clone(),
25 suffix: params.suffix.clone(),
26 logprobs: params.logprobs,
27 echo: params.echo,
28 best_of: params.best_of,
29 n: params.n,
30 stream: params.stream,
31 logit_bias: params.logit_bias.clone(),
32 user: params.user.clone(),
33 };
34
35 let request = serde_json::to_string(&request)
36 .map_err(|e| ClientError::OtherError(format!("{:?}",e)))?;
37 let response = client
38 .post("https://api.openai.com/v1/completions")
39 .timeout(Duration::from_secs(30))
40 .insert_header(("Content-Type", "application/json"))
41 .insert_header(("Authorization", format!("Bearer {}", api_key)))
42 .send_body(request)
43 .await
44 .map_err(|e| ClientError::NetworkError(format!("{:?}",e)))?
45 .body()
46 .await
47 .map_err(|e| ClientError::NetworkError(format!("{:?}",e)))?;
48 let response_str = std::str::from_utf8(response.as_ref())
49 .map_err(|e| ClientError::OtherError(format!("{:?}",e)))?;
50
51 let completions_response: CompletionsResponse = match serde_json::from_str(response_str) {
52 Ok(response) => response,
53 Err(e1) => {
54 let error_response: ErrorResponse = match serde_json::from_str(response_str) {
55 Ok(response) => response,
56 Err(e2) => {
57 return Err(ClientError::OtherError(format!("{:?} {:?}",e2, e1)));
58 }
59 };
60 return Err(ClientError::APIError(error_response.error.message));
61 }
62 };
63 Ok(completions_response)
64
65}
66
67#[derive(Debug, Error)]
68pub enum ClientError {
69 #[error("OpenAI API error: `{0}`")]
70 APIError(String),
71 #[error("Network error: `{0}`")]
72 NetworkError(String),
73 #[error("Other error: `{0}`")]
74 OtherError(String),
75}
76
77pub async fn completions_pretty (
78 prompt: &str,
79 model: &str,
80 max_tokens: u32,
81 api_key: &str,
82) -> Result<String, ClientError> {
83 let params = CompletionsParams {
84 model: model.to_string(),
85 temperature: 0,
86 max_tokens: max_tokens,
87 top_p: 1.0,
88 frequency_penalty: 0.0,
89 presence_penalty: 0.0,
90 stop: None,
91 suffix: None,
92 n: 1,
93 stream: false,
94 logprobs: None,
95 echo: false,
96 best_of: 1,
97 logit_bias: None,
98 user: None,
99 };
100
101 let res = completions(prompt, ¶ms, api_key).await?;
102 Ok(res.choices[0].text.clone())
103}
104
105
106
107#[derive(Serialize, Deserialize, Debug)]
108pub struct ErrorResponse {
109 pub error: ErrorResponseObject,
110}
111
112#[derive(Serialize, Deserialize, Debug)]
113pub struct ErrorResponseObject {
114 pub message: String,
115 pub r#type: String,
116 pub param: Option<String>,
117 pub code: Option<String>,
118}
119
120#[derive(Deserialize, Serialize)]
121pub struct CompletionsParams {
122 pub model: String,
123 pub temperature: u32,
124 pub max_tokens: u32,
125 pub top_p: f32,
126 pub frequency_penalty: f32,
127 pub presence_penalty: f32,
128 pub stop: Option<Vec<String>>,
129 pub suffix: Option<String>,
130 pub n: u32,
131 pub stream: bool,
132 pub logprobs: Option<u32>,
133 pub echo: bool,
134 pub best_of: u32,
135 pub logit_bias: Option<HashMap<String, i32>>,
136 pub user: Option<String>,
137}
138
139#[derive(Deserialize, Serialize, Debug)]
140pub struct CompletionsResponse {
141 pub id: String,
142 pub object: String,
143 pub created: u32,
144 pub model: String,
145 pub choices: Vec<CompletionsChoice>,
146 pub usage: Usage,
147}
148
149#[derive(Deserialize, Serialize, Debug)]
150pub struct CompletionsChoice {
151 pub text: String,
152 pub index: u32,
153 pub logprobs: Option<String>,
154 pub finish_reason: String,
155}
156
157#[derive(Deserialize, Serialize, Debug)]
158pub struct Usage {
159 pub prompt_tokens: u32,
160 pub completion_tokens: u32,
161 pub total_tokens: u32,
162}
163
164#[derive(Deserialize, Serialize)]
165pub struct Request {
166 pub model: String,
167 pub prompt: String,
168 pub temperature: u32,
169 pub max_tokens: u32,
170 pub top_p: f32,
171 pub frequency_penalty: f32,
172 pub presence_penalty: f32,
173 #[serde(skip_serializing_if = "Option::is_none")]
174 pub stop: Option<Vec<String>>,
175 #[serde(skip_serializing_if = "Option::is_none")]
176 pub suffix: Option<String>,
177 pub n: u32,
178 pub stream: bool,
179 #[serde(skip_serializing_if = "Option::is_none")]
180 pub logprobs: Option<u32>,
181 pub echo: bool,
182 pub best_of: u32,
183 #[serde(skip_serializing_if = "Option::is_none")]
184 pub logit_bias: Option<HashMap<String, i32>>,
185 #[serde(skip_serializing_if = "Option::is_none")]
186 pub user: Option<String>,
187}
188
189#[derive(Deserialize, Serialize)]
190pub struct EditsParams {
191 pub model: String,
192 pub temperature: u32,
193 pub top_p: f32,
194 pub n: u32,
195}
196
197#[derive(Deserialize, Serialize)]
198struct RequestEdit {
199 model: String,
200 input: String,
201 instruction: String,
202 n: u32,
203 temperature: u32,
204 top_p: f32,
205}
206
207#[derive(Deserialize, Serialize, Debug)]
208pub struct EditsResponse {
209 pub object: String,
210 pub created: u32,
211 pub choices: Vec<EditsChoice>,
212 pub usage: Usage,
213}
214
215#[derive(Deserialize, Serialize, Debug)]
216pub struct EditsChoice {
217 pub text: String,
218 pub index: u32,
219}
220
221pub async fn edits(
222 input: &str,
223 instruction: &str,
224 params: &EditsParams,
225 api_key: &str,
226) -> Result<EditsResponse, ClientError> {
227 let client = Client::default();
228
229 let request: RequestEdit = RequestEdit {
230 model: params.model.clone(),
231 input: input.to_string(),
232 instruction: instruction.to_string(),
233 n: params.n,
234 temperature: params.temperature,
235 top_p: params.top_p,
236 };
237
238 let request_string_result = serde_json::to_string(&request);
239 match request_string_result {
240 Ok(request_string) => {
241 let resp_result = client
242 .post("https://api.openai.com/v1/edits")
243 .timeout(Duration::from_secs(30))
244 .insert_header(("Content-Type", "application/json"))
245 .insert_header(("Authorization", format!("Bearer {}", api_key)))
246 .send_body(request_string)
247 .await;
248 match resp_result {
249 Ok(mut resp) => {
250 let bytes_result = resp.body().await;
251 match bytes_result {
252 Ok(bytes) => {
253 let string_result = String::from_utf8(bytes.to_vec());
254 match string_result {
255 Ok(string) => {
256 let parse_result: Result<EditsResponse, serde_json::Error> =
257 serde_json::from_str(string.as_str());
258 match parse_result {
259 Ok(response) => Ok(response),
260 Err(e1) => {
261 let error_result: Result<ErrorResponse, serde_json::Error> =
262 serde_json::from_str(string.as_str());
263 match error_result {
264 Ok(error) => Err(ClientError::APIError(error.error.message)),
265 Err(e2) => Err(ClientError::OtherError(format!("{:?} {:?}",e2, e1))),
266 }
267 },
268 }
269 }
270 Err(e) => Err(ClientError::OtherError(format!("{:?}",e))),
271 }
272 }
273 Err(e) => Err(ClientError::NetworkError(format!("{:?}",e))),
274 }
275 }
276 Err(e) => Err(ClientError::OtherError(format!("{:?}",e))),
277 }
278 }
279 Err(e) => Err(ClientError::OtherError(format!("{:?}",e))),
280 }
281}
282
283pub async fn edits_pretty(input: &str, instruction: &str, model: &str, api_key: &str) -> Result<String, ClientError> {
284 let params = EditsParams {
285 model: model.to_string(),
286 temperature: 0,
287 top_p: 1.0,
288 n: 1,
289 };
290
291 let res = edits(input, instruction, ¶ms, api_key).await?;
292 Ok(res.choices[0].text.clone())
293}
294
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299 use dotenv::dotenv;
300 use std::env;
301
302 #[actix_rt::test]
303 async fn it_works() {
304 dotenv().ok();
305 let api_key = env::var("OPEN_AI_API_KEY").expect("OPEN_AI_API_KEY must be set");
306
307 let model = "text-davinci-003";
308 let max_tokens: u32 = 3;
309 let result: String = completions_pretty(
310 "Is Madonna president of USA? If you ask yes or not. I say:",
311 model,
312 max_tokens,
313 &api_key,
314 ).await.unwrap();
315 println!("result: {}", result);
316
317 let result_edits: String = edits_pretty(
318 "Helsllo, Mick!",
319 "Fix grammar",
320 "text-davinci-edit-001",
321 &api_key,
322 )
323 .await.unwrap();
324 println!("result: {}", result_edits);
325 }
326}