Skip to main content

nominal_streaming/
upload.rs

1use std::io::Read;
2use std::io::Seek;
3use std::path::Path;
4use std::path::PathBuf;
5use std::pin::Pin;
6use std::sync::Arc;
7
8use conjure_error::Error;
9use conjure_http::client::AsyncWriteBody;
10use conjure_http::private::Stream;
11use conjure_object::BearerToken;
12use conjure_object::ResourceIdentifier;
13use conjure_object::SafeLong;
14use conjure_runtime_rustls_platform_verifier::BodyWriter;
15use conjure_runtime_rustls_platform_verifier::Client;
16use futures::StreamExt;
17use nominal_api::clients::ingest::api::AsyncIngestService;
18use nominal_api::clients::ingest::api::AsyncIngestServiceClient;
19use nominal_api::clients::upload::api::AsyncUploadService;
20use nominal_api::clients::upload::api::AsyncUploadServiceClient;
21use nominal_api::objects::api::rids::WorkspaceRid;
22use nominal_api::objects::ingest::api::AvroStreamOpts;
23use nominal_api::objects::ingest::api::CompleteMultipartUploadResponse;
24use nominal_api::objects::ingest::api::DatasetIngestTarget;
25use nominal_api::objects::ingest::api::ExistingDatasetIngestDestination;
26use nominal_api::objects::ingest::api::IngestOptions;
27use nominal_api::objects::ingest::api::IngestRequest;
28use nominal_api::objects::ingest::api::IngestResponse;
29use nominal_api::objects::ingest::api::IngestSource;
30use nominal_api::objects::ingest::api::InitiateMultipartUploadRequest;
31use nominal_api::objects::ingest::api::InitiateMultipartUploadResponse;
32use nominal_api::objects::ingest::api::Part;
33use nominal_api::objects::ingest::api::S3IngestSource;
34use tokio::sync::Semaphore;
35use tracing::error;
36use tracing::info;
37
38use crate::client::NominalApiClients;
39use crate::types::AuthProvider;
40
41const SMALL_FILE_SIZE_LIMIT: u64 = 512 * 1024 * 1024; // 512 MB
42
43#[derive(Clone)]
44pub struct AvroIngestManager {
45    pub upload_queue: async_channel::Receiver<PathBuf>,
46}
47
48impl AvroIngestManager {
49    pub fn new(
50        clients: NominalApiClients,
51        http_client: reqwest::Client,
52        handle: tokio::runtime::Handle,
53        opts: UploaderOpts,
54        upload_queue: async_channel::Receiver<PathBuf>,
55        auth_provider: impl AuthProvider + 'static,
56        data_source_rid: ResourceIdentifier,
57    ) -> Self {
58        let uploader = FileObjectStoreUploader::new(
59            clients.upload,
60            clients.ingest,
61            http_client,
62            handle.clone(),
63            opts,
64        );
65
66        let upload_queue_clone = upload_queue.clone();
67
68        handle.spawn(async move {
69            Self::run(upload_queue_clone, uploader, auth_provider, data_source_rid).await;
70        });
71
72        AvroIngestManager { upload_queue }
73    }
74
75    pub async fn run(
76        upload_queue: async_channel::Receiver<PathBuf>,
77        uploader: FileObjectStoreUploader,
78        auth_provider: impl AuthProvider + 'static,
79        data_source_rid: ResourceIdentifier,
80    ) {
81        while let Ok(file_path) = upload_queue.recv().await {
82            let file_name = file_path.to_str().unwrap_or("nmstream_file");
83            let file = std::fs::File::open(&file_path);
84            let Some(token) = auth_provider.token() else {
85                error!("Missing token for upload");
86                continue;
87            };
88            match file {
89                Ok(f) => {
90                    match upload_and_ingest_file(
91                        uploader.clone(),
92                        &token,
93                        auth_provider.workspace_rid(),
94                        f,
95                        file_name,
96                        &file_path,
97                        data_source_rid.clone(),
98                    )
99                    .await
100                    {
101                        Ok(()) => {}
102                        Err(e) => {
103                            error!(
104                                "Error uploading and ingesting file {}: {}",
105                                file_path.display(),
106                                e
107                            );
108                        }
109                    }
110                }
111                Err(e) => {
112                    error!("Failed to open file {}: {:?}", file_path.display(), e);
113                }
114            }
115        }
116    }
117}
118
119async fn upload_and_ingest_file(
120    uploader: FileObjectStoreUploader,
121    token: &BearerToken,
122    workspace_rid: Option<WorkspaceRid>,
123    file: std::fs::File,
124    file_name: &str,
125    file_path: &PathBuf,
126    data_source_rid: ResourceIdentifier,
127) -> Result<(), String> {
128    match uploader.upload(token, file, file_name, workspace_rid).await {
129        Ok(response) => {
130            match uploader
131                .ingest_avro(token, &response, data_source_rid)
132                .await
133            {
134                Ok(ingest_response) => {
135                    info!(
136                        "Successfully uploaded and ingested file {}: {:?}",
137                        file_name, ingest_response
138                    );
139                    if let Err(e) = std::fs::remove_file(file_path) {
140                        Err(format!(
141                            "Failed to remove file {}: {:?}",
142                            file_path.display(),
143                            e
144                        ))
145                    } else {
146                        info!("Removed file {}", file_path.display());
147                        Ok(())
148                    }
149                }
150                Err(e) => Err(format!("Failed to ingest file {file_name}: {e:?}")),
151            }
152        }
153        Err(e) => Err(format!("Failed to upload file {file_name}: {e:?}")),
154    }
155}
156
157#[derive(Debug, thiserror::Error)]
158pub enum UploaderError {
159    #[error("Conjure error: {0}")]
160    Conjure(String),
161    #[error("Failed to initiate multipart upload: {0}")]
162    IOError(#[from] std::io::Error),
163    #[error("Failed to upload part: {0}")]
164    HTTPError(#[from] reqwest::Error),
165    #[error("Error executing upload tasks: {0}")]
166    TokioError(#[from] tokio::task::JoinError),
167    #[error("Error: {0}")]
168    Other(String),
169}
170
171#[derive(Debug, Clone)]
172pub struct UploaderOpts {
173    pub chunk_size: usize,
174    pub max_retries: usize,
175    pub max_concurrent_uploads: usize,
176}
177
178impl Default for UploaderOpts {
179    fn default() -> Self {
180        UploaderOpts {
181            chunk_size: 512 * 1024 * 1024, // 512 MB
182            max_retries: 3,
183            max_concurrent_uploads: 1,
184        }
185    }
186}
187
188pub struct FileWriteBody {
189    file: std::fs::File,
190}
191
192impl FileWriteBody {
193    pub fn new(file: std::fs::File) -> Self {
194        FileWriteBody { file }
195    }
196}
197
198impl AsyncWriteBody<BodyWriter> for FileWriteBody {
199    async fn write_body(self: Pin<&mut Self>, w: Pin<&mut BodyWriter>) -> Result<(), Error> {
200        let mut file = self
201            .file
202            .try_clone()
203            .map_err(|e| Error::internal_safe(format!("Failed to clone file for upload: {e}")))?;
204
205        let mut buffer = Vec::new();
206        file.read_to_end(&mut buffer)
207            .map_err(|e| Error::internal_safe(format!("Failed to read bytes from file: {e}")))?;
208
209        w.write_bytes(buffer.into())
210            .await
211            .map_err(|e| Error::internal_safe(format!("Failed to write bytes to body: {e}")))?;
212
213        Ok(())
214    }
215
216    async fn reset(self: Pin<&mut Self>) -> bool {
217        let Ok(mut file) = self.file.try_clone() else {
218            return false;
219        };
220
221        use std::io::SeekFrom;
222
223        file.seek(SeekFrom::Start(0)).is_ok()
224    }
225}
226
227#[derive(Clone)]
228pub struct FileObjectStoreUploader {
229    upload_client: AsyncUploadServiceClient<Client>,
230    ingest_client: AsyncIngestServiceClient<Client>,
231    http_client: reqwest::Client,
232    handle: tokio::runtime::Handle,
233    opts: UploaderOpts,
234}
235
236impl FileObjectStoreUploader {
237    pub fn new(
238        upload_client: AsyncUploadServiceClient<Client>,
239        ingest_client: AsyncIngestServiceClient<Client>,
240        http_client: reqwest::Client,
241        handle: tokio::runtime::Handle,
242        opts: UploaderOpts,
243    ) -> Self {
244        FileObjectStoreUploader {
245            upload_client,
246            ingest_client,
247            http_client,
248            handle,
249            opts,
250        }
251    }
252
253    pub async fn initiate_upload(
254        &self,
255        token: &BearerToken,
256        file_name: &str,
257        workspace_rid: Option<WorkspaceRid>,
258    ) -> Result<InitiateMultipartUploadResponse, UploaderError> {
259        let request = InitiateMultipartUploadRequest::builder()
260            .filename(file_name)
261            .filetype("application/octet-stream")
262            .workspace(workspace_rid)
263            .build();
264        let response = self
265            .upload_client
266            .initiate_multipart_upload(token, &request)
267            .await
268            .map_err(|e| UploaderError::Conjure(format!("{e:?}")))?;
269
270        info!("Initiated multipart upload for file: {}", file_name);
271        Ok(response)
272    }
273
274    #[expect(clippy::too_many_arguments)]
275    async fn upload_part(
276        client: AsyncUploadServiceClient<Client>,
277        http_client: reqwest::Client,
278        token: BearerToken,
279        upload_id: String,
280        key: String,
281        part_number: i32,
282        chunk: Vec<u8>,
283        max_retries: usize,
284    ) -> Result<Part, UploaderError> {
285        let mut attempts = 0;
286
287        loop {
288            attempts += 1;
289            match Self::try_upload_part(
290                client.clone(),
291                http_client.clone(),
292                &token,
293                &upload_id,
294                &key,
295                part_number,
296                chunk.clone(),
297            )
298            .await
299            {
300                Ok(part) => return Ok(part),
301                Err(e) if attempts < max_retries => {
302                    error!("Upload attempt {} failed, retrying: {}", attempts, e);
303                    continue;
304                }
305                Err(e) => {
306                    return Err(e);
307                }
308            }
309        }
310    }
311
312    async fn try_upload_part(
313        client: AsyncUploadServiceClient<Client>,
314        http_client: reqwest::Client,
315        token: &BearerToken,
316        upload_id: &str,
317        key: &str,
318        part_number: i32,
319        chunk: Vec<u8>,
320    ) -> Result<Part, UploaderError> {
321        let response = client
322            .sign_part(token, upload_id, key, part_number)
323            .await
324            .map_err(|e| UploaderError::Conjure(format!("{e:?}")))?;
325
326        let mut request_builder = http_client.put(response.url()).body(chunk);
327
328        for (header_name, header_value) in response.headers() {
329            request_builder = request_builder.header(header_name, header_value);
330        }
331
332        let http_response = request_builder.send().await?;
333        let headers = http_response.headers().clone();
334        let status = http_response.status();
335
336        if !status.is_success() {
337            error!("Failed to upload body");
338            return Err(UploaderError::Other(format!(
339                "Failed to upload part {part_number}: HTTP status {status}"
340            )));
341        }
342
343        let etag = headers
344            .get("etag")
345            .and_then(|v| v.to_str().ok())
346            .unwrap_or("ignored-etag");
347
348        Ok(Part::new(part_number, etag))
349    }
350
351    pub async fn upload_parts<R>(
352        &self,
353        token: &BearerToken,
354        reader: R,
355        key: &str,
356        upload_id: &str,
357    ) -> Result<CompleteMultipartUploadResponse, UploaderError>
358    where
359        R: Read + Send + 'static,
360    {
361        let chunks = ChunkedStreamReader::new(reader, self.opts.chunk_size);
362
363        let parallel_part_uploads = Arc::new(Semaphore::new(self.opts.max_concurrent_uploads));
364        let mut upload_futures = Vec::new();
365
366        futures::pin_mut!(chunks);
367
368        while let Some(entry) = chunks.next().await {
369            let (index, chunk) = entry?;
370            let part_number = (index + 1) as i32;
371
372            let token = token.clone();
373            let key = key.to_string();
374            let upload_id = upload_id.to_string();
375            let parallel_part_uploads = Arc::clone(&parallel_part_uploads);
376            let client = self.upload_client.clone();
377            let http_client = self.http_client.clone();
378            let max_retries = self.opts.max_retries;
379
380            upload_futures.push(self.handle.spawn(async move {
381                let _permit = parallel_part_uploads.acquire().await;
382                Self::upload_part(
383                    client,
384                    http_client,
385                    token,
386                    upload_id,
387                    key,
388                    part_number,
389                    chunk,
390                    max_retries,
391                )
392                .await
393            }));
394        }
395
396        let mut part_responses = futures::future::join_all(upload_futures)
397            .await
398            .into_iter()
399            .map(|result| result.map_err(UploaderError::TokioError)?)
400            .collect::<Result<Vec<_>, _>>()?;
401
402        part_responses.sort_by_key(|part| part.part_number());
403
404        let response = self
405            .upload_client
406            .complete_multipart_upload(token, upload_id, key, &part_responses)
407            .await
408            .map_err(|e| UploaderError::Conjure(format!("{e:?}")))?;
409
410        Ok(response)
411    }
412
413    pub async fn upload_small_file(
414        &self,
415        token: &BearerToken,
416        file_name: &str,
417        size_bytes: i64,
418        workspace_rid: Option<WorkspaceRid>,
419        file: std::fs::File,
420    ) -> Result<String, UploaderError> {
421        let s3_path = self
422            .upload_client
423            .upload_file(
424                token,
425                file_name,
426                SafeLong::new(size_bytes).ok(),
427                workspace_rid.as_ref(),
428                FileWriteBody::new(file),
429            )
430            .await
431            .map_err(|e| UploaderError::Conjure(format!("{e:?}")))?;
432
433        Ok(s3_path.as_str().to_string())
434    }
435
436    pub async fn upload<R>(
437        &self,
438        token: &BearerToken,
439        reader: R,
440        file_name: impl Into<&str>,
441        workspace_rid: Option<WorkspaceRid>,
442    ) -> Result<String, UploaderError>
443    where
444        R: Read + Send + 'static,
445    {
446        let file_name = file_name.into();
447        let path = Path::new(file_name);
448        let file_size = std::fs::metadata(path)?.len();
449        if file_size < SMALL_FILE_SIZE_LIMIT {
450            return self
451                .upload_small_file(
452                    token,
453                    file_name,
454                    file_size as i64,
455                    workspace_rid,
456                    std::fs::File::open(path)?,
457                )
458                .await;
459        }
460
461        let initiate_response = self
462            .initiate_upload(token, file_name, workspace_rid)
463            .await?;
464        let upload_id = initiate_response.upload_id();
465        let key = initiate_response.key();
466
467        let response = self.upload_parts(token, reader, key, upload_id).await?;
468
469        let s3_path = response.location().ok_or_else(|| {
470            UploaderError::Other("Upload response did not contain a location".to_string())
471        })?;
472
473        Ok(s3_path.to_string())
474    }
475
476    pub async fn ingest_avro(
477        &self,
478        token: &BearerToken,
479        s3_path: &str,
480        data_source_rid: ResourceIdentifier,
481    ) -> Result<IngestResponse, UploaderError> {
482        let opts = IngestOptions::AvroStream(
483            AvroStreamOpts::builder()
484                .source(IngestSource::S3(S3IngestSource::new(s3_path)))
485                .target(DatasetIngestTarget::Existing(
486                    ExistingDatasetIngestDestination::new(data_source_rid),
487                ))
488                .build(),
489        );
490
491        let request = IngestRequest::new(opts);
492
493        self.ingest_client
494            .ingest(token, &request)
495            .await
496            .map_err(|e| UploaderError::Conjure(format!("{e:?}")))
497    }
498}
499
500pub struct ChunkedStreamReader {
501    reader: Box<dyn Read + Send>,
502    chunk_size: usize,
503    current_index: usize,
504}
505
506impl ChunkedStreamReader {
507    pub fn new<R>(reader: R, chunk_size: usize) -> Self
508    where
509        R: Read + Send + 'static,
510    {
511        Self {
512            reader: Box::new(reader),
513            chunk_size,
514            current_index: 0,
515        }
516    }
517}
518
519impl Stream for ChunkedStreamReader {
520    type Item = Result<(usize, Vec<u8>), std::io::Error>;
521
522    fn poll_next(
523        mut self: Pin<&mut Self>,
524        _cx: &mut std::task::Context<'_>,
525    ) -> std::task::Poll<Option<Self::Item>> {
526        let mut buffer = vec![0u8; self.chunk_size];
527
528        match self.reader.read(&mut buffer) {
529            Ok(0) => std::task::Poll::Ready(None),
530            Ok(n) => {
531                buffer.truncate(n);
532                let index = self.current_index;
533                self.current_index += 1;
534                std::task::Poll::Ready(Some(Ok((index, buffer))))
535            }
536            Err(e) => std::task::Poll::Ready(Some(Err(e))),
537        }
538    }
539}