1use std::io::Read;
2use std::io::Seek;
3use std::path::Path;
4use std::path::PathBuf;
5use std::pin::Pin;
6use std::sync::Arc;
7
8use conjure_error::Error;
9use conjure_http::client::AsyncWriteBody;
10use conjure_http::private::Stream;
11use conjure_object::BearerToken;
12use conjure_object::ResourceIdentifier;
13use conjure_object::SafeLong;
14use conjure_runtime_rustls_platform_verifier::conjure_runtime::BodyWriter;
15use conjure_runtime_rustls_platform_verifier::PlatformVerifierClient;
16use futures::StreamExt;
17use nominal_api::api::rids::WorkspaceRid;
18use nominal_api::ingest::api::AvroStreamOpts;
19use nominal_api::ingest::api::CompleteMultipartUploadResponse;
20use nominal_api::ingest::api::DatasetIngestTarget;
21use nominal_api::ingest::api::ExistingDatasetIngestDestination;
22use nominal_api::ingest::api::IngestOptions;
23use nominal_api::ingest::api::IngestRequest;
24use nominal_api::ingest::api::IngestResponse;
25use nominal_api::ingest::api::IngestServiceAsyncClient;
26use nominal_api::ingest::api::IngestSource;
27use nominal_api::ingest::api::InitiateMultipartUploadRequest;
28use nominal_api::ingest::api::InitiateMultipartUploadResponse;
29use nominal_api::ingest::api::Part;
30use nominal_api::ingest::api::S3IngestSource;
31use nominal_api::upload::api::UploadServiceAsyncClient;
32use tokio::sync::Semaphore;
33use tracing::error;
34use tracing::info;
35
36use crate::client::NominalApiClients;
37use crate::types::AuthProvider;
38
39const SMALL_FILE_SIZE_LIMIT: u64 = 512 * 1024 * 1024; #[derive(Clone)]
42pub struct AvroIngestManager {
43 pub upload_queue: async_channel::Receiver<PathBuf>,
44}
45
46impl AvroIngestManager {
47 pub fn new(
48 clients: NominalApiClients,
49 http_client: reqwest::Client,
50 handle: tokio::runtime::Handle,
51 opts: UploaderOpts,
52 upload_queue: async_channel::Receiver<PathBuf>,
53 auth_provider: impl AuthProvider + 'static,
54 data_source_rid: ResourceIdentifier,
55 ) -> Self {
56 let uploader = FileObjectStoreUploader::new(
57 clients.upload,
58 clients.ingest,
59 http_client,
60 handle.clone(),
61 opts,
62 );
63
64 let upload_queue_clone = upload_queue.clone();
65
66 handle.spawn(async move {
67 Self::run(upload_queue_clone, uploader, auth_provider, data_source_rid).await;
68 });
69
70 AvroIngestManager { upload_queue }
71 }
72
73 pub async fn run(
74 upload_queue: async_channel::Receiver<PathBuf>,
75 uploader: FileObjectStoreUploader,
76 auth_provider: impl AuthProvider + 'static,
77 data_source_rid: ResourceIdentifier,
78 ) {
79 while let Ok(file_path) = upload_queue.recv().await {
80 let file_name = file_path.to_str().unwrap_or("nmstream_file");
81 let file = std::fs::File::open(&file_path);
82 let Some(token) = auth_provider.token() else {
83 error!("Missing token for upload");
84 continue;
85 };
86 match file {
87 Ok(f) => {
88 match upload_and_ingest_file(
89 uploader.clone(),
90 &token,
91 auth_provider.workspace_rid(),
92 f,
93 file_name,
94 &file_path,
95 data_source_rid.clone(),
96 )
97 .await
98 {
99 Ok(()) => {}
100 Err(e) => {
101 error!(
102 "Error uploading and ingesting file {}: {}",
103 file_path.display(),
104 e
105 );
106 }
107 }
108 }
109 Err(e) => {
110 error!("Failed to open file {}: {:?}", file_path.display(), e);
111 }
112 }
113 }
114 }
115}
116
117async fn upload_and_ingest_file(
118 uploader: FileObjectStoreUploader,
119 token: &BearerToken,
120 workspace_rid: Option<WorkspaceRid>,
121 file: std::fs::File,
122 file_name: &str,
123 file_path: &PathBuf,
124 data_source_rid: ResourceIdentifier,
125) -> Result<(), String> {
126 match uploader.upload(token, file, file_name, workspace_rid).await {
127 Ok(response) => {
128 match uploader
129 .ingest_avro(token, &response, data_source_rid)
130 .await
131 {
132 Ok(ingest_response) => {
133 info!(
134 "Successfully uploaded and ingested file {}: {:?}",
135 file_name, ingest_response
136 );
137 if let Err(e) = std::fs::remove_file(file_path) {
138 Err(format!(
139 "Failed to remove file {}: {:?}",
140 file_path.display(),
141 e
142 ))
143 } else {
144 info!("Removed file {}", file_path.display());
145 Ok(())
146 }
147 }
148 Err(e) => Err(format!("Failed to ingest file {file_name}: {e:?}")),
149 }
150 }
151 Err(e) => Err(format!("Failed to upload file {file_name}: {e:?}")),
152 }
153}
154
155#[derive(Debug, thiserror::Error)]
156pub enum UploaderError {
157 #[error("Conjure error: {0}")]
158 Conjure(String),
159 #[error("Failed to initiate multipart upload: {0}")]
160 IOError(#[from] std::io::Error),
161 #[error("Failed to upload part: {0}")]
162 HTTPError(#[from] reqwest::Error),
163 #[error("Error executing upload tasks: {0}")]
164 TokioError(#[from] tokio::task::JoinError),
165 #[error("Error: {0}")]
166 Other(String),
167}
168
169#[derive(Debug, Clone)]
170pub struct UploaderOpts {
171 pub chunk_size: usize,
172 pub max_retries: usize,
173 pub max_concurrent_uploads: usize,
174}
175
176impl Default for UploaderOpts {
177 fn default() -> Self {
178 UploaderOpts {
179 chunk_size: 512 * 1024 * 1024, max_retries: 3,
181 max_concurrent_uploads: 1,
182 }
183 }
184}
185
186pub struct FileWriteBody {
187 file: std::fs::File,
188}
189
190impl FileWriteBody {
191 pub fn new(file: std::fs::File) -> Self {
192 FileWriteBody { file }
193 }
194}
195
196impl AsyncWriteBody<BodyWriter> for FileWriteBody {
197 async fn write_body(self: Pin<&mut Self>, w: Pin<&mut BodyWriter>) -> Result<(), Error> {
198 let mut file = self
199 .file
200 .try_clone()
201 .map_err(|e| Error::internal_safe(format!("Failed to clone file for upload: {e}")))?;
202
203 let mut buffer = Vec::new();
204 file.read_to_end(&mut buffer)
205 .map_err(|e| Error::internal_safe(format!("Failed to read bytes from file: {e}")))?;
206
207 w.write_bytes(buffer.into())
208 .await
209 .map_err(|e| Error::internal_safe(format!("Failed to write bytes to body: {e}")))?;
210
211 Ok(())
212 }
213
214 async fn reset(self: Pin<&mut Self>) -> bool {
215 let Ok(mut file) = self.file.try_clone() else {
216 return false;
217 };
218
219 use std::io::SeekFrom;
220
221 file.seek(SeekFrom::Start(0)).is_ok()
222 }
223}
224
225#[derive(Clone)]
226pub struct FileObjectStoreUploader {
227 upload_client: UploadServiceAsyncClient<PlatformVerifierClient>,
228 ingest_client: IngestServiceAsyncClient<PlatformVerifierClient>,
229 http_client: reqwest::Client,
230 handle: tokio::runtime::Handle,
231 opts: UploaderOpts,
232}
233
234impl FileObjectStoreUploader {
235 pub fn new(
236 upload_client: UploadServiceAsyncClient<PlatformVerifierClient>,
237 ingest_client: IngestServiceAsyncClient<PlatformVerifierClient>,
238 http_client: reqwest::Client,
239 handle: tokio::runtime::Handle,
240 opts: UploaderOpts,
241 ) -> Self {
242 FileObjectStoreUploader {
243 upload_client,
244 ingest_client,
245 http_client,
246 handle,
247 opts,
248 }
249 }
250
251 pub async fn initiate_upload(
252 &self,
253 token: &BearerToken,
254 file_name: &str,
255 workspace_rid: Option<WorkspaceRid>,
256 ) -> Result<InitiateMultipartUploadResponse, UploaderError> {
257 let request = InitiateMultipartUploadRequest::builder()
258 .filename(file_name)
259 .filetype("application/octet-stream")
260 .workspace(workspace_rid)
261 .build();
262 let response = self
263 .upload_client
264 .initiate_multipart_upload(token, &request)
265 .await
266 .map_err(|e| UploaderError::Conjure(format!("{e:?}")))?;
267
268 info!("Initiated multipart upload for file: {}", file_name);
269 Ok(response)
270 }
271
272 #[expect(clippy::too_many_arguments)]
273 async fn upload_part(
274 client: UploadServiceAsyncClient<PlatformVerifierClient>,
275 http_client: reqwest::Client,
276 token: BearerToken,
277 upload_id: String,
278 key: String,
279 part_number: i32,
280 chunk: Vec<u8>,
281 max_retries: usize,
282 ) -> Result<Part, UploaderError> {
283 let mut attempts = 0;
284
285 loop {
286 attempts += 1;
287 match Self::try_upload_part(
288 client.clone(),
289 http_client.clone(),
290 &token,
291 &upload_id,
292 &key,
293 part_number,
294 chunk.clone(),
295 )
296 .await
297 {
298 Ok(part) => return Ok(part),
299 Err(e) if attempts < max_retries => {
300 error!("Upload attempt {} failed, retrying: {}", attempts, e);
301 continue;
302 }
303 Err(e) => {
304 return Err(e);
305 }
306 }
307 }
308 }
309
310 async fn try_upload_part(
311 client: UploadServiceAsyncClient<PlatformVerifierClient>,
312 http_client: reqwest::Client,
313 token: &BearerToken,
314 upload_id: &str,
315 key: &str,
316 part_number: i32,
317 chunk: Vec<u8>,
318 ) -> Result<Part, UploaderError> {
319 let response = client
320 .sign_part(token, upload_id, key, part_number)
321 .await
322 .map_err(|e| UploaderError::Conjure(format!("{e:?}")))?;
323
324 let mut request_builder = http_client.put(response.url()).body(chunk);
325
326 for (header_name, header_value) in response.headers() {
327 request_builder = request_builder.header(header_name, header_value);
328 }
329
330 let http_response = request_builder.send().await?;
331 let headers = http_response.headers().clone();
332 let status = http_response.status();
333
334 if !status.is_success() {
335 error!("Failed to upload body");
336 return Err(UploaderError::Other(format!(
337 "Failed to upload part {part_number}: HTTP status {status}"
338 )));
339 }
340
341 let etag = headers
342 .get("etag")
343 .and_then(|v| v.to_str().ok())
344 .unwrap_or("ignored-etag");
345
346 Ok(Part::new(part_number, etag))
347 }
348
349 pub async fn upload_parts<R>(
350 &self,
351 token: &BearerToken,
352 reader: R,
353 key: &str,
354 upload_id: &str,
355 ) -> Result<CompleteMultipartUploadResponse, UploaderError>
356 where
357 R: Read + Send + 'static,
358 {
359 let chunks = ChunkedStreamReader::new(reader, self.opts.chunk_size);
360
361 let parallel_part_uploads = Arc::new(Semaphore::new(self.opts.max_concurrent_uploads));
362 let mut upload_futures = Vec::new();
363
364 futures::pin_mut!(chunks);
365
366 while let Some(entry) = chunks.next().await {
367 let (index, chunk) = entry?;
368 let part_number = (index + 1) as i32;
369
370 let token = token.clone();
371 let key = key.to_string();
372 let upload_id = upload_id.to_string();
373 let parallel_part_uploads = Arc::clone(¶llel_part_uploads);
374 let client = self.upload_client.clone();
375 let http_client = self.http_client.clone();
376 let max_retries = self.opts.max_retries;
377
378 upload_futures.push(self.handle.spawn(async move {
379 let _permit = parallel_part_uploads.acquire().await;
380 Self::upload_part(
381 client,
382 http_client,
383 token,
384 upload_id,
385 key,
386 part_number,
387 chunk,
388 max_retries,
389 )
390 .await
391 }));
392 }
393
394 let mut part_responses = futures::future::join_all(upload_futures)
395 .await
396 .into_iter()
397 .map(|result| result.map_err(UploaderError::TokioError)?)
398 .collect::<Result<Vec<_>, _>>()?;
399
400 part_responses.sort_by_key(|part| part.part_number());
401
402 let response = self
403 .upload_client
404 .complete_multipart_upload(token, upload_id, key, &part_responses)
405 .await
406 .map_err(|e| UploaderError::Conjure(format!("{e:?}")))?;
407
408 Ok(response)
409 }
410
411 pub async fn upload_small_file(
412 &self,
413 token: &BearerToken,
414 file_name: &str,
415 size_bytes: i64,
416 workspace_rid: Option<WorkspaceRid>,
417 file: std::fs::File,
418 ) -> Result<String, UploaderError> {
419 let s3_path = self
420 .upload_client
421 .upload_file(
422 token,
423 file_name,
424 SafeLong::new(size_bytes).ok(),
425 workspace_rid.as_ref(),
426 FileWriteBody::new(file),
427 )
428 .await
429 .map_err(|e| UploaderError::Conjure(format!("{e:?}")))?;
430
431 Ok(s3_path.as_str().to_string())
432 }
433
434 pub async fn upload<R>(
435 &self,
436 token: &BearerToken,
437 reader: R,
438 file_name: impl Into<&str>,
439 workspace_rid: Option<WorkspaceRid>,
440 ) -> Result<String, UploaderError>
441 where
442 R: Read + Send + 'static,
443 {
444 let file_name = file_name.into();
445 let path = Path::new(file_name);
446 let file_size = std::fs::metadata(path)?.len();
447 if file_size < SMALL_FILE_SIZE_LIMIT {
448 return self
449 .upload_small_file(
450 token,
451 file_name,
452 file_size as i64,
453 workspace_rid,
454 std::fs::File::open(path)?,
455 )
456 .await;
457 }
458
459 let initiate_response = self
460 .initiate_upload(token, file_name, workspace_rid.map(WorkspaceRid::from))
461 .await?;
462 let upload_id = initiate_response.upload_id();
463 let key = initiate_response.key();
464
465 let response = self.upload_parts(token, reader, key, upload_id).await?;
466
467 let s3_path = response.location().ok_or_else(|| {
468 UploaderError::Other("Upload response did not contain a location".to_string())
469 })?;
470
471 Ok(s3_path.to_string())
472 }
473
474 pub async fn ingest_avro(
475 &self,
476 token: &BearerToken,
477 s3_path: &str,
478 data_source_rid: ResourceIdentifier,
479 ) -> Result<IngestResponse, UploaderError> {
480 let opts = IngestOptions::AvroStream(
481 AvroStreamOpts::builder()
482 .source(IngestSource::S3(S3IngestSource::new(s3_path)))
483 .target(DatasetIngestTarget::Existing(
484 ExistingDatasetIngestDestination::new(data_source_rid),
485 ))
486 .build(),
487 );
488
489 let request = IngestRequest::new(opts);
490
491 self.ingest_client
492 .ingest(token, &request)
493 .await
494 .map_err(|e| UploaderError::Conjure(format!("{e:?}")))
495 }
496}
497
498pub struct ChunkedStreamReader {
499 reader: Box<dyn Read + Send>,
500 chunk_size: usize,
501 current_index: usize,
502}
503
504impl ChunkedStreamReader {
505 pub fn new<R>(reader: R, chunk_size: usize) -> Self
506 where
507 R: Read + Send + 'static,
508 {
509 Self {
510 reader: Box::new(reader),
511 chunk_size,
512 current_index: 0,
513 }
514 }
515}
516
517impl Stream for ChunkedStreamReader {
518 type Item = Result<(usize, Vec<u8>), std::io::Error>;
519
520 fn poll_next(
521 mut self: Pin<&mut Self>,
522 _cx: &mut std::task::Context<'_>,
523 ) -> std::task::Poll<Option<Self::Item>> {
524 let mut buffer = vec![0u8; self.chunk_size];
525
526 match self.reader.read(&mut buffer) {
527 Ok(0) => std::task::Poll::Ready(None),
528 Ok(n) => {
529 buffer.truncate(n);
530 let index = self.current_index;
531 self.current_index += 1;
532 std::task::Poll::Ready(Some(Ok((index, buffer))))
533 }
534 Err(e) => std::task::Poll::Ready(Some(Err(e))),
535 }
536 }
537}