aionic/openai/
mod.rs

1pub mod audio;
2pub mod chat;
3pub mod embeddings;
4pub mod files;
5pub mod fine_tunes;
6pub mod image;
7mod misc;
8pub mod moderations;
9
10pub use audio::{Audio, Response as AudioResponse, ResponseFormat as AudioResponseFormat};
11
12pub use chat::{Chat, Message, MessageRole};
13use chat::{Response, StreamedReponse};
14pub use embeddings::{Embedding, InputType, Response as EmbeddingResponse};
15pub use files::Files;
16use files::{Data as FileData, DeleteResponse, PromptCompletion, Response as FileResponse};
17pub use fine_tunes::{
18    EventResponse as FineTuneEventResponse, FineTune, ListResponse as FineTuneListResponse,
19    Response as FineTuneResponse,
20};
21use image::Size;
22pub use image::{Image, Response as ImageResponse, ResponseDataType};
23use misc::ModelsResponse;
24pub use misc::{Model, OpenAIError, Usage};
25pub use moderations::{Moderation, Response as ModerationResponse};
26
27use reqwest::multipart::{Form, Part};
28use reqwest::{Body, Client, IntoUrl};
29use tokio_util::codec::{BytesCodec, FramedRead};
30
31use rustyline::error::ReadlineError;
32use rustyline::DefaultEditor;
33use serde::Serialize;
34use std::env;
35use std::error::Error;
36use std::fs;
37use std::io::{self, Write};
38use std::path::Path;
39use std::process::exit;
40
41// =-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
42// = OpenAIConfig TRAIT
43// =-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
44
45pub trait OpenAIConfig: Send + Sync {
46    fn default() -> Self;
47}
48
49impl OpenAIConfig for Chat {
50    fn default() -> Self {
51        Self {
52            model: Self::get_default_model().into(),
53            messages: vec![],
54            functions: None,
55            function_call: None,
56            temperature: Some(Self::get_default_temperature()),
57            top_p: None,
58            n: None,
59            stream: Some(Self::get_default_stream()),
60            stop: None,
61            max_tokens: Some(Self::get_default_max_tokens()),
62            presence_penalty: None,
63            frequency_penalty: None,
64            logit_bias: None,
65            user: None,
66        }
67    }
68}
69
70impl OpenAIConfig for Image {
71    fn default() -> Self {
72        Self {
73            prompt: None,
74            n: Some(Self::get_default_n()),
75            size: Some(Self::get_default_size().into()),
76            response_format: Some(Self::get_default_response_format().into()),
77            user: None,
78            image: None,
79            mask: None,
80        }
81    }
82}
83
84impl OpenAIConfig for Embedding {
85    fn default() -> Self {
86        Self {
87            model: Self::get_default_model().into(),
88            input: InputType::SingleString(String::new()),
89            user: None,
90        }
91    }
92}
93
94impl OpenAIConfig for Audio {
95    fn default() -> Self {
96        Self {
97            file: String::new(),
98            model: Self::get_default_model().into(),
99            prompt: None,
100            response_format: Some(AudioResponseFormat::get_default_response_format()),
101            temperature: Some(0.0),
102            language: None,
103        }
104    }
105}
106
107impl OpenAIConfig for Files {
108    fn default() -> Self {
109        Self {
110            file: None,
111            purpose: None,
112            file_id: None,
113        }
114    }
115}
116
117impl OpenAIConfig for Moderation {
118    fn default() -> Self {
119        Self {
120            input: String::new(),
121        }
122    }
123}
124
125impl OpenAIConfig for FineTune {
126    fn default() -> Self {
127        Self {
128            training_file: String::new(),
129            validation_file: None,
130            model: Some(Self::get_default_model().into()),
131            n_epochs: Some(Self::get_default_n_epochs()),
132            batch_size: None,
133            learning_rate_multiplier: None,
134            prompt_loss_weight: Some(Self::get_default_prompt_loss_weight()),
135            compute_classification_metrics: Some(Self::get_default_compute_classification_metrics()),
136            classification_n_classes: None,
137            classification_positive_class: None,
138            classification_betas: None,
139            suffix: None,
140        }
141    }
142}
143
144// =-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
145// = OpenAI SHARED IMPLEMENTATION
146// =-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
147
148/// The `OpenAI` struct is the main entry point for interacting with the `OpenAI` API.
149/// It contains the API key, the client, and the configuration for the API call,
150/// such as the chat completion endpoint. It also contains a boolean flag to disable
151/// the live stream of the chat endpoint.
152#[derive(Clone, Debug)]
153pub struct OpenAI<C: OpenAIConfig> {
154    /// The HTTP client used to make requests to the `OpenAI` API.
155    pub client: Client,
156
157    /// The API key used to authenticate with the `OpenAI` API.
158    pub api_key: String,
159
160    /// A boolean flag to disable the live stream of the chat endpoint.
161    pub disable_live_stream: bool,
162
163    /// An endpoint specific configuration struct that holds all necessary parameters
164    /// for the API call.
165    pub config: C,
166}
167
168impl<C: OpenAIConfig + Serialize + Sync + Send + std::fmt::Debug> Default for OpenAI<C> {
169    fn default() -> Self {
170        Self::new()
171    }
172}
173
174impl<C: OpenAIConfig + Serialize + std::fmt::Debug> OpenAI<C> {
175    const OPENAI_API_MODELS_URL: &str = "https://api.openai.com/v1/models";
176    pub fn new() -> Self {
177        env::var("OPENAI_API_KEY").map_or_else(
178            |_| {
179                println!("OPENAI_API_KEY environment variable not set");
180                exit(1);
181            },
182            |api_key| {
183                let client = Client::new();
184                Self {
185                    client,
186                    api_key,
187                    disable_live_stream: false,
188                    config: C::default(),
189                }
190            },
191        )
192    }
193
194    /// Allows to batch configure the AI assistant with the settings provided in the `Chat` struct.
195    ///
196    /// # Arguments
197    ///
198    /// * `config`: A `Chat` struct that contains the settings for the AI assistant.
199    ///
200    /// # Returns
201    ///
202    /// This function returns the instance of the AI assistant with the new configuration.
203    pub fn with_config(mut self, config: C) -> Self {
204        self.config = config;
205        self
206    }
207
208    /// Disables standard output for the instance of `OpenAi`, which is enabled by default.
209    /// This is only interesting for the chat completion, as it will otherwise print the
210    /// messages of the AI assistant to the terminal.
211    pub fn disable_stdout(mut self) -> Self {
212        self.disable_live_stream = true;
213        self
214    }
215
216    pub fn is_valid_temperature(&mut self, temperature: f64, limit: f64) -> bool {
217        (0.0..=limit).contains(&temperature)
218    }
219
220    async fn _make_post_request<S: IntoUrl + Send + Sync>(
221        &mut self,
222        url: S,
223    ) -> Result<reqwest::Response, Box<dyn Error + Send + Sync>> {
224        let res = self
225            .client
226            .post(url)
227            .header("Content-Type", "application/json")
228            .header("Authorization", format!("Bearer {}", self.api_key))
229            .json(&self.config)
230            .send()
231            .await?;
232        Ok(res)
233    }
234
235    async fn _make_delete_request<S: IntoUrl + Send + Sync>(
236        &mut self,
237        url: S,
238    ) -> Result<reqwest::Response, Box<dyn Error + Send + Sync>> {
239        let res = self
240            .client
241            .delete(url)
242            .header("Authorization", format!("Bearer {}", self.api_key))
243            .send()
244            .await?;
245        Ok(res)
246    }
247
248    async fn _make_get_request<S: IntoUrl + Send + Sync>(
249        &mut self,
250        url: S,
251    ) -> Result<reqwest::Response, Box<dyn Error + Send + Sync>> {
252        let res = self
253            .client
254            .get(url)
255            .header("Content-Type", "application/json")
256            .header("Authorization", format!("Bearer {}", self.api_key))
257            .send()
258            .await?;
259        Ok(res)
260    }
261
262    async fn _make_form_request<S: IntoUrl + Send + Sync>(
263        &mut self,
264        url: S,
265        form: Form,
266    ) -> Result<reqwest::Response, Box<dyn Error + Send + Sync>> {
267        let res = self
268            .client
269            .post(url)
270            .header("Authorization", format!("Bearer {}", self.api_key))
271            .multipart(form)
272            .send()
273            .await?;
274        Ok(res)
275    }
276
277    /// Fetches a list of available models from the `OpenAI` API.
278    ///
279    /// This method sends a GET request to the `OpenAI` API and returns a vector of identifiers of
280    /// all available models.
281    ///
282    /// # Returns
283    ///
284    /// A `Result` which is:
285    /// * `Ok` if the request was successful, carrying a `Vec<String>` of model identifiers.
286    /// * `Err` if the request or the parsing failed, carrying the error of type `Box<dyn std::error::Error + Send + Sync>`.
287    ///
288    /// # Errors
289    ///
290    /// This method will return an error if the GET request fails, or if the response from the
291    /// `OpenAI` API cannot be parsed into a `ModelsResponse`.
292    ///
293    /// # Example
294    ///
295    /// ```rust
296    /// use aionic::openai::{OpenAI, Chat};
297    ///
298    ///
299    /// #[tokio::main]
300    /// async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
301    ///     let mut client = OpenAI::<Chat>::new();
302    ///     match client.models().await {
303    ///         Ok(models) => println!("Models: {:?}", models),
304    ///         Err(e) => println!("Error: {}", e),
305    ///     }
306    ///    Ok(())
307    /// }
308    /// ```
309    ///
310    /// # Note
311    ///
312    /// This method is `async` and needs to be awaited.
313    pub async fn models(
314        &mut self,
315    ) -> Result<Vec<String>, Box<dyn std::error::Error + Send + Sync>> {
316        let resp = self._make_get_request(Self::OPENAI_API_MODELS_URL).await?;
317
318        if !resp.status().is_success() {
319            return Err(Box::new(std::io::Error::new(
320                std::io::ErrorKind::Other,
321                format!("Error: {}", resp.status()),
322            )));
323        }
324
325        let data: ModelsResponse = resp.json().await?;
326        let model_ids: Vec<String> = data.data.into_iter().map(|model| model.id).collect();
327        Ok(model_ids)
328    }
329
330    /// Fetches a specific model by identifier from the `OpenAI` API.
331    ///
332    /// This method sends a GET request to the `OpenAI` API for a specific model and returns the `Model`.
333    ///
334    /// # Parameters
335    ///
336    /// * `model`: A `&str` that represents the name of the model to fetch.
337    ///
338    /// # Returns
339    ///
340    /// A `Result` which is:
341    /// * `Ok` if the request was successful, carrying the `Model`.
342    /// * `Err` if the request or the parsing failed, carrying the error of type `Box<dyn std::error::Error + Send + Sync>`.
343    ///
344    /// # Errors
345    ///
346    /// This method will return an error if the GET request fails, or if the response from the
347    /// `OpenAI` API cannot be parsed into a `Model`.
348    ///
349    /// # Example
350    ///
351    /// ```rust
352    /// use aionic::openai::{OpenAI, Chat};
353    ///
354    /// #[tokio::main]
355    /// async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
356    ///     let mut client = OpenAI::<Chat>::new();
357    ///     match client.check_model("gpt-3.5-turbo").await {
358    ///         Ok(model) => println!("Model: {:?}", model),
359    ///         Err(e) => println!("Error: {}", e),
360    ///     }
361    ///     Ok(())
362    /// }
363    /// ```
364    ///
365    /// # Note
366    ///
367    /// This method is `async` and needs to be awaited.
368    pub async fn check_model(
369        &mut self,
370        model: &str,
371    ) -> Result<Model, Box<dyn std::error::Error + Send + Sync>> {
372        let resp = self
373            ._make_get_request(format!("{}/{}", Self::OPENAI_API_MODELS_URL, model))
374            .await?;
375
376        if !resp.status().is_success() {
377            return Err(Box::new(std::io::Error::new(
378                std::io::ErrorKind::Other,
379                format!("Error: {}", resp.status()),
380            )));
381        }
382        let model: Model = resp.json().await?;
383        Ok(model)
384    }
385
386    /// Creates a file upload part for a multi-part upload operation.
387    ///
388    /// This method reads the file at the given path, prepares it for uploading, and
389    /// returns a `Part` that represents this file in the multi-part upload operation.
390    ///
391    /// # Type Parameters
392    ///
393    /// * `P`: The type of the file path. Must implement the `AsRef<Path>` trait.
394    ///
395    /// # Parameters
396    ///
397    /// * `path`: The path of the file to upload. This can be any type that implements `AsRef<Path>`.
398    ///
399    /// # Returns
400    ///
401    /// A `Result` which is:
402    /// * `Ok` if the file was read successfully and the `Part` was created, carrying the `Part`.
403    /// * `Err` if there was an error reading the file or creating the `Part`, carrying the error of type `Box<dyn Error + Send + Sync>`.
404    ///
405    /// # Errors
406    ///
407    /// This method will return an error if there was an error reading the file at the given path,
408    /// or if there was an error creating the `Part` (for example, if the MIME type was not recognized).
409    ///
410    /// # Example
411    ///
412    /// ```rust
413    /// use aionic::openai::{OpenAI, Chat};
414    ///
415    /// #[tokio::main]
416    /// async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
417    ///     let mut client = OpenAI::<Chat>::new();
418    ///     match client.create_file_upload_part("path/to/file.txt").await {
419    ///         Ok(part) => println!("Part created successfully."),
420    ///         Err(e) => println!("Error: {}", e),
421    ///     }
422    ///     Ok(())
423    /// }
424    /// ```
425    ///
426    /// # Note
427    ///
428    /// This method is `async` and needs to be awaited.
429    pub async fn create_file_upload_part<P: AsRef<Path> + Send>(
430        &mut self,
431        path: P,
432    ) -> Result<Part, Box<dyn Error + Send + Sync>> {
433        let file_name = path.as_ref().to_str().unwrap().to_string();
434        let streamed_body = self._get_streamed_body(path).await?;
435        let part_stream = Part::stream(streamed_body)
436            .file_name(file_name)
437            .mime_str("application/octet-stream")?;
438        Ok(part_stream)
439    }
440
441    async fn _get_streamed_body<P: AsRef<Path> + Send>(
442        &mut self,
443        path: P,
444    ) -> Result<Body, Box<dyn Error + Send + Sync>> {
445        if !path.as_ref().exists() {
446            return Err(Box::new(std::io::Error::new(
447                std::io::ErrorKind::Other,
448                "Image not found",
449            )));
450        }
451        let file_stream_body = tokio::fs::File::open(path).await?;
452        let stream = FramedRead::new(file_stream_body, BytesCodec::new());
453        let body = Body::wrap_stream(stream);
454        Ok(body)
455    }
456
457    /// A helper function to handle potential errors from `OpenAI` API responses.
458    ///
459    /// # Arguments
460    ///
461    /// * `res` - A `Response` object from the `OpenAI` API call.
462    ///
463    /// # Returns
464    ///
465    /// `Result<Response, Box<dyn std::error::Error + Send + Sync>>`:
466    /// Returns the original `Response` object if the status code indicates success.
467    /// If the status code indicates an error, it will attempt to deserialize the response
468    /// into an `OpenAIError` and returns a `std::io::Error` constructed from the error message.
469    pub async fn handle_api_errors(
470        &mut self,
471        res: reqwest::Response,
472    ) -> Result<reqwest::Response, Box<dyn std::error::Error + Send + Sync>> {
473        if res.status().is_success() {
474            Ok(res)
475        } else {
476            let err_resp: OpenAIError = res.json().await?;
477            Err(Box::new(std::io::Error::new(
478                std::io::ErrorKind::Other,
479                err_resp.error.message,
480            )))
481        }
482    }
483}
484
485// =-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
486// = OpenAI CHAT IMPLEMENTATION
487// =-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
488
489impl OpenAI<Chat> {
490    const OPENAI_API_COMPLETIONS_URL: &str = "https://api.openai.com/v1/chat/completions";
491
492    /// Sets the model of the AI assistant.
493    ///
494    /// # Arguments
495    ///
496    /// * `model`: A string that specifies the model name to be used by the AI assistant.
497    ///
498    /// # Returns
499    ///
500    /// This function returns the instance of the AI assistant with the specified model.
501    pub fn set_model<S: Into<String>>(mut self, model: S) -> Self {
502        self.config.model = model.into();
503        self
504    }
505
506    /// Sets the maximum number of tokens that the AI model can generate in a single response.
507    ///
508    /// # Arguments
509    ///
510    /// * `max_tokens`: An unsigned 64-bit integer that specifies the maximum number of tokens
511    /// that the AI model can generate in a single response.
512    ///
513    /// # Returns
514    ///
515    /// This function returns the instance of the AI assistant with the specified maximum number of tokens.
516    pub fn set_max_tokens(mut self, max_tokens: u64) -> Self {
517        self.config.max_tokens = Some(max_tokens);
518        self
519    }
520
521    /// Allows to set the chat history in a specific state.
522    ///
523    /// # Arguments
524    ///
525    /// * `messages`: A vector of `Message` structs.
526    ///
527    /// # Returns
528    ///
529    /// This function returns the instance of the AI assistant with the specified messages.
530    pub fn set_messages(mut self, messages: Vec<Message>) -> Self {
531        self.config.messages = messages;
532        self
533    }
534
535    /// Sets the temperature of the AI model's responses.
536    ///
537    /// The temperature setting adjusts the randomness of the AI's responses.
538    /// Higher values produce more random responses, while lower values produce more deterministic responses.
539    /// The allowed range of values is between 0.0 and 2.0, with 0 being the most deterministic and 1 being the most random.
540    ///
541    /// # Arguments
542    ///
543    /// * `temperature`: A float that specifies the temperature.
544    ///
545    /// # Returns
546    ///
547    /// This function returns the instance of the AI assistant with the specified temperature.
548    pub fn set_temperature(mut self, temperature: f64) -> Self {
549        self.config.temperature = Some(temperature);
550        self
551    }
552
553    /// Sets the streaming configuration of the AI assistant.
554    ///
555    /// If streaming is enabled, the AI assistant will fetch and process the AI's responses as they arrive.
556    /// If it's disabled, the assistant will collect all of the AI's responses at once and return them as a single response.
557    ///
558    /// # Arguments
559    ///
560    /// * `streamed`: A boolean that specifies whether streaming should be enabled.
561    ///
562    /// # Returns
563    ///
564    /// This function returns the instance of the AI assistant with the specified streaming setting.
565    pub fn set_stream_responses(mut self, streamed: bool) -> Self {
566        self.config.stream = Some(streamed);
567        self
568    }
569
570    /// Sets a primer message for the AI assistant.
571    ///
572    /// The primer message is inserted at the beginning of the `messages` vector in the `config` struct.
573    /// This can be used to prime the AI model with a certain context or instruction.
574    ///
575    /// # Arguments
576    ///
577    /// * `primer_msg`: A string that specifies the primer message.
578    ///
579    /// # Returns
580    ///
581    /// This function returns the instance of the AI assistant with the specified primer message.
582    pub fn set_primer<S: Into<String>>(mut self, primer_msg: S) -> Self {
583        let msg = Message::new(&MessageRole::System, primer_msg.into());
584        self.config.messages.insert(0, msg);
585        self
586    }
587
588    /// Returns the last message in the AI assistant's configuration.
589    ///
590    /// # Returns
591    ///
592    /// This function returns an `Option` that contains a reference to the last `Message`
593    /// in the `config` struct if it exists, or `None` if it doesn't.
594    pub fn get_last_message(&self) -> Option<&Message> {
595        self.config.messages.last()
596    }
597
598    /// Clears the messages in the AI assistant's configuration to start from a clean state.
599    /// This is only necessary in very specific cases.
600    ///
601    /// # Returns
602    ///
603    /// This function returns the instance of the AI assistant with no messages in its configuration.
604    pub fn clear_state(mut self) -> Self {
605        self.config.messages.clear();
606        self
607    }
608
609    fn _process_delta(
610        &self,
611        line: &str,
612        answer_text: &mut Vec<String>,
613    ) -> Result<(), Box<dyn Error + Send + Sync>> {
614        line.strip_prefix("data: ").map_or(Ok(()), |chunk| {
615            if chunk.starts_with("[DONE]") {
616                return Ok(());
617            }
618            let serde_chunk: Result<StreamedReponse, _> = serde_json::from_str(chunk);
619            match serde_chunk {
620                Ok(chunk) => {
621                    for choice in chunk.choices {
622                        if let Some(content) = choice.delta.content {
623                            let sanitized_content =
624                                content.trim().strip_suffix('\n').unwrap_or(&content);
625                            if !self.disable_live_stream {
626                                print!("{}", sanitized_content);
627                                io::stdout().flush()?;
628                            }
629                            answer_text.push(sanitized_content.to_string());
630                        }
631                    }
632                    Ok(())
633                }
634                Err(_) => Err(Box::new(std::io::Error::new(
635                    std::io::ErrorKind::Other,
636                    "Deserialization Error",
637                ))),
638            }
639        })
640    }
641
642    async fn _ask_openai_streamed(
643        &mut self,
644        res: &mut reqwest::Response,
645        answer_text: &mut Vec<String>,
646    ) -> Result<(), Box<dyn Error + Send + Sync>> {
647        print!("AI: ");
648        loop {
649            let chunk = match res.chunk().await {
650                Ok(Some(chunk)) => chunk,
651                Ok(None) => break,
652                Err(e) => return Err(Box::new(e)),
653            };
654            let chunk_str = String::from_utf8_lossy(&chunk);
655            let lines: Vec<&str> = chunk_str.split('\n').collect();
656            for line in lines {
657                self._process_delta(line, answer_text)?;
658            }
659        }
660        println!();
661        Ok(())
662    }
663
664    /// Makes a request to `OpenAI`'s GPT model and retrieves a response based on the provided `prompt`.
665    ///
666    /// This function accepts a prompt, converts it into a string, and sends a request to the `OpenAI` API.
667    /// Depending on the streaming configuration (`is_streamed`), the function either collects all of the AI's responses
668    /// at once, or fetches and processes them as they arrive.
669    ///
670    /// # Arguments
671    ///
672    /// * `prompt`: A value that implements `Into<String>`. This will be converted into a string and sent to the API as the
673    /// prompt for the AI model.
674    ///
675    /// * `persist_state`: If true, the function will push the AI's response to the `messages` vector in the `config` struct.
676    /// If false, it will remove the last message from the `messages` vector.
677    ///
678    /// # Returns
679    ///
680    /// * `Ok(String)`: A success value containing the AI's response as a string.
681    ///
682    /// * `Err(Box<dyn std::error::Error + Send + Sync>)`: An error value. This is a dynamic error, meaning it could represent
683    /// various kinds of failures. The function will return an error if any step in the process fails, such as making the HTTP request,
684    /// parsing the JSON response, or if there's an issue with the streaming process.
685    ///
686    /// # Errors
687    ///
688    /// This function will return an error if the HTTP request fails, the JSON response from the API cannot be parsed, or if
689    /// an error occurs during streaming.
690    ///
691    /// # Examples
692    ///
693    /// ```rust
694    ///  
695    /// use aionic::openai::chat::Chat;
696    /// use aionic::openai::OpenAI;
697    ///
698    /// #[tokio::main]
699    /// async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
700    ///     let prompt = "Hello, world!";
701    ///     let mut client = OpenAI::<Chat>::new();
702    ///     let result = client.ask(prompt, true).await;
703    ///     match result {
704    ///         Ok(response) => println!("{}", response),
705    ///         Err(e) => println!("Error: {}", e),
706    ///     }
707    ///     Ok(())
708    ///  }
709    /// ```
710    ///
711    /// # Note
712    ///
713    /// This function is `async` and must be awaited when called.
714    pub async fn ask<P: Into<Message> + Send>(
715        &mut self,
716        prompt: P,
717        persist_state: bool,
718    ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
719        let mut answer_chunks: Vec<String> = Vec::new();
720        let is_streamed = self.config.stream.unwrap_or(false);
721        self.config.messages.push(prompt.into());
722        if let Some(temp) = self.config.temperature {
723            // TODO: Add a log warning
724            if !self.is_valid_temperature(temp, 2.0) {
725                self.config.temperature = Some(2.0);
726            }
727        }
728        let mut r = self
729            ._make_post_request(Self::OPENAI_API_COMPLETIONS_URL)
730            .await?;
731        if is_streamed {
732            self._ask_openai_streamed(&mut r, &mut answer_chunks)
733                .await?;
734        } else {
735            let r = r.json::<Response>().await?;
736            if let Some(choices) = r.choices {
737                for choice in choices {
738                    if !self.disable_live_stream {
739                        print!("AI: {}\n", choice.message.content);
740                        io::stdout().flush()?;
741                    }
742                    answer_chunks.push(choice.message.content);
743                }
744            }
745        }
746
747        let answer_text = answer_chunks.join("");
748        if persist_state {
749            self.config
750                .messages
751                .push(Message::new(&MessageRole::Assistant, &answer_text));
752        } else {
753            self.config.messages.pop();
754        }
755        Ok(answer_text)
756    }
757
758    /// Starts a chat session with the AI assistant.
759    ///
760    /// This function uses a Readline-style interface for input and output. The user types a message at the `>>> ` prompt,
761    /// and the message is sent to the AI assistant using the `ask` function. The AI's response is then printed to the console.
762    ///
763    /// If the user enters CTRL-C, the function prints "CTRL-C" and exits the chat session.
764    ///
765    /// If the user enters CTRL-D, the function prints "CTRL-D" and exits the chat session.
766    ///
767    /// If there's an error during readline, the function prints the error message and exits the chat session.
768    ///
769    /// # Returns
770    ///
771    /// * `Ok(())`: A success value indicating that the chat session ended normally.
772    ///
773    /// * `Err(Box<dyn std::error::Error + Send + Sync>)`: An error value. This is a dynamic error, meaning it could represent
774    /// various kinds of failures. The function will return an error if any step in the process fails, such as reading a line
775    /// from the console, or if there's an error in the `ask` function.
776    ///
777    /// # Errors
778    ///
779    /// This function will return an error if the readline fails or if there's an error in the `ask` function.
780    ///
781    /// # Examples
782    ///
783    /// ```rust
784    /// use aionic::openai::chat::Chat;
785    /// use aionic::openai::OpenAI;
786    ///
787    /// #[tokio::main]
788    /// async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
789    ///     let mut client = OpenAI::<Chat>::new();
790    ///     let result = client.chat().await;
791    ///     match result {
792    ///         Ok(()) => println!("Chat session ended."),
793    ///         Err(e) => println!("Error during chat session: {}", e),
794    ///     }
795    ///     Ok(())
796    /// }
797    /// ```
798    ///
799    /// # Note
800    ///
801    /// This function is `async` and must be awaited when called.
802    pub async fn chat(&mut self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
803        let mut rl = DefaultEditor::new()?;
804        let prompt = ">>> ";
805        loop {
806            let readline = rl.readline(prompt);
807            match readline {
808                Ok(line) => {
809                    self.ask(line, true).await?;
810                    println!();
811                }
812                Err(ReadlineError::Interrupted) => {
813                    println!("CTRL-C");
814                    break;
815                }
816                Err(ReadlineError::Eof) => {
817                    println!("CTRL-D");
818                    break;
819                }
820                Err(err) => {
821                    println!("Error: {:?}", err);
822                    break;
823                }
824            }
825        }
826        Ok(())
827    }
828}
829
830// =-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
831// = OpenAI IMAGE IMPLEMENTATION
832// =-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
833
834impl OpenAI<Image> {
835    const OPENAI_API_IMAGE_GEN_URL: &str = "https://api.openai.com/v1/images/generations";
836    const OPENAI_API_IMAGE_EDIT_URL: &str = "https://api.openai.com/v1/images/edits";
837    const OPENAI_API_IMAGE_VARIATION_URL: &str = "https://api.openai.com/v1/images/variations";
838
839    /// Allows setting the return format of the response. `ResponseDataType` is an enum with the
840    /// following variants:
841    /// * `Url`: The response will be a vector of URLs to the generated images.
842    /// * `Base64Json`: The response will be a vector of base64 encoded images.
843    pub fn set_response_format(mut self, response_format: &ResponseDataType) -> Self {
844        self.config.response_format = Some(response_format.to_string());
845        self
846    }
847
848    /// Allows setting the number of images to be generated.
849    pub fn set_max_images(mut self, number_of_images: u64) -> Self {
850        self.config.n = Some(number_of_images);
851        self
852    }
853
854    /// Allows setting the dimensions of the generated images.
855    pub fn set_size(mut self, size: &Size) -> Self {
856        self.config.size = Some(size.to_string());
857        self
858    }
859
860    /// Generates an image based on a textual description.
861    ///
862    /// This function sets the prompt to the given string and sends a request to the `OpenAI` API to create an image.
863    /// The function then parses the response and returns a vector of image URLs.
864    ///
865    /// # Arguments
866    ///
867    /// * `prompt`: A string that describes the image to be generated.
868    ///
869    /// # Returns
870    ///
871    /// This function returns a `Result` with a vector of strings on success, each string being a URL to an image.
872    /// If there's an error, it returns a dynamic error.
873    pub async fn create<S: Into<String> + Send>(
874        &mut self,
875        prompt: S,
876    ) -> Result<Vec<String>, Box<dyn Error + Send + Sync>> {
877        self.config.prompt = Some(prompt.into());
878        if self.config.image.is_some() {
879            self.config.image = None;
880        }
881        if self.config.mask.is_some() {
882            self.config.mask = None;
883        }
884        let res: reqwest::Response = self
885            ._make_post_request(Self::OPENAI_API_IMAGE_GEN_URL)
886            .await?;
887        let handle_res = self.handle_api_errors(res).await?;
888        let image_response: ImageResponse = handle_res.json().await?;
889
890        Ok(self._parse_response(&image_response))
891    }
892
893    /// Modifies an existing image based on a textual description.
894    ///
895    /// This function sets the image and optionally the mask, then sets the prompt to the given string and sends a request to the `OpenAI` API to modify the image.
896    /// The function then parses the response and returns a vector of image URLs.
897    ///
898    /// # Arguments
899    ///
900    /// * `prompt`: A string that describes the modifications to be made to the image.
901    /// * `image_file_path`: A string that specifies the path to the image file to be modified.
902    /// * `mask`: An optional string that specifies the path to a mask file. If the mask is not provided, it is set to `None`.
903    ///
904    /// # Returns
905    ///
906    /// This function returns a `Result` with a vector of strings on success, each string being a URL to an image.
907    /// If there's an error, it returns a dynamic error.
908    pub async fn edit<S: Into<String> + Send>(
909        &mut self,
910        prompt: S,
911        image_file_path: S,
912        mask: Option<S>,
913    ) -> Result<Vec<String>, Box<dyn Error + Send + Sync>> {
914        self.config.image = Some(image_file_path.into());
915        if let Some(mask) = mask {
916            self.config.mask = Some(mask.into());
917        }
918        self.config.prompt = Some(prompt.into());
919
920        if let Some(n) = self.config.n {
921            // TODO: Add a warning here
922            if !image::Image::is_valid_n(n) {
923                self.config.n = Some(image::Image::get_default_n());
924            }
925        }
926
927        if let Some(size) = self.config.size.as_ref() {
928            // TODO: Add a warning here
929            if !image::Image::is_valid_size(size) {
930                self.config.size = Some(image::Image::get_default_size().into());
931            }
932        }
933
934        if let Some(response_format) = self.config.response_format.as_ref() {
935            // TODO: Add a warning here
936            if !image::Image::is_valid_response_format(response_format) {
937                self.config.response_format =
938                    Some(image::Image::get_default_response_format().into());
939            }
940        }
941
942        let image_response: ImageResponse = self
943            ._make_file_upload_request(Self::OPENAI_API_IMAGE_EDIT_URL)
944            .await?;
945        Ok(self._parse_response(&image_response))
946    }
947
948    /// Generates variations of an existing image.
949    ///
950    /// This function sets the image and sends a request to the `OpenAI` API to create variations of the image.
951    /// The function then parses the response and returns a vector of image URLs.
952    ///
953    /// # Arguments
954    ///
955    /// * `image_file_path`: A string that specifies the path to the image file.
956    ///
957    /// # Returns
958    ///
959    /// This function returns a `Result` with a vector of strings on success, each string being a URL to a new variation of the image.
960    /// If there's an error, it returns a dynamic error.
961    pub async fn variation<S: Into<String> + Send>(
962        &mut self,
963        image_file_path: S,
964    ) -> Result<Vec<String>, Box<dyn Error + Send + Sync>> {
965        self.config.image = Some(image_file_path.into());
966        if self.config.prompt.is_some() {
967            self.config.prompt = None;
968        }
969        if self.config.mask.is_some() {
970            self.config.mask = None;
971        }
972        let image_response: ImageResponse = self
973            ._make_file_upload_request(Self::OPENAI_API_IMAGE_VARIATION_URL)
974            .await?;
975
976        Ok(self._parse_response(&image_response))
977    }
978
979    fn _parse_response(&mut self, image_response: &ImageResponse) -> Vec<String> {
980        image_response
981            .data
982            .iter()
983            .filter_map(|d| {
984                if self.config.response_format == Some("url".into()) {
985                    d.url.clone()
986                } else {
987                    d.b64_json.clone()
988                }
989            })
990            .collect::<Vec<String>>()
991    }
992
993    async fn _make_file_upload_request<S: IntoUrl + Send + Sync>(
994        &mut self,
995        url: S,
996    ) -> Result<ImageResponse, Box<dyn Error + Send + Sync>> {
997        let file_name = self.config.image.as_ref().unwrap();
998        let file_part_stream = self.create_file_upload_part(file_name.to_string()).await?;
999        let mut form = Form::new().part("image", file_part_stream);
1000
1001        if let Some(prompt) = self.config.prompt.as_ref() {
1002            form = form.text("prompt", prompt.clone());
1003        }
1004        if let Some(mask_name) = self.config.mask.as_ref() {
1005            let mask_part_stream = self.create_file_upload_part(mask_name.to_string()).await?;
1006            form = form.part("mask", mask_part_stream);
1007        }
1008
1009        if let Some(response_format) = self.config.response_format.as_ref() {
1010            form = form.text("response_format", response_format.clone());
1011        }
1012
1013        if let Some(size) = self.config.size.as_ref() {
1014            form = form.text("size", size.clone());
1015        }
1016
1017        if let Some(n) = self.config.n {
1018            form = form.text("n", n.to_string());
1019        }
1020
1021        if let Some(user) = self.config.user.as_ref() {
1022            form = form.text("user", user.clone());
1023        }
1024
1025        let res: reqwest::Response = self._make_form_request(url, form).await?;
1026        let handle_res = self.handle_api_errors(res).await?;
1027        let image_response: ImageResponse = handle_res.json().await?;
1028
1029        Ok(image_response)
1030    }
1031}
1032
1033// =-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
1034// = OpenAI EMBEDDINGS IMPLEMENTATION
1035// =-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
1036
1037impl OpenAI<Embedding> {
1038    const OPENAI_API_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings";
1039
1040    /// Sets the model of the AI assistant.
1041    ///
1042    /// # Arguments
1043    ///
1044    /// * `model`: A string that specifies the model name to be used by the AI assistant.
1045    ///
1046    /// # Returns
1047    ///
1048    /// This function returns the instance of the AI assistant with the specified model.
1049    pub fn set_model<S: Into<String>>(mut self, model: S) -> Self {
1050        self.config.model = model.into();
1051        self
1052    }
1053
1054    /// Sends a POST request to the `OpenAI` API to get embeddings for the given prompt.
1055    ///
1056    /// This method accepts a prompt of type `S` which can be converted into `InputType`
1057    /// (an enum that encapsulates the different types of possible inputs). The method converts
1058    /// the provided prompt into `InputType` and assigns it to the `input` field of the `config`
1059    /// instance variable. It then sends a POST request to the `OpenAI` API and attempts to parse
1060    /// the response as `EmbeddingResponse`.
1061    ///
1062    /// # Type Parameters
1063    ///
1064    /// * `S`: The type of the prompt. Must implement the `Into<InputType>` trait.
1065    ///
1066    /// # Parameters
1067    ///
1068    /// * `prompt`: The prompt for which to get embeddings. Can be a `String`, a `Vec<String>`,
1069    /// a `Vec<u64>`, or a `&str` that is converted into an `InputType`.
1070    ///
1071    /// # Returns
1072    ///
1073    /// A `Result` which is:
1074    /// * `Ok` if the request was successful, carrying the `EmbeddingResponse` which contains the embeddings.
1075    /// * `Err` if the request or the parsing failed, carrying the error of type `Box<dyn std::error::Error + Send + Sync>`.
1076    ///
1077    /// # Errors
1078    ///
1079    /// This method will return an error if the POST request fails, or if the response from the
1080    /// `OpenAI` API cannot be parsed into an `EmbeddingResponse`.
1081    ///
1082    /// # Example
1083    ///
1084    /// ```rust
1085    /// use aionic::openai::embeddings::Embedding;
1086    /// use aionic::openai::OpenAI;
1087    ///
1088    /// #[tokio::main]
1089    /// async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
1090    ///     let mut client = OpenAI::<Embedding>::new();
1091    ///     let prompt = "Hello, world!";
1092    ///     match client.embed(prompt).await {
1093    ///         Ok(response) => println!("Embeddings: {:?}", response),
1094    ///         Err(e) => println!("Error: {}", e),
1095    ///     }
1096    ///     Ok(())
1097    /// }
1098    /// ```
1099    ///
1100    /// # Note
1101    ///
1102    /// This method is `async` and needs to be awaited.
1103    pub async fn embed<S: Into<InputType> + Send>(
1104        &mut self,
1105        prompt: S,
1106    ) -> Result<EmbeddingResponse, Box<dyn std::error::Error + Send + Sync>> {
1107        self.config.input = prompt.into();
1108        let res: reqwest::Response = self
1109            ._make_post_request(Self::OPENAI_API_EMBEDDINGS_URL)
1110            .await?;
1111        let handled_res = self.handle_api_errors(res).await?;
1112        let embedding: EmbeddingResponse = handled_res.json().await?;
1113        Ok(embedding)
1114    }
1115}
1116
1117// =-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
1118// = OpenAI AUDIO IMPLEMENTATION
1119// =-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
1120
1121impl OpenAI<Audio> {
1122    const OPENAI_API_TRANSCRIPTION_URL: &str = "https://api.openai.com/v1/audio/transcriptions";
1123    const OPENAI_API_TRANSLATION_URL: &str = "https://api.openai.com/v1/audio/translations";
1124
1125    /// Sets the model of the AI assistant.
1126    ///
1127    /// # Arguments
1128    ///
1129    /// * `model`: A string that specifies the model name to be used by the AI assistant.
1130    ///
1131    /// # Returns
1132    ///
1133    /// This function returns the instance of the AI assistant with the specified model.
1134    pub fn set_model<S: Into<String>>(mut self, model: S) -> Self {
1135        self.config.model = model.into();
1136        self
1137    }
1138
1139    /// Sets the optional prompt to giode the model's style of response.
1140    ///
1141    /// # Arguments
1142    ///
1143    /// * `prompt`: An optional string that specifies the prompt to guide the model's style of response.
1144    ///
1145    /// # Returns
1146    ///
1147    /// This function returns the instance of the AI assistant with the specified prompt
1148    pub fn set_prompt<S: Into<String>>(mut self, prompt: S) -> Self {
1149        self.config.prompt = Some(prompt.into());
1150        self
1151    }
1152
1153    /// Sets the required audio file to be transcribed or translated.
1154    ///
1155    /// # Arguments
1156    ///
1157    /// * `file`: A string that specifies the path to the audio file to be transcribed or translated.
1158    /// The path must be a valid path to a file.
1159    ///
1160    /// # Returns
1161    ///
1162    /// This function returns the instance of the AI assistant with the specified audio file.
1163    fn _set_file<P: AsRef<Path> + Send + Sync>(
1164        &mut self,
1165        file: P,
1166    ) -> Result<&mut Self, Box<dyn std::error::Error + Send + Sync>> {
1167        let path = file.as_ref();
1168        if fs::metadata(path)?.is_file() {
1169            let path_str = path.to_str().ok_or("Path is not valid UTF-8")?;
1170            self.config.file = path_str.to_string();
1171            if self._is_valid_mime_time().is_err() {
1172                return Err(Box::new(std::io::Error::new(
1173                    std::io::ErrorKind::InvalidInput,
1174                    format!(
1175                        "Invalid audio file type. Supported types are {:?}",
1176                        Audio::get_supported_file_types()
1177                    ),
1178                )));
1179            }
1180            Ok(self)
1181        } else {
1182            Err(Box::new(std::io::Error::new(
1183                std::io::ErrorKind::InvalidInput,
1184                format!("Path is not a file: {}", path.display()),
1185            )))
1186        }
1187    }
1188
1189    /// Sets the optional audio file format to be returned
1190    ///
1191    /// # Arguments
1192    ///
1193    /// * `format`: An optional enum type that specifies the audio file format to be returned.
1194    /// The default is `AudioResponseFormat::Json`..
1195    ///
1196    /// # Returns
1197    ///
1198    /// This function returns the instance of the AI assistant with the specified audio file format.
1199    pub fn set_response_format(&mut self, format: AudioResponseFormat) -> &mut Self {
1200        self.config.response_format = Some(format);
1201        self
1202    }
1203
1204    fn _is_valid_mime_time(&mut self) -> Result<bool, String> {
1205        Audio::is_file_type_supported(&self.config.file)
1206    }
1207
1208    fn _is_valid_model(&mut self) -> bool {
1209        self.config.model == "whisper-1"
1210    }
1211
1212    fn _sanity_checks(&mut self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
1213        if let Some(temp) = self.config.temperature {
1214            if !self.is_valid_temperature(temp, 1.0) {
1215                // TODO: Log warning
1216                self.config.temperature = Some(1.0);
1217            }
1218        }
1219
1220        if !self._is_valid_model() {
1221            return Err(Box::new(std::io::Error::new(
1222                std::io::ErrorKind::InvalidInput,
1223                format!(
1224                    "Invalid model. Supported models are {:?}",
1225                    Audio::get_supported_models()
1226                ),
1227            )));
1228        }
1229
1230        if let Some(lang) = &self.config.language {
1231            if !Audio::is_valid_language(lang) {
1232                return Err(Box::new(std::io::Error::new(
1233                    std::io::ErrorKind::InvalidInput,
1234                    format!(
1235                        "Invalid language code. Supported language codes are {:?}",
1236                        Audio::ISO_639_1_CODES
1237                    ),
1238                )));
1239            }
1240        }
1241        Ok(())
1242    }
1243
1244    async fn _form_builder(&mut self) -> Result<Form, Box<dyn std::error::Error + Send + Sync>> {
1245        let file_part_stream = self
1246            .create_file_upload_part(self.config.file.clone())
1247            .await?;
1248        let mut form = Form::new().part("file", file_part_stream);
1249        form = form.text("model", self.config.model.clone());
1250
1251        if let Some(prompt) = self.config.prompt.as_ref() {
1252            form = form.text("prompt", prompt.clone());
1253        }
1254
1255        if let Some(response_format) = self.config.response_format.as_ref() {
1256            form = form.text("response_format", response_format.to_string());
1257        }
1258
1259        if let Some(temp) = self.config.temperature {
1260            form = form.text("temperature", temp.to_string());
1261        }
1262        Ok(form)
1263    }
1264
1265    /// Transcribe an audio file.
1266    ///
1267    /// # Arguments
1268    ///
1269    /// * `audio_file` - The path to the audio file to transcribe.
1270    ///
1271    /// # Returns
1272    ///
1273    /// `Result<AudioResponse, Box<dyn std::error::Error + Send + Sync>>`:
1274    /// An `AudioResponse` object representing the transcription of the audio file,
1275    /// or an error if the request fails.
1276    pub async fn transcribe<P: AsRef<Path> + Sync + Send>(
1277        &mut self,
1278        audio_file: P,
1279    ) -> Result<AudioResponse, Box<dyn std::error::Error + Send + Sync>> {
1280        self._set_file(audio_file)?;
1281        self._sanity_checks()?;
1282        let mut form = self._form_builder().await?;
1283
1284        if let Some(lang) = self.config.language.clone() {
1285            form = form.text("language", lang);
1286        }
1287
1288        let res: reqwest::Response = self
1289            ._make_form_request(Self::OPENAI_API_TRANSCRIPTION_URL, form)
1290            .await?;
1291
1292        let handled_res = self.handle_api_errors(res).await?;
1293        let transcription: AudioResponse = handled_res.json().await?;
1294        Ok(transcription)
1295    }
1296
1297    /// Translate an audio file. Currently only supports translating
1298    /// to English.
1299    ///
1300    /// # Arguments
1301    ///
1302    /// * `audio_file` - The path to the audio file to translate.
1303    ///
1304    /// # Returns
1305    ///
1306    /// `Result<AudioResponse, Box<dyn std::error::Error + Send + Sync>>`:
1307    /// An `AudioResponse` object representing the translation of the audio file,
1308    /// or an error if the request fails.
1309    pub async fn translate<P: AsRef<Path> + Send + Sync>(
1310        &mut self,
1311        audio_file: P,
1312    ) -> Result<AudioResponse, Box<dyn std::error::Error + Send + Sync>> {
1313        self._set_file(audio_file)?;
1314        self._sanity_checks()?;
1315        if self.config.language.is_some() {
1316            self.config.language = None;
1317        }
1318        let form = self._form_builder().await?;
1319        let res: reqwest::Response = self
1320            ._make_form_request(Self::OPENAI_API_TRANSLATION_URL, form)
1321            .await?;
1322        let handled_res = self.handle_api_errors(res).await?;
1323        let translation: AudioResponse = handled_res.json().await?;
1324        Ok(translation)
1325    }
1326}
1327
1328// =-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
1329// = OpenAI FILES IMPLEMENTATION
1330// =-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
1331
1332impl OpenAI<Files> {
1333    const OPENAI_API_LIST_FILES_URL: &str = "https://api.openai.com/v1/files";
1334
1335    /// List all files that have been uploaded.
1336    ///
1337    /// # Returns
1338    ///
1339    /// `Result<FileResponse, Box<dyn std::error::Error + Send + Sync>>`:
1340    /// A `FileResponse` object representing all uploaded files,
1341    /// or an error if the request fails.
1342    pub async fn list(&mut self) -> Result<FileResponse, Box<dyn std::error::Error + Send + Sync>> {
1343        let res: reqwest::Response = self
1344            ._make_get_request(Self::OPENAI_API_LIST_FILES_URL)
1345            .await?;
1346        let handled_res = self.handle_api_errors(res).await?;
1347        let files: FileResponse = handled_res.json().await?;
1348        Ok(files)
1349    }
1350
1351    /// Retrieve the details of a specific file.
1352    ///
1353    /// # Arguments
1354    ///
1355    /// * `file_id` - A string that holds the unique id of the file.
1356    ///
1357    /// # Returns
1358    ///
1359    /// `Result<FileData, Box<dyn std::error::Error + Send + Sync>>`:
1360    /// A `FileData` object representing the file's details,
1361    /// or an error if the request fails.
1362    pub async fn retrieve<S: Into<String> + std::fmt::Display + Sync + Send>(
1363        &mut self,
1364        file_id: S,
1365    ) -> Result<FileData, Box<dyn std::error::Error + Send + Sync>> {
1366        let res: reqwest::Response = self
1367            ._make_get_request(format!("{}/{}", Self::OPENAI_API_LIST_FILES_URL, file_id))
1368            .await?;
1369
1370        let handled_res = self.handle_api_errors(res).await?;
1371        let file: FileData = handled_res.json().await?;
1372        Ok(file)
1373    }
1374
1375    /// Retrieve the content of a specific file.
1376    ///
1377    /// # Arguments
1378    ///
1379    /// * `file_id` - A string that holds the unique id of the file.
1380    ///
1381    /// # Returns
1382    ///
1383    /// `Result<FileData, Box<dyn std::error::Error + Send + Sync>>`:
1384    /// A `FileData` object representing the file's content,
1385    /// or an error if the request fails.
1386    pub async fn retrieve_content<S: Into<String> + std::fmt::Display + Send + Sync>(
1387        &mut self,
1388        file_id: S,
1389    ) -> Result<Vec<PromptCompletion>, Box<dyn std::error::Error + Send + Sync>> {
1390        let res = self
1391            ._make_get_request(format!(
1392                "{}/{}/content",
1393                Self::OPENAI_API_LIST_FILES_URL,
1394                file_id
1395            ))
1396            .await?;
1397
1398        let handled_res = self.handle_api_errors(res).await?;
1399        let files: Vec<PromptCompletion> = handled_res
1400            .text()
1401            .await?
1402            .lines()
1403            .map(serde_json::from_str)
1404            .collect::<Result<Vec<PromptCompletion>, _>>()?;
1405        Ok(files)
1406    }
1407
1408    /// Upload a file to the `OpenAI` API.
1409    ///
1410    /// # Arguments
1411    ///
1412    /// * `file` - The path to the file to upload.
1413    /// * `purpose` - The purpose of the upload (e.g., 'answers', 'questions').
1414    ///
1415    /// # Returns
1416    ///
1417    /// `Result<FileData, Box<dyn std::error::Error + Send + Sync>>`:
1418    /// A `FileData` object representing the uploaded file's details,
1419    /// or an error if the request fails.
1420    pub async fn upload<P: AsRef<Path> + Send + Sync>(
1421        &mut self,
1422        file: P,
1423    ) -> Result<FileData, Box<dyn std::error::Error + Send + Sync>> {
1424        let path = file.as_ref();
1425        if fs::metadata(path)?.is_file() {
1426            let path_str = path.to_str().ok_or("Path is not valid UTF-8")?;
1427            if !std::path::Path::new(path_str)
1428                .extension()
1429                .map_or(false, |ext| ext.eq_ignore_ascii_case("jsonl"))
1430            {
1431                return Err(Box::new(std::io::Error::new(
1432                    std::io::ErrorKind::InvalidInput,
1433                    format!("File must be a .jsonl file: {}", path.display()),
1434                )));
1435            }
1436            self.config.file = Some(path_str.to_string());
1437        } else {
1438            return Err(Box::new(std::io::Error::new(
1439                std::io::ErrorKind::InvalidInput,
1440                format!("Path is not a file: {}", path.display()),
1441            )));
1442        }
1443
1444        let file_part_stream = self.create_file_upload_part(file).await?;
1445        let mut form = Form::new().part("file", file_part_stream);
1446        form = form.text("purpose", "fine-tune");
1447        let res: reqwest::Response = self
1448            ._make_form_request(Self::OPENAI_API_LIST_FILES_URL, form)
1449            .await?;
1450
1451        let handled_res = self.handle_api_errors(res).await?;
1452        let file_data: FileData = handled_res.json().await?;
1453        Ok(file_data)
1454    }
1455
1456    /// Delete a specific file.
1457    ///
1458    /// # Arguments
1459    ///
1460    /// * `file_id` - A string that holds the unique id of the file.
1461    ///
1462    /// # Returns
1463    ///
1464    /// `Result<DeleteResponse, Box<dyn std::error::Error + Send + Sync>>`:
1465    /// A `DeleteResponse` object representing the response from the delete request,
1466    /// or an error if the request fails.
1467    pub async fn delete<S: Into<String> + std::fmt::Display + Send + Sync>(
1468        &mut self,
1469        file_id: S,
1470    ) -> Result<DeleteResponse, Box<dyn std::error::Error + Send + Sync>> {
1471        let res: reqwest::Response = self
1472            ._make_delete_request(format!("{}/{}", Self::OPENAI_API_LIST_FILES_URL, file_id))
1473            .await?;
1474
1475        let handled_res = self.handle_api_errors(res).await?;
1476        let del_resp: DeleteResponse = handled_res.json().await?;
1477        Ok(del_resp)
1478    }
1479}
1480
1481// =-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
1482// = OpenAI FINE-TUNE IMPLEMENTATION
1483// =-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
1484
1485impl OpenAI<FineTune> {
1486    const OPENAI_API_FINE_TUNE_URL: &str = "https://api.openai.com/v1/fine-tunes";
1487
1488    /// Create a fine-tune from an uploaded `training_file`.
1489    ///
1490    /// # Arguments
1491    ///
1492    /// * `training_file` - A string that holds the unique id of the file.
1493    ///
1494    /// # Returns
1495    ///
1496    /// `Result<FineTuneResponse, Box<dyn std::error::Error + Send + Sync>>`:
1497    /// A `FineTuneResponse` object representing the result of the fine-tune request,
1498    /// or an error if the request fails.
1499    pub async fn create<S: Into<String> + Send + Sync>(
1500        &mut self,
1501        training_file: S,
1502    ) -> Result<FineTuneResponse, Box<dyn std::error::Error + Send + Sync>> {
1503        self.config.training_file = training_file.into();
1504        let res: reqwest::Response = self
1505            ._make_post_request(Self::OPENAI_API_FINE_TUNE_URL)
1506            .await?;
1507
1508        let handled_res = self.handle_api_errors(res).await?;
1509        let fine_tune_resp: FineTuneResponse = handled_res.json().await?;
1510        Ok(fine_tune_resp)
1511    }
1512
1513    /// List all fine-tunes.
1514    ///
1515    /// # Returns
1516    ///
1517    /// `Result<FineTuneListResponse, Box<dyn std::error::Error + Send + Sync>>`:
1518    /// A `FineTuneResponse` object representing the result of the list fine-tunes request,
1519    /// or an error if the request fails.
1520    pub async fn list(
1521        &mut self,
1522    ) -> Result<FineTuneListResponse, Box<dyn std::error::Error + Send + Sync>> {
1523        let res: reqwest::Response = self
1524            ._make_get_request(Self::OPENAI_API_FINE_TUNE_URL)
1525            .await?;
1526
1527        let handled_res = self.handle_api_errors(res).await?;
1528        let res: FineTuneListResponse = handled_res.json().await?;
1529        Ok(res)
1530    }
1531
1532    /// Get a specific fine-tune by its id
1533    ///
1534    /// # Arguments
1535    ///
1536    /// * `fine_tune_id` - A string that holds the unique id of the file.
1537    ///
1538    /// # Returns
1539    ///
1540    /// `Result<FineTuneResponse, Box<dyn std::error::Error + Send + Sync>>`:
1541    /// A `FineTuneResponse` object representing the result of the get fine-tune request,
1542    /// or an error if the request fails.
1543    pub async fn retrieve<S: Into<String> + Send + Sync + std::fmt::Display>(
1544        &mut self,
1545        fine_tune_id: S,
1546    ) -> Result<FineTuneResponse, Box<dyn std::error::Error + Send + Sync>> {
1547        let res: reqwest::Response = self
1548            ._make_get_request(format!(
1549                "{}/{}",
1550                Self::OPENAI_API_FINE_TUNE_URL,
1551                fine_tune_id
1552            ))
1553            .await?;
1554
1555        let handled_res = self.handle_api_errors(res).await?;
1556        let res: FineTuneResponse = handled_res.json().await?;
1557        Ok(res)
1558    }
1559
1560    /// Immediately cancel a fine-tune job.
1561    ///
1562    /// # Arguments
1563    ///
1564    /// * `fine_tune_id` - A string that holds the unique id of the file.
1565    ///
1566    /// # Returns
1567    ///
1568    /// `Result<FineTuneResponse, Box<dyn std::error::Error + Send + Sync>>`:
1569    /// A `FineTuneResponse` object representing the result of the cancel fine-tune request,
1570    /// or an error if the request fails.
1571    pub async fn cancel<S: Into<String> + Send + Sync + std::fmt::Display>(
1572        &mut self,
1573        fine_tune_id: S,
1574    ) -> Result<FineTuneResponse, Box<dyn std::error::Error + Send + Sync>> {
1575        let url = format!("{}/{}/cancel", Self::OPENAI_API_FINE_TUNE_URL, fine_tune_id);
1576        let res = self
1577            .client
1578            .post(url)
1579            .header("Content-Type", "application/json")
1580            .header("Authorization", format!("Bearer {}", self.api_key))
1581            .send()
1582            .await?;
1583
1584        let handled_res = self.handle_api_errors(res).await?;
1585        let res: FineTuneResponse = handled_res.json().await?;
1586        Ok(res)
1587    }
1588
1589    /// Get fine-grained status updates for a fine-tune job.
1590    ///
1591    /// # Arguments
1592    ///
1593    /// * `fine_tune_id` - A string that holds the unique id of the file.
1594    ///
1595    /// # Returns
1596    ///
1597    /// `Result<FineTuneEventResponse, Box<dyn std::error::Error + Send + Sync>>`:
1598    /// A `FineTuneEventResponse` object representing the result of the list fine-tunes request,
1599    /// or an error if the request fails.
1600    pub async fn list_events<S: Into<String> + Send + Sync + std::fmt::Display>(
1601        &mut self,
1602        fine_tune_id: S,
1603    ) -> Result<FineTuneEventResponse, Box<dyn std::error::Error + Send + Sync>> {
1604        let url = format!("{}/{}/events", Self::OPENAI_API_FINE_TUNE_URL, fine_tune_id);
1605        let res = self._make_get_request(url).await?;
1606
1607        let handled_res = self.handle_api_errors(res).await?;
1608        let res: FineTuneEventResponse = handled_res.json().await?;
1609        Ok(res)
1610    }
1611
1612    /// Delete a fine-tuned model. You must have the Owner role in your organization.
1613    ///
1614    /// # Arguments
1615    ///
1616    /// * `model` - The model to delete
1617    ///
1618    /// # Returns
1619    ///
1620    /// `Result<DeleteResponse, Box<dyn std::error::Error + Send + Sync>>`:
1621    /// A `DeleteResponse` object representing the status of the delete request,
1622    /// or an error if the request fails.
1623    pub async fn delete_model<S: Into<String> + Send + Sync + std::fmt::Display>(
1624        &mut self,
1625        model: S,
1626    ) -> Result<DeleteResponse, Box<dyn std::error::Error + Send + Sync>> {
1627        let url = format!("{}/{}", Self::OPENAI_API_MODELS_URL, model);
1628        let res = self._make_delete_request(url).await?;
1629
1630        let handled_res = self.handle_api_errors(res).await?;
1631        let res: DeleteResponse = handled_res.json().await?;
1632        Ok(res)
1633    }
1634}
1635
1636// =-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
1637// = OpenAI MODERATIONS IMPLEMENTATION
1638// =-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
1639
1640impl OpenAI<Moderation> {
1641    const OPENAI_API_MODERATIONS_URL: &str = "https://api.openai.com/v1/moderations";
1642
1643    /// Create moderation for a classification if text violates `OpenAI`'s Content Policy
1644    ///
1645    /// # Arguments
1646    ///
1647    /// * `input` - The text input to classify
1648    ///
1649    /// # Returns
1650    ///
1651    /// `Result<, Box<dyn std::error::Error + Send + Sync>>`:
1652    /// A `ModerationResponse` object representing the result of the moderation request,
1653    /// or an error if the request fails.
1654    pub async fn moderate<S: Into<String> + Send + Sync>(
1655        &mut self,
1656        input: S,
1657    ) -> Result<ModerationResponse, Box<dyn std::error::Error + Send + Sync>> {
1658        self.config.input = input.into();
1659        let res: reqwest::Response = self
1660            ._make_post_request(Self::OPENAI_API_MODERATIONS_URL)
1661            .await?;
1662
1663        let handled_res = self.handle_api_errors(res).await?;
1664        let mod_resp: ModerationResponse = handled_res.json().await?;
1665        Ok(mod_resp)
1666    }
1667}
1668
1669#[cfg(test)]
1670mod tests {
1671    use super::*;
1672
1673    #[tokio::test]
1674    async fn test_get_all_models() {
1675        let mut client = OpenAI::<Chat>::new();
1676        let models = client.models().await;
1677        assert!(models.is_ok());
1678        assert!(models.unwrap().contains(&"gpt-3.5-turbo".to_string()));
1679    }
1680
1681    #[tokio::test]
1682    async fn test_check_model() {
1683        let mut client = OpenAI::<Chat>::new();
1684        let model = client.check_model("gpt-3.5-turbo").await;
1685        assert!(model.is_ok());
1686    }
1687
1688    #[tokio::test]
1689    async fn test_check_model_error() {
1690        let mut client = OpenAI::<Chat>::new();
1691        let model = client.check_model("gpt-turbo").await;
1692        assert!(model.is_err());
1693    }
1694
1695    #[tokio::test]
1696    async fn test_single_request() {
1697        let mut client = OpenAI::<Chat>::new().set_stream_responses(false);
1698        let reply = client.ask("Say this is a test!", false).await;
1699        assert!(reply.is_ok());
1700        assert!(reply.unwrap().contains("This is a test"));
1701    }
1702
1703    #[tokio::test]
1704    async fn test_single_request_streamed() {
1705        let mut client = OpenAI::<Chat>::new();
1706        let reply = client.ask("Say this is a test!", false).await;
1707        assert!(reply.is_ok());
1708        assert!(reply.unwrap().contains("This is a test"));
1709    }
1710
1711    #[tokio::test]
1712    async fn test_create_single_image_url() {
1713        let mut client = OpenAI::<Image>::new();
1714        let images = client.create("A beautiful sunset over the sea.").await;
1715        assert!(images.is_ok());
1716        assert_eq!(images.unwrap().len(), 1);
1717    }
1718
1719    #[tokio::test]
1720    async fn test_create_multiple_image_urls() {
1721        let mut client = OpenAI::<Image>::new().set_max_images(2);
1722        let images = client
1723            .create("A logo for a library written in Rust that deals with AI")
1724            .await;
1725        assert!(images.is_ok());
1726        assert_eq!(images.unwrap().len(), 2);
1727    }
1728
1729    #[tokio::test]
1730    async fn test_create_image_b64_json() {
1731        let mut client = OpenAI::<Image>::new().set_response_format(&ResponseDataType::Base64Json);
1732        let images = client.create("A beautiful sunset over the sea.").await;
1733        assert!(images.is_ok());
1734        assert_eq!(images.unwrap().len(), 1);
1735    }
1736
1737    #[tokio::test]
1738    async fn test_image_variation() {
1739        let mut client = OpenAI::<Image>::new();
1740        let images = client.variation("./img/logo.png").await;
1741        assert!(images.is_ok());
1742        assert_eq!(images.unwrap().len(), 1);
1743    }
1744
1745    #[tokio::test]
1746    async fn test_image_edit() {
1747        let mut client = OpenAI::<Image>::new();
1748        let images = client
1749            .edit("Make the background transparent", "./img/logo.png", None)
1750            .await;
1751        assert!(images.is_ok());
1752        assert_eq!(images.unwrap().len(), 1);
1753    }
1754
1755    #[tokio::test]
1756    async fn test_embedding() {
1757        let mut client = OpenAI::<Embedding>::new();
1758        let embedding = client
1759            .embed("The food was delicious and the waiter...")
1760            .await;
1761        assert!(embedding.is_ok());
1762        assert!(!embedding.unwrap().data.is_empty());
1763    }
1764
1765    #[tokio::test]
1766    async fn test_transcribe() {
1767        let mut client = OpenAI::<Audio>::new();
1768        let transcribe = client.transcribe("examples/samples/sample-1.mp3").await;
1769        assert!(transcribe.is_ok());
1770    }
1771
1772    #[tokio::test]
1773    async fn test_translate() {
1774        let mut client = OpenAI::<Audio>::new();
1775        let translate = client
1776            .translate("examples/samples/colours-german.mp3")
1777            .await;
1778        assert!(translate.is_ok());
1779    }
1780
1781    #[tokio::test]
1782    async fn test_list_files() {
1783        let files = OpenAI::<Files>::new().list().await;
1784        assert!(files.is_ok());
1785    }
1786
1787    #[tokio::test]
1788    async fn test_delete_non_existing_file() {
1789        let files = OpenAI::<Files>::new().delete("invalid_file_id").await;
1790        assert!(files.is_err());
1791        assert_eq!(
1792            files.unwrap_err().to_string(),
1793            "No such File object: invalid_file_id"
1794        );
1795    }
1796
1797    #[tokio::test]
1798    async fn test_upload_non_existing_file() {
1799        let files = OpenAI::<Files>::new().upload("invalid_file").await;
1800        assert!(files.is_err());
1801        assert_eq!(
1802            files.unwrap_err().to_string(),
1803            "No such file or directory (os error 2)"
1804        );
1805    }
1806
1807    #[tokio::test]
1808    async fn test_file_ops() {
1809        let test_file = "examples/samples/test.jsonl";
1810        let mut client = OpenAI::<Files>::new();
1811
1812        // Upload file
1813        let fup = client.upload(test_file).await;
1814        assert!(fup.is_ok());
1815        let file_id = fup.unwrap().id;
1816        println!("{}", file_id);
1817
1818        // Check if file exists online
1819        let files = client.list().await;
1820        assert!(files.is_ok());
1821        assert!(!files.unwrap().data.is_empty());
1822
1823        // Fetch file metadata
1824        let file = client.retrieve(&file_id).await;
1825        assert!(file.is_ok());
1826        assert_eq!(file.unwrap().id, file_id);
1827
1828        // Fetch file contents
1829        let contents = client.retrieve_content(&file_id).await;
1830        assert!(contents.is_ok());
1831        assert_eq!(contents.unwrap().len(), 3);
1832
1833        // Delete file
1834        // Wait for file to be uploaded for 5 seconds
1835        tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
1836        let fdel = client.delete(&file_id).await;
1837        assert!(fdel.is_ok());
1838        assert_eq!(fdel.unwrap().id, file_id);
1839
1840        // Verify no files exist anymore
1841        let files = client.list().await;
1842        assert!(files.is_ok());
1843        assert!(files.unwrap().data.is_empty());
1844    }
1845
1846    #[tokio::test]
1847    async fn test_moderation() {
1848        let moderation = OpenAI::<Moderation>::new()
1849            .moderate("I want to kill them.")
1850            .await;
1851        assert!(moderation.is_ok());
1852        assert!(moderation.unwrap().results[0].categories.violence);
1853    }
1854
1855    #[tokio::test]
1856    async fn test_list_fine_tunes() {
1857        let tunes = OpenAI::<FineTune>::new().list().await;
1858        assert!(tunes.is_ok());
1859    }
1860}