Skip to main content

claude_agent/client/
files.rs

1//! Files API client for managing uploaded files.
2
3use serde::{Deserialize, Serialize};
4use std::path::PathBuf;
5use url::form_urlencoded;
6
7use super::messages::ErrorResponse;
8use crate::{Error, Result};
9
10const FILES_API_BETA: &str = "files-api-2025-04-14";
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct File {
14    pub id: String,
15    #[serde(rename = "type")]
16    pub file_type: String,
17    pub filename: String,
18    pub mime_type: String,
19    pub size_bytes: u64,
20    pub created_at: String,
21    #[serde(default)]
22    pub downloadable: bool,
23}
24
25#[derive(Debug, Clone)]
26pub struct UploadFileRequest {
27    pub data: FileData,
28    pub filename: Option<String>,
29}
30
31impl UploadFileRequest {
32    pub fn from_bytes(data: Vec<u8>, mime_type: impl Into<String>) -> Self {
33        Self {
34            data: FileData::Bytes {
35                data,
36                mime_type: mime_type.into(),
37            },
38            filename: None,
39        }
40    }
41
42    pub fn from_path(path: impl Into<PathBuf>) -> Self {
43        Self {
44            data: FileData::Path(path.into()),
45            filename: None,
46        }
47    }
48
49    pub fn filename(mut self, filename: impl Into<String>) -> Self {
50        self.filename = Some(filename.into());
51        self
52    }
53}
54
55#[derive(Debug, Clone)]
56pub enum FileData {
57    Bytes { data: Vec<u8>, mime_type: String },
58    Path(PathBuf),
59}
60
61#[derive(Debug, Clone, Deserialize)]
62pub struct FileListResponse {
63    pub data: Vec<File>,
64    pub has_more: bool,
65    pub first_id: Option<String>,
66    pub last_id: Option<String>,
67}
68
69pub struct FileDownload {
70    response: reqwest::Response,
71    pub content_type: String,
72    pub content_length: Option<u64>,
73}
74
75impl FileDownload {
76    pub fn into_response(self) -> reqwest::Response {
77        self.response
78    }
79
80    pub fn bytes_stream(
81        self,
82    ) -> impl futures::Stream<Item = std::result::Result<bytes::Bytes, reqwest::Error>> {
83        self.response.bytes_stream()
84    }
85
86    pub async fn bytes(self) -> Result<bytes::Bytes> {
87        self.response.bytes().await.map_err(Error::Network)
88    }
89}
90
91pub struct FilesClient<'a> {
92    client: &'a super::Client,
93}
94
95impl<'a> FilesClient<'a> {
96    pub fn new(client: &'a super::Client) -> Self {
97        Self { client }
98    }
99
100    fn base_url(&self) -> &str {
101        self.client.adapter().base_url()
102    }
103
104    fn api_version(&self) -> &str {
105        &self.client.config().api_version
106    }
107
108    fn build_url(&self, path: &str) -> String {
109        format!("{}/v1/files{}", self.base_url(), path)
110    }
111
112    async fn build_request(&self, method: reqwest::Method, url: &str) -> reqwest::RequestBuilder {
113        if let Err(e) = self.client.adapter().ensure_fresh_credentials().await {
114            tracing::debug!("Proactive credential refresh failed: {}", e);
115        }
116
117        let req = self.client.http().request(method, url);
118        self.client
119            .adapter()
120            .apply_auth_headers(req)
121            .await
122            .header("anthropic-version", self.api_version())
123            .header("anthropic-beta", FILES_API_BETA)
124    }
125
126    pub async fn upload(&self, request: UploadFileRequest) -> Result<File> {
127        let url = self.build_url("");
128
129        let (data, mime_type, filename) = match request.data {
130            FileData::Bytes { data, mime_type } => {
131                let filename = request.filename.unwrap_or_else(|| "file".to_string());
132                (data, mime_type, filename)
133            }
134            FileData::Path(path) => {
135                let filename = request
136                    .filename
137                    .or_else(|| path.file_name().and_then(|n| n.to_str()).map(String::from))
138                    .unwrap_or_else(|| "file".to_string());
139
140                let data = tokio::fs::read(&path).await.map_err(Error::Io)?;
141
142                let mime_type = mime_guess::from_path(&path)
143                    .first_or_octet_stream()
144                    .to_string();
145
146                (data, mime_type, filename)
147            }
148        };
149
150        let part = reqwest::multipart::Part::bytes(data)
151            .file_name(filename)
152            .mime_str(&mime_type)
153            .map_err(|e| Error::Config(e.to_string()))?;
154
155        let form = reqwest::multipart::Form::new().part("file", part);
156
157        let response = self
158            .build_request(reqwest::Method::POST, &url)
159            .await
160            .multipart(form)
161            .send()
162            .await
163            .map_err(Error::Network)?;
164
165        self.handle_response(response).await
166    }
167
168    pub async fn get(&self, file_id: &str) -> Result<File> {
169        let url = self.build_url(&format!("/{}", file_id));
170        let response = self
171            .build_request(reqwest::Method::GET, &url)
172            .await
173            .send()
174            .await
175            .map_err(Error::Network)?;
176        self.handle_response(response).await
177    }
178
179    pub async fn download(&self, file_id: &str) -> Result<FileDownload> {
180        let url = self.build_url(&format!("/{}/content", file_id));
181        let response = self
182            .build_request(reqwest::Method::GET, &url)
183            .await
184            .send()
185            .await
186            .map_err(Error::Network)?;
187
188        if !response.status().is_success() {
189            let status = response.status().as_u16();
190            let error: ErrorResponse = response.json().await.map_err(Error::Network)?;
191            return Err(error.into_error(status));
192        }
193
194        let content_type = response
195            .headers()
196            .get(reqwest::header::CONTENT_TYPE)
197            .and_then(|v| v.to_str().ok())
198            .unwrap_or("application/octet-stream")
199            .to_string();
200
201        let content_length = response
202            .headers()
203            .get(reqwest::header::CONTENT_LENGTH)
204            .and_then(|v| v.to_str().ok())
205            .and_then(|v| v.parse().ok());
206
207        Ok(FileDownload {
208            response,
209            content_type,
210            content_length,
211        })
212    }
213
214    pub async fn download_bytes(&self, file_id: &str) -> Result<Vec<u8>> {
215        let download = self.download(file_id).await?;
216        let bytes = download.bytes().await?;
217        Ok(bytes.to_vec())
218    }
219
220    pub async fn delete(&self, file_id: &str) -> Result<()> {
221        let url = self.build_url(&format!("/{}", file_id));
222        let response = self
223            .build_request(reqwest::Method::DELETE, &url)
224            .await
225            .send()
226            .await
227            .map_err(Error::Network)?;
228        self.handle_response::<serde_json::Value>(response).await?;
229        Ok(())
230    }
231
232    pub async fn list(
233        &self,
234        limit: Option<u32>,
235        after_id: Option<&str>,
236    ) -> Result<FileListResponse> {
237        let mut url = self.build_url("");
238
239        let mut query_params: Vec<(&str, String)> = Vec::new();
240        if let Some(limit) = limit {
241            query_params.push(("limit", limit.to_string()));
242        }
243        if let Some(after_id) = after_id {
244            query_params.push(("after_id", after_id.to_string()));
245        }
246        if !query_params.is_empty() {
247            let encoded: String = form_urlencoded::Serializer::new(String::new())
248                .extend_pairs(query_params.iter().map(|(k, v)| (*k, v.as_str())))
249                .finish();
250            url = format!("{}?{}", url, encoded);
251        }
252
253        let response = self
254            .build_request(reqwest::Method::GET, &url)
255            .await
256            .send()
257            .await
258            .map_err(Error::Network)?;
259        self.handle_response(response).await
260    }
261
262    pub async fn list_all(&self) -> Result<Vec<File>> {
263        let mut all_files = Vec::new();
264        let mut after_id: Option<String> = None;
265
266        loop {
267            let response = self.list(Some(100), after_id.as_deref()).await?;
268            all_files.extend(response.data);
269
270            if !response.has_more {
271                break;
272            }
273            after_id = response.last_id;
274        }
275
276        Ok(all_files)
277    }
278
279    async fn handle_response<T: serde::de::DeserializeOwned>(
280        &self,
281        response: reqwest::Response,
282    ) -> Result<T> {
283        if !response.status().is_success() {
284            let status = response.status().as_u16();
285            let error: ErrorResponse = response.json().await.map_err(Error::Network)?;
286            return Err(error.into_error(status));
287        }
288
289        response.json().await.map_err(Error::Network)
290    }
291}
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296
297    #[test]
298    fn test_upload_request_from_bytes() {
299        let request = UploadFileRequest::from_bytes(vec![1, 2, 3], "image/png");
300        assert!(request.filename.is_none());
301    }
302
303    #[test]
304    fn test_upload_request_with_filename() {
305        let request =
306            UploadFileRequest::from_bytes(vec![1, 2, 3], "image/png").filename("test.png");
307        assert_eq!(request.filename, Some("test.png".to_string()));
308    }
309
310    #[test]
311    fn test_file_deserialization() {
312        let json = r#"{
313            "id": "file_abc123",
314            "type": "file",
315            "filename": "test.pdf",
316            "mime_type": "application/pdf",
317            "size_bytes": 1024,
318            "created_at": "2025-01-01T00:00:00Z",
319            "downloadable": false
320        }"#;
321        let file: File = serde_json::from_str(json).unwrap();
322        assert_eq!(file.id, "file_abc123");
323        assert_eq!(file.filename, "test.pdf");
324    }
325
326    #[test]
327    fn test_file_list_response_deserialization() {
328        let json = r#"{
329            "data": [],
330            "has_more": false,
331            "first_id": null,
332            "last_id": null
333        }"#;
334        let response: FileListResponse = serde_json::from_str(json).unwrap();
335        assert!(!response.has_more);
336        assert!(response.data.is_empty());
337    }
338}