1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
use super::util;
use reqwest::{
blocking::Client,
header::{HeaderMap, HeaderValue},
};
use serde::{Deserialize, Serialize};
use serde_json::{from_str, Value};
use std::error::Error;
use std::io::Read;
const OPENAI_API_URL: &str = "https://api.openai.com/v1/completions";
const OPENAI_MODEL: &str = "text-davinci-003";
const MAX_TOKENS: u32 = 4097;
const TEMPERATURE: f32 = 0.2;
type BoxResult<T> = Result<T, Box<dyn Error>>;
#[derive(Serialize, Deserialize, Debug)]
struct Prompt {
model: String,
prompt: String,
temperature: f32,
max_tokens: u32,
}
pub struct GPTClient {
api_key: String,
url: String,
}
impl GPTClient {
pub fn new(api_key: String) -> Self {
GPTClient {
api_key,
url: String::from(OPENAI_API_URL),
}
}
pub fn prompt(&self, prompt: String) -> BoxResult<String> {
let prompt_length = prompt.len() as u32;
if prompt_length >= MAX_TOKENS {
return Err(format!(
"Prompt cannot exceed length of {} characters",
MAX_TOKENS - 1
)
.into());
}
let p = Prompt {
max_tokens: MAX_TOKENS - prompt_length,
model: String::from(OPENAI_MODEL),
prompt,
temperature: TEMPERATURE,
};
let mut auth = String::from("Bearer ");
auth.push_str(&self.api_key);
let mut headers = HeaderMap::new();
headers.insert("Authorization", HeaderValue::from_str(auth.as_str())?);
headers.insert("Content-Type", HeaderValue::from_str("application/json")?);
let body = serde_json::to_string(&p)?;
let client = Client::new();
let mut res = client.post(&self.url).body(body).headers(headers).send()?;
let mut response_body = String::new();
res.read_to_string(&mut response_body)?;
let json_object: Value = from_str(&response_body)?;
let answer = json_object["choices"][0]["text"].as_str();
match answer {
Some(a) => Ok(String::from(a)),
None => {
util::pretty_print(&response_body, "json");
Err(format!("JSON parse error").into())
}
}
}
}