openai_gpt_client/
client.rs1use log::{debug, info, warn};
2use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
3use reqwest::{Client, Error, Response};
4use serde::{Deserialize, Serialize};
5use serde_json::to_value;
6
7use crate::chat::{ChatMessage, ChatRequest, ChatResponse};
8use crate::model_variants::ModelId;
9use crate::text_completion::{TextCompletionRequest, TextCompletionResponse};
10use crate::GptError;
11
12#[derive(Debug, Clone, Copy)]
13pub enum ClientProfile {
14 Chat,
15 Code,
16}
17
18impl ClientProfile {
19 pub fn get_temperature(&self) -> f32 {
20 match self {
21 ClientProfile::Chat => 0.4,
22 ClientProfile::Code => 0.7,
23 }
24 }
25
26 pub fn get_top_p(&self) -> f32 {
27 match self {
28 ClientProfile::Chat => 0.9,
29 ClientProfile::Code => 0.7,
30 }
31 }
32
33 pub fn get_frequency_penalty(&self) -> f32 {
34 match self {
35 ClientProfile::Chat => 0.0,
36 ClientProfile::Code => 0.2,
37 }
38 }
39
40 pub fn get_presence_penalty(&self) -> f32 {
41 match self {
42 ClientProfile::Chat => 0.6,
43 ClientProfile::Code => 0.0,
44 }
45 }
46
47 pub fn get_stop(&self) -> Option<Stop> {
48 None
49 }
50}
51
52#[derive(Debug, Serialize, Deserialize)]
53#[serde(untagged)]
54pub enum Stop {
55 Single(String),
56 Multiple(Vec<String>),
57}
58
59pub struct OpenAiClient {
60 client: Client,
61 profile: ClientProfile,
62}
63
64impl OpenAiClient {
65 pub fn new(api_key: &str, profile: ClientProfile) -> OpenAiClient {
66 let headers = Self::build_headers(api_key);
67
68 let client = Client::builder().default_headers(headers).build().unwrap();
69
70 OpenAiClient { client, profile }
71 }
72
73 fn build_headers(api_key: &str) -> HeaderMap {
74 let mut headers = HeaderMap::new();
75 headers.insert(
76 AUTHORIZATION,
77 HeaderValue::from_str(&format!("Bearer {api_key}")).unwrap(),
78 );
79 headers
80 }
81
82 #[allow(dead_code)]
83 async fn get_request(&self, endpoint: &str) -> Result<Response, Error> {
84 let url = format!("https://api.openai.com/v1/{endpoint}");
85
86 self.client.get(&url).send().await
87 }
88
89 async fn post_request(
90 &self,
91 endpoint: &str,
92 body: serde_json::Value,
93 ) -> Result<Response, Error> {
94 let url = format!("https://api.openai.com/v1/{endpoint}");
95
96 let response = self
98 .client
99 .post(&url)
100 .header(CONTENT_TYPE, HeaderValue::from_static("application/json"))
101 .json(&body)
102 .send()
103 .await?;
104
105 Ok(response)
106 }
107
108 pub async fn completion(
109 &self,
110 model: ModelId,
111 prompt: &str,
112 max_tokens: u16,
113 ) -> Result<String, Error> {
114 let request = TextCompletionRequest {
115 model,
116 prompt: prompt.to_owned(),
117 max_tokens: Some(i32::from(max_tokens)),
118 stop: self.profile.get_stop(),
119 temperature: Some(self.profile.get_temperature() as f64),
120 top_p: Some(self.profile.get_top_p() as f64),
121 frequency_penalty: Some(self.profile.get_frequency_penalty() as f64),
122 presence_penalty: Some(self.profile.get_presence_penalty() as f64),
123 ..Default::default()
124 };
125 debug!("Request: {:?}", request);
126
127 let response = self
128 .post_request("completions", to_value(&request).unwrap())
129 .await?;
130 debug!("Response: {:?}", response);
131
132 let body = response.text().await?;
133 debug!("Body: {:?}", body);
134
135 let completion_response: TextCompletionResponse = serde_json::from_str(&body).unwrap();
137
138 Ok(completion_response.choices.unwrap()[0].text.clone())
140 }
141
142 pub async fn chat(
143 &self,
144 model: ModelId,
145 max_tokens: u16,
146 messages: Vec<ChatMessage>,
147 ) -> Result<ChatMessage, GptError> {
148 info!("Messages:\n {:#?}", messages);
149 let request = ChatRequest {
150 model,
151 messages,
152 max_tokens: Some(i32::from(max_tokens)),
153 stop: self.profile.get_stop(),
154 temperature: Some(self.profile.get_temperature() as f64),
155 top_p: Some(self.profile.get_top_p() as f64),
156 frequency_penalty: Some(self.profile.get_frequency_penalty() as f64),
157 presence_penalty: Some(self.profile.get_presence_penalty() as f64),
158 ..Default::default()
159 };
160 debug!("Request: {:?}", request);
161
162 let response = self
163 .post_request("chat/completions", to_value(&request).unwrap())
164 .await?;
165 debug!("Response: {:?}", response);
166
167 let body = response.text().await?;
168 debug!("Body: {:?}", body);
169
170 let chat_response: ChatResponse = serde_json::from_str(&body)?;
172
173 if let Some(usage) = chat_response.usage {
174 warn!(
175 "Completion: {:?}, total ${}",
176 usage,
177 usage.total_tokens as f32 * 0.002 / 1000.0
178 );
179 }
180
181 Ok(chat_response.choices[0].message.clone())
183 }
184}