1use anyhow::{anyhow, Result};
2use serde::{Deserialize, Serialize};
3use serde_json::{json, Value};
4
5use crate::{
6 constants::{OPENAI_API_URL, OPENAI_BASE_INSTRUCTIONS, OPENAI_FUNCTION_INSTRUCTIONS},
7 domain::{OpenAIRateLimit, OpenAPIChatResponse, OpenAPICompletionsResponse},
8};
9
10#[derive(Deserialize, Serialize, Debug, Clone)]
11pub enum OpenAIModels {
12 Gpt3_5Turbo,
13 Gpt3_5Turbo0613,
14 Gpt3_5Turbo16k,
15 Gpt4,
16 Gpt4_32k,
17 TextDavinci003,
18 Gpt4Turbo,
19}
20
21impl OpenAIModels {
22 pub fn as_str(&self) -> &'static str {
23 match self {
24 OpenAIModels::Gpt3_5Turbo => "gpt-3.5-turbo",
27 OpenAIModels::Gpt3_5Turbo0613 => "gpt-3.5-turbo-0613",
28 OpenAIModels::Gpt3_5Turbo16k => "gpt-3.5-turbo-16k",
29 OpenAIModels::Gpt4 => "gpt-4-0613",
30 OpenAIModels::Gpt4_32k => "gpt-4-32k",
31 OpenAIModels::TextDavinci003 => "text-davinci-003",
32 OpenAIModels::Gpt4Turbo => "gpt-4-1106-preview",
33 }
34 }
35
36 pub fn default_max_tokens(&self) -> usize {
37 match self {
40 OpenAIModels::Gpt3_5Turbo => 4096,
41 OpenAIModels::Gpt3_5Turbo0613 => 4096,
42 OpenAIModels::Gpt3_5Turbo16k => 16384,
43 OpenAIModels::Gpt4 => 8192,
44 OpenAIModels::Gpt4_32k => 32768,
45 OpenAIModels::TextDavinci003 => 4097,
46 OpenAIModels::Gpt4Turbo => 128_000,
47 }
48 }
49
50 pub(crate) fn get_endpoint(&self) -> String {
51 match self {
53 OpenAIModels::Gpt3_5Turbo
54 | OpenAIModels::Gpt3_5Turbo0613
55 | OpenAIModels::Gpt3_5Turbo16k
56 | OpenAIModels::Gpt4
57 | OpenAIModels::Gpt4Turbo
58 | OpenAIModels::Gpt4_32k => {
59 format!(
60 "{OPENAI_API_URL}/v1/chat/completions",
61 OPENAI_API_URL = *OPENAI_API_URL
62 )
63 }
64 OpenAIModels::TextDavinci003 => format!(
65 "{OPENAI_API_URL}/v1/completions",
66 OPENAI_API_URL = *OPENAI_API_URL
67 ),
68 }
69 }
70
71 pub(crate) fn get_base_instructions(&self, function_call: Option<bool>) -> String {
72 let function_call = function_call.unwrap_or_else(|| self.function_call_default());
73 match function_call {
74 true => OPENAI_FUNCTION_INSTRUCTIONS.to_string(),
75 false => OPENAI_BASE_INSTRUCTIONS.to_string(),
76 }
77 }
78
79 pub(crate) fn function_call_default(&self) -> bool {
80 match self {
82 OpenAIModels::TextDavinci003 | OpenAIModels::Gpt3_5Turbo | OpenAIModels::Gpt4_32k => {
83 false
84 }
85 OpenAIModels::Gpt3_5Turbo0613
86 | OpenAIModels::Gpt3_5Turbo16k
87 | OpenAIModels::Gpt4
88 | OpenAIModels::Gpt4Turbo => true,
89 }
90 }
91
92 pub(crate) fn get_body(
94 &self,
95 instructions: &str,
96 json_schema: &Value,
97 function_call: bool,
98 max_tokens: &usize,
99 temperature: &u32,
100 ) -> serde_json::Value {
101 match self {
102 OpenAIModels::TextDavinci003 => {
105 let schema_string = serde_json::to_string(json_schema).unwrap_or_default();
106 let base_instructions = self.get_base_instructions(Some(function_call));
107 json!({
108 "model": self.as_str(),
109 "max_tokens": max_tokens,
110 "temperature": temperature,
111 "prompt": format!(
112 "{base_instructions}\n\n
113 Output Json schema:\n
114 {schema_string}\n\n
115 {instructions}",
116 ),
117 })
118 }
119 OpenAIModels::Gpt3_5Turbo
120 | OpenAIModels::Gpt3_5Turbo0613
121 | OpenAIModels::Gpt3_5Turbo16k
122 | OpenAIModels::Gpt4
123 | OpenAIModels::Gpt4Turbo
124 | OpenAIModels::Gpt4_32k => {
125 let base_instructions = self.get_base_instructions(Some(function_call));
126 let system_message = json!({
127 "role": "system",
128 "content": base_instructions,
129 });
130
131 match function_call {
132 true => {
135 let user_message = json!({
136 "role": "user",
137 "content": instructions,
138 });
139
140 let function = json!({
141 "name": "analyze_data",
142 "description": "Use this function to compute the answer based on input data, instructions and your language model. Output should be a fully formed JSON object.",
143 "parameters": json_schema,
144 });
145
146 let function_call = json!({
147 "name": "analyze_data"
148 });
149
150 json!({
152 "model": self.as_str(),
153 "temperature": temperature,
154 "messages": vec![
155 system_message,
156 user_message,
157 ],
158 "functions": vec![
159 function,
160 ],
161 "function_call": function_call,
163 })
164 }
165 false => {
167 let schema_string = serde_json::to_string(json_schema).unwrap_or_default();
168
169 let user_message = json!({
170 "role": "user",
171 "content": format!(
172 "Output Json schema:\n
173 {schema_string}\n\n
174 {instructions}"
175 ),
176 });
177 json!({
179 "model": self.as_str(),
180 "temperature": temperature,
181 "messages": vec![
182 system_message,
183 user_message,
184 ],
185 })
186 }
187 }
188 }
189 }
190 }
191
192 pub(crate) fn get_data(&self, response_text: &str, function_call: bool) -> Result<String> {
194 match self {
195 OpenAIModels::TextDavinci003 => {
197 let completions_response: OpenAPICompletionsResponse =
199 serde_json::from_str(response_text)?;
200
201 match completions_response.choices {
203 Some(choices) => Ok(choices.into_iter().filter_map(|item| item.text).collect()),
204 None => Err(anyhow!(
205 "Unable to retrieve response from OpenAI Completions API"
206 )),
207 }
208 }
209 OpenAIModels::Gpt3_5Turbo
211 | OpenAIModels::Gpt3_5Turbo0613
212 | OpenAIModels::Gpt3_5Turbo16k
213 | OpenAIModels::Gpt4
214 | OpenAIModels::Gpt4Turbo
215 | OpenAIModels::Gpt4_32k => {
216 let chat_response: OpenAPIChatResponse = serde_json::from_str(response_text)?;
218
219 match chat_response.choices {
221 Some(choices) => Ok(choices
222 .into_iter()
223 .filter_map(|item| {
224 match function_call {
226 true => item
227 .message
228 .function_call
229 .map(|function_call| function_call.arguments),
230 false => item.message.content,
231 }
232 })
233 .collect()),
234 None => Err(anyhow!("Unable to retrieve response from OpenAI Chat API")),
235 }
236 }
237 }
238 }
239
240 fn get_rate_limit(&self) -> OpenAIRateLimit {
242 match self {
245 OpenAIModels::Gpt3_5Turbo => OpenAIRateLimit {
246 tpm: 90_000,
247 rpm: 3_500,
248 },
249 OpenAIModels::Gpt3_5Turbo0613 => OpenAIRateLimit {
250 tpm: 90_000,
251 rpm: 3_500,
252 },
253 OpenAIModels::Gpt3_5Turbo16k => OpenAIRateLimit {
254 tpm: 180_000,
255 rpm: 3_500,
256 },
257 OpenAIModels::Gpt4 => OpenAIRateLimit {
258 tpm: 10_000,
259 rpm: 200,
260 },
261 OpenAIModels::Gpt4Turbo => OpenAIRateLimit {
262 tpm: 10_000,
263 rpm: 200,
264 },
265 OpenAIModels::Gpt4_32k => OpenAIRateLimit {
266 tpm: 10_000,
267 rpm: 200,
268 },
269 OpenAIModels::TextDavinci003 => OpenAIRateLimit {
270 tpm: 250_000,
271 rpm: 3_000,
272 },
273 }
274 }
275
276 pub fn get_max_requests(&self) -> usize {
278 let rate_limit = self.get_rate_limit();
279
280 let max_requests_from_rpm = rate_limit.rpm;
282
283 let max_tokens_per_minute = rate_limit.tpm;
286 let tpm_per_request = (self.default_max_tokens() as f64 * 0.5).ceil() as usize;
287 let max_requests_from_tpm = max_tokens_per_minute / tpm_per_request;
289
290 std::cmp::min(max_requests_from_rpm, max_requests_from_tpm)
292 }
293}
294
295#[cfg(test)]
296mod tests {
297 use crate::models::OpenAIModels;
298 use crate::utils::get_tokenizer;
299
300 #[test]
301 fn it_computes_gpt3_5_tokenization() {
302 let bpe = get_tokenizer(&OpenAIModels::Gpt4_32k).unwrap();
303 let tokenized: Result<Vec<_>, _> = bpe
304 .split_by_token_iter("This is a test with a lot of spaces", true)
305 .collect();
306 let tokenized = tokenized.unwrap();
307 assert_eq!(
308 tokenized,
309 vec!["This", " is", " a", " test", " ", " with", " a", " lot", " of", " spaces"]
310 );
311 }
312
313 #[test]
315 fn test_gpt3_5turbo_max_requests() {
316 let model = OpenAIModels::Gpt3_5Turbo;
317 let max_requests = model.get_max_requests();
318 let expected_max = std::cmp::min(3500, 90000 / ((4096_f64 * 0.5).ceil() as usize));
319 assert_eq!(max_requests, expected_max);
320 }
321
322 #[test]
323 fn test_gpt3_5turbo0613_max_requests() {
324 let model = OpenAIModels::Gpt3_5Turbo0613;
325 let max_requests = model.get_max_requests();
326 let expected_max = std::cmp::min(3500, 90000 / ((4096_f64 * 0.5).ceil() as usize));
327 assert_eq!(max_requests, expected_max);
328 }
329
330 #[test]
331 fn test_gpt3_5turbo16k_max_requests() {
332 let model = OpenAIModels::Gpt3_5Turbo16k;
333 let max_requests = model.get_max_requests();
334 let expected_max = std::cmp::min(3500, 180000 / ((16384_f64 * 0.5).ceil() as usize));
335 assert_eq!(max_requests, expected_max);
336 }
337
338 #[test]
339 fn test_gpt4_max_requests() {
340 let model = OpenAIModels::Gpt4;
341 let max_requests = model.get_max_requests();
342 let expected_max = std::cmp::min(200, 10000 / ((8192_f64 * 0.5).ceil() as usize));
343 assert_eq!(max_requests, expected_max);
344 }
345}