openai_interface/files/
retrieve.rs

1//! This module provides functionality for retrieving files from the OpenAI API.
2//!
3//! It includes the `RetrieveRequest` struct which implements the `Get` and `GetNoStream` traits
4//! to build URLs and fetch file data asynchronously.
5//!
6//! # Example
7//!
8//! ```rust
9//! use std::sync::LazyLock;
10//!
11//! use openai_interface::files::retrieve::*;
12//! use openai_interface::{
13//!     files::{list::request::ListFilesRequest, retrieve::request::RetrieveRequest},
14//!     rest::get::{Get, GetNoStream},
15//! };
16//! use anyhow::bail;
17//! use futures_util::future::{self};
18//!
19//! const MODELSCOPE_BASE_URL: &str = "https://dashscope.aliyuncs.com/compatible-mode/v1/";
20//! const MODELSCOPE_KEY: LazyLock<&str> =
21//!     LazyLock::new(|| include_str!("../../keys/modelstudio_domestic_key").trim());
22//!
23//! #[tokio::main]
24//! async fn main() -> Result<(), anyhow::Error> {
25//!     // first get all files
26//!     let list_request = ListFilesRequest {
27//!         limit: Some(5), // avoid rate limit
28//!         ..Default::default()
29//!     };
30//!
31//!     let list_response = list_request
32//!         .get_response(MODELSCOPE_BASE_URL, &MODELSCOPE_KEY)
33//!         .await?;
34//!
35//!     let futures: Vec<_> = list_response
36//!         .data
37//!         .iter()
38//!         .map(|file_object| {
39//!             let file_id = file_object.id.clone();
40//!             let base_url = MODELSCOPE_BASE_URL.to_string();
41//!             let key = MODELSCOPE_KEY.to_string();
42//!             async move {
43//!                 let retrieve_request = RetrieveRequest {
44//!                     file_id: &file_id,
45//!                     ..Default::default()
46//!                 };
47//!                 retrieve_request.get_response(&base_url, &key).await
48//!             }
49//!         })
50//!         .collect();
51//!
52//!     let results = future::join_all(futures).await;
53//!
54//!     for (i, result) in results.iter().enumerate() {
55//!         match result {
56//!             Ok(file_object) => {
57//!                 assert_eq!(&list_response.data[i].id, &file_object.id);
58//!                 assert_eq!(&list_response.data[i].filename, &file_object.filename);
59//!                 assert_eq!(&list_response.data[i].purpose, &file_object.purpose);
60//!                 // assert_eq!(&list_response.data[i], file_object);
61//!             }
62//!             Err(e) => {
63//!                 bail!(
64//!                     "Failed to get response: {e:#}. The file is: index {i}, {:?}",
65//!                     list_response.data[i]
66//!                 )
67//!             }
68//!         }
69//!     }
70//!
71//!     Ok(())
72//! }
73//! ```
74
75pub mod request {
76    use std::collections::HashMap;
77
78    use url::Url;
79
80    use crate::{
81        errors::OapiError,
82        rest::get::{Get, GetNoStream},
83    };
84
85    /// Query parameters for retrieving a file.
86    #[derive(Debug, Clone, Default)]
87    pub struct RetrieveRequest<'a> {
88        pub file_id: &'a str,
89        pub extra_query: HashMap<&'a str, &'a str>,
90    }
91
92    impl<'a> Get for RetrieveRequest<'a> {
93        /// base_url should look like <https://api.openai.com/v1>
94        fn build_url(&self, base_url: &str) -> Result<String, OapiError> {
95            let mut url = Url::parse(base_url.trim_end_matches('/'))
96                .map_err(|err| OapiError::UrlError(err))?;
97            url.path_segments_mut()
98                .map_err(|_| OapiError::UrlError(url::ParseError::RelativeUrlWithoutBase))?
99                .push("files")
100                .push(self.file_id);
101
102            for (key, value) in &self.extra_query {
103                url.query_pairs_mut().append_pair(key, value);
104            }
105
106            Ok(url.to_string())
107        }
108    }
109
110    impl<'a> GetNoStream for RetrieveRequest<'a> {
111        type Response = crate::files::FileObject;
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use std::sync::LazyLock;
118
119    use super::*;
120    use crate::{
121        files::{list::request::ListFilesRequest, retrieve::request::RetrieveRequest},
122        rest::get::{Get, GetNoStream},
123    };
124    use anyhow::bail;
125    use futures_util::future::{self};
126
127    const MODELSCOPE_BASE_URL: &str = "https://dashscope.aliyuncs.com/compatible-mode/v1/";
128    const MODELSCOPE_KEY: LazyLock<&str> =
129        LazyLock::new(|| include_str!("../../keys/modelstudio_domestic_key").trim());
130
131    #[test]
132    fn test_build_url() {
133        let request = request::RetrieveRequest {
134            file_id: "file_id",
135            ..Default::default()
136        };
137        let url = request.build_url("https://api.openai.com/v1/").unwrap();
138        assert_eq!(url, "https://api.openai.com/v1/files/file_id");
139    }
140
141    #[tokio::test]
142    async fn test_retrieve_file() -> Result<(), anyhow::Error> {
143        // first get all files
144        let list_request = ListFilesRequest {
145            limit: Some(5), // avoid rate limit
146            ..Default::default()
147        };
148
149        let list_response = list_request
150            .get_response(MODELSCOPE_BASE_URL, &MODELSCOPE_KEY)
151            .await?;
152
153        let futures: Vec<_> = list_response
154            .data
155            .iter()
156            .map(|file_object| {
157                let file_id = file_object.id.clone();
158                let base_url = MODELSCOPE_BASE_URL.to_string();
159                let key = MODELSCOPE_KEY.to_string();
160                async move {
161                    let retrieve_request = RetrieveRequest {
162                        file_id: &file_id,
163                        ..Default::default()
164                    };
165                    retrieve_request.get_response(&base_url, &key).await
166                }
167            })
168            .collect();
169
170        let results = future::join_all(futures).await;
171
172        for (i, result) in results.iter().enumerate() {
173            match result {
174                Ok(file_object) => {
175                    assert_eq!(&list_response.data[i].id, &file_object.id);
176                    assert_eq!(&list_response.data[i].filename, &file_object.filename);
177                    assert_eq!(&list_response.data[i].purpose, &file_object.purpose);
178                    // assert_eq!(&list_response.data[i], file_object);
179                }
180                Err(e) => {
181                    bail!(
182                        "Failed to get response: {e:#}. The file is: index {i}, {:?}",
183                        list_response.data[i]
184                    )
185                }
186            }
187        }
188
189        Ok(())
190    }
191}