adaptive_client_rust/
lib.rs

1use std::{fmt::Display, fs::File, io::Read, path::Path, time::SystemTime};
2
3use futures::{StreamExt, stream::BoxStream};
4use thiserror::Error;
5use tokio::sync::mpsc;
6
7use graphql_client::{GraphQLQuery, Response};
8use reqwest::Client;
9use serde::{Deserialize, Serialize, de::DeserializeOwned};
10use serde_json::{Map, Value};
11use url::Url;
12use uuid::Uuid;
13
14mod rest_types;
15mod serde_utils;
16
17use rest_types::{AbortChunkedUploadRequest, InitChunkedUploadRequest, InitChunkedUploadResponse};
18
19const MEGABYTE: u64 = 1024 * 1024; // 1MB
20pub const MIN_CHUNK_SIZE_BYTES: u64 = 5 * MEGABYTE;
21const MAX_CHUNK_SIZE_BYTES: u64 = 100 * MEGABYTE;
22const MAX_PARTS_COUNT: u64 = 10000;
23
24const SIZE_500MB: u64 = 500 * MEGABYTE;
25const SIZE_10GB: u64 = 10 * 1024 * MEGABYTE;
26const SIZE_50GB: u64 = 50 * 1024 * MEGABYTE;
27
28#[derive(Error, Debug)]
29pub enum AdaptiveError {
30    #[error("HTTP error: {0}")]
31    HttpError(#[from] reqwest::Error),
32
33    #[error("JSON error: {0}")]
34    JsonError(#[from] serde_json::Error),
35
36    #[error("IO error: {0}")]
37    IoError(#[from] std::io::Error),
38
39    #[error("URL parse error: {0}")]
40    UrlParseError(#[from] url::ParseError),
41
42    #[error("File too small for chunked upload: {size} bytes (minimum: {min_size} bytes)")]
43    FileTooSmall { size: u64, min_size: u64 },
44
45    #[error("File too large: {size} bytes exceeds maximum {max_size} bytes")]
46    FileTooLarge { size: u64, max_size: u64 },
47
48    #[error("GraphQL errors: {0:?}")]
49    GraphQLErrors(Vec<graphql_client::Error>),
50
51    #[error("No data returned from GraphQL")]
52    NoGraphQLData,
53
54    #[error("Job not found: {0}")]
55    JobNotFound(Uuid),
56
57    #[error("Failed to initialize chunked upload: {status} - {body}")]
58    ChunkedUploadInitFailed { status: String, body: String },
59
60    #[error("Failed to upload part {part_number}: {status} - {body}")]
61    ChunkedUploadPartFailed {
62        part_number: u64,
63        status: String,
64        body: String,
65    },
66
67    #[error("Failed to create dataset: {0}")]
68    DatasetCreationFailed(String),
69}
70
71type Result<T> = std::result::Result<T, AdaptiveError>;
72
73#[derive(Clone, Debug, Default)]
74pub struct ChunkedUploadProgress {
75    pub bytes_uploaded: u64,
76    pub total_bytes: u64,
77}
78
79#[derive(Debug)]
80pub enum UploadEvent {
81    Progress(ChunkedUploadProgress),
82    Complete(
83        create_dataset_from_multipart::CreateDatasetFromMultipartCreateDatasetFromMultipartUpload,
84    ),
85}
86
87pub fn calculate_upload_parts(file_size: u64) -> Result<(u64, u64)> {
88    if file_size < MIN_CHUNK_SIZE_BYTES {
89        return Err(AdaptiveError::FileTooSmall {
90            size: file_size,
91            min_size: MIN_CHUNK_SIZE_BYTES,
92        });
93    }
94
95    let mut chunk_size = if file_size < SIZE_500MB {
96        5 * MEGABYTE
97    } else if file_size < SIZE_10GB {
98        10 * MEGABYTE
99    } else if file_size < SIZE_50GB {
100        50 * MEGABYTE
101    } else {
102        100 * MEGABYTE
103    };
104
105    let mut total_parts = file_size.div_ceil(chunk_size);
106
107    if total_parts > MAX_PARTS_COUNT {
108        chunk_size = file_size.div_ceil(MAX_PARTS_COUNT);
109
110        if chunk_size > MAX_CHUNK_SIZE_BYTES {
111            let max_file_size = MAX_CHUNK_SIZE_BYTES * MAX_PARTS_COUNT;
112            return Err(AdaptiveError::FileTooLarge {
113                size: file_size,
114                max_size: max_file_size,
115            });
116        }
117
118        total_parts = file_size.div_ceil(chunk_size);
119    }
120
121    Ok((total_parts, chunk_size))
122}
123
124type IdOrKey = String;
125#[allow(clippy::upper_case_acronyms)]
126type UUID = Uuid;
127type JsObject = Map<String, Value>;
128type InputDatetime = String;
129#[allow(clippy::upper_case_acronyms)]
130type JSON = Value;
131
132#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
133pub struct Timestamp(pub SystemTime);
134
135impl<'de> serde::Deserialize<'de> for Timestamp {
136    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
137    where
138        D: serde::Deserializer<'de>,
139    {
140        let system_time = serde_utils::deserialize_timestamp_millis(deserializer)?;
141        Ok(Timestamp(system_time))
142    }
143}
144
145const PAGE_SIZE: usize = 20;
146
147#[derive(Debug, PartialEq, Serialize, Deserialize)]
148pub struct Upload(usize);
149
150#[derive(GraphQLQuery)]
151#[graphql(
152    schema_path = "schema.gql",
153    query_path = "src/graphql/list.graphql",
154    response_derives = "Debug, Clone"
155)]
156pub struct GetCustomRecipes;
157
158#[derive(GraphQLQuery)]
159#[graphql(
160    schema_path = "schema.gql",
161    query_path = "src/graphql/job.graphql",
162    response_derives = "Debug, Clone"
163)]
164pub struct GetJob;
165
166#[derive(GraphQLQuery)]
167#[graphql(
168    schema_path = "schema.gql",
169    query_path = "src/graphql/jobs.graphql",
170    response_derives = "Debug, Clone"
171)]
172pub struct ListJobs;
173
174#[derive(GraphQLQuery)]
175#[graphql(
176    schema_path = "schema.gql",
177    query_path = "src/graphql/cancel.graphql",
178    response_derives = "Debug, Clone"
179)]
180pub struct CancelJob;
181
182#[derive(GraphQLQuery)]
183#[graphql(
184    schema_path = "schema.gql",
185    query_path = "src/graphql/models.graphql",
186    response_derives = "Debug, Clone"
187)]
188pub struct ListModels;
189
190#[derive(GraphQLQuery)]
191#[graphql(
192    schema_path = "schema.gql",
193    query_path = "src/graphql/all_models.graphql",
194    response_derives = "Debug, Clone"
195)]
196pub struct ListAllModels;
197
198impl Display for get_job::JobStatus {
199    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
200        match self {
201            get_job::JobStatus::PENDING => write!(f, "Pending"),
202            get_job::JobStatus::RUNNING => write!(f, "Running"),
203            get_job::JobStatus::COMPLETED => write!(f, "Completed"),
204            get_job::JobStatus::FAILED => write!(f, "Failed"),
205            get_job::JobStatus::CANCELED => write!(f, "Canceled"),
206            get_job::JobStatus::Other(_) => write!(f, "Unknown"),
207        }
208    }
209}
210
211#[derive(GraphQLQuery)]
212#[graphql(
213    schema_path = "schema.gql",
214    query_path = "src/graphql/publish.graphql",
215    response_derives = "Debug, Clone"
216)]
217pub struct PublishCustomRecipe;
218
219#[derive(GraphQLQuery)]
220#[graphql(
221    schema_path = "schema.gql",
222    query_path = "src/graphql/upload_dataset.graphql",
223    response_derives = "Debug, Clone"
224)]
225pub struct UploadDataset;
226
227#[derive(GraphQLQuery)]
228#[graphql(
229    schema_path = "schema.gql",
230    query_path = "src/graphql/create_dataset_from_multipart.graphql",
231    response_derives = "Debug, Clone"
232)]
233pub struct CreateDatasetFromMultipart;
234
235#[derive(GraphQLQuery)]
236#[graphql(
237    schema_path = "schema.gql",
238    query_path = "src/graphql/run.graphql",
239    response_derives = "Debug, Clone"
240)]
241pub struct RunCustomRecipe;
242
243#[derive(GraphQLQuery)]
244#[graphql(
245    schema_path = "schema.gql",
246    query_path = "src/graphql/usecases.graphql",
247    response_derives = "Debug, Clone"
248)]
249pub struct ListUseCases;
250
251#[derive(GraphQLQuery)]
252#[graphql(
253    schema_path = "schema.gql",
254    query_path = "src/graphql/pools.graphql",
255    response_derives = "Debug, Clone"
256)]
257pub struct ListComputePools;
258
259#[derive(GraphQLQuery)]
260#[graphql(
261    schema_path = "schema.gql",
262    query_path = "src/graphql/recipe.graphql",
263    response_derives = "Debug, Clone"
264)]
265pub struct GetRecipe;
266
267#[derive(GraphQLQuery)]
268#[graphql(
269    schema_path = "schema.gql",
270    query_path = "src/graphql/grader.graphql",
271    response_derives = "Debug, Clone, Serialize"
272)]
273pub struct GetGrader;
274
275#[derive(GraphQLQuery)]
276#[graphql(
277    schema_path = "schema.gql",
278    query_path = "src/graphql/dataset.graphql",
279    response_derives = "Debug, Clone"
280)]
281pub struct GetDataset;
282
283#[derive(GraphQLQuery)]
284#[graphql(
285    schema_path = "schema.gql",
286    query_path = "src/graphql/model_config.graphql",
287    response_derives = "Debug, Clone, Serialize"
288)]
289pub struct GetModelConfig;
290
291const INIT_CHUNKED_UPLOAD_ROUTE: &str = "v1/upload/init";
292const UPLOAD_PART_ROUTE: &str = "v1/upload/part";
293const ABORT_CHUNKED_UPLOAD_ROUTE: &str = "v1/upload/abort";
294
295#[derive(Clone)]
296pub struct AdaptiveClient {
297    client: Client,
298    graphql_url: Url,
299    rest_base_url: Url,
300    auth_token: String,
301}
302
303impl AdaptiveClient {
304    pub fn new(api_base_url: Url, auth_token: String) -> Self {
305        let graphql_url = api_base_url
306            .join("graphql")
307            .expect("Failed to append graphql to base URL");
308
309        let client = Client::builder()
310            .user_agent(format!(
311                "adaptive-client-rust/{}",
312                env!("CARGO_PKG_VERSION")
313            ))
314            .build()
315            .expect("Failed to build HTTP client");
316
317        Self {
318            client,
319            graphql_url,
320            rest_base_url: api_base_url,
321            auth_token,
322        }
323    }
324
325    async fn execute_query<T>(&self, _query: T, variables: T::Variables) -> Result<T::ResponseData>
326    where
327        T: GraphQLQuery,
328        T::Variables: serde::Serialize,
329        T::ResponseData: DeserializeOwned,
330    {
331        let request_body = T::build_query(variables);
332
333        let response = self
334            .client
335            .post(self.graphql_url.clone())
336            .bearer_auth(&self.auth_token)
337            .json(&request_body)
338            .send()
339            .await?;
340
341        let response_body: Response<T::ResponseData> = response.json().await?;
342
343        match response_body.data {
344            Some(data) => Ok(data),
345            None => {
346                if let Some(errors) = response_body.errors {
347                    return Err(AdaptiveError::GraphQLErrors(errors));
348                }
349                Err(AdaptiveError::NoGraphQLData)
350            }
351        }
352    }
353
354    pub async fn list_recipes(
355        &self,
356        usecase: &str,
357    ) -> Result<Vec<get_custom_recipes::GetCustomRecipesCustomRecipes>> {
358        let variables = get_custom_recipes::Variables {
359            usecase: IdOrKey::from(usecase),
360        };
361
362        let response_data = self.execute_query(GetCustomRecipes, variables).await?;
363        Ok(response_data.custom_recipes)
364    }
365
366    pub async fn get_job(&self, job_id: Uuid) -> Result<get_job::GetJobJob> {
367        let variables = get_job::Variables { id: job_id };
368
369        let response_data = self.execute_query(GetJob, variables).await?;
370
371        match response_data.job {
372            Some(job) => Ok(job),
373            None => Err(AdaptiveError::JobNotFound(job_id)),
374        }
375    }
376
377    pub async fn upload_dataset<P: AsRef<Path>>(
378        &self,
379        usecase: &str,
380        name: &str,
381        dataset: P,
382    ) -> Result<upload_dataset::UploadDatasetCreateDataset> {
383        let variables = upload_dataset::Variables {
384            usecase: IdOrKey::from(usecase),
385            file: Upload(0),
386            name: Some(name.to_string()),
387        };
388
389        let operations = UploadDataset::build_query(variables);
390        let operations = serde_json::to_string(&operations)?;
391
392        let file_map = r#"{ "0": ["variables.file"] }"#;
393
394        let dataset_file = reqwest::multipart::Part::file(dataset).await?;
395
396        let form = reqwest::multipart::Form::new()
397            .text("operations", operations)
398            .text("map", file_map)
399            .part("0", dataset_file);
400
401        let response = self
402            .client
403            .post(self.graphql_url.clone())
404            .bearer_auth(&self.auth_token)
405            .multipart(form)
406            .send()
407            .await?;
408
409        let response: Response<<UploadDataset as graphql_client::GraphQLQuery>::ResponseData> =
410            response.json().await?;
411
412        match response.data {
413            Some(data) => Ok(data.create_dataset),
414            None => {
415                if let Some(errors) = response.errors {
416                    return Err(AdaptiveError::GraphQLErrors(errors));
417                }
418                Err(AdaptiveError::NoGraphQLData)
419            }
420        }
421    }
422
423    pub async fn publish_recipe<P: AsRef<Path>>(
424        &self,
425        usecase: &str,
426        name: &str,
427        key: &str,
428        recipe: P,
429    ) -> Result<publish_custom_recipe::PublishCustomRecipeCreateCustomRecipe> {
430        let variables = publish_custom_recipe::Variables {
431            usecase: IdOrKey::from(usecase),
432            file: Upload(0),
433            name: Some(name.to_string()),
434            key: Some(key.to_string()),
435        };
436
437        let operations = PublishCustomRecipe::build_query(variables);
438        let operations = serde_json::to_string(&operations)?;
439
440        let file_map = r#"{ "0": ["variables.file"] }"#;
441
442        let recipe_file = reqwest::multipart::Part::file(recipe).await?;
443
444        let form = reqwest::multipart::Form::new()
445            .text("operations", operations)
446            .text("map", file_map)
447            .part("0", recipe_file);
448
449        let response = self
450            .client
451            .post(self.graphql_url.clone())
452            .bearer_auth(&self.auth_token)
453            .multipart(form)
454            .send()
455            .await?;
456        let response: Response<
457            <PublishCustomRecipe as graphql_client::GraphQLQuery>::ResponseData,
458        > = response.json().await?;
459
460        match response.data {
461            Some(data) => Ok(data.create_custom_recipe),
462            None => {
463                if let Some(errors) = response.errors {
464                    return Err(AdaptiveError::GraphQLErrors(errors));
465                }
466                Err(AdaptiveError::NoGraphQLData)
467            }
468        }
469    }
470
471    pub async fn run_recipe(
472        &self,
473        usecase: &str,
474        recipe_id: &str,
475        parameters: Map<String, Value>,
476        name: Option<String>,
477        compute_pool: Option<String>,
478        num_gpus: u32,
479    ) -> Result<run_custom_recipe::RunCustomRecipeCreateJob> {
480        let variables = run_custom_recipe::Variables {
481            input: run_custom_recipe::JobInput {
482                recipe: recipe_id.to_string(),
483                use_case: usecase.to_string(),
484                args: parameters,
485                name,
486                compute_pool,
487                num_gpus: num_gpus as i64,
488            },
489        };
490
491        let response_data = self.execute_query(RunCustomRecipe, variables).await?;
492        Ok(response_data.create_job)
493    }
494
495    pub async fn list_jobs(
496        &self,
497        usecase: Option<String>,
498    ) -> Result<Vec<list_jobs::ListJobsJobsNodes>> {
499        let mut jobs = Vec::new();
500        let mut page = self.list_jobs_page(usecase.clone(), None).await?;
501        jobs.extend(page.nodes.iter().cloned());
502        while page.page_info.has_next_page {
503            page = self
504                .list_jobs_page(usecase.clone(), page.page_info.end_cursor)
505                .await?;
506            jobs.extend(page.nodes.iter().cloned());
507        }
508        Ok(jobs)
509    }
510
511    async fn list_jobs_page(
512        &self,
513        usecase: Option<String>,
514        after: Option<String>,
515    ) -> Result<list_jobs::ListJobsJobs> {
516        let variables = list_jobs::Variables {
517            filter: Some(list_jobs::ListJobsFilterInput {
518                use_case: usecase,
519                kind: Some(vec![list_jobs::JobKind::CUSTOM]),
520                status: Some(vec![
521                    list_jobs::JobStatus::RUNNING,
522                    list_jobs::JobStatus::PENDING,
523                ]),
524                timerange: None,
525                custom_recipes: None,
526                artifacts: None,
527            }),
528            cursor: Some(list_jobs::CursorPageInput {
529                first: Some(PAGE_SIZE as i64),
530                after,
531                before: None,
532                last: None,
533                offset: None,
534            }),
535        };
536
537        let response_data = self.execute_query(ListJobs, variables).await?;
538        Ok(response_data.jobs)
539    }
540
541    pub async fn cancel_job(&self, job_id: Uuid) -> Result<cancel_job::CancelJobCancelJob> {
542        let variables = cancel_job::Variables { job_id };
543
544        let response_data = self.execute_query(CancelJob, variables).await?;
545        Ok(response_data.cancel_job)
546    }
547
548    pub async fn list_models(
549        &self,
550        usecase: String,
551    ) -> Result<Vec<list_models::ListModelsUseCaseModelServices>> {
552        let variables = list_models::Variables {
553            use_case_id: usecase,
554        };
555
556        let response_data = self.execute_query(ListModels, variables).await?;
557        Ok(response_data
558            .use_case
559            .map(|use_case| use_case.model_services)
560            .unwrap_or(Vec::new()))
561    }
562
563    pub async fn list_all_models(&self) -> Result<Vec<list_all_models::ListAllModelsModels>> {
564        let variables = list_all_models::Variables {};
565
566        let response_data = self.execute_query(ListAllModels, variables).await?;
567        Ok(response_data.models)
568    }
569
570    pub async fn list_usecases(&self) -> Result<Vec<list_use_cases::ListUseCasesUseCases>> {
571        let variables = list_use_cases::Variables {};
572
573        let response_data = self.execute_query(ListUseCases, variables).await?;
574        Ok(response_data.use_cases)
575    }
576
577    pub async fn list_pools(
578        &self,
579    ) -> Result<Vec<list_compute_pools::ListComputePoolsComputePools>> {
580        let variables = list_compute_pools::Variables {};
581
582        let response_data = self.execute_query(ListComputePools, variables).await?;
583        Ok(response_data.compute_pools)
584    }
585
586    pub async fn get_recipe(
587        &self,
588        usecase: String,
589        id_or_key: String,
590    ) -> Result<Option<get_recipe::GetRecipeCustomRecipe>> {
591        let variables = get_recipe::Variables { usecase, id_or_key };
592
593        let response_data = self.execute_query(GetRecipe, variables).await?;
594        Ok(response_data.custom_recipe)
595    }
596
597    pub async fn get_grader(
598        &self,
599        id_or_key: &str,
600        use_case: &str,
601    ) -> Result<get_grader::GetGraderGrader> {
602        let variables = get_grader::Variables {
603            id: id_or_key.to_string(),
604            use_case: use_case.to_string(),
605        };
606
607        let response_data = self.execute_query(GetGrader, variables).await?;
608        Ok(response_data.grader)
609    }
610
611    pub async fn get_dataset(
612        &self,
613        id_or_key: &str,
614        use_case: &str,
615    ) -> Result<Option<get_dataset::GetDatasetDataset>> {
616        let variables = get_dataset::Variables {
617            id_or_key: id_or_key.to_string(),
618            use_case: use_case.to_string(),
619        };
620
621        let response_data = self.execute_query(GetDataset, variables).await?;
622        Ok(response_data.dataset)
623    }
624
625    pub async fn get_model_config(
626        &self,
627        id_or_key: &str,
628    ) -> Result<Option<get_model_config::GetModelConfigModel>> {
629        let variables = get_model_config::Variables {
630            id_or_key: id_or_key.to_string(),
631        };
632
633        let response_data = self.execute_query(GetModelConfig, variables).await?;
634        Ok(response_data.model)
635    }
636
637    pub fn base_url(&self) -> &Url {
638        &self.rest_base_url
639    }
640
641    async fn init_chunked_upload(&self, total_parts: u64) -> Result<String> {
642        let url = self.rest_base_url.join(INIT_CHUNKED_UPLOAD_ROUTE)?;
643
644        let request = InitChunkedUploadRequest {
645            content_type: "application/jsonl".to_string(),
646            metadata: None,
647            total_parts_count: total_parts,
648        };
649
650        let response = self
651            .client
652            .post(url)
653            .bearer_auth(&self.auth_token)
654            .json(&request)
655            .send()
656            .await?;
657
658        if !response.status().is_success() {
659            return Err(AdaptiveError::ChunkedUploadInitFailed {
660                status: response.status().to_string(),
661                body: response.text().await.unwrap_or_default(),
662            });
663        }
664
665        let init_response: InitChunkedUploadResponse = response.json().await?;
666        Ok(init_response.session_id)
667    }
668
669    async fn upload_part(
670        &self,
671        session_id: &str,
672        part_number: u64,
673        data: Vec<u8>,
674        progress_tx: mpsc::Sender<u64>,
675    ) -> Result<()> {
676        const SUB_CHUNK_SIZE: usize = 64 * 1024;
677
678        let url = self.rest_base_url.join(UPLOAD_PART_ROUTE)?;
679
680        let chunks: Vec<Vec<u8>> = data
681            .chunks(SUB_CHUNK_SIZE)
682            .map(|chunk| chunk.to_vec())
683            .collect();
684
685        let stream = futures::stream::iter(chunks).map(move |chunk| {
686            let len = chunk.len() as u64;
687            let tx = progress_tx.clone();
688            let _ = tx.try_send(len);
689            Ok::<_, std::io::Error>(chunk)
690        });
691
692        let body = reqwest::Body::wrap_stream(stream);
693
694        let response = self
695            .client
696            .post(url)
697            .bearer_auth(&self.auth_token)
698            .query(&[
699                ("session_id", session_id),
700                ("part_number", &part_number.to_string()),
701            ])
702            .header("Content-Type", "application/octet-stream")
703            .body(body)
704            .send()
705            .await?;
706
707        if !response.status().is_success() {
708            return Err(AdaptiveError::ChunkedUploadPartFailed {
709                part_number,
710                status: response.status().to_string(),
711                body: response.text().await.unwrap_or_default(),
712            });
713        }
714
715        Ok(())
716    }
717
718    async fn abort_chunked_upload(&self, session_id: &str) -> Result<()> {
719        let url = self.rest_base_url.join(ABORT_CHUNKED_UPLOAD_ROUTE)?;
720
721        let request = AbortChunkedUploadRequest {
722            session_id: session_id.to_string(),
723        };
724
725        let _ = self
726            .client
727            .delete(url)
728            .bearer_auth(&self.auth_token)
729            .json(&request)
730            .send()
731            .await;
732
733        Ok(())
734    }
735
736    async fn create_dataset_from_multipart(
737        &self,
738        usecase: &str,
739        name: &str,
740        key: &str,
741        session_id: &str,
742    ) -> Result<
743        create_dataset_from_multipart::CreateDatasetFromMultipartCreateDatasetFromMultipartUpload,
744    > {
745        let variables = create_dataset_from_multipart::Variables {
746            input: create_dataset_from_multipart::DatasetCreateFromMultipartUpload {
747                use_case: usecase.to_string(),
748                name: name.to_string(),
749                key: Some(key.to_string()),
750                source: None,
751                upload_session_id: session_id.to_string(),
752            },
753        };
754
755        let response_data = self
756            .execute_query(CreateDatasetFromMultipart, variables)
757            .await?;
758        Ok(response_data.create_dataset_from_multipart_upload)
759    }
760
761    pub fn chunked_upload_dataset<'a, P: AsRef<Path> + Send + 'a>(
762        &'a self,
763        usecase: &'a str,
764        name: &'a str,
765        key: &'a str,
766        dataset: P,
767    ) -> Result<BoxStream<'a, Result<UploadEvent>>> {
768        let file_size = std::fs::metadata(dataset.as_ref())?.len();
769
770        let (total_parts, chunk_size) = calculate_upload_parts(file_size)?;
771
772        let stream = async_stream::try_stream! {
773            yield UploadEvent::Progress(ChunkedUploadProgress {
774                bytes_uploaded: 0,
775                total_bytes: file_size,
776            });
777
778            let session_id = self.init_chunked_upload(total_parts).await?;
779
780            let mut file = File::open(dataset.as_ref())?;
781            let mut buffer = vec![0u8; chunk_size as usize];
782            let mut bytes_uploaded = 0u64;
783
784            let (progress_tx, mut progress_rx) = mpsc::channel::<u64>(64);
785
786            for part_number in 1..=total_parts {
787                let bytes_read = file.read(&mut buffer)?;
788                let chunk_data = buffer[..bytes_read].to_vec();
789
790                let upload_fut = self.upload_part(&session_id, part_number, chunk_data, progress_tx.clone());
791                tokio::pin!(upload_fut);
792
793                let upload_result: Result<()> = loop {
794                    tokio::select! {
795                        biased;
796                        result = &mut upload_fut => {
797                            break result;
798                        }
799                        Some(bytes) = progress_rx.recv() => {
800                            bytes_uploaded += bytes;
801                            yield UploadEvent::Progress(ChunkedUploadProgress {
802                                bytes_uploaded,
803                                total_bytes: file_size,
804                            });
805                        }
806                    }
807                };
808
809                if let Err(e) = upload_result {
810                    let _ = self.abort_chunked_upload(&session_id).await;
811                    Err(e)?;
812                }
813            }
814
815            let create_result = self
816                .create_dataset_from_multipart(usecase, name, key, &session_id)
817                .await;
818
819            match create_result {
820                Ok(response) => {
821                    yield UploadEvent::Complete(response);
822                }
823                Err(e) => {
824                    let _ = self.abort_chunked_upload(&session_id).await;
825                    Err(AdaptiveError::DatasetCreationFailed(e.to_string()))?;
826                }
827            }
828        };
829
830        Ok(Box::pin(stream))
831    }
832}