1use std::path::{Path, PathBuf};
6
7use rc_core::Result;
8
9pub const DEFAULT_PART_SIZE: u64 = 64 * 1024 * 1024;
11
12pub const MIN_PART_SIZE: u64 = 5 * 1024 * 1024;
14
15pub const MAX_PART_SIZE: u64 = 5 * 1024 * 1024 * 1024;
17
18pub const MAX_PARTS: usize = 10_000;
20
21#[derive(Debug, Clone)]
23pub struct MultipartConfig {
24 pub part_size: u64,
26
27 pub concurrency: usize,
29
30 pub state_dir: Option<PathBuf>,
32}
33
34impl Default for MultipartConfig {
35 fn default() -> Self {
36 Self {
37 part_size: DEFAULT_PART_SIZE,
38 concurrency: 4,
39 state_dir: None,
40 }
41 }
42}
43
44impl MultipartConfig {
45 pub fn new() -> Self {
46 Self::default()
47 }
48
49 pub fn part_size(mut self, size: u64) -> Self {
50 self.part_size = size.clamp(MIN_PART_SIZE, MAX_PART_SIZE);
51 self
52 }
53
54 pub fn concurrency(mut self, n: usize) -> Self {
55 self.concurrency = n.max(1);
56 self
57 }
58
59 pub fn state_dir(mut self, path: impl Into<PathBuf>) -> Self {
60 self.state_dir = Some(path.into());
61 self
62 }
63
64 pub fn calculate_part_size(&self, file_size: u64) -> u64 {
66 if file_size <= MIN_PART_SIZE {
68 return MIN_PART_SIZE;
69 }
70
71 let parts = file_size.div_ceil(self.part_size);
73
74 if parts <= MAX_PARTS as u64 {
75 self.part_size
76 } else {
77 let required_size = file_size.div_ceil(MAX_PARTS as u64);
79 required_size.clamp(MIN_PART_SIZE, MAX_PART_SIZE)
80 }
81 }
82}
83
84#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
86pub struct UploadState {
87 pub upload_id: String,
89
90 pub target: String,
92
93 pub source: Option<String>,
95
96 pub total_size: u64,
98
99 pub part_size: u64,
101
102 pub completed_parts: Vec<CompletedPart>,
104
105 pub last_updated: jiff::Timestamp,
107}
108
109#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
110pub struct CompletedPart {
111 pub part_number: i32,
112 pub etag: String,
113}
114
115impl UploadState {
116 pub fn new(
118 upload_id: impl Into<String>,
119 target: impl Into<String>,
120 total_size: u64,
121 part_size: u64,
122 ) -> Self {
123 Self {
124 upload_id: upload_id.into(),
125 target: target.into(),
126 source: None,
127 total_size,
128 part_size,
129 completed_parts: Vec::new(),
130 last_updated: jiff::Timestamp::now(),
131 }
132 }
133
134 pub fn with_source(mut self, source: impl Into<String>) -> Self {
136 self.source = Some(source.into());
137 self
138 }
139
140 pub fn add_completed_part(&mut self, part_number: i32, etag: String) {
142 self.completed_parts
143 .push(CompletedPart { part_number, etag });
144 self.last_updated = jiff::Timestamp::now();
145 }
146
147 pub fn next_part_number(&self) -> i32 {
149 self.completed_parts
150 .iter()
151 .map(|p| p.part_number)
152 .max()
153 .map(|n| n + 1)
154 .unwrap_or(1)
155 }
156
157 pub fn progress_percent(&self) -> f64 {
159 let completed_bytes = self.completed_parts.len() as u64 * self.part_size;
160 (completed_bytes as f64 / self.total_size as f64 * 100.0).min(100.0)
161 }
162
163 pub fn state_file_path(state_dir: &Path, upload_id: &str) -> PathBuf {
165 let safe_id: String = upload_id
167 .chars()
168 .map(|c| if c.is_alphanumeric() { c } else { '_' })
169 .collect();
170 state_dir.join(format!("upload_{safe_id}.json"))
171 }
172
173 pub fn save(&self, state_dir: &Path) -> Result<()> {
175 let path = Self::state_file_path(state_dir, &self.upload_id);
176
177 if let Some(parent) = path.parent() {
179 std::fs::create_dir_all(parent)?;
180 }
181
182 let json = serde_json::to_string_pretty(self)?;
183 std::fs::write(&path, json)?;
184 Ok(())
185 }
186
187 pub fn load(state_dir: &Path, upload_id: &str) -> Result<Self> {
189 let path = Self::state_file_path(state_dir, upload_id);
190 let content = std::fs::read_to_string(&path)?;
191 let state: Self = serde_json::from_str(&content)?;
192 Ok(state)
193 }
194
195 pub fn delete(state_dir: &Path, upload_id: &str) -> Result<()> {
197 let path = Self::state_file_path(state_dir, upload_id);
198 if path.exists() {
199 std::fs::remove_file(&path)?;
200 }
201 Ok(())
202 }
203
204 pub fn find_pending(state_dir: &Path, target: &str) -> Result<Vec<Self>> {
206 let mut pending = Vec::new();
207
208 if !state_dir.exists() {
209 return Ok(pending);
210 }
211
212 for entry in std::fs::read_dir(state_dir)? {
213 let entry = entry?;
214 let path = entry.path();
215
216 if path.extension().map(|e| e == "json").unwrap_or(false)
217 && let Ok(content) = std::fs::read_to_string(&path)
218 && let Ok(state) = serde_json::from_str::<Self>(&content)
219 && state.target == target
220 {
221 pending.push(state);
222 }
223 }
224
225 Ok(pending)
226 }
227}
228
229pub fn calculate_parts(file_size: u64, part_size: u64) -> usize {
231 file_size.div_ceil(part_size) as usize
232}
233
234pub fn part_byte_range(part_number: i32, part_size: u64, total_size: u64) -> (u64, u64) {
236 let start = (part_number as u64 - 1) * part_size;
237 let end = (start + part_size).min(total_size);
238 (start, end)
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244
245 #[test]
246 fn test_default_config() {
247 let config = MultipartConfig::default();
248 assert_eq!(config.part_size, DEFAULT_PART_SIZE);
249 assert_eq!(config.concurrency, 4);
250 }
251
252 #[test]
253 fn test_config_builder() {
254 let config = MultipartConfig::new()
255 .part_size(128 * 1024 * 1024)
256 .concurrency(8);
257
258 assert_eq!(config.part_size, 128 * 1024 * 1024);
259 assert_eq!(config.concurrency, 8);
260 }
261
262 #[test]
263 fn test_part_size_clamping() {
264 let config = MultipartConfig::new().part_size(1024);
266 assert_eq!(config.part_size, MIN_PART_SIZE);
267
268 let config = MultipartConfig::new().part_size(10 * 1024 * 1024 * 1024);
270 assert_eq!(config.part_size, MAX_PART_SIZE);
271 }
272
273 #[test]
274 fn test_calculate_part_size_small_file() {
275 let config = MultipartConfig::default();
276 let size = config.calculate_part_size(1024 * 1024); assert_eq!(size, MIN_PART_SIZE);
278 }
279
280 #[test]
281 fn test_calculate_part_size_large_file() {
282 let config = MultipartConfig::default();
283 let huge_file = DEFAULT_PART_SIZE * 20_000;
285 let size = config.calculate_part_size(huge_file);
286 let parts = calculate_parts(huge_file, size);
287 assert!(parts <= MAX_PARTS);
288 }
289
290 #[test]
291 fn test_upload_state() {
292 let mut state = UploadState::new("upload-123", "bucket/key", 1000, 100);
293 assert_eq!(state.next_part_number(), 1);
294
295 state.add_completed_part(1, "etag1".to_string());
296 assert_eq!(state.next_part_number(), 2);
297
298 state.add_completed_part(2, "etag2".to_string());
299 assert_eq!(state.next_part_number(), 3);
300 }
301
302 #[test]
303 fn test_progress_percent() {
304 let mut state = UploadState::new("upload-123", "bucket/key", 1000, 100);
305 assert_eq!(state.progress_percent(), 0.0);
306
307 state.add_completed_part(1, "etag1".to_string());
308 assert_eq!(state.progress_percent(), 10.0);
309
310 state.add_completed_part(2, "etag2".to_string());
311 assert_eq!(state.progress_percent(), 20.0);
312 }
313
314 #[test]
315 fn test_calculate_parts() {
316 assert_eq!(calculate_parts(100, 10), 10);
317 assert_eq!(calculate_parts(101, 10), 11);
318 assert_eq!(calculate_parts(99, 10), 10);
319 }
320
321 #[test]
322 fn test_part_byte_range() {
323 let (start, end) = part_byte_range(1, 100, 250);
325 assert_eq!(start, 0);
326 assert_eq!(end, 100);
327
328 let (start, end) = part_byte_range(2, 100, 250);
330 assert_eq!(start, 100);
331 assert_eq!(end, 200);
332
333 let (start, end) = part_byte_range(3, 100, 250);
335 assert_eq!(start, 200);
336 assert_eq!(end, 250);
337 }
338
339 #[test]
340 fn test_calculate_part_size_default_is_sufficient_for_common_sizes() {
341 let config = MultipartConfig::default();
342
343 let size = config.calculate_part_size(100 * 1024 * 1024);
345 assert_eq!(size, DEFAULT_PART_SIZE);
346
347 let size = config.calculate_part_size(1024 * 1024 * 1024);
349 assert_eq!(size, DEFAULT_PART_SIZE);
350
351 let size_500g = 500 * 1024 * 1024 * 1024_u64;
353 let size = config.calculate_part_size(size_500g);
354 let parts = calculate_parts(size_500g, size);
355 assert!(parts <= MAX_PARTS);
356 assert!(size >= MIN_PART_SIZE);
357 }
358
359 #[test]
360 fn test_part_byte_range_covers_full_file() {
361 let total = 250_u64;
362 let part_size = 100_u64;
363 let num_parts = calculate_parts(total, part_size);
364 assert_eq!(num_parts, 3);
365
366 let mut covered = 0_u64;
368 for part_number in 1..=(num_parts as i32) {
369 let (start, end) = part_byte_range(part_number, part_size, total);
370 assert_eq!(start, covered);
371 covered = end;
372 }
373 assert_eq!(covered, total);
374 }
375}