Skip to main content

marple_db/
upload.rs

1use crate::models::{
2    Dataset, IngestionInit, PartUrl, PartUrlsResponse, PushFileOptions, UploadMode,
3    UploadModeOverride,
4};
5use crate::{Error, MarpleDB, ProgressReporter, Result};
6use base64::Engine;
7use futures_util::StreamExt;
8use reqwest::{Body, Response, header::CONTENT_LENGTH};
9use serde_json::Value;
10use std::io::SeekFrom;
11use std::path::{Path, PathBuf};
12use std::sync::{
13    Arc,
14    atomic::{AtomicU64, Ordering},
15};
16use tokio::io::{AsyncReadExt, AsyncSeekExt};
17use tokio::sync::Mutex;
18use tokio_util::io::ReaderStream;
19
20const PROGRESS_STREAM_CHUNK: usize = 256 * 1024;
21
22#[derive(Clone)]
23struct MultipartUploadContext {
24    file_path: PathBuf,
25    part_size: u64,
26    total_size: u64,
27    uploaded: Arc<AtomicU64>,
28    progress: Arc<dyn ProgressReporter>,
29}
30
31#[derive(Clone)]
32struct BlockDescriptor {
33    offset: u64,
34    length: u64,
35    block_id: String,
36}
37
38#[derive(Clone)]
39struct AzureBlockUploadContext {
40    sas_url: Arc<reqwest::Url>,
41    file_path: PathBuf,
42    uploaded: Arc<AtomicU64>,
43    progress: Arc<dyn ProgressReporter>,
44}
45
46fn progress_reporting_stream(
47    data: Vec<u8>,
48    uploaded: Arc<AtomicU64>,
49    progress: Arc<dyn ProgressReporter>,
50) -> impl futures_util::Stream<Item = std::io::Result<Vec<u8>>> + Send + 'static {
51    async_stream::stream! {
52        let mut pos = 0;
53        while pos < data.len() {
54            let end = (pos + PROGRESS_STREAM_CHUNK).min(data.len());
55            let chunk = data[pos..end].to_vec();
56            let chunk_len = chunk.len() as u64;
57            let new_uploaded = uploaded.fetch_add(chunk_len, Ordering::Relaxed) + chunk_len;
58            progress.set_position(new_uploaded);
59            yield Ok(chunk);
60            pos = end;
61        }
62    }
63}
64
65fn azure_block_descriptors(total_size: u64, block_size: u64) -> Vec<BlockDescriptor> {
66    if total_size == 0 {
67        return Vec::new();
68    }
69
70    let n_blocks = total_size.div_ceil(block_size);
71    (0..n_blocks as u32)
72        .map(|block_number| {
73            let offset = u64::from(block_number) * block_size;
74            let length = block_size.min(total_size - offset);
75            BlockDescriptor {
76                offset,
77                length,
78                block_id: format!("{block_number:08}"),
79            }
80        })
81        .collect()
82}
83
84async fn ensure_success(response: Response, failure_message: impl Into<String>) -> Result<()> {
85    if response.status().is_success() {
86        Ok(())
87    } else {
88        let context = failure_message.into();
89        let status = response.status();
90        let body = response.text().await.map_err(|source| Error::Storage {
91            context: context.clone(),
92            status: Some(status),
93            body: None,
94            source: Some(source),
95        })?;
96        Err(Error::Storage {
97            context,
98            status: Some(status),
99            body: Some(body),
100            source: None,
101        })
102    }
103}
104
105async fn send_storage(
106    request: reqwest::RequestBuilder,
107    context: impl Into<String>,
108) -> Result<Response> {
109    let context = context.into();
110    request.send().await.map_err(|source| Error::Storage {
111        context,
112        status: None,
113        body: None,
114        source: Some(source),
115    })
116}
117
118impl MarpleDB {
119    async fn init_ingestion(
120        &self,
121        stream_id: i32,
122        dataset_name: &str,
123        file_size: u64,
124        metadata: &crate::Metadata,
125    ) -> Result<IngestionInit> {
126        let body = serde_json::json!({
127            "stream_id": stream_id,
128            "dataset_name": dataset_name,
129            "file_size": file_size,
130            "metadata": metadata,
131        });
132        self.post_json("ingestion", &body).await
133    }
134
135    async fn get_part_urls(
136        &self,
137        ingestion_id: i32,
138        start_part: u32,
139        count: usize,
140    ) -> Result<PartUrlsResponse> {
141        let endpoint = format!("ingestion/{}/upload/part-urls", ingestion_id);
142        self.get_json(
143            &endpoint,
144            &[("start_part", start_part), ("count", count as u32)],
145        )
146        .await
147    }
148
149    async fn complete_upload(&self, ingestion_id: i32) -> Result<()> {
150        let endpoint = format!("ingestion/{}/upload/complete", ingestion_id);
151        self.post_json::<_, Value>(&endpoint, &serde_json::json!({}))
152            .await?;
153        Ok(())
154    }
155
156    async fn abort_upload(&self, ingestion_id: i32, reason: &str) -> Result<()> {
157        let endpoint = format!("ingestion/{}/abort", ingestion_id);
158        self.post_json::<_, Value>(&endpoint, &serde_json::json!({ "reason": reason }))
159            .await?;
160        Ok(())
161    }
162
163    async fn upload_via_single(
164        &self,
165        init: &IngestionInit,
166        file_path: &Path,
167        total_size: u64,
168        progress: Arc<dyn ProgressReporter>,
169    ) -> Result<()> {
170        let url = init.presigned_url.as_deref().ok_or_else(|| {
171            Error::Protocol("single upload mode without presigned_url".to_string())
172        })?;
173        let file = tokio::fs::File::open(file_path).await?;
174        let mut uploaded = 0;
175
176        let mut reader = ReaderStream::new(file);
177        let stream = async_stream::stream! {
178            while let Some(chunk) = reader.next().await {
179                if let Ok(chunk) = &chunk {
180                    uploaded += chunk.len() as u64;
181                    progress.set_position(uploaded);
182                }
183                yield chunk;
184            }
185            progress.finish();
186        };
187
188        let response = send_storage(
189            self.storage_client
190                .put(url)
191                .header(CONTENT_LENGTH, total_size)
192                .body(Body::wrap_stream(stream)),
193            "storage PUT failed",
194        )
195        .await?;
196        ensure_success(response, "storage PUT failed").await?;
197        Ok(())
198    }
199
200    async fn upload_via_server(
201        &self,
202        init: &IngestionInit,
203        file_path: &Path,
204        file_name: &str,
205        total_size: u64,
206        progress: Arc<dyn ProgressReporter>,
207    ) -> Result<()> {
208        let file = tokio::fs::File::open(file_path).await?;
209        let mut uploaded = 0;
210
211        let mut reader = ReaderStream::new(file);
212        let stream = async_stream::stream! {
213            while let Some(chunk) = reader.next().await {
214                if let Ok(chunk) = &chunk {
215                    uploaded += chunk.len() as u64;
216                    progress.set_position(uploaded);
217                }
218                yield chunk;
219            }
220            progress.finish();
221        };
222
223        let body = Body::wrap_stream(stream);
224        let part = reqwest::multipart::Part::stream_with_length(body, total_size)
225            .file_name(file_name.to_string())
226            .mime_str("application/octet-stream")
227            .map_err(|source| Error::Storage {
228                context: "building multipart upload body".to_string(),
229                status: None,
230                body: None,
231                source: Some(source),
232            })?;
233        let form = reqwest::multipart::Form::new().part("file", part);
234        let endpoint = format!("ingestion/{}/upload/server", init.ingestion_id);
235        self.post_multipart(&endpoint, form).await?;
236        Ok(())
237    }
238
239    async fn put_block(
240        &self,
241        file: &mut tokio::fs::File,
242        context: &AzureBlockUploadContext,
243        descriptor: BlockDescriptor,
244    ) -> Result<()> {
245        file.seek(SeekFrom::Start(descriptor.offset)).await?;
246        let mut data = vec![0; usize::try_from(descriptor.length)?];
247        file.read_exact(&mut data).await?;
248
249        let stream = progress_reporting_stream(
250            data,
251            Arc::clone(&context.uploaded),
252            Arc::clone(&context.progress),
253        );
254
255        let mut block_url = (*context.sas_url).clone();
256        block_url
257            .query_pairs_mut()
258            .append_pair("comp", "block")
259            .append_pair(
260                "blockid",
261                &base64::engine::general_purpose::STANDARD.encode(descriptor.block_id.as_bytes()),
262            );
263
264        let response = send_storage(
265            self.storage_client
266                .put(block_url)
267                .header(CONTENT_LENGTH, descriptor.length)
268                .body(Body::wrap_stream(stream)),
269            format!("Azure block {} upload failed", descriptor.block_id),
270        )
271        .await?;
272        ensure_success(
273            response,
274            format!("Azure block {} upload failed", descriptor.block_id),
275        )
276        .await?;
277        Ok(())
278    }
279
280    async fn upload_via_azure(
281        &self,
282        init: &IngestionInit,
283        file_path: &Path,
284        total_size: u64,
285        concurrency: usize,
286        progress: Arc<dyn ProgressReporter>,
287    ) -> Result<()> {
288        const AZURE_BLOCK_SIZE: u64 = 64 * 1024 * 1024;
289
290        let url = init.presigned_url.as_deref().ok_or_else(|| {
291            Error::Protocol("azure upload mode without presigned_url".to_string())
292        })?;
293        let sas_url: reqwest::Url = url.parse()?;
294
295        let concurrency = concurrency.max(1);
296        let descriptors = azure_block_descriptors(total_size, AZURE_BLOCK_SIZE);
297        let context = AzureBlockUploadContext {
298            sas_url: Arc::new(sas_url.clone()),
299            file_path: file_path.to_path_buf(),
300            uploaded: Arc::new(AtomicU64::new(0)),
301            progress: Arc::clone(&progress),
302        };
303        let cursor = Arc::new(Mutex::new(descriptors.clone().into_iter()));
304
305        let workers = (0..concurrency).map(|_| {
306            let context = context.clone();
307            let cursor = Arc::clone(&cursor);
308            async move {
309                let mut file = tokio::fs::File::open(&context.file_path).await?;
310                loop {
311                    let descriptor = {
312                        let mut cursor = cursor.lock().await;
313                        cursor.next()
314                    };
315                    let Some(descriptor) = descriptor else {
316                        return Ok::<_, Error>(());
317                    };
318
319                    self.put_block(&mut file, &context, descriptor).await?;
320                }
321            }
322        });
323        futures_util::future::try_join_all(workers).await?;
324
325        self.commit_azure_block_list(&sas_url, &descriptors).await?;
326        progress.finish();
327        Ok(())
328    }
329
330    async fn commit_azure_block_list(
331        &self,
332        sas_url: &reqwest::Url,
333        descriptors: &[BlockDescriptor],
334    ) -> Result<()> {
335        let mut block_list_url = sas_url.clone();
336        block_list_url
337            .query_pairs_mut()
338            .append_pair("comp", "blocklist");
339
340        let mut xml = String::from("<?xml version=\"1.0\" encoding=\"utf-8\"?>\n<BlockList>\n");
341        for descriptor in descriptors {
342            let block_id =
343                base64::engine::general_purpose::STANDARD.encode(descriptor.block_id.as_bytes());
344            xml.push_str("\t<Uncommitted>");
345            xml.push_str(&block_id);
346            xml.push_str("</Uncommitted>\n");
347        }
348        xml.push_str("</BlockList>");
349        let date = httpdate::fmt_http_date(std::time::SystemTime::now());
350
351        let response = send_storage(
352            self.storage_client
353                .put(block_list_url)
354                .header(reqwest::header::CONTENT_TYPE, "application/xml")
355                .header(reqwest::header::CONTENT_LENGTH, xml.len())
356                .header("x-ms-date", date)
357                .header("x-ms-version", "2022-11-02")
358                .body(xml),
359            "Azure block list commit failed",
360        )
361        .await?;
362        ensure_success(response, "Azure block list commit failed").await
363    }
364
365    async fn put_part(
366        &self,
367        file: &mut tokio::fs::File,
368        context: &MultipartUploadContext,
369        part: PartUrl,
370    ) -> Result<()> {
371        let offset = u64::from(part.part_number - 1) * context.part_size;
372        if offset >= context.total_size {
373            return Err(Error::Protocol(format!(
374                "part {} offset is outside the file",
375                part.part_number
376            )));
377        }
378        let part_len = context.part_size.min(context.total_size - offset);
379
380        file.seek(SeekFrom::Start(offset)).await?;
381        let mut data = vec![0; usize::try_from(part_len)?];
382        file.read_exact(&mut data).await?;
383
384        let stream = progress_reporting_stream(
385            data,
386            Arc::clone(&context.uploaded),
387            Arc::clone(&context.progress),
388        );
389
390        let response = send_storage(
391            self.storage_client
392                .put(part.url)
393                .header(CONTENT_LENGTH, part_len)
394                .body(Body::wrap_stream(stream)),
395            format!("part {} storage PUT failed", part.part_number),
396        )
397        .await?;
398        ensure_success(
399            response,
400            format!("part {} storage PUT failed", part.part_number),
401        )
402        .await?;
403        Ok(())
404    }
405
406    fn signed_parts_stream(
407        &self,
408        ingestion_id: i32,
409        batch_size: usize,
410    ) -> impl futures_util::Stream<Item = Result<PartUrl>> + '_ {
411        async_stream::try_stream! {
412            let mut next_part = Some(1);
413
414            while let Some(start_part) = next_part {
415                let urls = self.get_part_urls(ingestion_id, start_part, batch_size).await?;
416                if urls.parts.is_empty() {
417                    Err(Error::Protocol("server returned no multipart upload URLs".to_string()))?;
418                }
419
420                for part in urls.parts {
421                    yield part;
422                }
423
424                next_part = urls.next_part;
425            }
426        }
427    }
428
429    async fn upload_via_multipart(
430        &self,
431        init: &IngestionInit,
432        file_path: &Path,
433        total_size: u64,
434        concurrency: usize,
435        progress: Arc<dyn ProgressReporter>,
436    ) -> Result<()> {
437        let part_size = init.part_size.ok_or_else(|| {
438            Error::Protocol("multipart upload mode without part_size".to_string())
439        })?;
440        if part_size == 0 {
441            return Err(Error::Protocol(
442                "multipart upload part_size must be positive".to_string(),
443            ));
444        }
445        let concurrency = concurrency.max(1);
446
447        let uploaded = Arc::new(AtomicU64::new(0));
448        let batch_size = concurrency.max(32);
449        let context = MultipartUploadContext {
450            file_path: file_path.to_path_buf(),
451            part_size,
452            total_size,
453            uploaded,
454            progress: Arc::clone(&progress),
455        };
456        let parts = self.signed_parts_stream(init.ingestion_id, batch_size);
457        let parts = Arc::new(Mutex::new(Box::pin(parts)));
458
459        let workers = (0..concurrency).map(|_| {
460            let context = context.clone();
461            let parts = Arc::clone(&parts);
462            async move {
463                let mut file = tokio::fs::File::open(&context.file_path).await?;
464                loop {
465                    let part = {
466                        let mut parts = parts.lock().await;
467                        parts.next().await.transpose()?
468                    };
469                    let Some(part) = part else {
470                        return Ok::<_, Error>(());
471                    };
472
473                    self.put_part(&mut file, &context, part).await?;
474                }
475            }
476        });
477        futures_util::future::try_join_all(workers).await?;
478
479        progress.finish();
480        Ok(())
481    }
482
483    /// Uploads a file to a stream and returns the created dataset.
484    #[tracing::instrument(skip_all, fields(stream_id, path = %file_path.as_ref().display()))]
485    pub async fn push_file(
486        &self,
487        stream_id: i32,
488        file_path: impl AsRef<Path>,
489        options: PushFileOptions,
490    ) -> Result<Dataset> {
491        let file_path = file_path.as_ref();
492        let file_name = file_path.file_name().unwrap().to_string_lossy().to_string();
493        let total_size = tokio::fs::metadata(file_path).await?.len();
494
495        let init = self
496            .init_ingestion(stream_id, &file_name, total_size, &options.metadata)
497            .await?;
498        let progress = Arc::clone(&options.progress);
499
500        let upload_result = async {
501            match (options.upload_mode, &init.mode) {
502                (UploadModeOverride::Server, _) | (_, UploadMode::Server) => {
503                    tracing::debug!(ingestion_id = init.ingestion_id, "uploading via server");
504                    self.upload_via_server(
505                        &init,
506                        file_path,
507                        &file_name,
508                        total_size,
509                        Arc::clone(&progress),
510                    )
511                    .await?;
512                }
513                (_, UploadMode::Azure) => {
514                    tracing::debug!(
515                        ingestion_id = init.ingestion_id,
516                        "uploading via Azure blocks"
517                    );
518                    self.upload_via_azure(
519                        &init,
520                        file_path,
521                        total_size,
522                        options.concurrency,
523                        Arc::clone(&progress),
524                    )
525                    .await?;
526                }
527                (_, UploadMode::Single) => {
528                    tracing::debug!(ingestion_id = init.ingestion_id, "uploading via single PUT");
529                    self.upload_via_single(&init, file_path, total_size, Arc::clone(&progress))
530                        .await?;
531                }
532                (_, UploadMode::Multipart) => {
533                    tracing::debug!(ingestion_id = init.ingestion_id, "uploading via multipart");
534                    self.upload_via_multipart(
535                        &init,
536                        file_path,
537                        total_size,
538                        options.concurrency,
539                        Arc::clone(&progress),
540                    )
541                    .await?;
542                }
543            }
544            self.complete_upload(init.ingestion_id).await?;
545            self.get_dataset(stream_id, init.dataset_id).await
546        }
547        .await;
548
549        match upload_result {
550            Ok(dataset) => Ok(dataset),
551            Err(e) => {
552                let _ = self
553                    .abort_upload(init.ingestion_id, &format!("{:#}", e))
554                    .await;
555                Err(e)
556            }
557        }
558    }
559}