Skip to main content

async_dashscope/operation/file/
mod.rs

1pub mod output;
2pub mod param;
3use crate::{Client, error::DashScopeError};
4
5const FILE_PATH: &str = "files";
6
7pub enum FilePurpose {
8    FineTune,
9    FileExtract,
10    Batch,
11}
12impl FilePurpose {
13    pub fn as_str(&self) -> &'static str {
14        match self {
15            FilePurpose::FineTune => "fine-tune",
16            FilePurpose::FileExtract => "file-extract",
17            FilePurpose::Batch => "batch",
18        }
19    }
20}
21
22pub struct File<'a> {
23    client: &'a Client,
24}
25
26impl<'a> File<'a> {
27    pub fn new(client: &'a Client) -> Self {
28        Self { client }
29    }
30
31    /// 上传文件
32    /// 
33    /// # 参数
34    /// 
35    /// * `files` - 要上传的文件路径列表,内容和 prupose 强相关, 例如 fine-tune 用途只能上传 jsonl 文件
36    /// * `purpose` - 文件用途,如 "fine-tune"
37    /// * `descriptions` - 文件描述列表(可选)
38    /// 
39    /// # 返回
40    /// 
41    /// 返回上传结果,包含文件信息
42    pub async fn create(
43        &self, 
44        files: Vec<&str>, 
45        purpose: FilePurpose, 
46        descriptions: Option<Vec<&str>>
47    ) -> Result<crate::operation::file::output::FileUploadOutput, DashScopeError> {
48        use reqwest::multipart;
49        use std::path::Path;
50
51        let purpose_str = purpose.as_str().to_string();
52        
53        let file_paths: Vec<String> = {
54            let mut result = Vec::with_capacity(files.len());
55            for p in files {
56                let path = Path::new(p);
57                if !path.exists() {
58                    return Err(DashScopeError::UploadError(format!("File not found: {}", p)));
59                }
60                result.push(p.to_string());
61            }
62            result
63        };
64        
65        let descriptions: Option<Vec<String>> = descriptions.map(|descs| descs.iter().map(|s| s.to_string()).collect());
66
67        self.client.post_multipart(FILE_PATH, move || {
68            let mut form = multipart::Form::new()
69                .text("purpose", purpose_str.clone());
70
71            if let Some(descs) = &descriptions {
72                for desc in descs {
73                    form = form.text("descriptions", desc.clone());
74                }
75            };
76
77            let mut form_with_files = form;
78            for file_path in &file_paths {
79                let path = Path::new(file_path);
80                let file_name = path.file_name()
81                    .and_then(|name| name.to_str())
82                    .unwrap_or(&format!("file_{}", std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_nanos()))
83                    .to_string();
84
85                let file_data = std::fs::read(file_path)
86                    .expect("File should exist and be readable since we validated it earlier");
87
88                let part = multipart::Part::bytes(file_data)
89                    .file_name(file_name);
90
91                form_with_files = form_with_files.part("files", part);
92            }
93
94            form_with_files
95        }).await
96    }
97
98    /// 查询文件信息
99    pub async fn retrieve(
100        &self,
101        file_id: &str,
102    ) -> Result<crate::operation::file::output::FileRetrieveOutput, DashScopeError> {
103        // 构建路径
104        let path = format!("files/{}", file_id);
105
106        // 使用客户端的get_with_params方法发送请求,参数为空对象
107        self.client.get_with_params(&path, &()).await
108    }
109
110    /// 查询文件列表
111    pub async fn list(
112        &self,
113        page_no: Option<u64>,
114        page_size: Option<u64>,
115    ) -> Result<crate::operation::file::output::FileListOutput, DashScopeError> {
116        use serde_json::json;
117
118        // 验证参数
119        let validated_page_no = page_no.unwrap_or(1);
120        let validated_page_no = if validated_page_no < 1 {
121            1
122        } else {
123            validated_page_no
124        };
125
126        let validated_page_size = page_size.unwrap_or(10);
127        let validated_page_size = validated_page_size.clamp(1, 100);
128
129        // 构建查询参数
130        let params = json!({
131            "page_no": validated_page_no,
132            "page_size": validated_page_size,
133        });
134
135        // 使用客户端的get方法发送请求
136        self.client.get_with_params("files", &params).await
137    }
138
139    /// 删除文件
140    pub async fn delete(
141        &self,
142        file_id: &str,
143    ) -> Result<crate::operation::file::output::FileDeleteOutput, DashScopeError> {
144        // 构建路径
145        let path = format!("files/{}", file_id);
146
147        // 使用客户端的delete方法发送请求
148        self.client.delete(&path).await
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155    use crate::config::ConfigBuilder;
156
157    #[tokio::test]
158    async fn test_file_operations() {
159        let _ = dotenvy::dotenv(); // 加载 .env 文件,如果存在的话
160        let api_key = std::env::var("DASHSCOPE_API_KEY").expect("DASHSCOPE_API_KEY must be set");
161        let config = ConfigBuilder::default().api_key(api_key).build().unwrap();
162        let client = Client::with_config(config);
163        let file = File::new(&client);
164
165        // 测试文件列表功能
166        let result = file.list(Some(1), Some(10)).await;
167        match result {
168            Ok(list_output) => {
169                println!("Retrieved {} files", list_output.data.files.len());
170            }
171            Err(e) => {
172                eprintln!("Error listing files: {:?}", e);
173            }
174        }
175    }
176
177    #[tokio::test]
178    async fn test_file_retrieve() {
179        let _ = dotenvy::dotenv();
180        let api_key = std::env::var("DASHSCOPE_API_KEY").expect("DASHSCOPE_API_KEY must be set");
181        let config = ConfigBuilder::default().api_key(api_key).build().unwrap();
182        let client = Client::with_config(config);
183        let file = File::new(&client);
184
185        // 测试获取文件列表以获取一个文件ID
186        let list_result = file.list(Some(1), Some(1)).await;
187        match list_result {
188            Ok(list_output) => {
189                if let Some(first_file) = list_output.data.files.first() {
190                    // 测试查询单个文件
191                    let result = file.retrieve(&first_file.file_id).await;
192                    match result {
193                        Ok(file_info) => {
194                            println!("Retrieved file: {}", file_info.data.name);
195                        }
196                        Err(e) => {
197                            eprintln!("Error retrieving file: {:?}", e);
198                        }
199                    }
200                } else {
201                    println!("No files found to retrieve");
202                }
203            }
204            Err(e) => {
205                eprintln!("Error listing files for retrieve test: {:?}", e);
206            }
207        }
208    }
209}