Skip to main content

raps_oss/
multipart.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2025 Dmytro Yemelianov
3
4//! Multipart upload operations for the OSS API.
5
6use anyhow::{Context, Result};
7use futures_util::StreamExt;
8use std::path::Path;
9use std::sync::Arc;
10use tokio::io::{AsyncReadExt, AsyncSeekExt, SeekFrom};
11use tokio::sync::Semaphore;
12
13use raps_kernel::progress;
14
15use crate::OssClient;
16use crate::types::*;
17
18impl OssClient {
19    /// Create a fresh multipart upload state with signed URLs
20    #[allow(clippy::too_many_arguments)]
21    async fn start_fresh_upload(
22        &self,
23        bucket_key: &str,
24        object_key: &str,
25        file_path: &Path,
26        total_parts: u32,
27        file_size: u64,
28        chunk_size: u64,
29        file_mtime: i64,
30    ) -> Result<(MultipartUploadState, Option<Vec<String>>)> {
31        let signed = self
32            .get_signed_upload_url(bucket_key, object_key, Some(total_parts), None)
33            .await?;
34        if signed.urls.len() != total_parts as usize {
35            anyhow::bail!(
36                "Expected {} URLs but got {}",
37                total_parts,
38                signed.urls.len()
39            );
40        }
41        let new_state = MultipartUploadState {
42            bucket_key: bucket_key.to_string(),
43            object_key: object_key.to_string(),
44            file_path: file_path.to_string_lossy().to_string(),
45            file_size,
46            chunk_size,
47            total_parts,
48            completed_parts: Vec::new(),
49            part_etags: std::collections::HashMap::new(),
50            upload_key: signed.upload_key,
51            started_at: chrono::Utc::now().timestamp(),
52            file_mtime,
53        };
54        new_state.save()?;
55        Ok((new_state, Some(signed.urls)))
56    }
57
58    /// Upload a large file using multipart upload with resume capability
59    pub async fn upload_multipart(
60        &self,
61        bucket_key: &str,
62        object_key: &str,
63        file_path: &Path,
64        resume: bool,
65    ) -> Result<ObjectInfo> {
66        let metadata = tokio::fs::metadata(file_path)
67            .await
68            .context("Failed to get file metadata")?;
69        let file_size = metadata.len();
70        let file_mtime = metadata
71            .modified()
72            .ok()
73            .and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok())
74            .map(|d| d.as_secs() as i64)
75            .unwrap_or(0);
76
77        let chunk_size = MultipartUploadState::DEFAULT_CHUNK_SIZE;
78        let total_parts = file_size.div_ceil(chunk_size) as u32;
79
80        let (state, initial_urls) = if resume {
81            if let Some(existing_state) = MultipartUploadState::load(bucket_key, object_key)? {
82                if existing_state.can_resume(file_path) {
83                    tracing::info!(
84                        "Resuming upload: {}/{} completed parts",
85                        existing_state.completed_parts.len(),
86                        existing_state.total_parts
87                    );
88                    (existing_state, None)
89                } else {
90                    tracing::info!("File changed since last upload, starting fresh");
91                    MultipartUploadState::delete(bucket_key, object_key)?;
92                    self.start_fresh_upload(
93                        bucket_key,
94                        object_key,
95                        file_path,
96                        total_parts,
97                        file_size,
98                        chunk_size,
99                        file_mtime,
100                    )
101                    .await?
102                }
103            } else {
104                self.start_fresh_upload(
105                    bucket_key,
106                    object_key,
107                    file_path,
108                    total_parts,
109                    file_size,
110                    chunk_size,
111                    file_mtime,
112                )
113                .await?
114            }
115        } else {
116            MultipartUploadState::delete(bucket_key, object_key)?;
117            self.start_fresh_upload(
118                bucket_key,
119                object_key,
120                file_path,
121                total_parts,
122                file_size,
123                chunk_size,
124                file_mtime,
125            )
126            .await?
127        };
128
129        // Create progress bar (hidden in non-interactive mode)
130        let pb = progress::file_progress(file_size, &format!("Uploading {}", object_key));
131
132        // Update progress if resuming
133        if !state.completed_parts.is_empty() {
134            let completed_bytes: u64 = state
135                .completed_parts
136                .iter()
137                .map(|&part| {
138                    let start = (part as u64 - 1) * state.chunk_size;
139                    let end = std::cmp::min(start + state.chunk_size, state.file_size);
140                    end - start
141                })
142                .sum();
143            pb.set_position(completed_bytes);
144            pb.set_message(format!(
145                "Resuming {} ({} parts done)",
146                object_key,
147                state.completed_parts.len()
148            ));
149        } else {
150            pb.set_message(format!("Starting multipart upload for {}", object_key));
151        }
152
153        // Get remaining parts to upload
154        let remaining_parts = state.remaining_parts();
155
156        if remaining_parts.is_empty() {
157            pb.set_message(format!("All parts uploaded, completing {}", object_key));
158        } else {
159            pb.set_message(format!(
160                "Uploading {} ({} parts remaining)",
161                object_key,
162                remaining_parts.len()
163            ));
164        }
165
166        let urls = if let Some(u) = initial_urls {
167            u
168        } else {
169            let signed = self
170                .get_signed_upload_url(bucket_key, object_key, Some(total_parts), None)
171                .await?;
172            signed.urls
173        };
174
175        // Upload remaining parts in parallel with bounded concurrency
176        use futures_util::stream::FuturesUnordered;
177        use tokio::sync::Mutex;
178
179        const MAX_CONCURRENT_UPLOADS: usize = 5;
180        let semaphore = Arc::new(Semaphore::new(MAX_CONCURRENT_UPLOADS));
181        let upload_key = state.upload_key.clone();
182        let state_mutex = Arc::new(Mutex::new(state));
183        let pb_arc = Arc::new(Mutex::new(pb));
184        let file_path_clone = file_path.to_path_buf();
185
186        // Create upload tasks
187        let upload_tasks: FuturesUnordered<_> = remaining_parts
188            .into_iter()
189            .map(|part_num| {
190                let part_index = (part_num - 1) as usize;
191                let start = (part_num as u64 - 1) * chunk_size;
192                let end = std::cmp::min(start + chunk_size, file_size);
193                let part_size = end - start;
194                let s3_url = urls[part_index].clone();
195                let client = self.http_client.clone();
196                let semaphore = semaphore.clone();
197                let state_mutex = state_mutex.clone();
198                let pb_arc = pb_arc.clone();
199                let object_key = object_key.to_string();
200                let file_path = file_path_clone.clone();
201
202                async move {
203                    // Acquire semaphore permit to limit concurrency
204                    let _permit = semaphore
205                        .acquire()
206                        .await
207                        .map_err(|_| anyhow::anyhow!("Upload cancelled"))?;
208
209                    // Read file chunk
210                    let buffer = {
211                        let mut file =
212                            tokio::fs::File::open(&file_path).await.with_context(|| {
213                                format!("Failed to open file for part {}", part_num)
214                            })?;
215                        file.seek(SeekFrom::Start(start)).await?;
216                        let mut buffer = vec![0u8; part_size as usize];
217                        file.read_exact(&mut buffer).await?;
218                        buffer
219                    };
220
221                    // Upload part with retry logic
222                    let mut attempts = 0;
223                    const MAX_RETRIES: usize = 3;
224                    let mut total_part_network_time = std::time::Duration::ZERO;
225
226                    loop {
227                        attempts += 1;
228
229                        let _part_start = std::time::Instant::now();
230                        let response = client
231                            .put(&s3_url)
232                            .header("Content-Type", "application/octet-stream")
233                            .header("Content-Length", part_size.to_string())
234                            .body(buffer.clone())
235                            .send()
236                            .await;
237                        total_part_network_time += _part_start.elapsed();
238
239                        match response {
240                            Ok(resp) if resp.status().is_success() => {
241                                // Get ETag from response
242                                let etag = resp
243                                    .headers()
244                                    .get("etag")
245                                    .and_then(|v| v.to_str().ok())
246                                    .map(|s| s.trim_matches('"').to_string())
247                                    .unwrap_or_default();
248
249                                // Update state atomically
250                                {
251                                    let mut state_guard = state_mutex.lock().await;
252                                    state_guard.completed_parts.push(part_num);
253                                    state_guard.part_etags.insert(part_num, etag);
254                                    if let Err(e) = state_guard.save() {
255                                        tracing::warn!(error = %e, "Failed to save upload state");
256                                    }
257                                }
258
259                                // Update progress bar
260                                {
261                                    let pb_guard = pb_arc.lock().await;
262                                    pb_guard.set_position(end);
263                                    pb_guard.set_message(format!(
264                                        "Uploading {} ({} parts completed)",
265                                        object_key, part_num
266                                    ));
267                                }
268
269                                raps_kernel::profiler::record_http_request(total_part_network_time);
270                                return Ok::<_, anyhow::Error>(part_num);
271                            }
272                            Ok(resp) => {
273                                let status = resp.status();
274                                let error_text = resp.text().await.unwrap_or_default();
275                                if attempts >= MAX_RETRIES {
276                                    raps_kernel::profiler::record_http_request(
277                                        total_part_network_time,
278                                    );
279                                    anyhow::bail!(
280                                        "Failed to upload part {} after {} attempts ({}): {}",
281                                        part_num,
282                                        attempts,
283                                        status,
284                                        error_text
285                                    );
286                                }
287                                raps_kernel::profiler::record_http_retry();
288                                // Wait before retry with exponential backoff
289                                let delay =
290                                    std::time::Duration::from_millis(100 * (1 << (attempts - 1)));
291                                tokio::time::sleep(delay).await;
292                            }
293                            Err(e) => {
294                                if attempts >= MAX_RETRIES {
295                                    raps_kernel::profiler::record_http_request(
296                                        total_part_network_time,
297                                    );
298                                    anyhow::bail!(
299                                        "Failed to upload part {} after {} attempts: {}",
300                                        part_num,
301                                        attempts,
302                                        e
303                                    );
304                                }
305                                raps_kernel::profiler::record_http_retry();
306                                // Wait before retry
307                                let delay =
308                                    std::time::Duration::from_millis(100 * (1 << (attempts - 1)));
309                                tokio::time::sleep(delay).await;
310                            }
311                        }
312                    }
313                }
314            })
315            .collect();
316
317        // Execute all upload tasks concurrently
318        let mut upload_results = Vec::new();
319        let mut upload_stream = upload_tasks;
320
321        while let Some(result) = upload_stream.next().await {
322            match result {
323                Ok(part_num) => {
324                    upload_results.push(part_num);
325                }
326                Err(e) => {
327                    return Err(e);
328                }
329            }
330        }
331
332        // Get the progress bar back from the Arc<Mutex<>>
333        let pb = match Arc::try_unwrap(pb_arc) {
334            Ok(mutex) => mutex.into_inner(),
335            Err(arc) => arc.lock().await.clone(),
336        };
337
338        // Complete the upload
339        pb.set_message(format!("Completing upload for {}", object_key));
340        let object_info = self
341            .complete_signed_upload(bucket_key, object_key, &upload_key)
342            .await?;
343
344        // Clean up state file
345        MultipartUploadState::delete(bucket_key, object_key)?;
346
347        pb.finish_with_message(format!("Uploaded {} (multipart)", object_key));
348
349        Ok(object_info)
350    }
351}