1use anyhow::{anyhow, Result};
2use log::{error, info, warn};
3use reqwest::{header, Client};
4use schemars::{schema_for, JsonSchema};
5use serde::{de::DeserializeOwned, Serialize};
6use serde_json::Value;
7
8use crate::{domain::OpenAIDataResponse, models::OpenAIModels, utils::get_tokenizer};
9
10pub struct OpenAI {
16 model: OpenAIModels,
17 max_tokens: usize,
19 temperature: u32,
20 input_json: Option<String>,
21 debug: bool,
22 function_call: bool,
23 api_key: String,
24}
25
26impl OpenAI {
27 pub fn new(
29 open_ai_key: &str,
30 model: OpenAIModels,
31 max_tokens: Option<usize>,
32 temperature: Option<u32>,
33 ) -> Self {
34 OpenAI {
35 max_tokens: max_tokens.unwrap_or_else(|| model.default_max_tokens()),
37 function_call: model.function_call_default(),
38 model,
39 temperature: temperature.unwrap_or(0u32), input_json: None,
41 debug: false,
42 api_key: open_ai_key.to_string(),
43 }
44 }
45
46 pub fn debug(mut self) -> Self {
50 self.debug = true;
51 self
52 }
53
54 pub fn function_calling(mut self, function_call: bool) -> Self {
58 self.function_call = function_call;
59 self
60 }
61
62 pub fn set_context<T: Serialize>(mut self, input_name: &str, input_data: &T) -> Result<Self> {
68 let input_json = if let Ok(json) = serde_json::to_string(&input_data) {
69 json
70 } else {
71 return Err(anyhow!("Unable serialize provided input data."));
72 };
73 let line_break = match self.input_json {
74 Some(_) => "\n\n".to_string(),
75 None => "".to_string(),
76 };
77 let new_json = format!(
78 "{}{}{}: {}",
79 self.input_json.unwrap_or_default(),
80 line_break,
81 input_name,
82 input_json,
83 );
84 self.input_json = Some(new_json);
85 Ok(self)
86 }
87
88 pub fn check_prompt_tokens<T: JsonSchema + DeserializeOwned>(
93 &self,
94 instructions: &str,
95 ) -> Result<usize> {
96 let schema = schema_for!(T);
98 let json_value: Value = serde_json::to_value(&schema)?;
99
100 let prompt = format!(
101 "Instructions:
102 {instructions}
103
104 Input data:
105 {input_json}
106
107 Respond ONLY with the data portion of a valid Json object. No schema definition required. No other words.",
108 instructions = instructions,
109 input_json = self.input_json.clone().unwrap_or_default(),
110 );
111
112 let full_prompt = format!(
113 "{}{}{}",
114 self.model.get_base_instructions(Some(self.function_call)),
116 prompt,
118 serde_json::to_string(&json_value).unwrap_or_default()
120 );
121
122 let bpe = get_tokenizer(&self.model)?;
124 let prompt_tokens = bpe.encode_with_special_tokens(&full_prompt).len();
125
126 Ok((prompt_tokens as f64 * 1.05) as usize)
128 }
129
130 async fn call_openai_api(&self, body: &serde_json::Value) -> Result<String> {
136 let model_url = self.model.get_endpoint();
138
139 let client = Client::new();
141
142 let response = client
143 .post(model_url)
144 .header(header::CONTENT_TYPE, "application/json")
145 .bearer_auth(&self.api_key)
146 .json(&body)
147 .send()
148 .await?;
149
150 let response_status = response.status();
151 let response_text = response.text().await?;
152
153 if self.debug {
154 info!(
155 "[debug] OpenAI API response: [{}] {:#?}",
156 &response_status, &response_text
157 );
158 }
159
160 Ok(response_text)
161 }
162
163 pub async fn get_answer<T: JsonSchema + DeserializeOwned>(
169 self,
170 instructions: &str,
171 ) -> Result<T> {
172 let schema = schema_for!(T);
174 let json_value: Value = serde_json::to_value(&schema)?;
175
176 let prompt = format!(
177 "Instructions:
178 {instructions}
179
180 Input data:
181 {input_json}
182
183 Respond ONLY with the data portion of a valid Json object. No schema definition required. No other words.",
184 instructions = instructions,
185 input_json = self.input_json.clone().unwrap_or_default(),
186 );
187
188 let prompt_tokens = self
190 .check_prompt_tokens::<T>(instructions)
191 .unwrap_or_default();
192
193 if prompt_tokens >= self.max_tokens {
194 return Err(anyhow!(
195 "The provided prompt requires more tokens than allocated."
196 ));
197 }
198 let response_tokens = self.max_tokens - prompt_tokens;
199
200 if prompt_tokens * 2 >= self.max_tokens {
203 warn!(
204 "{} tokens remaining for response: {} allocated, {} used for prompt",
205 response_tokens.to_string(),
206 self.max_tokens.to_string(),
207 prompt_tokens.to_string(),
208 );
209 };
210
211 let model_body = self.model.get_body(
213 &prompt,
214 &json_value,
215 self.function_call,
216 &response_tokens,
217 &self.temperature,
218 );
219
220 if self.debug {
222 info!("[debug] Model body: {:#?}", model_body);
223 info!(
224 "[debug] Prompt accounts for approx {} tokens, leaving {} tokens for answer.",
225 prompt_tokens.to_string(),
226 response_tokens.to_string(),
227 );
228 }
229
230 let response_text = self.call_openai_api(&model_body).await?;
231
232 let response_string = self.model.get_data(&response_text, self.function_call)?;
234
235 if self.debug {
236 info!("[debug] OpenAI response data: {}", response_string);
237 }
238 let response_deser: anyhow::Result<T, anyhow::Error> =
240 serde_json::from_str(&response_string).map_err(|error| {
241 error!("[OpenAI] Response serialization error: {}", &error);
242 anyhow!("Error: {}", error)
243 });
244 if let Err(_e) = response_deser {
246 let response_deser: OpenAIDataResponse<T> = serde_json::from_str(&response_text)
247 .map_err(|error| {
248 error!("[OpenAI] Response serialization error: {}", &error);
249 anyhow!("Error: {}", error)
250 })?;
251 Ok(response_deser.data)
252 } else {
253 Ok(response_deser.unwrap())
254 }
255 }
256}