1
2use std::fmt::{Display, Formatter};
3use std::fs;
4use std::pin::Pin;
5use crate::{OpenAIApiError, ReturnErrorType, ErrorInfo, stream};
6use crate::configuration::Configuration;
7use serde_derive::{Deserialize, Serialize};
8use reqwest::{Method, multipart::Part};
9use tracing::*;
10use futures::{Stream};
11use reqwest_eventsource::{RequestBuilderExt};
12
13
14#[derive(Deserialize, Serialize, Debug)]
15pub struct Permission {
16 pub id: String,
17 pub object: String,
18 pub created: i64,
19 pub allow_create_engine: bool,
20 pub allow_sampling: bool,
21 pub allow_logprobs: bool,
22 pub allow_search_indices: bool,
23 pub allow_view: bool,
24 pub allow_fine_tuning: bool,
25 pub organization: String,
26 pub group: Option<String>,
27 pub is_blocking: bool,
28}
29
30#[derive(Deserialize, Serialize, Debug)]
31pub struct ModelInfo {
32 pub id: String,
33 pub object: String,
34 pub owned_by: String,
35 pub permission: Vec<Permission>,
36 pub root: String,
37 pub parent: Option<String>,
38}
39
40#[derive(Deserialize, Serialize, Debug)]
41pub struct ListModelsResponse {
42 pub data: Vec<ModelInfo>,
43 pub object: String,
44}
45
46pub type RetrieveModelResponse = ModelInfo;
47
48#[derive(Deserialize, Serialize, Debug, Default)]
49pub struct CreateCompletionRequest {
50 pub model: String,
53 #[serde(skip_serializing_if = "Option::is_none")]
56 pub prompt: Option<Vec<String>>,
57 #[serde(skip_serializing_if = "Option::is_none")]
59 pub suffix: Option<String>,
60 #[serde(skip_serializing_if = "Option::is_none")]
64 pub max_tokens: Option<u64>,
65 #[serde(skip_serializing_if = "Option::is_none")]
68 pub temperature: Option<f32>,
69 #[serde(skip_serializing_if = "Option::is_none")]
73 pub top_p: Option<f32>,
74 #[serde(skip_serializing_if = "Option::is_none")]
78 pub n: Option<u16>,
79 #[serde(skip_serializing_if = "Option::is_none")]
81 pub stream: Option<bool>,
82 #[serde(skip_serializing_if = "Option::is_none")]
86 pub logprobs: Option<i16>,
87 #[serde(skip_serializing_if = "Option::is_none")]
89 pub echo: Option<bool>,
90 #[serde(skip_serializing_if = "Option::is_none")]
92 pub stop: Option<Vec<String>>,
93 #[serde(skip_serializing_if = "Option::is_none")]
95 pub presence_penalty: Option<f32>,
96 #[serde(skip_serializing_if = "Option::is_none")]
98 pub frequency_penalty: Option<f32>,
99 #[serde(skip_serializing_if = "Option::is_none")]
103 pub best_of: Option<u16>,
104 #[serde(skip_serializing_if = "Option::is_none")]
110 pub logit_bias: Option<serde_json::Value>,
111 #[serde(skip_serializing_if = "Option::is_none")]
113 pub user: Option<String>,
114}
115
116#[derive(Deserialize, Serialize, Debug)]
117pub struct CreateCompletionResponseChoice {
118 pub text: String,
119 pub index: i64,
120 pub logprobs: Option<serde_json::Value>,
121 pub finish_reason: Option<String>,
122}
123
124#[derive(Deserialize, Serialize, Debug)]
125pub struct Usage {
126 pub prompt_tokens: i64,
127 pub completion_tokens: i64,
128 pub total_tokens: i64,
129}
130
131#[derive(Deserialize, Serialize, Debug)]
132pub struct CreateCompletionResponse {
133 pub id: String,
134 pub object: String,
135 pub created: i64,
136 pub model: String,
137 pub choices: Vec<CreateCompletionResponseChoice>,
138 pub usage: Option<Usage>,
139}
140
141pub type CreateCompletionResponseStream =
142 Pin<Box<dyn Stream<Item = Result<CreateCompletionResponse, OpenAIApiError>> + Send>>;
143
144#[derive(Deserialize, Serialize, Debug)]
145pub struct ChatFormat {
146 pub role: String,
147 pub content: String,
148}
149
150#[derive(Deserialize, Serialize, Debug)]
151pub struct ChatFormatDelta {
152 pub role: Option<String>,
153 pub content: Option<String>,
154}
155
156#[derive(Deserialize, Serialize, Debug, Default)]
157pub struct CreateChatCompletionRequest {
158 pub model: String,
160 pub messages: Vec<ChatFormat>,
162 #[serde(skip_serializing_if = "Option::is_none")]
165 pub temperature: Option<f32>,
166 #[serde(skip_serializing_if = "Option::is_none")]
169 pub top_p: Option<f32>,
170 #[serde(skip_serializing_if = "Option::is_none")]
172 pub n: Option<u16>,
173 #[serde(skip_serializing_if = "Option::is_none")]
175 pub stream: Option<bool>,
176 #[serde(skip_serializing_if = "Option::is_none")]
178 pub stop: Option<Vec<String>>,
179 #[serde(skip_serializing_if = "Option::is_none")]
181 pub max_tokens: Option<u64>,
182 #[serde(skip_serializing_if = "Option::is_none")]
184 pub presence_penalty: Option<f32>,
185 #[serde(skip_serializing_if = "Option::is_none")]
187 pub frequency_penalty: Option<f32>,
188 #[serde(skip_serializing_if = "Option::is_none")]
192 pub logit_bias: Option<serde_json::Value>,
193 #[serde(skip_serializing_if = "Option::is_none")]
195 pub user: Option<String>,
196}
197
198#[derive(Deserialize, Serialize, Debug)]
199pub struct CreateChatCompletionResponseChoice {
200 pub message: ChatFormat,
201 pub index: i64,
202 pub finish_reason: String,
203}
204
205#[derive(Deserialize, Serialize, Debug)]
206pub struct CreateChatCompletionResponseChoiceDelta {
207 pub delta: ChatFormatDelta,
208 pub index: i64,
209 pub finish_reason: Option<String>,
210}
211
212#[derive(Deserialize, Serialize, Debug)]
213pub struct CreateChatCompletionResponse {
214 pub id: String,
215 pub object: String,
216 pub created: i64,
217 pub choices: Vec<CreateChatCompletionResponseChoice>,
218 pub usage: Usage,
219}
220
221#[derive(Deserialize, Serialize, Debug)]
222pub struct CreateChatCompletionStreamResponse {
223 pub id: String,
224 pub object: String,
225 pub created: i64,
226 pub model: String,
227 pub choices: Vec<CreateChatCompletionResponseChoiceDelta>,
228 }
230
231#[derive(Deserialize, Serialize, Debug, Default)]
232pub struct CreateEditRequest {
233 pub model: String,
235 #[serde(skip_serializing_if = "Option::is_none")]
237 pub input: Option<String>,
238 pub instruction: String,
240 #[serde(skip_serializing_if = "Option::is_none")]
242 pub n: Option<u16>,
243 #[serde(skip_serializing_if = "Option::is_none")]
246 pub temperature: Option<f32>,
247 #[serde(skip_serializing_if = "Option::is_none")]
251 pub top_p: Option<f32>,
252}
253
254pub type CreateChatCompletionResponseStream =
255 Pin<Box<dyn Stream<Item = Result<CreateChatCompletionStreamResponse, OpenAIApiError>> + Send>>;
256
257#[derive(Deserialize, Serialize, Debug)]
258pub struct CreateEditResponseChoice {
259 pub index: i64,
260 pub text: String,
261}
262
263#[derive(Deserialize, Serialize, Debug)]
264pub struct CreateEditResponse {
265 pub object: String,
266 pub created: i64,
267 pub choices: Vec<CreateEditResponseChoice>,
268 pub usage: Usage,
269}
270
271#[derive(Deserialize, Serialize, Debug)]
272pub enum ImageFormat {
273 #[serde(rename = "url")]
274 URL,
275 #[serde(rename = "b64_json")]
276 B64JSON,
277}
278
279impl Display for ImageFormat {
280 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
281 match self {
282 ImageFormat::URL => write!(f, "url"),
283 ImageFormat::B64JSON => write!(f, "b64_json"),
284 }
285 }
286
287}
288
289#[derive(Deserialize, Serialize, Debug, Default)]
290pub struct CreateImageRequest {
291 pub prompt: String,
293 #[serde(skip_serializing_if = "Option::is_none")]
295 pub n: Option<u16>,
296 #[serde(skip_serializing_if = "Option::is_none")]
298 pub size: Option<String>,
299 #[serde(skip_serializing_if = "Option::is_none")]
301 pub response_format: Option<ImageFormat>,
302 #[serde(skip_serializing_if = "Option::is_none")]
304 pub user: Option<String>,
305}
306
307#[derive(Deserialize, Serialize, Debug, Clone)]
308pub enum CreateImageResponseData {
309 #[serde(rename = "url")]
310 Url(String),
311 #[serde(rename = "b64_json")]
312 B64Json(String),
313}
314
315#[derive(Deserialize, Serialize, Debug)]
316pub struct CreateImageResponse {
317 pub created: i64,
318 pub data: Vec<CreateImageResponseData>,
319}
320
321#[derive(Deserialize, Serialize, Debug, Default)]
322pub struct CreateImageEditRequest {
323 pub image: String,
326 #[serde(skip_serializing_if = "Option::is_none")]
328 pub mask: Option<String>,
329 pub prompt: String,
331 #[serde(skip_serializing_if = "Option::is_none")]
333 pub n: Option<u16>,
334 #[serde(skip_serializing_if = "Option::is_none")]
336 pub size: Option<String>,
337 #[serde(skip_serializing_if = "Option::is_none")]
339 pub response_format: Option<ImageFormat>,
340 #[serde(skip_serializing_if = "Option::is_none")]
342 pub user: Option<String>,
343}
344
345pub type CreateImageEditResponse = CreateImageResponse;
346
347#[derive(Deserialize, Serialize, Debug, Default)]
348pub struct CreateImageVariationRequest {
349 pub image: String,
351 #[serde(skip_serializing_if = "Option::is_none")]
353 pub n: Option<u16>,
354 #[serde(skip_serializing_if = "Option::is_none")]
356 pub size: Option<String>,
357 #[serde(skip_serializing_if = "Option::is_none")]
359 pub response_format: Option<ImageFormat>,
360 #[serde(skip_serializing_if = "Option::is_none")]
362 pub user: Option<String>,
363}
364
365pub type CreateImageVariationResponse = CreateImageResponse;
366
367#[derive(Deserialize, Serialize, Debug, Default)]
368pub struct CreateEmbeddingsRequest {
369 pub model: String,
371 pub input: Vec<String>,
373 #[serde(skip_serializing_if = "Option::is_none")]
375 pub user: Option<String>,
376}
377
378#[derive(Deserialize, Serialize, Debug)]
379pub struct CreateEmbeddingsResponseData {
380 pub object: String,
381 pub embedding: Vec<f32>,
382 pub index: i64,
383}
384
385#[derive(Deserialize, Serialize, Debug)]
386pub struct CreateEmbeddingsResponseUsage {
387 pub prompt_tokens: i64,
388 pub total_tokens: i64,
389}
390
391#[derive(Deserialize, Serialize, Debug)]
392pub struct CreateEmbeddingsResponse {
393 pub object: String,
394 pub data: Vec<CreateEmbeddingsResponseData>,
395 pub model: String,
396 pub usage: CreateEmbeddingsResponseUsage,
397}
398
399#[derive(Deserialize, Serialize, Debug, Clone, Copy)]
400pub enum CreateTranscriptionResponseFormat {
401 #[serde(rename = "json")]
402 JSON,
403 #[serde(rename = "text")]
404 TEXT,
405 #[serde(rename = "srt")]
406 SRT,
407 #[serde(rename = "verbose_json")]
408 VERBOSEJSON,
409 #[serde(rename = "vtt")]
410 VTT,
411}
412
413impl Display for CreateTranscriptionResponseFormat {
414 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
415 match self {
416 CreateTranscriptionResponseFormat::JSON => write!(f, "json"),
417 CreateTranscriptionResponseFormat::TEXT => write!(f, "text"),
418 CreateTranscriptionResponseFormat::SRT => write!(f, "srt"),
419 CreateTranscriptionResponseFormat::VERBOSEJSON => write!(f, "verbose_json"),
420 CreateTranscriptionResponseFormat::VTT => write!(f, "vtt"),
421 }
422 }
423}
424
425#[derive(Deserialize, Serialize, Debug, Default)]
426pub struct CreateTranscriptionRequest {
427 pub file: String,
429 pub model: String,
431 #[serde(skip_serializing_if = "Option::is_none")]
433 pub prompt: Option<String>,
434 #[serde(skip_serializing_if = "Option::is_none")]
436 pub response_format: Option<CreateTranscriptionResponseFormat>,
437 #[serde(skip_serializing_if = "Option::is_none")]
439 pub temperature: Option<f32>,
440 #[serde(skip_serializing_if = "Option::is_none")]
442 pub language: Option<String>,
443}
444
445pub enum CreateTranscriptionResponse {
446 Text(CreateTranscriptionResponseText),
447 Json(CreateTranscriptionResponseJson),
448 Srt(CreateTranscriptionResponseSrt),
449 VerboseJson(CreateTranscriptionResponseVerboseJson),
450 Vtt(CreateTranscriptionResponseVtt),
451}
452
453#[derive(Deserialize, Serialize, Debug)]
454pub struct CreateTranscriptionResponseText {
455 pub text: String,
456}
457
458#[derive(Deserialize, Serialize, Debug)]
459pub struct CreateTranscriptionResponseJson {
460 pub text: String,
461}
462
463
464#[derive(Deserialize, Serialize, Debug)]
465pub struct CreateTranscriptionResponseSrt {
466 pub text: String,
467}
468
469#[derive(Deserialize, Serialize, Debug)]
470pub struct TranscriptionSegment {
471 pub id: String,
472 pub seek: i32,
473 pub start: f32,
474 pub end: f32,
475 pub text: String,
476 pub tokens: Vec<i64>,
477 pub temperature: f32,
478 pub avg_logprob: f64,
479 pub compression_ratio: f64,
480 pub no_speech_prob: f64,
481 pub transient: bool,
482}
483
484#[derive(Deserialize, Serialize, Debug)]
485pub struct CreateTranscriptionResponseVerboseJson {
486 pub task: String,
487 pub language: String,
488 pub duration: f32,
489 pub segments: Vec<TranscriptionSegment>,
490 pub text: String,
491}
492
493#[derive(Deserialize, Serialize, Debug)]
494pub struct CreateTranscriptionResponseVtt {
495 pub text: String,
496}
497
498#[derive(Deserialize, Serialize, Debug, Default)]
499pub struct CreateTranslationRequest {
500 pub file: String,
502 pub model: String,
504 #[serde(skip_serializing_if = "Option::is_none")]
506 pub prompt: Option<String>,
507 #[serde(skip_serializing_if = "Option::is_none")]
509 pub response_format: Option<CreateTranscriptionResponseFormat>,
510 #[serde(skip_serializing_if = "Option::is_none")]
512 pub temperature: Option<f32>,
513}
514
515pub type CreateTranslationResponse = CreateTranscriptionResponse;
516
517#[derive(Deserialize, Serialize, Debug)]
518pub struct FileInfo {
519 pub id: String,
520 pub object: String,
521 pub bytes: i32,
522 pub created_at: i64,
523 pub filename: String,
524 pub purpose: String,
525}
526
527#[derive(Deserialize, Serialize, Debug)]
528pub struct ListFilesResponse {
529 pub data: Vec<FileInfo>,
530 pub object: String,
531}
532
533#[derive(Deserialize, Serialize, Debug)]
534pub struct UploadFileRequest {
535 pub file: String,
538 pub filename: String,
540 pub purpose: String,
542}
543
544pub type UploadFileResponse = FileInfo;
545
546#[derive(Deserialize, Serialize, Debug)]
547pub struct DeleteFileResponse {
548 pub deleted: bool,
549 pub id: String,
550 pub object: String,
551}
552
553pub type RetrieveFileResponse = FileInfo;
554
555#[derive(Deserialize, Serialize, Debug, Default)]
556pub struct CreateFineTuneRequest {
557 pub training_file: String,
562
563 #[serde(skip_serializing_if = "Option::is_none")]
569 pub validation_file: Option<String>,
570
571 #[serde(skip_serializing_if = "Option::is_none")]
575 pub model: Option<String>,
576
577 #[serde(skip_serializing_if = "Option::is_none")]
579 pub n_epochs: Option<i32>,
580
581 #[serde(skip_serializing_if = "Option::is_none")]
584 pub batch_size: Option<i32>,
585
586 #[serde(skip_serializing_if = "Option::is_none")]
590 pub learning_rate_multiplier: Option<f32>,
591
592 #[serde(skip_serializing_if = "Option::is_none")]
595 pub prompt_loss_weight: Option<f32>,
596
597 #[serde(skip_serializing_if = "Option::is_none")]
602 pub compute_classification_metrics: Option<bool>,
603
604 #[serde(skip_serializing_if = "Option::is_none")]
607 pub classification_n_classes: Option<i32>,
608
609 #[serde(skip_serializing_if = "Option::is_none")]
612 pub classification_positive_class: Option<String>,
613
614 #[serde(skip_serializing_if = "Option::is_none")]
620 pub classification_betas: Option<Vec<f32>>,
621
622 #[serde(skip_serializing_if = "Option::is_none")]
625 pub suffix: Option<String>,
626}
627
628#[derive(Deserialize, Serialize, Debug)]
629pub struct FineTuneEvent {
630 pub object: String,
631 pub created_at: i64,
632 pub level: String,
633 pub message: String,
634}
635
636#[derive(Deserialize, Serialize, Debug)]
637pub struct FineTuneHyperparams {
638 pub batch_size: i32,
639 pub learning_rate_multiplier: f32,
640 pub prompt_loss_weight: f32,
641 pub n_epochs: i32,
642}
643
644#[derive(Deserialize, Serialize, Debug)]
645pub struct CreateFineTuneResponse {
646 pub id: String,
647 pub object: String,
648 pub model: String,
649 pub created_at: i64,
650 pub events: Vec<FineTuneEvent>,
651 pub fine_tuned_model: Option<String>,
652 pub hyperparams: FineTuneHyperparams,
653 pub organization_id: String,
654 pub result_files: Vec<FileInfo>,
655 pub status: String,
656 pub validation_files: Vec<FileInfo>,
657 pub training_files: Vec<FileInfo>,
658 pub updated_at: i64,
659}
660
661#[derive(Deserialize, Serialize, Debug)]
662pub struct ListFineTunesResponse {
663 pub object: String,
664 pub data: Vec<CreateFineTuneResponse>,
665}
666
667pub type RetrieveFineTuneResponse = CreateFineTuneResponse;
668
669pub type CancelFineTuneResponse = CreateFineTuneResponse;
670
671#[derive(Deserialize, Serialize, Debug)]
672pub struct ListFineTuneEventsResponse {
673 pub object: String,
674 pub data: Vec<FineTuneEvent>,
675}
676
677#[derive(Deserialize, Serialize, Debug)]
678pub struct DeleteFineTuneModelResponse {
679 pub id: String,
680 pub object: String,
681 pub deleted: bool,
682}
683
684#[derive(Deserialize, Serialize, Debug, Default)]
685pub struct CreateModerationRequest {
686 pub input: Vec<String>,
688 #[serde(skip_serializing_if = "Option::is_none")]
694 pub model: Option<String>,
695}
696
697#[derive(Deserialize, Serialize, Debug)]
698pub struct ModerationCategories {
699 pub hate: bool,
700 #[serde(rename = "hate/threatening")]
701 pub hate_threatening: bool,
702 #[serde(rename = "self-harm")]
703 pub self_harm: bool,
704 pub sexual: bool,
705 #[serde(rename = "sexual/minors")]
706 pub sexual_minors: bool,
707 pub violence: bool,
708 #[serde(rename = "violence/graphic")]
709 pub violence_graphic: bool,
710}
711
712#[derive(Deserialize, Serialize, Debug)]
713pub struct ModerationCategoryScores {
714 pub hate: f64,
715 #[serde(rename = "hate/threatening")]
716 pub hate_threatening: f64,
717 #[serde(rename = "self-harm")]
718 pub self_harm: f64,
719 pub sexual: f64,
720 #[serde(rename = "sexual/minors")]
721 pub sexual_minors: f64,
722 pub violence: f64,
723 #[serde(rename = "violence/graphic")]
724 pub violence_graphic: f64,
725}
726
727#[derive(Deserialize, Serialize, Debug)]
728pub struct CreateModerationResult {
729 pub categories: ModerationCategories,
730 pub category_scores: ModerationCategoryScores,
731 pub flagged: bool,
732}
733
734#[derive(Deserialize, Serialize, Debug)]
735pub struct CreateModerationResponse {
736 pub id: String,
737 pub model: String,
738 pub results: Vec<CreateModerationResult>,
739}
740
741pub struct OpenAIApi {
742 configuration: Configuration,
743}
744
745impl OpenAIApi {
746
747 pub fn new(configuration: Configuration) -> Self {
748 Self { configuration }
749 }
750
751 pub async fn list_models(self) -> Result<ListModelsResponse, OpenAIApiError> {
755
756 let client_builder = reqwest::Client::builder();
757 let request_builder = self
758 .configuration
759 .apply_to_request(
760 client_builder,
761 "/models".to_string(),
762 Method::GET,
763 );
764 let response = request_builder
765 .send()
766 .await
767 .map_err(OpenAIApiError::from)?;
768 if response.status().is_success() {
769 response
770 .json::<ListModelsResponse>()
771 .await
772 .map_err(OpenAIApiError::from)
773 } else {
774 let status = response.status().as_u16() as i32;
775 let ret_err = response.json::<ReturnErrorType>().await.map_err( OpenAIApiError::from)?;
776 Err(OpenAIApiError::new(status, ret_err.error))
777 }
778 }
779
780 pub async fn retrieve_model(self, model: String) -> Result<RetrieveModelResponse, OpenAIApiError> {
784
785 let client_builder = reqwest::Client::builder();
786 let request_builder = self.configuration.apply_to_request(
787 client_builder,
788 format!("/models/{}", model),
789 Method::GET,
790 );
791 let response = request_builder.send().await
792 .map_err(|err| OpenAIApiError::from(err))?;
793 if response.status().is_success() {
794 response.json::<RetrieveModelResponse>().await
795 .map_err(|err| OpenAIApiError::from(err))
796 } else {
797 let status = response.status().as_u16() as i32;
798 let ret_err = response.json::<ReturnErrorType>().await
799 .map_err(|err| OpenAIApiError::from(err))?;
800 Err(OpenAIApiError::new(status, ret_err.error))
801 }
802 }
803
804 pub async fn create_completion(self, mut request: CreateCompletionRequest) -> Result<CreateCompletionResponse, OpenAIApiError> {
808
809 let client_builder = reqwest::Client::builder();
810 let request_builder = self.configuration.apply_to_request(
811 client_builder,
812 "/completions".to_string(),
813 Method::POST,
814 );
815 request.stream = None;
816 let response = request_builder.json(&request).send().await
817 .map_err(|err| OpenAIApiError::from(err))?;
818 info!("response: {:#?}", response);
819 if response.status().is_success() {
820 response.json::<CreateCompletionResponse>().await
821 .map_err(|err| OpenAIApiError::from(err))
822 } else {
823 let status = response.status().as_u16() as i32;
824 let ret_err = response.json::<ReturnErrorType>().await
825 .map_err(|err| OpenAIApiError::from(err))?;
826 Err(OpenAIApiError::new(status, ret_err.error))
827 }
828 }
829
830 pub async fn create_completion_stream(self, mut request: CreateCompletionRequest) -> Result<CreateCompletionResponseStream, OpenAIApiError> {
838 let client_builder = reqwest::Client::builder();
839 let request_builder = self.configuration.apply_to_request(
840 client_builder,
841 "/completions".to_string(),
842 Method::POST,
843 );
844 request.stream = Some(true);
845 let event_source = request_builder.json(&request).eventsource().unwrap();
846 Ok(stream(event_source).await)
847 }
848
849 pub async fn create_chat_completion(self, mut request: CreateChatCompletionRequest) -> Result<CreateChatCompletionResponse, OpenAIApiError> {
854
855 let client_builder = reqwest::Client::builder();
856 let request_builder = self.configuration.apply_to_request(
857 client_builder,
858 "/chat/completions".to_string(),
859 Method::POST,
860 );
861 request.stream = None;
862 let response = request_builder.json(&request).send().await
863 .map_err(|err| OpenAIApiError::from(err))?;
864 info!("response: {:#?}", response);
865 if response.status().is_success() {
867 response.json::<CreateChatCompletionResponse>().await
868 .map_err(|err| OpenAIApiError::from(err))
869 } else {
870 let status = response.status().as_u16() as i32;
871 let ret_err = response.json::<ReturnErrorType>().await
872 .map_err(|err| OpenAIApiError::from(err))?;
873 Err(OpenAIApiError::new(status, ret_err.error))
874 }
875 }
876
877 pub async fn create_chat_completion_stream(self, mut request: CreateChatCompletionRequest) -> Result<CreateChatCompletionResponseStream, OpenAIApiError> {
885 let client_builder = reqwest::Client::builder();
886 let request_builder = self.configuration.apply_to_request(
887 client_builder,
888 "/chat/completions".to_string(),
889 Method::POST,
890 );
891 request.stream = Some(true);
892 let event_source = request_builder.json(&request).eventsource().unwrap();
893 Ok(stream(event_source).await)
894 }
895
896 pub async fn create_edit(self, request: CreateEditRequest) -> Result<CreateEditResponse, OpenAIApiError> {
900 let client_builder = reqwest::Client::builder();
901 let request_builder = self.configuration.apply_to_request(
902 client_builder,
903 "/edits".to_string(),
904 Method::POST,
905 );
906 let response = request_builder.json(&request).send().await
907 .map_err(|err| OpenAIApiError::from(err))?;
908 info!("response: {:#?}", response);
909 if response.status().is_success() {
911 response.json::<CreateEditResponse>().await
912 .map_err(|err| OpenAIApiError::from(err))
913 } else {
914 let status = response.status().as_u16() as i32;
915 let ret_err = response.json::<ReturnErrorType>().await
916 .map_err(|err| OpenAIApiError::from(err))?;
917 Err(OpenAIApiError::new(status, ret_err.error))
918 }
919 }
920
921 pub async fn create_image(self, request: CreateImageRequest) -> Result<CreateImageResponse, OpenAIApiError>{
927 let client_builder = reqwest::Client::builder();
928 let request_builder = self.configuration.apply_to_request(
929 client_builder,
930 "/images/generations".to_string(),
931 Method::POST,
932 );
933 let response = request_builder.json(&request).send().await
934 .map_err(|err| OpenAIApiError::from(err))?;
935 if response.status().is_success() {
937 response.json::<CreateImageResponse>().await
938 .map_err(|err| OpenAIApiError::from(err))
939 } else {
940 let status = response.status().as_u16() as i32;
941 let ret_err = response.json::<ReturnErrorType>().await
942 .map_err(|err| OpenAIApiError::from(err))?;
943 Err(OpenAIApiError::new(status, ret_err.error))
944 }
945 }
946
947
948 pub async fn create_image_edit(self, request: CreateImageEditRequest) -> Result<CreateImageEditResponse, OpenAIApiError> {
952 let client_builder = reqwest::Client::builder();
953 let request_builder = self.configuration.apply_to_request(
954 client_builder,
955 "/images/edits".to_string(),
956 Method::POST,
957 );
958 let image_file = fs::read(request.image).unwrap();
959 let image_file_part = Part::bytes(image_file)
960 .file_name("image.png")
961 .mime_str("image/png")
962 .unwrap();
963 let mut form = reqwest::multipart::Form::new()
964 .part("image", image_file_part);
965 form = match request.mask {
966 Some(mask) => {
967 let mask_file = fs::read(mask).unwrap();
968 let mask_file_part = Part::bytes(mask_file)
969 .file_name("mask.png")
970 .mime_str("image/png")
971 .unwrap();
972 form.part("mask", mask_file_part)
973 },
974 None => form,
975 };
976 form = form.text("prompt", request.prompt.clone());
977 form = match request.n {
978 Some(n) => form.text("n", n.to_string()),
979 None => form,
980 };
981 form = match request.size {
982 Some(size) => form.text("size", size),
983 None => form,
984 };
985 form = match request.response_format {
986 Some(response_format) => form.text("response_format", response_format.to_string()),
987 None => form,
988 };
989 form = match request.user {
990 Some(user) => form.text("user", user),
991 None => form,
992 };
993 let response = request_builder.multipart(form).send().await
994 .map_err(|err| OpenAIApiError::from(err))?;
995 if response.status().is_success() {
997 response.json::<CreateImageEditResponse>().await
998 .map_err(|err| OpenAIApiError::from(err))
999 } else {
1000 let status = response.status().as_u16() as i32;
1001 let ret_err = response.json::<ReturnErrorType>().await
1002 .map_err(|err| OpenAIApiError::from(err))?;
1003 Err(OpenAIApiError::new(status, ret_err.error))
1004 }
1005 }
1006
1007 pub async fn create_image_variation(self, request: CreateImageVariationRequest) -> Result<CreateImageVariationResponse, OpenAIApiError> {
1011 let client_builder = reqwest::Client::builder();
1012 let request_builder = self.configuration.apply_to_request(
1013 client_builder,
1014 "/images/variations".to_string(),
1015 Method::POST,
1016 );
1017 let image_file = fs::read(request.image).unwrap();
1018 let image_file_part = Part::bytes(image_file)
1019 .file_name("image.png")
1020 .mime_str("image/png")
1021 .unwrap();
1022 let mut form = reqwest::multipart::Form::new().part("image", image_file_part);
1023
1024 form = match request.n {
1025 Some(n) => form.text("n", n.to_string()),
1026 None => form,
1027 };
1028 form = match request.size {
1029 Some(size) => form.text("size", size),
1030 None => form,
1031 };
1032 form = match request.response_format {
1033 Some(response_format) => form.text("response_format", response_format.to_string()),
1034 None => form,
1035 };
1036 form = match request.user {
1037 Some(user) => form.text("user", user),
1038 None => form,
1039 };
1040 let response = request_builder.multipart(form).send().await
1041 .map_err(|err| OpenAIApiError::from(err))?;
1042 if response.status().is_success() {
1044 response.json::<CreateImageVariationResponse>().await
1045 .map_err(|err| OpenAIApiError::from(err))
1046 } else {
1047 let status = response.status().as_u16() as i32;
1048 let ret_err = response.json::<ReturnErrorType>().await
1049 .map_err(|err| OpenAIApiError::from(err))?;
1050 Err(OpenAIApiError::new(status, ret_err.error))
1051 }
1052 }
1053
1054 pub async fn create_embeddings(self, request: CreateEmbeddingsRequest) -> Result<CreateEmbeddingsResponse, OpenAIApiError> {
1058 let client_builder = reqwest::Client::builder();
1059 let request_builder = self.configuration.apply_to_request(
1060 client_builder,
1061 "/embeddings".to_string(),
1062 Method::POST,
1063 );
1064 let response = request_builder.json(&request).send().await
1065 .map_err(|err| OpenAIApiError::from(err))?;
1066 info!("response: {:#?}", response);
1067 if response.status().is_success() {
1068 response.json::<CreateEmbeddingsResponse>().await
1069 .map_err(|err| OpenAIApiError::from(err))
1070 } else {
1071 let status = response.status().as_u16() as i32;
1072 let ret_err = response.json::<ReturnErrorType>().await
1073 .map_err(|err| OpenAIApiError::from(err))?;
1074 Err(OpenAIApiError::new(status, ret_err.error))
1075 }
1076 }
1077
1078 pub async fn create_transcription(self, request: CreateTranscriptionRequest) -> Result<CreateTranscriptionResponse, OpenAIApiError> {
1082 let client_builder = reqwest::Client::builder();
1083 let request_builder = self.configuration.apply_to_request(
1084 client_builder,
1085 "/audio/transcriptions".to_string(),
1086 Method::POST,
1087 );
1088 let parts: Vec<&str> = request.file.split('.').collect();
1089 let suffix = parts[parts.len() - 1];
1090 let mime_type = Self::get_mime_type_from_suffix(suffix.to_string())?;
1091 let audio_file = fs::read(request.file.clone()).unwrap();
1092 let audio_file_part = Part::bytes(audio_file)
1093 .file_name(format!("audio.{}", suffix))
1094 .mime_str(mime_type.as_str())
1095 .unwrap();
1096 let mut form = reqwest::multipart::Form::new().part("file", audio_file_part)
1097 .text("model", request.model);
1098 form = match request.prompt {
1099 Some(prompt) => form.text("prompt", prompt),
1100 None => form,
1101 };
1102 form = match request.response_format {
1103 Some(response_format) => form.text("response_format", response_format.to_string()),
1104 None => form,
1105 };
1106 form = match request.temperature {
1107 Some(temperature) => form.text("temperature", temperature.to_string()),
1108 None => form,
1109 };
1110 form = match request.language {
1111 Some(language) => form.text("language", language),
1112 None => form,
1113 };
1114 info!("request form: {:#?}", form);
1115 let response = request_builder.multipart(form).send().await
1116 .map_err(|err| OpenAIApiError::from(err))?;
1117 println!("response: {:#?}", response);
1118 let rf = request.response_format.clone();
1119 if response.status().is_success() {
1120 match rf {
1121 Some(response_format) => match response_format {
1122 CreateTranscriptionResponseFormat::TEXT => {
1123 let text = response.text().await
1124 .map_err(|err| OpenAIApiError::from(err))?;
1125 let response = CreateTranscriptionResponseText {
1126 text,
1127 };
1128 Ok(CreateTranscriptionResponse::Text(response))
1129 },
1130 CreateTranscriptionResponseFormat::JSON => {
1131 let response = response.json::<CreateTranscriptionResponseJson>().await
1132 .map_err(|err| OpenAIApiError::from(err)).unwrap();
1133 Ok(CreateTranscriptionResponse::Json(response))
1134 },
1135 CreateTranscriptionResponseFormat::SRT => {
1136 let text = response.text().await
1137 .map_err(|err| OpenAIApiError::from(err))?;
1138 let response = CreateTranscriptionResponseSrt {
1139 text,
1140 };
1141 Ok(CreateTranscriptionResponse::Srt(response))
1142 },
1143 CreateTranscriptionResponseFormat::VTT => {
1144 let text = response.text().await
1145 .map_err(|err| OpenAIApiError::from(err))?;
1146 let response = CreateTranscriptionResponseVtt {
1147 text,
1148 };
1149 Ok(CreateTranscriptionResponse::Vtt(response))
1150 },
1151 CreateTranscriptionResponseFormat::VERBOSEJSON => {
1152 let response = response.json::<CreateTranscriptionResponseVerboseJson>().await
1153 .map_err(|err| OpenAIApiError::from(err))?;
1154 Ok(CreateTranscriptionResponse::VerboseJson(response))
1155 },
1156 },
1157 None => {
1158 let response = response.json::<CreateTranscriptionResponseJson>().await
1159 .map_err(|err| OpenAIApiError::from(err))?;
1160 Ok(CreateTranscriptionResponse::Json(response))
1161 },
1162 }
1163 } else {
1164 let status = response.status().as_u16() as i32;
1165 let ret_err = response.json::<ReturnErrorType>().await
1166 .map_err(|err| OpenAIApiError::from(err))?;
1167 Err(OpenAIApiError::new(status, ret_err.error))
1168 }
1169
1170 }
1171
1172 pub async fn create_translation(self, request: CreateTranslationRequest) -> Result<CreateTranslationResponse, OpenAIApiError> {
1176 let client_builder = reqwest::Client::builder();
1177 let request_builder = self.configuration.apply_to_request(
1178 client_builder,
1179 "/audio/translations".to_string(),
1180 Method::POST,
1181 );
1182 let parts: Vec<&str> = request.file.split('.').collect();
1183 let suffix = parts[parts.len() - 1];
1184 let mime_type = Self::get_mime_type_from_suffix(suffix.to_string()).unwrap();
1185 let audio_file = fs::read(request.file.clone()).unwrap();
1186 let audio_file_part = Part::bytes(audio_file)
1187 .file_name(format!("audio.{}", suffix))
1188 .mime_str(mime_type.as_str())
1189 .unwrap();
1190 let mut form = reqwest::multipart::Form::new().part("file", audio_file_part)
1191 .text("model", request.model);
1192 form = match request.prompt {
1193 Some(prompt) => form.text("prompt", prompt),
1194 None => form,
1195 };
1196 form = match request.response_format {
1197 Some(response_format) => form.text("response_format", response_format.to_string()),
1198 None => form,
1199 };
1200 form = match request.temperature {
1201 Some(temperature) => form.text("temperature", temperature.to_string()),
1202 None => form,
1203 };
1204 let response = request_builder.multipart(form).send().await
1205 .map_err(|err| OpenAIApiError::from(err))?;
1206 info!("response: {:#?}", response);
1207 let rf = request.response_format.clone();
1208 if response.status().is_success() {
1209 match rf {
1210 Some(response_format) => match response_format {
1211 CreateTranscriptionResponseFormat::TEXT => {
1212 let text = response.text().await
1213 .map_err(|err| OpenAIApiError::from(err))?;
1214 let response = CreateTranscriptionResponseText {
1215 text,
1216 };
1217 Ok(CreateTranslationResponse::Text(response))
1218 },
1219 CreateTranscriptionResponseFormat::JSON => {
1220 let response = response.json::<CreateTranscriptionResponseJson>().await
1221 .map_err(|err| OpenAIApiError::from(err))?;
1222 Ok(CreateTranslationResponse::Json(response))
1223 },
1224 CreateTranscriptionResponseFormat::SRT => {
1225 let text = response.text().await
1226 .map_err(|err| OpenAIApiError::from(err))?;
1227 let response = CreateTranscriptionResponseSrt {
1228 text,
1229 };
1230 Ok(CreateTranslationResponse::Srt(response))
1231 },
1232 CreateTranscriptionResponseFormat::VTT => {
1233 let text = response.text().await
1234 .map_err(|err| OpenAIApiError::from(err))?;
1235 let response = CreateTranscriptionResponseVtt {
1236 text,
1237 };
1238 Ok(CreateTranslationResponse::Vtt(response))
1239 },
1240 CreateTranscriptionResponseFormat::VERBOSEJSON => {
1241 let response = response.json::<CreateTranscriptionResponseVerboseJson>().await
1242 .map_err(|err| OpenAIApiError::from(err))?;
1243 Ok(CreateTranslationResponse::VerboseJson(response))
1244 },
1245 },
1246 None => {
1247 let response = response.json::<CreateTranscriptionResponseJson>().await
1248 .map_err(|err| OpenAIApiError::from(err))?;
1249 Ok(CreateTranslationResponse::Json(response))
1250 },
1251 }
1252 } else {
1253 let status = response.status().as_u16() as i32;
1254 let ret_err = response.json::<ReturnErrorType>().await
1255 .map_err(|err| OpenAIApiError::from(err))?;
1256 Err(OpenAIApiError::new(status, ret_err.error))
1257 }
1258
1259
1260 }
1261
1262 pub async fn list_files(self) -> Result<ListFilesResponse, OpenAIApiError> {
1266 let client_builder = reqwest::Client::builder();
1267 let request_builder = self.configuration.apply_to_request(
1268 client_builder,
1269 "/files".to_string(),
1270 Method::GET,
1271 );
1272 let response = request_builder.send().await
1273 .map_err(|err| OpenAIApiError::from(err))?;
1274 if response.status().is_success() {
1275 response.json::<ListFilesResponse>().await
1276 .map_err(|err| OpenAIApiError::from(err))
1277 } else {
1278 let status = response.status().as_u16() as i32;
1279 let ret_err = response.json::<ReturnErrorType>().await
1280 .map_err(|err| OpenAIApiError::from(err))?;
1281 Err(OpenAIApiError::new(status, ret_err.error))
1282 }
1283 }
1284
1285 pub async fn upload_file(self, request: UploadFileRequest) -> Result<UploadFileResponse, OpenAIApiError> {
1289 let client_builder = reqwest::Client::builder();
1290 let request_builder = self.configuration.apply_to_request(
1291 client_builder,
1292 "/files".to_string(),
1293 Method::POST,
1294 );
1295 let file = fs::read(request.file.clone()).unwrap();
1296 let file_part = Part::bytes(file)
1297 .file_name(request.file.clone())
1298 .mime_str(mime::APPLICATION_JSON.to_string().as_str())
1299 .unwrap();
1300 let form = reqwest::multipart::Form::new().part("file", file_part)
1301 .text("purpose", request.purpose);
1302 let response = request_builder.multipart(form).send().await
1303 .map_err(|err| OpenAIApiError::from(err))?;
1304 if response.status().is_success() {
1305 response.json::<UploadFileResponse>().await
1306 .map_err(|err| OpenAIApiError::from(err))
1307 } else {
1308 let status = response.status().as_u16() as i32;
1309 let ret_err = response.json::<ReturnErrorType>().await
1310 .map_err(|err| OpenAIApiError::from(err))?;
1311 Err(OpenAIApiError::new(status, ret_err.error))
1312 }
1313 }
1314
1315 pub async fn delete_file(self, file_id: String) -> Result<DeleteFileResponse, OpenAIApiError> {
1319 let client_builder = reqwest::Client::builder();
1320 let request_builder = self.configuration.apply_to_request(
1321 client_builder,
1322 format!("/files/{}", file_id),
1323 Method::DELETE,
1324 );
1325 let response = request_builder.send().await
1326 .map_err(|err| OpenAIApiError::from(err))?;
1327 if response.status().is_success() {
1328 response.json::<DeleteFileResponse>().await
1329 .map_err(|err| OpenAIApiError::from(err))
1330 } else {
1331 let status = response.status().as_u16() as i32;
1332 let ret_err = response.json::<ReturnErrorType>().await
1333 .map_err(|err| OpenAIApiError::from(err))?;
1334 Err(OpenAIApiError::new(status, ret_err.error))
1335 }
1336 }
1337
1338 pub async fn retrieve_file(self, file_id: String) -> Result<RetrieveFileResponse, OpenAIApiError> {
1342 let client_builder = reqwest::Client::builder();
1343 let request_builder = self.configuration.apply_to_request(
1344 client_builder,
1345 format!("/files/{}", file_id),
1346 Method::GET,
1347 );
1348 let response = request_builder.send().await
1349 .map_err(|err| OpenAIApiError::from(err))?;
1350 if response.status().is_success() {
1351 response.json::<RetrieveFileResponse>().await
1352 .map_err(|err| OpenAIApiError::from(err))
1353 } else {
1354 let status = response.status().as_u16() as i32;
1355 let ret_err = response.json::<ReturnErrorType>().await
1356 .map_err(|err| OpenAIApiError::from(err))?;
1357 Err(OpenAIApiError::new(status, ret_err.error))
1358 }
1359 }
1360
1361 pub async fn retrieve_file_content(self, file_id: String) -> Result<String, OpenAIApiError> {
1365 let client_builder = reqwest::Client::builder();
1366 let request_builder = self.configuration.apply_to_request(
1367 client_builder,
1368 format!("/files/{}/content", file_id),
1369 Method::GET,
1370 );
1371 let response = request_builder.send().await
1372 .map_err(|err| OpenAIApiError::from(err))?;
1373 if response.status().is_success() {
1374 response.text().await
1375 .map_err(|err| OpenAIApiError::from(err))
1376 } else {
1377 let status = response.status().as_u16() as i32;
1378 let ret_err = response.json::<ReturnErrorType>().await
1379 .map_err(|err| OpenAIApiError::from(err))?;
1380 Err(OpenAIApiError::new(status, ret_err.error))
1381 }
1382 }
1383
1384 pub async fn create_fine_tune(self, request: CreateFineTuneRequest) -> Result<CreateFineTuneResponse, OpenAIApiError> {
1389 let client_builder = reqwest::Client::builder();
1390 let request_builder = self.configuration.apply_to_request(
1391 client_builder,
1392 "/fine-tunes".to_string(),
1393 Method::POST,
1394 );
1395 let response = request_builder.json(&request).send().await
1396 .map_err(|err| OpenAIApiError::from(err))?;
1397 if response.status().is_success() {
1398 response.json::<CreateFineTuneResponse>().await
1399 .map_err(|err| OpenAIApiError::from(err))
1400 } else {
1401 let status = response.status().as_u16() as i32;
1402 let ret_err = response.json::<ReturnErrorType>().await
1403 .map_err(|err| OpenAIApiError::from(err))?;
1404 Err(OpenAIApiError::new(status, ret_err.error))
1405 }
1406 }
1407
1408 pub async fn list_fine_tunes(self) -> Result<ListFineTunesResponse, OpenAIApiError> {
1412 let client_builder = reqwest::Client::builder();
1413 let request_builder = self.configuration.apply_to_request(
1414 client_builder,
1415 "/fine-tunes".to_string(),
1416 Method::GET,
1417 );
1418 let response = request_builder.send().await
1419 .map_err(|err| OpenAIApiError::from(err))?;
1420 if response.status().is_success() {
1421 response.json::<ListFineTunesResponse>().await
1422 .map_err(|err| OpenAIApiError::from(err))
1423 } else {
1424 let status = response.status().as_u16() as i32;
1425 let ret_err = response.json::<ReturnErrorType>().await
1426 .map_err(|err| OpenAIApiError::from(err))?;
1427 Err(OpenAIApiError::new(status, ret_err.error))
1428 }
1429 }
1430
1431 pub async fn retrieve_fine_tune(self, fine_tune_id: String) -> Result<RetrieveFineTuneResponse, OpenAIApiError> {
1435 let client_builder = reqwest::Client::builder();
1436 let request_builder = self.configuration.apply_to_request(
1437 client_builder,
1438 format!("/fine-tunes/{}", fine_tune_id),
1439 Method::GET,
1440 );
1441 let response = request_builder.send().await
1442 .map_err(|err| OpenAIApiError::from(err))?;
1443 if response.status().is_success() {
1444 response.json::<RetrieveFineTuneResponse>().await
1445 .map_err(|err| OpenAIApiError::from(err))
1446 } else {
1447 let status = response.status().as_u16() as i32;
1448 let ret_err = response.json::<ReturnErrorType>().await
1449 .map_err(|err| OpenAIApiError::from(err))?;
1450 Err(OpenAIApiError::new(status, ret_err.error))
1451 }
1452 }
1453
1454 pub async fn cancel_fine_tune(self, fine_tune_id: String) -> Result<CancelFineTuneResponse, OpenAIApiError> {
1458 let client_builder = reqwest::Client::builder();
1459 let request_builder = self.configuration.apply_to_request(
1460 client_builder,
1461 format!("/fine-tunes/{}/cancel", fine_tune_id),
1462 Method::POST,
1463 );
1464 let response = request_builder.send().await
1465 .map_err(|err| OpenAIApiError::from(err))?;
1466 if response.status().is_success() {
1467 response.json::<CancelFineTuneResponse>().await
1468 .map_err(|err| OpenAIApiError::from(err))
1469 } else {
1470 let status = response.status().as_u16() as i32;
1471 let ret_err = response.json::<ReturnErrorType>().await
1472 .map_err(|err| OpenAIApiError::from(err))?;
1473 Err(OpenAIApiError::new(status, ret_err.error))
1474 }
1475 }
1476
1477 pub async fn list_fine_tune_events(self, fine_tune_id: String) -> Result<ListFineTuneEventsResponse, OpenAIApiError> {
1481 let client_builder = reqwest::Client::builder();
1482 let request_builder = self.configuration.apply_to_request(
1483 client_builder,
1484 format!("/fine-tunes/{}/events", fine_tune_id),
1485 Method::GET,
1486 );
1487 let response = request_builder.send().await
1488 .map_err(|err| OpenAIApiError::from(err))?;
1489 if response.status().is_success() {
1490 response.json::<ListFineTuneEventsResponse>().await
1491 .map_err(|err| OpenAIApiError::from(err))
1492 } else {
1493 let status = response.status().as_u16() as i32;
1494 let ret_err = response.json::<ReturnErrorType>().await
1495 .map_err(|err| OpenAIApiError::from(err))?;
1496 Err(OpenAIApiError::new(status, ret_err.error))
1497 }
1498 }
1499
1500 pub async fn delete_fine_tune_model(self, model: String) -> Result<DeleteFineTuneModelResponse, OpenAIApiError> {
1504 let client_builder = reqwest::Client::builder();
1505 let request_builder = self.configuration.apply_to_request(
1506 client_builder,
1507 format!("/models/{}", model),
1508 Method::DELETE,
1509 );
1510 let response = request_builder.send().await
1511 .map_err(|err| OpenAIApiError::from(err))?;
1512 if response.status().is_success() {
1513 response.json::<DeleteFineTuneModelResponse>().await
1514 .map_err(|err| OpenAIApiError::from(err))
1515 } else {
1516 let status = response.status().as_u16() as i32;
1517 let ret_err = response.json::<ReturnErrorType>().await
1518 .map_err(|err| OpenAIApiError::from(err))?;
1519 Err(OpenAIApiError::new(status, ret_err.error))
1520 }
1521 }
1522
1523 pub async fn create_moderation(self, request: CreateModerationRequest) -> Result<CreateModerationResponse, OpenAIApiError> {
1527 let client_builder = reqwest::Client::builder();
1528 let request_builder = self.configuration.apply_to_request(
1529 client_builder,
1530 "/moderations".to_string(),
1531 Method::POST,
1532 );
1533 let response = request_builder.json(&request).send().await
1534 .map_err(|err| OpenAIApiError::from(err))?;
1535 if response.status().is_success() {
1536 response.json::<CreateModerationResponse>().await
1537 .map_err(|err| OpenAIApiError::from(err))
1538 } else {
1539 let status = response.status().as_u16() as i32;
1540 let ret_err = response.json::<ReturnErrorType>().await
1541 .map_err(|err| OpenAIApiError::from(err))?;
1542 Err(OpenAIApiError::new(status, ret_err.error))
1543 }
1544 }
1545
1546 fn get_mime_type_from_suffix(suffix: String) -> Result<String, OpenAIApiError> {
1547 match suffix.as_str() {
1548 "json" => Ok(mime::APPLICATION_JSON.to_string()),
1549 "txt" => Ok(mime::TEXT_PLAIN.to_string()),
1550 "html" => Ok(mime::TEXT_HTML.to_string()),
1551 "pdf" => Ok(mime::APPLICATION_PDF.to_string()),
1552 "png" => Ok(mime::IMAGE_PNG.to_string()),
1553 "jpg" => Ok(mime::IMAGE_JPEG.to_string()),
1554 "jpeg" => Ok(mime::IMAGE_JPEG.to_string()),
1555 "gif" => Ok(mime::IMAGE_GIF.to_string()),
1556 "svg" => Ok(mime::IMAGE_SVG.to_string()),
1557 "m4a" => Ok("audio/m4a".to_string()),
1558 "mp3" => Ok("audio/mp3".to_string()),
1559 "wav" => Ok("audio/wav".to_string()),
1560 "flac" => Ok("audio/flac".to_string()),
1561 "mp4" => Ok("video/mp4".to_string()),
1562 "mpeg" => Ok("video/mpeg".to_string()),
1563 "mpga" => Ok("audio/mpeg".to_string()),
1564 "webm" => Ok("video/webm".to_string()),
1565 _ => {
1566 let e = ErrorInfo {
1567 message: format!("Unsupported file type: {}", suffix),
1568 code: None,
1569 message_type: "unsupported_file_type".to_string(),
1570 param: None,
1571 };
1572 Err(OpenAIApiError::new(400, e))
1573 },
1574 }
1575 }
1576
1577}
1578
1579#[cfg(test)]
1580mod tests {
1581 use super::*;
1582 use crate::configuration::Configuration;
1583 use dotenv::vars;
1584
1585 #[tokio::test]
1586 async fn test_list_models() {
1587
1588 let api_key = vars().find(|(key, _)| key == "API_KEY").unwrap_or(("API_KEY".to_string(),"".to_string())).1;
1589
1590 let configuration = Configuration::new_personal(api_key)
1591 .proxy("http://127.0.0.1:7890".to_string());
1592
1593 let openai_api = OpenAIApi::new(configuration);
1594 let response = openai_api.list_models().await.unwrap();
1595 assert_eq!(response.object, "list");
1596 }
1597
1598 #[tokio::test]
1599 async fn test_retrieve_model() {
1600 let api_key = vars().find(|(key, _)| key == "API_KEY").unwrap_or(("API_KEY".to_string(),"".to_string())).1;
1601
1602 let configuration = Configuration::new_personal(api_key)
1603 .proxy("http://127.0.0.1:7890".to_string());
1604
1605 let openai_api = OpenAIApi::new(configuration);
1606 let response = openai_api.retrieve_model("davinci".to_string()).await.unwrap();
1607 assert_eq!(response.object, "model");
1608 }
1609
1610 #[tokio::test]
1611 async fn test_create_completion() {
1612 let api_key = vars().find(|(key, _)| key == "API_KEY").unwrap_or(("API_KEY".to_string(),"".to_string())).1;
1613
1614 let configuration = Configuration::new_personal(api_key)
1615 .proxy("http://127.0.0.1:7890".to_string());
1616
1617 let openai_api = OpenAIApi::new(configuration);
1618 let request = CreateCompletionRequest {
1619 model: "text-davinci-003".to_string(),
1620 prompt: Some(vec!["Once upon a time".to_string()]),
1621 max_tokens: Some(7),
1622 temperature: Some(0.7),
1623 ..Default::default()
1624 };
1625
1626 let response = openai_api.create_completion(request).await.unwrap();
1628 assert_eq!(response.object, "text_completion");
1629 }
1630
1631 #[tokio::test]
1632 async fn test_create_chat_completion() {
1633 let api_key = vars().find(|(key, _)| key == "API_KEY").unwrap_or(("API_KEY".to_string(),"".to_string())).1;
1634
1635 let configuration = Configuration::new_personal(api_key)
1636 .proxy("http://127.0.0.1:7890".to_string());
1637
1638 let openai_api = OpenAIApi::new(configuration);
1639 let request = CreateChatCompletionRequest {
1640 model: "gpt-3.5-turbo".to_string(),
1641 messages: vec![ChatFormat{role: "user".to_string(), content: "tell me a story".to_string()}],
1642 ..Default::default()
1643 };
1644 let response = openai_api.create_chat_completion(request).await.unwrap();
1646 assert_eq!(response.object, "chat.completion");
1647 }
1648
1649 #[tokio::test]
1650 async fn test_create_edit() {
1651 let api_key = vars().find(|(key, _)| key == "API_KEY").unwrap_or(("API_KEY".to_string(),"".to_string())).1;
1652
1653 let configuration = Configuration::new_personal(api_key)
1654 .proxy("http://127.0.0.1:7890".to_string());
1655
1656 let openai_api = OpenAIApi::new(configuration);
1657 let request = CreateEditRequest {
1658 model: "text-davinci-edit-001".to_string(),
1659 input: Some("What day of the wek is it?".to_string()),
1660 instruction: "Fix the spelling mistakes".to_string(),
1661 ..Default::default()
1662 };
1663 let response = openai_api.create_edit(request).await.unwrap();
1665 assert_eq!(response.object, "edit");
1666 }
1667
1668 #[tokio::test]
1669 async fn test_create_image() {
1670 let api_key = vars().find(|(key, _)| key == "API_KEY").unwrap_or(("API_KEY".to_string(),"".to_string())).1;
1671
1672 let configuration = Configuration::new_personal(api_key)
1673 .proxy("http://127.0.0.1:7890".to_string());
1674
1675 let openai_api = OpenAIApi::new(configuration);
1676 let request = CreateImageRequest {
1677 prompt: "A photo of a dog".to_string(),
1678 n: Some(1),
1679 size: Some("512x512".to_string()),
1680 response_format: Some(ImageFormat::URL),
1681 ..Default::default()
1682 };
1683 println!("request: {:#?}", serde_json::to_string(&request).unwrap());
1684 let response = openai_api.create_image(request).await.unwrap();
1685
1686 assert_eq!(response.data.len(), 1);
1687 match response.data[0].clone() {
1688 CreateImageResponseData::Url(url) => {
1689 assert!(url.starts_with("https://"));
1690 },
1691 _ => {
1692 assert!(false, "error response format");
1693 }
1694 }
1695 }
1696
1697 #[tokio::test]
1698 async fn test_create_transcription() {
1699 let api_key = vars().find(|(key, _)| key == "API_KEY").unwrap_or(("API_KEY".to_string(),"".to_string())).1;
1700
1701 let configuration = Configuration::new_personal(api_key)
1702 .proxy("http://127.0.0.1:7890".to_string());
1703
1704 let openai_api = OpenAIApi::new(configuration);
1705 let request = CreateTranscriptionRequest {
1706 file: "./misc/test_audio.m4a".to_string(),
1707 model: "whisper-1".to_string(),
1708 response_format: Some(CreateTranscriptionResponseFormat::JSON),
1709 ..Default::default()
1710 };
1711 println!("request: {:#?}", serde_json::to_string(&request).unwrap());
1712 let response = openai_api.create_transcription(request).await.unwrap();
1713 match response {
1714 CreateTranscriptionResponse::Json(content) => {
1715 assert_eq!(content.text, "你好你好");
1716 },
1717 _ => {
1718 assert!(false);
1719 }
1720 };
1721 }
1722
1723 #[tokio::test]
1724 async fn test_create_translation() {
1725 let api_key = vars().find(|(key, _)| key == "API_KEY").unwrap_or(("API_KEY".to_string(),"".to_string())).1;
1726
1727 let configuration = Configuration::new_personal(api_key)
1728 .proxy("http://127.0.0.1:7890".to_string());
1729
1730 let openai_api = OpenAIApi::new(configuration);
1731 let request = CreateTranslationRequest {
1732 file: "./misc/test_audio.m4a".to_string(),
1733 model: "whisper-1".to_string(),
1734 response_format: Some(CreateTranscriptionResponseFormat::JSON),
1735 ..Default::default()
1736 };
1737 println!("request: {:#?}", serde_json::to_string(&request).unwrap());
1738 let response = openai_api.create_translation(request).await.unwrap();
1739 match response {
1740 CreateTranslationResponse::Json(content) => {
1741 assert_eq!(content.text, "Ni hao, ni hao.");
1742 },
1743 _ => {
1744 assert!(false);
1745 }
1746 };
1747 }
1748
1749 #[tokio::test]
1750 async fn test_list_files() {
1751 let api_key = vars().find(|(key, _)| key == "API_KEY").unwrap_or(("API_KEY".to_string(),"".to_string())).1;
1752
1753 let configuration = Configuration::new_personal(api_key)
1754 .proxy("http://127.0.0.1:7890".to_string());
1755
1756 let openai_api = OpenAIApi::new(configuration);
1757 let response = openai_api.list_files().await.unwrap();
1758 assert_eq!(response.object, "list");
1759 }
1760
1761 #[tokio::test]
1762 async fn test_list_fine_tunes() {
1763 let api_key = vars().find(|(key, _)| key == "API_KEY").unwrap_or(("API_KEY".to_string(),"".to_string())).1;
1764
1765 let configuration = Configuration::new_personal(api_key)
1766 .proxy("http://127.0.0.1:7890".to_string());
1767
1768 let openai_api = OpenAIApi::new(configuration);
1769 let response = openai_api.list_fine_tunes().await.unwrap();
1770 assert_eq!(response.object, "list");
1771 }
1772
1773 #[tokio::test]
1774 async fn test_create_moderation() {
1775 let api_key = vars().find(|(key, _)| key == "API_KEY").unwrap_or(("API_KEY".to_string(),"".to_string())).1;
1776
1777 let configuration = Configuration::new_personal(api_key)
1778 .proxy("http://127.0.0.1:7890".to_string());
1779
1780 let openai_api = OpenAIApi::new(configuration);
1781 let response = openai_api.create_moderation(CreateModerationRequest {
1782 input: vec!["I want to kill them.".to_string()],
1783 ..Default::default()
1784 }).await.unwrap();
1785 assert!(response.results[0].categories.violence);
1787 }
1788
1789}
1790
1791
1792
1793
1794