aionic/openai/mod.rs
1pub mod audio;
2pub mod chat;
3pub mod embeddings;
4pub mod files;
5pub mod fine_tunes;
6pub mod image;
7mod misc;
8pub mod moderations;
9
10pub use audio::{Audio, Response as AudioResponse, ResponseFormat as AudioResponseFormat};
11
12pub use chat::{Chat, Message, MessageRole};
13use chat::{Response, StreamedReponse};
14pub use embeddings::{Embedding, InputType, Response as EmbeddingResponse};
15pub use files::Files;
16use files::{Data as FileData, DeleteResponse, PromptCompletion, Response as FileResponse};
17pub use fine_tunes::{
18 EventResponse as FineTuneEventResponse, FineTune, ListResponse as FineTuneListResponse,
19 Response as FineTuneResponse,
20};
21use image::Size;
22pub use image::{Image, Response as ImageResponse, ResponseDataType};
23use misc::ModelsResponse;
24pub use misc::{Model, OpenAIError, Usage};
25pub use moderations::{Moderation, Response as ModerationResponse};
26
27use reqwest::multipart::{Form, Part};
28use reqwest::{Body, Client, IntoUrl};
29use tokio_util::codec::{BytesCodec, FramedRead};
30
31use rustyline::error::ReadlineError;
32use rustyline::DefaultEditor;
33use serde::Serialize;
34use std::env;
35use std::error::Error;
36use std::fs;
37use std::io::{self, Write};
38use std::path::Path;
39use std::process::exit;
40
41// =-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
42// = OpenAIConfig TRAIT
43// =-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
44
45pub trait OpenAIConfig: Send + Sync {
46 fn default() -> Self;
47}
48
49impl OpenAIConfig for Chat {
50 fn default() -> Self {
51 Self {
52 model: Self::get_default_model().into(),
53 messages: vec![],
54 functions: None,
55 function_call: None,
56 temperature: Some(Self::get_default_temperature()),
57 top_p: None,
58 n: None,
59 stream: Some(Self::get_default_stream()),
60 stop: None,
61 max_tokens: Some(Self::get_default_max_tokens()),
62 presence_penalty: None,
63 frequency_penalty: None,
64 logit_bias: None,
65 user: None,
66 }
67 }
68}
69
70impl OpenAIConfig for Image {
71 fn default() -> Self {
72 Self {
73 prompt: None,
74 n: Some(Self::get_default_n()),
75 size: Some(Self::get_default_size().into()),
76 response_format: Some(Self::get_default_response_format().into()),
77 user: None,
78 image: None,
79 mask: None,
80 }
81 }
82}
83
84impl OpenAIConfig for Embedding {
85 fn default() -> Self {
86 Self {
87 model: Self::get_default_model().into(),
88 input: InputType::SingleString(String::new()),
89 user: None,
90 }
91 }
92}
93
94impl OpenAIConfig for Audio {
95 fn default() -> Self {
96 Self {
97 file: String::new(),
98 model: Self::get_default_model().into(),
99 prompt: None,
100 response_format: Some(AudioResponseFormat::get_default_response_format()),
101 temperature: Some(0.0),
102 language: None,
103 }
104 }
105}
106
107impl OpenAIConfig for Files {
108 fn default() -> Self {
109 Self {
110 file: None,
111 purpose: None,
112 file_id: None,
113 }
114 }
115}
116
117impl OpenAIConfig for Moderation {
118 fn default() -> Self {
119 Self {
120 input: String::new(),
121 }
122 }
123}
124
125impl OpenAIConfig for FineTune {
126 fn default() -> Self {
127 Self {
128 training_file: String::new(),
129 validation_file: None,
130 model: Some(Self::get_default_model().into()),
131 n_epochs: Some(Self::get_default_n_epochs()),
132 batch_size: None,
133 learning_rate_multiplier: None,
134 prompt_loss_weight: Some(Self::get_default_prompt_loss_weight()),
135 compute_classification_metrics: Some(Self::get_default_compute_classification_metrics()),
136 classification_n_classes: None,
137 classification_positive_class: None,
138 classification_betas: None,
139 suffix: None,
140 }
141 }
142}
143
144// =-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
145// = OpenAI SHARED IMPLEMENTATION
146// =-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
147
148/// The `OpenAI` struct is the main entry point for interacting with the `OpenAI` API.
149/// It contains the API key, the client, and the configuration for the API call,
150/// such as the chat completion endpoint. It also contains a boolean flag to disable
151/// the live stream of the chat endpoint.
152#[derive(Clone, Debug)]
153pub struct OpenAI<C: OpenAIConfig> {
154 /// The HTTP client used to make requests to the `OpenAI` API.
155 pub client: Client,
156
157 /// The API key used to authenticate with the `OpenAI` API.
158 pub api_key: String,
159
160 /// A boolean flag to disable the live stream of the chat endpoint.
161 pub disable_live_stream: bool,
162
163 /// An endpoint specific configuration struct that holds all necessary parameters
164 /// for the API call.
165 pub config: C,
166}
167
168impl<C: OpenAIConfig + Serialize + Sync + Send + std::fmt::Debug> Default for OpenAI<C> {
169 fn default() -> Self {
170 Self::new()
171 }
172}
173
174impl<C: OpenAIConfig + Serialize + std::fmt::Debug> OpenAI<C> {
175 const OPENAI_API_MODELS_URL: &str = "https://api.openai.com/v1/models";
176 pub fn new() -> Self {
177 env::var("OPENAI_API_KEY").map_or_else(
178 |_| {
179 println!("OPENAI_API_KEY environment variable not set");
180 exit(1);
181 },
182 |api_key| {
183 let client = Client::new();
184 Self {
185 client,
186 api_key,
187 disable_live_stream: false,
188 config: C::default(),
189 }
190 },
191 )
192 }
193
194 /// Allows to batch configure the AI assistant with the settings provided in the `Chat` struct.
195 ///
196 /// # Arguments
197 ///
198 /// * `config`: A `Chat` struct that contains the settings for the AI assistant.
199 ///
200 /// # Returns
201 ///
202 /// This function returns the instance of the AI assistant with the new configuration.
203 pub fn with_config(mut self, config: C) -> Self {
204 self.config = config;
205 self
206 }
207
208 /// Disables standard output for the instance of `OpenAi`, which is enabled by default.
209 /// This is only interesting for the chat completion, as it will otherwise print the
210 /// messages of the AI assistant to the terminal.
211 pub fn disable_stdout(mut self) -> Self {
212 self.disable_live_stream = true;
213 self
214 }
215
216 pub fn is_valid_temperature(&mut self, temperature: f64, limit: f64) -> bool {
217 (0.0..=limit).contains(&temperature)
218 }
219
220 async fn _make_post_request<S: IntoUrl + Send + Sync>(
221 &mut self,
222 url: S,
223 ) -> Result<reqwest::Response, Box<dyn Error + Send + Sync>> {
224 let res = self
225 .client
226 .post(url)
227 .header("Content-Type", "application/json")
228 .header("Authorization", format!("Bearer {}", self.api_key))
229 .json(&self.config)
230 .send()
231 .await?;
232 Ok(res)
233 }
234
235 async fn _make_delete_request<S: IntoUrl + Send + Sync>(
236 &mut self,
237 url: S,
238 ) -> Result<reqwest::Response, Box<dyn Error + Send + Sync>> {
239 let res = self
240 .client
241 .delete(url)
242 .header("Authorization", format!("Bearer {}", self.api_key))
243 .send()
244 .await?;
245 Ok(res)
246 }
247
248 async fn _make_get_request<S: IntoUrl + Send + Sync>(
249 &mut self,
250 url: S,
251 ) -> Result<reqwest::Response, Box<dyn Error + Send + Sync>> {
252 let res = self
253 .client
254 .get(url)
255 .header("Content-Type", "application/json")
256 .header("Authorization", format!("Bearer {}", self.api_key))
257 .send()
258 .await?;
259 Ok(res)
260 }
261
262 async fn _make_form_request<S: IntoUrl + Send + Sync>(
263 &mut self,
264 url: S,
265 form: Form,
266 ) -> Result<reqwest::Response, Box<dyn Error + Send + Sync>> {
267 let res = self
268 .client
269 .post(url)
270 .header("Authorization", format!("Bearer {}", self.api_key))
271 .multipart(form)
272 .send()
273 .await?;
274 Ok(res)
275 }
276
277 /// Fetches a list of available models from the `OpenAI` API.
278 ///
279 /// This method sends a GET request to the `OpenAI` API and returns a vector of identifiers of
280 /// all available models.
281 ///
282 /// # Returns
283 ///
284 /// A `Result` which is:
285 /// * `Ok` if the request was successful, carrying a `Vec<String>` of model identifiers.
286 /// * `Err` if the request or the parsing failed, carrying the error of type `Box<dyn std::error::Error + Send + Sync>`.
287 ///
288 /// # Errors
289 ///
290 /// This method will return an error if the GET request fails, or if the response from the
291 /// `OpenAI` API cannot be parsed into a `ModelsResponse`.
292 ///
293 /// # Example
294 ///
295 /// ```rust
296 /// use aionic::openai::{OpenAI, Chat};
297 ///
298 ///
299 /// #[tokio::main]
300 /// async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
301 /// let mut client = OpenAI::<Chat>::new();
302 /// match client.models().await {
303 /// Ok(models) => println!("Models: {:?}", models),
304 /// Err(e) => println!("Error: {}", e),
305 /// }
306 /// Ok(())
307 /// }
308 /// ```
309 ///
310 /// # Note
311 ///
312 /// This method is `async` and needs to be awaited.
313 pub async fn models(
314 &mut self,
315 ) -> Result<Vec<String>, Box<dyn std::error::Error + Send + Sync>> {
316 let resp = self._make_get_request(Self::OPENAI_API_MODELS_URL).await?;
317
318 if !resp.status().is_success() {
319 return Err(Box::new(std::io::Error::new(
320 std::io::ErrorKind::Other,
321 format!("Error: {}", resp.status()),
322 )));
323 }
324
325 let data: ModelsResponse = resp.json().await?;
326 let model_ids: Vec<String> = data.data.into_iter().map(|model| model.id).collect();
327 Ok(model_ids)
328 }
329
330 /// Fetches a specific model by identifier from the `OpenAI` API.
331 ///
332 /// This method sends a GET request to the `OpenAI` API for a specific model and returns the `Model`.
333 ///
334 /// # Parameters
335 ///
336 /// * `model`: A `&str` that represents the name of the model to fetch.
337 ///
338 /// # Returns
339 ///
340 /// A `Result` which is:
341 /// * `Ok` if the request was successful, carrying the `Model`.
342 /// * `Err` if the request or the parsing failed, carrying the error of type `Box<dyn std::error::Error + Send + Sync>`.
343 ///
344 /// # Errors
345 ///
346 /// This method will return an error if the GET request fails, or if the response from the
347 /// `OpenAI` API cannot be parsed into a `Model`.
348 ///
349 /// # Example
350 ///
351 /// ```rust
352 /// use aionic::openai::{OpenAI, Chat};
353 ///
354 /// #[tokio::main]
355 /// async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
356 /// let mut client = OpenAI::<Chat>::new();
357 /// match client.check_model("gpt-3.5-turbo").await {
358 /// Ok(model) => println!("Model: {:?}", model),
359 /// Err(e) => println!("Error: {}", e),
360 /// }
361 /// Ok(())
362 /// }
363 /// ```
364 ///
365 /// # Note
366 ///
367 /// This method is `async` and needs to be awaited.
368 pub async fn check_model(
369 &mut self,
370 model: &str,
371 ) -> Result<Model, Box<dyn std::error::Error + Send + Sync>> {
372 let resp = self
373 ._make_get_request(format!("{}/{}", Self::OPENAI_API_MODELS_URL, model))
374 .await?;
375
376 if !resp.status().is_success() {
377 return Err(Box::new(std::io::Error::new(
378 std::io::ErrorKind::Other,
379 format!("Error: {}", resp.status()),
380 )));
381 }
382 let model: Model = resp.json().await?;
383 Ok(model)
384 }
385
386 /// Creates a file upload part for a multi-part upload operation.
387 ///
388 /// This method reads the file at the given path, prepares it for uploading, and
389 /// returns a `Part` that represents this file in the multi-part upload operation.
390 ///
391 /// # Type Parameters
392 ///
393 /// * `P`: The type of the file path. Must implement the `AsRef<Path>` trait.
394 ///
395 /// # Parameters
396 ///
397 /// * `path`: The path of the file to upload. This can be any type that implements `AsRef<Path>`.
398 ///
399 /// # Returns
400 ///
401 /// A `Result` which is:
402 /// * `Ok` if the file was read successfully and the `Part` was created, carrying the `Part`.
403 /// * `Err` if there was an error reading the file or creating the `Part`, carrying the error of type `Box<dyn Error + Send + Sync>`.
404 ///
405 /// # Errors
406 ///
407 /// This method will return an error if there was an error reading the file at the given path,
408 /// or if there was an error creating the `Part` (for example, if the MIME type was not recognized).
409 ///
410 /// # Example
411 ///
412 /// ```rust
413 /// use aionic::openai::{OpenAI, Chat};
414 ///
415 /// #[tokio::main]
416 /// async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
417 /// let mut client = OpenAI::<Chat>::new();
418 /// match client.create_file_upload_part("path/to/file.txt").await {
419 /// Ok(part) => println!("Part created successfully."),
420 /// Err(e) => println!("Error: {}", e),
421 /// }
422 /// Ok(())
423 /// }
424 /// ```
425 ///
426 /// # Note
427 ///
428 /// This method is `async` and needs to be awaited.
429 pub async fn create_file_upload_part<P: AsRef<Path> + Send>(
430 &mut self,
431 path: P,
432 ) -> Result<Part, Box<dyn Error + Send + Sync>> {
433 let file_name = path.as_ref().to_str().unwrap().to_string();
434 let streamed_body = self._get_streamed_body(path).await?;
435 let part_stream = Part::stream(streamed_body)
436 .file_name(file_name)
437 .mime_str("application/octet-stream")?;
438 Ok(part_stream)
439 }
440
441 async fn _get_streamed_body<P: AsRef<Path> + Send>(
442 &mut self,
443 path: P,
444 ) -> Result<Body, Box<dyn Error + Send + Sync>> {
445 if !path.as_ref().exists() {
446 return Err(Box::new(std::io::Error::new(
447 std::io::ErrorKind::Other,
448 "Image not found",
449 )));
450 }
451 let file_stream_body = tokio::fs::File::open(path).await?;
452 let stream = FramedRead::new(file_stream_body, BytesCodec::new());
453 let body = Body::wrap_stream(stream);
454 Ok(body)
455 }
456
457 /// A helper function to handle potential errors from `OpenAI` API responses.
458 ///
459 /// # Arguments
460 ///
461 /// * `res` - A `Response` object from the `OpenAI` API call.
462 ///
463 /// # Returns
464 ///
465 /// `Result<Response, Box<dyn std::error::Error + Send + Sync>>`:
466 /// Returns the original `Response` object if the status code indicates success.
467 /// If the status code indicates an error, it will attempt to deserialize the response
468 /// into an `OpenAIError` and returns a `std::io::Error` constructed from the error message.
469 pub async fn handle_api_errors(
470 &mut self,
471 res: reqwest::Response,
472 ) -> Result<reqwest::Response, Box<dyn std::error::Error + Send + Sync>> {
473 if res.status().is_success() {
474 Ok(res)
475 } else {
476 let err_resp: OpenAIError = res.json().await?;
477 Err(Box::new(std::io::Error::new(
478 std::io::ErrorKind::Other,
479 err_resp.error.message,
480 )))
481 }
482 }
483}
484
485// =-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
486// = OpenAI CHAT IMPLEMENTATION
487// =-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
488
489impl OpenAI<Chat> {
490 const OPENAI_API_COMPLETIONS_URL: &str = "https://api.openai.com/v1/chat/completions";
491
492 /// Sets the model of the AI assistant.
493 ///
494 /// # Arguments
495 ///
496 /// * `model`: A string that specifies the model name to be used by the AI assistant.
497 ///
498 /// # Returns
499 ///
500 /// This function returns the instance of the AI assistant with the specified model.
501 pub fn set_model<S: Into<String>>(mut self, model: S) -> Self {
502 self.config.model = model.into();
503 self
504 }
505
506 /// Sets the maximum number of tokens that the AI model can generate in a single response.
507 ///
508 /// # Arguments
509 ///
510 /// * `max_tokens`: An unsigned 64-bit integer that specifies the maximum number of tokens
511 /// that the AI model can generate in a single response.
512 ///
513 /// # Returns
514 ///
515 /// This function returns the instance of the AI assistant with the specified maximum number of tokens.
516 pub fn set_max_tokens(mut self, max_tokens: u64) -> Self {
517 self.config.max_tokens = Some(max_tokens);
518 self
519 }
520
521 /// Allows to set the chat history in a specific state.
522 ///
523 /// # Arguments
524 ///
525 /// * `messages`: A vector of `Message` structs.
526 ///
527 /// # Returns
528 ///
529 /// This function returns the instance of the AI assistant with the specified messages.
530 pub fn set_messages(mut self, messages: Vec<Message>) -> Self {
531 self.config.messages = messages;
532 self
533 }
534
535 /// Sets the temperature of the AI model's responses.
536 ///
537 /// The temperature setting adjusts the randomness of the AI's responses.
538 /// Higher values produce more random responses, while lower values produce more deterministic responses.
539 /// The allowed range of values is between 0.0 and 2.0, with 0 being the most deterministic and 1 being the most random.
540 ///
541 /// # Arguments
542 ///
543 /// * `temperature`: A float that specifies the temperature.
544 ///
545 /// # Returns
546 ///
547 /// This function returns the instance of the AI assistant with the specified temperature.
548 pub fn set_temperature(mut self, temperature: f64) -> Self {
549 self.config.temperature = Some(temperature);
550 self
551 }
552
553 /// Sets the streaming configuration of the AI assistant.
554 ///
555 /// If streaming is enabled, the AI assistant will fetch and process the AI's responses as they arrive.
556 /// If it's disabled, the assistant will collect all of the AI's responses at once and return them as a single response.
557 ///
558 /// # Arguments
559 ///
560 /// * `streamed`: A boolean that specifies whether streaming should be enabled.
561 ///
562 /// # Returns
563 ///
564 /// This function returns the instance of the AI assistant with the specified streaming setting.
565 pub fn set_stream_responses(mut self, streamed: bool) -> Self {
566 self.config.stream = Some(streamed);
567 self
568 }
569
570 /// Sets a primer message for the AI assistant.
571 ///
572 /// The primer message is inserted at the beginning of the `messages` vector in the `config` struct.
573 /// This can be used to prime the AI model with a certain context or instruction.
574 ///
575 /// # Arguments
576 ///
577 /// * `primer_msg`: A string that specifies the primer message.
578 ///
579 /// # Returns
580 ///
581 /// This function returns the instance of the AI assistant with the specified primer message.
582 pub fn set_primer<S: Into<String>>(mut self, primer_msg: S) -> Self {
583 let msg = Message::new(&MessageRole::System, primer_msg.into());
584 self.config.messages.insert(0, msg);
585 self
586 }
587
588 /// Returns the last message in the AI assistant's configuration.
589 ///
590 /// # Returns
591 ///
592 /// This function returns an `Option` that contains a reference to the last `Message`
593 /// in the `config` struct if it exists, or `None` if it doesn't.
594 pub fn get_last_message(&self) -> Option<&Message> {
595 self.config.messages.last()
596 }
597
598 /// Clears the messages in the AI assistant's configuration to start from a clean state.
599 /// This is only necessary in very specific cases.
600 ///
601 /// # Returns
602 ///
603 /// This function returns the instance of the AI assistant with no messages in its configuration.
604 pub fn clear_state(mut self) -> Self {
605 self.config.messages.clear();
606 self
607 }
608
609 fn _process_delta(
610 &self,
611 line: &str,
612 answer_text: &mut Vec<String>,
613 ) -> Result<(), Box<dyn Error + Send + Sync>> {
614 line.strip_prefix("data: ").map_or(Ok(()), |chunk| {
615 if chunk.starts_with("[DONE]") {
616 return Ok(());
617 }
618 let serde_chunk: Result<StreamedReponse, _> = serde_json::from_str(chunk);
619 match serde_chunk {
620 Ok(chunk) => {
621 for choice in chunk.choices {
622 if let Some(content) = choice.delta.content {
623 let sanitized_content =
624 content.trim().strip_suffix('\n').unwrap_or(&content);
625 if !self.disable_live_stream {
626 print!("{}", sanitized_content);
627 io::stdout().flush()?;
628 }
629 answer_text.push(sanitized_content.to_string());
630 }
631 }
632 Ok(())
633 }
634 Err(_) => Err(Box::new(std::io::Error::new(
635 std::io::ErrorKind::Other,
636 "Deserialization Error",
637 ))),
638 }
639 })
640 }
641
642 async fn _ask_openai_streamed(
643 &mut self,
644 res: &mut reqwest::Response,
645 answer_text: &mut Vec<String>,
646 ) -> Result<(), Box<dyn Error + Send + Sync>> {
647 print!("AI: ");
648 loop {
649 let chunk = match res.chunk().await {
650 Ok(Some(chunk)) => chunk,
651 Ok(None) => break,
652 Err(e) => return Err(Box::new(e)),
653 };
654 let chunk_str = String::from_utf8_lossy(&chunk);
655 let lines: Vec<&str> = chunk_str.split('\n').collect();
656 for line in lines {
657 self._process_delta(line, answer_text)?;
658 }
659 }
660 println!();
661 Ok(())
662 }
663
664 /// Makes a request to `OpenAI`'s GPT model and retrieves a response based on the provided `prompt`.
665 ///
666 /// This function accepts a prompt, converts it into a string, and sends a request to the `OpenAI` API.
667 /// Depending on the streaming configuration (`is_streamed`), the function either collects all of the AI's responses
668 /// at once, or fetches and processes them as they arrive.
669 ///
670 /// # Arguments
671 ///
672 /// * `prompt`: A value that implements `Into<String>`. This will be converted into a string and sent to the API as the
673 /// prompt for the AI model.
674 ///
675 /// * `persist_state`: If true, the function will push the AI's response to the `messages` vector in the `config` struct.
676 /// If false, it will remove the last message from the `messages` vector.
677 ///
678 /// # Returns
679 ///
680 /// * `Ok(String)`: A success value containing the AI's response as a string.
681 ///
682 /// * `Err(Box<dyn std::error::Error + Send + Sync>)`: An error value. This is a dynamic error, meaning it could represent
683 /// various kinds of failures. The function will return an error if any step in the process fails, such as making the HTTP request,
684 /// parsing the JSON response, or if there's an issue with the streaming process.
685 ///
686 /// # Errors
687 ///
688 /// This function will return an error if the HTTP request fails, the JSON response from the API cannot be parsed, or if
689 /// an error occurs during streaming.
690 ///
691 /// # Examples
692 ///
693 /// ```rust
694 ///
695 /// use aionic::openai::chat::Chat;
696 /// use aionic::openai::OpenAI;
697 ///
698 /// #[tokio::main]
699 /// async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
700 /// let prompt = "Hello, world!";
701 /// let mut client = OpenAI::<Chat>::new();
702 /// let result = client.ask(prompt, true).await;
703 /// match result {
704 /// Ok(response) => println!("{}", response),
705 /// Err(e) => println!("Error: {}", e),
706 /// }
707 /// Ok(())
708 /// }
709 /// ```
710 ///
711 /// # Note
712 ///
713 /// This function is `async` and must be awaited when called.
714 pub async fn ask<P: Into<Message> + Send>(
715 &mut self,
716 prompt: P,
717 persist_state: bool,
718 ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
719 let mut answer_chunks: Vec<String> = Vec::new();
720 let is_streamed = self.config.stream.unwrap_or(false);
721 self.config.messages.push(prompt.into());
722 if let Some(temp) = self.config.temperature {
723 // TODO: Add a log warning
724 if !self.is_valid_temperature(temp, 2.0) {
725 self.config.temperature = Some(2.0);
726 }
727 }
728 let mut r = self
729 ._make_post_request(Self::OPENAI_API_COMPLETIONS_URL)
730 .await?;
731 if is_streamed {
732 self._ask_openai_streamed(&mut r, &mut answer_chunks)
733 .await?;
734 } else {
735 let r = r.json::<Response>().await?;
736 if let Some(choices) = r.choices {
737 for choice in choices {
738 if !self.disable_live_stream {
739 print!("AI: {}\n", choice.message.content);
740 io::stdout().flush()?;
741 }
742 answer_chunks.push(choice.message.content);
743 }
744 }
745 }
746
747 let answer_text = answer_chunks.join("");
748 if persist_state {
749 self.config
750 .messages
751 .push(Message::new(&MessageRole::Assistant, &answer_text));
752 } else {
753 self.config.messages.pop();
754 }
755 Ok(answer_text)
756 }
757
758 /// Starts a chat session with the AI assistant.
759 ///
760 /// This function uses a Readline-style interface for input and output. The user types a message at the `>>> ` prompt,
761 /// and the message is sent to the AI assistant using the `ask` function. The AI's response is then printed to the console.
762 ///
763 /// If the user enters CTRL-C, the function prints "CTRL-C" and exits the chat session.
764 ///
765 /// If the user enters CTRL-D, the function prints "CTRL-D" and exits the chat session.
766 ///
767 /// If there's an error during readline, the function prints the error message and exits the chat session.
768 ///
769 /// # Returns
770 ///
771 /// * `Ok(())`: A success value indicating that the chat session ended normally.
772 ///
773 /// * `Err(Box<dyn std::error::Error + Send + Sync>)`: An error value. This is a dynamic error, meaning it could represent
774 /// various kinds of failures. The function will return an error if any step in the process fails, such as reading a line
775 /// from the console, or if there's an error in the `ask` function.
776 ///
777 /// # Errors
778 ///
779 /// This function will return an error if the readline fails or if there's an error in the `ask` function.
780 ///
781 /// # Examples
782 ///
783 /// ```rust
784 /// use aionic::openai::chat::Chat;
785 /// use aionic::openai::OpenAI;
786 ///
787 /// #[tokio::main]
788 /// async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
789 /// let mut client = OpenAI::<Chat>::new();
790 /// let result = client.chat().await;
791 /// match result {
792 /// Ok(()) => println!("Chat session ended."),
793 /// Err(e) => println!("Error during chat session: {}", e),
794 /// }
795 /// Ok(())
796 /// }
797 /// ```
798 ///
799 /// # Note
800 ///
801 /// This function is `async` and must be awaited when called.
802 pub async fn chat(&mut self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
803 let mut rl = DefaultEditor::new()?;
804 let prompt = ">>> ";
805 loop {
806 let readline = rl.readline(prompt);
807 match readline {
808 Ok(line) => {
809 self.ask(line, true).await?;
810 println!();
811 }
812 Err(ReadlineError::Interrupted) => {
813 println!("CTRL-C");
814 break;
815 }
816 Err(ReadlineError::Eof) => {
817 println!("CTRL-D");
818 break;
819 }
820 Err(err) => {
821 println!("Error: {:?}", err);
822 break;
823 }
824 }
825 }
826 Ok(())
827 }
828}
829
830// =-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
831// = OpenAI IMAGE IMPLEMENTATION
832// =-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
833
834impl OpenAI<Image> {
835 const OPENAI_API_IMAGE_GEN_URL: &str = "https://api.openai.com/v1/images/generations";
836 const OPENAI_API_IMAGE_EDIT_URL: &str = "https://api.openai.com/v1/images/edits";
837 const OPENAI_API_IMAGE_VARIATION_URL: &str = "https://api.openai.com/v1/images/variations";
838
839 /// Allows setting the return format of the response. `ResponseDataType` is an enum with the
840 /// following variants:
841 /// * `Url`: The response will be a vector of URLs to the generated images.
842 /// * `Base64Json`: The response will be a vector of base64 encoded images.
843 pub fn set_response_format(mut self, response_format: &ResponseDataType) -> Self {
844 self.config.response_format = Some(response_format.to_string());
845 self
846 }
847
848 /// Allows setting the number of images to be generated.
849 pub fn set_max_images(mut self, number_of_images: u64) -> Self {
850 self.config.n = Some(number_of_images);
851 self
852 }
853
854 /// Allows setting the dimensions of the generated images.
855 pub fn set_size(mut self, size: &Size) -> Self {
856 self.config.size = Some(size.to_string());
857 self
858 }
859
860 /// Generates an image based on a textual description.
861 ///
862 /// This function sets the prompt to the given string and sends a request to the `OpenAI` API to create an image.
863 /// The function then parses the response and returns a vector of image URLs.
864 ///
865 /// # Arguments
866 ///
867 /// * `prompt`: A string that describes the image to be generated.
868 ///
869 /// # Returns
870 ///
871 /// This function returns a `Result` with a vector of strings on success, each string being a URL to an image.
872 /// If there's an error, it returns a dynamic error.
873 pub async fn create<S: Into<String> + Send>(
874 &mut self,
875 prompt: S,
876 ) -> Result<Vec<String>, Box<dyn Error + Send + Sync>> {
877 self.config.prompt = Some(prompt.into());
878 if self.config.image.is_some() {
879 self.config.image = None;
880 }
881 if self.config.mask.is_some() {
882 self.config.mask = None;
883 }
884 let res: reqwest::Response = self
885 ._make_post_request(Self::OPENAI_API_IMAGE_GEN_URL)
886 .await?;
887 let handle_res = self.handle_api_errors(res).await?;
888 let image_response: ImageResponse = handle_res.json().await?;
889
890 Ok(self._parse_response(&image_response))
891 }
892
893 /// Modifies an existing image based on a textual description.
894 ///
895 /// This function sets the image and optionally the mask, then sets the prompt to the given string and sends a request to the `OpenAI` API to modify the image.
896 /// The function then parses the response and returns a vector of image URLs.
897 ///
898 /// # Arguments
899 ///
900 /// * `prompt`: A string that describes the modifications to be made to the image.
901 /// * `image_file_path`: A string that specifies the path to the image file to be modified.
902 /// * `mask`: An optional string that specifies the path to a mask file. If the mask is not provided, it is set to `None`.
903 ///
904 /// # Returns
905 ///
906 /// This function returns a `Result` with a vector of strings on success, each string being a URL to an image.
907 /// If there's an error, it returns a dynamic error.
908 pub async fn edit<S: Into<String> + Send>(
909 &mut self,
910 prompt: S,
911 image_file_path: S,
912 mask: Option<S>,
913 ) -> Result<Vec<String>, Box<dyn Error + Send + Sync>> {
914 self.config.image = Some(image_file_path.into());
915 if let Some(mask) = mask {
916 self.config.mask = Some(mask.into());
917 }
918 self.config.prompt = Some(prompt.into());
919
920 if let Some(n) = self.config.n {
921 // TODO: Add a warning here
922 if !image::Image::is_valid_n(n) {
923 self.config.n = Some(image::Image::get_default_n());
924 }
925 }
926
927 if let Some(size) = self.config.size.as_ref() {
928 // TODO: Add a warning here
929 if !image::Image::is_valid_size(size) {
930 self.config.size = Some(image::Image::get_default_size().into());
931 }
932 }
933
934 if let Some(response_format) = self.config.response_format.as_ref() {
935 // TODO: Add a warning here
936 if !image::Image::is_valid_response_format(response_format) {
937 self.config.response_format =
938 Some(image::Image::get_default_response_format().into());
939 }
940 }
941
942 let image_response: ImageResponse = self
943 ._make_file_upload_request(Self::OPENAI_API_IMAGE_EDIT_URL)
944 .await?;
945 Ok(self._parse_response(&image_response))
946 }
947
948 /// Generates variations of an existing image.
949 ///
950 /// This function sets the image and sends a request to the `OpenAI` API to create variations of the image.
951 /// The function then parses the response and returns a vector of image URLs.
952 ///
953 /// # Arguments
954 ///
955 /// * `image_file_path`: A string that specifies the path to the image file.
956 ///
957 /// # Returns
958 ///
959 /// This function returns a `Result` with a vector of strings on success, each string being a URL to a new variation of the image.
960 /// If there's an error, it returns a dynamic error.
961 pub async fn variation<S: Into<String> + Send>(
962 &mut self,
963 image_file_path: S,
964 ) -> Result<Vec<String>, Box<dyn Error + Send + Sync>> {
965 self.config.image = Some(image_file_path.into());
966 if self.config.prompt.is_some() {
967 self.config.prompt = None;
968 }
969 if self.config.mask.is_some() {
970 self.config.mask = None;
971 }
972 let image_response: ImageResponse = self
973 ._make_file_upload_request(Self::OPENAI_API_IMAGE_VARIATION_URL)
974 .await?;
975
976 Ok(self._parse_response(&image_response))
977 }
978
979 fn _parse_response(&mut self, image_response: &ImageResponse) -> Vec<String> {
980 image_response
981 .data
982 .iter()
983 .filter_map(|d| {
984 if self.config.response_format == Some("url".into()) {
985 d.url.clone()
986 } else {
987 d.b64_json.clone()
988 }
989 })
990 .collect::<Vec<String>>()
991 }
992
993 async fn _make_file_upload_request<S: IntoUrl + Send + Sync>(
994 &mut self,
995 url: S,
996 ) -> Result<ImageResponse, Box<dyn Error + Send + Sync>> {
997 let file_name = self.config.image.as_ref().unwrap();
998 let file_part_stream = self.create_file_upload_part(file_name.to_string()).await?;
999 let mut form = Form::new().part("image", file_part_stream);
1000
1001 if let Some(prompt) = self.config.prompt.as_ref() {
1002 form = form.text("prompt", prompt.clone());
1003 }
1004 if let Some(mask_name) = self.config.mask.as_ref() {
1005 let mask_part_stream = self.create_file_upload_part(mask_name.to_string()).await?;
1006 form = form.part("mask", mask_part_stream);
1007 }
1008
1009 if let Some(response_format) = self.config.response_format.as_ref() {
1010 form = form.text("response_format", response_format.clone());
1011 }
1012
1013 if let Some(size) = self.config.size.as_ref() {
1014 form = form.text("size", size.clone());
1015 }
1016
1017 if let Some(n) = self.config.n {
1018 form = form.text("n", n.to_string());
1019 }
1020
1021 if let Some(user) = self.config.user.as_ref() {
1022 form = form.text("user", user.clone());
1023 }
1024
1025 let res: reqwest::Response = self._make_form_request(url, form).await?;
1026 let handle_res = self.handle_api_errors(res).await?;
1027 let image_response: ImageResponse = handle_res.json().await?;
1028
1029 Ok(image_response)
1030 }
1031}
1032
1033// =-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
1034// = OpenAI EMBEDDINGS IMPLEMENTATION
1035// =-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
1036
1037impl OpenAI<Embedding> {
1038 const OPENAI_API_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings";
1039
1040 /// Sets the model of the AI assistant.
1041 ///
1042 /// # Arguments
1043 ///
1044 /// * `model`: A string that specifies the model name to be used by the AI assistant.
1045 ///
1046 /// # Returns
1047 ///
1048 /// This function returns the instance of the AI assistant with the specified model.
1049 pub fn set_model<S: Into<String>>(mut self, model: S) -> Self {
1050 self.config.model = model.into();
1051 self
1052 }
1053
1054 /// Sends a POST request to the `OpenAI` API to get embeddings for the given prompt.
1055 ///
1056 /// This method accepts a prompt of type `S` which can be converted into `InputType`
1057 /// (an enum that encapsulates the different types of possible inputs). The method converts
1058 /// the provided prompt into `InputType` and assigns it to the `input` field of the `config`
1059 /// instance variable. It then sends a POST request to the `OpenAI` API and attempts to parse
1060 /// the response as `EmbeddingResponse`.
1061 ///
1062 /// # Type Parameters
1063 ///
1064 /// * `S`: The type of the prompt. Must implement the `Into<InputType>` trait.
1065 ///
1066 /// # Parameters
1067 ///
1068 /// * `prompt`: The prompt for which to get embeddings. Can be a `String`, a `Vec<String>`,
1069 /// a `Vec<u64>`, or a `&str` that is converted into an `InputType`.
1070 ///
1071 /// # Returns
1072 ///
1073 /// A `Result` which is:
1074 /// * `Ok` if the request was successful, carrying the `EmbeddingResponse` which contains the embeddings.
1075 /// * `Err` if the request or the parsing failed, carrying the error of type `Box<dyn std::error::Error + Send + Sync>`.
1076 ///
1077 /// # Errors
1078 ///
1079 /// This method will return an error if the POST request fails, or if the response from the
1080 /// `OpenAI` API cannot be parsed into an `EmbeddingResponse`.
1081 ///
1082 /// # Example
1083 ///
1084 /// ```rust
1085 /// use aionic::openai::embeddings::Embedding;
1086 /// use aionic::openai::OpenAI;
1087 ///
1088 /// #[tokio::main]
1089 /// async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
1090 /// let mut client = OpenAI::<Embedding>::new();
1091 /// let prompt = "Hello, world!";
1092 /// match client.embed(prompt).await {
1093 /// Ok(response) => println!("Embeddings: {:?}", response),
1094 /// Err(e) => println!("Error: {}", e),
1095 /// }
1096 /// Ok(())
1097 /// }
1098 /// ```
1099 ///
1100 /// # Note
1101 ///
1102 /// This method is `async` and needs to be awaited.
1103 pub async fn embed<S: Into<InputType> + Send>(
1104 &mut self,
1105 prompt: S,
1106 ) -> Result<EmbeddingResponse, Box<dyn std::error::Error + Send + Sync>> {
1107 self.config.input = prompt.into();
1108 let res: reqwest::Response = self
1109 ._make_post_request(Self::OPENAI_API_EMBEDDINGS_URL)
1110 .await?;
1111 let handled_res = self.handle_api_errors(res).await?;
1112 let embedding: EmbeddingResponse = handled_res.json().await?;
1113 Ok(embedding)
1114 }
1115}
1116
1117// =-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
1118// = OpenAI AUDIO IMPLEMENTATION
1119// =-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
1120
1121impl OpenAI<Audio> {
1122 const OPENAI_API_TRANSCRIPTION_URL: &str = "https://api.openai.com/v1/audio/transcriptions";
1123 const OPENAI_API_TRANSLATION_URL: &str = "https://api.openai.com/v1/audio/translations";
1124
1125 /// Sets the model of the AI assistant.
1126 ///
1127 /// # Arguments
1128 ///
1129 /// * `model`: A string that specifies the model name to be used by the AI assistant.
1130 ///
1131 /// # Returns
1132 ///
1133 /// This function returns the instance of the AI assistant with the specified model.
1134 pub fn set_model<S: Into<String>>(mut self, model: S) -> Self {
1135 self.config.model = model.into();
1136 self
1137 }
1138
1139 /// Sets the optional prompt to giode the model's style of response.
1140 ///
1141 /// # Arguments
1142 ///
1143 /// * `prompt`: An optional string that specifies the prompt to guide the model's style of response.
1144 ///
1145 /// # Returns
1146 ///
1147 /// This function returns the instance of the AI assistant with the specified prompt
1148 pub fn set_prompt<S: Into<String>>(mut self, prompt: S) -> Self {
1149 self.config.prompt = Some(prompt.into());
1150 self
1151 }
1152
1153 /// Sets the required audio file to be transcribed or translated.
1154 ///
1155 /// # Arguments
1156 ///
1157 /// * `file`: A string that specifies the path to the audio file to be transcribed or translated.
1158 /// The path must be a valid path to a file.
1159 ///
1160 /// # Returns
1161 ///
1162 /// This function returns the instance of the AI assistant with the specified audio file.
1163 fn _set_file<P: AsRef<Path> + Send + Sync>(
1164 &mut self,
1165 file: P,
1166 ) -> Result<&mut Self, Box<dyn std::error::Error + Send + Sync>> {
1167 let path = file.as_ref();
1168 if fs::metadata(path)?.is_file() {
1169 let path_str = path.to_str().ok_or("Path is not valid UTF-8")?;
1170 self.config.file = path_str.to_string();
1171 if self._is_valid_mime_time().is_err() {
1172 return Err(Box::new(std::io::Error::new(
1173 std::io::ErrorKind::InvalidInput,
1174 format!(
1175 "Invalid audio file type. Supported types are {:?}",
1176 Audio::get_supported_file_types()
1177 ),
1178 )));
1179 }
1180 Ok(self)
1181 } else {
1182 Err(Box::new(std::io::Error::new(
1183 std::io::ErrorKind::InvalidInput,
1184 format!("Path is not a file: {}", path.display()),
1185 )))
1186 }
1187 }
1188
1189 /// Sets the optional audio file format to be returned
1190 ///
1191 /// # Arguments
1192 ///
1193 /// * `format`: An optional enum type that specifies the audio file format to be returned.
1194 /// The default is `AudioResponseFormat::Json`..
1195 ///
1196 /// # Returns
1197 ///
1198 /// This function returns the instance of the AI assistant with the specified audio file format.
1199 pub fn set_response_format(&mut self, format: AudioResponseFormat) -> &mut Self {
1200 self.config.response_format = Some(format);
1201 self
1202 }
1203
1204 fn _is_valid_mime_time(&mut self) -> Result<bool, String> {
1205 Audio::is_file_type_supported(&self.config.file)
1206 }
1207
1208 fn _is_valid_model(&mut self) -> bool {
1209 self.config.model == "whisper-1"
1210 }
1211
1212 fn _sanity_checks(&mut self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
1213 if let Some(temp) = self.config.temperature {
1214 if !self.is_valid_temperature(temp, 1.0) {
1215 // TODO: Log warning
1216 self.config.temperature = Some(1.0);
1217 }
1218 }
1219
1220 if !self._is_valid_model() {
1221 return Err(Box::new(std::io::Error::new(
1222 std::io::ErrorKind::InvalidInput,
1223 format!(
1224 "Invalid model. Supported models are {:?}",
1225 Audio::get_supported_models()
1226 ),
1227 )));
1228 }
1229
1230 if let Some(lang) = &self.config.language {
1231 if !Audio::is_valid_language(lang) {
1232 return Err(Box::new(std::io::Error::new(
1233 std::io::ErrorKind::InvalidInput,
1234 format!(
1235 "Invalid language code. Supported language codes are {:?}",
1236 Audio::ISO_639_1_CODES
1237 ),
1238 )));
1239 }
1240 }
1241 Ok(())
1242 }
1243
1244 async fn _form_builder(&mut self) -> Result<Form, Box<dyn std::error::Error + Send + Sync>> {
1245 let file_part_stream = self
1246 .create_file_upload_part(self.config.file.clone())
1247 .await?;
1248 let mut form = Form::new().part("file", file_part_stream);
1249 form = form.text("model", self.config.model.clone());
1250
1251 if let Some(prompt) = self.config.prompt.as_ref() {
1252 form = form.text("prompt", prompt.clone());
1253 }
1254
1255 if let Some(response_format) = self.config.response_format.as_ref() {
1256 form = form.text("response_format", response_format.to_string());
1257 }
1258
1259 if let Some(temp) = self.config.temperature {
1260 form = form.text("temperature", temp.to_string());
1261 }
1262 Ok(form)
1263 }
1264
1265 /// Transcribe an audio file.
1266 ///
1267 /// # Arguments
1268 ///
1269 /// * `audio_file` - The path to the audio file to transcribe.
1270 ///
1271 /// # Returns
1272 ///
1273 /// `Result<AudioResponse, Box<dyn std::error::Error + Send + Sync>>`:
1274 /// An `AudioResponse` object representing the transcription of the audio file,
1275 /// or an error if the request fails.
1276 pub async fn transcribe<P: AsRef<Path> + Sync + Send>(
1277 &mut self,
1278 audio_file: P,
1279 ) -> Result<AudioResponse, Box<dyn std::error::Error + Send + Sync>> {
1280 self._set_file(audio_file)?;
1281 self._sanity_checks()?;
1282 let mut form = self._form_builder().await?;
1283
1284 if let Some(lang) = self.config.language.clone() {
1285 form = form.text("language", lang);
1286 }
1287
1288 let res: reqwest::Response = self
1289 ._make_form_request(Self::OPENAI_API_TRANSCRIPTION_URL, form)
1290 .await?;
1291
1292 let handled_res = self.handle_api_errors(res).await?;
1293 let transcription: AudioResponse = handled_res.json().await?;
1294 Ok(transcription)
1295 }
1296
1297 /// Translate an audio file. Currently only supports translating
1298 /// to English.
1299 ///
1300 /// # Arguments
1301 ///
1302 /// * `audio_file` - The path to the audio file to translate.
1303 ///
1304 /// # Returns
1305 ///
1306 /// `Result<AudioResponse, Box<dyn std::error::Error + Send + Sync>>`:
1307 /// An `AudioResponse` object representing the translation of the audio file,
1308 /// or an error if the request fails.
1309 pub async fn translate<P: AsRef<Path> + Send + Sync>(
1310 &mut self,
1311 audio_file: P,
1312 ) -> Result<AudioResponse, Box<dyn std::error::Error + Send + Sync>> {
1313 self._set_file(audio_file)?;
1314 self._sanity_checks()?;
1315 if self.config.language.is_some() {
1316 self.config.language = None;
1317 }
1318 let form = self._form_builder().await?;
1319 let res: reqwest::Response = self
1320 ._make_form_request(Self::OPENAI_API_TRANSLATION_URL, form)
1321 .await?;
1322 let handled_res = self.handle_api_errors(res).await?;
1323 let translation: AudioResponse = handled_res.json().await?;
1324 Ok(translation)
1325 }
1326}
1327
1328// =-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
1329// = OpenAI FILES IMPLEMENTATION
1330// =-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
1331
1332impl OpenAI<Files> {
1333 const OPENAI_API_LIST_FILES_URL: &str = "https://api.openai.com/v1/files";
1334
1335 /// List all files that have been uploaded.
1336 ///
1337 /// # Returns
1338 ///
1339 /// `Result<FileResponse, Box<dyn std::error::Error + Send + Sync>>`:
1340 /// A `FileResponse` object representing all uploaded files,
1341 /// or an error if the request fails.
1342 pub async fn list(&mut self) -> Result<FileResponse, Box<dyn std::error::Error + Send + Sync>> {
1343 let res: reqwest::Response = self
1344 ._make_get_request(Self::OPENAI_API_LIST_FILES_URL)
1345 .await?;
1346 let handled_res = self.handle_api_errors(res).await?;
1347 let files: FileResponse = handled_res.json().await?;
1348 Ok(files)
1349 }
1350
1351 /// Retrieve the details of a specific file.
1352 ///
1353 /// # Arguments
1354 ///
1355 /// * `file_id` - A string that holds the unique id of the file.
1356 ///
1357 /// # Returns
1358 ///
1359 /// `Result<FileData, Box<dyn std::error::Error + Send + Sync>>`:
1360 /// A `FileData` object representing the file's details,
1361 /// or an error if the request fails.
1362 pub async fn retrieve<S: Into<String> + std::fmt::Display + Sync + Send>(
1363 &mut self,
1364 file_id: S,
1365 ) -> Result<FileData, Box<dyn std::error::Error + Send + Sync>> {
1366 let res: reqwest::Response = self
1367 ._make_get_request(format!("{}/{}", Self::OPENAI_API_LIST_FILES_URL, file_id))
1368 .await?;
1369
1370 let handled_res = self.handle_api_errors(res).await?;
1371 let file: FileData = handled_res.json().await?;
1372 Ok(file)
1373 }
1374
1375 /// Retrieve the content of a specific file.
1376 ///
1377 /// # Arguments
1378 ///
1379 /// * `file_id` - A string that holds the unique id of the file.
1380 ///
1381 /// # Returns
1382 ///
1383 /// `Result<FileData, Box<dyn std::error::Error + Send + Sync>>`:
1384 /// A `FileData` object representing the file's content,
1385 /// or an error if the request fails.
1386 pub async fn retrieve_content<S: Into<String> + std::fmt::Display + Send + Sync>(
1387 &mut self,
1388 file_id: S,
1389 ) -> Result<Vec<PromptCompletion>, Box<dyn std::error::Error + Send + Sync>> {
1390 let res = self
1391 ._make_get_request(format!(
1392 "{}/{}/content",
1393 Self::OPENAI_API_LIST_FILES_URL,
1394 file_id
1395 ))
1396 .await?;
1397
1398 let handled_res = self.handle_api_errors(res).await?;
1399 let files: Vec<PromptCompletion> = handled_res
1400 .text()
1401 .await?
1402 .lines()
1403 .map(serde_json::from_str)
1404 .collect::<Result<Vec<PromptCompletion>, _>>()?;
1405 Ok(files)
1406 }
1407
1408 /// Upload a file to the `OpenAI` API.
1409 ///
1410 /// # Arguments
1411 ///
1412 /// * `file` - The path to the file to upload.
1413 /// * `purpose` - The purpose of the upload (e.g., 'answers', 'questions').
1414 ///
1415 /// # Returns
1416 ///
1417 /// `Result<FileData, Box<dyn std::error::Error + Send + Sync>>`:
1418 /// A `FileData` object representing the uploaded file's details,
1419 /// or an error if the request fails.
1420 pub async fn upload<P: AsRef<Path> + Send + Sync>(
1421 &mut self,
1422 file: P,
1423 ) -> Result<FileData, Box<dyn std::error::Error + Send + Sync>> {
1424 let path = file.as_ref();
1425 if fs::metadata(path)?.is_file() {
1426 let path_str = path.to_str().ok_or("Path is not valid UTF-8")?;
1427 if !std::path::Path::new(path_str)
1428 .extension()
1429 .map_or(false, |ext| ext.eq_ignore_ascii_case("jsonl"))
1430 {
1431 return Err(Box::new(std::io::Error::new(
1432 std::io::ErrorKind::InvalidInput,
1433 format!("File must be a .jsonl file: {}", path.display()),
1434 )));
1435 }
1436 self.config.file = Some(path_str.to_string());
1437 } else {
1438 return Err(Box::new(std::io::Error::new(
1439 std::io::ErrorKind::InvalidInput,
1440 format!("Path is not a file: {}", path.display()),
1441 )));
1442 }
1443
1444 let file_part_stream = self.create_file_upload_part(file).await?;
1445 let mut form = Form::new().part("file", file_part_stream);
1446 form = form.text("purpose", "fine-tune");
1447 let res: reqwest::Response = self
1448 ._make_form_request(Self::OPENAI_API_LIST_FILES_URL, form)
1449 .await?;
1450
1451 let handled_res = self.handle_api_errors(res).await?;
1452 let file_data: FileData = handled_res.json().await?;
1453 Ok(file_data)
1454 }
1455
1456 /// Delete a specific file.
1457 ///
1458 /// # Arguments
1459 ///
1460 /// * `file_id` - A string that holds the unique id of the file.
1461 ///
1462 /// # Returns
1463 ///
1464 /// `Result<DeleteResponse, Box<dyn std::error::Error + Send + Sync>>`:
1465 /// A `DeleteResponse` object representing the response from the delete request,
1466 /// or an error if the request fails.
1467 pub async fn delete<S: Into<String> + std::fmt::Display + Send + Sync>(
1468 &mut self,
1469 file_id: S,
1470 ) -> Result<DeleteResponse, Box<dyn std::error::Error + Send + Sync>> {
1471 let res: reqwest::Response = self
1472 ._make_delete_request(format!("{}/{}", Self::OPENAI_API_LIST_FILES_URL, file_id))
1473 .await?;
1474
1475 let handled_res = self.handle_api_errors(res).await?;
1476 let del_resp: DeleteResponse = handled_res.json().await?;
1477 Ok(del_resp)
1478 }
1479}
1480
1481// =-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
1482// = OpenAI FINE-TUNE IMPLEMENTATION
1483// =-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
1484
1485impl OpenAI<FineTune> {
1486 const OPENAI_API_FINE_TUNE_URL: &str = "https://api.openai.com/v1/fine-tunes";
1487
1488 /// Create a fine-tune from an uploaded `training_file`.
1489 ///
1490 /// # Arguments
1491 ///
1492 /// * `training_file` - A string that holds the unique id of the file.
1493 ///
1494 /// # Returns
1495 ///
1496 /// `Result<FineTuneResponse, Box<dyn std::error::Error + Send + Sync>>`:
1497 /// A `FineTuneResponse` object representing the result of the fine-tune request,
1498 /// or an error if the request fails.
1499 pub async fn create<S: Into<String> + Send + Sync>(
1500 &mut self,
1501 training_file: S,
1502 ) -> Result<FineTuneResponse, Box<dyn std::error::Error + Send + Sync>> {
1503 self.config.training_file = training_file.into();
1504 let res: reqwest::Response = self
1505 ._make_post_request(Self::OPENAI_API_FINE_TUNE_URL)
1506 .await?;
1507
1508 let handled_res = self.handle_api_errors(res).await?;
1509 let fine_tune_resp: FineTuneResponse = handled_res.json().await?;
1510 Ok(fine_tune_resp)
1511 }
1512
1513 /// List all fine-tunes.
1514 ///
1515 /// # Returns
1516 ///
1517 /// `Result<FineTuneListResponse, Box<dyn std::error::Error + Send + Sync>>`:
1518 /// A `FineTuneResponse` object representing the result of the list fine-tunes request,
1519 /// or an error if the request fails.
1520 pub async fn list(
1521 &mut self,
1522 ) -> Result<FineTuneListResponse, Box<dyn std::error::Error + Send + Sync>> {
1523 let res: reqwest::Response = self
1524 ._make_get_request(Self::OPENAI_API_FINE_TUNE_URL)
1525 .await?;
1526
1527 let handled_res = self.handle_api_errors(res).await?;
1528 let res: FineTuneListResponse = handled_res.json().await?;
1529 Ok(res)
1530 }
1531
1532 /// Get a specific fine-tune by its id
1533 ///
1534 /// # Arguments
1535 ///
1536 /// * `fine_tune_id` - A string that holds the unique id of the file.
1537 ///
1538 /// # Returns
1539 ///
1540 /// `Result<FineTuneResponse, Box<dyn std::error::Error + Send + Sync>>`:
1541 /// A `FineTuneResponse` object representing the result of the get fine-tune request,
1542 /// or an error if the request fails.
1543 pub async fn retrieve<S: Into<String> + Send + Sync + std::fmt::Display>(
1544 &mut self,
1545 fine_tune_id: S,
1546 ) -> Result<FineTuneResponse, Box<dyn std::error::Error + Send + Sync>> {
1547 let res: reqwest::Response = self
1548 ._make_get_request(format!(
1549 "{}/{}",
1550 Self::OPENAI_API_FINE_TUNE_URL,
1551 fine_tune_id
1552 ))
1553 .await?;
1554
1555 let handled_res = self.handle_api_errors(res).await?;
1556 let res: FineTuneResponse = handled_res.json().await?;
1557 Ok(res)
1558 }
1559
1560 /// Immediately cancel a fine-tune job.
1561 ///
1562 /// # Arguments
1563 ///
1564 /// * `fine_tune_id` - A string that holds the unique id of the file.
1565 ///
1566 /// # Returns
1567 ///
1568 /// `Result<FineTuneResponse, Box<dyn std::error::Error + Send + Sync>>`:
1569 /// A `FineTuneResponse` object representing the result of the cancel fine-tune request,
1570 /// or an error if the request fails.
1571 pub async fn cancel<S: Into<String> + Send + Sync + std::fmt::Display>(
1572 &mut self,
1573 fine_tune_id: S,
1574 ) -> Result<FineTuneResponse, Box<dyn std::error::Error + Send + Sync>> {
1575 let url = format!("{}/{}/cancel", Self::OPENAI_API_FINE_TUNE_URL, fine_tune_id);
1576 let res = self
1577 .client
1578 .post(url)
1579 .header("Content-Type", "application/json")
1580 .header("Authorization", format!("Bearer {}", self.api_key))
1581 .send()
1582 .await?;
1583
1584 let handled_res = self.handle_api_errors(res).await?;
1585 let res: FineTuneResponse = handled_res.json().await?;
1586 Ok(res)
1587 }
1588
1589 /// Get fine-grained status updates for a fine-tune job.
1590 ///
1591 /// # Arguments
1592 ///
1593 /// * `fine_tune_id` - A string that holds the unique id of the file.
1594 ///
1595 /// # Returns
1596 ///
1597 /// `Result<FineTuneEventResponse, Box<dyn std::error::Error + Send + Sync>>`:
1598 /// A `FineTuneEventResponse` object representing the result of the list fine-tunes request,
1599 /// or an error if the request fails.
1600 pub async fn list_events<S: Into<String> + Send + Sync + std::fmt::Display>(
1601 &mut self,
1602 fine_tune_id: S,
1603 ) -> Result<FineTuneEventResponse, Box<dyn std::error::Error + Send + Sync>> {
1604 let url = format!("{}/{}/events", Self::OPENAI_API_FINE_TUNE_URL, fine_tune_id);
1605 let res = self._make_get_request(url).await?;
1606
1607 let handled_res = self.handle_api_errors(res).await?;
1608 let res: FineTuneEventResponse = handled_res.json().await?;
1609 Ok(res)
1610 }
1611
1612 /// Delete a fine-tuned model. You must have the Owner role in your organization.
1613 ///
1614 /// # Arguments
1615 ///
1616 /// * `model` - The model to delete
1617 ///
1618 /// # Returns
1619 ///
1620 /// `Result<DeleteResponse, Box<dyn std::error::Error + Send + Sync>>`:
1621 /// A `DeleteResponse` object representing the status of the delete request,
1622 /// or an error if the request fails.
1623 pub async fn delete_model<S: Into<String> + Send + Sync + std::fmt::Display>(
1624 &mut self,
1625 model: S,
1626 ) -> Result<DeleteResponse, Box<dyn std::error::Error + Send + Sync>> {
1627 let url = format!("{}/{}", Self::OPENAI_API_MODELS_URL, model);
1628 let res = self._make_delete_request(url).await?;
1629
1630 let handled_res = self.handle_api_errors(res).await?;
1631 let res: DeleteResponse = handled_res.json().await?;
1632 Ok(res)
1633 }
1634}
1635
1636// =-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
1637// = OpenAI MODERATIONS IMPLEMENTATION
1638// =-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
1639
1640impl OpenAI<Moderation> {
1641 const OPENAI_API_MODERATIONS_URL: &str = "https://api.openai.com/v1/moderations";
1642
1643 /// Create moderation for a classification if text violates `OpenAI`'s Content Policy
1644 ///
1645 /// # Arguments
1646 ///
1647 /// * `input` - The text input to classify
1648 ///
1649 /// # Returns
1650 ///
1651 /// `Result<, Box<dyn std::error::Error + Send + Sync>>`:
1652 /// A `ModerationResponse` object representing the result of the moderation request,
1653 /// or an error if the request fails.
1654 pub async fn moderate<S: Into<String> + Send + Sync>(
1655 &mut self,
1656 input: S,
1657 ) -> Result<ModerationResponse, Box<dyn std::error::Error + Send + Sync>> {
1658 self.config.input = input.into();
1659 let res: reqwest::Response = self
1660 ._make_post_request(Self::OPENAI_API_MODERATIONS_URL)
1661 .await?;
1662
1663 let handled_res = self.handle_api_errors(res).await?;
1664 let mod_resp: ModerationResponse = handled_res.json().await?;
1665 Ok(mod_resp)
1666 }
1667}
1668
1669#[cfg(test)]
1670mod tests {
1671 use super::*;
1672
1673 #[tokio::test]
1674 async fn test_get_all_models() {
1675 let mut client = OpenAI::<Chat>::new();
1676 let models = client.models().await;
1677 assert!(models.is_ok());
1678 assert!(models.unwrap().contains(&"gpt-3.5-turbo".to_string()));
1679 }
1680
1681 #[tokio::test]
1682 async fn test_check_model() {
1683 let mut client = OpenAI::<Chat>::new();
1684 let model = client.check_model("gpt-3.5-turbo").await;
1685 assert!(model.is_ok());
1686 }
1687
1688 #[tokio::test]
1689 async fn test_check_model_error() {
1690 let mut client = OpenAI::<Chat>::new();
1691 let model = client.check_model("gpt-turbo").await;
1692 assert!(model.is_err());
1693 }
1694
1695 #[tokio::test]
1696 async fn test_single_request() {
1697 let mut client = OpenAI::<Chat>::new().set_stream_responses(false);
1698 let reply = client.ask("Say this is a test!", false).await;
1699 assert!(reply.is_ok());
1700 assert!(reply.unwrap().contains("This is a test"));
1701 }
1702
1703 #[tokio::test]
1704 async fn test_single_request_streamed() {
1705 let mut client = OpenAI::<Chat>::new();
1706 let reply = client.ask("Say this is a test!", false).await;
1707 assert!(reply.is_ok());
1708 assert!(reply.unwrap().contains("This is a test"));
1709 }
1710
1711 #[tokio::test]
1712 async fn test_create_single_image_url() {
1713 let mut client = OpenAI::<Image>::new();
1714 let images = client.create("A beautiful sunset over the sea.").await;
1715 assert!(images.is_ok());
1716 assert_eq!(images.unwrap().len(), 1);
1717 }
1718
1719 #[tokio::test]
1720 async fn test_create_multiple_image_urls() {
1721 let mut client = OpenAI::<Image>::new().set_max_images(2);
1722 let images = client
1723 .create("A logo for a library written in Rust that deals with AI")
1724 .await;
1725 assert!(images.is_ok());
1726 assert_eq!(images.unwrap().len(), 2);
1727 }
1728
1729 #[tokio::test]
1730 async fn test_create_image_b64_json() {
1731 let mut client = OpenAI::<Image>::new().set_response_format(&ResponseDataType::Base64Json);
1732 let images = client.create("A beautiful sunset over the sea.").await;
1733 assert!(images.is_ok());
1734 assert_eq!(images.unwrap().len(), 1);
1735 }
1736
1737 #[tokio::test]
1738 async fn test_image_variation() {
1739 let mut client = OpenAI::<Image>::new();
1740 let images = client.variation("./img/logo.png").await;
1741 assert!(images.is_ok());
1742 assert_eq!(images.unwrap().len(), 1);
1743 }
1744
1745 #[tokio::test]
1746 async fn test_image_edit() {
1747 let mut client = OpenAI::<Image>::new();
1748 let images = client
1749 .edit("Make the background transparent", "./img/logo.png", None)
1750 .await;
1751 assert!(images.is_ok());
1752 assert_eq!(images.unwrap().len(), 1);
1753 }
1754
1755 #[tokio::test]
1756 async fn test_embedding() {
1757 let mut client = OpenAI::<Embedding>::new();
1758 let embedding = client
1759 .embed("The food was delicious and the waiter...")
1760 .await;
1761 assert!(embedding.is_ok());
1762 assert!(!embedding.unwrap().data.is_empty());
1763 }
1764
1765 #[tokio::test]
1766 async fn test_transcribe() {
1767 let mut client = OpenAI::<Audio>::new();
1768 let transcribe = client.transcribe("examples/samples/sample-1.mp3").await;
1769 assert!(transcribe.is_ok());
1770 }
1771
1772 #[tokio::test]
1773 async fn test_translate() {
1774 let mut client = OpenAI::<Audio>::new();
1775 let translate = client
1776 .translate("examples/samples/colours-german.mp3")
1777 .await;
1778 assert!(translate.is_ok());
1779 }
1780
1781 #[tokio::test]
1782 async fn test_list_files() {
1783 let files = OpenAI::<Files>::new().list().await;
1784 assert!(files.is_ok());
1785 }
1786
1787 #[tokio::test]
1788 async fn test_delete_non_existing_file() {
1789 let files = OpenAI::<Files>::new().delete("invalid_file_id").await;
1790 assert!(files.is_err());
1791 assert_eq!(
1792 files.unwrap_err().to_string(),
1793 "No such File object: invalid_file_id"
1794 );
1795 }
1796
1797 #[tokio::test]
1798 async fn test_upload_non_existing_file() {
1799 let files = OpenAI::<Files>::new().upload("invalid_file").await;
1800 assert!(files.is_err());
1801 assert_eq!(
1802 files.unwrap_err().to_string(),
1803 "No such file or directory (os error 2)"
1804 );
1805 }
1806
1807 #[tokio::test]
1808 async fn test_file_ops() {
1809 let test_file = "examples/samples/test.jsonl";
1810 let mut client = OpenAI::<Files>::new();
1811
1812 // Upload file
1813 let fup = client.upload(test_file).await;
1814 assert!(fup.is_ok());
1815 let file_id = fup.unwrap().id;
1816 println!("{}", file_id);
1817
1818 // Check if file exists online
1819 let files = client.list().await;
1820 assert!(files.is_ok());
1821 assert!(!files.unwrap().data.is_empty());
1822
1823 // Fetch file metadata
1824 let file = client.retrieve(&file_id).await;
1825 assert!(file.is_ok());
1826 assert_eq!(file.unwrap().id, file_id);
1827
1828 // Fetch file contents
1829 let contents = client.retrieve_content(&file_id).await;
1830 assert!(contents.is_ok());
1831 assert_eq!(contents.unwrap().len(), 3);
1832
1833 // Delete file
1834 // Wait for file to be uploaded for 5 seconds
1835 tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
1836 let fdel = client.delete(&file_id).await;
1837 assert!(fdel.is_ok());
1838 assert_eq!(fdel.unwrap().id, file_id);
1839
1840 // Verify no files exist anymore
1841 let files = client.list().await;
1842 assert!(files.is_ok());
1843 assert!(files.unwrap().data.is_empty());
1844 }
1845
1846 #[tokio::test]
1847 async fn test_moderation() {
1848 let moderation = OpenAI::<Moderation>::new()
1849 .moderate("I want to kill them.")
1850 .await;
1851 assert!(moderation.is_ok());
1852 assert!(moderation.unwrap().results[0].categories.violence);
1853 }
1854
1855 #[tokio::test]
1856 async fn test_list_fine_tunes() {
1857 let tunes = OpenAI::<FineTune>::new().list().await;
1858 assert!(tunes.is_ok());
1859 }
1860}