replicate_client/api/
files.rs

1//! Files API for uploading and managing files.
2
3use crate::error::{Error, Result};
4use crate::http::HttpClient;
5use crate::models::file::{FileEncodingStrategy, FileInput};
6use base64::{Engine as _, engine::general_purpose};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::path::Path;
10
11/// Represents a file uploaded to Replicate.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct File {
14    /// The unique ID of the file.
15    pub id: String,
16    /// The name of the file.
17    pub name: String,
18    /// The content type of the file.
19    pub content_type: String,
20    /// The size of the file in bytes.
21    pub size: i64,
22    /// The ETag of the file.
23    pub etag: String,
24    /// File checksums.
25    pub checksums: HashMap<String, String>,
26    /// File metadata.
27    pub metadata: HashMap<String, serde_json::Value>,
28    /// When the file was created.
29    pub created_at: String,
30    /// When the file expires (optional).
31    pub expires_at: Option<String>,
32    /// File URLs.
33    pub urls: HashMap<String, String>,
34}
35
36/// Files API for managing file uploads.
37#[derive(Debug, Clone)]
38pub struct FilesApi {
39    http: HttpClient,
40}
41
42impl FilesApi {
43    /// Create a new Files API instance.
44    pub fn new(http: HttpClient) -> Self {
45        Self { http }
46    }
47
48    /// Upload a file from bytes with optional metadata.
49    pub async fn create_from_bytes(
50        &self,
51        file_content: &[u8],
52        filename: Option<&str>,
53        content_type: Option<&str>,
54        metadata: Option<&HashMap<String, serde_json::Value>>,
55    ) -> Result<File> {
56        let form =
57            HttpClient::create_file_form(file_content, filename, content_type, metadata).await?;
58
59        self.http.post_multipart_json("/v1/files", form).await
60    }
61
62    /// Upload a file from a local path.
63    pub async fn create_from_path(
64        &self,
65        file_path: &Path,
66        metadata: Option<&HashMap<String, serde_json::Value>>,
67    ) -> Result<File> {
68        let form = HttpClient::create_file_form_from_path(file_path, metadata).await?;
69        self.http.post_multipart_json("/v1/files", form).await
70    }
71
72    /// Upload a file from FileInput.
73    pub async fn create_from_file_input(
74        &self,
75        file_input: &FileInput,
76        metadata: Option<&HashMap<String, serde_json::Value>>,
77    ) -> Result<File> {
78        match file_input {
79            FileInput::Path(path) => self.create_from_path(path, metadata).await,
80            FileInput::Bytes {
81                data,
82                filename,
83                content_type,
84            } => {
85                self.create_from_bytes(data, filename.as_deref(), content_type.as_deref(), metadata)
86                    .await
87            }
88            FileInput::Url(_) => Err(Error::InvalidInput(
89                "Cannot upload from URL - file must be local or bytes".to_string(),
90            )),
91        }
92    }
93
94    /// Get a file by ID.
95    pub async fn get(&self, file_id: &str) -> Result<File> {
96        self.http.get_json(&format!("/v1/files/{}", file_id)).await
97    }
98
99    /// List all uploaded files.
100    pub async fn list(&self) -> Result<Vec<File>> {
101        #[derive(Deserialize)]
102        struct ListResponse {
103            results: Vec<File>,
104        }
105
106        let response: ListResponse = self.http.get_json("/v1/files").await?;
107        Ok(response.results)
108    }
109
110    /// Delete a file by ID.
111    pub async fn delete(&self, file_id: &str) -> Result<bool> {
112        let response = self.http.delete(&format!("/v1/files/{}", file_id)).await?;
113        Ok(response.status() == 204)
114    }
115}
116
117/// Helper to process file inputs based on encoding strategy.
118pub async fn process_file_input(
119    file_input: &FileInput,
120    encoding_strategy: &FileEncodingStrategy,
121    files_api: Option<&FilesApi>,
122) -> Result<String> {
123    match encoding_strategy {
124        FileEncodingStrategy::Base64DataUrl => encode_file_as_data_url(file_input).await,
125        FileEncodingStrategy::Multipart => {
126            if let Some(api) = files_api {
127                let file = api.create_from_file_input(file_input, None).await?;
128                // Return the file URL for use in predictions
129                file.urls
130                    .get("get")
131                    .cloned()
132                    .ok_or_else(|| Error::InvalidInput("File missing URL".to_string()))
133            } else {
134                Err(Error::InvalidInput(
135                    "Files API required for multipart upload".to_string(),
136                ))
137            }
138        }
139    }
140}
141
142/// Encode a file input as a base64 data URL.
143async fn encode_file_as_data_url(file_input: &FileInput) -> Result<String> {
144    match file_input {
145        FileInput::Url(_url) => {
146            // For URLs, we can't encode as data URL without downloading
147            Err(Error::InvalidInput(
148                "Cannot encode URL as data URL without downloading".to_string(),
149            ))
150        }
151        FileInput::Path(path) => {
152            let content = tokio::fs::read(path).await?;
153            let content_type = mime_guess::from_path(path)
154                .first_or_octet_stream()
155                .to_string();
156
157            let encoded = general_purpose::STANDARD.encode(&content);
158            Ok(format!("data:{};base64,{}", content_type, encoded))
159        }
160        FileInput::Bytes {
161            data, content_type, ..
162        } => {
163            let content_type = content_type
164                .as_deref()
165                .unwrap_or("application/octet-stream");
166
167            let encoded = general_purpose::STANDARD.encode(data);
168            Ok(format!("data:{};base64,{}", content_type, encoded))
169        }
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176    use tempfile::tempdir;
177
178    #[tokio::test]
179    async fn test_data_url_encoding() {
180        let file_input = FileInput::from_bytes_with_metadata(
181            &b"Hello, World!"[..],
182            Some("test.txt".to_string()),
183            Some("text/plain".to_string()),
184        );
185
186        let data_url = encode_file_as_data_url(&file_input).await.unwrap();
187        assert_eq!(data_url, "data:text/plain;base64,SGVsbG8sIFdvcmxkIQ==");
188    }
189
190    #[tokio::test]
191    async fn test_file_path_data_url() {
192        let temp_dir = tempdir().unwrap();
193        let file_path = temp_dir.path().join("test.txt");
194        tokio::fs::write(&file_path, b"Test content").await.unwrap();
195
196        let file_input = FileInput::from_path(&file_path);
197        let data_url = encode_file_as_data_url(&file_input).await.unwrap();
198
199        assert!(data_url.starts_with("data:text/plain;base64,"));
200        assert!(data_url.contains("VGVzdCBjb250ZW50")); // "Test content" in base64
201    }
202}