1use anyhow::{anyhow, Result};
2use log::{error, info, warn};
3use reqwest::{header, Client};
4use schemars::{schema_for, JsonSchema};
5use serde::{de::DeserializeOwned, Deserialize, Serialize};
6use serde_json::{json, Value};
7
8use crate::{
9 constants::{OPENAI_API_URL, OPENAI_BASE_INSTRUCTIONS, OPENAI_FUNCTION_INSTRUCTIONS},
10 domain::{OpenAIDataResponse, OpenAPIChatResponse, OpenAPICompletionsResponse, RateLimit},
11 utils::get_tokenizer_old,
12};
13
14#[deprecated(
17 since = "0.6.1",
18 note = "This struct is deprecated. Please use the `llm_models::OpenAIModels` struct for latest functionality."
19)]
20#[derive(Deserialize, Serialize, Debug, Clone)]
21pub enum OpenAIModels {
22 Gpt3_5Turbo,
23 Gpt3_5Turbo0613,
24 Gpt3_5Turbo16k,
25 Gpt4,
26 Gpt4_32k,
27 TextDavinci003,
28 Gpt4Turbo,
29 Gpt4o,
30}
31
32impl OpenAIModels {
33 pub fn as_str(&self) -> &'static str {
34 match self {
35 OpenAIModels::Gpt3_5Turbo => "gpt-3.5-turbo",
38 OpenAIModels::Gpt3_5Turbo0613 => "gpt-3.5-turbo-0613",
39 OpenAIModels::Gpt3_5Turbo16k => "gpt-3.5-turbo-16k",
40 OpenAIModels::Gpt4 => "gpt-4-0613",
41 OpenAIModels::Gpt4_32k => "gpt-4-32k",
42 OpenAIModels::TextDavinci003 => "text-davinci-003",
43 OpenAIModels::Gpt4Turbo => "gpt-4-1106-preview",
44 OpenAIModels::Gpt4o => "gpt-4o",
45 }
46 }
47
48 pub fn default_max_tokens(&self) -> usize {
49 match self {
52 OpenAIModels::Gpt3_5Turbo => 4096,
53 OpenAIModels::Gpt3_5Turbo0613 => 4096,
54 OpenAIModels::Gpt3_5Turbo16k => 16384,
55 OpenAIModels::Gpt4 => 8192,
56 OpenAIModels::Gpt4_32k => 32768,
57 OpenAIModels::TextDavinci003 => 4097,
58 OpenAIModels::Gpt4Turbo => 128_000,
59 OpenAIModels::Gpt4o => 128_000,
60 }
61 }
62
63 pub(crate) fn get_endpoint(&self) -> String {
64 match self {
66 OpenAIModels::Gpt3_5Turbo
67 | OpenAIModels::Gpt3_5Turbo0613
68 | OpenAIModels::Gpt3_5Turbo16k
69 | OpenAIModels::Gpt4
70 | OpenAIModels::Gpt4Turbo
71 | OpenAIModels::Gpt4o
72 | OpenAIModels::Gpt4_32k => {
73 format!(
74 "{OPENAI_API_URL}/v1/chat/completions",
75 OPENAI_API_URL = *OPENAI_API_URL
76 )
77 }
78 OpenAIModels::TextDavinci003 => format!(
79 "{OPENAI_API_URL}/v1/completions",
80 OPENAI_API_URL = *OPENAI_API_URL
81 ),
82 }
83 }
84
85 pub(crate) fn get_base_instructions(&self, function_call: Option<bool>) -> String {
86 let function_call = function_call.unwrap_or_else(|| self.function_call_default());
87 match function_call {
88 true => OPENAI_FUNCTION_INSTRUCTIONS.to_string(),
89 false => OPENAI_BASE_INSTRUCTIONS.to_string(),
90 }
91 }
92
93 pub(crate) fn function_call_default(&self) -> bool {
94 match self {
96 OpenAIModels::TextDavinci003 | OpenAIModels::Gpt3_5Turbo | OpenAIModels::Gpt4_32k => {
97 false
98 }
99 OpenAIModels::Gpt3_5Turbo0613
100 | OpenAIModels::Gpt3_5Turbo16k
101 | OpenAIModels::Gpt4
102 | OpenAIModels::Gpt4Turbo
103 | OpenAIModels::Gpt4o => true,
104 }
105 }
106
107 pub(crate) fn get_body(
109 &self,
110 instructions: &str,
111 json_schema: &Value,
112 function_call: bool,
113 max_tokens: &usize,
114 temperature: &u32,
115 ) -> serde_json::Value {
116 match self {
117 OpenAIModels::TextDavinci003 => {
120 let schema_string = serde_json::to_string(json_schema).unwrap_or_default();
121 let base_instructions = self.get_base_instructions(Some(function_call));
122 json!({
123 "model": self.as_str(),
124 "max_tokens": max_tokens,
125 "temperature": temperature,
126 "prompt": format!(
127 "{base_instructions}\n\n
128 Output Json schema:\n
129 {schema_string}\n\n
130 {instructions}",
131 ),
132 })
133 }
134 OpenAIModels::Gpt3_5Turbo
135 | OpenAIModels::Gpt3_5Turbo0613
136 | OpenAIModels::Gpt3_5Turbo16k
137 | OpenAIModels::Gpt4
138 | OpenAIModels::Gpt4Turbo
139 | OpenAIModels::Gpt4o
140 | OpenAIModels::Gpt4_32k => {
141 let base_instructions = self.get_base_instructions(Some(function_call));
142 let system_message = json!({
143 "role": "system",
144 "content": base_instructions,
145 });
146
147 match function_call {
148 true => {
151 let user_message = json!({
152 "role": "user",
153 "content": instructions,
154 });
155
156 let function = json!({
157 "name": "analyze_data",
158 "description": "Use this function to compute the answer based on input data, instructions and your language model. Output should be a fully formed JSON object.",
159 "parameters": json_schema,
160 });
161
162 let function_call = json!({
163 "name": "analyze_data"
164 });
165
166 json!({
168 "model": self.as_str(),
169 "temperature": temperature,
170 "messages": vec![
171 system_message,
172 user_message,
173 ],
174 "functions": vec![
175 function,
176 ],
177 "function_call": function_call,
179 })
180 }
181 false => {
183 let schema_string = serde_json::to_string(json_schema).unwrap_or_default();
184
185 let user_message = json!({
186 "role": "user",
187 "content": format!(
188 "Output Json schema:\n
189 {schema_string}\n\n
190 {instructions}"
191 ),
192 });
193 json!({
195 "model": self.as_str(),
196 "temperature": temperature,
197 "messages": vec![
198 system_message,
199 user_message,
200 ],
201 })
202 }
203 }
204 }
205 }
206 }
207 pub async fn call_api(
213 &self,
214 api_key: &str,
215 body: &serde_json::Value,
216 debug: bool,
217 ) -> Result<String> {
218 let model_url = self.get_endpoint();
220
221 let client = Client::new();
223
224 let response = client
226 .post(model_url)
227 .header(header::CONTENT_TYPE, "application/json")
228 .bearer_auth(api_key)
229 .json(&body)
230 .send()
231 .await?;
232
233 let response_status = response.status();
234 let response_text = response.text().await?;
235
236 if debug {
237 info!(
238 "[debug] OpenAI API response: [{}] {:#?}",
239 &response_status, &response_text
240 );
241 }
242
243 Ok(response_text)
244 }
245
246 pub(crate) fn get_data(&self, response_text: &str, function_call: bool) -> Result<String> {
248 match self {
249 OpenAIModels::TextDavinci003 => {
251 let completions_response: OpenAPICompletionsResponse =
253 serde_json::from_str(response_text)?;
254
255 match completions_response.choices {
257 Some(choices) => Ok(choices.into_iter().filter_map(|item| item.text).collect()),
258 None => Err(anyhow!(
259 "Unable to retrieve response from OpenAI Completions API"
260 )),
261 }
262 }
263 OpenAIModels::Gpt3_5Turbo
265 | OpenAIModels::Gpt3_5Turbo0613
266 | OpenAIModels::Gpt3_5Turbo16k
267 | OpenAIModels::Gpt4
268 | OpenAIModels::Gpt4Turbo
269 | OpenAIModels::Gpt4o
270 | OpenAIModels::Gpt4_32k => {
271 let chat_response: OpenAPIChatResponse = serde_json::from_str(response_text)?;
273
274 match chat_response.choices {
276 Some(choices) => Ok(choices
277 .into_iter()
278 .filter_map(|item| {
279 match function_call {
281 true => item
282 .message
283 .function_call
284 .map(|function_call| function_call.arguments),
285 false => item.message.content,
286 }
287 })
288 .collect()),
289 None => Err(anyhow!("Unable to retrieve response from OpenAI Chat API")),
290 }
291 }
292 }
293 }
294
295 fn get_rate_limit(&self) -> RateLimit {
297 match self {
300 OpenAIModels::Gpt3_5Turbo => RateLimit {
301 tpm: 2_000_000,
302 rpm: 10_000,
303 },
304 OpenAIModels::Gpt3_5Turbo0613 => RateLimit {
305 tpm: 2_000_000,
306 rpm: 10_000,
307 },
308 OpenAIModels::Gpt3_5Turbo16k => RateLimit {
309 tpm: 2_000_000,
310 rpm: 10_000,
311 },
312 OpenAIModels::Gpt4 => RateLimit {
313 tpm: 300_000,
314 rpm: 10_000,
315 },
316 OpenAIModels::Gpt4Turbo => RateLimit {
317 tpm: 2_000_000,
318 rpm: 10_000,
319 },
320 OpenAIModels::Gpt4_32k => RateLimit {
321 tpm: 300_000,
322 rpm: 10_000,
323 },
324 OpenAIModels::Gpt4o => RateLimit {
325 tpm: 2_000_000,
326 rpm: 10_000,
327 },
328 OpenAIModels::TextDavinci003 => RateLimit {
329 tpm: 250_000,
330 rpm: 3_000,
331 },
332 }
333 }
334
335 pub fn get_max_requests(&self) -> usize {
337 let rate_limit = self.get_rate_limit();
338
339 let max_requests_from_rpm = rate_limit.rpm;
341
342 let max_tokens_per_minute = rate_limit.tpm;
345 let tpm_per_request = (self.default_max_tokens() as f64 * 0.5).ceil() as usize;
346 let max_requests_from_tpm = max_tokens_per_minute / tpm_per_request;
348
349 std::cmp::min(max_requests_from_rpm, max_requests_from_tpm)
351 }
352}
353
354#[deprecated(
364 since = "0.6.1",
365 note = "This struct is deprecated. Please use the `Completions` struct for latest functionality including GPT-4o."
366)]
367pub struct OpenAI {
368 model: OpenAIModels,
369 max_tokens: usize,
371 temperature: u32,
372 input_json: Option<String>,
373 debug: bool,
374 function_call: bool,
375 api_key: String,
376}
377
378impl OpenAI {
379 pub fn new(
380 open_ai_key: &str,
381 model: OpenAIModels,
382 max_tokens: Option<usize>,
383 temperature: Option<u32>,
384 ) -> Self {
385 OpenAI {
386 max_tokens: max_tokens.unwrap_or_else(|| model.default_max_tokens()),
388 function_call: model.function_call_default(),
389 model,
390 temperature: temperature.unwrap_or(0u32), input_json: None,
392 debug: false,
393 api_key: open_ai_key.to_string(),
394 }
395 }
396
397 pub fn debug(mut self) -> Self {
401 self.debug = true;
402 self
403 }
404
405 pub fn function_calling(mut self, function_call: bool) -> Self {
409 self.function_call = function_call;
410 self
411 }
412
413 pub fn set_context<T: Serialize>(mut self, input_name: &str, input_data: &T) -> Result<Self> {
419 let input_json = if let Ok(json) = serde_json::to_string(&input_data) {
420 json
421 } else {
422 return Err(anyhow!("Unable serialize provided input data."));
423 };
424 let line_break = match self.input_json {
425 Some(_) => "\n\n".to_string(),
426 None => "".to_string(),
427 };
428 let new_json = format!(
429 "{}{}{}: {}",
430 self.input_json.unwrap_or_default(),
431 line_break,
432 input_name,
433 input_json,
434 );
435 self.input_json = Some(new_json);
436 Ok(self)
437 }
438
439 pub fn check_prompt_tokens<T: JsonSchema + DeserializeOwned>(
444 &self,
445 instructions: &str,
446 ) -> Result<usize> {
447 let schema = schema_for!(T);
449 let json_value: Value = serde_json::to_value(&schema)?;
450
451 let prompt = format!(
452 "Instructions:
453 {instructions}
454
455 Input data:
456 {input_json}
457
458 Respond ONLY with the data portion of a valid Json object. No schema definition required. No other words.",
459 instructions = instructions,
460 input_json = self.input_json.clone().unwrap_or_default(),
461 );
462
463 let full_prompt = format!(
464 "{}{}{}",
465 self.model.get_base_instructions(Some(self.function_call)),
467 prompt,
469 serde_json::to_string(&json_value).unwrap_or_default()
471 );
472
473 let bpe = get_tokenizer_old(&self.model)?;
475 let prompt_tokens = bpe.encode_with_special_tokens(&full_prompt).len();
476
477 Ok((prompt_tokens as f64 * 1.05) as usize)
479 }
480
481 pub async fn get_answer<T: JsonSchema + DeserializeOwned>(
487 self,
488 instructions: &str,
489 ) -> Result<T> {
490 let schema = schema_for!(T);
492 let json_value: Value = serde_json::to_value(&schema)?;
493
494 let prompt = format!(
495 "Instructions:
496 {instructions}
497
498 Input data:
499 {input_json}
500
501 Respond ONLY with the data portion of a valid Json object. No schema definition required. No other words.",
502 instructions = instructions,
503 input_json = self.input_json.clone().unwrap_or_default(),
504 );
505
506 let prompt_tokens = self
508 .check_prompt_tokens::<T>(instructions)
509 .unwrap_or_default();
510
511 if prompt_tokens >= self.max_tokens {
512 return Err(anyhow!(
513 "The provided prompt requires more tokens than allocated."
514 ));
515 }
516 let response_tokens = self.max_tokens - prompt_tokens;
517
518 if prompt_tokens * 2 >= self.max_tokens {
521 warn!(
522 "{} tokens remaining for response: {} allocated, {} used for prompt",
523 response_tokens, self.max_tokens, prompt_tokens,
524 );
525 };
526
527 let model_body = self.model.get_body(
529 &prompt,
530 &json_value,
531 self.function_call,
532 &response_tokens,
533 &self.temperature,
534 );
535
536 if self.debug {
538 info!("[debug] Model body: {:#?}", model_body);
539 info!(
540 "[debug] Prompt accounts for approx {} tokens, leaving {} tokens for answer.",
541 prompt_tokens, response_tokens,
542 );
543 }
544
545 let response_text = self
546 .model
547 .call_api(&self.api_key, &model_body, self.debug)
548 .await?;
549
550 let response_string = self.model.get_data(&response_text, self.function_call)?;
552
553 if self.debug {
554 info!("[debug] OpenAI response data: {}", response_string);
555 }
556 let response_deser: anyhow::Result<T, anyhow::Error> =
558 serde_json::from_str(&response_string).map_err(|error| {
559 error!("[OpenAI] Response serialization error: {}", &error);
560 anyhow!("Error: {}", error)
561 });
562 if let Err(_e) = response_deser {
564 let response_deser: OpenAIDataResponse<T> = serde_json::from_str(&response_text)
565 .map_err(|error| {
566 error!("[OpenAI] Response serialization error: {}", &error);
567 anyhow!("Error: {}", error)
568 })?;
569 Ok(response_deser.data)
570 } else {
571 Ok(response_deser.unwrap())
572 }
573 }
574}
575
576#[cfg(test)]
577mod tests {
578 use crate::utils::get_tokenizer_old;
579 use crate::OpenAIModels;
580
581 #[test]
582 fn it_computes_gpt3_5_tokenization() {
583 let bpe = get_tokenizer_old(&OpenAIModels::Gpt4_32k).unwrap();
584 let tokenized: Result<Vec<_>, _> = bpe
585 .split_by_token_iter("This is a test with a lot of spaces", true)
586 .collect();
587 let tokenized = tokenized.unwrap();
588 assert_eq!(
589 tokenized,
590 vec!["This", " is", " a", " test", " ", " with", " a", " lot", " of", " spaces"]
591 );
592 }
593
594 #[test]
596 fn test_gpt3_5turbo_max_requests() {
597 let model = OpenAIModels::Gpt3_5Turbo;
598 let max_requests = model.get_max_requests();
599 let expected_max = std::cmp::min(10_000, 2_000_000 / ((4096_f64 * 0.5).ceil() as usize));
600 assert_eq!(max_requests, expected_max);
601 }
602
603 #[test]
604 fn test_gpt3_5turbo0613_max_requests() {
605 let model = OpenAIModels::Gpt3_5Turbo0613;
606 let max_requests = model.get_max_requests();
607 let expected_max = std::cmp::min(10_000, 2_000_000 / ((4096_f64 * 0.5).ceil() as usize));
608 assert_eq!(max_requests, expected_max);
609 }
610
611 #[test]
612 fn test_gpt3_5turbo16k_max_requests() {
613 let model = OpenAIModels::Gpt3_5Turbo16k;
614 let max_requests = model.get_max_requests();
615 let expected_max = std::cmp::min(10_000, 2_000_000 / ((16384_f64 * 0.5).ceil() as usize));
616 assert_eq!(max_requests, expected_max);
617 }
618
619 #[test]
620 fn test_gpt4_max_requests() {
621 let model = OpenAIModels::Gpt4;
622 let max_requests = model.get_max_requests();
623 let expected_max = std::cmp::min(10_000, 300_000 / ((8192_f64 * 0.5).ceil() as usize));
624 assert_eq!(max_requests, expected_max);
625 }
626}