1use crate::models::{
2 Dataset, IngestionInit, PartUrl, PartUrlsResponse, PushFileOptions, UploadMode,
3 UploadModeOverride,
4};
5use crate::{Error, MarpleDB, ProgressReporter, Result};
6use base64::Engine;
7use futures_util::StreamExt;
8use reqwest::{Body, Response, header::CONTENT_LENGTH};
9use serde_json::Value;
10use std::io::SeekFrom;
11use std::path::{Path, PathBuf};
12use std::sync::{
13 Arc,
14 atomic::{AtomicU64, Ordering},
15};
16use tokio::io::{AsyncReadExt, AsyncSeekExt};
17use tokio::sync::Mutex;
18use tokio_util::io::ReaderStream;
19
20const PROGRESS_STREAM_CHUNK: usize = 256 * 1024;
21
22#[derive(Clone)]
23struct MultipartUploadContext {
24 file_path: PathBuf,
25 part_size: u64,
26 total_size: u64,
27 uploaded: Arc<AtomicU64>,
28 progress: Arc<dyn ProgressReporter>,
29}
30
31#[derive(Clone)]
32struct BlockDescriptor {
33 offset: u64,
34 length: u64,
35 block_id: String,
36}
37
38#[derive(Clone)]
39struct AzureBlockUploadContext {
40 sas_url: Arc<reqwest::Url>,
41 file_path: PathBuf,
42 uploaded: Arc<AtomicU64>,
43 progress: Arc<dyn ProgressReporter>,
44}
45
46fn progress_reporting_stream(
47 data: Vec<u8>,
48 uploaded: Arc<AtomicU64>,
49 progress: Arc<dyn ProgressReporter>,
50) -> impl futures_util::Stream<Item = std::io::Result<Vec<u8>>> + Send + 'static {
51 async_stream::stream! {
52 let mut pos = 0;
53 while pos < data.len() {
54 let end = (pos + PROGRESS_STREAM_CHUNK).min(data.len());
55 let chunk = data[pos..end].to_vec();
56 let chunk_len = chunk.len() as u64;
57 let new_uploaded = uploaded.fetch_add(chunk_len, Ordering::Relaxed) + chunk_len;
58 progress.set_position(new_uploaded);
59 yield Ok(chunk);
60 pos = end;
61 }
62 }
63}
64
65fn azure_block_descriptors(total_size: u64, block_size: u64) -> Vec<BlockDescriptor> {
66 if total_size == 0 {
67 return Vec::new();
68 }
69
70 let n_blocks = total_size.div_ceil(block_size);
71 (0..n_blocks as u32)
72 .map(|block_number| {
73 let offset = u64::from(block_number) * block_size;
74 let length = block_size.min(total_size - offset);
75 BlockDescriptor {
76 offset,
77 length,
78 block_id: format!("{block_number:08}"),
79 }
80 })
81 .collect()
82}
83
84async fn ensure_success(response: Response, failure_message: impl Into<String>) -> Result<()> {
85 if response.status().is_success() {
86 Ok(())
87 } else {
88 let context = failure_message.into();
89 let status = response.status();
90 let body = response.text().await.map_err(|source| Error::Storage {
91 context: context.clone(),
92 status: Some(status),
93 body: None,
94 source: Some(source),
95 })?;
96 Err(Error::Storage {
97 context,
98 status: Some(status),
99 body: Some(body),
100 source: None,
101 })
102 }
103}
104
105async fn send_storage(
106 request: reqwest::RequestBuilder,
107 context: impl Into<String>,
108) -> Result<Response> {
109 let context = context.into();
110 request.send().await.map_err(|source| Error::Storage {
111 context,
112 status: None,
113 body: None,
114 source: Some(source),
115 })
116}
117
118impl MarpleDB {
119 async fn init_ingestion(
120 &self,
121 stream_id: i32,
122 dataset_name: &str,
123 file_size: u64,
124 metadata: &crate::Metadata,
125 ) -> Result<IngestionInit> {
126 let body = serde_json::json!({
127 "stream_id": stream_id,
128 "dataset_name": dataset_name,
129 "file_size": file_size,
130 "metadata": metadata,
131 });
132 self.post_json("ingestion", &body).await
133 }
134
135 async fn get_part_urls(
136 &self,
137 ingestion_id: i32,
138 start_part: u32,
139 count: usize,
140 ) -> Result<PartUrlsResponse> {
141 let endpoint = format!("ingestion/{}/upload/part-urls", ingestion_id);
142 self.get_json(
143 &endpoint,
144 &[("start_part", start_part), ("count", count as u32)],
145 )
146 .await
147 }
148
149 async fn complete_upload(&self, ingestion_id: i32) -> Result<()> {
150 let endpoint = format!("ingestion/{}/upload/complete", ingestion_id);
151 self.post_json::<_, Value>(&endpoint, &serde_json::json!({}))
152 .await?;
153 Ok(())
154 }
155
156 async fn abort_upload(&self, ingestion_id: i32, reason: &str) -> Result<()> {
157 let endpoint = format!("ingestion/{}/abort", ingestion_id);
158 self.post_json::<_, Value>(&endpoint, &serde_json::json!({ "reason": reason }))
159 .await?;
160 Ok(())
161 }
162
163 async fn upload_via_single(
164 &self,
165 init: &IngestionInit,
166 file_path: &Path,
167 total_size: u64,
168 progress: Arc<dyn ProgressReporter>,
169 ) -> Result<()> {
170 let url = init.presigned_url.as_deref().ok_or_else(|| {
171 Error::Protocol("single upload mode without presigned_url".to_string())
172 })?;
173 let file = tokio::fs::File::open(file_path).await?;
174 let mut uploaded = 0;
175
176 let mut reader = ReaderStream::new(file);
177 let stream = async_stream::stream! {
178 while let Some(chunk) = reader.next().await {
179 if let Ok(chunk) = &chunk {
180 uploaded += chunk.len() as u64;
181 progress.set_position(uploaded);
182 }
183 yield chunk;
184 }
185 progress.finish();
186 };
187
188 let response = send_storage(
189 self.storage_client
190 .put(url)
191 .header(CONTENT_LENGTH, total_size)
192 .body(Body::wrap_stream(stream)),
193 "storage PUT failed",
194 )
195 .await?;
196 ensure_success(response, "storage PUT failed").await?;
197 Ok(())
198 }
199
200 async fn upload_via_server(
201 &self,
202 init: &IngestionInit,
203 file_path: &Path,
204 file_name: &str,
205 total_size: u64,
206 progress: Arc<dyn ProgressReporter>,
207 ) -> Result<()> {
208 let file = tokio::fs::File::open(file_path).await?;
209 let mut uploaded = 0;
210
211 let mut reader = ReaderStream::new(file);
212 let stream = async_stream::stream! {
213 while let Some(chunk) = reader.next().await {
214 if let Ok(chunk) = &chunk {
215 uploaded += chunk.len() as u64;
216 progress.set_position(uploaded);
217 }
218 yield chunk;
219 }
220 progress.finish();
221 };
222
223 let body = Body::wrap_stream(stream);
224 let part = reqwest::multipart::Part::stream_with_length(body, total_size)
225 .file_name(file_name.to_string())
226 .mime_str("application/octet-stream")
227 .map_err(|source| Error::Storage {
228 context: "building multipart upload body".to_string(),
229 status: None,
230 body: None,
231 source: Some(source),
232 })?;
233 let form = reqwest::multipart::Form::new().part("file", part);
234 let endpoint = format!("ingestion/{}/upload/server", init.ingestion_id);
235 self.post_multipart(&endpoint, form).await?;
236 Ok(())
237 }
238
239 async fn put_block(
240 &self,
241 file: &mut tokio::fs::File,
242 context: &AzureBlockUploadContext,
243 descriptor: BlockDescriptor,
244 ) -> Result<()> {
245 file.seek(SeekFrom::Start(descriptor.offset)).await?;
246 let mut data = vec![0; usize::try_from(descriptor.length)?];
247 file.read_exact(&mut data).await?;
248
249 let stream = progress_reporting_stream(
250 data,
251 Arc::clone(&context.uploaded),
252 Arc::clone(&context.progress),
253 );
254
255 let mut block_url = (*context.sas_url).clone();
256 block_url
257 .query_pairs_mut()
258 .append_pair("comp", "block")
259 .append_pair(
260 "blockid",
261 &base64::engine::general_purpose::STANDARD.encode(descriptor.block_id.as_bytes()),
262 );
263
264 let response = send_storage(
265 self.storage_client
266 .put(block_url)
267 .header(CONTENT_LENGTH, descriptor.length)
268 .body(Body::wrap_stream(stream)),
269 format!("Azure block {} upload failed", descriptor.block_id),
270 )
271 .await?;
272 ensure_success(
273 response,
274 format!("Azure block {} upload failed", descriptor.block_id),
275 )
276 .await?;
277 Ok(())
278 }
279
280 async fn upload_via_azure(
281 &self,
282 init: &IngestionInit,
283 file_path: &Path,
284 total_size: u64,
285 concurrency: usize,
286 progress: Arc<dyn ProgressReporter>,
287 ) -> Result<()> {
288 const AZURE_BLOCK_SIZE: u64 = 64 * 1024 * 1024;
289
290 let url = init.presigned_url.as_deref().ok_or_else(|| {
291 Error::Protocol("azure upload mode without presigned_url".to_string())
292 })?;
293 let sas_url: reqwest::Url = url.parse()?;
294
295 let concurrency = concurrency.max(1);
296 let descriptors = azure_block_descriptors(total_size, AZURE_BLOCK_SIZE);
297 let context = AzureBlockUploadContext {
298 sas_url: Arc::new(sas_url.clone()),
299 file_path: file_path.to_path_buf(),
300 uploaded: Arc::new(AtomicU64::new(0)),
301 progress: Arc::clone(&progress),
302 };
303 let cursor = Arc::new(Mutex::new(descriptors.clone().into_iter()));
304
305 let workers = (0..concurrency).map(|_| {
306 let context = context.clone();
307 let cursor = Arc::clone(&cursor);
308 async move {
309 let mut file = tokio::fs::File::open(&context.file_path).await?;
310 loop {
311 let descriptor = {
312 let mut cursor = cursor.lock().await;
313 cursor.next()
314 };
315 let Some(descriptor) = descriptor else {
316 return Ok::<_, Error>(());
317 };
318
319 self.put_block(&mut file, &context, descriptor).await?;
320 }
321 }
322 });
323 futures_util::future::try_join_all(workers).await?;
324
325 self.commit_azure_block_list(&sas_url, &descriptors).await?;
326 progress.finish();
327 Ok(())
328 }
329
330 async fn commit_azure_block_list(
331 &self,
332 sas_url: &reqwest::Url,
333 descriptors: &[BlockDescriptor],
334 ) -> Result<()> {
335 let mut block_list_url = sas_url.clone();
336 block_list_url
337 .query_pairs_mut()
338 .append_pair("comp", "blocklist");
339
340 let mut xml = String::from("<?xml version=\"1.0\" encoding=\"utf-8\"?>\n<BlockList>\n");
341 for descriptor in descriptors {
342 let block_id =
343 base64::engine::general_purpose::STANDARD.encode(descriptor.block_id.as_bytes());
344 xml.push_str("\t<Uncommitted>");
345 xml.push_str(&block_id);
346 xml.push_str("</Uncommitted>\n");
347 }
348 xml.push_str("</BlockList>");
349 let date = httpdate::fmt_http_date(std::time::SystemTime::now());
350
351 let response = send_storage(
352 self.storage_client
353 .put(block_list_url)
354 .header(reqwest::header::CONTENT_TYPE, "application/xml")
355 .header(reqwest::header::CONTENT_LENGTH, xml.len())
356 .header("x-ms-date", date)
357 .header("x-ms-version", "2022-11-02")
358 .body(xml),
359 "Azure block list commit failed",
360 )
361 .await?;
362 ensure_success(response, "Azure block list commit failed").await
363 }
364
365 async fn put_part(
366 &self,
367 file: &mut tokio::fs::File,
368 context: &MultipartUploadContext,
369 part: PartUrl,
370 ) -> Result<()> {
371 let offset = u64::from(part.part_number - 1) * context.part_size;
372 if offset >= context.total_size {
373 return Err(Error::Protocol(format!(
374 "part {} offset is outside the file",
375 part.part_number
376 )));
377 }
378 let part_len = context.part_size.min(context.total_size - offset);
379
380 file.seek(SeekFrom::Start(offset)).await?;
381 let mut data = vec![0; usize::try_from(part_len)?];
382 file.read_exact(&mut data).await?;
383
384 let stream = progress_reporting_stream(
385 data,
386 Arc::clone(&context.uploaded),
387 Arc::clone(&context.progress),
388 );
389
390 let response = send_storage(
391 self.storage_client
392 .put(part.url)
393 .header(CONTENT_LENGTH, part_len)
394 .body(Body::wrap_stream(stream)),
395 format!("part {} storage PUT failed", part.part_number),
396 )
397 .await?;
398 ensure_success(
399 response,
400 format!("part {} storage PUT failed", part.part_number),
401 )
402 .await?;
403 Ok(())
404 }
405
406 fn signed_parts_stream(
407 &self,
408 ingestion_id: i32,
409 batch_size: usize,
410 ) -> impl futures_util::Stream<Item = Result<PartUrl>> + '_ {
411 async_stream::try_stream! {
412 let mut next_part = Some(1);
413
414 while let Some(start_part) = next_part {
415 let urls = self.get_part_urls(ingestion_id, start_part, batch_size).await?;
416 if urls.parts.is_empty() {
417 Err(Error::Protocol("server returned no multipart upload URLs".to_string()))?;
418 }
419
420 for part in urls.parts {
421 yield part;
422 }
423
424 next_part = urls.next_part;
425 }
426 }
427 }
428
429 async fn upload_via_multipart(
430 &self,
431 init: &IngestionInit,
432 file_path: &Path,
433 total_size: u64,
434 concurrency: usize,
435 progress: Arc<dyn ProgressReporter>,
436 ) -> Result<()> {
437 let part_size = init.part_size.ok_or_else(|| {
438 Error::Protocol("multipart upload mode without part_size".to_string())
439 })?;
440 if part_size == 0 {
441 return Err(Error::Protocol(
442 "multipart upload part_size must be positive".to_string(),
443 ));
444 }
445 let concurrency = concurrency.max(1);
446
447 let uploaded = Arc::new(AtomicU64::new(0));
448 let batch_size = concurrency.max(32);
449 let context = MultipartUploadContext {
450 file_path: file_path.to_path_buf(),
451 part_size,
452 total_size,
453 uploaded,
454 progress: Arc::clone(&progress),
455 };
456 let parts = self.signed_parts_stream(init.ingestion_id, batch_size);
457 let parts = Arc::new(Mutex::new(Box::pin(parts)));
458
459 let workers = (0..concurrency).map(|_| {
460 let context = context.clone();
461 let parts = Arc::clone(&parts);
462 async move {
463 let mut file = tokio::fs::File::open(&context.file_path).await?;
464 loop {
465 let part = {
466 let mut parts = parts.lock().await;
467 parts.next().await.transpose()?
468 };
469 let Some(part) = part else {
470 return Ok::<_, Error>(());
471 };
472
473 self.put_part(&mut file, &context, part).await?;
474 }
475 }
476 });
477 futures_util::future::try_join_all(workers).await?;
478
479 progress.finish();
480 Ok(())
481 }
482
483 #[tracing::instrument(skip_all, fields(stream_id, path = %file_path.as_ref().display()))]
485 pub async fn push_file(
486 &self,
487 stream_id: i32,
488 file_path: impl AsRef<Path>,
489 options: PushFileOptions,
490 ) -> Result<Dataset> {
491 let file_path = file_path.as_ref();
492 let file_name = file_path.file_name().unwrap().to_string_lossy().to_string();
493 let total_size = tokio::fs::metadata(file_path).await?.len();
494
495 let init = self
496 .init_ingestion(stream_id, &file_name, total_size, &options.metadata)
497 .await?;
498 let progress = Arc::clone(&options.progress);
499
500 let upload_result = async {
501 match (options.upload_mode, &init.mode) {
502 (UploadModeOverride::Server, _) | (_, UploadMode::Server) => {
503 tracing::debug!(ingestion_id = init.ingestion_id, "uploading via server");
504 self.upload_via_server(
505 &init,
506 file_path,
507 &file_name,
508 total_size,
509 Arc::clone(&progress),
510 )
511 .await?;
512 }
513 (_, UploadMode::Azure) => {
514 tracing::debug!(
515 ingestion_id = init.ingestion_id,
516 "uploading via Azure blocks"
517 );
518 self.upload_via_azure(
519 &init,
520 file_path,
521 total_size,
522 options.concurrency,
523 Arc::clone(&progress),
524 )
525 .await?;
526 }
527 (_, UploadMode::Single) => {
528 tracing::debug!(ingestion_id = init.ingestion_id, "uploading via single PUT");
529 self.upload_via_single(&init, file_path, total_size, Arc::clone(&progress))
530 .await?;
531 }
532 (_, UploadMode::Multipart) => {
533 tracing::debug!(ingestion_id = init.ingestion_id, "uploading via multipart");
534 self.upload_via_multipart(
535 &init,
536 file_path,
537 total_size,
538 options.concurrency,
539 Arc::clone(&progress),
540 )
541 .await?;
542 }
543 }
544 self.complete_upload(init.ingestion_id).await?;
545 self.get_dataset(stream_id, init.dataset_id).await
546 }
547 .await;
548
549 match upload_result {
550 Ok(dataset) => Ok(dataset),
551 Err(e) => {
552 let _ = self
553 .abort_upload(init.ingestion_id, &format!("{:#}", e))
554 .await;
555 Err(e)
556 }
557 }
558 }
559}