1use crate::v1::assistant::{
2 AssistantFileObject, AssistantFileRequest, AssistantObject, AssistantRequest, ListAssistant,
3 ListAssistantFile,
4};
5use crate::v1::audio::{
6 AudioSpeechRequest, AudioSpeechResponse, AudioTranscriptionRequest, AudioTranscriptionResponse,
7 AudioTranslationRequest, AudioTranslationResponse,
8};
9use crate::v1::batch::{BatchResponse, CreateBatchRequest, ListBatchResponse};
10use crate::v1::chat_completion::chat_completion::{ChatCompletionRequest, ChatCompletionResponse};
11use crate::v1::chat_completion::chat_completion_stream::{
12 ChatCompletionStream, ChatCompletionStreamRequest, ChatCompletionStreamResponse,
13};
14use crate::v1::common;
15use crate::v1::completion::{CompletionRequest, CompletionResponse};
16use crate::v1::edit::{EditRequest, EditResponse};
17use crate::v1::embedding::{EmbeddingRequest, EmbeddingResponse};
18use crate::v1::error::APIError;
19use crate::v1::file::{
20 FileDeleteRequest, FileDeleteResponse, FileListResponse, FileRetrieveResponse,
21 FileUploadRequest, FileUploadResponse,
22};
23use crate::v1::fine_tuning::{
24 CancelFineTuningJobRequest, CreateFineTuningJobRequest, FineTuningJobEvent,
25 FineTuningJobObject, FineTuningPagination, ListFineTuningJobEventsRequest,
26 RetrieveFineTuningJobRequest,
27};
28use crate::v1::image::{
29 ImageEditRequest, ImageEditResponse, ImageGenerationRequest, ImageGenerationResponse,
30 ImageVariationRequest, ImageVariationResponse,
31};
32use crate::v1::message::{
33 CreateMessageRequest, ListMessage, ListMessageFile, MessageFileObject, MessageObject,
34 ModifyMessageRequest,
35};
36use crate::v1::model::{ModelResponse, ModelsResponse};
37use crate::v1::moderation::{CreateModerationRequest, CreateModerationResponse};
38use crate::v1::responses::responses::{
39 CountTokensRequest, CountTokensResponse, CreateResponseRequest, ListResponses, ResponseObject,
40};
41use crate::v1::responses::responses_stream::{
42 CreateResponseStreamRequest, ResponseStream, ResponseStreamResponse,
43};
44use crate::v1::run::{
45 CreateRunRequest, CreateThreadAndRunRequest, ListRun, ListRunStep, ModifyRunRequest, RunObject,
46 RunStepObject,
47};
48use crate::v1::thread::{CreateThreadRequest, ModifyThreadRequest, ThreadObject};
49
50use bytes::Bytes;
51use futures_util::Stream;
52use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
53use reqwest::multipart::{Form, Part};
54use reqwest::{Client, Method, Response};
55use serde::Serialize;
56use serde_json::{to_value, Value};
57use url::Url;
58
59use std::error::Error;
60use std::fs::{create_dir_all, File};
61use std::io::Read;
62use std::io::Write;
63use std::path::Path;
64
65const API_URL_V1: &str = "https://api.openai.com/v1";
66
67#[derive(Default)]
68pub struct OpenAIClientBuilder {
69 api_endpoint: Option<String>,
70 api_key: Option<String>,
71 organization: Option<String>,
72 proxy: Option<String>,
73 timeout: Option<u64>,
74 headers: Option<HeaderMap>,
75}
76
77#[derive(Debug)]
78pub struct OpenAIClient {
79 api_endpoint: String,
80 api_key: Option<String>,
81 organization: Option<String>,
82 proxy: Option<String>,
83 timeout: Option<u64>,
84 headers: Option<HeaderMap>,
85 pub response_headers: Option<HeaderMap>,
86}
87
88impl OpenAIClientBuilder {
89 pub fn new() -> Self {
90 Self::default()
91 }
92
93 pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
94 self.api_key = Some(api_key.into());
95 self
96 }
97
98 pub fn with_endpoint(mut self, endpoint: impl Into<String>) -> Self {
99 self.api_endpoint = Some(endpoint.into());
100 self
101 }
102
103 pub fn with_organization(mut self, organization: impl Into<String>) -> Self {
104 self.organization = Some(organization.into());
105 self
106 }
107
108 pub fn with_proxy(mut self, proxy: impl Into<String>) -> Self {
109 self.proxy = Some(proxy.into());
110 self
111 }
112
113 pub fn with_timeout(mut self, timeout: u64) -> Self {
114 self.timeout = Some(timeout);
115 self
116 }
117
118 pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
119 let headers = self.headers.get_or_insert_with(HeaderMap::new);
120 headers.insert(
121 HeaderName::from_bytes(key.into().as_bytes()).expect("Invalid header name"),
122 HeaderValue::from_str(&value.into()).expect("Invalid header value"),
123 );
124 self
125 }
126
127 pub fn build(self) -> Result<OpenAIClient, Box<dyn Error>> {
128 let api_endpoint = self.api_endpoint.unwrap_or_else(|| {
129 std::env::var("OPENAI_API_BASE").unwrap_or_else(|_| API_URL_V1.to_owned())
130 });
131
132 Ok(OpenAIClient {
133 api_endpoint,
134 api_key: self.api_key,
135 organization: self.organization,
136 proxy: self.proxy,
137 timeout: self.timeout,
138 headers: self.headers,
139 response_headers: None,
140 })
141 }
142}
143
144impl OpenAIClient {
145 pub fn builder() -> OpenAIClientBuilder {
146 OpenAIClientBuilder::new()
147 }
148
149 async fn build_request(&self, method: Method, path: &str) -> reqwest::RequestBuilder {
150 let url = self
151 .build_url_with_preserved_query(path)
152 .unwrap_or_else(|_| format!("{}/{}", self.api_endpoint, path));
153
154 let client = Client::builder();
155
156 #[cfg(feature = "rustls")]
157 let client = client.use_rustls_tls();
158
159 let client = if let Some(timeout) = self.timeout {
160 client.timeout(std::time::Duration::from_secs(timeout))
161 } else {
162 client
163 };
164
165 let client = if let Some(proxy) = &self.proxy {
166 client.proxy(reqwest::Proxy::all(proxy).unwrap())
167 } else {
168 client
169 };
170
171 let client = client.build().unwrap();
172
173 let mut request = client.request(method, url);
174
175 if let Some(api_key) = &self.api_key {
176 request = request.header("Authorization", format!("Bearer {api_key}"));
177 }
178
179 if let Some(organization) = &self.organization {
180 request = request.header("openai-organization", organization);
181 }
182
183 if let Some(headers) = &self.headers {
184 for (key, value) in headers {
185 request = request.header(key, value);
186 }
187 }
188
189 if Self::is_beta(path) {
190 request = request.header("OpenAI-Beta", "assistants=v2");
191 }
192
193 request
194 }
195
196 async fn post<T: serde::de::DeserializeOwned>(
197 &mut self,
198 path: &str,
199 body: &impl serde::ser::Serialize,
200 ) -> Result<T, APIError> {
201 let request = self.build_request(Method::POST, path).await;
202 let request = request.json(body);
203 let response = request.send().await?;
204 self.handle_response(response).await
205 }
206
207 async fn get<T: serde::de::DeserializeOwned>(&mut self, path: &str) -> Result<T, APIError> {
208 let request = self.build_request(Method::GET, path).await;
209 let response = request.send().await?;
210 self.handle_response(response).await
211 }
212
213 async fn get_raw(&self, path: &str) -> Result<Bytes, APIError> {
214 let request = self.build_request(Method::GET, path).await;
215 let response = request.send().await?;
216 Ok(response.bytes().await?)
217 }
218
219 async fn delete<T: serde::de::DeserializeOwned>(&mut self, path: &str) -> Result<T, APIError> {
220 let request = self.build_request(Method::DELETE, path).await;
221 let response = request.send().await?;
222 self.handle_response(response).await
223 }
224
225 async fn post_form<T: serde::de::DeserializeOwned>(
226 &mut self,
227 path: &str,
228 form: Form,
229 ) -> Result<T, APIError> {
230 let request = self.build_request(Method::POST, path).await;
231 let request = request.multipart(form);
232 let response = request.send().await?;
233 self.handle_response(response).await
234 }
235
236 async fn post_form_raw(&self, path: &str, form: Form) -> Result<Bytes, APIError> {
237 let request = self.build_request(Method::POST, path).await;
238 let request = request.multipart(form);
239 let response = request.send().await?;
240 Ok(response.bytes().await?)
241 }
242
243 async fn handle_response<T: serde::de::DeserializeOwned>(
244 &mut self,
245 response: Response,
246 ) -> Result<T, APIError> {
247 let status = response.status();
248 let headers = response.headers().clone();
249 if status.is_success() {
250 let text = response.text().await.unwrap_or_else(|_| "".to_string());
251 match serde_json::from_str::<T>(&text) {
252 Ok(parsed) => {
253 self.response_headers = Some(headers);
254 Ok(parsed)
255 }
256 Err(e) => Err(APIError::CustomError {
257 message: format!("Failed to parse JSON: {e} / response {text}"),
258 }),
259 }
260 } else {
261 let error_message = response
262 .text()
263 .await
264 .unwrap_or_else(|_| "Unknown error".to_string());
265 Err(APIError::CustomError {
266 message: format!("{status}: {error_message}"),
267 })
268 }
269 }
270
271 pub async fn completion(
272 &mut self,
273 req: CompletionRequest,
274 ) -> Result<CompletionResponse, APIError> {
275 self.post("completions", &req).await
276 }
277
278 pub async fn edit(&mut self, req: EditRequest) -> Result<EditResponse, APIError> {
279 self.post("edits", &req).await
280 }
281
282 pub async fn image_generation(
283 &mut self,
284 req: ImageGenerationRequest,
285 ) -> Result<ImageGenerationResponse, APIError> {
286 self.post("images/generations", &req).await
287 }
288
289 pub async fn image_edit(
290 &mut self,
291 req: ImageEditRequest,
292 ) -> Result<ImageEditResponse, APIError> {
293 self.post("images/edits", &req).await
294 }
295
296 pub async fn image_variation(
297 &mut self,
298 req: ImageVariationRequest,
299 ) -> Result<ImageVariationResponse, APIError> {
300 self.post("images/variations", &req).await
301 }
302
303 pub async fn embedding(
304 &mut self,
305 req: EmbeddingRequest,
306 ) -> Result<EmbeddingResponse, APIError> {
307 self.post("embeddings", &req).await
308 }
309
310 pub async fn file_list(&mut self) -> Result<FileListResponse, APIError> {
311 self.get("files").await
312 }
313
314 pub async fn upload_file(
315 &mut self,
316 req: FileUploadRequest,
317 ) -> Result<FileUploadResponse, APIError> {
318 let form = Self::create_form(&req, "file")?;
319 self.post_form("files", form).await
320 }
321
322 pub async fn delete_file(
323 &mut self,
324 req: FileDeleteRequest,
325 ) -> Result<FileDeleteResponse, APIError> {
326 self.delete(&format!("files/{}", req.file_id)).await
327 }
328
329 pub async fn retrieve_file(
330 &mut self,
331 file_id: String,
332 ) -> Result<FileRetrieveResponse, APIError> {
333 self.get(&format!("files/{file_id}")).await
334 }
335
336 pub async fn retrieve_file_content(&self, file_id: String) -> Result<Bytes, APIError> {
337 self.get_raw(&format!("files/{file_id}/content")).await
338 }
339
340 pub async fn chat_completion(
341 &mut self,
342 req: ChatCompletionRequest,
343 ) -> Result<ChatCompletionResponse, APIError> {
344 self.post("chat/completions", &req).await
345 }
346
347 pub async fn chat_completion_stream(
348 &mut self,
349 req: ChatCompletionStreamRequest,
350 ) -> Result<impl Stream<Item = ChatCompletionStreamResponse>, APIError> {
351 let mut payload = to_value(&req).map_err(|err| APIError::CustomError {
352 message: format!("Failed to serialize request: {}", err),
353 })?;
354
355 if let Some(obj) = payload.as_object_mut() {
356 obj.insert("stream".into(), Value::Bool(true));
357 }
358
359 let request = self.build_request(Method::POST, "chat/completions").await;
360 let request = request.json(&payload);
361 let response = request.send().await?;
362
363 if response.status().is_success() {
364 Ok(ChatCompletionStream {
365 response: Box::pin(response.bytes_stream()),
366 buffer: String::new(),
367 first_chunk: true,
368 })
369 } else {
370 let error_text = response
371 .text()
372 .await
373 .unwrap_or_else(|_| String::from("Unknown error"));
374
375 Err(APIError::CustomError {
376 message: error_text,
377 })
378 }
379 }
380
381 pub async fn audio_transcription(
382 &mut self,
383 req: AudioTranscriptionRequest,
384 ) -> Result<AudioTranscriptionResponse, APIError> {
385 if let Some(response_format) = &req.response_format {
387 if response_format != "json" && response_format != "verbose_json" {
388 return Err(APIError::CustomError {
389 message: "response_format must be either 'json' or 'verbose_json' please use audio_transcription_raw".to_string(),
390 });
391 }
392 }
393 let form: Form;
394 if req.clone().file.is_some() {
395 form = Self::create_form(&req, "file")?;
396 } else if let Some(bytes) = req.clone().bytes {
397 form = Self::create_form_from_bytes(&req, bytes)?;
398 } else {
399 return Err(APIError::CustomError {
400 message: "Either file or bytes must be provided".to_string(),
401 });
402 }
403 self.post_form("audio/transcriptions", form).await
404 }
405
406 pub async fn audio_transcription_raw(
407 &mut self,
408 req: AudioTranscriptionRequest,
409 ) -> Result<Bytes, APIError> {
410 if let Some(response_format) = &req.response_format {
412 if response_format != "text" && response_format != "srt" && response_format != "vtt" {
413 return Err(APIError::CustomError {
414 message: "response_format must be either 'text', 'srt' or 'vtt', please use audio_transcription".to_string(),
415 });
416 }
417 }
418 let form: Form;
419 if req.clone().file.is_some() {
420 form = Self::create_form(&req, "file")?;
421 } else if let Some(bytes) = req.clone().bytes {
422 form = Self::create_form_from_bytes(&req, bytes)?;
423 } else {
424 return Err(APIError::CustomError {
425 message: "Either file or bytes must be provided".to_string(),
426 });
427 }
428 self.post_form_raw("audio/transcriptions", form).await
429 }
430
431 pub async fn audio_translation(
432 &mut self,
433 req: AudioTranslationRequest,
434 ) -> Result<AudioTranslationResponse, APIError> {
435 let form = Self::create_form(&req, "file")?;
436 self.post_form("audio/translations", form).await
437 }
438
439 pub async fn audio_speech(
440 &mut self,
441 req: AudioSpeechRequest,
442 ) -> Result<AudioSpeechResponse, APIError> {
443 let request = self.build_request(Method::POST, "audio/speech").await;
444 let request = request.json(&req);
445 let response = request.send().await?;
446 let headers = response.headers().clone();
447 let bytes = response.bytes().await?;
448 let path = Path::new(req.output.as_str());
449 if let Some(parent) = path.parent() {
450 match create_dir_all(parent) {
451 Ok(_) => {}
452 Err(e) => {
453 return Err(APIError::CustomError {
454 message: e.to_string(),
455 })
456 }
457 }
458 }
459 match File::create(path) {
460 Ok(mut file) => match file.write_all(&bytes) {
461 Ok(_) => {}
462 Err(e) => {
463 return Err(APIError::CustomError {
464 message: e.to_string(),
465 })
466 }
467 },
468 Err(e) => {
469 return Err(APIError::CustomError {
470 message: e.to_string(),
471 })
472 }
473 }
474
475 Ok(AudioSpeechResponse {
476 result: true,
477 headers: Some(headers),
478 })
479 }
480
481 pub async fn create_fine_tuning_job(
482 &mut self,
483 req: CreateFineTuningJobRequest,
484 ) -> Result<FineTuningJobObject, APIError> {
485 self.post("fine_tuning/jobs", &req).await
486 }
487
488 pub async fn list_fine_tuning_jobs(
489 &mut self,
490 ) -> Result<FineTuningPagination<FineTuningJobObject>, APIError> {
491 self.get("fine_tuning/jobs").await
492 }
493
494 pub async fn list_fine_tuning_job_events(
495 &mut self,
496 req: ListFineTuningJobEventsRequest,
497 ) -> Result<FineTuningPagination<FineTuningJobEvent>, APIError> {
498 self.get(&format!(
499 "fine_tuning/jobs/{}/events",
500 req.fine_tuning_job_id
501 ))
502 .await
503 }
504
505 pub async fn retrieve_fine_tuning_job(
506 &mut self,
507 req: RetrieveFineTuningJobRequest,
508 ) -> Result<FineTuningJobObject, APIError> {
509 self.get(&format!("fine_tuning/jobs/{}", req.fine_tuning_job_id))
510 .await
511 }
512
513 pub async fn cancel_fine_tuning_job(
514 &mut self,
515 req: CancelFineTuningJobRequest,
516 ) -> Result<FineTuningJobObject, APIError> {
517 self.post(
518 &format!("fine_tuning/jobs/{}/cancel", req.fine_tuning_job_id),
519 &req,
520 )
521 .await
522 }
523
524 pub async fn create_moderation(
525 &mut self,
526 req: CreateModerationRequest,
527 ) -> Result<CreateModerationResponse, APIError> {
528 self.post("moderations", &req).await
529 }
530
531 pub async fn create_assistant(
532 &mut self,
533 req: AssistantRequest,
534 ) -> Result<AssistantObject, APIError> {
535 self.post("assistants", &req).await
536 }
537
538 pub async fn retrieve_assistant(
539 &mut self,
540 assistant_id: String,
541 ) -> Result<AssistantObject, APIError> {
542 self.get(&format!("assistants/{assistant_id}")).await
543 }
544
545 pub async fn modify_assistant(
546 &mut self,
547 assistant_id: String,
548 req: AssistantRequest,
549 ) -> Result<AssistantObject, APIError> {
550 self.post(&format!("assistants/{assistant_id}"), &req).await
551 }
552
553 pub async fn delete_assistant(
554 &mut self,
555 assistant_id: String,
556 ) -> Result<common::DeletionStatus, APIError> {
557 self.delete(&format!("assistants/{assistant_id}")).await
558 }
559
560 pub async fn list_assistant(
561 &mut self,
562 limit: Option<i64>,
563 order: Option<String>,
564 after: Option<String>,
565 before: Option<String>,
566 ) -> Result<ListAssistant, APIError> {
567 let url = Self::query_params(limit, order, after, before, "assistants".to_string());
568 self.get(&url).await
569 }
570
571 pub async fn create_assistant_file(
572 &mut self,
573 assistant_id: String,
574 req: AssistantFileRequest,
575 ) -> Result<AssistantFileObject, APIError> {
576 self.post(&format!("assistants/{assistant_id}/files"), &req)
577 .await
578 }
579
580 pub async fn retrieve_assistant_file(
581 &mut self,
582 assistant_id: String,
583 file_id: String,
584 ) -> Result<AssistantFileObject, APIError> {
585 self.get(&format!("assistants/{assistant_id}/files/{file_id}"))
586 .await
587 }
588
589 pub async fn delete_assistant_file(
590 &mut self,
591 assistant_id: String,
592 file_id: String,
593 ) -> Result<common::DeletionStatus, APIError> {
594 self.delete(&format!("assistants/{assistant_id}/files/{file_id}"))
595 .await
596 }
597
598 pub async fn list_assistant_file(
599 &mut self,
600 assistant_id: String,
601 limit: Option<i64>,
602 order: Option<String>,
603 after: Option<String>,
604 before: Option<String>,
605 ) -> Result<ListAssistantFile, APIError> {
606 let url = Self::query_params(
607 limit,
608 order,
609 after,
610 before,
611 format!("assistants/{assistant_id}/files"),
612 );
613 self.get(&url).await
614 }
615
616 pub async fn create_thread(
617 &mut self,
618 req: CreateThreadRequest,
619 ) -> Result<ThreadObject, APIError> {
620 self.post("threads", &req).await
621 }
622
623 pub async fn retrieve_thread(&mut self, thread_id: String) -> Result<ThreadObject, APIError> {
624 self.get(&format!("threads/{thread_id}")).await
625 }
626
627 pub async fn modify_thread(
628 &mut self,
629 thread_id: String,
630 req: ModifyThreadRequest,
631 ) -> Result<ThreadObject, APIError> {
632 self.post(&format!("threads/{thread_id}"), &req).await
633 }
634
635 pub async fn delete_thread(
636 &mut self,
637 thread_id: String,
638 ) -> Result<common::DeletionStatus, APIError> {
639 self.delete(&format!("threads/{thread_id}")).await
640 }
641
642 pub async fn create_message(
643 &mut self,
644 thread_id: String,
645 req: CreateMessageRequest,
646 ) -> Result<MessageObject, APIError> {
647 self.post(&format!("threads/{thread_id}/messages"), &req)
648 .await
649 }
650
651 pub async fn retrieve_message(
652 &mut self,
653 thread_id: String,
654 message_id: String,
655 ) -> Result<MessageObject, APIError> {
656 self.get(&format!("threads/{thread_id}/messages/{message_id}"))
657 .await
658 }
659
660 pub async fn modify_message(
661 &mut self,
662 thread_id: String,
663 message_id: String,
664 req: ModifyMessageRequest,
665 ) -> Result<MessageObject, APIError> {
666 self.post(&format!("threads/{thread_id}/messages/{message_id}"), &req)
667 .await
668 }
669
670 pub async fn list_messages(&mut self, thread_id: String) -> Result<ListMessage, APIError> {
671 self.get(&format!("threads/{thread_id}/messages")).await
672 }
673
674 pub async fn retrieve_message_file(
675 &mut self,
676 thread_id: String,
677 message_id: String,
678 file_id: String,
679 ) -> Result<MessageFileObject, APIError> {
680 self.get(&format!(
681 "threads/{thread_id}/messages/{message_id}/files/{file_id}"
682 ))
683 .await
684 }
685
686 pub async fn list_message_file(
687 &mut self,
688 thread_id: String,
689 message_id: String,
690 limit: Option<i64>,
691 order: Option<String>,
692 after: Option<String>,
693 before: Option<String>,
694 ) -> Result<ListMessageFile, APIError> {
695 let url = Self::query_params(
696 limit,
697 order,
698 after,
699 before,
700 format!("threads/{thread_id}/messages/{message_id}/files"),
701 );
702 self.get(&url).await
703 }
704
705 pub async fn create_run(
706 &mut self,
707 thread_id: String,
708 req: CreateRunRequest,
709 ) -> Result<RunObject, APIError> {
710 self.post(&format!("threads/{thread_id}/runs"), &req).await
711 }
712
713 pub async fn retrieve_run(
714 &mut self,
715 thread_id: String,
716 run_id: String,
717 ) -> Result<RunObject, APIError> {
718 self.get(&format!("threads/{thread_id}/runs/{run_id}"))
719 .await
720 }
721
722 pub async fn modify_run(
723 &mut self,
724 thread_id: String,
725 run_id: String,
726 req: ModifyRunRequest,
727 ) -> Result<RunObject, APIError> {
728 self.post(&format!("threads/{thread_id}/runs/{run_id}"), &req)
729 .await
730 }
731
732 pub async fn list_run(
733 &mut self,
734 thread_id: String,
735 limit: Option<i64>,
736 order: Option<String>,
737 after: Option<String>,
738 before: Option<String>,
739 ) -> Result<ListRun, APIError> {
740 let url = Self::query_params(
741 limit,
742 order,
743 after,
744 before,
745 format!("threads/{thread_id}/runs"),
746 );
747 self.get(&url).await
748 }
749
750 pub async fn cancel_run(
751 &mut self,
752 thread_id: String,
753 run_id: String,
754 ) -> Result<RunObject, APIError> {
755 self.post(
756 &format!("threads/{thread_id}/runs/{run_id}/cancel"),
757 &ModifyRunRequest::default(),
758 )
759 .await
760 }
761
762 pub async fn create_thread_and_run(
763 &mut self,
764 req: CreateThreadAndRunRequest,
765 ) -> Result<RunObject, APIError> {
766 self.post("threads/runs", &req).await
767 }
768
769 pub async fn retrieve_run_step(
770 &mut self,
771 thread_id: String,
772 run_id: String,
773 step_id: String,
774 ) -> Result<RunStepObject, APIError> {
775 self.get(&format!(
776 "threads/{thread_id}/runs/{run_id}/steps/{step_id}"
777 ))
778 .await
779 }
780
781 pub async fn list_run_step(
782 &mut self,
783 thread_id: String,
784 run_id: String,
785 limit: Option<i64>,
786 order: Option<String>,
787 after: Option<String>,
788 before: Option<String>,
789 ) -> Result<ListRunStep, APIError> {
790 let url = Self::query_params(
791 limit,
792 order,
793 after,
794 before,
795 format!("threads/{thread_id}/runs/{run_id}/steps"),
796 );
797 self.get(&url).await
798 }
799
800 pub async fn create_batch(
801 &mut self,
802 req: CreateBatchRequest,
803 ) -> Result<BatchResponse, APIError> {
804 self.post("batches", &req).await
805 }
806
807 pub async fn retrieve_batch(&mut self, batch_id: String) -> Result<BatchResponse, APIError> {
808 self.get(&format!("batches/{batch_id}")).await
809 }
810
811 pub async fn cancel_batch(&mut self, batch_id: String) -> Result<BatchResponse, APIError> {
812 self.post(
813 &format!("batches/{batch_id}/cancel"),
814 &common::EmptyRequestBody {},
815 )
816 .await
817 }
818
819 pub async fn list_batch(
820 &mut self,
821 after: Option<String>,
822 limit: Option<i64>,
823 ) -> Result<ListBatchResponse, APIError> {
824 let url = Self::query_params(limit, None, after, None, "batches".to_string());
825 self.get(&url).await
826 }
827
828 pub async fn create_response(
830 &mut self,
831 req: CreateResponseRequest,
832 ) -> Result<ResponseObject, APIError> {
833 self.post("responses", &req).await
834 }
835
836 pub async fn create_response_stream(
837 &mut self,
838 req: CreateResponseStreamRequest,
839 ) -> Result<impl Stream<Item = ResponseStreamResponse>, APIError> {
840 let mut payload = to_value(&req).map_err(|err| APIError::CustomError {
841 message: format!("Failed to serialize request: {}", err),
842 })?;
843
844 if let Some(obj) = payload.as_object_mut() {
845 obj.insert("stream".into(), Value::Bool(true));
846 }
847
848 let request = self.build_request(Method::POST, "responses").await;
849 let request = request.json(&payload);
850 let response = request.send().await?;
851
852 if response.status().is_success() {
853 Ok(ResponseStream {
854 response: Box::pin(response.bytes_stream()),
855 buffer: String::new(),
856 first_chunk: true,
857 })
858 } else {
859 let error_text = response
860 .text()
861 .await
862 .unwrap_or_else(|_| String::from("Unknown error"));
863
864 Err(APIError::CustomError {
865 message: error_text,
866 })
867 }
868 }
869
870 pub async fn retrieve_response(
871 &mut self,
872 response_id: String,
873 ) -> Result<ResponseObject, APIError> {
874 self.get(&format!("responses/{response_id}")).await
875 }
876
877 pub async fn delete_response(
878 &mut self,
879 response_id: String,
880 ) -> Result<common::DeletionStatus, APIError> {
881 self.delete(&format!("responses/{response_id}")).await
882 }
883
884 pub async fn cancel_response(
885 &mut self,
886 response_id: String,
887 ) -> Result<ResponseObject, APIError> {
888 self.post(
889 &format!("responses/{response_id}/cancel"),
890 &common::EmptyRequestBody {},
891 )
892 .await
893 }
894
895 pub async fn list_response_input_items(
896 &mut self,
897 response_id: String,
898 after: Option<String>,
899 limit: Option<i64>,
900 order: Option<String>,
901 ) -> Result<ListResponses, APIError> {
902 let mut url = format!("responses/{}/input_items", response_id);
903 let mut params = vec![];
904 if let Some(after) = after {
905 params.push(format!("after={}", after));
906 }
907 if let Some(limit) = limit {
908 params.push(format!("limit={}", limit));
909 }
910 if let Some(order) = order {
911 params.push(format!("order={}", order));
912 }
913 if !params.is_empty() {
914 url = format!("{}?{}", url, params.join("&"));
915 }
916 self.get(&url).await
917 }
918
919 pub async fn count_response_input_tokens(
920 &mut self,
921 req: CountTokensRequest,
922 ) -> Result<CountTokensResponse, APIError> {
923 self.post("responses/input_tokens", &req).await
924 }
925
926 pub async fn list_models(&mut self) -> Result<ModelsResponse, APIError> {
927 self.get("models").await
928 }
929
930 pub async fn retrieve_model(&mut self, model_id: String) -> Result<ModelResponse, APIError> {
931 self.get(&format!("models/{model_id}")).await
932 }
933
934 pub async fn delete_model(
935 &mut self,
936 model_id: String,
937 ) -> Result<common::DeletionStatus, APIError> {
938 self.delete(&format!("models/{model_id}")).await
939 }
940
941 fn build_url_with_preserved_query(&self, path: &str) -> Result<String, url::ParseError> {
942 let (base, query_opt) = match self.api_endpoint.split_once('?') {
943 Some((b, q)) => (b.trim_end_matches('/'), Some(q)),
944 None => (self.api_endpoint.trim_end_matches('/'), None),
945 };
946
947 let full_path = format!("{}/{}", base, path.trim_start_matches('/'));
948 let mut url = Url::parse(&full_path)?;
949
950 if let Some(query) = query_opt {
951 for (k, v) in url::form_urlencoded::parse(query.as_bytes()) {
952 url.query_pairs_mut().append_pair(&k, &v);
953 }
954 }
955 Ok(url.to_string())
956 }
957
958 fn query_params(
959 limit: Option<i64>,
960 order: Option<String>,
961 after: Option<String>,
962 before: Option<String>,
963 mut url: String,
964 ) -> String {
965 let mut params = vec![];
966 if let Some(limit) = limit {
967 params.push(format!("limit={limit}"));
968 }
969 if let Some(order) = order {
970 params.push(format!("order={order}"));
971 }
972 if let Some(after) = after {
973 params.push(format!("after={after}"));
974 }
975 if let Some(before) = before {
976 params.push(format!("before={before}"));
977 }
978 if !params.is_empty() {
979 url = format!("{}?{}", url, params.join("&"));
980 }
981 url
982 }
983
984 fn is_beta(path: &str) -> bool {
985 path.starts_with("assistants") || path.starts_with("threads")
986 }
987
988 fn create_form<T>(req: &T, file_field: &str) -> Result<Form, APIError>
989 where
990 T: Serialize,
991 {
992 let json = match serde_json::to_value(req) {
993 Ok(json) => json,
994 Err(e) => {
995 return Err(APIError::CustomError {
996 message: e.to_string(),
997 })
998 }
999 };
1000 let file_path = if let Value::Object(map) = &json {
1001 map.get(file_field)
1002 .and_then(|v| v.as_str())
1003 .ok_or(APIError::CustomError {
1004 message: format!("Field '{file_field}' not found or not a string"),
1005 })?
1006 } else {
1007 return Err(APIError::CustomError {
1008 message: "Request is not a JSON object".to_string(),
1009 });
1010 };
1011
1012 let mut file = match File::open(file_path) {
1013 Ok(file) => file,
1014 Err(e) => {
1015 return Err(APIError::CustomError {
1016 message: e.to_string(),
1017 })
1018 }
1019 };
1020 let mut buffer = Vec::new();
1021 match file.read_to_end(&mut buffer) {
1022 Ok(_) => {}
1023 Err(e) => {
1024 return Err(APIError::CustomError {
1025 message: e.to_string(),
1026 })
1027 }
1028 }
1029
1030 let mut form =
1031 Form::new().part("file", Part::bytes(buffer).file_name(file_path.to_string()));
1032
1033 if let Value::Object(map) = json {
1034 for (key, value) in map.into_iter() {
1035 if key != file_field {
1036 match value {
1037 Value::String(s) => {
1038 form = form.text(key, s);
1039 }
1040 Value::Number(n) => {
1041 form = form.text(key, n.to_string());
1042 }
1043 _ => {}
1044 }
1045 }
1046 }
1047 }
1048
1049 Ok(form)
1050 }
1051
1052 fn create_form_from_bytes<T>(req: &T, bytes: Vec<u8>) -> Result<Form, APIError>
1053 where
1054 T: Serialize,
1055 {
1056 let json = match serde_json::to_value(req) {
1057 Ok(json) => json,
1058 Err(e) => {
1059 return Err(APIError::CustomError {
1060 message: e.to_string(),
1061 })
1062 }
1063 };
1064
1065 let mut form = Form::new().part("file", Part::bytes(bytes.clone()).file_name("file.mp3"));
1066
1067 if let Value::Object(map) = json {
1068 for (key, value) in map.into_iter() {
1069 match value {
1070 Value::String(s) => {
1071 form = form.text(key, s);
1072 }
1073 Value::Number(n) => {
1074 form = form.text(key, n.to_string());
1075 }
1076 _ => {}
1077 }
1078 }
1079 }
1080
1081 Ok(form)
1082 }
1083}