openai_safe/
models.rs

1use anyhow::{anyhow, Result};
2use serde::{Deserialize, Serialize};
3use serde_json::{json, Value};
4
5use crate::{
6    constants::{OPENAI_API_URL, OPENAI_BASE_INSTRUCTIONS, OPENAI_FUNCTION_INSTRUCTIONS},
7    domain::{OpenAIRateLimit, OpenAPIChatResponse, OpenAPICompletionsResponse},
8};
9
10#[derive(Deserialize, Serialize, Debug, Clone)]
11pub enum OpenAIModels {
12    Gpt3_5Turbo,
13    Gpt3_5Turbo0613,
14    Gpt3_5Turbo16k,
15    Gpt4,
16    Gpt4_32k,
17    TextDavinci003,
18    Gpt4Turbo,
19}
20
21impl OpenAIModels {
22    pub fn as_str(&self) -> &'static str {
23        match self {
24            //In an API call, you can describe functions to gpt-3.5-turbo-0613 and gpt-4-0613
25            //On June 27, 2023 the stable gpt-3.5-turbo will be automatically upgraded to gpt-3.5-turbo-0613
26            OpenAIModels::Gpt3_5Turbo => "gpt-3.5-turbo",
27            OpenAIModels::Gpt3_5Turbo0613 => "gpt-3.5-turbo-0613",
28            OpenAIModels::Gpt3_5Turbo16k => "gpt-3.5-turbo-16k",
29            OpenAIModels::Gpt4 => "gpt-4-0613",
30            OpenAIModels::Gpt4_32k => "gpt-4-32k",
31            OpenAIModels::TextDavinci003 => "text-davinci-003",
32            OpenAIModels::Gpt4Turbo => "gpt-4-1106-preview",
33        }
34    }
35
36    pub fn default_max_tokens(&self) -> usize {
37        //OpenAI documentation: https://platform.openai.com/docs/models/gpt-3-5
38        //This is the max tokens allowed between prompt & response
39        match self {
40            OpenAIModels::Gpt3_5Turbo => 4096,
41            OpenAIModels::Gpt3_5Turbo0613 => 4096,
42            OpenAIModels::Gpt3_5Turbo16k => 16384,
43            OpenAIModels::Gpt4 => 8192,
44            OpenAIModels::Gpt4_32k => 32768,
45            OpenAIModels::TextDavinci003 => 4097,
46            OpenAIModels::Gpt4Turbo => 128_000,
47        }
48    }
49
50    pub(crate) fn get_endpoint(&self) -> String {
51        //OpenAI documentation: https://platform.openai.com/docs/models/model-endpoint-compatibility
52        match self {
53            OpenAIModels::Gpt3_5Turbo
54            | OpenAIModels::Gpt3_5Turbo0613
55            | OpenAIModels::Gpt3_5Turbo16k
56            | OpenAIModels::Gpt4
57            | OpenAIModels::Gpt4Turbo
58            | OpenAIModels::Gpt4_32k => {
59                format!(
60                    "{OPENAI_API_URL}/v1/chat/completions",
61                    OPENAI_API_URL = *OPENAI_API_URL
62                )
63            }
64            OpenAIModels::TextDavinci003 => format!(
65                "{OPENAI_API_URL}/v1/completions",
66                OPENAI_API_URL = *OPENAI_API_URL
67            ),
68        }
69    }
70
71    pub(crate) fn get_base_instructions(&self, function_call: Option<bool>) -> String {
72        let function_call = function_call.unwrap_or_else(|| self.function_call_default());
73        match function_call {
74            true => OPENAI_FUNCTION_INSTRUCTIONS.to_string(),
75            false => OPENAI_BASE_INSTRUCTIONS.to_string(),
76        }
77    }
78
79    pub(crate) fn function_call_default(&self) -> bool {
80        //OpenAI documentation: https://platform.openai.com/docs/guides/gpt/function-calling
81        match self {
82            OpenAIModels::TextDavinci003 | OpenAIModels::Gpt3_5Turbo | OpenAIModels::Gpt4_32k => {
83                false
84            }
85            OpenAIModels::Gpt3_5Turbo0613
86            | OpenAIModels::Gpt3_5Turbo16k
87            | OpenAIModels::Gpt4
88            | OpenAIModels::Gpt4Turbo => true,
89        }
90    }
91
92    //This method prepares the body of the API call for different models
93    pub(crate) fn get_body(
94        &self,
95        instructions: &str,
96        json_schema: &Value,
97        function_call: bool,
98        max_tokens: &usize,
99        temperature: &u32,
100    ) -> serde_json::Value {
101        match self {
102            //https://platform.openai.com/docs/api-reference/completions/create
103            //For DaVinci model all text goes into the 'prompt' filed of the body
104            OpenAIModels::TextDavinci003 => {
105                let schema_string = serde_json::to_string(json_schema).unwrap_or_default();
106                let base_instructions = self.get_base_instructions(Some(function_call));
107                json!({
108                    "model": self.as_str(),
109                    "max_tokens": max_tokens,
110                    "temperature": temperature,
111                    "prompt": format!(
112                        "{base_instructions}\n\n
113                        Output Json schema:\n
114                        {schema_string}\n\n
115                        {instructions}",
116                    ),
117                })
118            }
119            OpenAIModels::Gpt3_5Turbo
120            | OpenAIModels::Gpt3_5Turbo0613
121            | OpenAIModels::Gpt3_5Turbo16k
122            | OpenAIModels::Gpt4
123            | OpenAIModels::Gpt4Turbo
124            | OpenAIModels::Gpt4_32k => {
125                let base_instructions = self.get_base_instructions(Some(function_call));
126                let system_message = json!({
127                    "role": "system",
128                    "content": base_instructions,
129                });
130
131                match function_call {
132                    //If we choose to use function calling
133                    //https://platform.openai.com/docs/guides/gpt/function-calling
134                    true => {
135                        let user_message = json!({
136                            "role": "user",
137                            "content": instructions,
138                        });
139
140                        let function = json!({
141                            "name": "analyze_data",
142                            "description": "Use this function to compute the answer based on input data, instructions and your language model. Output should be a fully formed JSON object.",
143                            "parameters": json_schema,
144                        });
145
146                        let function_call = json!({
147                            "name": "analyze_data"
148                        });
149
150                        //For ChatGPT we ignore max_tokens. It will default to 'inf'
151                        json!({
152                            "model": self.as_str(),
153                            "temperature": temperature,
154                            "messages": vec![
155                                system_message,
156                                user_message,
157                            ],
158                            "functions": vec![
159                                function,
160                            ],
161                            //This forces ChatGPT to use the function definition
162                            "function_call": function_call,
163                        })
164                    }
165                    //https://platform.openai.com/docs/guides/chat/introduction
166                    false => {
167                        let schema_string = serde_json::to_string(json_schema).unwrap_or_default();
168
169                        let user_message = json!({
170                            "role": "user",
171                            "content": format!(
172                                "Output Json schema:\n
173                                {schema_string}\n\n
174                                {instructions}"
175                            ),
176                        });
177                        //For ChatGPT we ignore max_tokens. It will default to 'inf'
178                        json!({
179                            "model": self.as_str(),
180                            "temperature": temperature,
181                            "messages": vec![
182                                system_message,
183                                user_message,
184                            ],
185                        })
186                    }
187                }
188            }
189        }
190    }
191
192    //This method attempts to convert the provided API response text into the expected struct and extracts the data from the response
193    pub(crate) fn get_data(&self, response_text: &str, function_call: bool) -> Result<String> {
194        match self {
195            //https://platform.openai.com/docs/api-reference/completions/create
196            OpenAIModels::TextDavinci003 => {
197                //Convert API response to struct representing expected response format
198                let completions_response: OpenAPICompletionsResponse =
199                    serde_json::from_str(response_text)?;
200
201                //Extract data part
202                match completions_response.choices {
203                    Some(choices) => Ok(choices.into_iter().filter_map(|item| item.text).collect()),
204                    None => Err(anyhow!(
205                        "Unable to retrieve response from OpenAI Completions API"
206                    )),
207                }
208            }
209            //https://platform.openai.com/docs/guides/chat/introduction
210            OpenAIModels::Gpt3_5Turbo
211            | OpenAIModels::Gpt3_5Turbo0613
212            | OpenAIModels::Gpt3_5Turbo16k
213            | OpenAIModels::Gpt4
214            | OpenAIModels::Gpt4Turbo
215            | OpenAIModels::Gpt4_32k => {
216                //Convert API response to struct representing expected response format
217                let chat_response: OpenAPIChatResponse = serde_json::from_str(response_text)?;
218
219                //Extract data part
220                match chat_response.choices {
221                    Some(choices) => Ok(choices
222                        .into_iter()
223                        .filter_map(|item| {
224                            //For function_call the response is in arguments, and for regular call in content
225                            match function_call {
226                                true => item
227                                    .message
228                                    .function_call
229                                    .map(|function_call| function_call.arguments),
230                                false => item.message.content,
231                            }
232                        })
233                        .collect()),
234                    None => Err(anyhow!("Unable to retrieve response from OpenAI Chat API")),
235                }
236            }
237        }
238    }
239
240    //This function allows to check the rate limits for different models
241    fn get_rate_limit(&self) -> OpenAIRateLimit {
242        //OpenAI documentation: https://platform.openai.com/account/rate-limits
243        //This is the max tokens allowed between prompt & response
244        match self {
245            OpenAIModels::Gpt3_5Turbo => OpenAIRateLimit {
246                tpm: 90_000,
247                rpm: 3_500,
248            },
249            OpenAIModels::Gpt3_5Turbo0613 => OpenAIRateLimit {
250                tpm: 90_000,
251                rpm: 3_500,
252            },
253            OpenAIModels::Gpt3_5Turbo16k => OpenAIRateLimit {
254                tpm: 180_000,
255                rpm: 3_500,
256            },
257            OpenAIModels::Gpt4 => OpenAIRateLimit {
258                tpm: 10_000,
259                rpm: 200,
260            },
261            OpenAIModels::Gpt4Turbo => OpenAIRateLimit {
262                tpm: 10_000,
263                rpm: 200,
264            },
265            OpenAIModels::Gpt4_32k => OpenAIRateLimit {
266                tpm: 10_000,
267                rpm: 200,
268            },
269            OpenAIModels::TextDavinci003 => OpenAIRateLimit {
270                tpm: 250_000,
271                rpm: 3_000,
272            },
273        }
274    }
275
276    //This function checks how many requests can be sent to an OpenAI model within a minute
277    pub fn get_max_requests(&self) -> usize {
278        let rate_limit = self.get_rate_limit();
279
280        //Check max requests based on rpm
281        let max_requests_from_rpm = rate_limit.rpm;
282
283        //Double check max number of requests based on tpm
284        //Assume we will use ~50% of allowed tokens per request (for prompt + response)
285        let max_tokens_per_minute = rate_limit.tpm;
286        let tpm_per_request = (self.default_max_tokens() as f64 * 0.5).ceil() as usize;
287        //Then check how many requests we can process
288        let max_requests_from_tpm = max_tokens_per_minute / tpm_per_request;
289
290        //To be safe we go with smaller of the numbers
291        std::cmp::min(max_requests_from_rpm, max_requests_from_tpm)
292    }
293}
294
295#[cfg(test)]
296mod tests {
297    use crate::models::OpenAIModels;
298    use crate::utils::get_tokenizer;
299
300    #[test]
301    fn it_computes_gpt3_5_tokenization() {
302        let bpe = get_tokenizer(&OpenAIModels::Gpt4_32k).unwrap();
303        let tokenized: Result<Vec<_>, _> = bpe
304            .split_by_token_iter("This is a test         with a lot of spaces", true)
305            .collect();
306        let tokenized = tokenized.unwrap();
307        assert_eq!(
308            tokenized,
309            vec!["This", " is", " a", " test", "        ", " with", " a", " lot", " of", " spaces"]
310        );
311    }
312
313    // Tests for calculating max requests per model
314    #[test]
315    fn test_gpt3_5turbo_max_requests() {
316        let model = OpenAIModels::Gpt3_5Turbo;
317        let max_requests = model.get_max_requests();
318        let expected_max = std::cmp::min(3500, 90000 / ((4096_f64 * 0.5).ceil() as usize));
319        assert_eq!(max_requests, expected_max);
320    }
321
322    #[test]
323    fn test_gpt3_5turbo0613_max_requests() {
324        let model = OpenAIModels::Gpt3_5Turbo0613;
325        let max_requests = model.get_max_requests();
326        let expected_max = std::cmp::min(3500, 90000 / ((4096_f64 * 0.5).ceil() as usize));
327        assert_eq!(max_requests, expected_max);
328    }
329
330    #[test]
331    fn test_gpt3_5turbo16k_max_requests() {
332        let model = OpenAIModels::Gpt3_5Turbo16k;
333        let max_requests = model.get_max_requests();
334        let expected_max = std::cmp::min(3500, 180000 / ((16384_f64 * 0.5).ceil() as usize));
335        assert_eq!(max_requests, expected_max);
336    }
337
338    #[test]
339    fn test_gpt4_max_requests() {
340        let model = OpenAIModels::Gpt4;
341        let max_requests = model.get_max_requests();
342        let expected_max = std::cmp::min(200, 10000 / ((8192_f64 * 0.5).ceil() as usize));
343        assert_eq!(max_requests, expected_max);
344    }
345}