openai_api_rs/v1/
api.rs

1use crate::v1::assistant::{
2    AssistantFileObject, AssistantFileRequest, AssistantObject, AssistantRequest, ListAssistant,
3    ListAssistantFile,
4};
5use crate::v1::audio::{
6    AudioSpeechRequest, AudioSpeechResponse, AudioTranscriptionRequest, AudioTranscriptionResponse,
7    AudioTranslationRequest, AudioTranslationResponse,
8};
9use crate::v1::batch::{BatchResponse, CreateBatchRequest, ListBatchResponse};
10use crate::v1::chat_completion::chat_completion::{ChatCompletionRequest, ChatCompletionResponse};
11use crate::v1::chat_completion::chat_completion_stream::{
12    ChatCompletionStream, ChatCompletionStreamRequest, ChatCompletionStreamResponse,
13};
14use crate::v1::common;
15use crate::v1::completion::{CompletionRequest, CompletionResponse};
16use crate::v1::edit::{EditRequest, EditResponse};
17use crate::v1::embedding::{EmbeddingRequest, EmbeddingResponse};
18use crate::v1::error::APIError;
19use crate::v1::file::{
20    FileDeleteRequest, FileDeleteResponse, FileListResponse, FileRetrieveResponse,
21    FileUploadRequest, FileUploadResponse,
22};
23use crate::v1::fine_tuning::{
24    CancelFineTuningJobRequest, CreateFineTuningJobRequest, FineTuningJobEvent,
25    FineTuningJobObject, FineTuningPagination, ListFineTuningJobEventsRequest,
26    RetrieveFineTuningJobRequest,
27};
28use crate::v1::image::{
29    ImageEditRequest, ImageEditResponse, ImageGenerationRequest, ImageGenerationResponse,
30    ImageVariationRequest, ImageVariationResponse,
31};
32use crate::v1::message::{
33    CreateMessageRequest, ListMessage, ListMessageFile, MessageFileObject, MessageObject,
34    ModifyMessageRequest,
35};
36use crate::v1::model::{ModelResponse, ModelsResponse};
37use crate::v1::moderation::{CreateModerationRequest, CreateModerationResponse};
38use crate::v1::responses::{
39    CountTokensRequest, CountTokensResponse, CreateResponseRequest, ListResponses, ResponseObject,
40};
41use crate::v1::run::{
42    CreateRunRequest, CreateThreadAndRunRequest, ListRun, ListRunStep, ModifyRunRequest, RunObject,
43    RunStepObject,
44};
45use crate::v1::thread::{CreateThreadRequest, ModifyThreadRequest, ThreadObject};
46
47use bytes::Bytes;
48use futures_util::Stream;
49use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
50use reqwest::multipart::{Form, Part};
51use reqwest::{Client, Method, Response};
52use serde::Serialize;
53use serde_json::{to_value, Value};
54use url::Url;
55
56use std::error::Error;
57use std::fs::{create_dir_all, File};
58use std::io::Read;
59use std::io::Write;
60use std::path::Path;
61
62const API_URL_V1: &str = "https://api.openai.com/v1";
63
64#[derive(Default)]
65pub struct OpenAIClientBuilder {
66    api_endpoint: Option<String>,
67    api_key: Option<String>,
68    organization: Option<String>,
69    proxy: Option<String>,
70    timeout: Option<u64>,
71    headers: Option<HeaderMap>,
72}
73
74#[derive(Debug)]
75pub struct OpenAIClient {
76    api_endpoint: String,
77    api_key: Option<String>,
78    organization: Option<String>,
79    proxy: Option<String>,
80    timeout: Option<u64>,
81    headers: Option<HeaderMap>,
82    pub response_headers: Option<HeaderMap>,
83}
84
85impl OpenAIClientBuilder {
86    pub fn new() -> Self {
87        Self::default()
88    }
89
90    pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
91        self.api_key = Some(api_key.into());
92        self
93    }
94
95    pub fn with_endpoint(mut self, endpoint: impl Into<String>) -> Self {
96        self.api_endpoint = Some(endpoint.into());
97        self
98    }
99
100    pub fn with_organization(mut self, organization: impl Into<String>) -> Self {
101        self.organization = Some(organization.into());
102        self
103    }
104
105    pub fn with_proxy(mut self, proxy: impl Into<String>) -> Self {
106        self.proxy = Some(proxy.into());
107        self
108    }
109
110    pub fn with_timeout(mut self, timeout: u64) -> Self {
111        self.timeout = Some(timeout);
112        self
113    }
114
115    pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
116        let headers = self.headers.get_or_insert_with(HeaderMap::new);
117        headers.insert(
118            HeaderName::from_bytes(key.into().as_bytes()).expect("Invalid header name"),
119            HeaderValue::from_str(&value.into()).expect("Invalid header value"),
120        );
121        self
122    }
123
124    pub fn build(self) -> Result<OpenAIClient, Box<dyn Error>> {
125        let api_endpoint = self.api_endpoint.unwrap_or_else(|| {
126            std::env::var("OPENAI_API_BASE").unwrap_or_else(|_| API_URL_V1.to_owned())
127        });
128
129        Ok(OpenAIClient {
130            api_endpoint,
131            api_key: self.api_key,
132            organization: self.organization,
133            proxy: self.proxy,
134            timeout: self.timeout,
135            headers: self.headers,
136            response_headers: None,
137        })
138    }
139}
140
141impl OpenAIClient {
142    pub fn builder() -> OpenAIClientBuilder {
143        OpenAIClientBuilder::new()
144    }
145
146    async fn build_request(&self, method: Method, path: &str) -> reqwest::RequestBuilder {
147        let url = self
148            .build_url_with_preserved_query(path)
149            .unwrap_or_else(|_| format!("{}/{}", self.api_endpoint, path));
150
151        let client = Client::builder();
152
153        #[cfg(feature = "rustls")]
154        let client = client.use_rustls_tls();
155
156        let client = if let Some(timeout) = self.timeout {
157            client.timeout(std::time::Duration::from_secs(timeout))
158        } else {
159            client
160        };
161
162        let client = if let Some(proxy) = &self.proxy {
163            client.proxy(reqwest::Proxy::all(proxy).unwrap())
164        } else {
165            client
166        };
167
168        let client = client.build().unwrap();
169
170        let mut request = client.request(method, url);
171
172        if let Some(api_key) = &self.api_key {
173            request = request.header("Authorization", format!("Bearer {api_key}"));
174        }
175
176        if let Some(organization) = &self.organization {
177            request = request.header("openai-organization", organization);
178        }
179
180        if let Some(headers) = &self.headers {
181            for (key, value) in headers {
182                request = request.header(key, value);
183            }
184        }
185
186        if Self::is_beta(path) {
187            request = request.header("OpenAI-Beta", "assistants=v2");
188        }
189
190        request
191    }
192
193    async fn post<T: serde::de::DeserializeOwned>(
194        &mut self,
195        path: &str,
196        body: &impl serde::ser::Serialize,
197    ) -> Result<T, APIError> {
198        let request = self.build_request(Method::POST, path).await;
199        let request = request.json(body);
200        let response = request.send().await?;
201        self.handle_response(response).await
202    }
203
204    async fn get<T: serde::de::DeserializeOwned>(&mut self, path: &str) -> Result<T, APIError> {
205        let request = self.build_request(Method::GET, path).await;
206        let response = request.send().await?;
207        self.handle_response(response).await
208    }
209
210    async fn get_raw(&self, path: &str) -> Result<Bytes, APIError> {
211        let request = self.build_request(Method::GET, path).await;
212        let response = request.send().await?;
213        Ok(response.bytes().await?)
214    }
215
216    async fn delete<T: serde::de::DeserializeOwned>(&mut self, path: &str) -> Result<T, APIError> {
217        let request = self.build_request(Method::DELETE, path).await;
218        let response = request.send().await?;
219        self.handle_response(response).await
220    }
221
222    async fn post_form<T: serde::de::DeserializeOwned>(
223        &mut self,
224        path: &str,
225        form: Form,
226    ) -> Result<T, APIError> {
227        let request = self.build_request(Method::POST, path).await;
228        let request = request.multipart(form);
229        let response = request.send().await?;
230        self.handle_response(response).await
231    }
232
233    async fn post_form_raw(&self, path: &str, form: Form) -> Result<Bytes, APIError> {
234        let request = self.build_request(Method::POST, path).await;
235        let request = request.multipart(form);
236        let response = request.send().await?;
237        Ok(response.bytes().await?)
238    }
239
240    async fn handle_response<T: serde::de::DeserializeOwned>(
241        &mut self,
242        response: Response,
243    ) -> Result<T, APIError> {
244        let status = response.status();
245        let headers = response.headers().clone();
246        if status.is_success() {
247            let text = response.text().await.unwrap_or_else(|_| "".to_string());
248            match serde_json::from_str::<T>(&text) {
249                Ok(parsed) => {
250                    self.response_headers = Some(headers);
251                    Ok(parsed)
252                }
253                Err(e) => Err(APIError::CustomError {
254                    message: format!("Failed to parse JSON: {e} / response {text}"),
255                }),
256            }
257        } else {
258            let error_message = response
259                .text()
260                .await
261                .unwrap_or_else(|_| "Unknown error".to_string());
262            Err(APIError::CustomError {
263                message: format!("{status}: {error_message}"),
264            })
265        }
266    }
267
268    pub async fn completion(
269        &mut self,
270        req: CompletionRequest,
271    ) -> Result<CompletionResponse, APIError> {
272        self.post("completions", &req).await
273    }
274
275    pub async fn edit(&mut self, req: EditRequest) -> Result<EditResponse, APIError> {
276        self.post("edits", &req).await
277    }
278
279    pub async fn image_generation(
280        &mut self,
281        req: ImageGenerationRequest,
282    ) -> Result<ImageGenerationResponse, APIError> {
283        self.post("images/generations", &req).await
284    }
285
286    pub async fn image_edit(
287        &mut self,
288        req: ImageEditRequest,
289    ) -> Result<ImageEditResponse, APIError> {
290        self.post("images/edits", &req).await
291    }
292
293    pub async fn image_variation(
294        &mut self,
295        req: ImageVariationRequest,
296    ) -> Result<ImageVariationResponse, APIError> {
297        self.post("images/variations", &req).await
298    }
299
300    pub async fn embedding(
301        &mut self,
302        req: EmbeddingRequest,
303    ) -> Result<EmbeddingResponse, APIError> {
304        self.post("embeddings", &req).await
305    }
306
307    pub async fn file_list(&mut self) -> Result<FileListResponse, APIError> {
308        self.get("files").await
309    }
310
311    pub async fn upload_file(
312        &mut self,
313        req: FileUploadRequest,
314    ) -> Result<FileUploadResponse, APIError> {
315        let form = Self::create_form(&req, "file")?;
316        self.post_form("files", form).await
317    }
318
319    pub async fn delete_file(
320        &mut self,
321        req: FileDeleteRequest,
322    ) -> Result<FileDeleteResponse, APIError> {
323        self.delete(&format!("files/{}", req.file_id)).await
324    }
325
326    pub async fn retrieve_file(
327        &mut self,
328        file_id: String,
329    ) -> Result<FileRetrieveResponse, APIError> {
330        self.get(&format!("files/{file_id}")).await
331    }
332
333    pub async fn retrieve_file_content(&self, file_id: String) -> Result<Bytes, APIError> {
334        self.get_raw(&format!("files/{file_id}/content")).await
335    }
336
337    pub async fn chat_completion(
338        &mut self,
339        req: ChatCompletionRequest,
340    ) -> Result<ChatCompletionResponse, APIError> {
341        self.post("chat/completions", &req).await
342    }
343
344    pub async fn chat_completion_stream(
345        &mut self,
346        req: ChatCompletionStreamRequest,
347    ) -> Result<impl Stream<Item = ChatCompletionStreamResponse>, APIError> {
348        let mut payload = to_value(&req).map_err(|err| APIError::CustomError {
349            message: format!("Failed to serialize request: {}", err),
350        })?;
351
352        if let Some(obj) = payload.as_object_mut() {
353            obj.insert("stream".into(), Value::Bool(true));
354        }
355
356        let request = self.build_request(Method::POST, "chat/completions").await;
357        let request = request.json(&payload);
358        let response = request.send().await?;
359
360        if response.status().is_success() {
361            Ok(ChatCompletionStream {
362                response: Box::pin(response.bytes_stream()),
363                buffer: String::new(),
364                first_chunk: true,
365            })
366        } else {
367            let error_text = response
368                .text()
369                .await
370                .unwrap_or_else(|_| String::from("Unknown error"));
371
372            Err(APIError::CustomError {
373                message: error_text,
374            })
375        }
376    }
377
378    pub async fn audio_transcription(
379        &mut self,
380        req: AudioTranscriptionRequest,
381    ) -> Result<AudioTranscriptionResponse, APIError> {
382        // https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-response_format
383        if let Some(response_format) = &req.response_format {
384            if response_format != "json" && response_format != "verbose_json" {
385                return Err(APIError::CustomError {
386                    message: "response_format must be either 'json' or 'verbose_json' please use audio_transcription_raw".to_string(),
387                });
388            }
389        }
390        let form: Form;
391        if req.clone().file.is_some() {
392            form = Self::create_form(&req, "file")?;
393        } else if let Some(bytes) = req.clone().bytes {
394            form = Self::create_form_from_bytes(&req, bytes)?;
395        } else {
396            return Err(APIError::CustomError {
397                message: "Either file or bytes must be provided".to_string(),
398            });
399        }
400        self.post_form("audio/transcriptions", form).await
401    }
402
403    pub async fn audio_transcription_raw(
404        &mut self,
405        req: AudioTranscriptionRequest,
406    ) -> Result<Bytes, APIError> {
407        // https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-response_format
408        if let Some(response_format) = &req.response_format {
409            if response_format != "text" && response_format != "srt" && response_format != "vtt" {
410                return Err(APIError::CustomError {
411                    message: "response_format must be either 'text', 'srt' or 'vtt', please use audio_transcription".to_string(),
412                });
413            }
414        }
415        let form: Form;
416        if req.clone().file.is_some() {
417            form = Self::create_form(&req, "file")?;
418        } else if let Some(bytes) = req.clone().bytes {
419            form = Self::create_form_from_bytes(&req, bytes)?;
420        } else {
421            return Err(APIError::CustomError {
422                message: "Either file or bytes must be provided".to_string(),
423            });
424        }
425        self.post_form_raw("audio/transcriptions", form).await
426    }
427
428    pub async fn audio_translation(
429        &mut self,
430        req: AudioTranslationRequest,
431    ) -> Result<AudioTranslationResponse, APIError> {
432        let form = Self::create_form(&req, "file")?;
433        self.post_form("audio/translations", form).await
434    }
435
436    pub async fn audio_speech(
437        &mut self,
438        req: AudioSpeechRequest,
439    ) -> Result<AudioSpeechResponse, APIError> {
440        let request = self.build_request(Method::POST, "audio/speech").await;
441        let request = request.json(&req);
442        let response = request.send().await?;
443        let headers = response.headers().clone();
444        let bytes = response.bytes().await?;
445        let path = Path::new(req.output.as_str());
446        if let Some(parent) = path.parent() {
447            match create_dir_all(parent) {
448                Ok(_) => {}
449                Err(e) => {
450                    return Err(APIError::CustomError {
451                        message: e.to_string(),
452                    })
453                }
454            }
455        }
456        match File::create(path) {
457            Ok(mut file) => match file.write_all(&bytes) {
458                Ok(_) => {}
459                Err(e) => {
460                    return Err(APIError::CustomError {
461                        message: e.to_string(),
462                    })
463                }
464            },
465            Err(e) => {
466                return Err(APIError::CustomError {
467                    message: e.to_string(),
468                })
469            }
470        }
471
472        Ok(AudioSpeechResponse {
473            result: true,
474            headers: Some(headers),
475        })
476    }
477
478    pub async fn create_fine_tuning_job(
479        &mut self,
480        req: CreateFineTuningJobRequest,
481    ) -> Result<FineTuningJobObject, APIError> {
482        self.post("fine_tuning/jobs", &req).await
483    }
484
485    pub async fn list_fine_tuning_jobs(
486        &mut self,
487    ) -> Result<FineTuningPagination<FineTuningJobObject>, APIError> {
488        self.get("fine_tuning/jobs").await
489    }
490
491    pub async fn list_fine_tuning_job_events(
492        &mut self,
493        req: ListFineTuningJobEventsRequest,
494    ) -> Result<FineTuningPagination<FineTuningJobEvent>, APIError> {
495        self.get(&format!(
496            "fine_tuning/jobs/{}/events",
497            req.fine_tuning_job_id
498        ))
499        .await
500    }
501
502    pub async fn retrieve_fine_tuning_job(
503        &mut self,
504        req: RetrieveFineTuningJobRequest,
505    ) -> Result<FineTuningJobObject, APIError> {
506        self.get(&format!("fine_tuning/jobs/{}", req.fine_tuning_job_id))
507            .await
508    }
509
510    pub async fn cancel_fine_tuning_job(
511        &mut self,
512        req: CancelFineTuningJobRequest,
513    ) -> Result<FineTuningJobObject, APIError> {
514        self.post(
515            &format!("fine_tuning/jobs/{}/cancel", req.fine_tuning_job_id),
516            &req,
517        )
518        .await
519    }
520
521    pub async fn create_moderation(
522        &mut self,
523        req: CreateModerationRequest,
524    ) -> Result<CreateModerationResponse, APIError> {
525        self.post("moderations", &req).await
526    }
527
528    pub async fn create_assistant(
529        &mut self,
530        req: AssistantRequest,
531    ) -> Result<AssistantObject, APIError> {
532        self.post("assistants", &req).await
533    }
534
535    pub async fn retrieve_assistant(
536        &mut self,
537        assistant_id: String,
538    ) -> Result<AssistantObject, APIError> {
539        self.get(&format!("assistants/{assistant_id}")).await
540    }
541
542    pub async fn modify_assistant(
543        &mut self,
544        assistant_id: String,
545        req: AssistantRequest,
546    ) -> Result<AssistantObject, APIError> {
547        self.post(&format!("assistants/{assistant_id}"), &req).await
548    }
549
550    pub async fn delete_assistant(
551        &mut self,
552        assistant_id: String,
553    ) -> Result<common::DeletionStatus, APIError> {
554        self.delete(&format!("assistants/{assistant_id}")).await
555    }
556
557    pub async fn list_assistant(
558        &mut self,
559        limit: Option<i64>,
560        order: Option<String>,
561        after: Option<String>,
562        before: Option<String>,
563    ) -> Result<ListAssistant, APIError> {
564        let url = Self::query_params(limit, order, after, before, "assistants".to_string());
565        self.get(&url).await
566    }
567
568    pub async fn create_assistant_file(
569        &mut self,
570        assistant_id: String,
571        req: AssistantFileRequest,
572    ) -> Result<AssistantFileObject, APIError> {
573        self.post(&format!("assistants/{assistant_id}/files"), &req)
574            .await
575    }
576
577    pub async fn retrieve_assistant_file(
578        &mut self,
579        assistant_id: String,
580        file_id: String,
581    ) -> Result<AssistantFileObject, APIError> {
582        self.get(&format!("assistants/{assistant_id}/files/{file_id}"))
583            .await
584    }
585
586    pub async fn delete_assistant_file(
587        &mut self,
588        assistant_id: String,
589        file_id: String,
590    ) -> Result<common::DeletionStatus, APIError> {
591        self.delete(&format!("assistants/{assistant_id}/files/{file_id}"))
592            .await
593    }
594
595    pub async fn list_assistant_file(
596        &mut self,
597        assistant_id: String,
598        limit: Option<i64>,
599        order: Option<String>,
600        after: Option<String>,
601        before: Option<String>,
602    ) -> Result<ListAssistantFile, APIError> {
603        let url = Self::query_params(
604            limit,
605            order,
606            after,
607            before,
608            format!("assistants/{assistant_id}/files"),
609        );
610        self.get(&url).await
611    }
612
613    pub async fn create_thread(
614        &mut self,
615        req: CreateThreadRequest,
616    ) -> Result<ThreadObject, APIError> {
617        self.post("threads", &req).await
618    }
619
620    pub async fn retrieve_thread(&mut self, thread_id: String) -> Result<ThreadObject, APIError> {
621        self.get(&format!("threads/{thread_id}")).await
622    }
623
624    pub async fn modify_thread(
625        &mut self,
626        thread_id: String,
627        req: ModifyThreadRequest,
628    ) -> Result<ThreadObject, APIError> {
629        self.post(&format!("threads/{thread_id}"), &req).await
630    }
631
632    pub async fn delete_thread(
633        &mut self,
634        thread_id: String,
635    ) -> Result<common::DeletionStatus, APIError> {
636        self.delete(&format!("threads/{thread_id}")).await
637    }
638
639    pub async fn create_message(
640        &mut self,
641        thread_id: String,
642        req: CreateMessageRequest,
643    ) -> Result<MessageObject, APIError> {
644        self.post(&format!("threads/{thread_id}/messages"), &req)
645            .await
646    }
647
648    pub async fn retrieve_message(
649        &mut self,
650        thread_id: String,
651        message_id: String,
652    ) -> Result<MessageObject, APIError> {
653        self.get(&format!("threads/{thread_id}/messages/{message_id}"))
654            .await
655    }
656
657    pub async fn modify_message(
658        &mut self,
659        thread_id: String,
660        message_id: String,
661        req: ModifyMessageRequest,
662    ) -> Result<MessageObject, APIError> {
663        self.post(&format!("threads/{thread_id}/messages/{message_id}"), &req)
664            .await
665    }
666
667    pub async fn list_messages(&mut self, thread_id: String) -> Result<ListMessage, APIError> {
668        self.get(&format!("threads/{thread_id}/messages")).await
669    }
670
671    pub async fn retrieve_message_file(
672        &mut self,
673        thread_id: String,
674        message_id: String,
675        file_id: String,
676    ) -> Result<MessageFileObject, APIError> {
677        self.get(&format!(
678            "threads/{thread_id}/messages/{message_id}/files/{file_id}"
679        ))
680        .await
681    }
682
683    pub async fn list_message_file(
684        &mut self,
685        thread_id: String,
686        message_id: String,
687        limit: Option<i64>,
688        order: Option<String>,
689        after: Option<String>,
690        before: Option<String>,
691    ) -> Result<ListMessageFile, APIError> {
692        let url = Self::query_params(
693            limit,
694            order,
695            after,
696            before,
697            format!("threads/{thread_id}/messages/{message_id}/files"),
698        );
699        self.get(&url).await
700    }
701
702    pub async fn create_run(
703        &mut self,
704        thread_id: String,
705        req: CreateRunRequest,
706    ) -> Result<RunObject, APIError> {
707        self.post(&format!("threads/{thread_id}/runs"), &req).await
708    }
709
710    pub async fn retrieve_run(
711        &mut self,
712        thread_id: String,
713        run_id: String,
714    ) -> Result<RunObject, APIError> {
715        self.get(&format!("threads/{thread_id}/runs/{run_id}"))
716            .await
717    }
718
719    pub async fn modify_run(
720        &mut self,
721        thread_id: String,
722        run_id: String,
723        req: ModifyRunRequest,
724    ) -> Result<RunObject, APIError> {
725        self.post(&format!("threads/{thread_id}/runs/{run_id}"), &req)
726            .await
727    }
728
729    pub async fn list_run(
730        &mut self,
731        thread_id: String,
732        limit: Option<i64>,
733        order: Option<String>,
734        after: Option<String>,
735        before: Option<String>,
736    ) -> Result<ListRun, APIError> {
737        let url = Self::query_params(
738            limit,
739            order,
740            after,
741            before,
742            format!("threads/{thread_id}/runs"),
743        );
744        self.get(&url).await
745    }
746
747    pub async fn cancel_run(
748        &mut self,
749        thread_id: String,
750        run_id: String,
751    ) -> Result<RunObject, APIError> {
752        self.post(
753            &format!("threads/{thread_id}/runs/{run_id}/cancel"),
754            &ModifyRunRequest::default(),
755        )
756        .await
757    }
758
759    pub async fn create_thread_and_run(
760        &mut self,
761        req: CreateThreadAndRunRequest,
762    ) -> Result<RunObject, APIError> {
763        self.post("threads/runs", &req).await
764    }
765
766    pub async fn retrieve_run_step(
767        &mut self,
768        thread_id: String,
769        run_id: String,
770        step_id: String,
771    ) -> Result<RunStepObject, APIError> {
772        self.get(&format!(
773            "threads/{thread_id}/runs/{run_id}/steps/{step_id}"
774        ))
775        .await
776    }
777
778    pub async fn list_run_step(
779        &mut self,
780        thread_id: String,
781        run_id: String,
782        limit: Option<i64>,
783        order: Option<String>,
784        after: Option<String>,
785        before: Option<String>,
786    ) -> Result<ListRunStep, APIError> {
787        let url = Self::query_params(
788            limit,
789            order,
790            after,
791            before,
792            format!("threads/{thread_id}/runs/{run_id}/steps"),
793        );
794        self.get(&url).await
795    }
796
797    pub async fn create_batch(
798        &mut self,
799        req: CreateBatchRequest,
800    ) -> Result<BatchResponse, APIError> {
801        self.post("batches", &req).await
802    }
803
804    pub async fn retrieve_batch(&mut self, batch_id: String) -> Result<BatchResponse, APIError> {
805        self.get(&format!("batches/{batch_id}")).await
806    }
807
808    pub async fn cancel_batch(&mut self, batch_id: String) -> Result<BatchResponse, APIError> {
809        self.post(
810            &format!("batches/{batch_id}/cancel"),
811            &common::EmptyRequestBody {},
812        )
813        .await
814    }
815
816    pub async fn list_batch(
817        &mut self,
818        after: Option<String>,
819        limit: Option<i64>,
820    ) -> Result<ListBatchResponse, APIError> {
821        let url = Self::query_params(limit, None, after, None, "batches".to_string());
822        self.get(&url).await
823    }
824
825    // Responses API
826    pub async fn create_response(
827        &mut self,
828        req: CreateResponseRequest,
829    ) -> Result<ResponseObject, APIError> {
830        self.post("responses", &req).await
831    }
832
833    pub async fn retrieve_response(
834        &mut self,
835        response_id: String,
836    ) -> Result<ResponseObject, APIError> {
837        self.get(&format!("responses/{response_id}")).await
838    }
839
840    pub async fn delete_response(
841        &mut self,
842        response_id: String,
843    ) -> Result<common::DeletionStatus, APIError> {
844        self.delete(&format!("responses/{response_id}")).await
845    }
846
847    pub async fn cancel_response(
848        &mut self,
849        response_id: String,
850    ) -> Result<ResponseObject, APIError> {
851        self.post(
852            &format!("responses/{response_id}/cancel"),
853            &common::EmptyRequestBody {},
854        )
855        .await
856    }
857
858    pub async fn list_response_input_items(
859        &mut self,
860        response_id: String,
861        after: Option<String>,
862        limit: Option<i64>,
863        order: Option<String>,
864    ) -> Result<ListResponses, APIError> {
865        let mut url = format!("responses/{}/input_items", response_id);
866        let mut params = vec![];
867        if let Some(after) = after {
868            params.push(format!("after={}", after));
869        }
870        if let Some(limit) = limit {
871            params.push(format!("limit={}", limit));
872        }
873        if let Some(order) = order {
874            params.push(format!("order={}", order));
875        }
876        if !params.is_empty() {
877            url = format!("{}?{}", url, params.join("&"));
878        }
879        self.get(&url).await
880    }
881
882    pub async fn count_response_input_tokens(
883        &mut self,
884        req: CountTokensRequest,
885    ) -> Result<CountTokensResponse, APIError> {
886        self.post("responses/input_tokens", &req).await
887    }
888
889    pub async fn list_models(&mut self) -> Result<ModelsResponse, APIError> {
890        self.get("models").await
891    }
892
893    pub async fn retrieve_model(&mut self, model_id: String) -> Result<ModelResponse, APIError> {
894        self.get(&format!("models/{model_id}")).await
895    }
896
897    pub async fn delete_model(
898        &mut self,
899        model_id: String,
900    ) -> Result<common::DeletionStatus, APIError> {
901        self.delete(&format!("models/{model_id}")).await
902    }
903
904    fn build_url_with_preserved_query(&self, path: &str) -> Result<String, url::ParseError> {
905        let (base, query_opt) = match self.api_endpoint.split_once('?') {
906            Some((b, q)) => (b.trim_end_matches('/'), Some(q)),
907            None => (self.api_endpoint.trim_end_matches('/'), None),
908        };
909
910        let full_path = format!("{}/{}", base, path.trim_start_matches('/'));
911        let mut url = Url::parse(&full_path)?;
912
913        if let Some(query) = query_opt {
914            for (k, v) in url::form_urlencoded::parse(query.as_bytes()) {
915                url.query_pairs_mut().append_pair(&k, &v);
916            }
917        }
918        Ok(url.to_string())
919    }
920
921    fn query_params(
922        limit: Option<i64>,
923        order: Option<String>,
924        after: Option<String>,
925        before: Option<String>,
926        mut url: String,
927    ) -> String {
928        let mut params = vec![];
929        if let Some(limit) = limit {
930            params.push(format!("limit={limit}"));
931        }
932        if let Some(order) = order {
933            params.push(format!("order={order}"));
934        }
935        if let Some(after) = after {
936            params.push(format!("after={after}"));
937        }
938        if let Some(before) = before {
939            params.push(format!("before={before}"));
940        }
941        if !params.is_empty() {
942            url = format!("{}?{}", url, params.join("&"));
943        }
944        url
945    }
946
947    fn is_beta(path: &str) -> bool {
948        path.starts_with("assistants") || path.starts_with("threads")
949    }
950
951    fn create_form<T>(req: &T, file_field: &str) -> Result<Form, APIError>
952    where
953        T: Serialize,
954    {
955        let json = match serde_json::to_value(req) {
956            Ok(json) => json,
957            Err(e) => {
958                return Err(APIError::CustomError {
959                    message: e.to_string(),
960                })
961            }
962        };
963        let file_path = if let Value::Object(map) = &json {
964            map.get(file_field)
965                .and_then(|v| v.as_str())
966                .ok_or(APIError::CustomError {
967                    message: format!("Field '{file_field}' not found or not a string"),
968                })?
969        } else {
970            return Err(APIError::CustomError {
971                message: "Request is not a JSON object".to_string(),
972            });
973        };
974
975        let mut file = match File::open(file_path) {
976            Ok(file) => file,
977            Err(e) => {
978                return Err(APIError::CustomError {
979                    message: e.to_string(),
980                })
981            }
982        };
983        let mut buffer = Vec::new();
984        match file.read_to_end(&mut buffer) {
985            Ok(_) => {}
986            Err(e) => {
987                return Err(APIError::CustomError {
988                    message: e.to_string(),
989                })
990            }
991        }
992
993        let mut form =
994            Form::new().part("file", Part::bytes(buffer).file_name(file_path.to_string()));
995
996        if let Value::Object(map) = json {
997            for (key, value) in map.into_iter() {
998                if key != file_field {
999                    match value {
1000                        Value::String(s) => {
1001                            form = form.text(key, s);
1002                        }
1003                        Value::Number(n) => {
1004                            form = form.text(key, n.to_string());
1005                        }
1006                        _ => {}
1007                    }
1008                }
1009            }
1010        }
1011
1012        Ok(form)
1013    }
1014
1015    fn create_form_from_bytes<T>(req: &T, bytes: Vec<u8>) -> Result<Form, APIError>
1016    where
1017        T: Serialize,
1018    {
1019        let json = match serde_json::to_value(req) {
1020            Ok(json) => json,
1021            Err(e) => {
1022                return Err(APIError::CustomError {
1023                    message: e.to_string(),
1024                })
1025            }
1026        };
1027
1028        let mut form = Form::new().part("file", Part::bytes(bytes.clone()).file_name("file.mp3"));
1029
1030        if let Value::Object(map) = json {
1031            for (key, value) in map.into_iter() {
1032                match value {
1033                    Value::String(s) => {
1034                        form = form.text(key, s);
1035                    }
1036                    Value::Number(n) => {
1037                        form = form.text(key, n.to_string());
1038                    }
1039                    _ => {}
1040                }
1041            }
1042        }
1043
1044        Ok(form)
1045    }
1046}