1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
use crate::error;
use openai_api_client::{
ClientError, CompletionsParams, CompletionsResponse, ErrorResponse, Request,
};
use reqwest::{header, ClientBuilder};
use std::time::Duration;
static TEXT_DAVINCI_MODEL: &str = "text-davinci-003";
pub async fn completions(prompt: &str, api_key: &str) -> std::result::Result<String, error::Errpr> {
let model = TEXT_DAVINCI_MODEL;
let max_tokens: u32 = 3000;
let res = completions_inner(prompt, model, max_tokens, api_key).await?;
Ok(res)
}
pub async fn completions_inner(
prompt: &str,
model: &str,
max_tokens: u32,
api_key: &str,
) -> std::result::Result<String, ClientError> {
let params = CompletionsParams {
model: model.to_string(),
temperature: 0,
max_tokens,
top_p: 1.0,
frequency_penalty: 0.0,
presence_penalty: 0.0,
stop: None,
suffix: None,
n: 1,
stream: false,
logprobs: None,
echo: false,
best_of: 1,
logit_bias: None,
user: None,
};
let request = Request {
model: params.model.clone(),
prompt: prompt.to_string(),
temperature: params.temperature,
max_tokens: params.max_tokens,
top_p: params.top_p,
frequency_penalty: params.frequency_penalty,
presence_penalty: params.presence_penalty,
stop: params.stop.clone(),
suffix: params.suffix.clone(),
logprobs: params.logprobs,
echo: params.echo,
best_of: params.best_of,
n: params.n,
stream: params.stream,
logit_bias: params.logit_bias.clone(),
user: params.user.clone(),
};
let request = serde_json::to_string(&request).unwrap();
let mut header = header::HeaderMap::new();
header.insert("Content-Type", "application/json".parse().unwrap());
header.insert(
"Authorization",
format!("Bearer {api_key}").parse().unwrap(),
);
let client = ClientBuilder::new()
.default_headers(header)
.build()
.unwrap();
let response = client
.post("https://api.openai.com/v1/completions")
.timeout(Duration::from_secs(60))
.body(request)
.send()
.await
.map_err(|e| ClientError::NetworkError(format!("{e:?}")))?
.bytes()
.await
.map_err(|e| ClientError::NetworkError(format!("{e:?}")))?;
let response_str =
std::str::from_utf8(&response).map_err(|e| ClientError::OtherError(format!("{e:?}")))?;
let completions_response: CompletionsResponse = match serde_json::from_str(response_str) {
Ok(response) => response,
Err(e1) => {
let error_response: ErrorResponse = match serde_json::from_str(response_str) {
Ok(response) => response,
Err(e2) => {
return Err(ClientError::OtherError(format!("{e2:?} {e1:?}")));
}
};
return Err(ClientError::APIError(error_response.error.message));
}
};
Ok(completions_response.choices[0].text.clone())
}