replicate_client/api/
files.rs1use crate::error::{Error, Result};
4use crate::http::HttpClient;
5use crate::models::file::{FileEncodingStrategy, FileInput};
6use base64::{Engine as _, engine::general_purpose};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::path::Path;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct File {
14 pub id: String,
16 pub name: String,
18 pub content_type: String,
20 pub size: i64,
22 pub etag: String,
24 pub checksums: HashMap<String, String>,
26 pub metadata: HashMap<String, serde_json::Value>,
28 pub created_at: String,
30 pub expires_at: Option<String>,
32 pub urls: HashMap<String, String>,
34}
35
36#[derive(Debug, Clone)]
38pub struct FilesApi {
39 http: HttpClient,
40}
41
42impl FilesApi {
43 pub fn new(http: HttpClient) -> Self {
45 Self { http }
46 }
47
48 pub async fn create_from_bytes(
50 &self,
51 file_content: &[u8],
52 filename: Option<&str>,
53 content_type: Option<&str>,
54 metadata: Option<&HashMap<String, serde_json::Value>>,
55 ) -> Result<File> {
56 let form =
57 HttpClient::create_file_form(file_content, filename, content_type, metadata).await?;
58
59 self.http.post_multipart_json("/v1/files", form).await
60 }
61
62 pub async fn create_from_path(
64 &self,
65 file_path: &Path,
66 metadata: Option<&HashMap<String, serde_json::Value>>,
67 ) -> Result<File> {
68 let form = HttpClient::create_file_form_from_path(file_path, metadata).await?;
69 self.http.post_multipart_json("/v1/files", form).await
70 }
71
72 pub async fn create_from_file_input(
74 &self,
75 file_input: &FileInput,
76 metadata: Option<&HashMap<String, serde_json::Value>>,
77 ) -> Result<File> {
78 match file_input {
79 FileInput::Path(path) => self.create_from_path(path, metadata).await,
80 FileInput::Bytes {
81 data,
82 filename,
83 content_type,
84 } => {
85 self.create_from_bytes(data, filename.as_deref(), content_type.as_deref(), metadata)
86 .await
87 }
88 FileInput::Url(_) => Err(Error::InvalidInput(
89 "Cannot upload from URL - file must be local or bytes".to_string(),
90 )),
91 }
92 }
93
94 pub async fn get(&self, file_id: &str) -> Result<File> {
96 self.http.get_json(&format!("/v1/files/{}", file_id)).await
97 }
98
99 pub async fn list(&self) -> Result<Vec<File>> {
101 #[derive(Deserialize)]
102 struct ListResponse {
103 results: Vec<File>,
104 }
105
106 let response: ListResponse = self.http.get_json("/v1/files").await?;
107 Ok(response.results)
108 }
109
110 pub async fn delete(&self, file_id: &str) -> Result<bool> {
112 let response = self.http.delete(&format!("/v1/files/{}", file_id)).await?;
113 Ok(response.status() == 204)
114 }
115}
116
117pub async fn process_file_input(
119 file_input: &FileInput,
120 encoding_strategy: &FileEncodingStrategy,
121 files_api: Option<&FilesApi>,
122) -> Result<String> {
123 match encoding_strategy {
124 FileEncodingStrategy::Base64DataUrl => encode_file_as_data_url(file_input).await,
125 FileEncodingStrategy::Multipart => {
126 if let Some(api) = files_api {
127 let file = api.create_from_file_input(file_input, None).await?;
128 file.urls
130 .get("get")
131 .cloned()
132 .ok_or_else(|| Error::InvalidInput("File missing URL".to_string()))
133 } else {
134 Err(Error::InvalidInput(
135 "Files API required for multipart upload".to_string(),
136 ))
137 }
138 }
139 }
140}
141
142async fn encode_file_as_data_url(file_input: &FileInput) -> Result<String> {
144 match file_input {
145 FileInput::Url(_url) => {
146 Err(Error::InvalidInput(
148 "Cannot encode URL as data URL without downloading".to_string(),
149 ))
150 }
151 FileInput::Path(path) => {
152 let content = tokio::fs::read(path).await?;
153 let content_type = mime_guess::from_path(path)
154 .first_or_octet_stream()
155 .to_string();
156
157 let encoded = general_purpose::STANDARD.encode(&content);
158 Ok(format!("data:{};base64,{}", content_type, encoded))
159 }
160 FileInput::Bytes {
161 data, content_type, ..
162 } => {
163 let content_type = content_type
164 .as_deref()
165 .unwrap_or("application/octet-stream");
166
167 let encoded = general_purpose::STANDARD.encode(data);
168 Ok(format!("data:{};base64,{}", content_type, encoded))
169 }
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176 use tempfile::tempdir;
177
178 #[tokio::test]
179 async fn test_data_url_encoding() {
180 let file_input = FileInput::from_bytes_with_metadata(
181 &b"Hello, World!"[..],
182 Some("test.txt".to_string()),
183 Some("text/plain".to_string()),
184 );
185
186 let data_url = encode_file_as_data_url(&file_input).await.unwrap();
187 assert_eq!(data_url, "data:text/plain;base64,SGVsbG8sIFdvcmxkIQ==");
188 }
189
190 #[tokio::test]
191 async fn test_file_path_data_url() {
192 let temp_dir = tempdir().unwrap();
193 let file_path = temp_dir.path().join("test.txt");
194 tokio::fs::write(&file_path, b"Test content").await.unwrap();
195
196 let file_input = FileInput::from_path(&file_path);
197 let data_url = encode_file_as_data_url(&file_input).await.unwrap();
198
199 assert!(data_url.starts_with("data:text/plain;base64,"));
200 assert!(data_url.contains("VGVzdCBjb250ZW50")); }
202}