openai_tools/
lib.rs

1//! # Usage
2//! ## Chat Completion
3//
4//! ### Simple Chat
5//
6//! ```rust
7//! # use openai_tools::{Message, OpenAI, Response};
8//! # fn main() {
9//!     let mut openai = OpenAI::new();
10//!     let messages = vec![
11//!         Message::new(String::from("user"), String::from("Hi there!"))
12//!     ];
13//
14//!     openai
15//!         .model_id(String::from("gpt-4o-mini"))
16//!         .messages(messages)
17//!         .temperature(1.0);
18//
19//!     let response: Response = openai.chat().unwrap();
20//!     println!("{}", &response.choices[0].message.content);
21//!     // Hello! How can I assist you today?
22//! # }
23//! ```
24//
25//! ### Chat with Json Schema
26//
27//! ```rust
28//! # use openai_tools::{json_schema::JsonSchema, Message, OpenAI, Response, ResponseFormat};
29//! # use serde::{Deserialize, Serialize};
30//! # use serde_json;
31//! # use std::env;
32//! # fn main() {
33//!     #[derive(Debug, Serialize, Deserialize)]
34//!     struct Weather {
35//!         location: String,
36//!         date: String,
37//!         weather: String,
38//!         error: String,
39//!     }
40//
41//!     let mut openai = OpenAI::new();
42//!     let messages = vec![Message::new(
43//!         String::from("user"),
44//!         String::from("Hi there! How's the weather tomorrow in Tokyo? If you can't answer, report error."),
45//!     )];
46//
47//!     // build json schema
48//!     let mut json_schema = JsonSchema::new("weather".to_string());
49//!     json_schema.add_property(
50//!         String::from("location"),
51//!         String::from("string"),
52//!         Option::from(String::from("The location to check the weather for.")),
53//!     );
54//!     json_schema.add_property(
55//!         String::from("date"),
56//!         String::from("string"),
57//!         Option::from(String::from("The date to check the weather for.")),
58//!     );
59//!     json_schema.add_property(
60//!         String::from("weather"),
61//!         String::from("string"),
62//!         Option::from(String::from("The weather for the location and date.")),
63//!     );
64//!     json_schema.add_property(
65//!         String::from("error"),
66//!         String::from("string"),
67//!         Option::from(String::from("Error message. If there is no error, leave this field empty.")),
68//!     );
69//
70//!     // configure chat completion model
71//!     openai
72//!         .model_id(String::from("gpt-4o-mini"))
73//!         .messages(messages)
74//!         .temperature(1.0)
75//!         .response_format(ResponseFormat::new(String::from("json_schema"), json_schema));
76//!
77//!     // execute chat
78//!     let response = openai.chat().unwrap();
79//
80//!     let answer: Weather = serde_json::from_str::<Weather>(&response.choices[0].message.content).unwrap();
81//!     println!("{:?}", answer)
82//!     // Weather {
83//!     //     location: "Tokyo",
84//!     //     date: "2023-10-01",
85//!     //     weather: "Temperatures around 25°C with partly cloudy skies and a slight chance of rain.",
86//!     //     error: "",
87//!     // }
88//! # }
89//! ```
90pub mod json_schema;
91
92use anyhow::Result;
93use dotenvy::dotenv;
94use fxhash::FxHashMap;
95use json_schema::JsonSchema;
96use serde::{Deserialize, Serialize};
97use std::env;
98use std::process::Command;
99
100#[derive(Debug, Clone, Deserialize, Serialize)]
101pub struct Message {
102    pub role: String,
103    pub content: String,
104    #[serde(skip_serializing_if = "Option::is_none")]
105    pub refusal: Option<String>,
106}
107
108impl Message {
109    pub fn new(role: String, message: String) -> Self {
110        Self {
111            role: String::from(role),
112            content: String::from(message),
113            refusal: None,
114        }
115    }
116}
117
118#[derive(Debug, Clone, Deserialize, Serialize)]
119pub struct ResponseFormat {
120    #[serde(rename = "type")]
121    pub type_name: String,
122    pub json_schema: JsonSchema,
123}
124
125impl ResponseFormat {
126    pub fn new(type_name: String, json_schema: JsonSchema) -> Self {
127        Self {
128            type_name: String::from(type_name),
129            json_schema,
130        }
131    }
132}
133
134#[derive(Debug, Clone, Deserialize, Serialize)]
135pub struct ChatCompletionRequestBody {
136    // ID of the model to use. (https://platform.openai.com/docs/models#model-endpoint-compatibility)
137    pub model: String,
138    // A list of messages comprising teh conversation so far.
139    pub messages: Vec<Message>,
140    // Whether or not to store the output of this chat completion request for user. false by default.
141    #[serde(skip_serializing_if = "Option::is_none")]
142    pub store: Option<bool>,
143    // -2.0 ~ 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.
144    #[serde(skip_serializing_if = "Option::is_none")]
145    pub frequency_penalty: Option<f32>,
146    // Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens to an associated bias value from 100 to 100.
147    #[serde(skip_serializing_if = "Option::is_none")]
148    pub logit_bias: Option<FxHashMap<String, i32>>,
149    // Whether to return log probabilities of the output tokens or not.
150    #[serde(skip_serializing_if = "Option::is_none")]
151    pub logprobs: Option<bool>,
152    // 0 ~ 20. Specify the number of most likely tokens to return at each token position, each with an associated log probability.
153    #[serde(skip_serializing_if = "Option::is_none")]
154    pub top_logprobs: Option<u8>,
155    // An upper bound for the number of tokens that can be generated for a completion.
156    #[serde(skip_serializing_if = "Option::is_none")]
157    pub max_completion_tokens: Option<u64>,
158    // How many chat completion choices to generate for each input message. 1 by default.
159    #[serde(skip_serializing_if = "Option::is_none")]
160    pub n: Option<u32>,
161    // Output types that you would like the model to generate for this request. ["text"] for most models.
162    #[serde(skip_serializing_if = "Option::is_none")]
163    pub modalities: Option<Vec<String>>,
164    // -2.0 ~ 2.0. Positive values penalize new tokens based on whether they apper in the text so far, increasing the model's likelihood to talk about new topics.
165    #[serde(skip_serializing_if = "Option::is_none")]
166    pub presence_penalty: Option<f32>,
167    // 0 ~ 2. What sampling temperature to use. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.
168    #[serde(skip_serializing_if = "Option::is_none")]
169    pub temperature: Option<f32>,
170    // An object specifying the format that the model must output. (https://platform.openai.com/docs/guides/structured-outputs)
171    #[serde(skip_serializing_if = "Option::is_none")]
172    pub response_format: Option<ResponseFormat>,
173}
174
175impl ChatCompletionRequestBody {
176    pub fn new(
177        model_id: String,
178        messages: Vec<Message>,
179        store: Option<bool>,
180        frequency_penalty: Option<f32>,
181        logit_bias: Option<FxHashMap<String, i32>>,
182        logprobs: Option<bool>,
183        top_logprobs: Option<u8>,
184        max_completion_tokens: Option<u64>,
185        n: Option<u32>,
186        modalities: Option<Vec<String>>,
187        presence_penalty: Option<f32>,
188        temperature: Option<f32>,
189        response_format: Option<ResponseFormat>,
190    ) -> Self {
191        Self {
192            model: model_id,
193            messages,
194            store: if let Some(value) = store {
195                Option::from(value)
196            } else {
197                None
198            },
199            frequency_penalty: if let Some(value) = frequency_penalty {
200                Option::from(value)
201            } else {
202                None
203            },
204            logit_bias: if let Some(value) = logit_bias {
205                Option::from(value)
206            } else {
207                None
208            },
209            logprobs: if let Some(value) = logprobs {
210                Option::from(value)
211            } else {
212                None
213            },
214            top_logprobs: if let Some(value) = top_logprobs {
215                Option::from(value)
216            } else {
217                None
218            },
219            max_completion_tokens: if let Some(value) = max_completion_tokens {
220                Option::from(value)
221            } else {
222                None
223            },
224            n: if let Some(value) = n {
225                Option::from(value)
226            } else {
227                None
228            },
229            modalities: if let Some(value) = modalities {
230                Option::from(value)
231            } else {
232                None
233            },
234            presence_penalty: if let Some(value) = presence_penalty {
235                Option::from(value)
236            } else {
237                None
238            },
239            temperature: if let Some(value) = temperature {
240                Option::from(value)
241            } else {
242                None
243            },
244            response_format: if let Some(value) = response_format {
245                Option::from(value)
246            } else {
247                None
248            },
249        }
250    }
251
252    pub fn default() -> Self {
253        Self {
254            model: String::default(),
255            messages: Vec::new(),
256            store: None,
257            frequency_penalty: None,
258            logit_bias: None,
259            logprobs: None,
260            top_logprobs: None,
261            max_completion_tokens: None,
262            n: None,
263            modalities: None,
264            presence_penalty: None,
265            temperature: None,
266            response_format: None,
267        }
268    }
269}
270
271#[derive(Debug, Clone, Deserialize, Serialize)]
272pub struct Choice {
273    pub index: u32,
274    pub message: Message,
275    pub finish_reason: String,
276}
277
278#[derive(Debug, Clone, Deserialize, Serialize)]
279pub struct Usage {
280    pub prompt_tokens: u64,
281    pub completion_tokens: u64,
282    pub total_tokens: u64,
283    pub completion_tokens_details: FxHashMap<String, u64>,
284}
285
286#[derive(Debug, Clone, Deserialize, Serialize)]
287pub struct Response {
288    pub id: String,
289    pub object: String,
290    pub created: u64,
291    pub model: String,
292    pub system_fingerprint: String,
293    pub choices: Vec<Choice>,
294    pub usage: Usage,
295}
296
297pub struct OpenAI {
298    api_key: String,
299    pub completion_body: ChatCompletionRequestBody,
300}
301
302impl OpenAI {
303    pub fn new() -> Self {
304        dotenv().ok();
305        let api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY is not set.");
306        return Self {
307            api_key,
308            completion_body: ChatCompletionRequestBody::default(),
309        };
310    }
311
312    pub fn model_id(&mut self, model_id: String) -> &mut Self {
313        self.completion_body.model = String::from(model_id);
314        return self;
315    }
316
317    pub fn messages(&mut self, messages: Vec<Message>) -> &mut Self {
318        self.completion_body.messages = messages;
319        return self;
320    }
321
322    pub fn store(&mut self, store: bool) -> &mut Self {
323        self.completion_body.store = Option::from(store);
324        return self;
325    }
326
327    pub fn frequency_penalty(&mut self, frequency_penalty: f32) -> &mut Self {
328        self.completion_body.frequency_penalty = Option::from(frequency_penalty);
329        return self;
330    }
331
332    pub fn logit_bias(&mut self, logit_bias: FxHashMap<String, i32>) -> &mut Self {
333        self.completion_body.logit_bias = Option::from(logit_bias);
334        return self;
335    }
336
337    pub fn logprobs(&mut self, logprobs: bool) -> &mut Self {
338        self.completion_body.logprobs = Option::from(logprobs);
339        return self;
340    }
341
342    pub fn top_logprobs(&mut self, top_logprobs: u8) -> &mut Self {
343        self.completion_body.top_logprobs = Option::from(top_logprobs);
344        return self;
345    }
346
347    pub fn max_completion_tokens(&mut self, max_completion_tokens: u64) -> &mut Self {
348        self.completion_body.max_completion_tokens = Option::from(max_completion_tokens);
349        return self;
350    }
351
352    pub fn n(&mut self, n: u32) -> &mut Self {
353        self.completion_body.n = Option::from(n);
354        return self;
355    }
356
357    pub fn modalities(&mut self, modalities: Vec<String>) -> &mut Self {
358        self.completion_body.modalities = Option::from(modalities);
359        return self;
360    }
361
362    pub fn presence_penalty(&mut self, presence_penalty: f32) -> &mut Self {
363        self.completion_body.presence_penalty = Option::from(presence_penalty);
364        return self;
365    }
366
367    pub fn temperature(&mut self, temperature: f32) -> &mut Self {
368        self.completion_body.temperature = Option::from(temperature);
369        return self;
370    }
371
372    pub fn response_format(&mut self, response_format: ResponseFormat) -> &mut Self {
373        self.completion_body.response_format = Option::from(response_format);
374        return self;
375    }
376
377    pub fn chat(&mut self) -> Result<Response> {
378        // Check if the API key is set & body is built.
379        if self.api_key.is_empty() {
380            return Err(anyhow::Error::msg("API key is not set."));
381        }
382        if self.completion_body.model.is_empty() {
383            return Err(anyhow::Error::msg("Model ID is not set."));
384        }
385        if self.completion_body.messages.is_empty() {
386            return Err(anyhow::Error::msg("Messages are not set."));
387        }
388
389        let body = serde_json::to_string(&self.completion_body)?;
390        let url = "https://api.openai.com/v1/chat/completions";
391        let cmd = Command::new("curl")
392            .arg(url)
393            .arg("-H")
394            .arg("Content-Type: application/json")
395            .arg("-H")
396            .arg(format!("Authorization: Bearer {}", self.api_key))
397            .arg("-d")
398            .arg(body)
399            .output()
400            .expect("Failed to execute command");
401
402        let content = String::from_utf8_lossy(&cmd.stdout).to_string();
403
404        match serde_json::from_str::<Response>(&content) {
405            Ok(response) => return Ok(response),
406            Err(e) => {
407                let e_msg = format!("Failed to parse JSON: {} CONTENT: {}", e, content);
408                return Err(anyhow::Error::msg(e_msg));
409            }
410        }
411    }
412}
413
414#[cfg(test)]
415mod tests;