1pub mod json_schema;
91
92use anyhow::Result;
93use dotenvy::dotenv;
94use fxhash::FxHashMap;
95use json_schema::JsonSchema;
96use serde::{Deserialize, Serialize};
97use std::env;
98
99#[derive(Debug, Clone, Deserialize, Serialize)]
100pub struct Message {
101 pub role: String,
102 pub content: String,
103 #[serde(skip_serializing_if = "Option::is_none")]
104 pub refusal: Option<String>,
105}
106
107impl Message {
108 pub fn new(role: String, message: String) -> Self {
109 Self {
110 role: String::from(role),
111 content: String::from(message),
112 refusal: None,
113 }
114 }
115}
116
117#[derive(Debug, Clone, Deserialize, Serialize)]
118pub struct ResponseFormat {
119 #[serde(rename = "type")]
120 pub type_name: String,
121 pub json_schema: JsonSchema,
122}
123
124impl ResponseFormat {
125 pub fn new(type_name: String, json_schema: JsonSchema) -> Self {
126 Self {
127 type_name: String::from(type_name),
128 json_schema,
129 }
130 }
131}
132
133#[derive(Debug, Clone, Deserialize, Serialize)]
134pub struct ChatCompletionRequestBody {
135 pub model: String,
137 pub messages: Vec<Message>,
139 #[serde(skip_serializing_if = "Option::is_none")]
141 pub store: Option<bool>,
142 #[serde(skip_serializing_if = "Option::is_none")]
144 pub frequency_penalty: Option<f32>,
145 #[serde(skip_serializing_if = "Option::is_none")]
147 pub logit_bias: Option<FxHashMap<String, i32>>,
148 #[serde(skip_serializing_if = "Option::is_none")]
150 pub logprobs: Option<bool>,
151 #[serde(skip_serializing_if = "Option::is_none")]
153 pub top_logprobs: Option<u8>,
154 #[serde(skip_serializing_if = "Option::is_none")]
156 pub max_completion_tokens: Option<u64>,
157 #[serde(skip_serializing_if = "Option::is_none")]
159 pub n: Option<u32>,
160 #[serde(skip_serializing_if = "Option::is_none")]
162 pub modalities: Option<Vec<String>>,
163 #[serde(skip_serializing_if = "Option::is_none")]
165 pub presence_penalty: Option<f32>,
166 #[serde(skip_serializing_if = "Option::is_none")]
168 pub temperature: Option<f32>,
169 #[serde(skip_serializing_if = "Option::is_none")]
171 pub response_format: Option<ResponseFormat>,
172}
173
174impl ChatCompletionRequestBody {
175 pub fn new(
176 model_id: String,
177 messages: Vec<Message>,
178 store: Option<bool>,
179 frequency_penalty: Option<f32>,
180 logit_bias: Option<FxHashMap<String, i32>>,
181 logprobs: Option<bool>,
182 top_logprobs: Option<u8>,
183 max_completion_tokens: Option<u64>,
184 n: Option<u32>,
185 modalities: Option<Vec<String>>,
186 presence_penalty: Option<f32>,
187 temperature: Option<f32>,
188 response_format: Option<ResponseFormat>,
189 ) -> Self {
190 Self {
191 model: model_id,
192 messages,
193 store: if let Some(value) = store {
194 Option::from(value)
195 } else {
196 None
197 },
198 frequency_penalty: if let Some(value) = frequency_penalty {
199 Option::from(value)
200 } else {
201 None
202 },
203 logit_bias: if let Some(value) = logit_bias {
204 Option::from(value)
205 } else {
206 None
207 },
208 logprobs: if let Some(value) = logprobs {
209 Option::from(value)
210 } else {
211 None
212 },
213 top_logprobs: if let Some(value) = top_logprobs {
214 Option::from(value)
215 } else {
216 None
217 },
218 max_completion_tokens: if let Some(value) = max_completion_tokens {
219 Option::from(value)
220 } else {
221 None
222 },
223 n: if let Some(value) = n {
224 Option::from(value)
225 } else {
226 None
227 },
228 modalities: if let Some(value) = modalities {
229 Option::from(value)
230 } else {
231 None
232 },
233 presence_penalty: if let Some(value) = presence_penalty {
234 Option::from(value)
235 } else {
236 None
237 },
238 temperature: if let Some(value) = temperature {
239 Option::from(value)
240 } else {
241 None
242 },
243 response_format: if let Some(value) = response_format {
244 Option::from(value)
245 } else {
246 None
247 },
248 }
249 }
250
251 pub fn default() -> Self {
252 Self {
253 model: String::default(),
254 messages: Vec::new(),
255 store: None,
256 frequency_penalty: None,
257 logit_bias: None,
258 logprobs: None,
259 top_logprobs: None,
260 max_completion_tokens: None,
261 n: None,
262 modalities: None,
263 presence_penalty: None,
264 temperature: None,
265 response_format: None,
266 }
267 }
268}
269
270#[derive(Debug, Clone, Deserialize, Serialize)]
271pub struct Choice {
272 pub index: u32,
273 pub message: Message,
274 pub finish_reason: String,
275}
276
277#[derive(Debug, Clone, Deserialize, Serialize)]
278pub struct Usage {
279 pub prompt_tokens: u64,
280 pub completion_tokens: u64,
281 pub total_tokens: u64,
282 pub completion_tokens_details: FxHashMap<String, u64>,
283}
284
285#[derive(Debug, Clone, Deserialize, Serialize)]
286pub struct Response {
287 pub id: String,
288 pub object: String,
289 pub created: u64,
290 pub model: String,
291 pub system_fingerprint: String,
292 pub choices: Vec<Choice>,
293 pub usage: Usage,
294}
295
296pub struct OpenAI {
297 api_key: String,
298 pub completion_body: ChatCompletionRequestBody,
299}
300
301impl OpenAI {
302 pub fn new() -> Self {
303 dotenv().ok();
304 let api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY is not set.");
305 return Self {
306 api_key,
307 completion_body: ChatCompletionRequestBody::default(),
308 };
309 }
310
311 pub fn model_id(&mut self, model_id: String) -> &mut Self {
312 self.completion_body.model = String::from(model_id);
313 return self;
314 }
315
316 pub fn messages(&mut self, messages: Vec<Message>) -> &mut Self {
317 self.completion_body.messages = messages;
318 return self;
319 }
320
321 pub fn store(&mut self, store: bool) -> &mut Self {
322 self.completion_body.store = Option::from(store);
323 return self;
324 }
325
326 pub fn frequency_penalty(&mut self, frequency_penalty: f32) -> &mut Self {
327 self.completion_body.frequency_penalty = Option::from(frequency_penalty);
328 return self;
329 }
330
331 pub fn logit_bias(&mut self, logit_bias: FxHashMap<String, i32>) -> &mut Self {
332 self.completion_body.logit_bias = Option::from(logit_bias);
333 return self;
334 }
335
336 pub fn logprobs(&mut self, logprobs: bool) -> &mut Self {
337 self.completion_body.logprobs = Option::from(logprobs);
338 return self;
339 }
340
341 pub fn top_logprobs(&mut self, top_logprobs: u8) -> &mut Self {
342 self.completion_body.top_logprobs = Option::from(top_logprobs);
343 return self;
344 }
345
346 pub fn max_completion_tokens(&mut self, max_completion_tokens: u64) -> &mut Self {
347 self.completion_body.max_completion_tokens = Option::from(max_completion_tokens);
348 return self;
349 }
350
351 pub fn n(&mut self, n: u32) -> &mut Self {
352 self.completion_body.n = Option::from(n);
353 return self;
354 }
355
356 pub fn modalities(&mut self, modalities: Vec<String>) -> &mut Self {
357 self.completion_body.modalities = Option::from(modalities);
358 return self;
359 }
360
361 pub fn presence_penalty(&mut self, presence_penalty: f32) -> &mut Self {
362 self.completion_body.presence_penalty = Option::from(presence_penalty);
363 return self;
364 }
365
366 pub fn temperature(&mut self, temperature: f32) -> &mut Self {
367 self.completion_body.temperature = Option::from(temperature);
368 return self;
369 }
370
371 pub fn response_format(&mut self, response_format: ResponseFormat) -> &mut Self {
372 self.completion_body.response_format = Option::from(response_format);
373 return self;
374 }
375
376 pub async fn chat(&mut self) -> Result<Response> {
377 if self.api_key.is_empty() {
379 return Err(anyhow::Error::msg("API key is not set."));
380 }
381 if self.completion_body.model.is_empty() {
382 return Err(anyhow::Error::msg("Model ID is not set."));
383 }
384 if self.completion_body.messages.is_empty() {
385 return Err(anyhow::Error::msg("Messages are not set."));
386 }
387
388 let body = serde_json::to_string(&self.completion_body)?;
389 let url = "https://api.openai.com/v1/chat/completions";
390
391 let client = request::Client::new();
392 let mut header = request::header::HeaderMap::new();
393 header.insert(
394 "Content-Type",
395 request::header::HeaderValue::from_static("application/json"),
396 );
397 header.insert(
398 "Authorization",
399 request::header::HeaderValue::from_str(&format!("Bearer {}", self.api_key)).unwrap(),
400 );
401 header.insert(
402 "User-Agent",
403 request::header::HeaderValue::from_static("openai-tools-rust/0.1.0"),
404 );
405 let response = client
406 .post(url)
407 .headers(header)
408 .body(body)
409 .send()
410 .await
411 .expect("Failed to send request");
412 if !response.status().is_success() {
413 let status = response.status();
414 let content = response.text().await.expect("Failed to read response text");
415 let e_msg = format!(
416 "Request failed with status: {} CONTENT: {}",
417 status, content
418 );
419 return Err(anyhow::Error::msg(e_msg));
420 }
421 let content = response.text().await.expect("Failed to read response text");
422
423 match serde_json::from_str::<Response>(&content) {
424 Ok(response) => return Ok(response),
425 Err(e) => {
426 let e_msg = format!("Failed to parse JSON: {} CONTENT: {}", e, content);
427 return Err(anyhow::Error::msg(e_msg));
428 }
429 }
430 }
431}
432
433#[cfg(test)]
434mod tests;