1use serde::{Deserialize, Serialize};
4use std::path::PathBuf;
5use url::form_urlencoded;
6
7use super::messages::ErrorResponse;
8use crate::{Error, Result};
9
10const FILES_API_BETA: &str = "files-api-2025-04-14";
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct File {
14 pub id: String,
15 #[serde(rename = "type")]
16 pub file_type: String,
17 pub filename: String,
18 pub mime_type: String,
19 pub size_bytes: u64,
20 pub created_at: String,
21 #[serde(default)]
22 pub downloadable: bool,
23}
24
25#[derive(Debug, Clone)]
26pub struct UploadFileRequest {
27 pub data: FileData,
28 pub filename: Option<String>,
29}
30
31impl UploadFileRequest {
32 pub fn from_bytes(data: Vec<u8>, mime_type: impl Into<String>) -> Self {
33 Self {
34 data: FileData::Bytes {
35 data,
36 mime_type: mime_type.into(),
37 },
38 filename: None,
39 }
40 }
41
42 pub fn from_path(path: impl Into<PathBuf>) -> Self {
43 Self {
44 data: FileData::Path(path.into()),
45 filename: None,
46 }
47 }
48
49 pub fn filename(mut self, filename: impl Into<String>) -> Self {
50 self.filename = Some(filename.into());
51 self
52 }
53}
54
55#[derive(Debug, Clone)]
56pub enum FileData {
57 Bytes { data: Vec<u8>, mime_type: String },
58 Path(PathBuf),
59}
60
61#[derive(Debug, Clone, Deserialize)]
62pub struct FileListResponse {
63 pub data: Vec<File>,
64 pub has_more: bool,
65 pub first_id: Option<String>,
66 pub last_id: Option<String>,
67}
68
69pub struct FileDownload {
70 response: reqwest::Response,
71 pub content_type: String,
72 pub content_length: Option<u64>,
73}
74
75impl FileDownload {
76 pub fn into_response(self) -> reqwest::Response {
77 self.response
78 }
79
80 pub fn bytes_stream(
81 self,
82 ) -> impl futures::Stream<Item = std::result::Result<bytes::Bytes, reqwest::Error>> {
83 self.response.bytes_stream()
84 }
85
86 pub async fn bytes(self) -> Result<bytes::Bytes> {
87 self.response.bytes().await.map_err(Error::Network)
88 }
89}
90
91pub struct FilesClient<'a> {
92 client: &'a super::Client,
93}
94
95impl<'a> FilesClient<'a> {
96 pub fn new(client: &'a super::Client) -> Self {
97 Self { client }
98 }
99
100 fn base_url(&self) -> &str {
101 self.client.adapter().base_url()
102 }
103
104 fn api_version(&self) -> &str {
105 &self.client.config().api_version
106 }
107
108 fn build_url(&self, path: &str) -> String {
109 format!("{}/v1/files{}", self.base_url(), path)
110 }
111
112 async fn build_request(&self, method: reqwest::Method, url: &str) -> reqwest::RequestBuilder {
113 if let Err(e) = self.client.adapter().ensure_fresh_credentials().await {
114 tracing::debug!("Proactive credential refresh failed: {}", e);
115 }
116
117 let req = self.client.http().request(method, url);
118 self.client
119 .adapter()
120 .apply_auth_headers(req)
121 .await
122 .header("anthropic-version", self.api_version())
123 .header("anthropic-beta", FILES_API_BETA)
124 }
125
126 pub async fn upload(&self, request: UploadFileRequest) -> Result<File> {
127 let url = self.build_url("");
128
129 let (data, mime_type, filename) = match request.data {
130 FileData::Bytes { data, mime_type } => {
131 let filename = request.filename.unwrap_or_else(|| "file".to_string());
132 (data, mime_type, filename)
133 }
134 FileData::Path(path) => {
135 let filename = request
136 .filename
137 .or_else(|| path.file_name().and_then(|n| n.to_str()).map(String::from))
138 .unwrap_or_else(|| "file".to_string());
139
140 let data = tokio::fs::read(&path).await.map_err(Error::Io)?;
141
142 let mime_type = mime_guess::from_path(&path)
143 .first_or_octet_stream()
144 .to_string();
145
146 (data, mime_type, filename)
147 }
148 };
149
150 let part = reqwest::multipart::Part::bytes(data)
151 .file_name(filename)
152 .mime_str(&mime_type)
153 .map_err(|e| Error::Config(e.to_string()))?;
154
155 let form = reqwest::multipart::Form::new().part("file", part);
156
157 let response = self
158 .build_request(reqwest::Method::POST, &url)
159 .await
160 .multipart(form)
161 .send()
162 .await
163 .map_err(Error::Network)?;
164
165 self.handle_response(response).await
166 }
167
168 pub async fn get(&self, file_id: &str) -> Result<File> {
169 let url = self.build_url(&format!("/{}", file_id));
170 let response = self
171 .build_request(reqwest::Method::GET, &url)
172 .await
173 .send()
174 .await
175 .map_err(Error::Network)?;
176 self.handle_response(response).await
177 }
178
179 pub async fn download(&self, file_id: &str) -> Result<FileDownload> {
180 let url = self.build_url(&format!("/{}/content", file_id));
181 let response = self
182 .build_request(reqwest::Method::GET, &url)
183 .await
184 .send()
185 .await
186 .map_err(Error::Network)?;
187
188 if !response.status().is_success() {
189 let status = response.status().as_u16();
190 let error: ErrorResponse = response.json().await.map_err(Error::Network)?;
191 return Err(error.into_error(status));
192 }
193
194 let content_type = response
195 .headers()
196 .get(reqwest::header::CONTENT_TYPE)
197 .and_then(|v| v.to_str().ok())
198 .unwrap_or("application/octet-stream")
199 .to_string();
200
201 let content_length = response
202 .headers()
203 .get(reqwest::header::CONTENT_LENGTH)
204 .and_then(|v| v.to_str().ok())
205 .and_then(|v| v.parse().ok());
206
207 Ok(FileDownload {
208 response,
209 content_type,
210 content_length,
211 })
212 }
213
214 pub async fn download_bytes(&self, file_id: &str) -> Result<Vec<u8>> {
215 let download = self.download(file_id).await?;
216 let bytes = download.bytes().await?;
217 Ok(bytes.to_vec())
218 }
219
220 pub async fn delete(&self, file_id: &str) -> Result<()> {
221 let url = self.build_url(&format!("/{}", file_id));
222 let response = self
223 .build_request(reqwest::Method::DELETE, &url)
224 .await
225 .send()
226 .await
227 .map_err(Error::Network)?;
228 self.handle_response::<serde_json::Value>(response).await?;
229 Ok(())
230 }
231
232 pub async fn list(
233 &self,
234 limit: Option<u32>,
235 after_id: Option<&str>,
236 ) -> Result<FileListResponse> {
237 let mut url = self.build_url("");
238
239 let mut query_params: Vec<(&str, String)> = Vec::new();
240 if let Some(limit) = limit {
241 query_params.push(("limit", limit.to_string()));
242 }
243 if let Some(after_id) = after_id {
244 query_params.push(("after_id", after_id.to_string()));
245 }
246 if !query_params.is_empty() {
247 let encoded: String = form_urlencoded::Serializer::new(String::new())
248 .extend_pairs(query_params.iter().map(|(k, v)| (*k, v.as_str())))
249 .finish();
250 url = format!("{}?{}", url, encoded);
251 }
252
253 let response = self
254 .build_request(reqwest::Method::GET, &url)
255 .await
256 .send()
257 .await
258 .map_err(Error::Network)?;
259 self.handle_response(response).await
260 }
261
262 pub async fn list_all(&self) -> Result<Vec<File>> {
263 let mut all_files = Vec::new();
264 let mut after_id: Option<String> = None;
265
266 loop {
267 let response = self.list(Some(100), after_id.as_deref()).await?;
268 all_files.extend(response.data);
269
270 if !response.has_more {
271 break;
272 }
273 after_id = response.last_id;
274 }
275
276 Ok(all_files)
277 }
278
279 async fn handle_response<T: serde::de::DeserializeOwned>(
280 &self,
281 response: reqwest::Response,
282 ) -> Result<T> {
283 if !response.status().is_success() {
284 let status = response.status().as_u16();
285 let error: ErrorResponse = response.json().await.map_err(Error::Network)?;
286 return Err(error.into_error(status));
287 }
288
289 response.json().await.map_err(Error::Network)
290 }
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296
297 #[test]
298 fn test_upload_request_from_bytes() {
299 let request = UploadFileRequest::from_bytes(vec![1, 2, 3], "image/png");
300 assert!(request.filename.is_none());
301 }
302
303 #[test]
304 fn test_upload_request_with_filename() {
305 let request =
306 UploadFileRequest::from_bytes(vec![1, 2, 3], "image/png").filename("test.png");
307 assert_eq!(request.filename, Some("test.png".to_string()));
308 }
309
310 #[test]
311 fn test_file_deserialization() {
312 let json = r#"{
313 "id": "file_abc123",
314 "type": "file",
315 "filename": "test.pdf",
316 "mime_type": "application/pdf",
317 "size_bytes": 1024,
318 "created_at": "2025-01-01T00:00:00Z",
319 "downloadable": false
320 }"#;
321 let file: File = serde_json::from_str(json).unwrap();
322 assert_eq!(file.id, "file_abc123");
323 assert_eq!(file.filename, "test.pdf");
324 }
325
326 #[test]
327 fn test_file_list_response_deserialization() {
328 let json = r#"{
329 "data": [],
330 "has_more": false,
331 "first_id": null,
332 "last_id": null
333 }"#;
334 let response: FileListResponse = serde_json::from_str(json).unwrap();
335 assert!(!response.has_more);
336 assert!(response.data.is_empty());
337 }
338}