ru_openai/
api.rs

1
2use std::fmt::{Display, Formatter};
3use std::fs;
4use std::pin::Pin;
5use crate::{OpenAIApiError, ReturnErrorType, ErrorInfo, stream};
6use crate::configuration::Configuration;
7use serde_derive::{Deserialize, Serialize};
8use reqwest::{Method, multipart::Part};
9use tracing::*;
10use futures::{Stream};
11use reqwest_eventsource::{RequestBuilderExt};
12
13
14#[derive(Deserialize, Serialize, Debug)]
15pub struct Permission {
16    pub id: String,
17    pub object: String,
18    pub created: i64,
19    pub allow_create_engine: bool,
20    pub allow_sampling: bool,
21    pub allow_logprobs: bool,
22    pub allow_search_indices: bool,
23    pub allow_view: bool,
24    pub allow_fine_tuning: bool,
25    pub organization: String,
26    pub group: Option<String>,
27    pub is_blocking: bool,
28}
29
30#[derive(Deserialize, Serialize, Debug)]
31pub struct ModelInfo {
32    pub id: String,
33    pub object: String,
34    pub owned_by: String,
35    pub permission: Vec<Permission>,
36    pub root: String,
37    pub parent: Option<String>,
38}
39
40#[derive(Deserialize, Serialize, Debug)]
41pub struct ListModelsResponse {
42    pub data: Vec<ModelInfo>,
43    pub object: String,
44}
45
46pub type RetrieveModelResponse = ModelInfo;
47
48#[derive(Deserialize, Serialize, Debug, Default)]
49pub struct CreateCompletionRequest {
50    /// ID of the model to use. 
51    /// You can use the List models API to see all of your available models, or see our Model overview for descriptions of them.
52    pub model: String,
53    /// The prompt(s) to generate completions for, encoded as a string, array of strings, array of tokens, or array of token arrays.
54    /// Note that <|endoftext|> is the document separator that the model sees during training, so if a prompt is not specified the model will generate as if from the beginning of a new document.
55    #[serde(skip_serializing_if = "Option::is_none")]
56    pub prompt: Option<Vec<String>>,
57    /// The suffix that comes after a completion of inserted text.
58    #[serde(skip_serializing_if = "Option::is_none")]
59    pub suffix: Option<String>,
60    /// The maximum number of tokens to generate in the completion.
61    /// The token count of your prompt plus max_tokens cannot exceed the model's context length. 
62    /// Most models have a context length of 2048 tokens (except for the newest models, which support 4096).
63    #[serde(skip_serializing_if = "Option::is_none")]
64    pub max_tokens: Option<u64>,
65    /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.
66    /// We generally recommend altering this or top_p but not both.
67    #[serde(skip_serializing_if = "Option::is_none")]
68    pub temperature: Option<f32>,
69    /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. 
70    /// So 0.1 means only the tokens comprising the top 10% probability mass are considered.
71    /// We generally recommend altering this or temperature but not both.
72    #[serde(skip_serializing_if = "Option::is_none")]
73    pub top_p: Option<f32>,
74    /// How many completions to generate for each prompt.
75    /// Note: Because this parameter generates many completions, it can quickly consume your token quota. 
76    /// Use carefully and ensure that you have reasonable settings for max_tokens and stop.
77    #[serde(skip_serializing_if = "Option::is_none")]
78    pub n: Option<u16>,
79    /// Whether to stream back partial progress. If set, tokens will be sent as data-only server-sent events as they become available, with the stream terminated by a data: [DONE] message.
80    #[serde(skip_serializing_if = "Option::is_none")]
81    pub stream: Option<bool>,
82    /// Include the log probabilities on the logprobs most likely tokens, as well the chosen tokens. For example, if logprobs is 5, the API will return a list of the 5 most likely tokens. 
83    /// The API will always return the logprob of the sampled token, so there may be up to logprobs+1 elements in the response.
84    /// The maximum value for logprobs is 5. If you need more than this, please contact us through our Help center and describe your use case.
85    #[serde(skip_serializing_if = "Option::is_none")]
86    pub logprobs: Option<i16>,
87    /// Echo back the prompt in addition to the completion
88    #[serde(skip_serializing_if = "Option::is_none")]
89    pub echo: Option<bool>,
90    /// Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence.
91    #[serde(skip_serializing_if = "Option::is_none")]
92    pub stop: Option<Vec<String>>,
93    /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
94    #[serde(skip_serializing_if = "Option::is_none")]
95    pub presence_penalty: Option<f32>,
96    /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.
97    #[serde(skip_serializing_if = "Option::is_none")]
98    pub frequency_penalty: Option<f32>,
99    /// Generates best_of completions server-side and returns the "best" (the one with the highest log probability per token). Results cannot be streamed.
100    /// When used with n, best_of controls the number of candidate completions and n specifies how many to return – best_of must be greater than n.
101    /// Note: Because this parameter generates many completions, it can quickly consume your token quota. Use carefully and ensure that you have reasonable settings for max_tokens and stop.
102    #[serde(skip_serializing_if = "Option::is_none")]
103    pub best_of: Option<u16>,
104    /// Modify the likelihood of specified tokens appearing in the completion.
105    /// Accepts a json object that maps tokens (specified by their token ID in the GPT tokenizer) to an associated bias value from -100 to 100. 
106    /// You can use this tokenizer tool (which works for both GPT-2 and GPT-3) to convert text to token IDs. 
107    /// Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.
108    /// As an example, you can pass {"50256": -100} to prevent the <|endoftext|> token from being generated.
109    #[serde(skip_serializing_if = "Option::is_none")]
110    pub logit_bias: Option<serde_json::Value>,
111    /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. 
112    #[serde(skip_serializing_if = "Option::is_none")]
113    pub user: Option<String>,
114}
115
116#[derive(Deserialize, Serialize, Debug)]
117pub struct CreateCompletionResponseChoice {
118    pub text: String,
119    pub index: i64,
120    pub logprobs: Option<serde_json::Value>,
121    pub finish_reason: Option<String>,
122}
123
124#[derive(Deserialize, Serialize, Debug)]
125pub struct Usage {
126    pub prompt_tokens: i64,
127    pub completion_tokens: i64,
128    pub total_tokens: i64,
129}
130
131#[derive(Deserialize, Serialize, Debug)]
132pub struct CreateCompletionResponse {
133    pub id: String,
134    pub object: String,
135    pub created: i64,
136    pub model: String,
137    pub choices: Vec<CreateCompletionResponseChoice>,
138    pub usage: Option<Usage>,
139}
140
141pub type CreateCompletionResponseStream =
142    Pin<Box<dyn Stream<Item = Result<CreateCompletionResponse, OpenAIApiError>> + Send>>;
143
144#[derive(Deserialize, Serialize, Debug)]
145pub struct ChatFormat {
146    pub role: String,
147    pub content: String,
148}
149
150#[derive(Deserialize, Serialize, Debug)]
151pub struct ChatFormatDelta {
152    pub role: Option<String>,
153    pub content: Option<String>,
154}
155
156#[derive(Deserialize, Serialize, Debug, Default)]
157pub struct CreateChatCompletionRequest {
158    /// ID of the model to use. Currently, only gpt-3.5-turbo and gpt-3.5-turbo-0301 are supported.
159    pub model: String,
160    /// The messages to generate chat completions for, in the chat format.
161    pub messages: Vec<ChatFormat>,
162    /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.
163    /// We generally recommend altering this or top_p but not both.
164    #[serde(skip_serializing_if = "Option::is_none")]
165    pub temperature: Option<f32>,
166    /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
167    /// We generally recommend altering this or temperature but not both.
168    #[serde(skip_serializing_if = "Option::is_none")]
169    pub top_p: Option<f32>,
170    /// How many chat completion choices to generate for each input message.
171    #[serde(skip_serializing_if = "Option::is_none")]
172    pub n: Option<u16>,
173    /// If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only server-sent events as they become available, with the stream terminated by a data: [DONE] message.
174    #[serde(skip_serializing_if = "Option::is_none")]
175    pub stream: Option<bool>,
176    /// Up to 4 sequences where the API will stop generating further tokens.
177    #[serde(skip_serializing_if = "Option::is_none")]
178    pub stop: Option<Vec<String>>,
179    /// The maximum number of tokens allowed for the generated answer. By default, the number of tokens the model can return will be (4096 - prompt tokens).
180    #[serde(skip_serializing_if = "Option::is_none")]
181    pub max_tokens: Option<u64>,
182    /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
183    #[serde(skip_serializing_if = "Option::is_none")]
184    pub presence_penalty: Option<f32>,
185    /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.
186    #[serde(skip_serializing_if = "Option::is_none")]
187    pub frequency_penalty: Option<f32>,
188    /// Modify the likelihood of specified tokens appearing in the completion.
189    /// Accepts a json object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. 
190    /// Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.
191    #[serde(skip_serializing_if = "Option::is_none")]
192    pub logit_bias: Option<serde_json::Value>,
193    /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
194    #[serde(skip_serializing_if = "Option::is_none")]
195    pub user: Option<String>,
196}
197
198#[derive(Deserialize, Serialize, Debug)]
199pub struct CreateChatCompletionResponseChoice {
200    pub message: ChatFormat,
201    pub index: i64,
202    pub finish_reason: String,
203}
204
205#[derive(Deserialize, Serialize, Debug)]
206pub struct CreateChatCompletionResponseChoiceDelta {
207    pub delta: ChatFormatDelta,
208    pub index: i64,
209    pub finish_reason: Option<String>,
210}
211
212#[derive(Deserialize, Serialize, Debug)]
213pub struct CreateChatCompletionResponse {
214    pub id: String,
215    pub object: String,
216    pub created: i64,
217    pub choices: Vec<CreateChatCompletionResponseChoice>,
218    pub usage: Usage,
219}
220
221#[derive(Deserialize, Serialize, Debug)]
222pub struct CreateChatCompletionStreamResponse {
223    pub id: String,
224    pub object: String,
225    pub created: i64,
226    pub model: String,
227    pub choices: Vec<CreateChatCompletionResponseChoiceDelta>,
228    // pub usage: Usage,
229}
230
231#[derive(Deserialize, Serialize, Debug, Default)]
232pub struct CreateEditRequest {
233    /// ID of the model to use. You can use the text-davinci-edit-001 or code-davinci-edit-001 model with this endpoint
234    pub model: String,
235    /// The input text to use as a starting point for the edit.
236    #[serde(skip_serializing_if = "Option::is_none")]
237    pub input: Option<String>,
238    /// The instruction that tells the model how to edit the prompt.
239    pub instruction: String,
240    /// How many edits to generate for the input and instruction.
241    #[serde(skip_serializing_if = "Option::is_none")]
242    pub n: Option<u16>,
243    /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.
244    /// We generally recommend altering this or top_p but not both.
245    #[serde(skip_serializing_if = "Option::is_none")]
246    pub temperature: Option<f32>,
247    /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. 
248    /// So 0.1 means only the tokens comprising the top 10% probability mass are considered.
249    /// We generally recommend altering this or temperature but not both.
250    #[serde(skip_serializing_if = "Option::is_none")]
251    pub top_p: Option<f32>,
252}
253
254pub type CreateChatCompletionResponseStream =
255    Pin<Box<dyn Stream<Item = Result<CreateChatCompletionStreamResponse, OpenAIApiError>> + Send>>;
256
257#[derive(Deserialize, Serialize, Debug)]
258pub struct CreateEditResponseChoice {
259    pub index: i64,
260    pub text: String,
261}
262
263#[derive(Deserialize, Serialize, Debug)]
264pub struct CreateEditResponse {
265    pub object: String,
266    pub created: i64,
267    pub choices: Vec<CreateEditResponseChoice>,
268    pub usage: Usage,
269}
270
271#[derive(Deserialize, Serialize, Debug)]
272pub enum ImageFormat {
273    #[serde(rename = "url")]
274    URL,
275    #[serde(rename = "b64_json")]
276    B64JSON,
277}
278
279impl Display for ImageFormat {
280    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
281        match self {
282            ImageFormat::URL => write!(f, "url"),
283            ImageFormat::B64JSON => write!(f, "b64_json"),
284        }
285    }
286    
287}
288
289#[derive(Deserialize, Serialize, Debug, Default)]
290pub struct CreateImageRequest {
291    /// A text description of the desired image(s). The maximum length is 1000 characters.
292    pub prompt: String,
293    /// The number of images to generate. Must be between 1 and 10.
294    #[serde(skip_serializing_if = "Option::is_none")]
295    pub n: Option<u16>,
296    /// The size of the generated images. Must be one of 256x256, 512x512, or 1024x1024.
297    #[serde(skip_serializing_if = "Option::is_none")]
298    pub size: Option<String>,
299    /// The format in which the generated images are returned. Must be one of ImageFormat::URL or ImageFormat::B64JSON.
300    #[serde(skip_serializing_if = "Option::is_none")]
301    pub response_format: Option<ImageFormat>,
302    /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
303    #[serde(skip_serializing_if = "Option::is_none")]
304    pub user: Option<String>,
305}
306
307#[derive(Deserialize, Serialize, Debug, Clone)]
308pub enum CreateImageResponseData {
309    #[serde(rename = "url")]
310    Url(String),
311    #[serde(rename = "b64_json")]
312    B64Json(String),
313}
314
315#[derive(Deserialize, Serialize, Debug)]
316pub struct CreateImageResponse {
317    pub created: i64,
318    pub data: Vec<CreateImageResponseData>,
319}
320
321#[derive(Deserialize, Serialize, Debug, Default)]
322pub struct CreateImageEditRequest {
323    /// The image to edit. Must be a valid PNG file, less than 4MB, and square. 
324    /// If mask is not provided, image must have transparency, which will be used as the mask.
325    pub image: String,
326    /// An additional image whose fully transparent areas (e.g. where alpha is zero) indicate where image should be edited. Must be a valid PNG file, less than 4MB, and have the same dimensions as image.
327    #[serde(skip_serializing_if = "Option::is_none")]
328    pub mask: Option<String>,
329    /// A text description of the desired image(s). The maximum length is 1000 characters.
330    pub prompt: String,
331    /// The number of images to generate. Must be between 1 and 10.
332    #[serde(skip_serializing_if = "Option::is_none")]
333    pub n: Option<u16>,
334    /// The size of the generated images. Must be one of 256x256, 512x512, or 1024x1024.
335    #[serde(skip_serializing_if = "Option::is_none")]
336    pub size: Option<String>,
337    /// The format in which the generated images are returned. Must be one of ImageFormat::URL or ImageFormat::B64JSON.
338    #[serde(skip_serializing_if = "Option::is_none")]
339    pub response_format: Option<ImageFormat>,
340    /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
341    #[serde(skip_serializing_if = "Option::is_none")]
342    pub user: Option<String>,
343}
344
345pub type CreateImageEditResponse = CreateImageResponse;
346
347#[derive(Deserialize, Serialize, Debug, Default)]
348pub struct CreateImageVariationRequest {
349    /// The image to use as the basis for the variation(s). Must be a valid PNG file, less than 4MB, and square.
350    pub image: String,
351    /// The number of images to generate. Must be between 1 and 10.
352    #[serde(skip_serializing_if = "Option::is_none")]
353    pub n: Option<u16>,
354    /// The size of the generated images. Must be one of 256x256, 512x512, or 1024x1024.
355    #[serde(skip_serializing_if = "Option::is_none")]
356    pub size: Option<String>,
357    /// The format in which the generated images are returned. Must be one of ImageFormat::URL or ImageFormat::B64JSON.
358    #[serde(skip_serializing_if = "Option::is_none")]
359    pub response_format: Option<ImageFormat>,
360    /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
361    #[serde(skip_serializing_if = "Option::is_none")]
362    pub user: Option<String>,
363}
364
365pub type CreateImageVariationResponse = CreateImageResponse;
366
367#[derive(Deserialize, Serialize, Debug, Default)]
368pub struct CreateEmbeddingsRequest {
369    /// ID of the model to use. You can use the List models API to see all of your available models, or see our Model overview for descriptions of them.
370    pub model: String,
371    /// Input text to get embeddings for, encoded as a string or array of tokens. To get embeddings for multiple inputs in a single request, pass an array of strings or array of token arrays. Each input must not exceed 8192 tokens in length.
372    pub input: Vec<String>,
373    /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
374    #[serde(skip_serializing_if = "Option::is_none")]
375    pub user: Option<String>,
376}
377
378#[derive(Deserialize, Serialize, Debug)]
379pub struct CreateEmbeddingsResponseData {
380    pub object: String,
381    pub embedding: Vec<f32>,
382    pub index: i64,
383}
384
385#[derive(Deserialize, Serialize, Debug)]
386pub struct CreateEmbeddingsResponseUsage {
387    pub prompt_tokens: i64,
388    pub total_tokens: i64,
389}
390
391#[derive(Deserialize, Serialize, Debug)]
392pub struct CreateEmbeddingsResponse {
393    pub object: String,
394    pub data: Vec<CreateEmbeddingsResponseData>,
395    pub model: String,
396    pub usage: CreateEmbeddingsResponseUsage,
397}
398
399#[derive(Deserialize, Serialize, Debug, Clone, Copy)]
400pub enum CreateTranscriptionResponseFormat {
401    #[serde(rename = "json")]
402    JSON,
403    #[serde(rename = "text")]
404    TEXT,
405    #[serde(rename = "srt")]
406    SRT,
407    #[serde(rename = "verbose_json")]
408    VERBOSEJSON,
409    #[serde(rename = "vtt")]
410    VTT,
411}
412
413impl Display for CreateTranscriptionResponseFormat {
414    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
415        match self {
416            CreateTranscriptionResponseFormat::JSON => write!(f, "json"),
417            CreateTranscriptionResponseFormat::TEXT => write!(f, "text"),
418            CreateTranscriptionResponseFormat::SRT => write!(f, "srt"),
419            CreateTranscriptionResponseFormat::VERBOSEJSON => write!(f, "verbose_json"),
420            CreateTranscriptionResponseFormat::VTT => write!(f, "vtt"),
421        }
422    } 
423}
424
425#[derive(Deserialize, Serialize, Debug, Default)]
426pub struct CreateTranscriptionRequest {
427    /// The audio file to transcribe, in one of these formats: mp3, mp4, mpeg, mpga, m4a, wav, or webm.
428    pub file: String,
429    /// ID of the model to use. Only whisper-1 is currently available.
430    pub model: String,
431    /// An optional text to guide the model's style or continue a previous audio segment. The prompt should match the audio language.
432    #[serde(skip_serializing_if = "Option::is_none")]
433    pub prompt: Option<String>,
434    /// The format of the transcript output, in one of these options: json, text, srt, verbose_json, or vtt.
435    #[serde(skip_serializing_if = "Option::is_none")]
436    pub response_format: Option<CreateTranscriptionResponseFormat>,
437    /// The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. If set to 0, the model will use log probability to automatically increase the temperature until certain thresholds are hit.
438    #[serde(skip_serializing_if = "Option::is_none")]
439    pub temperature: Option<f32>,
440    /// The language of the input audio. Supplying the input language in ISO-639-1 format will improve accuracy and latency.
441    #[serde(skip_serializing_if = "Option::is_none")]
442    pub language: Option<String>,
443}
444
445pub enum CreateTranscriptionResponse {
446    Text(CreateTranscriptionResponseText),
447    Json(CreateTranscriptionResponseJson),
448    Srt(CreateTranscriptionResponseSrt),
449    VerboseJson(CreateTranscriptionResponseVerboseJson),
450    Vtt(CreateTranscriptionResponseVtt),
451}
452
453#[derive(Deserialize, Serialize, Debug)]
454pub struct CreateTranscriptionResponseText {
455    pub text: String,
456}
457
458#[derive(Deserialize, Serialize, Debug)]
459pub struct CreateTranscriptionResponseJson {
460    pub text: String,
461}
462
463
464#[derive(Deserialize, Serialize, Debug)]
465pub struct CreateTranscriptionResponseSrt {
466    pub text: String,
467}
468
469#[derive(Deserialize, Serialize, Debug)]
470pub struct TranscriptionSegment {
471    pub id: String,
472    pub seek: i32,
473    pub start: f32,
474    pub end: f32,
475    pub text: String,
476    pub tokens: Vec<i64>,
477    pub temperature: f32,
478    pub avg_logprob: f64,
479    pub compression_ratio: f64,
480    pub no_speech_prob: f64,
481    pub transient: bool,
482}
483    
484#[derive(Deserialize, Serialize, Debug)]
485pub struct CreateTranscriptionResponseVerboseJson {
486    pub task: String,
487    pub language: String,
488    pub duration: f32,
489    pub segments: Vec<TranscriptionSegment>,
490    pub text: String,
491}
492
493#[derive(Deserialize, Serialize, Debug)]
494pub struct CreateTranscriptionResponseVtt {
495    pub text: String,
496}
497
498#[derive(Deserialize, Serialize, Debug, Default)]
499pub struct CreateTranslationRequest {
500    /// The audio file to transcribe, in one of these formats: mp3, mp4, mpeg, mpga, m4a, wav, or webm.
501    pub file: String,
502    /// ID of the model to use. Only whisper-1 is currently available.
503    pub model: String,
504    /// An optional text to guide the model's style or continue a previous audio segment. The prompt should match the audio language.
505    #[serde(skip_serializing_if = "Option::is_none")]
506    pub prompt: Option<String>,
507    /// The format of the transcript output, in one of these options: json, text, srt, verbose_json, or vtt.
508    #[serde(skip_serializing_if = "Option::is_none")]
509    pub response_format: Option<CreateTranscriptionResponseFormat>,
510    /// The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. If set to 0, the model will use log probability to automatically increase the temperature until certain thresholds are hit.
511    #[serde(skip_serializing_if = "Option::is_none")]
512    pub temperature: Option<f32>,
513}
514
515pub type CreateTranslationResponse = CreateTranscriptionResponse;
516
517#[derive(Deserialize, Serialize, Debug)]
518pub struct FileInfo {
519    pub id: String,
520    pub object: String,
521    pub bytes: i32,
522    pub created_at: i64,
523    pub filename: String,
524    pub purpose: String,
525}
526
527#[derive(Deserialize, Serialize, Debug)]
528pub struct ListFilesResponse {
529    pub data: Vec<FileInfo>,
530    pub object: String,
531}
532
533#[derive(Deserialize, Serialize, Debug)]
534pub struct UploadFileRequest {
535    /// JSON Lines file to be uploaded.
536    /// If the purpose is set to "fine-tune", each line is a JSON record with "prompt" and "completion" fields representing your training examples.
537    pub file: String,
538    /// The name of the file.
539    pub filename: String,
540    /// The purpose of the file. Can be "fine-tune" or "test".
541    pub purpose: String,
542}
543
544pub type UploadFileResponse = FileInfo;
545
546#[derive(Deserialize, Serialize, Debug)]
547pub struct DeleteFileResponse {
548    pub deleted: bool,
549    pub id: String,
550    pub object: String,
551}
552
553pub type RetrieveFileResponse = FileInfo;
554
555#[derive(Deserialize, Serialize, Debug, Default)]
556pub struct CreateFineTuneRequest {
557    /// The ID of an uploaded file that contains training data.
558    /// See upload file for how to upload a file.
559    /// Your dataset must be formatted as a JSONL file, where each training example is a JSON object with the keys "prompt" and "completion". 
560    /// Additionally, you must upload your file with the purpose fine-tune.
561    pub training_file: String,
562
563    /// The ID of an uploaded file that contains validation data.
564    /// If you provide this file, the data is used to generate validation metrics periodically during fine-tuning. 
565    /// These metrics can be viewed in the fine-tuning results file. Your train and validation data should be mutually exclusive.
566    /// Your dataset must be formatted as a JSONL file, where each validation example is a JSON object with the keys "prompt" and "completion". 
567    /// Additionally, you must upload your file with the purpose fine-tune.
568    #[serde(skip_serializing_if = "Option::is_none")]
569    pub validation_file: Option<String>,
570
571    /// The name of the base model to fine-tune. 
572    /// You can select one of "ada", "babbage", "curie", "davinci", or a fine-tuned model created after 2022-04-21. 
573    /// To learn more about these models
574    #[serde(skip_serializing_if = "Option::is_none")]
575    pub model: Option<String>,
576
577    /// The number of epochs to train the model for. An epoch refers to one full cycle through the training dataset.
578    #[serde(skip_serializing_if = "Option::is_none")]
579    pub n_epochs: Option<i32>,
580
581    /// The batch size to use for training. The batch size is the number of training examples used to train a single forward and backward pass.
582    /// By default, the batch size will be dynamically configured to be ~0.2% of the number of examples in the training set, capped at 256 - in general, we've found that larger batch sizes tend to work better for larger datasets.
583    #[serde(skip_serializing_if = "Option::is_none")]
584    pub batch_size: Option<i32>,
585
586    /// The learning rate multiplier to use for training. The fine-tuning learning rate is the original learning rate used for pretraining multiplied by this value.
587    /// By default, the learning rate multiplier is the 0.05, 0.1, or 0.2 depending on final batch_size (larger learning rates tend to perform better with larger batch sizes). 
588    /// We recommend experimenting with values in the range 0.02 to 0.2 to see what produces the best results.
589    #[serde(skip_serializing_if = "Option::is_none")]
590    pub learning_rate_multiplier: Option<f32>,
591
592    /// The weight to use for loss on the prompt tokens. This controls how much the model tries to learn to generate the prompt (as compared to the completion which always has a weight of 1.0), and can add a stabilizing effect to training when completions are short.
593    /// If prompts are extremely long (relative to completions), it may make sense to reduce this weight so as to avoid over-prioritizing learning the prompt.
594    #[serde(skip_serializing_if = "Option::is_none")]
595    pub prompt_loss_weight: Option<f32>,
596
597    /// If set, we calculate classification-specific metrics such as accuracy and F-1 score using the validation set at the end of every epoch. 
598    /// These metrics can be viewed in the results file.
599    /// In order to compute classification metrics, you must provide a validation_file. 
600    /// Additionally, you must specify classification_n_classes for multiclass classification or classification_positive_class for binary classification.
601    #[serde(skip_serializing_if = "Option::is_none")]
602    pub compute_classification_metrics: Option<bool>,
603
604    /// The number of classes in a classification task.
605    /// This parameter is required for multiclass classification.
606    #[serde(skip_serializing_if = "Option::is_none")]
607    pub classification_n_classes: Option<i32>,
608
609    /// The positive class in binary classification.
610    /// This parameter is needed to generate precision, recall, and F1 metrics when doing binary classification.
611    #[serde(skip_serializing_if = "Option::is_none")]
612    pub classification_positive_class: Option<String>,
613
614    /// If this is provided, we calculate F-beta scores at the specified beta values. 
615    /// The F-beta score is a generalization of F-1 score. This is only used for binary classification.
616    /// With a beta of 1 (i.e. the F-1 score), precision and recall are given the same weight. 
617    /// A larger beta score puts more weight on recall and less on precision. 
618    /// A smaller beta score puts more weight on precision and less on recall.
619    #[serde(skip_serializing_if = "Option::is_none")]
620    pub classification_betas: Option<Vec<f32>>,
621
622    /// A string of up to 40 characters that will be added to your fine-tuned model name.
623    /// For example, a suffix of "custom-model-name" would produce a model name like `ada:ft-your-org:custom-model-name-2022-02-15-04-21-04`.
624    #[serde(skip_serializing_if = "Option::is_none")]
625    pub suffix: Option<String>,
626}
627
628#[derive(Deserialize, Serialize, Debug)]
629pub struct FineTuneEvent {
630    pub object: String,
631    pub created_at: i64,
632    pub level: String,
633    pub message: String,
634}
635
636#[derive(Deserialize, Serialize, Debug)]
637pub struct FineTuneHyperparams {
638    pub batch_size: i32,
639    pub learning_rate_multiplier: f32,
640    pub prompt_loss_weight: f32,
641    pub n_epochs: i32,
642}
643
644#[derive(Deserialize, Serialize, Debug)]
645pub struct CreateFineTuneResponse {
646    pub id: String,
647    pub object: String,
648    pub model: String,
649    pub created_at: i64,
650    pub events: Vec<FineTuneEvent>,
651    pub fine_tuned_model: Option<String>,
652    pub hyperparams: FineTuneHyperparams,
653    pub organization_id: String,
654    pub result_files: Vec<FileInfo>,
655    pub status: String,
656    pub validation_files: Vec<FileInfo>,
657    pub training_files: Vec<FileInfo>,
658    pub updated_at: i64,
659}
660
661#[derive(Deserialize, Serialize, Debug)]
662pub struct ListFineTunesResponse {
663    pub object: String,
664    pub data: Vec<CreateFineTuneResponse>,
665}
666
667pub type RetrieveFineTuneResponse = CreateFineTuneResponse;
668
669pub type CancelFineTuneResponse = CreateFineTuneResponse;
670
671#[derive(Deserialize, Serialize, Debug)]
672pub struct ListFineTuneEventsResponse {
673    pub object: String,
674    pub data: Vec<FineTuneEvent>,
675}
676
677#[derive(Deserialize, Serialize, Debug)]
678pub struct DeleteFineTuneModelResponse {
679    pub id: String,
680    pub object: String,
681    pub deleted: bool,
682}
683
684#[derive(Deserialize, Serialize, Debug, Default)]
685pub struct CreateModerationRequest {
686    /// The input text to classify
687    pub input: Vec<String>,
688    /// Two content moderations models are available: text-moderation-stable and text-moderation-latest.
689    /// The default is text-moderation-latest which will be automatically upgraded over time. 
690    /// This ensures you are always using our most accurate model. 
691    /// If you use text-moderation-stable, we will provide advanced notice before updating the model. 
692    /// Accuracy of text-moderation-stable may be slightly lower than for text-moderation-latest.
693    #[serde(skip_serializing_if = "Option::is_none")]
694    pub model: Option<String>,
695}
696
697#[derive(Deserialize, Serialize, Debug)]
698pub struct ModerationCategories {
699    pub hate: bool,
700    #[serde(rename = "hate/threatening")]
701    pub hate_threatening: bool,
702    #[serde(rename = "self-harm")]
703    pub self_harm: bool,
704    pub sexual: bool,
705    #[serde(rename = "sexual/minors")]
706    pub sexual_minors: bool,
707    pub violence: bool,
708    #[serde(rename = "violence/graphic")]
709    pub violence_graphic: bool,
710}
711
712#[derive(Deserialize, Serialize, Debug)]
713pub struct ModerationCategoryScores {
714    pub hate: f64,
715    #[serde(rename = "hate/threatening")]
716    pub hate_threatening: f64,
717    #[serde(rename = "self-harm")]
718    pub self_harm: f64,
719    pub sexual: f64,
720    #[serde(rename = "sexual/minors")]
721    pub sexual_minors: f64,
722    pub violence: f64,
723    #[serde(rename = "violence/graphic")]
724    pub violence_graphic: f64,
725}
726
727#[derive(Deserialize, Serialize, Debug)]
728pub struct CreateModerationResult {
729    pub categories: ModerationCategories,
730    pub category_scores: ModerationCategoryScores,
731    pub flagged: bool,
732}
733
734#[derive(Deserialize, Serialize, Debug)]
735pub struct CreateModerationResponse {
736    pub id: String,
737    pub model: String,
738    pub results: Vec<CreateModerationResult>,
739}
740
741pub struct OpenAIApi {
742    configuration: Configuration,
743}
744
745impl OpenAIApi {
746
747    pub fn new(configuration: Configuration) -> Self {
748        Self { configuration }
749    }
750
751    /// List models
752    /// GET https://api.openai.com/v1/models
753    /// Lists the currently available models, and provides basic information about each one such as the owner and availability.
754    pub async fn list_models(self) -> Result<ListModelsResponse, OpenAIApiError> {
755
756        let client_builder = reqwest::Client::builder();
757        let request_builder = self
758            .configuration
759            .apply_to_request(
760                client_builder, 
761                "/models".to_string(), 
762                Method::GET,
763            );
764        let response = request_builder
765            .send()
766            .await
767            .map_err(OpenAIApiError::from)?;
768        if response.status().is_success() {
769            response
770                .json::<ListModelsResponse>()
771                .await
772                .map_err(OpenAIApiError::from)
773        } else {
774            let status = response.status().as_u16() as i32;
775            let ret_err = response.json::<ReturnErrorType>().await.map_err( OpenAIApiError::from)?;
776            Err(OpenAIApiError::new(status, ret_err.error))
777        }
778    }
779
780    /// Retrieve model
781    /// GET https://api.openai.com/v1/models/{model}
782    /// Retrieves a model instance, providing basic information about the model such as the owner and permissioning.
783    pub async fn retrieve_model(self, model: String) -> Result<RetrieveModelResponse, OpenAIApiError> {
784
785        let client_builder = reqwest::Client::builder();
786        let request_builder = self.configuration.apply_to_request(
787            client_builder, 
788            format!("/models/{}", model), 
789            Method::GET,
790        );
791        let response = request_builder.send().await
792            .map_err(|err| OpenAIApiError::from(err))?;
793        if response.status().is_success() {
794            response.json::<RetrieveModelResponse>().await
795                .map_err(|err| OpenAIApiError::from(err))
796        } else {
797            let status = response.status().as_u16() as i32;
798            let ret_err = response.json::<ReturnErrorType>().await
799                .map_err(|err| OpenAIApiError::from(err))?;
800            Err(OpenAIApiError::new(status, ret_err.error))
801        }
802    }
803
804    /// Create completion
805    /// POST https://api.openai.com/v1/completions
806    /// Creates a completion for the provided prompt and parameters
807    pub async fn create_completion(self, mut request: CreateCompletionRequest) -> Result<CreateCompletionResponse, OpenAIApiError> {
808
809        let client_builder = reqwest::Client::builder();
810        let request_builder = self.configuration.apply_to_request(
811            client_builder, 
812            "/completions".to_string(), 
813            Method::POST,
814        );
815        request.stream = None;
816        let response = request_builder.json(&request).send().await
817            .map_err(|err| OpenAIApiError::from(err))?;
818        info!("response: {:#?}", response);
819        if response.status().is_success() {
820            response.json::<CreateCompletionResponse>().await
821                .map_err(|err| OpenAIApiError::from(err))
822        } else {
823            let status = response.status().as_u16() as i32;
824            let ret_err = response.json::<ReturnErrorType>().await
825                .map_err(|err| OpenAIApiError::from(err))?;
826            Err(OpenAIApiError::new(status, ret_err.error))
827        }
828    }
829
830    /// Creates a completion request for the provided prompt and parameters
831    ///
832    /// Stream back partial progress. Tokens will be sent as data-only
833    /// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
834    /// as they become available, with the stream terminated by a data: \[DONE\] message.
835    ///
836    /// [CompletionResponseStream] is a parsed SSE stream until a \[DONE\] is received from server.
837    pub async fn create_completion_stream(self, mut request: CreateCompletionRequest) -> Result<CreateCompletionResponseStream, OpenAIApiError> {
838        let client_builder = reqwest::Client::builder();
839        let request_builder = self.configuration.apply_to_request(
840            client_builder, 
841            "/completions".to_string(), 
842            Method::POST,
843        );
844        request.stream = Some(true);
845        let event_source = request_builder.json(&request).eventsource().unwrap();
846        Ok(stream(event_source).await)
847    }
848
849    ///
850    /// Create chat completion
851    /// POST https://api.openai.com/v1/chat/completions
852    /// Creates a completion for the chat message
853    pub async fn create_chat_completion(self, mut request: CreateChatCompletionRequest) -> Result<CreateChatCompletionResponse, OpenAIApiError> {
854
855        let client_builder = reqwest::Client::builder();
856        let request_builder = self.configuration.apply_to_request(
857            client_builder, 
858            "/chat/completions".to_string(), 
859            Method::POST,
860        );
861        request.stream = None;
862        let response = request_builder.json(&request).send().await
863            .map_err(|err| OpenAIApiError::from(err))?;
864        info!("response: {:#?}", response);
865        // println!("response: {:#?}, {}", response, response.status().is_success());
866        if response.status().is_success() {
867            response.json::<CreateChatCompletionResponse>().await
868                .map_err(|err| OpenAIApiError::from(err))
869        } else {
870            let status = response.status().as_u16() as i32;
871            let ret_err = response.json::<ReturnErrorType>().await
872                .map_err(|err| OpenAIApiError::from(err))?;
873            Err(OpenAIApiError::new(status, ret_err.error))
874        }
875    }
876
877    /// Creates a chat completion request for the provided prompt and parameters
878    ///
879    /// Stream back partial progress. Tokens will be sent as data-only
880    /// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
881    /// as they become available, with the stream terminated by a data: \[DONE\] message.
882    ///
883    /// [ChatCompletionResponseStream] is a parsed SSE stream until a \[DONE\] is received from server.
884    pub async fn create_chat_completion_stream(self, mut request: CreateChatCompletionRequest) -> Result<CreateChatCompletionResponseStream, OpenAIApiError> {
885        let client_builder = reqwest::Client::builder();
886        let request_builder = self.configuration.apply_to_request(
887            client_builder, 
888            "/chat/completions".to_string(), 
889            Method::POST,
890        );
891        request.stream = Some(true);
892        let event_source = request_builder.json(&request).eventsource().unwrap();
893        Ok(stream(event_source).await)
894    }
895
896    /// Create edit
897    /// POST https://api.openai.com/v1/edits
898    /// Creates a new edit for the provided input, instruction, and parameters.
899    pub async fn create_edit(self, request: CreateEditRequest) -> Result<CreateEditResponse, OpenAIApiError> {
900        let client_builder = reqwest::Client::builder();
901        let request_builder = self.configuration.apply_to_request(
902            client_builder, 
903            "/edits".to_string(), 
904            Method::POST,
905        );
906        let response = request_builder.json(&request).send().await
907            .map_err(|err| OpenAIApiError::from(err))?;
908        info!("response: {:#?}", response);
909        // println!("response: {:#?}", response.status());
910        if response.status().is_success() {
911            response.json::<CreateEditResponse>().await
912                .map_err(|err| OpenAIApiError::from(err))
913        } else {
914            let status = response.status().as_u16() as i32;
915            let ret_err = response.json::<ReturnErrorType>().await
916                .map_err(|err| OpenAIApiError::from(err))?;
917            Err(OpenAIApiError::new(status, ret_err.error))
918        }
919    }
920
921    ///
922    /// Create image
923    /// POST https://api.openai.com/v1/images/generations
924    /// Creates an image given a prompt.
925    /// 
926    pub async fn create_image(self, request: CreateImageRequest)  -> Result<CreateImageResponse, OpenAIApiError>{
927        let client_builder = reqwest::Client::builder();
928        let request_builder = self.configuration.apply_to_request(
929            client_builder, 
930            "/images/generations".to_string(), 
931            Method::POST,
932        );
933        let response = request_builder.json(&request).send().await
934            .map_err(|err| OpenAIApiError::from(err))?;
935        // println!("body: {:?}", response.unwrap().text().await);
936        if response.status().is_success() {
937            response.json::<CreateImageResponse>().await
938                .map_err(|err| OpenAIApiError::from(err))
939        } else {
940            let status = response.status().as_u16() as i32;
941            let ret_err = response.json::<ReturnErrorType>().await
942                .map_err(|err| OpenAIApiError::from(err))?;
943            Err(OpenAIApiError::new(status, ret_err.error))
944        }
945    }
946
947    
948    /// Create image editBeta
949    /// POST https://api.openai.com/v1/images/edits
950    /// Creates an edited or extended image given an original image and a prompt.
951    pub async fn create_image_edit(self, request: CreateImageEditRequest) -> Result<CreateImageEditResponse, OpenAIApiError> {
952        let client_builder = reqwest::Client::builder();
953        let request_builder = self.configuration.apply_to_request(
954            client_builder, 
955            "/images/edits".to_string(), 
956            Method::POST,
957        );
958        let image_file = fs::read(request.image).unwrap();
959        let image_file_part = Part::bytes(image_file)
960            .file_name("image.png")
961            .mime_str("image/png")
962            .unwrap();
963        let mut form = reqwest::multipart::Form::new()
964        .part("image", image_file_part);
965        form = match request.mask {
966            Some(mask) => {
967                let mask_file = fs::read(mask).unwrap();
968                let mask_file_part = Part::bytes(mask_file)
969                    .file_name("mask.png")
970                    .mime_str("image/png")
971                    .unwrap();
972                form.part("mask", mask_file_part)
973            },
974            None => form,
975        };
976        form = form.text("prompt", request.prompt.clone());
977        form = match request.n {
978            Some(n) => form.text("n", n.to_string()),
979            None => form,
980        };
981        form = match request.size {
982            Some(size) => form.text("size", size),
983            None => form,
984        };
985        form = match request.response_format {
986            Some(response_format) => form.text("response_format", response_format.to_string()),
987            None => form,
988        };
989        form = match request.user {
990            Some(user) => form.text("user", user),
991            None => form,
992        };
993        let response = request_builder.multipart(form).send().await
994            .map_err(|err| OpenAIApiError::from(err))?;
995        // println!("response: {:#?}", response);
996        if response.status().is_success() {
997            response.json::<CreateImageEditResponse>().await
998                .map_err(|err| OpenAIApiError::from(err))
999        } else {
1000            let status = response.status().as_u16() as i32;
1001            let ret_err = response.json::<ReturnErrorType>().await
1002                .map_err(|err| OpenAIApiError::from(err))?;
1003            Err(OpenAIApiError::new(status, ret_err.error))
1004        }
1005    }
1006    
1007    /// Create image variation
1008    /// POST https://api.openai.com/v1/images/variations
1009    /// Creates a variation of a given image.
1010    pub async fn create_image_variation(self, request: CreateImageVariationRequest) -> Result<CreateImageVariationResponse, OpenAIApiError> {
1011        let client_builder = reqwest::Client::builder();
1012        let request_builder = self.configuration.apply_to_request(
1013            client_builder, 
1014            "/images/variations".to_string(), 
1015            Method::POST,
1016        );
1017        let image_file = fs::read(request.image).unwrap();
1018        let image_file_part = Part::bytes(image_file)
1019            .file_name("image.png")
1020            .mime_str("image/png")
1021            .unwrap();
1022        let mut form = reqwest::multipart::Form::new().part("image", image_file_part);
1023        
1024        form = match request.n {
1025            Some(n) => form.text("n", n.to_string()),
1026            None => form,
1027        };
1028        form = match request.size {
1029            Some(size) => form.text("size", size),
1030            None => form,
1031        };
1032        form = match request.response_format {
1033            Some(response_format) => form.text("response_format", response_format.to_string()),
1034            None => form,
1035        };
1036        form = match request.user {
1037            Some(user) => form.text("user", user),
1038            None => form,
1039        };
1040        let response = request_builder.multipart(form).send().await
1041            .map_err(|err| OpenAIApiError::from(err))?;
1042        // println!("response: {:#?}", response);
1043        if response.status().is_success() {
1044            response.json::<CreateImageVariationResponse>().await
1045                .map_err(|err| OpenAIApiError::from(err))
1046        } else {
1047            let status = response.status().as_u16() as i32;
1048            let ret_err = response.json::<ReturnErrorType>().await
1049                .map_err(|err| OpenAIApiError::from(err))?;
1050            Err(OpenAIApiError::new(status, ret_err.error))
1051        }
1052    }
1053
1054    /// Create embeddings
1055    /// POST https://api.openai.com/v1/embeddings
1056    /// Creates an embedding vector representing the input text.
1057    pub async fn create_embeddings(self, request: CreateEmbeddingsRequest) -> Result<CreateEmbeddingsResponse, OpenAIApiError> {
1058        let client_builder = reqwest::Client::builder();
1059        let request_builder = self.configuration.apply_to_request(
1060            client_builder, 
1061            "/embeddings".to_string(), 
1062            Method::POST,
1063        );
1064        let response = request_builder.json(&request).send().await
1065            .map_err(|err| OpenAIApiError::from(err))?;
1066        info!("response: {:#?}", response);
1067        if response.status().is_success() {
1068            response.json::<CreateEmbeddingsResponse>().await
1069                .map_err(|err| OpenAIApiError::from(err))
1070        } else {
1071            let status = response.status().as_u16() as i32;
1072            let ret_err = response.json::<ReturnErrorType>().await
1073                .map_err(|err| OpenAIApiError::from(err))?;
1074            Err(OpenAIApiError::new(status, ret_err.error))
1075        }
1076    }
1077
1078    /// Create transcription
1079    /// POST https://api.openai.com/v1/audio/transcriptions
1080    /// Transcribes audio into the input language.
1081    pub async fn create_transcription(self, request: CreateTranscriptionRequest) -> Result<CreateTranscriptionResponse, OpenAIApiError> {
1082        let client_builder = reqwest::Client::builder();
1083        let request_builder = self.configuration.apply_to_request(
1084            client_builder, 
1085            "/audio/transcriptions".to_string(), 
1086            Method::POST,
1087        );
1088        let parts: Vec<&str> = request.file.split('.').collect();
1089        let suffix = parts[parts.len() - 1];
1090        let mime_type = Self::get_mime_type_from_suffix(suffix.to_string())?;
1091        let audio_file = fs::read(request.file.clone()).unwrap();
1092        let audio_file_part = Part::bytes(audio_file)
1093            .file_name(format!("audio.{}", suffix))
1094            .mime_str(mime_type.as_str())
1095            .unwrap();
1096        let mut form = reqwest::multipart::Form::new().part("file", audio_file_part)
1097            .text("model", request.model);
1098        form = match request.prompt {
1099            Some(prompt) => form.text("prompt", prompt),
1100            None => form,
1101        };
1102        form = match request.response_format {
1103            Some(response_format) => form.text("response_format", response_format.to_string()),
1104            None => form,
1105        };
1106        form = match request.temperature {
1107            Some(temperature) => form.text("temperature", temperature.to_string()),
1108            None => form,
1109        };
1110        form = match request.language {
1111            Some(language) => form.text("language", language),
1112            None => form,
1113        };
1114        info!("request form: {:#?}", form);
1115        let response = request_builder.multipart(form).send().await
1116            .map_err(|err| OpenAIApiError::from(err))?;
1117        println!("response: {:#?}", response);
1118        let rf = request.response_format.clone();
1119        if response.status().is_success() {
1120            match rf {
1121                Some(response_format) => match response_format {
1122                    CreateTranscriptionResponseFormat::TEXT => {
1123                        let text = response.text().await
1124                            .map_err(|err| OpenAIApiError::from(err))?;
1125                        let response = CreateTranscriptionResponseText {
1126                            text,
1127                        };
1128                        Ok(CreateTranscriptionResponse::Text(response))
1129                    },
1130                    CreateTranscriptionResponseFormat::JSON => {
1131                        let response = response.json::<CreateTranscriptionResponseJson>().await
1132                            .map_err(|err| OpenAIApiError::from(err)).unwrap();
1133                        Ok(CreateTranscriptionResponse::Json(response))
1134                    },
1135                    CreateTranscriptionResponseFormat::SRT => {
1136                        let text = response.text().await
1137                            .map_err(|err| OpenAIApiError::from(err))?;
1138                        let response = CreateTranscriptionResponseSrt {
1139                            text,
1140                        };
1141                        Ok(CreateTranscriptionResponse::Srt(response))
1142                    },
1143                    CreateTranscriptionResponseFormat::VTT => {
1144                        let text = response.text().await
1145                            .map_err(|err| OpenAIApiError::from(err))?;
1146                        let response = CreateTranscriptionResponseVtt {
1147                            text,
1148                        };
1149                        Ok(CreateTranscriptionResponse::Vtt(response))
1150                    },
1151                    CreateTranscriptionResponseFormat::VERBOSEJSON => {
1152                        let response = response.json::<CreateTranscriptionResponseVerboseJson>().await
1153                            .map_err(|err| OpenAIApiError::from(err))?;
1154                        Ok(CreateTranscriptionResponse::VerboseJson(response))
1155                    },
1156                },
1157                None => {
1158                    let response = response.json::<CreateTranscriptionResponseJson>().await
1159                        .map_err(|err| OpenAIApiError::from(err))?;
1160                    Ok(CreateTranscriptionResponse::Json(response))
1161                },
1162            }
1163        } else {
1164            let status = response.status().as_u16() as i32;
1165            let ret_err = response.json::<ReturnErrorType>().await
1166                .map_err(|err| OpenAIApiError::from(err))?;
1167            Err(OpenAIApiError::new(status, ret_err.error))
1168        }
1169        
1170    }
1171
1172    /// Create translation
1173    /// POST https://api.openai.com/v1/audio/translations
1174    /// Translates audio into into English.
1175    pub async fn create_translation(self, request: CreateTranslationRequest) -> Result<CreateTranslationResponse, OpenAIApiError> {
1176        let client_builder = reqwest::Client::builder();
1177        let request_builder = self.configuration.apply_to_request(
1178            client_builder, 
1179            "/audio/translations".to_string(), 
1180            Method::POST,
1181        );
1182        let parts: Vec<&str> = request.file.split('.').collect();
1183        let suffix = parts[parts.len() - 1];
1184        let mime_type = Self::get_mime_type_from_suffix(suffix.to_string()).unwrap();
1185        let audio_file = fs::read(request.file.clone()).unwrap();
1186        let audio_file_part = Part::bytes(audio_file)
1187            .file_name(format!("audio.{}", suffix))
1188            .mime_str(mime_type.as_str())
1189            .unwrap();
1190        let mut form = reqwest::multipart::Form::new().part("file", audio_file_part)
1191            .text("model", request.model);
1192        form = match request.prompt {
1193            Some(prompt) => form.text("prompt", prompt),
1194            None => form,
1195        };
1196        form = match request.response_format {
1197            Some(response_format) => form.text("response_format", response_format.to_string()),
1198            None => form,
1199        };
1200        form = match request.temperature {
1201            Some(temperature) => form.text("temperature", temperature.to_string()),
1202            None => form,
1203        };
1204        let response = request_builder.multipart(form).send().await
1205            .map_err(|err| OpenAIApiError::from(err))?;
1206        info!("response: {:#?}", response);
1207        let rf = request.response_format.clone();
1208        if response.status().is_success() {
1209            match rf {
1210                Some(response_format) => match response_format {
1211                    CreateTranscriptionResponseFormat::TEXT => {
1212                        let text = response.text().await
1213                            .map_err(|err| OpenAIApiError::from(err))?;
1214                        let response = CreateTranscriptionResponseText {
1215                            text,
1216                        };
1217                        Ok(CreateTranslationResponse::Text(response))
1218                    },
1219                    CreateTranscriptionResponseFormat::JSON => {
1220                        let response = response.json::<CreateTranscriptionResponseJson>().await
1221                            .map_err(|err| OpenAIApiError::from(err))?;
1222                        Ok(CreateTranslationResponse::Json(response))
1223                    },
1224                    CreateTranscriptionResponseFormat::SRT => {
1225                        let text = response.text().await
1226                            .map_err(|err| OpenAIApiError::from(err))?;
1227                        let response = CreateTranscriptionResponseSrt {
1228                            text,
1229                        };
1230                        Ok(CreateTranslationResponse::Srt(response))
1231                    },
1232                    CreateTranscriptionResponseFormat::VTT => {
1233                        let text = response.text().await
1234                            .map_err(|err| OpenAIApiError::from(err))?;
1235                        let response = CreateTranscriptionResponseVtt {
1236                            text,
1237                        };
1238                        Ok(CreateTranslationResponse::Vtt(response))
1239                    },
1240                    CreateTranscriptionResponseFormat::VERBOSEJSON => {
1241                        let response = response.json::<CreateTranscriptionResponseVerboseJson>().await
1242                            .map_err(|err| OpenAIApiError::from(err))?;
1243                        Ok(CreateTranslationResponse::VerboseJson(response))
1244                    },
1245                },
1246                None => {
1247                    let response = response.json::<CreateTranscriptionResponseJson>().await
1248                        .map_err(|err| OpenAIApiError::from(err))?;
1249                    Ok(CreateTranslationResponse::Json(response))
1250                },
1251            }
1252        } else {
1253            let status = response.status().as_u16() as i32;
1254            let ret_err = response.json::<ReturnErrorType>().await
1255                .map_err(|err| OpenAIApiError::from(err))?;
1256            Err(OpenAIApiError::new(status, ret_err.error))
1257        }
1258        
1259        
1260    }
1261
1262    /// List files
1263    /// GET https://api.openai.com/v1/files
1264    /// Returns a list of files that belong to the user's organization.
1265    pub async fn list_files(self) -> Result<ListFilesResponse, OpenAIApiError> {
1266        let client_builder = reqwest::Client::builder();
1267        let request_builder = self.configuration.apply_to_request(
1268            client_builder, 
1269            "/files".to_string(), 
1270            Method::GET,
1271        );
1272        let response = request_builder.send().await
1273            .map_err(|err| OpenAIApiError::from(err))?;
1274        if response.status().is_success() {
1275            response.json::<ListFilesResponse>().await
1276                .map_err(|err| OpenAIApiError::from(err))
1277        } else {
1278            let status = response.status().as_u16() as i32;
1279            let ret_err = response.json::<ReturnErrorType>().await
1280                .map_err(|err| OpenAIApiError::from(err))?;
1281            Err(OpenAIApiError::new(status, ret_err.error))
1282        }
1283    }
1284
1285    /// Upload file
1286    /// POST https://api.openai.com/v1/files
1287    /// Upload a file that contains document(s) to be used across various endpoints/features. Currently, the size of all the files uploaded by one organization can be up to 1 GB. Please contact us if you need to increase the storage limit.
1288    pub async fn upload_file(self, request: UploadFileRequest) -> Result<UploadFileResponse, OpenAIApiError> {
1289        let client_builder = reqwest::Client::builder();
1290        let request_builder = self.configuration.apply_to_request(
1291            client_builder, 
1292            "/files".to_string(), 
1293            Method::POST,
1294        );
1295        let file = fs::read(request.file.clone()).unwrap();
1296        let file_part = Part::bytes(file)
1297            .file_name(request.file.clone())
1298            .mime_str(mime::APPLICATION_JSON.to_string().as_str())
1299            .unwrap();
1300        let form = reqwest::multipart::Form::new().part("file", file_part)
1301            .text("purpose", request.purpose);
1302        let response = request_builder.multipart(form).send().await
1303            .map_err(|err| OpenAIApiError::from(err))?;
1304        if response.status().is_success() {
1305            response.json::<UploadFileResponse>().await
1306                .map_err(|err| OpenAIApiError::from(err))
1307        } else {
1308            let status = response.status().as_u16() as i32;
1309            let ret_err = response.json::<ReturnErrorType>().await
1310                .map_err(|err| OpenAIApiError::from(err))?;
1311            Err(OpenAIApiError::new(status, ret_err.error))
1312        }
1313    }
1314
1315    /// Delete file
1316    /// DELETE https://api.openai.com/v1/files/{file_id}
1317    /// Delete a file.
1318    pub async fn delete_file(self, file_id: String) -> Result<DeleteFileResponse, OpenAIApiError> {
1319        let client_builder = reqwest::Client::builder();
1320        let request_builder = self.configuration.apply_to_request(
1321            client_builder, 
1322            format!("/files/{}", file_id), 
1323            Method::DELETE,
1324        );
1325        let response = request_builder.send().await
1326            .map_err(|err| OpenAIApiError::from(err))?;
1327        if response.status().is_success() {
1328            response.json::<DeleteFileResponse>().await
1329                .map_err(|err| OpenAIApiError::from(err))
1330        } else {
1331            let status = response.status().as_u16() as i32;
1332            let ret_err = response.json::<ReturnErrorType>().await
1333                .map_err(|err| OpenAIApiError::from(err))?;
1334            Err(OpenAIApiError::new(status, ret_err.error))
1335        }
1336    }
1337
1338    /// Retrieve file
1339    /// GET https://api.openai.com/v1/files/{file_id}
1340    /// Returns information about a specific file.
1341    pub async fn retrieve_file(self, file_id: String) -> Result<RetrieveFileResponse, OpenAIApiError> {
1342        let client_builder = reqwest::Client::builder();
1343        let request_builder = self.configuration.apply_to_request(
1344            client_builder, 
1345            format!("/files/{}", file_id), 
1346            Method::GET,
1347        );
1348        let response = request_builder.send().await
1349            .map_err(|err| OpenAIApiError::from(err))?;
1350        if response.status().is_success() {
1351            response.json::<RetrieveFileResponse>().await
1352                .map_err(|err| OpenAIApiError::from(err))
1353        } else {
1354            let status = response.status().as_u16() as i32;
1355            let ret_err = response.json::<ReturnErrorType>().await
1356                .map_err(|err| OpenAIApiError::from(err))?;
1357            Err(OpenAIApiError::new(status, ret_err.error))
1358        }
1359    }
1360
1361    /// Retrieve file content
1362    /// GET https://api.openai.com/v1/files/{file_id}/content
1363    /// Returns the contents of the specified file
1364    pub async fn retrieve_file_content(self, file_id: String) -> Result<String, OpenAIApiError> {
1365        let client_builder = reqwest::Client::builder();
1366        let request_builder = self.configuration.apply_to_request(
1367            client_builder, 
1368            format!("/files/{}/content", file_id), 
1369            Method::GET,
1370        );
1371        let response = request_builder.send().await
1372            .map_err(|err| OpenAIApiError::from(err))?;
1373        if response.status().is_success() {
1374            response.text().await
1375                .map_err(|err| OpenAIApiError::from(err))
1376        } else {
1377            let status = response.status().as_u16() as i32;
1378            let ret_err = response.json::<ReturnErrorType>().await
1379                .map_err(|err| OpenAIApiError::from(err))?;
1380            Err(OpenAIApiError::new(status, ret_err.error))
1381        }
1382    }
1383
1384    /// Create fine-tune
1385    /// POST https://api.openai.com/v1/fine-tunes
1386    /// Creates a job that fine-tunes a specified model from a given dataset.
1387    /// Response includes details of the enqueued job including job status and the name of the fine-tuned models once complete.
1388    pub async fn create_fine_tune(self, request: CreateFineTuneRequest) -> Result<CreateFineTuneResponse, OpenAIApiError> {
1389        let client_builder = reqwest::Client::builder();
1390        let request_builder = self.configuration.apply_to_request(
1391            client_builder, 
1392            "/fine-tunes".to_string(), 
1393            Method::POST,
1394        );
1395        let response = request_builder.json(&request).send().await
1396            .map_err(|err| OpenAIApiError::from(err))?;
1397        if response.status().is_success() {
1398            response.json::<CreateFineTuneResponse>().await
1399                .map_err(|err| OpenAIApiError::from(err))
1400        } else {
1401            let status = response.status().as_u16() as i32;
1402            let ret_err = response.json::<ReturnErrorType>().await
1403                .map_err(|err| OpenAIApiError::from(err))?;
1404            Err(OpenAIApiError::new(status, ret_err.error))
1405        }
1406    }
1407
1408    /// List fine-tunes
1409    /// GET https://api.openai.com/v1/fine-tunes
1410    /// List your organization's fine-tuning jobs
1411    pub async fn list_fine_tunes(self) -> Result<ListFineTunesResponse, OpenAIApiError> {
1412        let client_builder = reqwest::Client::builder();
1413        let request_builder = self.configuration.apply_to_request(
1414            client_builder, 
1415            "/fine-tunes".to_string(), 
1416            Method::GET,
1417        );
1418        let response = request_builder.send().await
1419            .map_err(|err| OpenAIApiError::from(err))?;
1420        if response.status().is_success() {
1421            response.json::<ListFineTunesResponse>().await
1422                .map_err(|err| OpenAIApiError::from(err))
1423        } else {
1424            let status = response.status().as_u16() as i32;
1425            let ret_err = response.json::<ReturnErrorType>().await
1426                .map_err(|err| OpenAIApiError::from(err))?;
1427            Err(OpenAIApiError::new(status, ret_err.error))
1428        }
1429    }
1430
1431    /// Retrieve fine-tune
1432    /// GET https://api.openai.com/v1/fine-tunes/{fine_tune_id}
1433    /// Gets info about the fine-tune job.
1434    pub async fn retrieve_fine_tune(self, fine_tune_id: String) -> Result<RetrieveFineTuneResponse, OpenAIApiError> {
1435        let client_builder = reqwest::Client::builder();
1436        let request_builder = self.configuration.apply_to_request(
1437            client_builder, 
1438            format!("/fine-tunes/{}", fine_tune_id), 
1439            Method::GET,
1440        );
1441        let response = request_builder.send().await
1442            .map_err(|err| OpenAIApiError::from(err))?;
1443        if response.status().is_success() {
1444            response.json::<RetrieveFineTuneResponse>().await
1445                .map_err(|err| OpenAIApiError::from(err))
1446        } else {
1447            let status = response.status().as_u16() as i32;
1448            let ret_err = response.json::<ReturnErrorType>().await
1449                .map_err(|err| OpenAIApiError::from(err))?;
1450            Err(OpenAIApiError::new(status, ret_err.error))
1451        }
1452    }
1453
1454    /// Cancel fine-tune
1455    /// POST https://api.openai.com/v1/fine-tunes/{fine_tune_id}/cancel
1456    /// Immediately cancel a fine-tune job.
1457    pub async fn cancel_fine_tune(self, fine_tune_id: String) -> Result<CancelFineTuneResponse, OpenAIApiError> {
1458        let client_builder = reqwest::Client::builder();
1459        let request_builder = self.configuration.apply_to_request(
1460            client_builder, 
1461            format!("/fine-tunes/{}/cancel", fine_tune_id), 
1462            Method::POST,
1463        );
1464        let response = request_builder.send().await
1465            .map_err(|err| OpenAIApiError::from(err))?;
1466        if response.status().is_success() {
1467            response.json::<CancelFineTuneResponse>().await
1468                .map_err(|err| OpenAIApiError::from(err))
1469        } else {
1470            let status = response.status().as_u16() as i32;
1471            let ret_err = response.json::<ReturnErrorType>().await
1472                .map_err(|err| OpenAIApiError::from(err))?;
1473            Err(OpenAIApiError::new(status, ret_err.error))
1474        }
1475    }
1476
1477    /// List fine-tune events
1478    /// GET https://api.openai.com/v1/fine-tunes/{fine_tune_id}/events
1479    /// Get fine-grained status updates for a fine-tune job.
1480    pub async fn list_fine_tune_events(self, fine_tune_id: String) -> Result<ListFineTuneEventsResponse, OpenAIApiError> {
1481        let client_builder = reqwest::Client::builder();
1482        let request_builder = self.configuration.apply_to_request(
1483            client_builder, 
1484            format!("/fine-tunes/{}/events", fine_tune_id), 
1485            Method::GET,
1486        );
1487        let response = request_builder.send().await
1488            .map_err(|err| OpenAIApiError::from(err))?;
1489        if response.status().is_success() {
1490            response.json::<ListFineTuneEventsResponse>().await
1491                .map_err(|err| OpenAIApiError::from(err))
1492        } else {
1493            let status = response.status().as_u16() as i32;
1494            let ret_err = response.json::<ReturnErrorType>().await
1495                .map_err(|err| OpenAIApiError::from(err))?;
1496            Err(OpenAIApiError::new(status, ret_err.error))
1497        }
1498    }
1499
1500    /// Delete fine-tune model
1501    /// DELETE https://api.openai.com/v1/models/{model}
1502    /// Delete a fine-tuned model. You must have the Owner role in your organization.
1503    pub async fn delete_fine_tune_model(self, model: String) -> Result<DeleteFineTuneModelResponse, OpenAIApiError> {
1504        let client_builder = reqwest::Client::builder();
1505        let request_builder = self.configuration.apply_to_request(
1506            client_builder, 
1507            format!("/models/{}", model), 
1508            Method::DELETE,
1509        );
1510        let response = request_builder.send().await
1511            .map_err(|err| OpenAIApiError::from(err))?;
1512        if response.status().is_success() {
1513            response.json::<DeleteFineTuneModelResponse>().await
1514                .map_err(|err| OpenAIApiError::from(err))
1515        } else {
1516            let status = response.status().as_u16() as i32;
1517            let ret_err = response.json::<ReturnErrorType>().await
1518                .map_err(|err| OpenAIApiError::from(err))?;
1519            Err(OpenAIApiError::new(status, ret_err.error))
1520        }
1521    }
1522
1523    /// Create moderation
1524    /// POST https://api.openai.com/v1/moderations
1525    /// Classifies if text violates OpenAI's Content Policy
1526    pub async fn create_moderation(self, request: CreateModerationRequest) -> Result<CreateModerationResponse, OpenAIApiError> {
1527        let client_builder = reqwest::Client::builder();
1528        let request_builder = self.configuration.apply_to_request(
1529            client_builder, 
1530            "/moderations".to_string(), 
1531            Method::POST,
1532        );
1533        let response = request_builder.json(&request).send().await
1534            .map_err(|err| OpenAIApiError::from(err))?;
1535        if response.status().is_success() {
1536            response.json::<CreateModerationResponse>().await
1537                .map_err(|err| OpenAIApiError::from(err))
1538        } else {
1539            let status = response.status().as_u16() as i32;
1540            let ret_err = response.json::<ReturnErrorType>().await
1541                .map_err(|err| OpenAIApiError::from(err))?;
1542            Err(OpenAIApiError::new(status, ret_err.error))
1543        }
1544    }
1545
1546    fn get_mime_type_from_suffix(suffix: String) -> Result<String, OpenAIApiError> {
1547        match suffix.as_str() {
1548            "json" => Ok(mime::APPLICATION_JSON.to_string()),
1549            "txt" => Ok(mime::TEXT_PLAIN.to_string()),
1550            "html" => Ok(mime::TEXT_HTML.to_string()),
1551            "pdf" => Ok(mime::APPLICATION_PDF.to_string()),
1552            "png" => Ok(mime::IMAGE_PNG.to_string()),
1553            "jpg" => Ok(mime::IMAGE_JPEG.to_string()),
1554            "jpeg" => Ok(mime::IMAGE_JPEG.to_string()),
1555            "gif" => Ok(mime::IMAGE_GIF.to_string()),
1556            "svg" => Ok(mime::IMAGE_SVG.to_string()),
1557            "m4a" => Ok("audio/m4a".to_string()),
1558            "mp3" => Ok("audio/mp3".to_string()),
1559            "wav" => Ok("audio/wav".to_string()),
1560            "flac" => Ok("audio/flac".to_string()),
1561            "mp4" => Ok("video/mp4".to_string()),
1562            "mpeg" => Ok("video/mpeg".to_string()),
1563            "mpga" => Ok("audio/mpeg".to_string()),
1564            "webm" => Ok("video/webm".to_string()),
1565            _ => {
1566                let e = ErrorInfo {
1567                    message: format!("Unsupported file type: {}", suffix),
1568                    code: None,
1569                    message_type: "unsupported_file_type".to_string(),
1570                    param: None,
1571                };
1572                Err(OpenAIApiError::new(400, e))
1573            },
1574        }
1575    }
1576
1577}
1578
1579#[cfg(test)]
1580mod tests {
1581    use super::*;
1582    use crate::configuration::Configuration;
1583    use dotenv::vars;
1584
1585    #[tokio::test]
1586    async fn test_list_models() {
1587
1588        let api_key = vars().find(|(key, _)| key == "API_KEY").unwrap_or(("API_KEY".to_string(),"".to_string())).1;
1589
1590        let configuration = Configuration::new_personal(api_key)
1591            .proxy("http://127.0.0.1:7890".to_string());
1592
1593        let openai_api = OpenAIApi::new(configuration);
1594        let response = openai_api.list_models().await.unwrap();
1595        assert_eq!(response.object, "list");
1596    }
1597
1598    #[tokio::test]
1599    async fn test_retrieve_model() {
1600        let api_key = vars().find(|(key, _)| key == "API_KEY").unwrap_or(("API_KEY".to_string(),"".to_string())).1;
1601
1602        let configuration = Configuration::new_personal(api_key)
1603            .proxy("http://127.0.0.1:7890".to_string());
1604
1605        let openai_api = OpenAIApi::new(configuration);
1606        let response = openai_api.retrieve_model("davinci".to_string()).await.unwrap();
1607        assert_eq!(response.object, "model");
1608    }
1609
1610    #[tokio::test]
1611    async fn test_create_completion() {
1612        let api_key = vars().find(|(key, _)| key == "API_KEY").unwrap_or(("API_KEY".to_string(),"".to_string())).1;
1613
1614        let configuration = Configuration::new_personal(api_key)
1615            .proxy("http://127.0.0.1:7890".to_string());
1616
1617        let openai_api = OpenAIApi::new(configuration);
1618        let request = CreateCompletionRequest {
1619            model: "text-davinci-003".to_string(),
1620            prompt: Some(vec!["Once upon a time".to_string()]),
1621            max_tokens: Some(7),
1622            temperature: Some(0.7),
1623            ..Default::default()
1624        };
1625        
1626        // println!("request: {:#?}", serde_json::to_string(&request).unwrap());
1627        let response = openai_api.create_completion(request).await.unwrap();
1628        assert_eq!(response.object, "text_completion");
1629    }
1630
1631    #[tokio::test]
1632    async fn test_create_chat_completion() {
1633        let api_key = vars().find(|(key, _)| key == "API_KEY").unwrap_or(("API_KEY".to_string(),"".to_string())).1;
1634
1635        let configuration = Configuration::new_personal(api_key)
1636            .proxy("http://127.0.0.1:7890".to_string());
1637
1638        let openai_api = OpenAIApi::new(configuration);
1639        let request = CreateChatCompletionRequest {
1640            model: "gpt-3.5-turbo".to_string(),
1641            messages: vec![ChatFormat{role: "user".to_string(), content: "tell me a story".to_string()}],
1642            ..Default::default()
1643        };
1644        // println!("request: {:#?}", serde_json::to_string(&request).unwrap());
1645        let response = openai_api.create_chat_completion(request).await.unwrap();
1646        assert_eq!(response.object, "chat.completion");
1647    }
1648
1649    #[tokio::test]
1650    async fn test_create_edit() {
1651        let api_key = vars().find(|(key, _)| key == "API_KEY").unwrap_or(("API_KEY".to_string(),"".to_string())).1;
1652
1653        let configuration = Configuration::new_personal(api_key)
1654            .proxy("http://127.0.0.1:7890".to_string());
1655
1656        let openai_api = OpenAIApi::new(configuration);
1657        let request = CreateEditRequest {
1658            model: "text-davinci-edit-001".to_string(),
1659            input: Some("What day of the wek is it?".to_string()),
1660            instruction: "Fix the spelling mistakes".to_string(),
1661            ..Default::default()
1662        };
1663        // println!("request: {:#?}", serde_json::to_string(&request).unwrap());
1664        let response = openai_api.create_edit(request).await.unwrap();
1665        assert_eq!(response.object, "edit");
1666    }
1667
1668    #[tokio::test]
1669    async fn test_create_image() {
1670        let api_key = vars().find(|(key, _)| key == "API_KEY").unwrap_or(("API_KEY".to_string(),"".to_string())).1;
1671
1672        let configuration = Configuration::new_personal(api_key)
1673            .proxy("http://127.0.0.1:7890".to_string());
1674
1675        let openai_api = OpenAIApi::new(configuration);
1676        let request = CreateImageRequest {
1677            prompt: "A photo of a dog".to_string(),
1678            n: Some(1),
1679            size: Some("512x512".to_string()),
1680            response_format: Some(ImageFormat::URL),
1681            ..Default::default()
1682        };
1683        println!("request: {:#?}", serde_json::to_string(&request).unwrap());
1684        let response = openai_api.create_image(request).await.unwrap();
1685        
1686        assert_eq!(response.data.len(), 1);
1687        match response.data[0].clone() {
1688            CreateImageResponseData::Url(url) => {
1689                assert!(url.starts_with("https://"));
1690            },
1691            _ => {
1692                assert!(false, "error response format");
1693            }
1694        }
1695    }
1696
1697    #[tokio::test]
1698    async fn test_create_transcription() {
1699        let api_key = vars().find(|(key, _)| key == "API_KEY").unwrap_or(("API_KEY".to_string(),"".to_string())).1;
1700
1701        let configuration = Configuration::new_personal(api_key)
1702            .proxy("http://127.0.0.1:7890".to_string());
1703
1704        let openai_api = OpenAIApi::new(configuration);
1705        let request = CreateTranscriptionRequest {
1706            file: "./misc/test_audio.m4a".to_string(),
1707            model: "whisper-1".to_string(),
1708            response_format: Some(CreateTranscriptionResponseFormat::JSON),
1709            ..Default::default()
1710        };
1711        println!("request: {:#?}", serde_json::to_string(&request).unwrap());
1712        let response = openai_api.create_transcription(request).await.unwrap();
1713        match response {
1714            CreateTranscriptionResponse::Json(content) => {
1715                assert_eq!(content.text, "你好你好");
1716            },
1717            _ => {
1718                assert!(false);
1719            }
1720        };
1721    }
1722
1723    #[tokio::test]
1724    async fn test_create_translation() {
1725        let api_key = vars().find(|(key, _)| key == "API_KEY").unwrap_or(("API_KEY".to_string(),"".to_string())).1;
1726
1727        let configuration = Configuration::new_personal(api_key)
1728            .proxy("http://127.0.0.1:7890".to_string());
1729
1730        let openai_api = OpenAIApi::new(configuration);
1731        let request = CreateTranslationRequest {
1732            file: "./misc/test_audio.m4a".to_string(),
1733            model: "whisper-1".to_string(),
1734            response_format: Some(CreateTranscriptionResponseFormat::JSON),
1735            ..Default::default()
1736        };
1737        println!("request: {:#?}", serde_json::to_string(&request).unwrap());
1738        let response = openai_api.create_translation(request).await.unwrap();
1739        match response {
1740            CreateTranslationResponse::Json(content) => {
1741                assert_eq!(content.text, "Ni hao, ni hao.");
1742            },
1743            _ => {
1744                assert!(false);
1745            }
1746        };
1747    }
1748
1749    #[tokio::test]
1750    async fn test_list_files() {
1751        let api_key = vars().find(|(key, _)| key == "API_KEY").unwrap_or(("API_KEY".to_string(),"".to_string())).1;
1752
1753        let configuration = Configuration::new_personal(api_key)
1754            .proxy("http://127.0.0.1:7890".to_string());
1755
1756        let openai_api = OpenAIApi::new(configuration);
1757        let response = openai_api.list_files().await.unwrap();
1758        assert_eq!(response.object, "list");
1759    }
1760
1761    #[tokio::test]
1762    async fn test_list_fine_tunes() {
1763        let api_key = vars().find(|(key, _)| key == "API_KEY").unwrap_or(("API_KEY".to_string(),"".to_string())).1;
1764
1765        let configuration = Configuration::new_personal(api_key)
1766            .proxy("http://127.0.0.1:7890".to_string());
1767
1768        let openai_api = OpenAIApi::new(configuration);
1769        let response = openai_api.list_fine_tunes().await.unwrap();
1770        assert_eq!(response.object, "list");
1771    }
1772
1773    #[tokio::test]
1774    async fn test_create_moderation() {
1775        let api_key = vars().find(|(key, _)| key == "API_KEY").unwrap_or(("API_KEY".to_string(),"".to_string())).1;
1776
1777        let configuration = Configuration::new_personal(api_key)
1778            .proxy("http://127.0.0.1:7890".to_string());
1779
1780        let openai_api = OpenAIApi::new(configuration);
1781        let response = openai_api.create_moderation(CreateModerationRequest {
1782            input: vec!["I want to kill them.".to_string()],
1783            ..Default::default()
1784        }).await.unwrap();
1785        // println!("response: {:#?}", response);
1786        assert!(response.results[0].categories.violence);
1787    }
1788
1789}
1790    
1791
1792
1793
1794