1use std::{fmt::Display, fs::File, io::Read, path::Path, time::SystemTime};
2
3use futures::{StreamExt, stream::BoxStream};
4use thiserror::Error;
5use tokio::sync::mpsc;
6
7use graphql_client::{GraphQLQuery, Response};
8use reqwest::Client;
9use serde::{Deserialize, Serialize, de::DeserializeOwned};
10use serde_json::{Map, Value};
11use url::Url;
12use uuid::Uuid;
13
14mod rest_types;
15mod serde_utils;
16
17use rest_types::{AbortChunkedUploadRequest, InitChunkedUploadRequest, InitChunkedUploadResponse};
18
19const MEGABYTE: u64 = 1024 * 1024; pub const MIN_CHUNK_SIZE_BYTES: u64 = 5 * MEGABYTE;
21const MAX_CHUNK_SIZE_BYTES: u64 = 100 * MEGABYTE;
22const MAX_PARTS_COUNT: u64 = 10000;
23
24const SIZE_500MB: u64 = 500 * MEGABYTE;
25const SIZE_10GB: u64 = 10 * 1024 * MEGABYTE;
26const SIZE_50GB: u64 = 50 * 1024 * MEGABYTE;
27
28#[derive(Error, Debug)]
29pub enum AdaptiveError {
30 #[error("HTTP error: {0}")]
31 HttpError(#[from] reqwest::Error),
32
33 #[error("JSON error: {0}")]
34 JsonError(#[from] serde_json::Error),
35
36 #[error("IO error: {0}")]
37 IoError(#[from] std::io::Error),
38
39 #[error("URL parse error: {0}")]
40 UrlParseError(#[from] url::ParseError),
41
42 #[error("File too small for chunked upload: {size} bytes (minimum: {min_size} bytes)")]
43 FileTooSmall { size: u64, min_size: u64 },
44
45 #[error("File too large: {size} bytes exceeds maximum {max_size} bytes")]
46 FileTooLarge { size: u64, max_size: u64 },
47
48 #[error("GraphQL errors: {0:?}")]
49 GraphQLErrors(Vec<graphql_client::Error>),
50
51 #[error("No data returned from GraphQL")]
52 NoGraphQLData,
53
54 #[error("Job not found: {0}")]
55 JobNotFound(Uuid),
56
57 #[error("Failed to initialize chunked upload: {status} - {body}")]
58 ChunkedUploadInitFailed { status: String, body: String },
59
60 #[error("Failed to upload part {part_number}: {status} - {body}")]
61 ChunkedUploadPartFailed {
62 part_number: u64,
63 status: String,
64 body: String,
65 },
66
67 #[error("Failed to create dataset: {0}")]
68 DatasetCreationFailed(String),
69}
70
71type Result<T> = std::result::Result<T, AdaptiveError>;
72
73#[derive(Clone, Debug, Default)]
74pub struct ChunkedUploadProgress {
75 pub bytes_uploaded: u64,
76 pub total_bytes: u64,
77}
78
79#[derive(Debug)]
80pub enum UploadEvent {
81 Progress(ChunkedUploadProgress),
82 Complete(
83 create_dataset_from_multipart::CreateDatasetFromMultipartCreateDatasetFromMultipartUpload,
84 ),
85}
86
87pub fn calculate_upload_parts(file_size: u64) -> Result<(u64, u64)> {
88 if file_size < MIN_CHUNK_SIZE_BYTES {
89 return Err(AdaptiveError::FileTooSmall {
90 size: file_size,
91 min_size: MIN_CHUNK_SIZE_BYTES,
92 });
93 }
94
95 let mut chunk_size = if file_size < SIZE_500MB {
96 5 * MEGABYTE
97 } else if file_size < SIZE_10GB {
98 10 * MEGABYTE
99 } else if file_size < SIZE_50GB {
100 50 * MEGABYTE
101 } else {
102 100 * MEGABYTE
103 };
104
105 let mut total_parts = file_size.div_ceil(chunk_size);
106
107 if total_parts > MAX_PARTS_COUNT {
108 chunk_size = file_size.div_ceil(MAX_PARTS_COUNT);
109
110 if chunk_size > MAX_CHUNK_SIZE_BYTES {
111 let max_file_size = MAX_CHUNK_SIZE_BYTES * MAX_PARTS_COUNT;
112 return Err(AdaptiveError::FileTooLarge {
113 size: file_size,
114 max_size: max_file_size,
115 });
116 }
117
118 total_parts = file_size.div_ceil(chunk_size);
119 }
120
121 Ok((total_parts, chunk_size))
122}
123
124type IdOrKey = String;
125#[allow(clippy::upper_case_acronyms)]
126type UUID = Uuid;
127type JsObject = Map<String, Value>;
128type InputDatetime = String;
129#[allow(clippy::upper_case_acronyms)]
130type JSON = Value;
131
132#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
133pub struct Timestamp(pub SystemTime);
134
135impl<'de> serde::Deserialize<'de> for Timestamp {
136 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
137 where
138 D: serde::Deserializer<'de>,
139 {
140 let system_time = serde_utils::deserialize_timestamp_millis(deserializer)?;
141 Ok(Timestamp(system_time))
142 }
143}
144
145const PAGE_SIZE: usize = 20;
146
147#[derive(Debug, PartialEq, Serialize, Deserialize)]
148pub struct Upload(usize);
149
150#[derive(GraphQLQuery)]
151#[graphql(
152 schema_path = "schema.gql",
153 query_path = "src/graphql/list.graphql",
154 response_derives = "Debug, Clone"
155)]
156pub struct GetCustomRecipes;
157
158#[derive(GraphQLQuery)]
159#[graphql(
160 schema_path = "schema.gql",
161 query_path = "src/graphql/job.graphql",
162 response_derives = "Debug, Clone"
163)]
164pub struct GetJob;
165
166#[derive(GraphQLQuery)]
167#[graphql(
168 schema_path = "schema.gql",
169 query_path = "src/graphql/jobs.graphql",
170 response_derives = "Debug, Clone"
171)]
172pub struct ListJobs;
173
174#[derive(GraphQLQuery)]
175#[graphql(
176 schema_path = "schema.gql",
177 query_path = "src/graphql/cancel.graphql",
178 response_derives = "Debug, Clone"
179)]
180pub struct CancelJob;
181
182#[derive(GraphQLQuery)]
183#[graphql(
184 schema_path = "schema.gql",
185 query_path = "src/graphql/models.graphql",
186 response_derives = "Debug, Clone"
187)]
188pub struct ListModels;
189
190#[derive(GraphQLQuery)]
191#[graphql(
192 schema_path = "schema.gql",
193 query_path = "src/graphql/all_models.graphql",
194 response_derives = "Debug, Clone"
195)]
196pub struct ListAllModels;
197
198impl Display for get_job::JobStatus {
199 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
200 match self {
201 get_job::JobStatus::PENDING => write!(f, "Pending"),
202 get_job::JobStatus::RUNNING => write!(f, "Running"),
203 get_job::JobStatus::COMPLETED => write!(f, "Completed"),
204 get_job::JobStatus::FAILED => write!(f, "Failed"),
205 get_job::JobStatus::CANCELED => write!(f, "Canceled"),
206 get_job::JobStatus::Other(_) => write!(f, "Unknown"),
207 }
208 }
209}
210
211#[derive(GraphQLQuery)]
212#[graphql(
213 schema_path = "schema.gql",
214 query_path = "src/graphql/publish.graphql",
215 response_derives = "Debug, Clone"
216)]
217pub struct PublishCustomRecipe;
218
219#[derive(GraphQLQuery)]
220#[graphql(
221 schema_path = "schema.gql",
222 query_path = "src/graphql/upload_dataset.graphql",
223 response_derives = "Debug, Clone"
224)]
225pub struct UploadDataset;
226
227#[derive(GraphQLQuery)]
228#[graphql(
229 schema_path = "schema.gql",
230 query_path = "src/graphql/create_dataset_from_multipart.graphql",
231 response_derives = "Debug, Clone"
232)]
233pub struct CreateDatasetFromMultipart;
234
235#[derive(GraphQLQuery)]
236#[graphql(
237 schema_path = "schema.gql",
238 query_path = "src/graphql/run.graphql",
239 response_derives = "Debug, Clone"
240)]
241pub struct RunCustomRecipe;
242
243#[derive(GraphQLQuery)]
244#[graphql(
245 schema_path = "schema.gql",
246 query_path = "src/graphql/usecases.graphql",
247 response_derives = "Debug, Clone"
248)]
249pub struct ListUseCases;
250
251#[derive(GraphQLQuery)]
252#[graphql(
253 schema_path = "schema.gql",
254 query_path = "src/graphql/pools.graphql",
255 response_derives = "Debug, Clone"
256)]
257pub struct ListComputePools;
258
259#[derive(GraphQLQuery)]
260#[graphql(
261 schema_path = "schema.gql",
262 query_path = "src/graphql/recipe.graphql",
263 response_derives = "Debug, Clone"
264)]
265pub struct GetRecipe;
266
267#[derive(GraphQLQuery)]
268#[graphql(
269 schema_path = "schema.gql",
270 query_path = "src/graphql/grader.graphql",
271 response_derives = "Debug, Clone, Serialize"
272)]
273pub struct GetGrader;
274
275#[derive(GraphQLQuery)]
276#[graphql(
277 schema_path = "schema.gql",
278 query_path = "src/graphql/dataset.graphql",
279 response_derives = "Debug, Clone"
280)]
281pub struct GetDataset;
282
283#[derive(GraphQLQuery)]
284#[graphql(
285 schema_path = "schema.gql",
286 query_path = "src/graphql/model_config.graphql",
287 response_derives = "Debug, Clone, Serialize"
288)]
289pub struct GetModelConfig;
290
291const INIT_CHUNKED_UPLOAD_ROUTE: &str = "v1/upload/init";
292const UPLOAD_PART_ROUTE: &str = "v1/upload/part";
293const ABORT_CHUNKED_UPLOAD_ROUTE: &str = "v1/upload/abort";
294
295#[derive(Clone)]
296pub struct AdaptiveClient {
297 client: Client,
298 graphql_url: Url,
299 rest_base_url: Url,
300 auth_token: String,
301}
302
303impl AdaptiveClient {
304 pub fn new(api_base_url: Url, auth_token: String) -> Self {
305 let graphql_url = api_base_url
306 .join("graphql")
307 .expect("Failed to append graphql to base URL");
308
309 let client = Client::builder()
310 .user_agent(format!(
311 "adaptive-client-rust/{}",
312 env!("CARGO_PKG_VERSION")
313 ))
314 .build()
315 .expect("Failed to build HTTP client");
316
317 Self {
318 client,
319 graphql_url,
320 rest_base_url: api_base_url,
321 auth_token,
322 }
323 }
324
325 async fn execute_query<T>(&self, _query: T, variables: T::Variables) -> Result<T::ResponseData>
326 where
327 T: GraphQLQuery,
328 T::Variables: serde::Serialize,
329 T::ResponseData: DeserializeOwned,
330 {
331 let request_body = T::build_query(variables);
332
333 let response = self
334 .client
335 .post(self.graphql_url.clone())
336 .bearer_auth(&self.auth_token)
337 .json(&request_body)
338 .send()
339 .await?;
340
341 let response_body: Response<T::ResponseData> = response.json().await?;
342
343 match response_body.data {
344 Some(data) => Ok(data),
345 None => {
346 if let Some(errors) = response_body.errors {
347 return Err(AdaptiveError::GraphQLErrors(errors));
348 }
349 Err(AdaptiveError::NoGraphQLData)
350 }
351 }
352 }
353
354 pub async fn list_recipes(
355 &self,
356 usecase: &str,
357 ) -> Result<Vec<get_custom_recipes::GetCustomRecipesCustomRecipes>> {
358 let variables = get_custom_recipes::Variables {
359 usecase: IdOrKey::from(usecase),
360 };
361
362 let response_data = self.execute_query(GetCustomRecipes, variables).await?;
363 Ok(response_data.custom_recipes)
364 }
365
366 pub async fn get_job(&self, job_id: Uuid) -> Result<get_job::GetJobJob> {
367 let variables = get_job::Variables { id: job_id };
368
369 let response_data = self.execute_query(GetJob, variables).await?;
370
371 match response_data.job {
372 Some(job) => Ok(job),
373 None => Err(AdaptiveError::JobNotFound(job_id)),
374 }
375 }
376
377 pub async fn upload_dataset<P: AsRef<Path>>(
378 &self,
379 usecase: &str,
380 name: &str,
381 dataset: P,
382 ) -> Result<upload_dataset::UploadDatasetCreateDataset> {
383 let variables = upload_dataset::Variables {
384 usecase: IdOrKey::from(usecase),
385 file: Upload(0),
386 name: Some(name.to_string()),
387 };
388
389 let operations = UploadDataset::build_query(variables);
390 let operations = serde_json::to_string(&operations)?;
391
392 let file_map = r#"{ "0": ["variables.file"] }"#;
393
394 let dataset_file = reqwest::multipart::Part::file(dataset).await?;
395
396 let form = reqwest::multipart::Form::new()
397 .text("operations", operations)
398 .text("map", file_map)
399 .part("0", dataset_file);
400
401 let response = self
402 .client
403 .post(self.graphql_url.clone())
404 .bearer_auth(&self.auth_token)
405 .multipart(form)
406 .send()
407 .await?;
408
409 let response: Response<<UploadDataset as graphql_client::GraphQLQuery>::ResponseData> =
410 response.json().await?;
411
412 match response.data {
413 Some(data) => Ok(data.create_dataset),
414 None => {
415 if let Some(errors) = response.errors {
416 return Err(AdaptiveError::GraphQLErrors(errors));
417 }
418 Err(AdaptiveError::NoGraphQLData)
419 }
420 }
421 }
422
423 pub async fn publish_recipe<P: AsRef<Path>>(
424 &self,
425 usecase: &str,
426 name: &str,
427 key: &str,
428 recipe: P,
429 ) -> Result<publish_custom_recipe::PublishCustomRecipeCreateCustomRecipe> {
430 let variables = publish_custom_recipe::Variables {
431 usecase: IdOrKey::from(usecase),
432 file: Upload(0),
433 name: Some(name.to_string()),
434 key: Some(key.to_string()),
435 };
436
437 let operations = PublishCustomRecipe::build_query(variables);
438 let operations = serde_json::to_string(&operations)?;
439
440 let file_map = r#"{ "0": ["variables.file"] }"#;
441
442 let recipe_file = reqwest::multipart::Part::file(recipe).await?;
443
444 let form = reqwest::multipart::Form::new()
445 .text("operations", operations)
446 .text("map", file_map)
447 .part("0", recipe_file);
448
449 let response = self
450 .client
451 .post(self.graphql_url.clone())
452 .bearer_auth(&self.auth_token)
453 .multipart(form)
454 .send()
455 .await?;
456 let response: Response<
457 <PublishCustomRecipe as graphql_client::GraphQLQuery>::ResponseData,
458 > = response.json().await?;
459
460 match response.data {
461 Some(data) => Ok(data.create_custom_recipe),
462 None => {
463 if let Some(errors) = response.errors {
464 return Err(AdaptiveError::GraphQLErrors(errors));
465 }
466 Err(AdaptiveError::NoGraphQLData)
467 }
468 }
469 }
470
471 pub async fn run_recipe(
472 &self,
473 usecase: &str,
474 recipe_id: &str,
475 parameters: Map<String, Value>,
476 name: Option<String>,
477 compute_pool: Option<String>,
478 num_gpus: u32,
479 ) -> Result<run_custom_recipe::RunCustomRecipeCreateJob> {
480 let variables = run_custom_recipe::Variables {
481 input: run_custom_recipe::JobInput {
482 recipe: recipe_id.to_string(),
483 use_case: usecase.to_string(),
484 args: parameters,
485 name,
486 compute_pool,
487 num_gpus: num_gpus as i64,
488 },
489 };
490
491 let response_data = self.execute_query(RunCustomRecipe, variables).await?;
492 Ok(response_data.create_job)
493 }
494
495 pub async fn list_jobs(
496 &self,
497 usecase: Option<String>,
498 ) -> Result<Vec<list_jobs::ListJobsJobsNodes>> {
499 let mut jobs = Vec::new();
500 let mut page = self.list_jobs_page(usecase.clone(), None).await?;
501 jobs.extend(page.nodes.iter().cloned());
502 while page.page_info.has_next_page {
503 page = self
504 .list_jobs_page(usecase.clone(), page.page_info.end_cursor)
505 .await?;
506 jobs.extend(page.nodes.iter().cloned());
507 }
508 Ok(jobs)
509 }
510
511 async fn list_jobs_page(
512 &self,
513 usecase: Option<String>,
514 after: Option<String>,
515 ) -> Result<list_jobs::ListJobsJobs> {
516 let variables = list_jobs::Variables {
517 filter: Some(list_jobs::ListJobsFilterInput {
518 use_case: usecase,
519 kind: Some(vec![list_jobs::JobKind::CUSTOM]),
520 status: Some(vec![
521 list_jobs::JobStatus::RUNNING,
522 list_jobs::JobStatus::PENDING,
523 ]),
524 timerange: None,
525 custom_recipes: None,
526 artifacts: None,
527 }),
528 cursor: Some(list_jobs::CursorPageInput {
529 first: Some(PAGE_SIZE as i64),
530 after,
531 before: None,
532 last: None,
533 offset: None,
534 }),
535 };
536
537 let response_data = self.execute_query(ListJobs, variables).await?;
538 Ok(response_data.jobs)
539 }
540
541 pub async fn cancel_job(&self, job_id: Uuid) -> Result<cancel_job::CancelJobCancelJob> {
542 let variables = cancel_job::Variables { job_id };
543
544 let response_data = self.execute_query(CancelJob, variables).await?;
545 Ok(response_data.cancel_job)
546 }
547
548 pub async fn list_models(
549 &self,
550 usecase: String,
551 ) -> Result<Vec<list_models::ListModelsUseCaseModelServices>> {
552 let variables = list_models::Variables {
553 use_case_id: usecase,
554 };
555
556 let response_data = self.execute_query(ListModels, variables).await?;
557 Ok(response_data
558 .use_case
559 .map(|use_case| use_case.model_services)
560 .unwrap_or(Vec::new()))
561 }
562
563 pub async fn list_all_models(&self) -> Result<Vec<list_all_models::ListAllModelsModels>> {
564 let variables = list_all_models::Variables {};
565
566 let response_data = self.execute_query(ListAllModels, variables).await?;
567 Ok(response_data.models)
568 }
569
570 pub async fn list_usecases(&self) -> Result<Vec<list_use_cases::ListUseCasesUseCases>> {
571 let variables = list_use_cases::Variables {};
572
573 let response_data = self.execute_query(ListUseCases, variables).await?;
574 Ok(response_data.use_cases)
575 }
576
577 pub async fn list_pools(
578 &self,
579 ) -> Result<Vec<list_compute_pools::ListComputePoolsComputePools>> {
580 let variables = list_compute_pools::Variables {};
581
582 let response_data = self.execute_query(ListComputePools, variables).await?;
583 Ok(response_data.compute_pools)
584 }
585
586 pub async fn get_recipe(
587 &self,
588 usecase: String,
589 id_or_key: String,
590 ) -> Result<Option<get_recipe::GetRecipeCustomRecipe>> {
591 let variables = get_recipe::Variables { usecase, id_or_key };
592
593 let response_data = self.execute_query(GetRecipe, variables).await?;
594 Ok(response_data.custom_recipe)
595 }
596
597 pub async fn get_grader(
598 &self,
599 id_or_key: &str,
600 use_case: &str,
601 ) -> Result<get_grader::GetGraderGrader> {
602 let variables = get_grader::Variables {
603 id: id_or_key.to_string(),
604 use_case: use_case.to_string(),
605 };
606
607 let response_data = self.execute_query(GetGrader, variables).await?;
608 Ok(response_data.grader)
609 }
610
611 pub async fn get_dataset(
612 &self,
613 id_or_key: &str,
614 use_case: &str,
615 ) -> Result<Option<get_dataset::GetDatasetDataset>> {
616 let variables = get_dataset::Variables {
617 id_or_key: id_or_key.to_string(),
618 use_case: use_case.to_string(),
619 };
620
621 let response_data = self.execute_query(GetDataset, variables).await?;
622 Ok(response_data.dataset)
623 }
624
625 pub async fn get_model_config(
626 &self,
627 id_or_key: &str,
628 ) -> Result<Option<get_model_config::GetModelConfigModel>> {
629 let variables = get_model_config::Variables {
630 id_or_key: id_or_key.to_string(),
631 };
632
633 let response_data = self.execute_query(GetModelConfig, variables).await?;
634 Ok(response_data.model)
635 }
636
637 pub fn base_url(&self) -> &Url {
638 &self.rest_base_url
639 }
640
641 async fn init_chunked_upload(&self, total_parts: u64) -> Result<String> {
642 let url = self.rest_base_url.join(INIT_CHUNKED_UPLOAD_ROUTE)?;
643
644 let request = InitChunkedUploadRequest {
645 content_type: "application/jsonl".to_string(),
646 metadata: None,
647 total_parts_count: total_parts,
648 };
649
650 let response = self
651 .client
652 .post(url)
653 .bearer_auth(&self.auth_token)
654 .json(&request)
655 .send()
656 .await?;
657
658 if !response.status().is_success() {
659 return Err(AdaptiveError::ChunkedUploadInitFailed {
660 status: response.status().to_string(),
661 body: response.text().await.unwrap_or_default(),
662 });
663 }
664
665 let init_response: InitChunkedUploadResponse = response.json().await?;
666 Ok(init_response.session_id)
667 }
668
669 async fn upload_part(
670 &self,
671 session_id: &str,
672 part_number: u64,
673 data: Vec<u8>,
674 progress_tx: mpsc::Sender<u64>,
675 ) -> Result<()> {
676 const SUB_CHUNK_SIZE: usize = 64 * 1024;
677
678 let url = self.rest_base_url.join(UPLOAD_PART_ROUTE)?;
679
680 let chunks: Vec<Vec<u8>> = data
681 .chunks(SUB_CHUNK_SIZE)
682 .map(|chunk| chunk.to_vec())
683 .collect();
684
685 let stream = futures::stream::iter(chunks).map(move |chunk| {
686 let len = chunk.len() as u64;
687 let tx = progress_tx.clone();
688 let _ = tx.try_send(len);
689 Ok::<_, std::io::Error>(chunk)
690 });
691
692 let body = reqwest::Body::wrap_stream(stream);
693
694 let response = self
695 .client
696 .post(url)
697 .bearer_auth(&self.auth_token)
698 .query(&[
699 ("session_id", session_id),
700 ("part_number", &part_number.to_string()),
701 ])
702 .header("Content-Type", "application/octet-stream")
703 .body(body)
704 .send()
705 .await?;
706
707 if !response.status().is_success() {
708 return Err(AdaptiveError::ChunkedUploadPartFailed {
709 part_number,
710 status: response.status().to_string(),
711 body: response.text().await.unwrap_or_default(),
712 });
713 }
714
715 Ok(())
716 }
717
718 async fn abort_chunked_upload(&self, session_id: &str) -> Result<()> {
719 let url = self.rest_base_url.join(ABORT_CHUNKED_UPLOAD_ROUTE)?;
720
721 let request = AbortChunkedUploadRequest {
722 session_id: session_id.to_string(),
723 };
724
725 let _ = self
726 .client
727 .delete(url)
728 .bearer_auth(&self.auth_token)
729 .json(&request)
730 .send()
731 .await;
732
733 Ok(())
734 }
735
736 async fn create_dataset_from_multipart(
737 &self,
738 usecase: &str,
739 name: &str,
740 key: &str,
741 session_id: &str,
742 ) -> Result<
743 create_dataset_from_multipart::CreateDatasetFromMultipartCreateDatasetFromMultipartUpload,
744 > {
745 let variables = create_dataset_from_multipart::Variables {
746 input: create_dataset_from_multipart::DatasetCreateFromMultipartUpload {
747 use_case: usecase.to_string(),
748 name: name.to_string(),
749 key: Some(key.to_string()),
750 source: None,
751 upload_session_id: session_id.to_string(),
752 },
753 };
754
755 let response_data = self
756 .execute_query(CreateDatasetFromMultipart, variables)
757 .await?;
758 Ok(response_data.create_dataset_from_multipart_upload)
759 }
760
761 pub fn chunked_upload_dataset<'a, P: AsRef<Path> + Send + 'a>(
762 &'a self,
763 usecase: &'a str,
764 name: &'a str,
765 key: &'a str,
766 dataset: P,
767 ) -> Result<BoxStream<'a, Result<UploadEvent>>> {
768 let file_size = std::fs::metadata(dataset.as_ref())?.len();
769
770 let (total_parts, chunk_size) = calculate_upload_parts(file_size)?;
771
772 let stream = async_stream::try_stream! {
773 yield UploadEvent::Progress(ChunkedUploadProgress {
774 bytes_uploaded: 0,
775 total_bytes: file_size,
776 });
777
778 let session_id = self.init_chunked_upload(total_parts).await?;
779
780 let mut file = File::open(dataset.as_ref())?;
781 let mut buffer = vec![0u8; chunk_size as usize];
782 let mut bytes_uploaded = 0u64;
783
784 let (progress_tx, mut progress_rx) = mpsc::channel::<u64>(64);
785
786 for part_number in 1..=total_parts {
787 let bytes_read = file.read(&mut buffer)?;
788 let chunk_data = buffer[..bytes_read].to_vec();
789
790 let upload_fut = self.upload_part(&session_id, part_number, chunk_data, progress_tx.clone());
791 tokio::pin!(upload_fut);
792
793 let upload_result: Result<()> = loop {
794 tokio::select! {
795 biased;
796 result = &mut upload_fut => {
797 break result;
798 }
799 Some(bytes) = progress_rx.recv() => {
800 bytes_uploaded += bytes;
801 yield UploadEvent::Progress(ChunkedUploadProgress {
802 bytes_uploaded,
803 total_bytes: file_size,
804 });
805 }
806 }
807 };
808
809 if let Err(e) = upload_result {
810 let _ = self.abort_chunked_upload(&session_id).await;
811 Err(e)?;
812 }
813 }
814
815 let create_result = self
816 .create_dataset_from_multipart(usecase, name, key, &session_id)
817 .await;
818
819 match create_result {
820 Ok(response) => {
821 yield UploadEvent::Complete(response);
822 }
823 Err(e) => {
824 let _ = self.abort_chunked_upload(&session_id).await;
825 Err(AdaptiveError::DatasetCreationFailed(e.to_string()))?;
826 }
827 }
828 };
829
830 Ok(Box::pin(stream))
831 }
832}