1use crate::v1::assistant::{
2 AssistantFileObject, AssistantFileRequest, AssistantObject, AssistantRequest, ListAssistant,
3 ListAssistantFile,
4};
5use crate::v1::audio::{
6 AudioSpeechRequest, AudioSpeechResponse, AudioTranscriptionRequest, AudioTranscriptionResponse,
7 AudioTranslationRequest, AudioTranslationResponse,
8};
9use crate::v1::batch::{BatchResponse, CreateBatchRequest, ListBatchResponse};
10use crate::v1::chat_completion::chat_completion::{ChatCompletionRequest, ChatCompletionResponse};
11use crate::v1::chat_completion::chat_completion_stream::{
12 ChatCompletionStream, ChatCompletionStreamRequest, ChatCompletionStreamResponse,
13};
14use crate::v1::common;
15use crate::v1::completion::{CompletionRequest, CompletionResponse};
16use crate::v1::edit::{EditRequest, EditResponse};
17use crate::v1::embedding::{EmbeddingRequest, EmbeddingResponse};
18use crate::v1::error::APIError;
19use crate::v1::file::{
20 FileDeleteRequest, FileDeleteResponse, FileListResponse, FileRetrieveResponse,
21 FileUploadRequest, FileUploadResponse,
22};
23use crate::v1::fine_tuning::{
24 CancelFineTuningJobRequest, CreateFineTuningJobRequest, FineTuningJobEvent,
25 FineTuningJobObject, FineTuningPagination, ListFineTuningJobEventsRequest,
26 RetrieveFineTuningJobRequest,
27};
28use crate::v1::image::{
29 ImageEditRequest, ImageEditResponse, ImageGenerationRequest, ImageGenerationResponse,
30 ImageVariationRequest, ImageVariationResponse,
31};
32use crate::v1::message::{
33 CreateMessageRequest, ListMessage, ListMessageFile, MessageFileObject, MessageObject,
34 ModifyMessageRequest,
35};
36use crate::v1::model::{ModelResponse, ModelsResponse};
37use crate::v1::moderation::{CreateModerationRequest, CreateModerationResponse};
38use crate::v1::responses::{
39 CountTokensRequest, CountTokensResponse, CreateResponseRequest, ListResponses, ResponseObject,
40};
41use crate::v1::run::{
42 CreateRunRequest, CreateThreadAndRunRequest, ListRun, ListRunStep, ModifyRunRequest, RunObject,
43 RunStepObject,
44};
45use crate::v1::thread::{CreateThreadRequest, ModifyThreadRequest, ThreadObject};
46
47use bytes::Bytes;
48use futures_util::Stream;
49use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
50use reqwest::multipart::{Form, Part};
51use reqwest::{Client, Method, Response};
52use serde::Serialize;
53use serde_json::{to_value, Value};
54use url::Url;
55
56use std::error::Error;
57use std::fs::{create_dir_all, File};
58use std::io::Read;
59use std::io::Write;
60use std::path::Path;
61
62const API_URL_V1: &str = "https://api.openai.com/v1";
63
64#[derive(Default)]
65pub struct OpenAIClientBuilder {
66 api_endpoint: Option<String>,
67 api_key: Option<String>,
68 organization: Option<String>,
69 proxy: Option<String>,
70 timeout: Option<u64>,
71 headers: Option<HeaderMap>,
72}
73
74#[derive(Debug)]
75pub struct OpenAIClient {
76 api_endpoint: String,
77 api_key: Option<String>,
78 organization: Option<String>,
79 proxy: Option<String>,
80 timeout: Option<u64>,
81 headers: Option<HeaderMap>,
82 pub response_headers: Option<HeaderMap>,
83}
84
85impl OpenAIClientBuilder {
86 pub fn new() -> Self {
87 Self::default()
88 }
89
90 pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
91 self.api_key = Some(api_key.into());
92 self
93 }
94
95 pub fn with_endpoint(mut self, endpoint: impl Into<String>) -> Self {
96 self.api_endpoint = Some(endpoint.into());
97 self
98 }
99
100 pub fn with_organization(mut self, organization: impl Into<String>) -> Self {
101 self.organization = Some(organization.into());
102 self
103 }
104
105 pub fn with_proxy(mut self, proxy: impl Into<String>) -> Self {
106 self.proxy = Some(proxy.into());
107 self
108 }
109
110 pub fn with_timeout(mut self, timeout: u64) -> Self {
111 self.timeout = Some(timeout);
112 self
113 }
114
115 pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
116 let headers = self.headers.get_or_insert_with(HeaderMap::new);
117 headers.insert(
118 HeaderName::from_bytes(key.into().as_bytes()).expect("Invalid header name"),
119 HeaderValue::from_str(&value.into()).expect("Invalid header value"),
120 );
121 self
122 }
123
124 pub fn build(self) -> Result<OpenAIClient, Box<dyn Error>> {
125 let api_endpoint = self.api_endpoint.unwrap_or_else(|| {
126 std::env::var("OPENAI_API_BASE").unwrap_or_else(|_| API_URL_V1.to_owned())
127 });
128
129 Ok(OpenAIClient {
130 api_endpoint,
131 api_key: self.api_key,
132 organization: self.organization,
133 proxy: self.proxy,
134 timeout: self.timeout,
135 headers: self.headers,
136 response_headers: None,
137 })
138 }
139}
140
141impl OpenAIClient {
142 pub fn builder() -> OpenAIClientBuilder {
143 OpenAIClientBuilder::new()
144 }
145
146 async fn build_request(&self, method: Method, path: &str) -> reqwest::RequestBuilder {
147 let url = self
148 .build_url_with_preserved_query(path)
149 .unwrap_or_else(|_| format!("{}/{}", self.api_endpoint, path));
150
151 let client = Client::builder();
152
153 #[cfg(feature = "rustls")]
154 let client = client.use_rustls_tls();
155
156 let client = if let Some(timeout) = self.timeout {
157 client.timeout(std::time::Duration::from_secs(timeout))
158 } else {
159 client
160 };
161
162 let client = if let Some(proxy) = &self.proxy {
163 client.proxy(reqwest::Proxy::all(proxy).unwrap())
164 } else {
165 client
166 };
167
168 let client = client.build().unwrap();
169
170 let mut request = client.request(method, url);
171
172 if let Some(api_key) = &self.api_key {
173 request = request.header("Authorization", format!("Bearer {api_key}"));
174 }
175
176 if let Some(organization) = &self.organization {
177 request = request.header("openai-organization", organization);
178 }
179
180 if let Some(headers) = &self.headers {
181 for (key, value) in headers {
182 request = request.header(key, value);
183 }
184 }
185
186 if Self::is_beta(path) {
187 request = request.header("OpenAI-Beta", "assistants=v2");
188 }
189
190 request
191 }
192
193 async fn post<T: serde::de::DeserializeOwned>(
194 &mut self,
195 path: &str,
196 body: &impl serde::ser::Serialize,
197 ) -> Result<T, APIError> {
198 let request = self.build_request(Method::POST, path).await;
199 let request = request.json(body);
200 let response = request.send().await?;
201 self.handle_response(response).await
202 }
203
204 async fn get<T: serde::de::DeserializeOwned>(&mut self, path: &str) -> Result<T, APIError> {
205 let request = self.build_request(Method::GET, path).await;
206 let response = request.send().await?;
207 self.handle_response(response).await
208 }
209
210 async fn get_raw(&self, path: &str) -> Result<Bytes, APIError> {
211 let request = self.build_request(Method::GET, path).await;
212 let response = request.send().await?;
213 Ok(response.bytes().await?)
214 }
215
216 async fn delete<T: serde::de::DeserializeOwned>(&mut self, path: &str) -> Result<T, APIError> {
217 let request = self.build_request(Method::DELETE, path).await;
218 let response = request.send().await?;
219 self.handle_response(response).await
220 }
221
222 async fn post_form<T: serde::de::DeserializeOwned>(
223 &mut self,
224 path: &str,
225 form: Form,
226 ) -> Result<T, APIError> {
227 let request = self.build_request(Method::POST, path).await;
228 let request = request.multipart(form);
229 let response = request.send().await?;
230 self.handle_response(response).await
231 }
232
233 async fn post_form_raw(&self, path: &str, form: Form) -> Result<Bytes, APIError> {
234 let request = self.build_request(Method::POST, path).await;
235 let request = request.multipart(form);
236 let response = request.send().await?;
237 Ok(response.bytes().await?)
238 }
239
240 async fn handle_response<T: serde::de::DeserializeOwned>(
241 &mut self,
242 response: Response,
243 ) -> Result<T, APIError> {
244 let status = response.status();
245 let headers = response.headers().clone();
246 if status.is_success() {
247 let text = response.text().await.unwrap_or_else(|_| "".to_string());
248 match serde_json::from_str::<T>(&text) {
249 Ok(parsed) => {
250 self.response_headers = Some(headers);
251 Ok(parsed)
252 }
253 Err(e) => Err(APIError::CustomError {
254 message: format!("Failed to parse JSON: {e} / response {text}"),
255 }),
256 }
257 } else {
258 let error_message = response
259 .text()
260 .await
261 .unwrap_or_else(|_| "Unknown error".to_string());
262 Err(APIError::CustomError {
263 message: format!("{status}: {error_message}"),
264 })
265 }
266 }
267
268 pub async fn completion(
269 &mut self,
270 req: CompletionRequest,
271 ) -> Result<CompletionResponse, APIError> {
272 self.post("completions", &req).await
273 }
274
275 pub async fn edit(&mut self, req: EditRequest) -> Result<EditResponse, APIError> {
276 self.post("edits", &req).await
277 }
278
279 pub async fn image_generation(
280 &mut self,
281 req: ImageGenerationRequest,
282 ) -> Result<ImageGenerationResponse, APIError> {
283 self.post("images/generations", &req).await
284 }
285
286 pub async fn image_edit(
287 &mut self,
288 req: ImageEditRequest,
289 ) -> Result<ImageEditResponse, APIError> {
290 self.post("images/edits", &req).await
291 }
292
293 pub async fn image_variation(
294 &mut self,
295 req: ImageVariationRequest,
296 ) -> Result<ImageVariationResponse, APIError> {
297 self.post("images/variations", &req).await
298 }
299
300 pub async fn embedding(
301 &mut self,
302 req: EmbeddingRequest,
303 ) -> Result<EmbeddingResponse, APIError> {
304 self.post("embeddings", &req).await
305 }
306
307 pub async fn file_list(&mut self) -> Result<FileListResponse, APIError> {
308 self.get("files").await
309 }
310
311 pub async fn upload_file(
312 &mut self,
313 req: FileUploadRequest,
314 ) -> Result<FileUploadResponse, APIError> {
315 let form = Self::create_form(&req, "file")?;
316 self.post_form("files", form).await
317 }
318
319 pub async fn delete_file(
320 &mut self,
321 req: FileDeleteRequest,
322 ) -> Result<FileDeleteResponse, APIError> {
323 self.delete(&format!("files/{}", req.file_id)).await
324 }
325
326 pub async fn retrieve_file(
327 &mut self,
328 file_id: String,
329 ) -> Result<FileRetrieveResponse, APIError> {
330 self.get(&format!("files/{file_id}")).await
331 }
332
333 pub async fn retrieve_file_content(&self, file_id: String) -> Result<Bytes, APIError> {
334 self.get_raw(&format!("files/{file_id}/content")).await
335 }
336
337 pub async fn chat_completion(
338 &mut self,
339 req: ChatCompletionRequest,
340 ) -> Result<ChatCompletionResponse, APIError> {
341 self.post("chat/completions", &req).await
342 }
343
344 pub async fn chat_completion_stream(
345 &mut self,
346 req: ChatCompletionStreamRequest,
347 ) -> Result<impl Stream<Item = ChatCompletionStreamResponse>, APIError> {
348 let mut payload = to_value(&req).map_err(|err| APIError::CustomError {
349 message: format!("Failed to serialize request: {}", err),
350 })?;
351
352 if let Some(obj) = payload.as_object_mut() {
353 obj.insert("stream".into(), Value::Bool(true));
354 }
355
356 let request = self.build_request(Method::POST, "chat/completions").await;
357 let request = request.json(&payload);
358 let response = request.send().await?;
359
360 if response.status().is_success() {
361 Ok(ChatCompletionStream {
362 response: Box::pin(response.bytes_stream()),
363 buffer: String::new(),
364 first_chunk: true,
365 })
366 } else {
367 let error_text = response
368 .text()
369 .await
370 .unwrap_or_else(|_| String::from("Unknown error"));
371
372 Err(APIError::CustomError {
373 message: error_text,
374 })
375 }
376 }
377
378 pub async fn audio_transcription(
379 &mut self,
380 req: AudioTranscriptionRequest,
381 ) -> Result<AudioTranscriptionResponse, APIError> {
382 if let Some(response_format) = &req.response_format {
384 if response_format != "json" && response_format != "verbose_json" {
385 return Err(APIError::CustomError {
386 message: "response_format must be either 'json' or 'verbose_json' please use audio_transcription_raw".to_string(),
387 });
388 }
389 }
390 let form: Form;
391 if req.clone().file.is_some() {
392 form = Self::create_form(&req, "file")?;
393 } else if let Some(bytes) = req.clone().bytes {
394 form = Self::create_form_from_bytes(&req, bytes)?;
395 } else {
396 return Err(APIError::CustomError {
397 message: "Either file or bytes must be provided".to_string(),
398 });
399 }
400 self.post_form("audio/transcriptions", form).await
401 }
402
403 pub async fn audio_transcription_raw(
404 &mut self,
405 req: AudioTranscriptionRequest,
406 ) -> Result<Bytes, APIError> {
407 if let Some(response_format) = &req.response_format {
409 if response_format != "text" && response_format != "srt" && response_format != "vtt" {
410 return Err(APIError::CustomError {
411 message: "response_format must be either 'text', 'srt' or 'vtt', please use audio_transcription".to_string(),
412 });
413 }
414 }
415 let form: Form;
416 if req.clone().file.is_some() {
417 form = Self::create_form(&req, "file")?;
418 } else if let Some(bytes) = req.clone().bytes {
419 form = Self::create_form_from_bytes(&req, bytes)?;
420 } else {
421 return Err(APIError::CustomError {
422 message: "Either file or bytes must be provided".to_string(),
423 });
424 }
425 self.post_form_raw("audio/transcriptions", form).await
426 }
427
428 pub async fn audio_translation(
429 &mut self,
430 req: AudioTranslationRequest,
431 ) -> Result<AudioTranslationResponse, APIError> {
432 let form = Self::create_form(&req, "file")?;
433 self.post_form("audio/translations", form).await
434 }
435
436 pub async fn audio_speech(
437 &mut self,
438 req: AudioSpeechRequest,
439 ) -> Result<AudioSpeechResponse, APIError> {
440 let request = self.build_request(Method::POST, "audio/speech").await;
441 let request = request.json(&req);
442 let response = request.send().await?;
443 let headers = response.headers().clone();
444 let bytes = response.bytes().await?;
445 let path = Path::new(req.output.as_str());
446 if let Some(parent) = path.parent() {
447 match create_dir_all(parent) {
448 Ok(_) => {}
449 Err(e) => {
450 return Err(APIError::CustomError {
451 message: e.to_string(),
452 })
453 }
454 }
455 }
456 match File::create(path) {
457 Ok(mut file) => match file.write_all(&bytes) {
458 Ok(_) => {}
459 Err(e) => {
460 return Err(APIError::CustomError {
461 message: e.to_string(),
462 })
463 }
464 },
465 Err(e) => {
466 return Err(APIError::CustomError {
467 message: e.to_string(),
468 })
469 }
470 }
471
472 Ok(AudioSpeechResponse {
473 result: true,
474 headers: Some(headers),
475 })
476 }
477
478 pub async fn create_fine_tuning_job(
479 &mut self,
480 req: CreateFineTuningJobRequest,
481 ) -> Result<FineTuningJobObject, APIError> {
482 self.post("fine_tuning/jobs", &req).await
483 }
484
485 pub async fn list_fine_tuning_jobs(
486 &mut self,
487 ) -> Result<FineTuningPagination<FineTuningJobObject>, APIError> {
488 self.get("fine_tuning/jobs").await
489 }
490
491 pub async fn list_fine_tuning_job_events(
492 &mut self,
493 req: ListFineTuningJobEventsRequest,
494 ) -> Result<FineTuningPagination<FineTuningJobEvent>, APIError> {
495 self.get(&format!(
496 "fine_tuning/jobs/{}/events",
497 req.fine_tuning_job_id
498 ))
499 .await
500 }
501
502 pub async fn retrieve_fine_tuning_job(
503 &mut self,
504 req: RetrieveFineTuningJobRequest,
505 ) -> Result<FineTuningJobObject, APIError> {
506 self.get(&format!("fine_tuning/jobs/{}", req.fine_tuning_job_id))
507 .await
508 }
509
510 pub async fn cancel_fine_tuning_job(
511 &mut self,
512 req: CancelFineTuningJobRequest,
513 ) -> Result<FineTuningJobObject, APIError> {
514 self.post(
515 &format!("fine_tuning/jobs/{}/cancel", req.fine_tuning_job_id),
516 &req,
517 )
518 .await
519 }
520
521 pub async fn create_moderation(
522 &mut self,
523 req: CreateModerationRequest,
524 ) -> Result<CreateModerationResponse, APIError> {
525 self.post("moderations", &req).await
526 }
527
528 pub async fn create_assistant(
529 &mut self,
530 req: AssistantRequest,
531 ) -> Result<AssistantObject, APIError> {
532 self.post("assistants", &req).await
533 }
534
535 pub async fn retrieve_assistant(
536 &mut self,
537 assistant_id: String,
538 ) -> Result<AssistantObject, APIError> {
539 self.get(&format!("assistants/{assistant_id}")).await
540 }
541
542 pub async fn modify_assistant(
543 &mut self,
544 assistant_id: String,
545 req: AssistantRequest,
546 ) -> Result<AssistantObject, APIError> {
547 self.post(&format!("assistants/{assistant_id}"), &req).await
548 }
549
550 pub async fn delete_assistant(
551 &mut self,
552 assistant_id: String,
553 ) -> Result<common::DeletionStatus, APIError> {
554 self.delete(&format!("assistants/{assistant_id}")).await
555 }
556
557 pub async fn list_assistant(
558 &mut self,
559 limit: Option<i64>,
560 order: Option<String>,
561 after: Option<String>,
562 before: Option<String>,
563 ) -> Result<ListAssistant, APIError> {
564 let url = Self::query_params(limit, order, after, before, "assistants".to_string());
565 self.get(&url).await
566 }
567
568 pub async fn create_assistant_file(
569 &mut self,
570 assistant_id: String,
571 req: AssistantFileRequest,
572 ) -> Result<AssistantFileObject, APIError> {
573 self.post(&format!("assistants/{assistant_id}/files"), &req)
574 .await
575 }
576
577 pub async fn retrieve_assistant_file(
578 &mut self,
579 assistant_id: String,
580 file_id: String,
581 ) -> Result<AssistantFileObject, APIError> {
582 self.get(&format!("assistants/{assistant_id}/files/{file_id}"))
583 .await
584 }
585
586 pub async fn delete_assistant_file(
587 &mut self,
588 assistant_id: String,
589 file_id: String,
590 ) -> Result<common::DeletionStatus, APIError> {
591 self.delete(&format!("assistants/{assistant_id}/files/{file_id}"))
592 .await
593 }
594
595 pub async fn list_assistant_file(
596 &mut self,
597 assistant_id: String,
598 limit: Option<i64>,
599 order: Option<String>,
600 after: Option<String>,
601 before: Option<String>,
602 ) -> Result<ListAssistantFile, APIError> {
603 let url = Self::query_params(
604 limit,
605 order,
606 after,
607 before,
608 format!("assistants/{assistant_id}/files"),
609 );
610 self.get(&url).await
611 }
612
613 pub async fn create_thread(
614 &mut self,
615 req: CreateThreadRequest,
616 ) -> Result<ThreadObject, APIError> {
617 self.post("threads", &req).await
618 }
619
620 pub async fn retrieve_thread(&mut self, thread_id: String) -> Result<ThreadObject, APIError> {
621 self.get(&format!("threads/{thread_id}")).await
622 }
623
624 pub async fn modify_thread(
625 &mut self,
626 thread_id: String,
627 req: ModifyThreadRequest,
628 ) -> Result<ThreadObject, APIError> {
629 self.post(&format!("threads/{thread_id}"), &req).await
630 }
631
632 pub async fn delete_thread(
633 &mut self,
634 thread_id: String,
635 ) -> Result<common::DeletionStatus, APIError> {
636 self.delete(&format!("threads/{thread_id}")).await
637 }
638
639 pub async fn create_message(
640 &mut self,
641 thread_id: String,
642 req: CreateMessageRequest,
643 ) -> Result<MessageObject, APIError> {
644 self.post(&format!("threads/{thread_id}/messages"), &req)
645 .await
646 }
647
648 pub async fn retrieve_message(
649 &mut self,
650 thread_id: String,
651 message_id: String,
652 ) -> Result<MessageObject, APIError> {
653 self.get(&format!("threads/{thread_id}/messages/{message_id}"))
654 .await
655 }
656
657 pub async fn modify_message(
658 &mut self,
659 thread_id: String,
660 message_id: String,
661 req: ModifyMessageRequest,
662 ) -> Result<MessageObject, APIError> {
663 self.post(&format!("threads/{thread_id}/messages/{message_id}"), &req)
664 .await
665 }
666
667 pub async fn list_messages(&mut self, thread_id: String) -> Result<ListMessage, APIError> {
668 self.get(&format!("threads/{thread_id}/messages")).await
669 }
670
671 pub async fn retrieve_message_file(
672 &mut self,
673 thread_id: String,
674 message_id: String,
675 file_id: String,
676 ) -> Result<MessageFileObject, APIError> {
677 self.get(&format!(
678 "threads/{thread_id}/messages/{message_id}/files/{file_id}"
679 ))
680 .await
681 }
682
683 pub async fn list_message_file(
684 &mut self,
685 thread_id: String,
686 message_id: String,
687 limit: Option<i64>,
688 order: Option<String>,
689 after: Option<String>,
690 before: Option<String>,
691 ) -> Result<ListMessageFile, APIError> {
692 let url = Self::query_params(
693 limit,
694 order,
695 after,
696 before,
697 format!("threads/{thread_id}/messages/{message_id}/files"),
698 );
699 self.get(&url).await
700 }
701
702 pub async fn create_run(
703 &mut self,
704 thread_id: String,
705 req: CreateRunRequest,
706 ) -> Result<RunObject, APIError> {
707 self.post(&format!("threads/{thread_id}/runs"), &req).await
708 }
709
710 pub async fn retrieve_run(
711 &mut self,
712 thread_id: String,
713 run_id: String,
714 ) -> Result<RunObject, APIError> {
715 self.get(&format!("threads/{thread_id}/runs/{run_id}"))
716 .await
717 }
718
719 pub async fn modify_run(
720 &mut self,
721 thread_id: String,
722 run_id: String,
723 req: ModifyRunRequest,
724 ) -> Result<RunObject, APIError> {
725 self.post(&format!("threads/{thread_id}/runs/{run_id}"), &req)
726 .await
727 }
728
729 pub async fn list_run(
730 &mut self,
731 thread_id: String,
732 limit: Option<i64>,
733 order: Option<String>,
734 after: Option<String>,
735 before: Option<String>,
736 ) -> Result<ListRun, APIError> {
737 let url = Self::query_params(
738 limit,
739 order,
740 after,
741 before,
742 format!("threads/{thread_id}/runs"),
743 );
744 self.get(&url).await
745 }
746
747 pub async fn cancel_run(
748 &mut self,
749 thread_id: String,
750 run_id: String,
751 ) -> Result<RunObject, APIError> {
752 self.post(
753 &format!("threads/{thread_id}/runs/{run_id}/cancel"),
754 &ModifyRunRequest::default(),
755 )
756 .await
757 }
758
759 pub async fn create_thread_and_run(
760 &mut self,
761 req: CreateThreadAndRunRequest,
762 ) -> Result<RunObject, APIError> {
763 self.post("threads/runs", &req).await
764 }
765
766 pub async fn retrieve_run_step(
767 &mut self,
768 thread_id: String,
769 run_id: String,
770 step_id: String,
771 ) -> Result<RunStepObject, APIError> {
772 self.get(&format!(
773 "threads/{thread_id}/runs/{run_id}/steps/{step_id}"
774 ))
775 .await
776 }
777
778 pub async fn list_run_step(
779 &mut self,
780 thread_id: String,
781 run_id: String,
782 limit: Option<i64>,
783 order: Option<String>,
784 after: Option<String>,
785 before: Option<String>,
786 ) -> Result<ListRunStep, APIError> {
787 let url = Self::query_params(
788 limit,
789 order,
790 after,
791 before,
792 format!("threads/{thread_id}/runs/{run_id}/steps"),
793 );
794 self.get(&url).await
795 }
796
797 pub async fn create_batch(
798 &mut self,
799 req: CreateBatchRequest,
800 ) -> Result<BatchResponse, APIError> {
801 self.post("batches", &req).await
802 }
803
804 pub async fn retrieve_batch(&mut self, batch_id: String) -> Result<BatchResponse, APIError> {
805 self.get(&format!("batches/{batch_id}")).await
806 }
807
808 pub async fn cancel_batch(&mut self, batch_id: String) -> Result<BatchResponse, APIError> {
809 self.post(
810 &format!("batches/{batch_id}/cancel"),
811 &common::EmptyRequestBody {},
812 )
813 .await
814 }
815
816 pub async fn list_batch(
817 &mut self,
818 after: Option<String>,
819 limit: Option<i64>,
820 ) -> Result<ListBatchResponse, APIError> {
821 let url = Self::query_params(limit, None, after, None, "batches".to_string());
822 self.get(&url).await
823 }
824
825 pub async fn create_response(
827 &mut self,
828 req: CreateResponseRequest,
829 ) -> Result<ResponseObject, APIError> {
830 self.post("responses", &req).await
831 }
832
833 pub async fn retrieve_response(
834 &mut self,
835 response_id: String,
836 ) -> Result<ResponseObject, APIError> {
837 self.get(&format!("responses/{response_id}")).await
838 }
839
840 pub async fn delete_response(
841 &mut self,
842 response_id: String,
843 ) -> Result<common::DeletionStatus, APIError> {
844 self.delete(&format!("responses/{response_id}")).await
845 }
846
847 pub async fn cancel_response(
848 &mut self,
849 response_id: String,
850 ) -> Result<ResponseObject, APIError> {
851 self.post(
852 &format!("responses/{response_id}/cancel"),
853 &common::EmptyRequestBody {},
854 )
855 .await
856 }
857
858 pub async fn list_response_input_items(
859 &mut self,
860 response_id: String,
861 after: Option<String>,
862 limit: Option<i64>,
863 order: Option<String>,
864 ) -> Result<ListResponses, APIError> {
865 let mut url = format!("responses/{}/input_items", response_id);
866 let mut params = vec![];
867 if let Some(after) = after {
868 params.push(format!("after={}", after));
869 }
870 if let Some(limit) = limit {
871 params.push(format!("limit={}", limit));
872 }
873 if let Some(order) = order {
874 params.push(format!("order={}", order));
875 }
876 if !params.is_empty() {
877 url = format!("{}?{}", url, params.join("&"));
878 }
879 self.get(&url).await
880 }
881
882 pub async fn count_response_input_tokens(
883 &mut self,
884 req: CountTokensRequest,
885 ) -> Result<CountTokensResponse, APIError> {
886 self.post("responses/input_tokens", &req).await
887 }
888
889 pub async fn list_models(&mut self) -> Result<ModelsResponse, APIError> {
890 self.get("models").await
891 }
892
893 pub async fn retrieve_model(&mut self, model_id: String) -> Result<ModelResponse, APIError> {
894 self.get(&format!("models/{model_id}")).await
895 }
896
897 pub async fn delete_model(
898 &mut self,
899 model_id: String,
900 ) -> Result<common::DeletionStatus, APIError> {
901 self.delete(&format!("models/{model_id}")).await
902 }
903
904 fn build_url_with_preserved_query(&self, path: &str) -> Result<String, url::ParseError> {
905 let (base, query_opt) = match self.api_endpoint.split_once('?') {
906 Some((b, q)) => (b.trim_end_matches('/'), Some(q)),
907 None => (self.api_endpoint.trim_end_matches('/'), None),
908 };
909
910 let full_path = format!("{}/{}", base, path.trim_start_matches('/'));
911 let mut url = Url::parse(&full_path)?;
912
913 if let Some(query) = query_opt {
914 for (k, v) in url::form_urlencoded::parse(query.as_bytes()) {
915 url.query_pairs_mut().append_pair(&k, &v);
916 }
917 }
918 Ok(url.to_string())
919 }
920
921 fn query_params(
922 limit: Option<i64>,
923 order: Option<String>,
924 after: Option<String>,
925 before: Option<String>,
926 mut url: String,
927 ) -> String {
928 let mut params = vec![];
929 if let Some(limit) = limit {
930 params.push(format!("limit={limit}"));
931 }
932 if let Some(order) = order {
933 params.push(format!("order={order}"));
934 }
935 if let Some(after) = after {
936 params.push(format!("after={after}"));
937 }
938 if let Some(before) = before {
939 params.push(format!("before={before}"));
940 }
941 if !params.is_empty() {
942 url = format!("{}?{}", url, params.join("&"));
943 }
944 url
945 }
946
947 fn is_beta(path: &str) -> bool {
948 path.starts_with("assistants") || path.starts_with("threads")
949 }
950
951 fn create_form<T>(req: &T, file_field: &str) -> Result<Form, APIError>
952 where
953 T: Serialize,
954 {
955 let json = match serde_json::to_value(req) {
956 Ok(json) => json,
957 Err(e) => {
958 return Err(APIError::CustomError {
959 message: e.to_string(),
960 })
961 }
962 };
963 let file_path = if let Value::Object(map) = &json {
964 map.get(file_field)
965 .and_then(|v| v.as_str())
966 .ok_or(APIError::CustomError {
967 message: format!("Field '{file_field}' not found or not a string"),
968 })?
969 } else {
970 return Err(APIError::CustomError {
971 message: "Request is not a JSON object".to_string(),
972 });
973 };
974
975 let mut file = match File::open(file_path) {
976 Ok(file) => file,
977 Err(e) => {
978 return Err(APIError::CustomError {
979 message: e.to_string(),
980 })
981 }
982 };
983 let mut buffer = Vec::new();
984 match file.read_to_end(&mut buffer) {
985 Ok(_) => {}
986 Err(e) => {
987 return Err(APIError::CustomError {
988 message: e.to_string(),
989 })
990 }
991 }
992
993 let mut form =
994 Form::new().part("file", Part::bytes(buffer).file_name(file_path.to_string()));
995
996 if let Value::Object(map) = json {
997 for (key, value) in map.into_iter() {
998 if key != file_field {
999 match value {
1000 Value::String(s) => {
1001 form = form.text(key, s);
1002 }
1003 Value::Number(n) => {
1004 form = form.text(key, n.to_string());
1005 }
1006 _ => {}
1007 }
1008 }
1009 }
1010 }
1011
1012 Ok(form)
1013 }
1014
1015 fn create_form_from_bytes<T>(req: &T, bytes: Vec<u8>) -> Result<Form, APIError>
1016 where
1017 T: Serialize,
1018 {
1019 let json = match serde_json::to_value(req) {
1020 Ok(json) => json,
1021 Err(e) => {
1022 return Err(APIError::CustomError {
1023 message: e.to_string(),
1024 })
1025 }
1026 };
1027
1028 let mut form = Form::new().part("file", Part::bytes(bytes.clone()).file_name("file.mp3"));
1029
1030 if let Value::Object(map) = json {
1031 for (key, value) in map.into_iter() {
1032 match value {
1033 Value::String(s) => {
1034 form = form.text(key, s);
1035 }
1036 Value::Number(n) => {
1037 form = form.text(key, n.to_string());
1038 }
1039 _ => {}
1040 }
1041 }
1042 }
1043
1044 Ok(form)
1045 }
1046}