openai_api_client/
lib.rs

1use thiserror::Error;
2use awc::Client;
3use serde::Deserialize;
4use serde::Serialize;
5use std::collections::HashMap;
6use std::time::Duration;
7
8
9pub async fn completions(
10    prompt: &str,
11    params: &CompletionsParams,
12    api_key: &str,
13) -> Result<CompletionsResponse, ClientError> {
14    let client = Client::default();
15
16    let request = Request {
17        model: params.model.clone(),
18        prompt: prompt.to_string(),
19        temperature: params.temperature,
20        max_tokens: params.max_tokens,
21        top_p: params.top_p,
22        frequency_penalty: params.frequency_penalty,
23        presence_penalty: params.presence_penalty,
24        stop: params.stop.clone(),
25        suffix: params.suffix.clone(),
26        logprobs: params.logprobs,
27        echo: params.echo,
28        best_of: params.best_of,
29        n: params.n,
30        stream: params.stream,
31        logit_bias: params.logit_bias.clone(),
32        user: params.user.clone(),
33    };
34
35    let request = serde_json::to_string(&request)
36        .map_err(|e| ClientError::OtherError(format!("{:?}",e)))?;
37    let response = client
38        .post("https://api.openai.com/v1/completions")
39        .timeout(Duration::from_secs(30))
40        .insert_header(("Content-Type", "application/json"))
41        .insert_header(("Authorization", format!("Bearer {}", api_key)))
42        .send_body(request)
43        .await
44        .map_err(|e| ClientError::NetworkError(format!("{:?}",e)))?
45        .body()
46        .await
47        .map_err(|e| ClientError::NetworkError(format!("{:?}",e)))?;
48    let response_str = std::str::from_utf8(response.as_ref())
49        .map_err(|e| ClientError::OtherError(format!("{:?}",e)))?;
50    
51    let completions_response: CompletionsResponse = match serde_json::from_str(response_str) {
52        Ok(response) => response,
53        Err(e1) => {
54            let error_response: ErrorResponse = match serde_json::from_str(response_str) {
55                Ok(response) => response,
56                Err(e2) => {
57                    return Err(ClientError::OtherError(format!("{:?} {:?}",e2, e1)));
58                }
59            };
60            return Err(ClientError::APIError(error_response.error.message));
61        }
62    };
63    Ok(completions_response)
64
65}
66
67#[derive(Debug, Error)]
68pub enum ClientError {
69    #[error("OpenAI API error: `{0}`")]
70    APIError(String),
71    #[error("Network error: `{0}`")]
72    NetworkError(String),
73    #[error("Other error: `{0}`")]
74    OtherError(String),
75}
76
77pub async fn completions_pretty (
78    prompt: &str,
79    model: &str,
80    max_tokens: u32,
81    api_key: &str,
82) -> Result<String, ClientError> {
83    let params = CompletionsParams {
84        model: model.to_string(),
85        temperature: 0,
86        max_tokens: max_tokens,
87        top_p: 1.0,
88        frequency_penalty: 0.0,
89        presence_penalty: 0.0,
90        stop: None,
91        suffix: None,
92        n: 1,
93        stream: false,
94        logprobs: None,
95        echo: false,
96        best_of: 1,
97        logit_bias: None,
98        user: None,
99    };
100
101    let res = completions(prompt, &params, api_key).await?;
102    Ok(res.choices[0].text.clone())
103}
104
105
106
107#[derive(Serialize, Deserialize, Debug)]
108pub struct ErrorResponse {
109    pub error: ErrorResponseObject,
110}
111
112#[derive(Serialize, Deserialize, Debug)]
113pub struct ErrorResponseObject {
114    pub message: String,
115    pub r#type: String,
116    pub param: Option<String>,
117    pub code: Option<String>,
118}
119
120#[derive(Deserialize, Serialize)]
121pub struct CompletionsParams {
122    pub model: String,
123    pub temperature: u32,
124    pub max_tokens: u32,
125    pub top_p: f32,
126    pub frequency_penalty: f32,
127    pub presence_penalty: f32,
128    pub stop: Option<Vec<String>>,
129    pub suffix: Option<String>,
130    pub n: u32,
131    pub stream: bool,
132    pub logprobs: Option<u32>,
133    pub echo: bool,
134    pub best_of: u32,
135    pub logit_bias: Option<HashMap<String, i32>>,
136    pub user: Option<String>,
137}
138
139#[derive(Deserialize, Serialize, Debug)]
140pub struct CompletionsResponse {
141    pub id: String,
142    pub object: String,
143    pub created: u32,
144    pub model: String,
145    pub choices: Vec<CompletionsChoice>,
146    pub usage: Usage,
147}
148
149#[derive(Deserialize, Serialize, Debug)]
150pub struct CompletionsChoice {
151    pub text: String,
152    pub index: u32,
153    pub logprobs: Option<String>,
154    pub finish_reason: String,
155}
156
157#[derive(Deserialize, Serialize, Debug)]
158pub struct Usage {
159    pub prompt_tokens: u32,
160    pub completion_tokens: u32,
161    pub total_tokens: u32,
162}
163
164#[derive(Deserialize, Serialize)]
165pub struct Request {
166    pub model: String,
167    pub prompt: String,
168    pub temperature: u32,
169    pub max_tokens: u32,
170    pub top_p: f32,
171    pub frequency_penalty: f32,
172    pub presence_penalty: f32,
173    #[serde(skip_serializing_if = "Option::is_none")]
174    pub stop: Option<Vec<String>>,
175    #[serde(skip_serializing_if = "Option::is_none")]
176    pub suffix: Option<String>,
177    pub n: u32,
178    pub stream: bool,
179    #[serde(skip_serializing_if = "Option::is_none")]
180    pub logprobs: Option<u32>,
181    pub echo: bool,
182    pub best_of: u32,
183    #[serde(skip_serializing_if = "Option::is_none")]
184    pub logit_bias: Option<HashMap<String, i32>>,
185    #[serde(skip_serializing_if = "Option::is_none")]
186    pub user: Option<String>,
187}
188
189#[derive(Deserialize, Serialize)]
190pub struct EditsParams {
191    pub model: String,
192    pub temperature: u32,
193    pub top_p: f32,
194    pub n: u32,
195}
196
197#[derive(Deserialize, Serialize)]
198struct RequestEdit {
199    model: String,
200    input: String,
201    instruction: String,
202    n: u32,
203    temperature: u32,
204    top_p: f32,
205}
206
207#[derive(Deserialize, Serialize, Debug)]
208pub struct EditsResponse {
209    pub object: String,
210    pub created: u32,
211    pub choices: Vec<EditsChoice>,
212    pub usage: Usage,
213}
214
215#[derive(Deserialize, Serialize, Debug)]
216pub struct EditsChoice {
217    pub text: String,
218    pub index: u32,
219}
220
221pub async fn edits(
222    input: &str,
223    instruction: &str,
224    params: &EditsParams,
225    api_key: &str,
226) -> Result<EditsResponse, ClientError> {
227    let client = Client::default();
228
229    let request: RequestEdit = RequestEdit {
230        model: params.model.clone(),
231        input: input.to_string(),
232        instruction: instruction.to_string(),
233        n: params.n,
234        temperature: params.temperature,
235        top_p: params.top_p,
236    };
237
238    let request_string_result = serde_json::to_string(&request);
239    match request_string_result {
240        Ok(request_string) => {
241            let resp_result = client
242                .post("https://api.openai.com/v1/edits")
243                .timeout(Duration::from_secs(30))
244                .insert_header(("Content-Type", "application/json"))
245                .insert_header(("Authorization", format!("Bearer {}", api_key)))
246                .send_body(request_string)
247                .await;
248            match resp_result {
249                Ok(mut resp) => {
250                    let bytes_result = resp.body().await;
251                    match bytes_result {
252                        Ok(bytes) => {
253                            let string_result = String::from_utf8(bytes.to_vec());
254                            match string_result {
255                                Ok(string) => {
256                                    let parse_result: Result<EditsResponse, serde_json::Error> =
257                                        serde_json::from_str(string.as_str());
258                                    match parse_result {
259                                        Ok(response) => Ok(response),
260                                        Err(e1) => {
261                                            let error_result: Result<ErrorResponse, serde_json::Error> =
262                                                serde_json::from_str(string.as_str());
263                                            match error_result {
264                                                Ok(error) => Err(ClientError::APIError(error.error.message)),
265                                                Err(e2) => Err(ClientError::OtherError(format!("{:?} {:?}",e2, e1))),
266                                            }
267                                        },
268                                    }
269                                }
270                                Err(e) => Err(ClientError::OtherError(format!("{:?}",e))),
271                            }
272                        }
273                        Err(e) => Err(ClientError::NetworkError(format!("{:?}",e))),
274                    }
275                }
276                Err(e) => Err(ClientError::OtherError(format!("{:?}",e))),
277            }
278        }
279        Err(e) => Err(ClientError::OtherError(format!("{:?}",e))),
280    }
281}
282
283pub async fn edits_pretty(input: &str, instruction: &str, model: &str, api_key: &str) -> Result<String, ClientError> {
284    let params = EditsParams {
285        model: model.to_string(),
286        temperature: 0,
287        top_p: 1.0,
288        n: 1,
289    };
290
291    let res = edits(input, instruction, &params, api_key).await?;
292    Ok(res.choices[0].text.clone())
293}
294
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299    use dotenv::dotenv;
300    use std::env;
301
302    #[actix_rt::test]
303    async fn it_works() {
304        dotenv().ok();
305        let api_key = env::var("OPEN_AI_API_KEY").expect("OPEN_AI_API_KEY must be set");
306
307        let model = "text-davinci-003";
308        let max_tokens: u32 = 3;
309        let result: String = completions_pretty(
310            "Is Madonna president of USA? If you ask yes or not. I say:",
311            model,
312            max_tokens,
313            &api_key,
314        ).await.unwrap();
315        println!("result: {}", result);
316
317        let result_edits: String = edits_pretty(
318            "Helsllo, Mick!",
319            "Fix grammar",
320            "text-davinci-edit-001",
321            &api_key,
322        )
323        .await.unwrap();
324        println!("result: {}", result_edits);
325    }
326}