openai_api_rs/v1/
api.rs

1use crate::v1::assistant::{
2    AssistantFileObject, AssistantFileRequest, AssistantObject, AssistantRequest, ListAssistant,
3    ListAssistantFile,
4};
5use crate::v1::audio::{
6    AudioSpeechRequest, AudioSpeechResponse, AudioTranscriptionRequest, AudioTranscriptionResponse,
7    AudioTranslationRequest, AudioTranslationResponse,
8};
9use crate::v1::batch::{BatchResponse, CreateBatchRequest, ListBatchResponse};
10use crate::v1::chat_completion::chat_completion::{ChatCompletionRequest, ChatCompletionResponse};
11use crate::v1::chat_completion::chat_completion_stream::{
12    ChatCompletionStream, ChatCompletionStreamRequest, ChatCompletionStreamResponse,
13};
14use crate::v1::common;
15use crate::v1::completion::{CompletionRequest, CompletionResponse};
16use crate::v1::edit::{EditRequest, EditResponse};
17use crate::v1::embedding::{EmbeddingRequest, EmbeddingResponse};
18use crate::v1::error::APIError;
19use crate::v1::file::{
20    FileDeleteRequest, FileDeleteResponse, FileListResponse, FileRetrieveResponse,
21    FileUploadRequest, FileUploadResponse,
22};
23use crate::v1::fine_tuning::{
24    CancelFineTuningJobRequest, CreateFineTuningJobRequest, FineTuningJobEvent,
25    FineTuningJobObject, FineTuningPagination, ListFineTuningJobEventsRequest,
26    RetrieveFineTuningJobRequest,
27};
28use crate::v1::image::{
29    ImageEditRequest, ImageEditResponse, ImageGenerationRequest, ImageGenerationResponse,
30    ImageVariationRequest, ImageVariationResponse,
31};
32use crate::v1::message::{
33    CreateMessageRequest, ListMessage, ListMessageFile, MessageFileObject, MessageObject,
34    ModifyMessageRequest,
35};
36use crate::v1::model::{ModelResponse, ModelsResponse};
37use crate::v1::moderation::{CreateModerationRequest, CreateModerationResponse};
38use crate::v1::run::{
39    CreateRunRequest, CreateThreadAndRunRequest, ListRun, ListRunStep, ModifyRunRequest, RunObject,
40    RunStepObject,
41};
42use crate::v1::thread::{CreateThreadRequest, ModifyThreadRequest, ThreadObject};
43
44use bytes::Bytes;
45use futures_util::Stream;
46use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
47use reqwest::multipart::{Form, Part};
48use reqwest::{Client, Method, Response};
49use serde::Serialize;
50use serde_json::{to_value, Value};
51use url::Url;
52
53use std::error::Error;
54use std::fs::{create_dir_all, File};
55use std::io::Read;
56use std::io::Write;
57use std::path::Path;
58
59const API_URL_V1: &str = "https://api.openai.com/v1";
60
61#[derive(Default)]
62pub struct OpenAIClientBuilder {
63    api_endpoint: Option<String>,
64    api_key: Option<String>,
65    organization: Option<String>,
66    proxy: Option<String>,
67    timeout: Option<u64>,
68    headers: Option<HeaderMap>,
69}
70
71#[derive(Debug)]
72pub struct OpenAIClient {
73    api_endpoint: String,
74    api_key: Option<String>,
75    organization: Option<String>,
76    proxy: Option<String>,
77    timeout: Option<u64>,
78    headers: Option<HeaderMap>,
79    pub response_headers: Option<HeaderMap>,
80}
81
82impl OpenAIClientBuilder {
83    pub fn new() -> Self {
84        Self::default()
85    }
86
87    pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
88        self.api_key = Some(api_key.into());
89        self
90    }
91
92    pub fn with_endpoint(mut self, endpoint: impl Into<String>) -> Self {
93        self.api_endpoint = Some(endpoint.into());
94        self
95    }
96
97    pub fn with_organization(mut self, organization: impl Into<String>) -> Self {
98        self.organization = Some(organization.into());
99        self
100    }
101
102    pub fn with_proxy(mut self, proxy: impl Into<String>) -> Self {
103        self.proxy = Some(proxy.into());
104        self
105    }
106
107    pub fn with_timeout(mut self, timeout: u64) -> Self {
108        self.timeout = Some(timeout);
109        self
110    }
111
112    pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
113        let headers = self.headers.get_or_insert_with(HeaderMap::new);
114        headers.insert(
115            HeaderName::from_bytes(key.into().as_bytes()).expect("Invalid header name"),
116            HeaderValue::from_str(&value.into()).expect("Invalid header value"),
117        );
118        self
119    }
120
121    pub fn build(self) -> Result<OpenAIClient, Box<dyn Error>> {
122        let api_endpoint = self.api_endpoint.unwrap_or_else(|| {
123            std::env::var("OPENAI_API_BASE").unwrap_or_else(|_| API_URL_V1.to_owned())
124        });
125
126        Ok(OpenAIClient {
127            api_endpoint,
128            api_key: self.api_key,
129            organization: self.organization,
130            proxy: self.proxy,
131            timeout: self.timeout,
132            headers: self.headers,
133            response_headers: None,
134        })
135    }
136}
137
138impl OpenAIClient {
139    pub fn builder() -> OpenAIClientBuilder {
140        OpenAIClientBuilder::new()
141    }
142
143    async fn build_request(&self, method: Method, path: &str) -> reqwest::RequestBuilder {
144        let url = self
145            .build_url_with_preserved_query(path)
146            .unwrap_or_else(|_| format!("{}/{}", self.api_endpoint, path));
147
148        let client = Client::builder();
149
150        #[cfg(feature = "rustls")]
151        let client = client.use_rustls_tls();
152
153        let client = if let Some(timeout) = self.timeout {
154            client.timeout(std::time::Duration::from_secs(timeout))
155        } else {
156            client
157        };
158
159        let client = if let Some(proxy) = &self.proxy {
160            client.proxy(reqwest::Proxy::all(proxy).unwrap())
161        } else {
162            client
163        };
164
165        let client = client.build().unwrap();
166
167        let mut request = client.request(method, url);
168
169        if let Some(api_key) = &self.api_key {
170            request = request.header("Authorization", format!("Bearer {api_key}"));
171        }
172
173        if let Some(organization) = &self.organization {
174            request = request.header("openai-organization", organization);
175        }
176
177        if let Some(headers) = &self.headers {
178            for (key, value) in headers {
179                request = request.header(key, value);
180            }
181        }
182
183        if Self::is_beta(path) {
184            request = request.header("OpenAI-Beta", "assistants=v2");
185        }
186
187        request
188    }
189
190    async fn post<T: serde::de::DeserializeOwned>(
191        &mut self,
192        path: &str,
193        body: &impl serde::ser::Serialize,
194    ) -> Result<T, APIError> {
195        let request = self.build_request(Method::POST, path).await;
196        let request = request.json(body);
197        let response = request.send().await?;
198        self.handle_response(response).await
199    }
200
201    async fn get<T: serde::de::DeserializeOwned>(&mut self, path: &str) -> Result<T, APIError> {
202        let request = self.build_request(Method::GET, path).await;
203        let response = request.send().await?;
204        self.handle_response(response).await
205    }
206
207    async fn get_raw(&self, path: &str) -> Result<Bytes, APIError> {
208        let request = self.build_request(Method::GET, path).await;
209        let response = request.send().await?;
210        Ok(response.bytes().await?)
211    }
212
213    async fn delete<T: serde::de::DeserializeOwned>(&mut self, path: &str) -> Result<T, APIError> {
214        let request = self.build_request(Method::DELETE, path).await;
215        let response = request.send().await?;
216        self.handle_response(response).await
217    }
218
219    async fn post_form<T: serde::de::DeserializeOwned>(
220        &mut self,
221        path: &str,
222        form: Form,
223    ) -> Result<T, APIError> {
224        let request = self.build_request(Method::POST, path).await;
225        let request = request.multipart(form);
226        let response = request.send().await?;
227        self.handle_response(response).await
228    }
229
230    async fn post_form_raw(&self, path: &str, form: Form) -> Result<Bytes, APIError> {
231        let request = self.build_request(Method::POST, path).await;
232        let request = request.multipart(form);
233        let response = request.send().await?;
234        Ok(response.bytes().await?)
235    }
236
237    async fn handle_response<T: serde::de::DeserializeOwned>(
238        &mut self,
239        response: Response,
240    ) -> Result<T, APIError> {
241        let status = response.status();
242        let headers = response.headers().clone();
243        if status.is_success() {
244            let text = response.text().await.unwrap_or_else(|_| "".to_string());
245            match serde_json::from_str::<T>(&text) {
246                Ok(parsed) => {
247                    self.response_headers = Some(headers);
248                    Ok(parsed)
249                }
250                Err(e) => Err(APIError::CustomError {
251                    message: format!("Failed to parse JSON: {e} / response {text}"),
252                }),
253            }
254        } else {
255            let error_message = response
256                .text()
257                .await
258                .unwrap_or_else(|_| "Unknown error".to_string());
259            Err(APIError::CustomError {
260                message: format!("{status}: {error_message}"),
261            })
262        }
263    }
264
265    pub async fn completion(
266        &mut self,
267        req: CompletionRequest,
268    ) -> Result<CompletionResponse, APIError> {
269        self.post("completions", &req).await
270    }
271
272    pub async fn edit(&mut self, req: EditRequest) -> Result<EditResponse, APIError> {
273        self.post("edits", &req).await
274    }
275
276    pub async fn image_generation(
277        &mut self,
278        req: ImageGenerationRequest,
279    ) -> Result<ImageGenerationResponse, APIError> {
280        self.post("images/generations", &req).await
281    }
282
283    pub async fn image_edit(
284        &mut self,
285        req: ImageEditRequest,
286    ) -> Result<ImageEditResponse, APIError> {
287        self.post("images/edits", &req).await
288    }
289
290    pub async fn image_variation(
291        &mut self,
292        req: ImageVariationRequest,
293    ) -> Result<ImageVariationResponse, APIError> {
294        self.post("images/variations", &req).await
295    }
296
297    pub async fn embedding(
298        &mut self,
299        req: EmbeddingRequest,
300    ) -> Result<EmbeddingResponse, APIError> {
301        self.post("embeddings", &req).await
302    }
303
304    pub async fn file_list(&mut self) -> Result<FileListResponse, APIError> {
305        self.get("files").await
306    }
307
308    pub async fn upload_file(
309        &mut self,
310        req: FileUploadRequest,
311    ) -> Result<FileUploadResponse, APIError> {
312        let form = Self::create_form(&req, "file")?;
313        self.post_form("files", form).await
314    }
315
316    pub async fn delete_file(
317        &mut self,
318        req: FileDeleteRequest,
319    ) -> Result<FileDeleteResponse, APIError> {
320        self.delete(&format!("files/{}", req.file_id)).await
321    }
322
323    pub async fn retrieve_file(
324        &mut self,
325        file_id: String,
326    ) -> Result<FileRetrieveResponse, APIError> {
327        self.get(&format!("files/{file_id}")).await
328    }
329
330    pub async fn retrieve_file_content(&self, file_id: String) -> Result<Bytes, APIError> {
331        self.get_raw(&format!("files/{file_id}/content")).await
332    }
333
334    pub async fn chat_completion(
335        &mut self,
336        req: ChatCompletionRequest,
337    ) -> Result<ChatCompletionResponse, APIError> {
338        self.post("chat/completions", &req).await
339    }
340
341    pub async fn chat_completion_stream(
342        &mut self,
343        req: ChatCompletionStreamRequest,
344    ) -> Result<impl Stream<Item = ChatCompletionStreamResponse>, APIError> {
345        let mut payload = to_value(&req).map_err(|err| APIError::CustomError {
346            message: format!("Failed to serialize request: {}", err),
347        })?;
348
349        if let Some(obj) = payload.as_object_mut() {
350            obj.insert("stream".into(), Value::Bool(true));
351        }
352
353        let request = self.build_request(Method::POST, "chat/completions").await;
354        let request = request.json(&payload);
355        let response = request.send().await?;
356
357        if response.status().is_success() {
358            Ok(ChatCompletionStream {
359                response: Box::pin(response.bytes_stream()),
360                buffer: String::new(),
361                first_chunk: true,
362            })
363        } else {
364            let error_text = response
365                .text()
366                .await
367                .unwrap_or_else(|_| String::from("Unknown error"));
368
369            Err(APIError::CustomError {
370                message: error_text,
371            })
372        }
373    }
374
375    pub async fn audio_transcription(
376        &mut self,
377        req: AudioTranscriptionRequest,
378    ) -> Result<AudioTranscriptionResponse, APIError> {
379        // https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-response_format
380        if let Some(response_format) = &req.response_format {
381            if response_format != "json" && response_format != "verbose_json" {
382                return Err(APIError::CustomError {
383                    message: "response_format must be either 'json' or 'verbose_json' please use audio_transcription_raw".to_string(),
384                });
385            }
386        }
387        let form: Form;
388        if req.clone().file.is_some() {
389            form = Self::create_form(&req, "file")?;
390        } else if let Some(bytes) = req.clone().bytes {
391            form = Self::create_form_from_bytes(&req, bytes)?;
392        } else {
393            return Err(APIError::CustomError {
394                message: "Either file or bytes must be provided".to_string(),
395            });
396        }
397        self.post_form("audio/transcriptions", form).await
398    }
399
400    pub async fn audio_transcription_raw(
401        &mut self,
402        req: AudioTranscriptionRequest,
403    ) -> Result<Bytes, APIError> {
404        // https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-response_format
405        if let Some(response_format) = &req.response_format {
406            if response_format != "text" && response_format != "srt" && response_format != "vtt" {
407                return Err(APIError::CustomError {
408                    message: "response_format must be either 'text', 'srt' or 'vtt', please use audio_transcription".to_string(),
409                });
410            }
411        }
412        let form: Form;
413        if req.clone().file.is_some() {
414            form = Self::create_form(&req, "file")?;
415        } else if let Some(bytes) = req.clone().bytes {
416            form = Self::create_form_from_bytes(&req, bytes)?;
417        } else {
418            return Err(APIError::CustomError {
419                message: "Either file or bytes must be provided".to_string(),
420            });
421        }
422        self.post_form_raw("audio/transcriptions", form).await
423    }
424
425    pub async fn audio_translation(
426        &mut self,
427        req: AudioTranslationRequest,
428    ) -> Result<AudioTranslationResponse, APIError> {
429        let form = Self::create_form(&req, "file")?;
430        self.post_form("audio/translations", form).await
431    }
432
433    pub async fn audio_speech(
434        &mut self,
435        req: AudioSpeechRequest,
436    ) -> Result<AudioSpeechResponse, APIError> {
437        let request = self.build_request(Method::POST, "audio/speech").await;
438        let request = request.json(&req);
439        let response = request.send().await?;
440        let headers = response.headers().clone();
441        let bytes = response.bytes().await?;
442        let path = Path::new(req.output.as_str());
443        if let Some(parent) = path.parent() {
444            match create_dir_all(parent) {
445                Ok(_) => {}
446                Err(e) => {
447                    return Err(APIError::CustomError {
448                        message: e.to_string(),
449                    })
450                }
451            }
452        }
453        match File::create(path) {
454            Ok(mut file) => match file.write_all(&bytes) {
455                Ok(_) => {}
456                Err(e) => {
457                    return Err(APIError::CustomError {
458                        message: e.to_string(),
459                    })
460                }
461            },
462            Err(e) => {
463                return Err(APIError::CustomError {
464                    message: e.to_string(),
465                })
466            }
467        }
468
469        Ok(AudioSpeechResponse {
470            result: true,
471            headers: Some(headers),
472        })
473    }
474
475    pub async fn create_fine_tuning_job(
476        &mut self,
477        req: CreateFineTuningJobRequest,
478    ) -> Result<FineTuningJobObject, APIError> {
479        self.post("fine_tuning/jobs", &req).await
480    }
481
482    pub async fn list_fine_tuning_jobs(
483        &mut self,
484    ) -> Result<FineTuningPagination<FineTuningJobObject>, APIError> {
485        self.get("fine_tuning/jobs").await
486    }
487
488    pub async fn list_fine_tuning_job_events(
489        &mut self,
490        req: ListFineTuningJobEventsRequest,
491    ) -> Result<FineTuningPagination<FineTuningJobEvent>, APIError> {
492        self.get(&format!(
493            "fine_tuning/jobs/{}/events",
494            req.fine_tuning_job_id
495        ))
496        .await
497    }
498
499    pub async fn retrieve_fine_tuning_job(
500        &mut self,
501        req: RetrieveFineTuningJobRequest,
502    ) -> Result<FineTuningJobObject, APIError> {
503        self.get(&format!("fine_tuning/jobs/{}", req.fine_tuning_job_id))
504            .await
505    }
506
507    pub async fn cancel_fine_tuning_job(
508        &mut self,
509        req: CancelFineTuningJobRequest,
510    ) -> Result<FineTuningJobObject, APIError> {
511        self.post(
512            &format!("fine_tuning/jobs/{}/cancel", req.fine_tuning_job_id),
513            &req,
514        )
515        .await
516    }
517
518    pub async fn create_moderation(
519        &mut self,
520        req: CreateModerationRequest,
521    ) -> Result<CreateModerationResponse, APIError> {
522        self.post("moderations", &req).await
523    }
524
525    pub async fn create_assistant(
526        &mut self,
527        req: AssistantRequest,
528    ) -> Result<AssistantObject, APIError> {
529        self.post("assistants", &req).await
530    }
531
532    pub async fn retrieve_assistant(
533        &mut self,
534        assistant_id: String,
535    ) -> Result<AssistantObject, APIError> {
536        self.get(&format!("assistants/{assistant_id}")).await
537    }
538
539    pub async fn modify_assistant(
540        &mut self,
541        assistant_id: String,
542        req: AssistantRequest,
543    ) -> Result<AssistantObject, APIError> {
544        self.post(&format!("assistants/{assistant_id}"), &req).await
545    }
546
547    pub async fn delete_assistant(
548        &mut self,
549        assistant_id: String,
550    ) -> Result<common::DeletionStatus, APIError> {
551        self.delete(&format!("assistants/{assistant_id}")).await
552    }
553
554    pub async fn list_assistant(
555        &mut self,
556        limit: Option<i64>,
557        order: Option<String>,
558        after: Option<String>,
559        before: Option<String>,
560    ) -> Result<ListAssistant, APIError> {
561        let url = Self::query_params(limit, order, after, before, "assistants".to_string());
562        self.get(&url).await
563    }
564
565    pub async fn create_assistant_file(
566        &mut self,
567        assistant_id: String,
568        req: AssistantFileRequest,
569    ) -> Result<AssistantFileObject, APIError> {
570        self.post(&format!("assistants/{assistant_id}/files"), &req)
571            .await
572    }
573
574    pub async fn retrieve_assistant_file(
575        &mut self,
576        assistant_id: String,
577        file_id: String,
578    ) -> Result<AssistantFileObject, APIError> {
579        self.get(&format!("assistants/{assistant_id}/files/{file_id}"))
580            .await
581    }
582
583    pub async fn delete_assistant_file(
584        &mut self,
585        assistant_id: String,
586        file_id: String,
587    ) -> Result<common::DeletionStatus, APIError> {
588        self.delete(&format!("assistants/{assistant_id}/files/{file_id}"))
589            .await
590    }
591
592    pub async fn list_assistant_file(
593        &mut self,
594        assistant_id: String,
595        limit: Option<i64>,
596        order: Option<String>,
597        after: Option<String>,
598        before: Option<String>,
599    ) -> Result<ListAssistantFile, APIError> {
600        let url = Self::query_params(
601            limit,
602            order,
603            after,
604            before,
605            format!("assistants/{assistant_id}/files"),
606        );
607        self.get(&url).await
608    }
609
610    pub async fn create_thread(
611        &mut self,
612        req: CreateThreadRequest,
613    ) -> Result<ThreadObject, APIError> {
614        self.post("threads", &req).await
615    }
616
617    pub async fn retrieve_thread(&mut self, thread_id: String) -> Result<ThreadObject, APIError> {
618        self.get(&format!("threads/{thread_id}")).await
619    }
620
621    pub async fn modify_thread(
622        &mut self,
623        thread_id: String,
624        req: ModifyThreadRequest,
625    ) -> Result<ThreadObject, APIError> {
626        self.post(&format!("threads/{thread_id}"), &req).await
627    }
628
629    pub async fn delete_thread(
630        &mut self,
631        thread_id: String,
632    ) -> Result<common::DeletionStatus, APIError> {
633        self.delete(&format!("threads/{thread_id}")).await
634    }
635
636    pub async fn create_message(
637        &mut self,
638        thread_id: String,
639        req: CreateMessageRequest,
640    ) -> Result<MessageObject, APIError> {
641        self.post(&format!("threads/{thread_id}/messages"), &req)
642            .await
643    }
644
645    pub async fn retrieve_message(
646        &mut self,
647        thread_id: String,
648        message_id: String,
649    ) -> Result<MessageObject, APIError> {
650        self.get(&format!("threads/{thread_id}/messages/{message_id}"))
651            .await
652    }
653
654    pub async fn modify_message(
655        &mut self,
656        thread_id: String,
657        message_id: String,
658        req: ModifyMessageRequest,
659    ) -> Result<MessageObject, APIError> {
660        self.post(&format!("threads/{thread_id}/messages/{message_id}"), &req)
661            .await
662    }
663
664    pub async fn list_messages(&mut self, thread_id: String) -> Result<ListMessage, APIError> {
665        self.get(&format!("threads/{thread_id}/messages")).await
666    }
667
668    pub async fn retrieve_message_file(
669        &mut self,
670        thread_id: String,
671        message_id: String,
672        file_id: String,
673    ) -> Result<MessageFileObject, APIError> {
674        self.get(&format!(
675            "threads/{thread_id}/messages/{message_id}/files/{file_id}"
676        ))
677        .await
678    }
679
680    pub async fn list_message_file(
681        &mut self,
682        thread_id: String,
683        message_id: String,
684        limit: Option<i64>,
685        order: Option<String>,
686        after: Option<String>,
687        before: Option<String>,
688    ) -> Result<ListMessageFile, APIError> {
689        let url = Self::query_params(
690            limit,
691            order,
692            after,
693            before,
694            format!("threads/{thread_id}/messages/{message_id}/files"),
695        );
696        self.get(&url).await
697    }
698
699    pub async fn create_run(
700        &mut self,
701        thread_id: String,
702        req: CreateRunRequest,
703    ) -> Result<RunObject, APIError> {
704        self.post(&format!("threads/{thread_id}/runs"), &req).await
705    }
706
707    pub async fn retrieve_run(
708        &mut self,
709        thread_id: String,
710        run_id: String,
711    ) -> Result<RunObject, APIError> {
712        self.get(&format!("threads/{thread_id}/runs/{run_id}"))
713            .await
714    }
715
716    pub async fn modify_run(
717        &mut self,
718        thread_id: String,
719        run_id: String,
720        req: ModifyRunRequest,
721    ) -> Result<RunObject, APIError> {
722        self.post(&format!("threads/{thread_id}/runs/{run_id}"), &req)
723            .await
724    }
725
726    pub async fn list_run(
727        &mut self,
728        thread_id: String,
729        limit: Option<i64>,
730        order: Option<String>,
731        after: Option<String>,
732        before: Option<String>,
733    ) -> Result<ListRun, APIError> {
734        let url = Self::query_params(
735            limit,
736            order,
737            after,
738            before,
739            format!("threads/{thread_id}/runs"),
740        );
741        self.get(&url).await
742    }
743
744    pub async fn cancel_run(
745        &mut self,
746        thread_id: String,
747        run_id: String,
748    ) -> Result<RunObject, APIError> {
749        self.post(
750            &format!("threads/{thread_id}/runs/{run_id}/cancel"),
751            &ModifyRunRequest::default(),
752        )
753        .await
754    }
755
756    pub async fn create_thread_and_run(
757        &mut self,
758        req: CreateThreadAndRunRequest,
759    ) -> Result<RunObject, APIError> {
760        self.post("threads/runs", &req).await
761    }
762
763    pub async fn retrieve_run_step(
764        &mut self,
765        thread_id: String,
766        run_id: String,
767        step_id: String,
768    ) -> Result<RunStepObject, APIError> {
769        self.get(&format!(
770            "threads/{thread_id}/runs/{run_id}/steps/{step_id}"
771        ))
772        .await
773    }
774
775    pub async fn list_run_step(
776        &mut self,
777        thread_id: String,
778        run_id: String,
779        limit: Option<i64>,
780        order: Option<String>,
781        after: Option<String>,
782        before: Option<String>,
783    ) -> Result<ListRunStep, APIError> {
784        let url = Self::query_params(
785            limit,
786            order,
787            after,
788            before,
789            format!("threads/{thread_id}/runs/{run_id}/steps"),
790        );
791        self.get(&url).await
792    }
793
794    pub async fn create_batch(
795        &mut self,
796        req: CreateBatchRequest,
797    ) -> Result<BatchResponse, APIError> {
798        self.post("batches", &req).await
799    }
800
801    pub async fn retrieve_batch(&mut self, batch_id: String) -> Result<BatchResponse, APIError> {
802        self.get(&format!("batches/{batch_id}")).await
803    }
804
805    pub async fn cancel_batch(&mut self, batch_id: String) -> Result<BatchResponse, APIError> {
806        self.post(
807            &format!("batches/{batch_id}/cancel"),
808            &common::EmptyRequestBody {},
809        )
810        .await
811    }
812
813    pub async fn list_batch(
814        &mut self,
815        after: Option<String>,
816        limit: Option<i64>,
817    ) -> Result<ListBatchResponse, APIError> {
818        let url = Self::query_params(limit, None, after, None, "batches".to_string());
819        self.get(&url).await
820    }
821
822    pub async fn list_models(&mut self) -> Result<ModelsResponse, APIError> {
823        self.get("models").await
824    }
825
826    pub async fn retrieve_model(&mut self, model_id: String) -> Result<ModelResponse, APIError> {
827        self.get(&format!("models/{model_id}")).await
828    }
829
830    pub async fn delete_model(
831        &mut self,
832        model_id: String,
833    ) -> Result<common::DeletionStatus, APIError> {
834        self.delete(&format!("models/{model_id}")).await
835    }
836
837    fn build_url_with_preserved_query(&self, path: &str) -> Result<String, url::ParseError> {
838        let (base, query_opt) = match self.api_endpoint.split_once('?') {
839            Some((b, q)) => (b.trim_end_matches('/'), Some(q)),
840            None => (self.api_endpoint.trim_end_matches('/'), None),
841        };
842
843        let full_path = format!("{}/{}", base, path.trim_start_matches('/'));
844        let mut url = Url::parse(&full_path)?;
845
846        if let Some(query) = query_opt {
847            for (k, v) in url::form_urlencoded::parse(query.as_bytes()) {
848                url.query_pairs_mut().append_pair(&k, &v);
849            }
850        }
851        Ok(url.to_string())
852    }
853
854    fn query_params(
855        limit: Option<i64>,
856        order: Option<String>,
857        after: Option<String>,
858        before: Option<String>,
859        mut url: String,
860    ) -> String {
861        let mut params = vec![];
862        if let Some(limit) = limit {
863            params.push(format!("limit={limit}"));
864        }
865        if let Some(order) = order {
866            params.push(format!("order={order}"));
867        }
868        if let Some(after) = after {
869            params.push(format!("after={after}"));
870        }
871        if let Some(before) = before {
872            params.push(format!("before={before}"));
873        }
874        if !params.is_empty() {
875            url = format!("{}?{}", url, params.join("&"));
876        }
877        url
878    }
879
880    fn is_beta(path: &str) -> bool {
881        path.starts_with("assistants") || path.starts_with("threads")
882    }
883
884    fn create_form<T>(req: &T, file_field: &str) -> Result<Form, APIError>
885    where
886        T: Serialize,
887    {
888        let json = match serde_json::to_value(req) {
889            Ok(json) => json,
890            Err(e) => {
891                return Err(APIError::CustomError {
892                    message: e.to_string(),
893                })
894            }
895        };
896        let file_path = if let Value::Object(map) = &json {
897            map.get(file_field)
898                .and_then(|v| v.as_str())
899                .ok_or(APIError::CustomError {
900                    message: format!("Field '{file_field}' not found or not a string"),
901                })?
902        } else {
903            return Err(APIError::CustomError {
904                message: "Request is not a JSON object".to_string(),
905            });
906        };
907
908        let mut file = match File::open(file_path) {
909            Ok(file) => file,
910            Err(e) => {
911                return Err(APIError::CustomError {
912                    message: e.to_string(),
913                })
914            }
915        };
916        let mut buffer = Vec::new();
917        match file.read_to_end(&mut buffer) {
918            Ok(_) => {}
919            Err(e) => {
920                return Err(APIError::CustomError {
921                    message: e.to_string(),
922                })
923            }
924        }
925
926        let mut form =
927            Form::new().part("file", Part::bytes(buffer).file_name(file_path.to_string()));
928
929        if let Value::Object(map) = json {
930            for (key, value) in map.into_iter() {
931                if key != file_field {
932                    match value {
933                        Value::String(s) => {
934                            form = form.text(key, s);
935                        }
936                        Value::Number(n) => {
937                            form = form.text(key, n.to_string());
938                        }
939                        _ => {}
940                    }
941                }
942            }
943        }
944
945        Ok(form)
946    }
947
948    fn create_form_from_bytes<T>(req: &T, bytes: Vec<u8>) -> Result<Form, APIError>
949    where
950        T: Serialize,
951    {
952        let json = match serde_json::to_value(req) {
953            Ok(json) => json,
954            Err(e) => {
955                return Err(APIError::CustomError {
956                    message: e.to_string(),
957                })
958            }
959        };
960
961        let mut form = Form::new().part("file", Part::bytes(bytes.clone()).file_name("file.mp3"));
962
963        if let Value::Object(map) = json {
964            for (key, value) in map.into_iter() {
965                match value {
966                    Value::String(s) => {
967                        form = form.text(key, s);
968                    }
969                    Value::Number(n) => {
970                        form = form.text(key, n.to_string());
971                    }
972                    _ => {}
973                }
974            }
975        }
976
977        Ok(form)
978    }
979}