1use std::io::Read;
2use std::io::Seek;
3use std::path::Path;
4use std::path::PathBuf;
5use std::pin::Pin;
6use std::sync::Arc;
7
8use conjure_error::Error;
9use conjure_http::client::AsyncWriteBody;
10use conjure_http::private::Stream;
11use conjure_object::BearerToken;
12use conjure_object::ResourceIdentifier;
13use conjure_object::SafeLong;
14use conjure_runtime_rustls_platform_verifier::BodyWriter;
15use conjure_runtime_rustls_platform_verifier::Client;
16use futures::StreamExt;
17use nominal_api::clients::ingest::api::AsyncIngestService;
18use nominal_api::clients::ingest::api::AsyncIngestServiceClient;
19use nominal_api::clients::upload::api::AsyncUploadService;
20use nominal_api::clients::upload::api::AsyncUploadServiceClient;
21use nominal_api::objects::api::rids::WorkspaceRid;
22use nominal_api::objects::ingest::api::AvroStreamOpts;
23use nominal_api::objects::ingest::api::CompleteMultipartUploadResponse;
24use nominal_api::objects::ingest::api::DatasetIngestTarget;
25use nominal_api::objects::ingest::api::ExistingDatasetIngestDestination;
26use nominal_api::objects::ingest::api::IngestOptions;
27use nominal_api::objects::ingest::api::IngestRequest;
28use nominal_api::objects::ingest::api::IngestResponse;
29use nominal_api::objects::ingest::api::IngestSource;
30use nominal_api::objects::ingest::api::InitiateMultipartUploadRequest;
31use nominal_api::objects::ingest::api::InitiateMultipartUploadResponse;
32use nominal_api::objects::ingest::api::Part;
33use nominal_api::objects::ingest::api::S3IngestSource;
34use tokio::sync::Semaphore;
35use tracing::error;
36use tracing::info;
37
38use crate::client::NominalApiClients;
39use crate::types::AuthProvider;
40
41const SMALL_FILE_SIZE_LIMIT: u64 = 512 * 1024 * 1024; #[derive(Clone)]
44pub struct AvroIngestManager {
45 pub upload_queue: async_channel::Receiver<PathBuf>,
46}
47
48impl AvroIngestManager {
49 pub fn new(
50 clients: NominalApiClients,
51 http_client: reqwest::Client,
52 handle: tokio::runtime::Handle,
53 opts: UploaderOpts,
54 upload_queue: async_channel::Receiver<PathBuf>,
55 auth_provider: impl AuthProvider + 'static,
56 data_source_rid: ResourceIdentifier,
57 ) -> Self {
58 let uploader = FileObjectStoreUploader::new(
59 clients.upload,
60 clients.ingest,
61 http_client,
62 handle.clone(),
63 opts,
64 );
65
66 let upload_queue_clone = upload_queue.clone();
67
68 handle.spawn(async move {
69 Self::run(upload_queue_clone, uploader, auth_provider, data_source_rid).await;
70 });
71
72 AvroIngestManager { upload_queue }
73 }
74
75 pub async fn run(
76 upload_queue: async_channel::Receiver<PathBuf>,
77 uploader: FileObjectStoreUploader,
78 auth_provider: impl AuthProvider + 'static,
79 data_source_rid: ResourceIdentifier,
80 ) {
81 while let Ok(file_path) = upload_queue.recv().await {
82 let file_name = file_path.to_str().unwrap_or("nmstream_file");
83 let file = std::fs::File::open(&file_path);
84 let Some(token) = auth_provider.token() else {
85 error!("Missing token for upload");
86 continue;
87 };
88 match file {
89 Ok(f) => {
90 match upload_and_ingest_file(
91 uploader.clone(),
92 &token,
93 auth_provider.workspace_rid(),
94 f,
95 file_name,
96 &file_path,
97 data_source_rid.clone(),
98 )
99 .await
100 {
101 Ok(()) => {}
102 Err(e) => {
103 error!(
104 "Error uploading and ingesting file {}: {}",
105 file_path.display(),
106 e
107 );
108 }
109 }
110 }
111 Err(e) => {
112 error!("Failed to open file {}: {:?}", file_path.display(), e);
113 }
114 }
115 }
116 }
117}
118
119async fn upload_and_ingest_file(
120 uploader: FileObjectStoreUploader,
121 token: &BearerToken,
122 workspace_rid: Option<WorkspaceRid>,
123 file: std::fs::File,
124 file_name: &str,
125 file_path: &PathBuf,
126 data_source_rid: ResourceIdentifier,
127) -> Result<(), String> {
128 match uploader.upload(token, file, file_name, workspace_rid).await {
129 Ok(response) => {
130 match uploader
131 .ingest_avro(token, &response, data_source_rid)
132 .await
133 {
134 Ok(ingest_response) => {
135 info!(
136 "Successfully uploaded and ingested file {}: {:?}",
137 file_name, ingest_response
138 );
139 if let Err(e) = std::fs::remove_file(file_path) {
140 Err(format!(
141 "Failed to remove file {}: {:?}",
142 file_path.display(),
143 e
144 ))
145 } else {
146 info!("Removed file {}", file_path.display());
147 Ok(())
148 }
149 }
150 Err(e) => Err(format!("Failed to ingest file {file_name}: {e:?}")),
151 }
152 }
153 Err(e) => Err(format!("Failed to upload file {file_name}: {e:?}")),
154 }
155}
156
157#[derive(Debug, thiserror::Error)]
158pub enum UploaderError {
159 #[error("Conjure error: {0}")]
160 Conjure(String),
161 #[error("Failed to initiate multipart upload: {0}")]
162 IOError(#[from] std::io::Error),
163 #[error("Failed to upload part: {0}")]
164 HTTPError(#[from] reqwest::Error),
165 #[error("Error executing upload tasks: {0}")]
166 TokioError(#[from] tokio::task::JoinError),
167 #[error("Error: {0}")]
168 Other(String),
169}
170
171#[derive(Debug, Clone)]
172pub struct UploaderOpts {
173 pub chunk_size: usize,
174 pub max_retries: usize,
175 pub max_concurrent_uploads: usize,
176}
177
178impl Default for UploaderOpts {
179 fn default() -> Self {
180 UploaderOpts {
181 chunk_size: 512 * 1024 * 1024, max_retries: 3,
183 max_concurrent_uploads: 1,
184 }
185 }
186}
187
188pub struct FileWriteBody {
189 file: std::fs::File,
190}
191
192impl FileWriteBody {
193 pub fn new(file: std::fs::File) -> Self {
194 FileWriteBody { file }
195 }
196}
197
198impl AsyncWriteBody<BodyWriter> for FileWriteBody {
199 async fn write_body(self: Pin<&mut Self>, w: Pin<&mut BodyWriter>) -> Result<(), Error> {
200 let mut file = self
201 .file
202 .try_clone()
203 .map_err(|e| Error::internal_safe(format!("Failed to clone file for upload: {e}")))?;
204
205 let mut buffer = Vec::new();
206 file.read_to_end(&mut buffer)
207 .map_err(|e| Error::internal_safe(format!("Failed to read bytes from file: {e}")))?;
208
209 w.write_bytes(buffer.into())
210 .await
211 .map_err(|e| Error::internal_safe(format!("Failed to write bytes to body: {e}")))?;
212
213 Ok(())
214 }
215
216 async fn reset(self: Pin<&mut Self>) -> bool {
217 let Ok(mut file) = self.file.try_clone() else {
218 return false;
219 };
220
221 use std::io::SeekFrom;
222
223 file.seek(SeekFrom::Start(0)).is_ok()
224 }
225}
226
227#[derive(Clone)]
228pub struct FileObjectStoreUploader {
229 upload_client: AsyncUploadServiceClient<Client>,
230 ingest_client: AsyncIngestServiceClient<Client>,
231 http_client: reqwest::Client,
232 handle: tokio::runtime::Handle,
233 opts: UploaderOpts,
234}
235
236impl FileObjectStoreUploader {
237 pub fn new(
238 upload_client: AsyncUploadServiceClient<Client>,
239 ingest_client: AsyncIngestServiceClient<Client>,
240 http_client: reqwest::Client,
241 handle: tokio::runtime::Handle,
242 opts: UploaderOpts,
243 ) -> Self {
244 FileObjectStoreUploader {
245 upload_client,
246 ingest_client,
247 http_client,
248 handle,
249 opts,
250 }
251 }
252
253 pub async fn initiate_upload(
254 &self,
255 token: &BearerToken,
256 file_name: &str,
257 workspace_rid: Option<WorkspaceRid>,
258 ) -> Result<InitiateMultipartUploadResponse, UploaderError> {
259 let request = InitiateMultipartUploadRequest::builder()
260 .filename(file_name)
261 .filetype("application/octet-stream")
262 .workspace(workspace_rid)
263 .build();
264 let response = self
265 .upload_client
266 .initiate_multipart_upload(token, &request)
267 .await
268 .map_err(|e| UploaderError::Conjure(format!("{e:?}")))?;
269
270 info!("Initiated multipart upload for file: {}", file_name);
271 Ok(response)
272 }
273
274 #[expect(clippy::too_many_arguments)]
275 async fn upload_part(
276 client: AsyncUploadServiceClient<Client>,
277 http_client: reqwest::Client,
278 token: BearerToken,
279 upload_id: String,
280 key: String,
281 part_number: i32,
282 chunk: Vec<u8>,
283 max_retries: usize,
284 ) -> Result<Part, UploaderError> {
285 let mut attempts = 0;
286
287 loop {
288 attempts += 1;
289 match Self::try_upload_part(
290 client.clone(),
291 http_client.clone(),
292 &token,
293 &upload_id,
294 &key,
295 part_number,
296 chunk.clone(),
297 )
298 .await
299 {
300 Ok(part) => return Ok(part),
301 Err(e) if attempts < max_retries => {
302 error!("Upload attempt {} failed, retrying: {}", attempts, e);
303 continue;
304 }
305 Err(e) => {
306 return Err(e);
307 }
308 }
309 }
310 }
311
312 async fn try_upload_part(
313 client: AsyncUploadServiceClient<Client>,
314 http_client: reqwest::Client,
315 token: &BearerToken,
316 upload_id: &str,
317 key: &str,
318 part_number: i32,
319 chunk: Vec<u8>,
320 ) -> Result<Part, UploaderError> {
321 let response = client
322 .sign_part(token, upload_id, key, part_number)
323 .await
324 .map_err(|e| UploaderError::Conjure(format!("{e:?}")))?;
325
326 let mut request_builder = http_client.put(response.url()).body(chunk);
327
328 for (header_name, header_value) in response.headers() {
329 request_builder = request_builder.header(header_name, header_value);
330 }
331
332 let http_response = request_builder.send().await?;
333 let headers = http_response.headers().clone();
334 let status = http_response.status();
335
336 if !status.is_success() {
337 error!("Failed to upload body");
338 return Err(UploaderError::Other(format!(
339 "Failed to upload part {part_number}: HTTP status {status}"
340 )));
341 }
342
343 let etag = headers
344 .get("etag")
345 .and_then(|v| v.to_str().ok())
346 .unwrap_or("ignored-etag");
347
348 Ok(Part::new(part_number, etag))
349 }
350
351 pub async fn upload_parts<R>(
352 &self,
353 token: &BearerToken,
354 reader: R,
355 key: &str,
356 upload_id: &str,
357 ) -> Result<CompleteMultipartUploadResponse, UploaderError>
358 where
359 R: Read + Send + 'static,
360 {
361 let chunks = ChunkedStreamReader::new(reader, self.opts.chunk_size);
362
363 let parallel_part_uploads = Arc::new(Semaphore::new(self.opts.max_concurrent_uploads));
364 let mut upload_futures = Vec::new();
365
366 futures::pin_mut!(chunks);
367
368 while let Some(entry) = chunks.next().await {
369 let (index, chunk) = entry?;
370 let part_number = (index + 1) as i32;
371
372 let token = token.clone();
373 let key = key.to_string();
374 let upload_id = upload_id.to_string();
375 let parallel_part_uploads = Arc::clone(¶llel_part_uploads);
376 let client = self.upload_client.clone();
377 let http_client = self.http_client.clone();
378 let max_retries = self.opts.max_retries;
379
380 upload_futures.push(self.handle.spawn(async move {
381 let _permit = parallel_part_uploads.acquire().await;
382 Self::upload_part(
383 client,
384 http_client,
385 token,
386 upload_id,
387 key,
388 part_number,
389 chunk,
390 max_retries,
391 )
392 .await
393 }));
394 }
395
396 let mut part_responses = futures::future::join_all(upload_futures)
397 .await
398 .into_iter()
399 .map(|result| result.map_err(UploaderError::TokioError)?)
400 .collect::<Result<Vec<_>, _>>()?;
401
402 part_responses.sort_by_key(|part| part.part_number());
403
404 let response = self
405 .upload_client
406 .complete_multipart_upload(token, upload_id, key, &part_responses)
407 .await
408 .map_err(|e| UploaderError::Conjure(format!("{e:?}")))?;
409
410 Ok(response)
411 }
412
413 pub async fn upload_small_file(
414 &self,
415 token: &BearerToken,
416 file_name: &str,
417 size_bytes: i64,
418 workspace_rid: Option<WorkspaceRid>,
419 file: std::fs::File,
420 ) -> Result<String, UploaderError> {
421 let s3_path = self
422 .upload_client
423 .upload_file(
424 token,
425 file_name,
426 SafeLong::new(size_bytes).ok(),
427 workspace_rid.as_ref(),
428 FileWriteBody::new(file),
429 )
430 .await
431 .map_err(|e| UploaderError::Conjure(format!("{e:?}")))?;
432
433 Ok(s3_path.as_str().to_string())
434 }
435
436 pub async fn upload<R>(
437 &self,
438 token: &BearerToken,
439 reader: R,
440 file_name: impl Into<&str>,
441 workspace_rid: Option<WorkspaceRid>,
442 ) -> Result<String, UploaderError>
443 where
444 R: Read + Send + 'static,
445 {
446 let file_name = file_name.into();
447 let path = Path::new(file_name);
448 let file_size = std::fs::metadata(path)?.len();
449 if file_size < SMALL_FILE_SIZE_LIMIT {
450 return self
451 .upload_small_file(
452 token,
453 file_name,
454 file_size as i64,
455 workspace_rid,
456 std::fs::File::open(path)?,
457 )
458 .await;
459 }
460
461 let initiate_response = self
462 .initiate_upload(token, file_name, workspace_rid)
463 .await?;
464 let upload_id = initiate_response.upload_id();
465 let key = initiate_response.key();
466
467 let response = self.upload_parts(token, reader, key, upload_id).await?;
468
469 let s3_path = response.location().ok_or_else(|| {
470 UploaderError::Other("Upload response did not contain a location".to_string())
471 })?;
472
473 Ok(s3_path.to_string())
474 }
475
476 pub async fn ingest_avro(
477 &self,
478 token: &BearerToken,
479 s3_path: &str,
480 data_source_rid: ResourceIdentifier,
481 ) -> Result<IngestResponse, UploaderError> {
482 let opts = IngestOptions::AvroStream(
483 AvroStreamOpts::builder()
484 .source(IngestSource::S3(S3IngestSource::new(s3_path)))
485 .target(DatasetIngestTarget::Existing(
486 ExistingDatasetIngestDestination::new(data_source_rid),
487 ))
488 .build(),
489 );
490
491 let request = IngestRequest::new(opts);
492
493 self.ingest_client
494 .ingest(token, &request)
495 .await
496 .map_err(|e| UploaderError::Conjure(format!("{e:?}")))
497 }
498}
499
500pub struct ChunkedStreamReader {
501 reader: Box<dyn Read + Send>,
502 chunk_size: usize,
503 current_index: usize,
504}
505
506impl ChunkedStreamReader {
507 pub fn new<R>(reader: R, chunk_size: usize) -> Self
508 where
509 R: Read + Send + 'static,
510 {
511 Self {
512 reader: Box::new(reader),
513 chunk_size,
514 current_index: 0,
515 }
516 }
517}
518
519impl Stream for ChunkedStreamReader {
520 type Item = Result<(usize, Vec<u8>), std::io::Error>;
521
522 fn poll_next(
523 mut self: Pin<&mut Self>,
524 _cx: &mut std::task::Context<'_>,
525 ) -> std::task::Poll<Option<Self::Item>> {
526 let mut buffer = vec![0u8; self.chunk_size];
527
528 match self.reader.read(&mut buffer) {
529 Ok(0) => std::task::Poll::Ready(None),
530 Ok(n) => {
531 buffer.truncate(n);
532 let index = self.current_index;
533 self.current_index += 1;
534 std::task::Poll::Ready(Some(Ok((index, buffer))))
535 }
536 Err(e) => std::task::Poll::Ready(Some(Err(e))),
537 }
538 }
539}