openai_api_rs/v1/
api.rs

1use crate::v1::assistant::{
2    AssistantFileObject, AssistantFileRequest, AssistantObject, AssistantRequest, ListAssistant,
3    ListAssistantFile,
4};
5use crate::v1::audio::{
6    AudioSpeechRequest, AudioSpeechResponse, AudioTranscriptionRequest, AudioTranscriptionResponse,
7    AudioTranslationRequest, AudioTranslationResponse,
8};
9use crate::v1::batch::{BatchResponse, CreateBatchRequest, ListBatchResponse};
10use crate::v1::chat_completion::chat_completion::{ChatCompletionRequest, ChatCompletionResponse};
11use crate::v1::chat_completion::chat_completion_stream::{
12    ChatCompletionStream, ChatCompletionStreamRequest, ChatCompletionStreamResponse,
13};
14use crate::v1::common;
15use crate::v1::completion::{CompletionRequest, CompletionResponse};
16use crate::v1::edit::{EditRequest, EditResponse};
17use crate::v1::embedding::{EmbeddingRequest, EmbeddingResponse};
18use crate::v1::error::APIError;
19use crate::v1::file::{
20    FileDeleteRequest, FileDeleteResponse, FileListResponse, FileRetrieveResponse,
21    FileUploadRequest, FileUploadResponse,
22};
23use crate::v1::fine_tuning::{
24    CancelFineTuningJobRequest, CreateFineTuningJobRequest, FineTuningJobEvent,
25    FineTuningJobObject, FineTuningPagination, ListFineTuningJobEventsRequest,
26    RetrieveFineTuningJobRequest,
27};
28use crate::v1::image::{
29    ImageEditRequest, ImageEditResponse, ImageGenerationRequest, ImageGenerationResponse,
30    ImageVariationRequest, ImageVariationResponse,
31};
32use crate::v1::message::{
33    CreateMessageRequest, ListMessage, ListMessageFile, MessageFileObject, MessageObject,
34    ModifyMessageRequest,
35};
36use crate::v1::model::{ModelResponse, ModelsResponse};
37use crate::v1::moderation::{CreateModerationRequest, CreateModerationResponse};
38use crate::v1::responses::responses::{
39    CountTokensRequest, CountTokensResponse, CreateResponseRequest, ListResponses, ResponseObject,
40};
41use crate::v1::responses::responses_stream::{
42    CreateResponseStreamRequest, ResponseStream, ResponseStreamResponse,
43};
44use crate::v1::run::{
45    CreateRunRequest, CreateThreadAndRunRequest, ListRun, ListRunStep, ModifyRunRequest, RunObject,
46    RunStepObject,
47};
48use crate::v1::thread::{CreateThreadRequest, ModifyThreadRequest, ThreadObject};
49
50use bytes::Bytes;
51use futures_util::Stream;
52use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
53use reqwest::multipart::{Form, Part};
54use reqwest::{Client, Method, Response};
55use serde::Serialize;
56use serde_json::{to_value, Value};
57use url::Url;
58
59use std::error::Error;
60use std::fs::{create_dir_all, File};
61use std::io::Read;
62use std::io::Write;
63use std::path::Path;
64
65const API_URL_V1: &str = "https://api.openai.com/v1";
66
67#[derive(Default)]
68pub struct OpenAIClientBuilder {
69    api_endpoint: Option<String>,
70    api_key: Option<String>,
71    organization: Option<String>,
72    proxy: Option<String>,
73    timeout: Option<u64>,
74    headers: Option<HeaderMap>,
75}
76
77#[derive(Debug)]
78pub struct OpenAIClient {
79    api_endpoint: String,
80    api_key: Option<String>,
81    organization: Option<String>,
82    proxy: Option<String>,
83    timeout: Option<u64>,
84    headers: Option<HeaderMap>,
85    pub response_headers: Option<HeaderMap>,
86}
87
88impl OpenAIClientBuilder {
89    pub fn new() -> Self {
90        Self::default()
91    }
92
93    pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
94        self.api_key = Some(api_key.into());
95        self
96    }
97
98    pub fn with_endpoint(mut self, endpoint: impl Into<String>) -> Self {
99        self.api_endpoint = Some(endpoint.into());
100        self
101    }
102
103    pub fn with_organization(mut self, organization: impl Into<String>) -> Self {
104        self.organization = Some(organization.into());
105        self
106    }
107
108    pub fn with_proxy(mut self, proxy: impl Into<String>) -> Self {
109        self.proxy = Some(proxy.into());
110        self
111    }
112
113    pub fn with_timeout(mut self, timeout: u64) -> Self {
114        self.timeout = Some(timeout);
115        self
116    }
117
118    pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
119        let headers = self.headers.get_or_insert_with(HeaderMap::new);
120        headers.insert(
121            HeaderName::from_bytes(key.into().as_bytes()).expect("Invalid header name"),
122            HeaderValue::from_str(&value.into()).expect("Invalid header value"),
123        );
124        self
125    }
126
127    pub fn build(self) -> Result<OpenAIClient, Box<dyn Error>> {
128        let api_endpoint = self.api_endpoint.unwrap_or_else(|| {
129            std::env::var("OPENAI_API_BASE").unwrap_or_else(|_| API_URL_V1.to_owned())
130        });
131
132        Ok(OpenAIClient {
133            api_endpoint,
134            api_key: self.api_key,
135            organization: self.organization,
136            proxy: self.proxy,
137            timeout: self.timeout,
138            headers: self.headers,
139            response_headers: None,
140        })
141    }
142}
143
144impl OpenAIClient {
145    pub fn builder() -> OpenAIClientBuilder {
146        OpenAIClientBuilder::new()
147    }
148
149    async fn build_request(&self, method: Method, path: &str) -> reqwest::RequestBuilder {
150        let url = self
151            .build_url_with_preserved_query(path)
152            .unwrap_or_else(|_| format!("{}/{}", self.api_endpoint, path));
153
154        let client = Client::builder();
155
156        #[cfg(feature = "rustls")]
157        let client = client.use_rustls_tls();
158
159        let client = if let Some(timeout) = self.timeout {
160            client.timeout(std::time::Duration::from_secs(timeout))
161        } else {
162            client
163        };
164
165        let client = if let Some(proxy) = &self.proxy {
166            client.proxy(reqwest::Proxy::all(proxy).unwrap())
167        } else {
168            client
169        };
170
171        let client = client.build().unwrap();
172
173        let mut request = client.request(method, url);
174
175        if let Some(api_key) = &self.api_key {
176            request = request.header("Authorization", format!("Bearer {api_key}"));
177        }
178
179        if let Some(organization) = &self.organization {
180            request = request.header("openai-organization", organization);
181        }
182
183        if let Some(headers) = &self.headers {
184            for (key, value) in headers {
185                request = request.header(key, value);
186            }
187        }
188
189        if Self::is_beta(path) {
190            request = request.header("OpenAI-Beta", "assistants=v2");
191        }
192
193        request
194    }
195
196    async fn post<T: serde::de::DeserializeOwned>(
197        &mut self,
198        path: &str,
199        body: &impl serde::ser::Serialize,
200    ) -> Result<T, APIError> {
201        let request = self.build_request(Method::POST, path).await;
202        let request = request.json(body);
203        let response = request.send().await?;
204        self.handle_response(response).await
205    }
206
207    async fn get<T: serde::de::DeserializeOwned>(&mut self, path: &str) -> Result<T, APIError> {
208        let request = self.build_request(Method::GET, path).await;
209        let response = request.send().await?;
210        self.handle_response(response).await
211    }
212
213    async fn get_raw(&self, path: &str) -> Result<Bytes, APIError> {
214        let request = self.build_request(Method::GET, path).await;
215        let response = request.send().await?;
216        Ok(response.bytes().await?)
217    }
218
219    async fn delete<T: serde::de::DeserializeOwned>(&mut self, path: &str) -> Result<T, APIError> {
220        let request = self.build_request(Method::DELETE, path).await;
221        let response = request.send().await?;
222        self.handle_response(response).await
223    }
224
225    async fn post_form<T: serde::de::DeserializeOwned>(
226        &mut self,
227        path: &str,
228        form: Form,
229    ) -> Result<T, APIError> {
230        let request = self.build_request(Method::POST, path).await;
231        let request = request.multipart(form);
232        let response = request.send().await?;
233        self.handle_response(response).await
234    }
235
236    async fn post_form_raw(&self, path: &str, form: Form) -> Result<Bytes, APIError> {
237        let request = self.build_request(Method::POST, path).await;
238        let request = request.multipart(form);
239        let response = request.send().await?;
240        Ok(response.bytes().await?)
241    }
242
243    async fn handle_response<T: serde::de::DeserializeOwned>(
244        &mut self,
245        response: Response,
246    ) -> Result<T, APIError> {
247        let status = response.status();
248        let headers = response.headers().clone();
249        if status.is_success() {
250            let text = response.text().await.unwrap_or_else(|_| "".to_string());
251            match serde_json::from_str::<T>(&text) {
252                Ok(parsed) => {
253                    self.response_headers = Some(headers);
254                    Ok(parsed)
255                }
256                Err(e) => Err(APIError::CustomError {
257                    message: format!("Failed to parse JSON: {e} / response {text}"),
258                }),
259            }
260        } else {
261            let error_message = response
262                .text()
263                .await
264                .unwrap_or_else(|_| "Unknown error".to_string());
265            Err(APIError::CustomError {
266                message: format!("{status}: {error_message}"),
267            })
268        }
269    }
270
271    pub async fn completion(
272        &mut self,
273        req: CompletionRequest,
274    ) -> Result<CompletionResponse, APIError> {
275        self.post("completions", &req).await
276    }
277
278    pub async fn edit(&mut self, req: EditRequest) -> Result<EditResponse, APIError> {
279        self.post("edits", &req).await
280    }
281
282    pub async fn image_generation(
283        &mut self,
284        req: ImageGenerationRequest,
285    ) -> Result<ImageGenerationResponse, APIError> {
286        self.post("images/generations", &req).await
287    }
288
289    pub async fn image_edit(
290        &mut self,
291        req: ImageEditRequest,
292    ) -> Result<ImageEditResponse, APIError> {
293        self.post("images/edits", &req).await
294    }
295
296    pub async fn image_variation(
297        &mut self,
298        req: ImageVariationRequest,
299    ) -> Result<ImageVariationResponse, APIError> {
300        self.post("images/variations", &req).await
301    }
302
303    pub async fn embedding(
304        &mut self,
305        req: EmbeddingRequest,
306    ) -> Result<EmbeddingResponse, APIError> {
307        self.post("embeddings", &req).await
308    }
309
310    pub async fn file_list(&mut self) -> Result<FileListResponse, APIError> {
311        self.get("files").await
312    }
313
314    pub async fn upload_file(
315        &mut self,
316        req: FileUploadRequest,
317    ) -> Result<FileUploadResponse, APIError> {
318        let form = Self::create_form(&req, "file")?;
319        self.post_form("files", form).await
320    }
321
322    pub async fn delete_file(
323        &mut self,
324        req: FileDeleteRequest,
325    ) -> Result<FileDeleteResponse, APIError> {
326        self.delete(&format!("files/{}", req.file_id)).await
327    }
328
329    pub async fn retrieve_file(
330        &mut self,
331        file_id: String,
332    ) -> Result<FileRetrieveResponse, APIError> {
333        self.get(&format!("files/{file_id}")).await
334    }
335
336    pub async fn retrieve_file_content(&self, file_id: String) -> Result<Bytes, APIError> {
337        self.get_raw(&format!("files/{file_id}/content")).await
338    }
339
340    pub async fn chat_completion(
341        &mut self,
342        req: ChatCompletionRequest,
343    ) -> Result<ChatCompletionResponse, APIError> {
344        self.post("chat/completions", &req).await
345    }
346
347    pub async fn chat_completion_stream(
348        &mut self,
349        req: ChatCompletionStreamRequest,
350    ) -> Result<impl Stream<Item = ChatCompletionStreamResponse>, APIError> {
351        let mut payload = to_value(&req).map_err(|err| APIError::CustomError {
352            message: format!("Failed to serialize request: {}", err),
353        })?;
354
355        if let Some(obj) = payload.as_object_mut() {
356            obj.insert("stream".into(), Value::Bool(true));
357        }
358
359        let request = self.build_request(Method::POST, "chat/completions").await;
360        let request = request.json(&payload);
361        let response = request.send().await?;
362
363        if response.status().is_success() {
364            Ok(ChatCompletionStream {
365                response: Box::pin(response.bytes_stream()),
366                buffer: String::new(),
367                first_chunk: true,
368            })
369        } else {
370            let error_text = response
371                .text()
372                .await
373                .unwrap_or_else(|_| String::from("Unknown error"));
374
375            Err(APIError::CustomError {
376                message: error_text,
377            })
378        }
379    }
380
381    pub async fn audio_transcription(
382        &mut self,
383        req: AudioTranscriptionRequest,
384    ) -> Result<AudioTranscriptionResponse, APIError> {
385        // https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-response_format
386        if let Some(response_format) = &req.response_format {
387            if response_format != "json" && response_format != "verbose_json" {
388                return Err(APIError::CustomError {
389                    message: "response_format must be either 'json' or 'verbose_json' please use audio_transcription_raw".to_string(),
390                });
391            }
392        }
393        let form: Form;
394        if req.clone().file.is_some() {
395            form = Self::create_form(&req, "file")?;
396        } else if let Some(bytes) = req.clone().bytes {
397            form = Self::create_form_from_bytes(&req, bytes)?;
398        } else {
399            return Err(APIError::CustomError {
400                message: "Either file or bytes must be provided".to_string(),
401            });
402        }
403        self.post_form("audio/transcriptions", form).await
404    }
405
406    pub async fn audio_transcription_raw(
407        &mut self,
408        req: AudioTranscriptionRequest,
409    ) -> Result<Bytes, APIError> {
410        // https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-response_format
411        if let Some(response_format) = &req.response_format {
412            if response_format != "text" && response_format != "srt" && response_format != "vtt" {
413                return Err(APIError::CustomError {
414                    message: "response_format must be either 'text', 'srt' or 'vtt', please use audio_transcription".to_string(),
415                });
416            }
417        }
418        let form: Form;
419        if req.clone().file.is_some() {
420            form = Self::create_form(&req, "file")?;
421        } else if let Some(bytes) = req.clone().bytes {
422            form = Self::create_form_from_bytes(&req, bytes)?;
423        } else {
424            return Err(APIError::CustomError {
425                message: "Either file or bytes must be provided".to_string(),
426            });
427        }
428        self.post_form_raw("audio/transcriptions", form).await
429    }
430
431    pub async fn audio_translation(
432        &mut self,
433        req: AudioTranslationRequest,
434    ) -> Result<AudioTranslationResponse, APIError> {
435        let form = Self::create_form(&req, "file")?;
436        self.post_form("audio/translations", form).await
437    }
438
439    pub async fn audio_speech(
440        &mut self,
441        req: AudioSpeechRequest,
442    ) -> Result<AudioSpeechResponse, APIError> {
443        let request = self.build_request(Method::POST, "audio/speech").await;
444        let request = request.json(&req);
445        let response = request.send().await?;
446        let headers = response.headers().clone();
447        let bytes = response.bytes().await?;
448        let path = Path::new(req.output.as_str());
449        if let Some(parent) = path.parent() {
450            match create_dir_all(parent) {
451                Ok(_) => {}
452                Err(e) => {
453                    return Err(APIError::CustomError {
454                        message: e.to_string(),
455                    })
456                }
457            }
458        }
459        match File::create(path) {
460            Ok(mut file) => match file.write_all(&bytes) {
461                Ok(_) => {}
462                Err(e) => {
463                    return Err(APIError::CustomError {
464                        message: e.to_string(),
465                    })
466                }
467            },
468            Err(e) => {
469                return Err(APIError::CustomError {
470                    message: e.to_string(),
471                })
472            }
473        }
474
475        Ok(AudioSpeechResponse {
476            result: true,
477            headers: Some(headers),
478        })
479    }
480
481    pub async fn create_fine_tuning_job(
482        &mut self,
483        req: CreateFineTuningJobRequest,
484    ) -> Result<FineTuningJobObject, APIError> {
485        self.post("fine_tuning/jobs", &req).await
486    }
487
488    pub async fn list_fine_tuning_jobs(
489        &mut self,
490    ) -> Result<FineTuningPagination<FineTuningJobObject>, APIError> {
491        self.get("fine_tuning/jobs").await
492    }
493
494    pub async fn list_fine_tuning_job_events(
495        &mut self,
496        req: ListFineTuningJobEventsRequest,
497    ) -> Result<FineTuningPagination<FineTuningJobEvent>, APIError> {
498        self.get(&format!(
499            "fine_tuning/jobs/{}/events",
500            req.fine_tuning_job_id
501        ))
502        .await
503    }
504
505    pub async fn retrieve_fine_tuning_job(
506        &mut self,
507        req: RetrieveFineTuningJobRequest,
508    ) -> Result<FineTuningJobObject, APIError> {
509        self.get(&format!("fine_tuning/jobs/{}", req.fine_tuning_job_id))
510            .await
511    }
512
513    pub async fn cancel_fine_tuning_job(
514        &mut self,
515        req: CancelFineTuningJobRequest,
516    ) -> Result<FineTuningJobObject, APIError> {
517        self.post(
518            &format!("fine_tuning/jobs/{}/cancel", req.fine_tuning_job_id),
519            &req,
520        )
521        .await
522    }
523
524    pub async fn create_moderation(
525        &mut self,
526        req: CreateModerationRequest,
527    ) -> Result<CreateModerationResponse, APIError> {
528        self.post("moderations", &req).await
529    }
530
531    pub async fn create_assistant(
532        &mut self,
533        req: AssistantRequest,
534    ) -> Result<AssistantObject, APIError> {
535        self.post("assistants", &req).await
536    }
537
538    pub async fn retrieve_assistant(
539        &mut self,
540        assistant_id: String,
541    ) -> Result<AssistantObject, APIError> {
542        self.get(&format!("assistants/{assistant_id}")).await
543    }
544
545    pub async fn modify_assistant(
546        &mut self,
547        assistant_id: String,
548        req: AssistantRequest,
549    ) -> Result<AssistantObject, APIError> {
550        self.post(&format!("assistants/{assistant_id}"), &req).await
551    }
552
553    pub async fn delete_assistant(
554        &mut self,
555        assistant_id: String,
556    ) -> Result<common::DeletionStatus, APIError> {
557        self.delete(&format!("assistants/{assistant_id}")).await
558    }
559
560    pub async fn list_assistant(
561        &mut self,
562        limit: Option<i64>,
563        order: Option<String>,
564        after: Option<String>,
565        before: Option<String>,
566    ) -> Result<ListAssistant, APIError> {
567        let url = Self::query_params(limit, order, after, before, "assistants".to_string());
568        self.get(&url).await
569    }
570
571    pub async fn create_assistant_file(
572        &mut self,
573        assistant_id: String,
574        req: AssistantFileRequest,
575    ) -> Result<AssistantFileObject, APIError> {
576        self.post(&format!("assistants/{assistant_id}/files"), &req)
577            .await
578    }
579
580    pub async fn retrieve_assistant_file(
581        &mut self,
582        assistant_id: String,
583        file_id: String,
584    ) -> Result<AssistantFileObject, APIError> {
585        self.get(&format!("assistants/{assistant_id}/files/{file_id}"))
586            .await
587    }
588
589    pub async fn delete_assistant_file(
590        &mut self,
591        assistant_id: String,
592        file_id: String,
593    ) -> Result<common::DeletionStatus, APIError> {
594        self.delete(&format!("assistants/{assistant_id}/files/{file_id}"))
595            .await
596    }
597
598    pub async fn list_assistant_file(
599        &mut self,
600        assistant_id: String,
601        limit: Option<i64>,
602        order: Option<String>,
603        after: Option<String>,
604        before: Option<String>,
605    ) -> Result<ListAssistantFile, APIError> {
606        let url = Self::query_params(
607            limit,
608            order,
609            after,
610            before,
611            format!("assistants/{assistant_id}/files"),
612        );
613        self.get(&url).await
614    }
615
616    pub async fn create_thread(
617        &mut self,
618        req: CreateThreadRequest,
619    ) -> Result<ThreadObject, APIError> {
620        self.post("threads", &req).await
621    }
622
623    pub async fn retrieve_thread(&mut self, thread_id: String) -> Result<ThreadObject, APIError> {
624        self.get(&format!("threads/{thread_id}")).await
625    }
626
627    pub async fn modify_thread(
628        &mut self,
629        thread_id: String,
630        req: ModifyThreadRequest,
631    ) -> Result<ThreadObject, APIError> {
632        self.post(&format!("threads/{thread_id}"), &req).await
633    }
634
635    pub async fn delete_thread(
636        &mut self,
637        thread_id: String,
638    ) -> Result<common::DeletionStatus, APIError> {
639        self.delete(&format!("threads/{thread_id}")).await
640    }
641
642    pub async fn create_message(
643        &mut self,
644        thread_id: String,
645        req: CreateMessageRequest,
646    ) -> Result<MessageObject, APIError> {
647        self.post(&format!("threads/{thread_id}/messages"), &req)
648            .await
649    }
650
651    pub async fn retrieve_message(
652        &mut self,
653        thread_id: String,
654        message_id: String,
655    ) -> Result<MessageObject, APIError> {
656        self.get(&format!("threads/{thread_id}/messages/{message_id}"))
657            .await
658    }
659
660    pub async fn modify_message(
661        &mut self,
662        thread_id: String,
663        message_id: String,
664        req: ModifyMessageRequest,
665    ) -> Result<MessageObject, APIError> {
666        self.post(&format!("threads/{thread_id}/messages/{message_id}"), &req)
667            .await
668    }
669
670    pub async fn list_messages(&mut self, thread_id: String) -> Result<ListMessage, APIError> {
671        self.get(&format!("threads/{thread_id}/messages")).await
672    }
673
674    pub async fn retrieve_message_file(
675        &mut self,
676        thread_id: String,
677        message_id: String,
678        file_id: String,
679    ) -> Result<MessageFileObject, APIError> {
680        self.get(&format!(
681            "threads/{thread_id}/messages/{message_id}/files/{file_id}"
682        ))
683        .await
684    }
685
686    pub async fn list_message_file(
687        &mut self,
688        thread_id: String,
689        message_id: String,
690        limit: Option<i64>,
691        order: Option<String>,
692        after: Option<String>,
693        before: Option<String>,
694    ) -> Result<ListMessageFile, APIError> {
695        let url = Self::query_params(
696            limit,
697            order,
698            after,
699            before,
700            format!("threads/{thread_id}/messages/{message_id}/files"),
701        );
702        self.get(&url).await
703    }
704
705    pub async fn create_run(
706        &mut self,
707        thread_id: String,
708        req: CreateRunRequest,
709    ) -> Result<RunObject, APIError> {
710        self.post(&format!("threads/{thread_id}/runs"), &req).await
711    }
712
713    pub async fn retrieve_run(
714        &mut self,
715        thread_id: String,
716        run_id: String,
717    ) -> Result<RunObject, APIError> {
718        self.get(&format!("threads/{thread_id}/runs/{run_id}"))
719            .await
720    }
721
722    pub async fn modify_run(
723        &mut self,
724        thread_id: String,
725        run_id: String,
726        req: ModifyRunRequest,
727    ) -> Result<RunObject, APIError> {
728        self.post(&format!("threads/{thread_id}/runs/{run_id}"), &req)
729            .await
730    }
731
732    pub async fn list_run(
733        &mut self,
734        thread_id: String,
735        limit: Option<i64>,
736        order: Option<String>,
737        after: Option<String>,
738        before: Option<String>,
739    ) -> Result<ListRun, APIError> {
740        let url = Self::query_params(
741            limit,
742            order,
743            after,
744            before,
745            format!("threads/{thread_id}/runs"),
746        );
747        self.get(&url).await
748    }
749
750    pub async fn cancel_run(
751        &mut self,
752        thread_id: String,
753        run_id: String,
754    ) -> Result<RunObject, APIError> {
755        self.post(
756            &format!("threads/{thread_id}/runs/{run_id}/cancel"),
757            &ModifyRunRequest::default(),
758        )
759        .await
760    }
761
762    pub async fn create_thread_and_run(
763        &mut self,
764        req: CreateThreadAndRunRequest,
765    ) -> Result<RunObject, APIError> {
766        self.post("threads/runs", &req).await
767    }
768
769    pub async fn retrieve_run_step(
770        &mut self,
771        thread_id: String,
772        run_id: String,
773        step_id: String,
774    ) -> Result<RunStepObject, APIError> {
775        self.get(&format!(
776            "threads/{thread_id}/runs/{run_id}/steps/{step_id}"
777        ))
778        .await
779    }
780
781    pub async fn list_run_step(
782        &mut self,
783        thread_id: String,
784        run_id: String,
785        limit: Option<i64>,
786        order: Option<String>,
787        after: Option<String>,
788        before: Option<String>,
789    ) -> Result<ListRunStep, APIError> {
790        let url = Self::query_params(
791            limit,
792            order,
793            after,
794            before,
795            format!("threads/{thread_id}/runs/{run_id}/steps"),
796        );
797        self.get(&url).await
798    }
799
800    pub async fn create_batch(
801        &mut self,
802        req: CreateBatchRequest,
803    ) -> Result<BatchResponse, APIError> {
804        self.post("batches", &req).await
805    }
806
807    pub async fn retrieve_batch(&mut self, batch_id: String) -> Result<BatchResponse, APIError> {
808        self.get(&format!("batches/{batch_id}")).await
809    }
810
811    pub async fn cancel_batch(&mut self, batch_id: String) -> Result<BatchResponse, APIError> {
812        self.post(
813            &format!("batches/{batch_id}/cancel"),
814            &common::EmptyRequestBody {},
815        )
816        .await
817    }
818
819    pub async fn list_batch(
820        &mut self,
821        after: Option<String>,
822        limit: Option<i64>,
823    ) -> Result<ListBatchResponse, APIError> {
824        let url = Self::query_params(limit, None, after, None, "batches".to_string());
825        self.get(&url).await
826    }
827
828    // Responses API
829    pub async fn create_response(
830        &mut self,
831        req: CreateResponseRequest,
832    ) -> Result<ResponseObject, APIError> {
833        self.post("responses", &req).await
834    }
835
836    pub async fn create_response_stream(
837        &mut self,
838        req: CreateResponseStreamRequest,
839    ) -> Result<impl Stream<Item = ResponseStreamResponse>, APIError> {
840        let mut payload = to_value(&req).map_err(|err| APIError::CustomError {
841            message: format!("Failed to serialize request: {}", err),
842        })?;
843
844        if let Some(obj) = payload.as_object_mut() {
845            obj.insert("stream".into(), Value::Bool(true));
846        }
847
848        let request = self.build_request(Method::POST, "responses").await;
849        let request = request.json(&payload);
850        let response = request.send().await?;
851
852        if response.status().is_success() {
853            Ok(ResponseStream {
854                response: Box::pin(response.bytes_stream()),
855                buffer: String::new(),
856                first_chunk: true,
857            })
858        } else {
859            let error_text = response
860                .text()
861                .await
862                .unwrap_or_else(|_| String::from("Unknown error"));
863
864            Err(APIError::CustomError {
865                message: error_text,
866            })
867        }
868    }
869
870    pub async fn retrieve_response(
871        &mut self,
872        response_id: String,
873    ) -> Result<ResponseObject, APIError> {
874        self.get(&format!("responses/{response_id}")).await
875    }
876
877    pub async fn delete_response(
878        &mut self,
879        response_id: String,
880    ) -> Result<common::DeletionStatus, APIError> {
881        self.delete(&format!("responses/{response_id}")).await
882    }
883
884    pub async fn cancel_response(
885        &mut self,
886        response_id: String,
887    ) -> Result<ResponseObject, APIError> {
888        self.post(
889            &format!("responses/{response_id}/cancel"),
890            &common::EmptyRequestBody {},
891        )
892        .await
893    }
894
895    pub async fn list_response_input_items(
896        &mut self,
897        response_id: String,
898        after: Option<String>,
899        limit: Option<i64>,
900        order: Option<String>,
901    ) -> Result<ListResponses, APIError> {
902        let mut url = format!("responses/{}/input_items", response_id);
903        let mut params = vec![];
904        if let Some(after) = after {
905            params.push(format!("after={}", after));
906        }
907        if let Some(limit) = limit {
908            params.push(format!("limit={}", limit));
909        }
910        if let Some(order) = order {
911            params.push(format!("order={}", order));
912        }
913        if !params.is_empty() {
914            url = format!("{}?{}", url, params.join("&"));
915        }
916        self.get(&url).await
917    }
918
919    pub async fn count_response_input_tokens(
920        &mut self,
921        req: CountTokensRequest,
922    ) -> Result<CountTokensResponse, APIError> {
923        self.post("responses/input_tokens", &req).await
924    }
925
926    pub async fn list_models(&mut self) -> Result<ModelsResponse, APIError> {
927        self.get("models").await
928    }
929
930    pub async fn retrieve_model(&mut self, model_id: String) -> Result<ModelResponse, APIError> {
931        self.get(&format!("models/{model_id}")).await
932    }
933
934    pub async fn delete_model(
935        &mut self,
936        model_id: String,
937    ) -> Result<common::DeletionStatus, APIError> {
938        self.delete(&format!("models/{model_id}")).await
939    }
940
941    fn build_url_with_preserved_query(&self, path: &str) -> Result<String, url::ParseError> {
942        let (base, query_opt) = match self.api_endpoint.split_once('?') {
943            Some((b, q)) => (b.trim_end_matches('/'), Some(q)),
944            None => (self.api_endpoint.trim_end_matches('/'), None),
945        };
946
947        let full_path = format!("{}/{}", base, path.trim_start_matches('/'));
948        let mut url = Url::parse(&full_path)?;
949
950        if let Some(query) = query_opt {
951            for (k, v) in url::form_urlencoded::parse(query.as_bytes()) {
952                url.query_pairs_mut().append_pair(&k, &v);
953            }
954        }
955        Ok(url.to_string())
956    }
957
958    fn query_params(
959        limit: Option<i64>,
960        order: Option<String>,
961        after: Option<String>,
962        before: Option<String>,
963        mut url: String,
964    ) -> String {
965        let mut params = vec![];
966        if let Some(limit) = limit {
967            params.push(format!("limit={limit}"));
968        }
969        if let Some(order) = order {
970            params.push(format!("order={order}"));
971        }
972        if let Some(after) = after {
973            params.push(format!("after={after}"));
974        }
975        if let Some(before) = before {
976            params.push(format!("before={before}"));
977        }
978        if !params.is_empty() {
979            url = format!("{}?{}", url, params.join("&"));
980        }
981        url
982    }
983
984    fn is_beta(path: &str) -> bool {
985        path.starts_with("assistants") || path.starts_with("threads")
986    }
987
988    fn create_form<T>(req: &T, file_field: &str) -> Result<Form, APIError>
989    where
990        T: Serialize,
991    {
992        let json = match serde_json::to_value(req) {
993            Ok(json) => json,
994            Err(e) => {
995                return Err(APIError::CustomError {
996                    message: e.to_string(),
997                })
998            }
999        };
1000        let file_path = if let Value::Object(map) = &json {
1001            map.get(file_field)
1002                .and_then(|v| v.as_str())
1003                .ok_or(APIError::CustomError {
1004                    message: format!("Field '{file_field}' not found or not a string"),
1005                })?
1006        } else {
1007            return Err(APIError::CustomError {
1008                message: "Request is not a JSON object".to_string(),
1009            });
1010        };
1011
1012        let mut file = match File::open(file_path) {
1013            Ok(file) => file,
1014            Err(e) => {
1015                return Err(APIError::CustomError {
1016                    message: e.to_string(),
1017                })
1018            }
1019        };
1020        let mut buffer = Vec::new();
1021        match file.read_to_end(&mut buffer) {
1022            Ok(_) => {}
1023            Err(e) => {
1024                return Err(APIError::CustomError {
1025                    message: e.to_string(),
1026                })
1027            }
1028        }
1029
1030        let mut form =
1031            Form::new().part("file", Part::bytes(buffer).file_name(file_path.to_string()));
1032
1033        if let Value::Object(map) = json {
1034            for (key, value) in map.into_iter() {
1035                if key != file_field {
1036                    match value {
1037                        Value::String(s) => {
1038                            form = form.text(key, s);
1039                        }
1040                        Value::Number(n) => {
1041                            form = form.text(key, n.to_string());
1042                        }
1043                        _ => {}
1044                    }
1045                }
1046            }
1047        }
1048
1049        Ok(form)
1050    }
1051
1052    fn create_form_from_bytes<T>(req: &T, bytes: Vec<u8>) -> Result<Form, APIError>
1053    where
1054        T: Serialize,
1055    {
1056        let json = match serde_json::to_value(req) {
1057            Ok(json) => json,
1058            Err(e) => {
1059                return Err(APIError::CustomError {
1060                    message: e.to_string(),
1061                })
1062            }
1063        };
1064
1065        let mut form = Form::new().part("file", Part::bytes(bytes.clone()).file_name("file.mp3"));
1066
1067        if let Value::Object(map) = json {
1068            for (key, value) in map.into_iter() {
1069                match value {
1070                    Value::String(s) => {
1071                        form = form.text(key, s);
1072                    }
1073                    Value::Number(n) => {
1074                        form = form.text(key, n.to_string());
1075                    }
1076                    _ => {}
1077                }
1078            }
1079        }
1080
1081        Ok(form)
1082    }
1083}