async_dashscope/operation/file/
mod.rs

1pub mod output;
2pub mod param;
3use crate::{Client, error::DashScopeError};
4
5const FILE_PATH: &str = "files";
6
7pub enum FilePurpose {
8    FineTune,
9    FileExtract,
10    Batch,
11}
12impl FilePurpose {
13    pub fn as_str(&self) -> &'static str {
14        match self {
15            FilePurpose::FineTune => "fine-tune",
16            FilePurpose::FileExtract => "file-extract",
17            FilePurpose::Batch => "batch",
18        }
19    }
20}
21
22pub struct File<'a> {
23    client: &'a Client,
24}
25
26impl<'a> File<'a> {
27    pub fn new(client: &'a Client) -> Self {
28        Self { client }
29    }
30
31    /// 上传文件
32    /// 
33    /// # 参数
34    /// 
35    /// * `files` - 要上传的文件路径列表,内容和 prupose 强相关, 例如 fine-tune 用途只能上传 jsonl 文件
36    /// * `purpose` - 文件用途,如 "fine-tune"
37    /// * `descriptions` - 文件描述列表(可选)
38    /// 
39    /// # 返回
40    /// 
41    /// 返回上传结果,包含文件信息
42    pub async fn create(
43        &self, 
44        files: Vec<&str>, 
45        purpose: FilePurpose, 
46        descriptions: Option<Vec<&str>>
47    ) -> Result<crate::operation::file::output::FileUploadOutput, DashScopeError> {
48        use reqwest::multipart;
49        use std::path::Path;
50
51        // 将参数转换为可以在闭包中使用的类型
52        let purpose_str = purpose.as_str().to_string();
53        
54        // 验证文件路径存在并转换为PathBuf
55        let file_paths: Vec<String> = {
56            let mut result = Vec::with_capacity(files.len());
57            for p in files {
58                let path = Path::new(p);
59                if !path.exists() {
60                    return Err(DashScopeError::UploadError(format!("File not found: {}", p)));
61                }
62                result.push(p.to_string());
63            }
64            result
65        };
66        
67        let descriptions: Option<Vec<String>> = descriptions.map(|descs| descs.iter().map(|s| s.to_string()).collect());
68
69        // 使用客户端的post_multipart方法发送请求,自动处理认证和重试
70        self.client.post_multipart(FILE_PATH, move || {
71            let mut form = multipart::Form::new()
72                .text("purpose", purpose_str.clone());
73
74            // 添加描述信息
75            if let Some(descs) = &descriptions {
76                for desc in descs {
77                    form = form.text("descriptions", desc.clone());
78                }
79            };
80
81            // 添加文件 - 每次调用闭包时重新读取文件(用于重试)
82            let mut form_with_files = form;
83            for file_path in &file_paths {
84                let path = Path::new(file_path);
85                let file_name = path.file_name()
86                    .and_then(|name| name.to_str())
87                    .unwrap_or(&format!("file_{}", std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_nanos()))
88                    .to_string();
89
90                let file_data = std::fs::read(file_path)
91                    .unwrap_or_else(|e| {
92                        std::panic::resume_unwind(Box::new(DashScopeError::UploadError(
93                            format!("Failed to read file {}: {}", file_path, e)
94                        )));
95                    });
96
97                let part = multipart::Part::bytes(file_data)
98                    .file_name(file_name);
99
100                form_with_files = form_with_files.part("files", part);
101            }
102
103            form_with_files
104        }).await
105    }
106
107    /// 查询文件信息
108    pub async fn retrieve(
109        &self,
110        file_id: &str,
111    ) -> Result<crate::operation::file::output::FileRetrieveOutput, DashScopeError> {
112        // 构建路径
113        let path = format!("files/{}", file_id);
114
115        // 使用客户端的get_with_params方法发送请求,参数为空对象
116        self.client.get_with_params(&path, &()).await
117    }
118
119    /// 查询文件列表
120    pub async fn list(
121        &self,
122        page_no: Option<u64>,
123        page_size: Option<u64>,
124    ) -> Result<crate::operation::file::output::FileListOutput, DashScopeError> {
125        use serde_json::json;
126
127        // 验证参数
128        let validated_page_no = page_no.unwrap_or(1);
129        let validated_page_no = if validated_page_no < 1 {
130            1
131        } else {
132            validated_page_no
133        };
134
135        let validated_page_size = page_size.unwrap_or(10);
136        let validated_page_size = validated_page_size.clamp(1, 100);
137
138        // 构建查询参数
139        let params = json!({
140            "page_no": validated_page_no,
141            "page_size": validated_page_size,
142        });
143
144        // 使用客户端的get方法发送请求
145        self.client.get_with_params("files", &params).await
146    }
147
148    /// 删除文件
149    pub async fn delete(
150        &self,
151        file_id: &str,
152    ) -> Result<crate::operation::file::output::FileDeleteOutput, DashScopeError> {
153        // 构建路径
154        let path = format!("files/{}", file_id);
155
156        // 使用客户端的delete方法发送请求
157        self.client.delete(&path).await
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164    use crate::config::ConfigBuilder;
165
166    #[tokio::test]
167    async fn test_file_operations() {
168        let _ = dotenvy::dotenv(); // 加载 .env 文件,如果存在的话
169        let api_key = std::env::var("DASHSCOPE_API_KEY").expect("DASHSCOPE_API_KEY must be set");
170        let config = ConfigBuilder::default().api_key(api_key).build().unwrap();
171        let client = Client::with_config(config);
172        let file = File::new(&client);
173
174        // 测试文件列表功能
175        let result = file.list(Some(1), Some(10)).await;
176        match result {
177            Ok(list_output) => {
178                println!("Retrieved {} files", list_output.data.files.len());
179            }
180            Err(e) => {
181                eprintln!("Error listing files: {:?}", e);
182            }
183        }
184    }
185
186    #[tokio::test]
187    async fn test_file_retrieve() {
188        let _ = dotenvy::dotenv();
189        let api_key = std::env::var("DASHSCOPE_API_KEY").expect("DASHSCOPE_API_KEY must be set");
190        let config = ConfigBuilder::default().api_key(api_key).build().unwrap();
191        let client = Client::with_config(config);
192        let file = File::new(&client);
193
194        // 测试获取文件列表以获取一个文件ID
195        let list_result = file.list(Some(1), Some(1)).await;
196        match list_result {
197            Ok(list_output) => {
198                if let Some(first_file) = list_output.data.files.first() {
199                    // 测试查询单个文件
200                    let result = file.retrieve(&first_file.file_id).await;
201                    match result {
202                        Ok(file_info) => {
203                            println!("Retrieved file: {}", file_info.data.name);
204                        }
205                        Err(e) => {
206                            eprintln!("Error retrieving file: {:?}", e);
207                        }
208                    }
209                } else {
210                    println!("No files found to retrieve");
211                }
212            }
213            Err(e) => {
214                eprintln!("Error listing files for retrieve test: {:?}", e);
215            }
216        }
217    }
218}