openai_tools/
lib.rs

1//! # Usage
2//! ## Chat Completion
3//
4//! ### Simple Chat
5//
6//! ```rust
7//! # use openai_tools::{Message, OpenAI, Response};
8//! # fn main() {
9//!     let mut openai = OpenAI::new();
10//!     let messages = vec![
11//!         Message::new(String::from("user"), String::from("Hi there!"))
12//!     ];
13//
14//!     openai
15//!         .model_id(String::from("gpt-4o-mini"))
16//!         .messages(messages)
17//!         .temperature(1.0);
18//
19//!     let response: Response = openai.chat().unwrap();
20//!     println!("{}", &response.choices[0].message.content);
21//!     // Hello! How can I assist you today?
22//! # }
23//! ```
24//
25//! ### Chat with Json Schema
26//
27//! ```rust
28//! # use openai_tools::{json_schema::JsonSchema, Message, OpenAI, Response, ResponseFormat};
29//! # use serde::{Deserialize, Serialize};
30//! # use serde_json;
31//! # use std::env;
32//! # fn main() {
33//!     #[derive(Debug, Serialize, Deserialize)]
34//!     struct Weather {
35//!         location: String,
36//!         date: String,
37//!         weather: String,
38//!         error: String,
39//!     }
40//
41//!     let mut openai = OpenAI::new();
42//!     let messages = vec![Message::new(
43//!         String::from("user"),
44//!         String::from("Hi there! How's the weather tomorrow in Tokyo? If you can't answer, report error."),
45//!     )];
46//
47//!     // build json schema
48//!     let mut json_schema = JsonSchema::new("weather".to_string());
49//!     json_schema.add_property(
50//!         String::from("location"),
51//!         String::from("string"),
52//!         Option::from(String::from("The location to check the weather for.")),
53//!     );
54//!     json_schema.add_property(
55//!         String::from("date"),
56//!         String::from("string"),
57//!         Option::from(String::from("The date to check the weather for.")),
58//!     );
59//!     json_schema.add_property(
60//!         String::from("weather"),
61//!         String::from("string"),
62//!         Option::from(String::from("The weather for the location and date.")),
63//!     );
64//!     json_schema.add_property(
65//!         String::from("error"),
66//!         String::from("string"),
67//!         Option::from(String::from("Error message. If there is no error, leave this field empty.")),
68//!     );
69//
70//!     // configure chat completion model
71//!     openai
72//!         .model_id(String::from("gpt-4o-mini"))
73//!         .messages(messages)
74//!         .temperature(1.0)
75//!         .response_format(ResponseFormat::new(String::from("json_schema"), json_schema));
76//!
77//!     // execute chat
78//!     let response = openai.chat().unwrap();
79//
80//!     let answer: Weather = serde_json::from_str::<Weather>(&response.choices[0].message.content).unwrap();
81//!     println!("{:?}", answer)
82//!     // Weather {
83//!     //     location: "Tokyo",
84//!     //     date: "2023-10-01",
85//!     //     weather: "Temperatures around 25°C with partly cloudy skies and a slight chance of rain.",
86//!     //     error: "",
87//!     // }
88//! # }
89//! ```
90pub mod json_schema;
91
92use anyhow::Result;
93use dotenvy::dotenv;
94use fxhash::FxHashMap;
95use json_schema::JsonSchema;
96use serde::{Deserialize, Serialize};
97use std::env;
98
99#[derive(Debug, Clone, Deserialize, Serialize)]
100pub struct Message {
101    pub role: String,
102    pub content: String,
103    #[serde(skip_serializing_if = "Option::is_none")]
104    pub refusal: Option<String>,
105}
106
107impl Message {
108    pub fn new(role: String, message: String) -> Self {
109        Self {
110            role: String::from(role),
111            content: String::from(message),
112            refusal: None,
113        }
114    }
115}
116
117#[derive(Debug, Clone, Deserialize, Serialize)]
118pub struct ResponseFormat {
119    #[serde(rename = "type")]
120    pub type_name: String,
121    pub json_schema: JsonSchema,
122}
123
124impl ResponseFormat {
125    pub fn new(type_name: String, json_schema: JsonSchema) -> Self {
126        Self {
127            type_name: String::from(type_name),
128            json_schema,
129        }
130    }
131}
132
133#[derive(Debug, Clone, Deserialize, Serialize)]
134pub struct ChatCompletionRequestBody {
135    // ID of the model to use. (https://platform.openai.com/docs/models#model-endpoint-compatibility)
136    pub model: String,
137    // A list of messages comprising teh conversation so far.
138    pub messages: Vec<Message>,
139    // Whether or not to store the output of this chat completion request for user. false by default.
140    #[serde(skip_serializing_if = "Option::is_none")]
141    pub store: Option<bool>,
142    // -2.0 ~ 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.
143    #[serde(skip_serializing_if = "Option::is_none")]
144    pub frequency_penalty: Option<f32>,
145    // Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens to an associated bias value from 100 to 100.
146    #[serde(skip_serializing_if = "Option::is_none")]
147    pub logit_bias: Option<FxHashMap<String, i32>>,
148    // Whether to return log probabilities of the output tokens or not.
149    #[serde(skip_serializing_if = "Option::is_none")]
150    pub logprobs: Option<bool>,
151    // 0 ~ 20. Specify the number of most likely tokens to return at each token position, each with an associated log probability.
152    #[serde(skip_serializing_if = "Option::is_none")]
153    pub top_logprobs: Option<u8>,
154    // An upper bound for the number of tokens that can be generated for a completion.
155    #[serde(skip_serializing_if = "Option::is_none")]
156    pub max_completion_tokens: Option<u64>,
157    // How many chat completion choices to generate for each input message. 1 by default.
158    #[serde(skip_serializing_if = "Option::is_none")]
159    pub n: Option<u32>,
160    // Output types that you would like the model to generate for this request. ["text"] for most models.
161    #[serde(skip_serializing_if = "Option::is_none")]
162    pub modalities: Option<Vec<String>>,
163    // -2.0 ~ 2.0. Positive values penalize new tokens based on whether they apper in the text so far, increasing the model's likelihood to talk about new topics.
164    #[serde(skip_serializing_if = "Option::is_none")]
165    pub presence_penalty: Option<f32>,
166    // 0 ~ 2. What sampling temperature to use. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.
167    #[serde(skip_serializing_if = "Option::is_none")]
168    pub temperature: Option<f32>,
169    // An object specifying the format that the model must output. (https://platform.openai.com/docs/guides/structured-outputs)
170    #[serde(skip_serializing_if = "Option::is_none")]
171    pub response_format: Option<ResponseFormat>,
172}
173
174impl ChatCompletionRequestBody {
175    pub fn new(
176        model_id: String,
177        messages: Vec<Message>,
178        store: Option<bool>,
179        frequency_penalty: Option<f32>,
180        logit_bias: Option<FxHashMap<String, i32>>,
181        logprobs: Option<bool>,
182        top_logprobs: Option<u8>,
183        max_completion_tokens: Option<u64>,
184        n: Option<u32>,
185        modalities: Option<Vec<String>>,
186        presence_penalty: Option<f32>,
187        temperature: Option<f32>,
188        response_format: Option<ResponseFormat>,
189    ) -> Self {
190        Self {
191            model: model_id,
192            messages,
193            store: if let Some(value) = store {
194                Option::from(value)
195            } else {
196                None
197            },
198            frequency_penalty: if let Some(value) = frequency_penalty {
199                Option::from(value)
200            } else {
201                None
202            },
203            logit_bias: if let Some(value) = logit_bias {
204                Option::from(value)
205            } else {
206                None
207            },
208            logprobs: if let Some(value) = logprobs {
209                Option::from(value)
210            } else {
211                None
212            },
213            top_logprobs: if let Some(value) = top_logprobs {
214                Option::from(value)
215            } else {
216                None
217            },
218            max_completion_tokens: if let Some(value) = max_completion_tokens {
219                Option::from(value)
220            } else {
221                None
222            },
223            n: if let Some(value) = n {
224                Option::from(value)
225            } else {
226                None
227            },
228            modalities: if let Some(value) = modalities {
229                Option::from(value)
230            } else {
231                None
232            },
233            presence_penalty: if let Some(value) = presence_penalty {
234                Option::from(value)
235            } else {
236                None
237            },
238            temperature: if let Some(value) = temperature {
239                Option::from(value)
240            } else {
241                None
242            },
243            response_format: if let Some(value) = response_format {
244                Option::from(value)
245            } else {
246                None
247            },
248        }
249    }
250
251    pub fn default() -> Self {
252        Self {
253            model: String::default(),
254            messages: Vec::new(),
255            store: None,
256            frequency_penalty: None,
257            logit_bias: None,
258            logprobs: None,
259            top_logprobs: None,
260            max_completion_tokens: None,
261            n: None,
262            modalities: None,
263            presence_penalty: None,
264            temperature: None,
265            response_format: None,
266        }
267    }
268}
269
270#[derive(Debug, Clone, Deserialize, Serialize)]
271pub struct Choice {
272    pub index: u32,
273    pub message: Message,
274    pub finish_reason: String,
275}
276
277#[derive(Debug, Clone, Deserialize, Serialize)]
278pub struct Usage {
279    pub prompt_tokens: u64,
280    pub completion_tokens: u64,
281    pub total_tokens: u64,
282    pub completion_tokens_details: FxHashMap<String, u64>,
283}
284
285#[derive(Debug, Clone, Deserialize, Serialize)]
286pub struct Response {
287    pub id: String,
288    pub object: String,
289    pub created: u64,
290    pub model: String,
291    pub system_fingerprint: String,
292    pub choices: Vec<Choice>,
293    pub usage: Usage,
294}
295
296pub struct OpenAI {
297    api_key: String,
298    pub completion_body: ChatCompletionRequestBody,
299}
300
301impl OpenAI {
302    pub fn new() -> Self {
303        dotenv().ok();
304        let api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY is not set.");
305        return Self {
306            api_key,
307            completion_body: ChatCompletionRequestBody::default(),
308        };
309    }
310
311    pub fn model_id(&mut self, model_id: String) -> &mut Self {
312        self.completion_body.model = String::from(model_id);
313        return self;
314    }
315
316    pub fn messages(&mut self, messages: Vec<Message>) -> &mut Self {
317        self.completion_body.messages = messages;
318        return self;
319    }
320
321    pub fn store(&mut self, store: bool) -> &mut Self {
322        self.completion_body.store = Option::from(store);
323        return self;
324    }
325
326    pub fn frequency_penalty(&mut self, frequency_penalty: f32) -> &mut Self {
327        self.completion_body.frequency_penalty = Option::from(frequency_penalty);
328        return self;
329    }
330
331    pub fn logit_bias(&mut self, logit_bias: FxHashMap<String, i32>) -> &mut Self {
332        self.completion_body.logit_bias = Option::from(logit_bias);
333        return self;
334    }
335
336    pub fn logprobs(&mut self, logprobs: bool) -> &mut Self {
337        self.completion_body.logprobs = Option::from(logprobs);
338        return self;
339    }
340
341    pub fn top_logprobs(&mut self, top_logprobs: u8) -> &mut Self {
342        self.completion_body.top_logprobs = Option::from(top_logprobs);
343        return self;
344    }
345
346    pub fn max_completion_tokens(&mut self, max_completion_tokens: u64) -> &mut Self {
347        self.completion_body.max_completion_tokens = Option::from(max_completion_tokens);
348        return self;
349    }
350
351    pub fn n(&mut self, n: u32) -> &mut Self {
352        self.completion_body.n = Option::from(n);
353        return self;
354    }
355
356    pub fn modalities(&mut self, modalities: Vec<String>) -> &mut Self {
357        self.completion_body.modalities = Option::from(modalities);
358        return self;
359    }
360
361    pub fn presence_penalty(&mut self, presence_penalty: f32) -> &mut Self {
362        self.completion_body.presence_penalty = Option::from(presence_penalty);
363        return self;
364    }
365
366    pub fn temperature(&mut self, temperature: f32) -> &mut Self {
367        self.completion_body.temperature = Option::from(temperature);
368        return self;
369    }
370
371    pub fn response_format(&mut self, response_format: ResponseFormat) -> &mut Self {
372        self.completion_body.response_format = Option::from(response_format);
373        return self;
374    }
375
376    pub async fn chat(&mut self) -> Result<Response> {
377        // Check if the API key is set & body is built.
378        if self.api_key.is_empty() {
379            return Err(anyhow::Error::msg("API key is not set."));
380        }
381        if self.completion_body.model.is_empty() {
382            return Err(anyhow::Error::msg("Model ID is not set."));
383        }
384        if self.completion_body.messages.is_empty() {
385            return Err(anyhow::Error::msg("Messages are not set."));
386        }
387
388        let body = serde_json::to_string(&self.completion_body)?;
389        let url = "https://api.openai.com/v1/chat/completions";
390
391        let client = request::Client::new();
392        let mut header = request::header::HeaderMap::new();
393        header.insert(
394            "Content-Type",
395            request::header::HeaderValue::from_static("application/json"),
396        );
397        header.insert(
398            "Authorization",
399            request::header::HeaderValue::from_str(&format!("Bearer {}", self.api_key)).unwrap(),
400        );
401        header.insert(
402            "User-Agent",
403            request::header::HeaderValue::from_static("openai-tools-rust/0.1.0"),
404        );
405        let response = client
406            .post(url)
407            .headers(header)
408            .body(body)
409            .send()
410            .await
411            .expect("Failed to send request");
412        if !response.status().is_success() {
413            let status = response.status();
414            let content = response.text().await.expect("Failed to read response text");
415            let e_msg = format!(
416                "Request failed with status: {} CONTENT: {}",
417                status, content
418            );
419            return Err(anyhow::Error::msg(e_msg));
420        }
421        let content = response.text().await.expect("Failed to read response text");
422
423        match serde_json::from_str::<Response>(&content) {
424            Ok(response) => return Ok(response),
425            Err(e) => {
426                let e_msg = format!("Failed to parse JSON: {} CONTENT: {}", e, content);
427                return Err(anyhow::Error::msg(e_msg));
428            }
429        }
430    }
431}
432
433#[cfg(test)]
434mod tests;