Skip to main content

nominal_streaming/
upload.rs

1use std::io::Read;
2use std::io::Seek;
3use std::path::Path;
4use std::path::PathBuf;
5use std::pin::Pin;
6use std::sync::Arc;
7
8use conjure_error::Error;
9use conjure_http::client::AsyncWriteBody;
10use conjure_http::private::Stream;
11use conjure_object::BearerToken;
12use conjure_object::ResourceIdentifier;
13use conjure_object::SafeLong;
14use conjure_runtime_rustls_platform_verifier::conjure_runtime::BodyWriter;
15use conjure_runtime_rustls_platform_verifier::PlatformVerifierClient;
16use futures::StreamExt;
17use nominal_api::api::rids::WorkspaceRid;
18use nominal_api::ingest::api::AvroStreamOpts;
19use nominal_api::ingest::api::CompleteMultipartUploadResponse;
20use nominal_api::ingest::api::DatasetIngestTarget;
21use nominal_api::ingest::api::ExistingDatasetIngestDestination;
22use nominal_api::ingest::api::IngestOptions;
23use nominal_api::ingest::api::IngestRequest;
24use nominal_api::ingest::api::IngestResponse;
25use nominal_api::ingest::api::IngestServiceAsyncClient;
26use nominal_api::ingest::api::IngestSource;
27use nominal_api::ingest::api::InitiateMultipartUploadRequest;
28use nominal_api::ingest::api::InitiateMultipartUploadResponse;
29use nominal_api::ingest::api::Part;
30use nominal_api::ingest::api::S3IngestSource;
31use nominal_api::upload::api::UploadServiceAsyncClient;
32use tokio::sync::Semaphore;
33use tracing::error;
34use tracing::info;
35
36use crate::client::NominalApiClients;
37use crate::types::AuthProvider;
38
39const SMALL_FILE_SIZE_LIMIT: u64 = 512 * 1024 * 1024; // 512 MB
40
41#[derive(Clone)]
42pub struct AvroIngestManager {
43    pub upload_queue: async_channel::Receiver<PathBuf>,
44}
45
46impl AvroIngestManager {
47    pub fn new(
48        clients: NominalApiClients,
49        http_client: reqwest::Client,
50        handle: tokio::runtime::Handle,
51        opts: UploaderOpts,
52        upload_queue: async_channel::Receiver<PathBuf>,
53        auth_provider: impl AuthProvider + 'static,
54        data_source_rid: ResourceIdentifier,
55    ) -> Self {
56        let uploader = FileObjectStoreUploader::new(
57            clients.upload,
58            clients.ingest,
59            http_client,
60            handle.clone(),
61            opts,
62        );
63
64        let upload_queue_clone = upload_queue.clone();
65
66        handle.spawn(async move {
67            Self::run(upload_queue_clone, uploader, auth_provider, data_source_rid).await;
68        });
69
70        AvroIngestManager { upload_queue }
71    }
72
73    pub async fn run(
74        upload_queue: async_channel::Receiver<PathBuf>,
75        uploader: FileObjectStoreUploader,
76        auth_provider: impl AuthProvider + 'static,
77        data_source_rid: ResourceIdentifier,
78    ) {
79        while let Ok(file_path) = upload_queue.recv().await {
80            let file_name = file_path.to_str().unwrap_or("nmstream_file");
81            let file = std::fs::File::open(&file_path);
82            let Some(token) = auth_provider.token() else {
83                error!("Missing token for upload");
84                continue;
85            };
86            match file {
87                Ok(f) => {
88                    match upload_and_ingest_file(
89                        uploader.clone(),
90                        &token,
91                        auth_provider.workspace_rid(),
92                        f,
93                        file_name,
94                        &file_path,
95                        data_source_rid.clone(),
96                    )
97                    .await
98                    {
99                        Ok(()) => {}
100                        Err(e) => {
101                            error!(
102                                "Error uploading and ingesting file {}: {}",
103                                file_path.display(),
104                                e
105                            );
106                        }
107                    }
108                }
109                Err(e) => {
110                    error!("Failed to open file {}: {:?}", file_path.display(), e);
111                }
112            }
113        }
114    }
115}
116
117async fn upload_and_ingest_file(
118    uploader: FileObjectStoreUploader,
119    token: &BearerToken,
120    workspace_rid: Option<WorkspaceRid>,
121    file: std::fs::File,
122    file_name: &str,
123    file_path: &PathBuf,
124    data_source_rid: ResourceIdentifier,
125) -> Result<(), String> {
126    match uploader.upload(token, file, file_name, workspace_rid).await {
127        Ok(response) => {
128            match uploader
129                .ingest_avro(token, &response, data_source_rid)
130                .await
131            {
132                Ok(ingest_response) => {
133                    info!(
134                        "Successfully uploaded and ingested file {}: {:?}",
135                        file_name, ingest_response
136                    );
137                    if let Err(e) = std::fs::remove_file(file_path) {
138                        Err(format!(
139                            "Failed to remove file {}: {:?}",
140                            file_path.display(),
141                            e
142                        ))
143                    } else {
144                        info!("Removed file {}", file_path.display());
145                        Ok(())
146                    }
147                }
148                Err(e) => Err(format!("Failed to ingest file {file_name}: {e:?}")),
149            }
150        }
151        Err(e) => Err(format!("Failed to upload file {file_name}: {e:?}")),
152    }
153}
154
155#[derive(Debug, thiserror::Error)]
156pub enum UploaderError {
157    #[error("Conjure error: {0}")]
158    Conjure(String),
159    #[error("Failed to initiate multipart upload: {0}")]
160    IOError(#[from] std::io::Error),
161    #[error("Failed to upload part: {0}")]
162    HTTPError(#[from] reqwest::Error),
163    #[error("Error executing upload tasks: {0}")]
164    TokioError(#[from] tokio::task::JoinError),
165    #[error("Error: {0}")]
166    Other(String),
167}
168
169#[derive(Debug, Clone)]
170pub struct UploaderOpts {
171    pub chunk_size: usize,
172    pub max_retries: usize,
173    pub max_concurrent_uploads: usize,
174}
175
176impl Default for UploaderOpts {
177    fn default() -> Self {
178        UploaderOpts {
179            chunk_size: 512 * 1024 * 1024, // 512 MB
180            max_retries: 3,
181            max_concurrent_uploads: 1,
182        }
183    }
184}
185
186pub struct FileWriteBody {
187    file: std::fs::File,
188}
189
190impl FileWriteBody {
191    pub fn new(file: std::fs::File) -> Self {
192        FileWriteBody { file }
193    }
194}
195
196impl AsyncWriteBody<BodyWriter> for FileWriteBody {
197    async fn write_body(self: Pin<&mut Self>, w: Pin<&mut BodyWriter>) -> Result<(), Error> {
198        let mut file = self
199            .file
200            .try_clone()
201            .map_err(|e| Error::internal_safe(format!("Failed to clone file for upload: {e}")))?;
202
203        let mut buffer = Vec::new();
204        file.read_to_end(&mut buffer)
205            .map_err(|e| Error::internal_safe(format!("Failed to read bytes from file: {e}")))?;
206
207        w.write_bytes(buffer.into())
208            .await
209            .map_err(|e| Error::internal_safe(format!("Failed to write bytes to body: {e}")))?;
210
211        Ok(())
212    }
213
214    async fn reset(self: Pin<&mut Self>) -> bool {
215        let Ok(mut file) = self.file.try_clone() else {
216            return false;
217        };
218
219        use std::io::SeekFrom;
220
221        file.seek(SeekFrom::Start(0)).is_ok()
222    }
223}
224
225#[derive(Clone)]
226pub struct FileObjectStoreUploader {
227    upload_client: UploadServiceAsyncClient<PlatformVerifierClient>,
228    ingest_client: IngestServiceAsyncClient<PlatformVerifierClient>,
229    http_client: reqwest::Client,
230    handle: tokio::runtime::Handle,
231    opts: UploaderOpts,
232}
233
234impl FileObjectStoreUploader {
235    pub fn new(
236        upload_client: UploadServiceAsyncClient<PlatformVerifierClient>,
237        ingest_client: IngestServiceAsyncClient<PlatformVerifierClient>,
238        http_client: reqwest::Client,
239        handle: tokio::runtime::Handle,
240        opts: UploaderOpts,
241    ) -> Self {
242        FileObjectStoreUploader {
243            upload_client,
244            ingest_client,
245            http_client,
246            handle,
247            opts,
248        }
249    }
250
251    pub async fn initiate_upload(
252        &self,
253        token: &BearerToken,
254        file_name: &str,
255        workspace_rid: Option<WorkspaceRid>,
256    ) -> Result<InitiateMultipartUploadResponse, UploaderError> {
257        let request = InitiateMultipartUploadRequest::builder()
258            .filename(file_name)
259            .filetype("application/octet-stream")
260            .workspace(workspace_rid)
261            .build();
262        let response = self
263            .upload_client
264            .initiate_multipart_upload(token, &request)
265            .await
266            .map_err(|e| UploaderError::Conjure(format!("{e:?}")))?;
267
268        info!("Initiated multipart upload for file: {}", file_name);
269        Ok(response)
270    }
271
272    #[expect(clippy::too_many_arguments)]
273    async fn upload_part(
274        client: UploadServiceAsyncClient<PlatformVerifierClient>,
275        http_client: reqwest::Client,
276        token: BearerToken,
277        upload_id: String,
278        key: String,
279        part_number: i32,
280        chunk: Vec<u8>,
281        max_retries: usize,
282    ) -> Result<Part, UploaderError> {
283        let mut attempts = 0;
284
285        loop {
286            attempts += 1;
287            match Self::try_upload_part(
288                client.clone(),
289                http_client.clone(),
290                &token,
291                &upload_id,
292                &key,
293                part_number,
294                chunk.clone(),
295            )
296            .await
297            {
298                Ok(part) => return Ok(part),
299                Err(e) if attempts < max_retries => {
300                    error!("Upload attempt {} failed, retrying: {}", attempts, e);
301                    continue;
302                }
303                Err(e) => {
304                    return Err(e);
305                }
306            }
307        }
308    }
309
310    async fn try_upload_part(
311        client: UploadServiceAsyncClient<PlatformVerifierClient>,
312        http_client: reqwest::Client,
313        token: &BearerToken,
314        upload_id: &str,
315        key: &str,
316        part_number: i32,
317        chunk: Vec<u8>,
318    ) -> Result<Part, UploaderError> {
319        let response = client
320            .sign_part(token, upload_id, key, part_number)
321            .await
322            .map_err(|e| UploaderError::Conjure(format!("{e:?}")))?;
323
324        let mut request_builder = http_client.put(response.url()).body(chunk);
325
326        for (header_name, header_value) in response.headers() {
327            request_builder = request_builder.header(header_name, header_value);
328        }
329
330        let http_response = request_builder.send().await?;
331        let headers = http_response.headers().clone();
332        let status = http_response.status();
333
334        if !status.is_success() {
335            error!("Failed to upload body");
336            return Err(UploaderError::Other(format!(
337                "Failed to upload part {part_number}: HTTP status {status}"
338            )));
339        }
340
341        let etag = headers
342            .get("etag")
343            .and_then(|v| v.to_str().ok())
344            .unwrap_or("ignored-etag");
345
346        Ok(Part::new(part_number, etag))
347    }
348
349    pub async fn upload_parts<R>(
350        &self,
351        token: &BearerToken,
352        reader: R,
353        key: &str,
354        upload_id: &str,
355    ) -> Result<CompleteMultipartUploadResponse, UploaderError>
356    where
357        R: Read + Send + 'static,
358    {
359        let chunks = ChunkedStreamReader::new(reader, self.opts.chunk_size);
360
361        let parallel_part_uploads = Arc::new(Semaphore::new(self.opts.max_concurrent_uploads));
362        let mut upload_futures = Vec::new();
363
364        futures::pin_mut!(chunks);
365
366        while let Some(entry) = chunks.next().await {
367            let (index, chunk) = entry?;
368            let part_number = (index + 1) as i32;
369
370            let token = token.clone();
371            let key = key.to_string();
372            let upload_id = upload_id.to_string();
373            let parallel_part_uploads = Arc::clone(&parallel_part_uploads);
374            let client = self.upload_client.clone();
375            let http_client = self.http_client.clone();
376            let max_retries = self.opts.max_retries;
377
378            upload_futures.push(self.handle.spawn(async move {
379                let _permit = parallel_part_uploads.acquire().await;
380                Self::upload_part(
381                    client,
382                    http_client,
383                    token,
384                    upload_id,
385                    key,
386                    part_number,
387                    chunk,
388                    max_retries,
389                )
390                .await
391            }));
392        }
393
394        let mut part_responses = futures::future::join_all(upload_futures)
395            .await
396            .into_iter()
397            .map(|result| result.map_err(UploaderError::TokioError)?)
398            .collect::<Result<Vec<_>, _>>()?;
399
400        part_responses.sort_by_key(|part| part.part_number());
401
402        let response = self
403            .upload_client
404            .complete_multipart_upload(token, upload_id, key, &part_responses)
405            .await
406            .map_err(|e| UploaderError::Conjure(format!("{e:?}")))?;
407
408        Ok(response)
409    }
410
411    pub async fn upload_small_file(
412        &self,
413        token: &BearerToken,
414        file_name: &str,
415        size_bytes: i64,
416        workspace_rid: Option<WorkspaceRid>,
417        file: std::fs::File,
418    ) -> Result<String, UploaderError> {
419        let s3_path = self
420            .upload_client
421            .upload_file(
422                token,
423                file_name,
424                SafeLong::new(size_bytes).ok(),
425                workspace_rid.as_ref(),
426                FileWriteBody::new(file),
427            )
428            .await
429            .map_err(|e| UploaderError::Conjure(format!("{e:?}")))?;
430
431        Ok(s3_path.as_str().to_string())
432    }
433
434    pub async fn upload<R>(
435        &self,
436        token: &BearerToken,
437        reader: R,
438        file_name: impl Into<&str>,
439        workspace_rid: Option<WorkspaceRid>,
440    ) -> Result<String, UploaderError>
441    where
442        R: Read + Send + 'static,
443    {
444        let file_name = file_name.into();
445        let path = Path::new(file_name);
446        let file_size = std::fs::metadata(path)?.len();
447        if file_size < SMALL_FILE_SIZE_LIMIT {
448            return self
449                .upload_small_file(
450                    token,
451                    file_name,
452                    file_size as i64,
453                    workspace_rid,
454                    std::fs::File::open(path)?,
455                )
456                .await;
457        }
458
459        let initiate_response = self
460            .initiate_upload(token, file_name, workspace_rid)
461            .await?;
462        let upload_id = initiate_response.upload_id();
463        let key = initiate_response.key();
464
465        let response = self.upload_parts(token, reader, key, upload_id).await?;
466
467        let s3_path = response.location().ok_or_else(|| {
468            UploaderError::Other("Upload response did not contain a location".to_string())
469        })?;
470
471        Ok(s3_path.to_string())
472    }
473
474    pub async fn ingest_avro(
475        &self,
476        token: &BearerToken,
477        s3_path: &str,
478        data_source_rid: ResourceIdentifier,
479    ) -> Result<IngestResponse, UploaderError> {
480        let opts = IngestOptions::AvroStream(
481            AvroStreamOpts::builder()
482                .source(IngestSource::S3(S3IngestSource::new(s3_path)))
483                .target(DatasetIngestTarget::Existing(
484                    ExistingDatasetIngestDestination::new(data_source_rid),
485                ))
486                .build(),
487        );
488
489        let request = IngestRequest::new(opts);
490
491        self.ingest_client
492            .ingest(token, &request)
493            .await
494            .map_err(|e| UploaderError::Conjure(format!("{e:?}")))
495    }
496}
497
498pub struct ChunkedStreamReader {
499    reader: Box<dyn Read + Send>,
500    chunk_size: usize,
501    current_index: usize,
502}
503
504impl ChunkedStreamReader {
505    pub fn new<R>(reader: R, chunk_size: usize) -> Self
506    where
507        R: Read + Send + 'static,
508    {
509        Self {
510            reader: Box::new(reader),
511            chunk_size,
512            current_index: 0,
513        }
514    }
515}
516
517impl Stream for ChunkedStreamReader {
518    type Item = Result<(usize, Vec<u8>), std::io::Error>;
519
520    fn poll_next(
521        mut self: Pin<&mut Self>,
522        _cx: &mut std::task::Context<'_>,
523    ) -> std::task::Poll<Option<Self::Item>> {
524        let mut buffer = vec![0u8; self.chunk_size];
525
526        match self.reader.read(&mut buffer) {
527            Ok(0) => std::task::Poll::Ready(None),
528            Ok(n) => {
529                buffer.truncate(n);
530                let index = self.current_index;
531                self.current_index += 1;
532                std::task::Poll::Ready(Some(Ok((index, buffer))))
533            }
534            Err(e) => std::task::Poll::Ready(Some(Err(e))),
535        }
536    }
537}