1pub mod json_schema;
91
92use anyhow::Result;
93use dotenvy::dotenv;
94use fxhash::FxHashMap;
95use json_schema::JsonSchema;
96use serde::{Deserialize, Serialize};
97use std::env;
98use std::process::Command;
99
100#[derive(Debug, Clone, Deserialize, Serialize)]
101pub struct Message {
102 pub role: String,
103 pub content: String,
104 #[serde(skip_serializing_if = "Option::is_none")]
105 pub refusal: Option<String>,
106}
107
108impl Message {
109 pub fn new(role: String, message: String) -> Self {
110 Self {
111 role: String::from(role),
112 content: String::from(message),
113 refusal: None,
114 }
115 }
116}
117
118#[derive(Debug, Clone, Deserialize, Serialize)]
119pub struct ResponseFormat {
120 #[serde(rename = "type")]
121 pub type_name: String,
122 pub json_schema: JsonSchema,
123}
124
125impl ResponseFormat {
126 pub fn new(type_name: String, json_schema: JsonSchema) -> Self {
127 Self {
128 type_name: String::from(type_name),
129 json_schema,
130 }
131 }
132}
133
134#[derive(Debug, Clone, Deserialize, Serialize)]
135pub struct ChatCompletionRequestBody {
136 pub model: String,
138 pub messages: Vec<Message>,
140 #[serde(skip_serializing_if = "Option::is_none")]
142 pub store: Option<bool>,
143 #[serde(skip_serializing_if = "Option::is_none")]
145 pub frequency_penalty: Option<f32>,
146 #[serde(skip_serializing_if = "Option::is_none")]
148 pub logit_bias: Option<FxHashMap<String, i32>>,
149 #[serde(skip_serializing_if = "Option::is_none")]
151 pub logprobs: Option<bool>,
152 #[serde(skip_serializing_if = "Option::is_none")]
154 pub top_logprobs: Option<u8>,
155 #[serde(skip_serializing_if = "Option::is_none")]
157 pub max_completion_tokens: Option<u64>,
158 #[serde(skip_serializing_if = "Option::is_none")]
160 pub n: Option<u32>,
161 #[serde(skip_serializing_if = "Option::is_none")]
163 pub modalities: Option<Vec<String>>,
164 #[serde(skip_serializing_if = "Option::is_none")]
166 pub presence_penalty: Option<f32>,
167 #[serde(skip_serializing_if = "Option::is_none")]
169 pub temperature: Option<f32>,
170 #[serde(skip_serializing_if = "Option::is_none")]
172 pub response_format: Option<ResponseFormat>,
173}
174
175impl ChatCompletionRequestBody {
176 pub fn new(
177 model_id: String,
178 messages: Vec<Message>,
179 store: Option<bool>,
180 frequency_penalty: Option<f32>,
181 logit_bias: Option<FxHashMap<String, i32>>,
182 logprobs: Option<bool>,
183 top_logprobs: Option<u8>,
184 max_completion_tokens: Option<u64>,
185 n: Option<u32>,
186 modalities: Option<Vec<String>>,
187 presence_penalty: Option<f32>,
188 temperature: Option<f32>,
189 response_format: Option<ResponseFormat>,
190 ) -> Self {
191 Self {
192 model: model_id,
193 messages,
194 store: if let Some(value) = store {
195 Option::from(value)
196 } else {
197 None
198 },
199 frequency_penalty: if let Some(value) = frequency_penalty {
200 Option::from(value)
201 } else {
202 None
203 },
204 logit_bias: if let Some(value) = logit_bias {
205 Option::from(value)
206 } else {
207 None
208 },
209 logprobs: if let Some(value) = logprobs {
210 Option::from(value)
211 } else {
212 None
213 },
214 top_logprobs: if let Some(value) = top_logprobs {
215 Option::from(value)
216 } else {
217 None
218 },
219 max_completion_tokens: if let Some(value) = max_completion_tokens {
220 Option::from(value)
221 } else {
222 None
223 },
224 n: if let Some(value) = n {
225 Option::from(value)
226 } else {
227 None
228 },
229 modalities: if let Some(value) = modalities {
230 Option::from(value)
231 } else {
232 None
233 },
234 presence_penalty: if let Some(value) = presence_penalty {
235 Option::from(value)
236 } else {
237 None
238 },
239 temperature: if let Some(value) = temperature {
240 Option::from(value)
241 } else {
242 None
243 },
244 response_format: if let Some(value) = response_format {
245 Option::from(value)
246 } else {
247 None
248 },
249 }
250 }
251
252 pub fn default() -> Self {
253 Self {
254 model: String::default(),
255 messages: Vec::new(),
256 store: None,
257 frequency_penalty: None,
258 logit_bias: None,
259 logprobs: None,
260 top_logprobs: None,
261 max_completion_tokens: None,
262 n: None,
263 modalities: None,
264 presence_penalty: None,
265 temperature: None,
266 response_format: None,
267 }
268 }
269}
270
271#[derive(Debug, Clone, Deserialize, Serialize)]
272pub struct Choice {
273 pub index: u32,
274 pub message: Message,
275 pub finish_reason: String,
276}
277
278#[derive(Debug, Clone, Deserialize, Serialize)]
279pub struct Usage {
280 pub prompt_tokens: u64,
281 pub completion_tokens: u64,
282 pub total_tokens: u64,
283 pub completion_tokens_details: FxHashMap<String, u64>,
284}
285
286#[derive(Debug, Clone, Deserialize, Serialize)]
287pub struct Response {
288 pub id: String,
289 pub object: String,
290 pub created: u64,
291 pub model: String,
292 pub system_fingerprint: String,
293 pub choices: Vec<Choice>,
294 pub usage: Usage,
295}
296
297pub struct OpenAI {
298 api_key: String,
299 pub completion_body: ChatCompletionRequestBody,
300}
301
302impl OpenAI {
303 pub fn new() -> Self {
304 dotenv().ok();
305 let api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY is not set.");
306 return Self {
307 api_key,
308 completion_body: ChatCompletionRequestBody::default(),
309 };
310 }
311
312 pub fn model_id(&mut self, model_id: String) -> &mut Self {
313 self.completion_body.model = String::from(model_id);
314 return self;
315 }
316
317 pub fn messages(&mut self, messages: Vec<Message>) -> &mut Self {
318 self.completion_body.messages = messages;
319 return self;
320 }
321
322 pub fn store(&mut self, store: bool) -> &mut Self {
323 self.completion_body.store = Option::from(store);
324 return self;
325 }
326
327 pub fn frequency_penalty(&mut self, frequency_penalty: f32) -> &mut Self {
328 self.completion_body.frequency_penalty = Option::from(frequency_penalty);
329 return self;
330 }
331
332 pub fn logit_bias(&mut self, logit_bias: FxHashMap<String, i32>) -> &mut Self {
333 self.completion_body.logit_bias = Option::from(logit_bias);
334 return self;
335 }
336
337 pub fn logprobs(&mut self, logprobs: bool) -> &mut Self {
338 self.completion_body.logprobs = Option::from(logprobs);
339 return self;
340 }
341
342 pub fn top_logprobs(&mut self, top_logprobs: u8) -> &mut Self {
343 self.completion_body.top_logprobs = Option::from(top_logprobs);
344 return self;
345 }
346
347 pub fn max_completion_tokens(&mut self, max_completion_tokens: u64) -> &mut Self {
348 self.completion_body.max_completion_tokens = Option::from(max_completion_tokens);
349 return self;
350 }
351
352 pub fn n(&mut self, n: u32) -> &mut Self {
353 self.completion_body.n = Option::from(n);
354 return self;
355 }
356
357 pub fn modalities(&mut self, modalities: Vec<String>) -> &mut Self {
358 self.completion_body.modalities = Option::from(modalities);
359 return self;
360 }
361
362 pub fn presence_penalty(&mut self, presence_penalty: f32) -> &mut Self {
363 self.completion_body.presence_penalty = Option::from(presence_penalty);
364 return self;
365 }
366
367 pub fn temperature(&mut self, temperature: f32) -> &mut Self {
368 self.completion_body.temperature = Option::from(temperature);
369 return self;
370 }
371
372 pub fn response_format(&mut self, response_format: ResponseFormat) -> &mut Self {
373 self.completion_body.response_format = Option::from(response_format);
374 return self;
375 }
376
377 pub fn chat(&mut self) -> Result<Response> {
378 if self.api_key.is_empty() {
380 return Err(anyhow::Error::msg("API key is not set."));
381 }
382 if self.completion_body.model.is_empty() {
383 return Err(anyhow::Error::msg("Model ID is not set."));
384 }
385 if self.completion_body.messages.is_empty() {
386 return Err(anyhow::Error::msg("Messages are not set."));
387 }
388
389 let body = serde_json::to_string(&self.completion_body)?;
390 let url = "https://api.openai.com/v1/chat/completions";
391 let cmd = Command::new("curl")
392 .arg(url)
393 .arg("-H")
394 .arg("Content-Type: application/json")
395 .arg("-H")
396 .arg(format!("Authorization: Bearer {}", self.api_key))
397 .arg("-d")
398 .arg(body)
399 .output()
400 .expect("Failed to execute command");
401
402 let content = String::from_utf8_lossy(&cmd.stdout).to_string();
403
404 match serde_json::from_str::<Response>(&content) {
405 Ok(response) => return Ok(response),
406 Err(e) => {
407 let e_msg = format!("Failed to parse JSON: {} CONTENT: {}", e, content);
408 return Err(anyhow::Error::msg(e_msg));
409 }
410 }
411 }
412}
413
414#[cfg(test)]
415mod tests;