1use crate::v1::assistant::{
2 AssistantFileObject, AssistantFileRequest, AssistantObject, AssistantRequest, ListAssistant,
3 ListAssistantFile,
4};
5use crate::v1::audio::{
6 AudioSpeechRequest, AudioSpeechResponse, AudioTranscriptionRequest, AudioTranscriptionResponse,
7 AudioTranslationRequest, AudioTranslationResponse,
8};
9use crate::v1::batch::{BatchResponse, CreateBatchRequest, ListBatchResponse};
10use crate::v1::chat_completion::chat_completion::{ChatCompletionRequest, ChatCompletionResponse};
11use crate::v1::chat_completion::chat_completion_stream::{
12 ChatCompletionStream, ChatCompletionStreamRequest, ChatCompletionStreamResponse,
13};
14use crate::v1::common;
15use crate::v1::completion::{CompletionRequest, CompletionResponse};
16use crate::v1::edit::{EditRequest, EditResponse};
17use crate::v1::embedding::{EmbeddingRequest, EmbeddingResponse};
18use crate::v1::error::APIError;
19use crate::v1::file::{
20 FileDeleteRequest, FileDeleteResponse, FileListResponse, FileRetrieveResponse,
21 FileUploadRequest, FileUploadResponse,
22};
23use crate::v1::fine_tuning::{
24 CancelFineTuningJobRequest, CreateFineTuningJobRequest, FineTuningJobEvent,
25 FineTuningJobObject, FineTuningPagination, ListFineTuningJobEventsRequest,
26 RetrieveFineTuningJobRequest,
27};
28use crate::v1::image::{
29 ImageEditRequest, ImageEditResponse, ImageGenerationRequest, ImageGenerationResponse,
30 ImageVariationRequest, ImageVariationResponse,
31};
32use crate::v1::message::{
33 CreateMessageRequest, ListMessage, ListMessageFile, MessageFileObject, MessageObject,
34 ModifyMessageRequest,
35};
36use crate::v1::model::{ModelResponse, ModelsResponse};
37use crate::v1::moderation::{CreateModerationRequest, CreateModerationResponse};
38use crate::v1::run::{
39 CreateRunRequest, CreateThreadAndRunRequest, ListRun, ListRunStep, ModifyRunRequest, RunObject,
40 RunStepObject,
41};
42use crate::v1::thread::{CreateThreadRequest, ModifyThreadRequest, ThreadObject};
43
44use bytes::Bytes;
45use futures_util::Stream;
46use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
47use reqwest::multipart::{Form, Part};
48use reqwest::{Client, Method, Response};
49use serde::Serialize;
50use serde_json::{to_value, Value};
51use url::Url;
52
53use std::error::Error;
54use std::fs::{create_dir_all, File};
55use std::io::Read;
56use std::io::Write;
57use std::path::Path;
58
59const API_URL_V1: &str = "https://api.openai.com/v1";
60
61#[derive(Default)]
62pub struct OpenAIClientBuilder {
63 api_endpoint: Option<String>,
64 api_key: Option<String>,
65 organization: Option<String>,
66 proxy: Option<String>,
67 timeout: Option<u64>,
68 headers: Option<HeaderMap>,
69}
70
71#[derive(Debug)]
72pub struct OpenAIClient {
73 api_endpoint: String,
74 api_key: Option<String>,
75 organization: Option<String>,
76 proxy: Option<String>,
77 timeout: Option<u64>,
78 headers: Option<HeaderMap>,
79 pub response_headers: Option<HeaderMap>,
80}
81
82impl OpenAIClientBuilder {
83 pub fn new() -> Self {
84 Self::default()
85 }
86
87 pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
88 self.api_key = Some(api_key.into());
89 self
90 }
91
92 pub fn with_endpoint(mut self, endpoint: impl Into<String>) -> Self {
93 self.api_endpoint = Some(endpoint.into());
94 self
95 }
96
97 pub fn with_organization(mut self, organization: impl Into<String>) -> Self {
98 self.organization = Some(organization.into());
99 self
100 }
101
102 pub fn with_proxy(mut self, proxy: impl Into<String>) -> Self {
103 self.proxy = Some(proxy.into());
104 self
105 }
106
107 pub fn with_timeout(mut self, timeout: u64) -> Self {
108 self.timeout = Some(timeout);
109 self
110 }
111
112 pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
113 let headers = self.headers.get_or_insert_with(HeaderMap::new);
114 headers.insert(
115 HeaderName::from_bytes(key.into().as_bytes()).expect("Invalid header name"),
116 HeaderValue::from_str(&value.into()).expect("Invalid header value"),
117 );
118 self
119 }
120
121 pub fn build(self) -> Result<OpenAIClient, Box<dyn Error>> {
122 let api_endpoint = self.api_endpoint.unwrap_or_else(|| {
123 std::env::var("OPENAI_API_BASE").unwrap_or_else(|_| API_URL_V1.to_owned())
124 });
125
126 Ok(OpenAIClient {
127 api_endpoint,
128 api_key: self.api_key,
129 organization: self.organization,
130 proxy: self.proxy,
131 timeout: self.timeout,
132 headers: self.headers,
133 response_headers: None,
134 })
135 }
136}
137
138impl OpenAIClient {
139 pub fn builder() -> OpenAIClientBuilder {
140 OpenAIClientBuilder::new()
141 }
142
143 async fn build_request(&self, method: Method, path: &str) -> reqwest::RequestBuilder {
144 let url = self
145 .build_url_with_preserved_query(path)
146 .unwrap_or_else(|_| format!("{}/{}", self.api_endpoint, path));
147
148 let client = Client::builder();
149
150 #[cfg(feature = "rustls")]
151 let client = client.use_rustls_tls();
152
153 let client = if let Some(timeout) = self.timeout {
154 client.timeout(std::time::Duration::from_secs(timeout))
155 } else {
156 client
157 };
158
159 let client = if let Some(proxy) = &self.proxy {
160 client.proxy(reqwest::Proxy::all(proxy).unwrap())
161 } else {
162 client
163 };
164
165 let client = client.build().unwrap();
166
167 let mut request = client.request(method, url);
168
169 if let Some(api_key) = &self.api_key {
170 request = request.header("Authorization", format!("Bearer {api_key}"));
171 }
172
173 if let Some(organization) = &self.organization {
174 request = request.header("openai-organization", organization);
175 }
176
177 if let Some(headers) = &self.headers {
178 for (key, value) in headers {
179 request = request.header(key, value);
180 }
181 }
182
183 if Self::is_beta(path) {
184 request = request.header("OpenAI-Beta", "assistants=v2");
185 }
186
187 request
188 }
189
190 async fn post<T: serde::de::DeserializeOwned>(
191 &mut self,
192 path: &str,
193 body: &impl serde::ser::Serialize,
194 ) -> Result<T, APIError> {
195 let request = self.build_request(Method::POST, path).await;
196 let request = request.json(body);
197 let response = request.send().await?;
198 self.handle_response(response).await
199 }
200
201 async fn get<T: serde::de::DeserializeOwned>(&mut self, path: &str) -> Result<T, APIError> {
202 let request = self.build_request(Method::GET, path).await;
203 let response = request.send().await?;
204 self.handle_response(response).await
205 }
206
207 async fn get_raw(&self, path: &str) -> Result<Bytes, APIError> {
208 let request = self.build_request(Method::GET, path).await;
209 let response = request.send().await?;
210 Ok(response.bytes().await?)
211 }
212
213 async fn delete<T: serde::de::DeserializeOwned>(&mut self, path: &str) -> Result<T, APIError> {
214 let request = self.build_request(Method::DELETE, path).await;
215 let response = request.send().await?;
216 self.handle_response(response).await
217 }
218
219 async fn post_form<T: serde::de::DeserializeOwned>(
220 &mut self,
221 path: &str,
222 form: Form,
223 ) -> Result<T, APIError> {
224 let request = self.build_request(Method::POST, path).await;
225 let request = request.multipart(form);
226 let response = request.send().await?;
227 self.handle_response(response).await
228 }
229
230 async fn post_form_raw(&self, path: &str, form: Form) -> Result<Bytes, APIError> {
231 let request = self.build_request(Method::POST, path).await;
232 let request = request.multipart(form);
233 let response = request.send().await?;
234 Ok(response.bytes().await?)
235 }
236
237 async fn handle_response<T: serde::de::DeserializeOwned>(
238 &mut self,
239 response: Response,
240 ) -> Result<T, APIError> {
241 let status = response.status();
242 let headers = response.headers().clone();
243 if status.is_success() {
244 let text = response.text().await.unwrap_or_else(|_| "".to_string());
245 match serde_json::from_str::<T>(&text) {
246 Ok(parsed) => {
247 self.response_headers = Some(headers);
248 Ok(parsed)
249 }
250 Err(e) => Err(APIError::CustomError {
251 message: format!("Failed to parse JSON: {e} / response {text}"),
252 }),
253 }
254 } else {
255 let error_message = response
256 .text()
257 .await
258 .unwrap_or_else(|_| "Unknown error".to_string());
259 Err(APIError::CustomError {
260 message: format!("{status}: {error_message}"),
261 })
262 }
263 }
264
265 pub async fn completion(
266 &mut self,
267 req: CompletionRequest,
268 ) -> Result<CompletionResponse, APIError> {
269 self.post("completions", &req).await
270 }
271
272 pub async fn edit(&mut self, req: EditRequest) -> Result<EditResponse, APIError> {
273 self.post("edits", &req).await
274 }
275
276 pub async fn image_generation(
277 &mut self,
278 req: ImageGenerationRequest,
279 ) -> Result<ImageGenerationResponse, APIError> {
280 self.post("images/generations", &req).await
281 }
282
283 pub async fn image_edit(
284 &mut self,
285 req: ImageEditRequest,
286 ) -> Result<ImageEditResponse, APIError> {
287 self.post("images/edits", &req).await
288 }
289
290 pub async fn image_variation(
291 &mut self,
292 req: ImageVariationRequest,
293 ) -> Result<ImageVariationResponse, APIError> {
294 self.post("images/variations", &req).await
295 }
296
297 pub async fn embedding(
298 &mut self,
299 req: EmbeddingRequest,
300 ) -> Result<EmbeddingResponse, APIError> {
301 self.post("embeddings", &req).await
302 }
303
304 pub async fn file_list(&mut self) -> Result<FileListResponse, APIError> {
305 self.get("files").await
306 }
307
308 pub async fn upload_file(
309 &mut self,
310 req: FileUploadRequest,
311 ) -> Result<FileUploadResponse, APIError> {
312 let form = Self::create_form(&req, "file")?;
313 self.post_form("files", form).await
314 }
315
316 pub async fn delete_file(
317 &mut self,
318 req: FileDeleteRequest,
319 ) -> Result<FileDeleteResponse, APIError> {
320 self.delete(&format!("files/{}", req.file_id)).await
321 }
322
323 pub async fn retrieve_file(
324 &mut self,
325 file_id: String,
326 ) -> Result<FileRetrieveResponse, APIError> {
327 self.get(&format!("files/{file_id}")).await
328 }
329
330 pub async fn retrieve_file_content(&self, file_id: String) -> Result<Bytes, APIError> {
331 self.get_raw(&format!("files/{file_id}/content")).await
332 }
333
334 pub async fn chat_completion(
335 &mut self,
336 req: ChatCompletionRequest,
337 ) -> Result<ChatCompletionResponse, APIError> {
338 self.post("chat/completions", &req).await
339 }
340
341 pub async fn chat_completion_stream(
342 &mut self,
343 req: ChatCompletionStreamRequest,
344 ) -> Result<impl Stream<Item = ChatCompletionStreamResponse>, APIError> {
345 let mut payload = to_value(&req).map_err(|err| APIError::CustomError {
346 message: format!("Failed to serialize request: {}", err),
347 })?;
348
349 if let Some(obj) = payload.as_object_mut() {
350 obj.insert("stream".into(), Value::Bool(true));
351 }
352
353 let request = self.build_request(Method::POST, "chat/completions").await;
354 let request = request.json(&payload);
355 let response = request.send().await?;
356
357 if response.status().is_success() {
358 Ok(ChatCompletionStream {
359 response: Box::pin(response.bytes_stream()),
360 buffer: String::new(),
361 first_chunk: true,
362 })
363 } else {
364 let error_text = response
365 .text()
366 .await
367 .unwrap_or_else(|_| String::from("Unknown error"));
368
369 Err(APIError::CustomError {
370 message: error_text,
371 })
372 }
373 }
374
375 pub async fn audio_transcription(
376 &mut self,
377 req: AudioTranscriptionRequest,
378 ) -> Result<AudioTranscriptionResponse, APIError> {
379 if let Some(response_format) = &req.response_format {
381 if response_format != "json" && response_format != "verbose_json" {
382 return Err(APIError::CustomError {
383 message: "response_format must be either 'json' or 'verbose_json' please use audio_transcription_raw".to_string(),
384 });
385 }
386 }
387 let form: Form;
388 if req.clone().file.is_some() {
389 form = Self::create_form(&req, "file")?;
390 } else if let Some(bytes) = req.clone().bytes {
391 form = Self::create_form_from_bytes(&req, bytes)?;
392 } else {
393 return Err(APIError::CustomError {
394 message: "Either file or bytes must be provided".to_string(),
395 });
396 }
397 self.post_form("audio/transcriptions", form).await
398 }
399
400 pub async fn audio_transcription_raw(
401 &mut self,
402 req: AudioTranscriptionRequest,
403 ) -> Result<Bytes, APIError> {
404 if let Some(response_format) = &req.response_format {
406 if response_format != "text" && response_format != "srt" && response_format != "vtt" {
407 return Err(APIError::CustomError {
408 message: "response_format must be either 'text', 'srt' or 'vtt', please use audio_transcription".to_string(),
409 });
410 }
411 }
412 let form: Form;
413 if req.clone().file.is_some() {
414 form = Self::create_form(&req, "file")?;
415 } else if let Some(bytes) = req.clone().bytes {
416 form = Self::create_form_from_bytes(&req, bytes)?;
417 } else {
418 return Err(APIError::CustomError {
419 message: "Either file or bytes must be provided".to_string(),
420 });
421 }
422 self.post_form_raw("audio/transcriptions", form).await
423 }
424
425 pub async fn audio_translation(
426 &mut self,
427 req: AudioTranslationRequest,
428 ) -> Result<AudioTranslationResponse, APIError> {
429 let form = Self::create_form(&req, "file")?;
430 self.post_form("audio/translations", form).await
431 }
432
433 pub async fn audio_speech(
434 &mut self,
435 req: AudioSpeechRequest,
436 ) -> Result<AudioSpeechResponse, APIError> {
437 let request = self.build_request(Method::POST, "audio/speech").await;
438 let request = request.json(&req);
439 let response = request.send().await?;
440 let headers = response.headers().clone();
441 let bytes = response.bytes().await?;
442 let path = Path::new(req.output.as_str());
443 if let Some(parent) = path.parent() {
444 match create_dir_all(parent) {
445 Ok(_) => {}
446 Err(e) => {
447 return Err(APIError::CustomError {
448 message: e.to_string(),
449 })
450 }
451 }
452 }
453 match File::create(path) {
454 Ok(mut file) => match file.write_all(&bytes) {
455 Ok(_) => {}
456 Err(e) => {
457 return Err(APIError::CustomError {
458 message: e.to_string(),
459 })
460 }
461 },
462 Err(e) => {
463 return Err(APIError::CustomError {
464 message: e.to_string(),
465 })
466 }
467 }
468
469 Ok(AudioSpeechResponse {
470 result: true,
471 headers: Some(headers),
472 })
473 }
474
475 pub async fn create_fine_tuning_job(
476 &mut self,
477 req: CreateFineTuningJobRequest,
478 ) -> Result<FineTuningJobObject, APIError> {
479 self.post("fine_tuning/jobs", &req).await
480 }
481
482 pub async fn list_fine_tuning_jobs(
483 &mut self,
484 ) -> Result<FineTuningPagination<FineTuningJobObject>, APIError> {
485 self.get("fine_tuning/jobs").await
486 }
487
488 pub async fn list_fine_tuning_job_events(
489 &mut self,
490 req: ListFineTuningJobEventsRequest,
491 ) -> Result<FineTuningPagination<FineTuningJobEvent>, APIError> {
492 self.get(&format!(
493 "fine_tuning/jobs/{}/events",
494 req.fine_tuning_job_id
495 ))
496 .await
497 }
498
499 pub async fn retrieve_fine_tuning_job(
500 &mut self,
501 req: RetrieveFineTuningJobRequest,
502 ) -> Result<FineTuningJobObject, APIError> {
503 self.get(&format!("fine_tuning/jobs/{}", req.fine_tuning_job_id))
504 .await
505 }
506
507 pub async fn cancel_fine_tuning_job(
508 &mut self,
509 req: CancelFineTuningJobRequest,
510 ) -> Result<FineTuningJobObject, APIError> {
511 self.post(
512 &format!("fine_tuning/jobs/{}/cancel", req.fine_tuning_job_id),
513 &req,
514 )
515 .await
516 }
517
518 pub async fn create_moderation(
519 &mut self,
520 req: CreateModerationRequest,
521 ) -> Result<CreateModerationResponse, APIError> {
522 self.post("moderations", &req).await
523 }
524
525 pub async fn create_assistant(
526 &mut self,
527 req: AssistantRequest,
528 ) -> Result<AssistantObject, APIError> {
529 self.post("assistants", &req).await
530 }
531
532 pub async fn retrieve_assistant(
533 &mut self,
534 assistant_id: String,
535 ) -> Result<AssistantObject, APIError> {
536 self.get(&format!("assistants/{assistant_id}")).await
537 }
538
539 pub async fn modify_assistant(
540 &mut self,
541 assistant_id: String,
542 req: AssistantRequest,
543 ) -> Result<AssistantObject, APIError> {
544 self.post(&format!("assistants/{assistant_id}"), &req).await
545 }
546
547 pub async fn delete_assistant(
548 &mut self,
549 assistant_id: String,
550 ) -> Result<common::DeletionStatus, APIError> {
551 self.delete(&format!("assistants/{assistant_id}")).await
552 }
553
554 pub async fn list_assistant(
555 &mut self,
556 limit: Option<i64>,
557 order: Option<String>,
558 after: Option<String>,
559 before: Option<String>,
560 ) -> Result<ListAssistant, APIError> {
561 let url = Self::query_params(limit, order, after, before, "assistants".to_string());
562 self.get(&url).await
563 }
564
565 pub async fn create_assistant_file(
566 &mut self,
567 assistant_id: String,
568 req: AssistantFileRequest,
569 ) -> Result<AssistantFileObject, APIError> {
570 self.post(&format!("assistants/{assistant_id}/files"), &req)
571 .await
572 }
573
574 pub async fn retrieve_assistant_file(
575 &mut self,
576 assistant_id: String,
577 file_id: String,
578 ) -> Result<AssistantFileObject, APIError> {
579 self.get(&format!("assistants/{assistant_id}/files/{file_id}"))
580 .await
581 }
582
583 pub async fn delete_assistant_file(
584 &mut self,
585 assistant_id: String,
586 file_id: String,
587 ) -> Result<common::DeletionStatus, APIError> {
588 self.delete(&format!("assistants/{assistant_id}/files/{file_id}"))
589 .await
590 }
591
592 pub async fn list_assistant_file(
593 &mut self,
594 assistant_id: String,
595 limit: Option<i64>,
596 order: Option<String>,
597 after: Option<String>,
598 before: Option<String>,
599 ) -> Result<ListAssistantFile, APIError> {
600 let url = Self::query_params(
601 limit,
602 order,
603 after,
604 before,
605 format!("assistants/{assistant_id}/files"),
606 );
607 self.get(&url).await
608 }
609
610 pub async fn create_thread(
611 &mut self,
612 req: CreateThreadRequest,
613 ) -> Result<ThreadObject, APIError> {
614 self.post("threads", &req).await
615 }
616
617 pub async fn retrieve_thread(&mut self, thread_id: String) -> Result<ThreadObject, APIError> {
618 self.get(&format!("threads/{thread_id}")).await
619 }
620
621 pub async fn modify_thread(
622 &mut self,
623 thread_id: String,
624 req: ModifyThreadRequest,
625 ) -> Result<ThreadObject, APIError> {
626 self.post(&format!("threads/{thread_id}"), &req).await
627 }
628
629 pub async fn delete_thread(
630 &mut self,
631 thread_id: String,
632 ) -> Result<common::DeletionStatus, APIError> {
633 self.delete(&format!("threads/{thread_id}")).await
634 }
635
636 pub async fn create_message(
637 &mut self,
638 thread_id: String,
639 req: CreateMessageRequest,
640 ) -> Result<MessageObject, APIError> {
641 self.post(&format!("threads/{thread_id}/messages"), &req)
642 .await
643 }
644
645 pub async fn retrieve_message(
646 &mut self,
647 thread_id: String,
648 message_id: String,
649 ) -> Result<MessageObject, APIError> {
650 self.get(&format!("threads/{thread_id}/messages/{message_id}"))
651 .await
652 }
653
654 pub async fn modify_message(
655 &mut self,
656 thread_id: String,
657 message_id: String,
658 req: ModifyMessageRequest,
659 ) -> Result<MessageObject, APIError> {
660 self.post(&format!("threads/{thread_id}/messages/{message_id}"), &req)
661 .await
662 }
663
664 pub async fn list_messages(&mut self, thread_id: String) -> Result<ListMessage, APIError> {
665 self.get(&format!("threads/{thread_id}/messages")).await
666 }
667
668 pub async fn retrieve_message_file(
669 &mut self,
670 thread_id: String,
671 message_id: String,
672 file_id: String,
673 ) -> Result<MessageFileObject, APIError> {
674 self.get(&format!(
675 "threads/{thread_id}/messages/{message_id}/files/{file_id}"
676 ))
677 .await
678 }
679
680 pub async fn list_message_file(
681 &mut self,
682 thread_id: String,
683 message_id: String,
684 limit: Option<i64>,
685 order: Option<String>,
686 after: Option<String>,
687 before: Option<String>,
688 ) -> Result<ListMessageFile, APIError> {
689 let url = Self::query_params(
690 limit,
691 order,
692 after,
693 before,
694 format!("threads/{thread_id}/messages/{message_id}/files"),
695 );
696 self.get(&url).await
697 }
698
699 pub async fn create_run(
700 &mut self,
701 thread_id: String,
702 req: CreateRunRequest,
703 ) -> Result<RunObject, APIError> {
704 self.post(&format!("threads/{thread_id}/runs"), &req).await
705 }
706
707 pub async fn retrieve_run(
708 &mut self,
709 thread_id: String,
710 run_id: String,
711 ) -> Result<RunObject, APIError> {
712 self.get(&format!("threads/{thread_id}/runs/{run_id}"))
713 .await
714 }
715
716 pub async fn modify_run(
717 &mut self,
718 thread_id: String,
719 run_id: String,
720 req: ModifyRunRequest,
721 ) -> Result<RunObject, APIError> {
722 self.post(&format!("threads/{thread_id}/runs/{run_id}"), &req)
723 .await
724 }
725
726 pub async fn list_run(
727 &mut self,
728 thread_id: String,
729 limit: Option<i64>,
730 order: Option<String>,
731 after: Option<String>,
732 before: Option<String>,
733 ) -> Result<ListRun, APIError> {
734 let url = Self::query_params(
735 limit,
736 order,
737 after,
738 before,
739 format!("threads/{thread_id}/runs"),
740 );
741 self.get(&url).await
742 }
743
744 pub async fn cancel_run(
745 &mut self,
746 thread_id: String,
747 run_id: String,
748 ) -> Result<RunObject, APIError> {
749 self.post(
750 &format!("threads/{thread_id}/runs/{run_id}/cancel"),
751 &ModifyRunRequest::default(),
752 )
753 .await
754 }
755
756 pub async fn create_thread_and_run(
757 &mut self,
758 req: CreateThreadAndRunRequest,
759 ) -> Result<RunObject, APIError> {
760 self.post("threads/runs", &req).await
761 }
762
763 pub async fn retrieve_run_step(
764 &mut self,
765 thread_id: String,
766 run_id: String,
767 step_id: String,
768 ) -> Result<RunStepObject, APIError> {
769 self.get(&format!(
770 "threads/{thread_id}/runs/{run_id}/steps/{step_id}"
771 ))
772 .await
773 }
774
775 pub async fn list_run_step(
776 &mut self,
777 thread_id: String,
778 run_id: String,
779 limit: Option<i64>,
780 order: Option<String>,
781 after: Option<String>,
782 before: Option<String>,
783 ) -> Result<ListRunStep, APIError> {
784 let url = Self::query_params(
785 limit,
786 order,
787 after,
788 before,
789 format!("threads/{thread_id}/runs/{run_id}/steps"),
790 );
791 self.get(&url).await
792 }
793
794 pub async fn create_batch(
795 &mut self,
796 req: CreateBatchRequest,
797 ) -> Result<BatchResponse, APIError> {
798 self.post("batches", &req).await
799 }
800
801 pub async fn retrieve_batch(&mut self, batch_id: String) -> Result<BatchResponse, APIError> {
802 self.get(&format!("batches/{batch_id}")).await
803 }
804
805 pub async fn cancel_batch(&mut self, batch_id: String) -> Result<BatchResponse, APIError> {
806 self.post(
807 &format!("batches/{batch_id}/cancel"),
808 &common::EmptyRequestBody {},
809 )
810 .await
811 }
812
813 pub async fn list_batch(
814 &mut self,
815 after: Option<String>,
816 limit: Option<i64>,
817 ) -> Result<ListBatchResponse, APIError> {
818 let url = Self::query_params(limit, None, after, None, "batches".to_string());
819 self.get(&url).await
820 }
821
822 pub async fn list_models(&mut self) -> Result<ModelsResponse, APIError> {
823 self.get("models").await
824 }
825
826 pub async fn retrieve_model(&mut self, model_id: String) -> Result<ModelResponse, APIError> {
827 self.get(&format!("models/{model_id}")).await
828 }
829
830 pub async fn delete_model(
831 &mut self,
832 model_id: String,
833 ) -> Result<common::DeletionStatus, APIError> {
834 self.delete(&format!("models/{model_id}")).await
835 }
836
837 fn build_url_with_preserved_query(&self, path: &str) -> Result<String, url::ParseError> {
838 let (base, query_opt) = match self.api_endpoint.split_once('?') {
839 Some((b, q)) => (b.trim_end_matches('/'), Some(q)),
840 None => (self.api_endpoint.trim_end_matches('/'), None),
841 };
842
843 let full_path = format!("{}/{}", base, path.trim_start_matches('/'));
844 let mut url = Url::parse(&full_path)?;
845
846 if let Some(query) = query_opt {
847 for (k, v) in url::form_urlencoded::parse(query.as_bytes()) {
848 url.query_pairs_mut().append_pair(&k, &v);
849 }
850 }
851 Ok(url.to_string())
852 }
853
854 fn query_params(
855 limit: Option<i64>,
856 order: Option<String>,
857 after: Option<String>,
858 before: Option<String>,
859 mut url: String,
860 ) -> String {
861 let mut params = vec![];
862 if let Some(limit) = limit {
863 params.push(format!("limit={limit}"));
864 }
865 if let Some(order) = order {
866 params.push(format!("order={order}"));
867 }
868 if let Some(after) = after {
869 params.push(format!("after={after}"));
870 }
871 if let Some(before) = before {
872 params.push(format!("before={before}"));
873 }
874 if !params.is_empty() {
875 url = format!("{}?{}", url, params.join("&"));
876 }
877 url
878 }
879
880 fn is_beta(path: &str) -> bool {
881 path.starts_with("assistants") || path.starts_with("threads")
882 }
883
884 fn create_form<T>(req: &T, file_field: &str) -> Result<Form, APIError>
885 where
886 T: Serialize,
887 {
888 let json = match serde_json::to_value(req) {
889 Ok(json) => json,
890 Err(e) => {
891 return Err(APIError::CustomError {
892 message: e.to_string(),
893 })
894 }
895 };
896 let file_path = if let Value::Object(map) = &json {
897 map.get(file_field)
898 .and_then(|v| v.as_str())
899 .ok_or(APIError::CustomError {
900 message: format!("Field '{file_field}' not found or not a string"),
901 })?
902 } else {
903 return Err(APIError::CustomError {
904 message: "Request is not a JSON object".to_string(),
905 });
906 };
907
908 let mut file = match File::open(file_path) {
909 Ok(file) => file,
910 Err(e) => {
911 return Err(APIError::CustomError {
912 message: e.to_string(),
913 })
914 }
915 };
916 let mut buffer = Vec::new();
917 match file.read_to_end(&mut buffer) {
918 Ok(_) => {}
919 Err(e) => {
920 return Err(APIError::CustomError {
921 message: e.to_string(),
922 })
923 }
924 }
925
926 let mut form =
927 Form::new().part("file", Part::bytes(buffer).file_name(file_path.to_string()));
928
929 if let Value::Object(map) = json {
930 for (key, value) in map.into_iter() {
931 if key != file_field {
932 match value {
933 Value::String(s) => {
934 form = form.text(key, s);
935 }
936 Value::Number(n) => {
937 form = form.text(key, n.to_string());
938 }
939 _ => {}
940 }
941 }
942 }
943 }
944
945 Ok(form)
946 }
947
948 fn create_form_from_bytes<T>(req: &T, bytes: Vec<u8>) -> Result<Form, APIError>
949 where
950 T: Serialize,
951 {
952 let json = match serde_json::to_value(req) {
953 Ok(json) => json,
954 Err(e) => {
955 return Err(APIError::CustomError {
956 message: e.to_string(),
957 })
958 }
959 };
960
961 let mut form = Form::new().part("file", Part::bytes(bytes.clone()).file_name("file.mp3"));
962
963 if let Value::Object(map) = json {
964 for (key, value) in map.into_iter() {
965 match value {
966 Value::String(s) => {
967 form = form.text(key, s);
968 }
969 Value::Number(n) => {
970 form = form.text(key, n.to_string());
971 }
972 _ => {}
973 }
974 }
975 }
976
977 Ok(form)
978 }
979}