Skip to main content

adaptive_client_rust/
lib.rs

1use std::{fmt::Display, fs::File, io::Read, path::Path, time::SystemTime};
2
3use futures::{StreamExt, stream::BoxStream};
4use thiserror::Error;
5use tokio::sync::mpsc;
6
7use graphql_client::{GraphQLQuery, Response};
8use reqwest::Client;
9use serde::{Deserialize, Serialize, de::DeserializeOwned};
10use serde_json::{Map, Value};
11use url::Url;
12use uuid::Uuid;
13
14mod rest_types;
15mod serde_utils;
16
17use rest_types::{AbortChunkedUploadRequest, InitChunkedUploadRequest, InitChunkedUploadResponse};
18
19const MEGABYTE: u64 = 1024 * 1024; // 1MB
20pub const MIN_CHUNK_SIZE_BYTES: u64 = 5 * MEGABYTE;
21const MAX_CHUNK_SIZE_BYTES: u64 = 100 * MEGABYTE;
22const MAX_PARTS_COUNT: u64 = 10000;
23
24const SIZE_500MB: u64 = 500 * MEGABYTE;
25const SIZE_10GB: u64 = 10 * 1024 * MEGABYTE;
26const SIZE_50GB: u64 = 50 * 1024 * MEGABYTE;
27
28#[derive(Error, Debug)]
29pub enum AdaptiveError {
30    #[error("HTTP error: {0}")]
31    HttpError(#[from] reqwest::Error),
32
33    #[error("JSON error: {0}")]
34    JsonError(#[from] serde_json::Error),
35
36    #[error("IO error: {0}")]
37    IoError(#[from] std::io::Error),
38
39    #[error("URL parse error: {0}")]
40    UrlParseError(#[from] url::ParseError),
41
42    #[error("File too small for chunked upload: {size} bytes (minimum: {min_size} bytes)")]
43    FileTooSmall { size: u64, min_size: u64 },
44
45    #[error("File too large: {size} bytes exceeds maximum {max_size} bytes")]
46    FileTooLarge { size: u64, max_size: u64 },
47
48    #[error("GraphQL errors: {0:?}")]
49    GraphQLErrors(Vec<graphql_client::Error>),
50
51    #[error("No data returned from GraphQL")]
52    NoGraphQLData,
53
54    #[error("Job not found: {0}")]
55    JobNotFound(Uuid),
56
57    #[error("Failed to initialize chunked upload: {status} - {body}")]
58    ChunkedUploadInitFailed { status: String, body: String },
59
60    #[error("Failed to upload part {part_number}: {status} - {body}")]
61    ChunkedUploadPartFailed {
62        part_number: u64,
63        status: String,
64        body: String,
65    },
66
67    #[error("Failed to create dataset: {0}")]
68    DatasetCreationFailed(String),
69
70    #[error("HTTP status error: {status} - {body}")]
71    HttpStatusError { status: String, body: String },
72
73    #[error("Failed to parse JSON response: {error}. Body preview: {body}")]
74    JsonParseError { error: String, body: String },
75}
76
77type Result<T> = std::result::Result<T, AdaptiveError>;
78
79#[derive(Clone, Debug, Default)]
80pub struct ChunkedUploadProgress {
81    pub bytes_uploaded: u64,
82    pub total_bytes: u64,
83}
84
85#[derive(Debug)]
86pub enum UploadEvent {
87    Progress(ChunkedUploadProgress),
88    Complete(
89        create_dataset_from_multipart::CreateDatasetFromMultipartCreateDatasetFromMultipartUpload,
90    ),
91}
92
93pub fn calculate_upload_parts(file_size: u64) -> Result<(u64, u64)> {
94    if file_size < MIN_CHUNK_SIZE_BYTES {
95        return Err(AdaptiveError::FileTooSmall {
96            size: file_size,
97            min_size: MIN_CHUNK_SIZE_BYTES,
98        });
99    }
100
101    let mut chunk_size = if file_size < SIZE_500MB {
102        5 * MEGABYTE
103    } else if file_size < SIZE_10GB {
104        10 * MEGABYTE
105    } else if file_size < SIZE_50GB {
106        50 * MEGABYTE
107    } else {
108        100 * MEGABYTE
109    };
110
111    let mut total_parts = file_size.div_ceil(chunk_size);
112
113    if total_parts > MAX_PARTS_COUNT {
114        chunk_size = file_size.div_ceil(MAX_PARTS_COUNT);
115
116        if chunk_size > MAX_CHUNK_SIZE_BYTES {
117            let max_file_size = MAX_CHUNK_SIZE_BYTES * MAX_PARTS_COUNT;
118            return Err(AdaptiveError::FileTooLarge {
119                size: file_size,
120                max_size: max_file_size,
121            });
122        }
123
124        total_parts = file_size.div_ceil(chunk_size);
125    }
126
127    Ok((total_parts, chunk_size))
128}
129
130type IdOrKey = String;
131#[allow(clippy::upper_case_acronyms)]
132type UUID = Uuid;
133type JsObject = Map<String, Value>;
134type InputDatetime = String;
135#[allow(clippy::upper_case_acronyms)]
136type JSON = Value;
137type KeyInput = String;
138
139#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
140pub struct Timestamp(pub SystemTime);
141
142impl<'de> serde::Deserialize<'de> for Timestamp {
143    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
144    where
145        D: serde::Deserializer<'de>,
146    {
147        let system_time = serde_utils::deserialize_timestamp_millis(deserializer)?;
148        Ok(Timestamp(system_time))
149    }
150}
151
152const PAGE_SIZE: usize = 20;
153
154#[derive(Debug, PartialEq, Serialize, Deserialize)]
155pub struct Upload(usize);
156
157#[derive(GraphQLQuery)]
158#[graphql(
159    schema_path = "schema.gql",
160    query_path = "src/graphql/list.graphql",
161    response_derives = "Debug, Clone"
162)]
163pub struct GetCustomRecipes;
164
165#[derive(GraphQLQuery)]
166#[graphql(
167    schema_path = "schema.gql",
168    query_path = "src/graphql/job.graphql",
169    response_derives = "Debug, Clone"
170)]
171pub struct GetJob;
172
173#[derive(GraphQLQuery)]
174#[graphql(
175    schema_path = "schema.gql",
176    query_path = "src/graphql/jobs.graphql",
177    response_derives = "Debug, Clone"
178)]
179pub struct ListJobs;
180
181#[derive(GraphQLQuery)]
182#[graphql(
183    schema_path = "schema.gql",
184    query_path = "src/graphql/cancel.graphql",
185    response_derives = "Debug, Clone"
186)]
187pub struct CancelJob;
188
189#[derive(GraphQLQuery)]
190#[graphql(
191    schema_path = "schema.gql",
192    query_path = "src/graphql/models.graphql",
193    response_derives = "Debug, Clone"
194)]
195pub struct ListModels;
196
197#[derive(GraphQLQuery)]
198#[graphql(
199    schema_path = "schema.gql",
200    query_path = "src/graphql/all_models.graphql",
201    response_derives = "Debug, Clone"
202)]
203pub struct ListAllModels;
204
205impl Display for get_job::JobStatus {
206    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
207        match self {
208            get_job::JobStatus::PENDING => write!(f, "Pending"),
209            get_job::JobStatus::RUNNING => write!(f, "Running"),
210            get_job::JobStatus::COMPLETED => write!(f, "Completed"),
211            get_job::JobStatus::FAILED => write!(f, "Failed"),
212            get_job::JobStatus::CANCELED => write!(f, "Canceled"),
213            get_job::JobStatus::Other(_) => write!(f, "Unknown"),
214        }
215    }
216}
217
218#[derive(GraphQLQuery)]
219#[graphql(
220    schema_path = "schema.gql",
221    query_path = "src/graphql/publish.graphql",
222    response_derives = "Debug, Clone"
223)]
224pub struct PublishCustomRecipe;
225
226#[derive(GraphQLQuery)]
227#[graphql(
228    schema_path = "schema.gql",
229    query_path = "src/graphql/upload_dataset.graphql",
230    response_derives = "Debug, Clone"
231)]
232pub struct UploadDataset;
233
234#[derive(GraphQLQuery)]
235#[graphql(
236    schema_path = "schema.gql",
237    query_path = "src/graphql/create_dataset_from_multipart.graphql",
238    response_derives = "Debug, Clone"
239)]
240pub struct CreateDatasetFromMultipart;
241
242#[derive(GraphQLQuery)]
243#[graphql(
244    schema_path = "schema.gql",
245    query_path = "src/graphql/run.graphql",
246    response_derives = "Debug, Clone"
247)]
248pub struct RunCustomRecipe;
249
250#[derive(GraphQLQuery)]
251#[graphql(
252    schema_path = "schema.gql",
253    query_path = "src/graphql/projects.graphql",
254    response_derives = "Debug, Clone"
255)]
256pub struct ListProjects;
257
258#[derive(GraphQLQuery)]
259#[graphql(
260    schema_path = "schema.gql",
261    query_path = "src/graphql/pools.graphql",
262    response_derives = "Debug, Clone"
263)]
264pub struct ListComputePools;
265
266#[derive(GraphQLQuery)]
267#[graphql(
268    schema_path = "schema.gql",
269    query_path = "src/graphql/recipe.graphql",
270    response_derives = "Debug, Clone"
271)]
272pub struct GetRecipe;
273
274#[derive(GraphQLQuery)]
275#[graphql(
276    schema_path = "schema.gql",
277    query_path = "src/graphql/grader.graphql",
278    response_derives = "Debug, Clone, Serialize"
279)]
280pub struct GetGrader;
281
282#[derive(GraphQLQuery)]
283#[graphql(
284    schema_path = "schema.gql",
285    query_path = "src/graphql/dataset.graphql",
286    response_derives = "Debug, Clone"
287)]
288pub struct GetDataset;
289
290#[derive(GraphQLQuery)]
291#[graphql(
292    schema_path = "schema.gql",
293    query_path = "src/graphql/model_config.graphql",
294    response_derives = "Debug, Clone, Serialize"
295)]
296pub struct GetModelConfig;
297
298#[derive(GraphQLQuery)]
299#[graphql(
300    schema_path = "schema.gql",
301    query_path = "src/graphql/job_progress.graphql",
302    response_derives = "Debug, Clone"
303)]
304pub struct UpdateJobProgress;
305
306const INIT_CHUNKED_UPLOAD_ROUTE: &str = "v1/upload/init";
307const UPLOAD_PART_ROUTE: &str = "v1/upload/part";
308const ABORT_CHUNKED_UPLOAD_ROUTE: &str = "v1/upload/abort";
309
310#[derive(Clone)]
311pub struct AdaptiveClient {
312    client: Client,
313    graphql_url: Url,
314    rest_base_url: Url,
315    auth_token: String,
316}
317
318impl AdaptiveClient {
319    pub fn new(api_base_url: Url, auth_token: String) -> Self {
320        let graphql_url = api_base_url
321            .join("graphql")
322            .expect("Failed to append graphql to base URL");
323
324        let client = Client::builder()
325            .user_agent(format!(
326                "adaptive-client-rust/{}",
327                env!("CARGO_PKG_VERSION")
328            ))
329            .build()
330            .expect("Failed to build HTTP client");
331
332        Self {
333            client,
334            graphql_url,
335            rest_base_url: api_base_url,
336            auth_token,
337        }
338    }
339
340    async fn execute_query<T>(&self, _query: T, variables: T::Variables) -> Result<T::ResponseData>
341    where
342        T: GraphQLQuery,
343        T::Variables: serde::Serialize,
344        T::ResponseData: DeserializeOwned,
345    {
346        let request_body = T::build_query(variables);
347
348        let response = self
349            .client
350            .post(self.graphql_url.clone())
351            .bearer_auth(&self.auth_token)
352            .json(&request_body)
353            .send()
354            .await?;
355
356        let status = response.status();
357        let response_text = response.text().await?;
358
359        if !status.is_success() {
360            return Err(AdaptiveError::HttpStatusError {
361                status: status.to_string(),
362                body: response_text,
363            });
364        }
365
366        let response_body: Response<T::ResponseData> = serde_json::from_str(&response_text)
367            .map_err(|e| AdaptiveError::JsonParseError {
368                error: e.to_string(),
369                body: response_text.chars().take(500).collect(),
370            })?;
371
372        match response_body.data {
373            Some(data) => Ok(data),
374            None => {
375                if let Some(errors) = response_body.errors {
376                    return Err(AdaptiveError::GraphQLErrors(errors));
377                }
378                Err(AdaptiveError::NoGraphQLData)
379            }
380        }
381    }
382
383    pub async fn list_recipes(
384        &self,
385        project: &str,
386    ) -> Result<Vec<get_custom_recipes::GetCustomRecipesCustomRecipes>> {
387        let variables = get_custom_recipes::Variables {
388            project: IdOrKey::from(project),
389        };
390
391        let response_data = self.execute_query(GetCustomRecipes, variables).await?;
392        Ok(response_data.custom_recipes)
393    }
394
395    pub async fn get_job(&self, job_id: Uuid) -> Result<get_job::GetJobJob> {
396        let variables = get_job::Variables { id: job_id };
397
398        let response_data = self.execute_query(GetJob, variables).await?;
399
400        match response_data.job {
401            Some(job) => Ok(job),
402            None => Err(AdaptiveError::JobNotFound(job_id)),
403        }
404    }
405
406    pub async fn upload_dataset<P: AsRef<Path>>(
407        &self,
408        project: &str,
409        name: &str,
410        dataset: P,
411    ) -> Result<upload_dataset::UploadDatasetCreateDataset> {
412        let variables = upload_dataset::Variables {
413            project: IdOrKey::from(project),
414            file: Upload(0),
415            name: Some(name.to_string()),
416        };
417
418        let operations = UploadDataset::build_query(variables);
419        let operations = serde_json::to_string(&operations)?;
420
421        let file_map = r#"{ "0": ["variables.file"] }"#;
422
423        let dataset_file = reqwest::multipart::Part::file(dataset).await?;
424
425        let form = reqwest::multipart::Form::new()
426            .text("operations", operations)
427            .text("map", file_map)
428            .part("0", dataset_file);
429
430        let response = self
431            .client
432            .post(self.graphql_url.clone())
433            .bearer_auth(&self.auth_token)
434            .multipart(form)
435            .send()
436            .await?;
437
438        let response: Response<<UploadDataset as graphql_client::GraphQLQuery>::ResponseData> =
439            response.json().await?;
440
441        match response.data {
442            Some(data) => Ok(data.create_dataset),
443            None => {
444                if let Some(errors) = response.errors {
445                    return Err(AdaptiveError::GraphQLErrors(errors));
446                }
447                Err(AdaptiveError::NoGraphQLData)
448            }
449        }
450    }
451
452    pub async fn publish_recipe<P: AsRef<Path>>(
453        &self,
454        project: &str,
455        name: &str,
456        key: &str,
457        recipe: P,
458    ) -> Result<publish_custom_recipe::PublishCustomRecipeCreateCustomRecipe> {
459        let variables = publish_custom_recipe::Variables {
460            project: IdOrKey::from(project),
461            file: Upload(0),
462            name: Some(name.to_string()),
463            key: Some(key.to_string()),
464        };
465
466        let operations = PublishCustomRecipe::build_query(variables);
467        let operations = serde_json::to_string(&operations)?;
468
469        let file_map = r#"{ "0": ["variables.file"] }"#;
470
471        let recipe_file = reqwest::multipart::Part::file(recipe).await?;
472
473        let form = reqwest::multipart::Form::new()
474            .text("operations", operations)
475            .text("map", file_map)
476            .part("0", recipe_file);
477
478        let response = self
479            .client
480            .post(self.graphql_url.clone())
481            .bearer_auth(&self.auth_token)
482            .multipart(form)
483            .send()
484            .await?;
485        let response: Response<
486            <PublishCustomRecipe as graphql_client::GraphQLQuery>::ResponseData,
487        > = response.json().await?;
488
489        match response.data {
490            Some(data) => Ok(data.create_custom_recipe),
491            None => {
492                if let Some(errors) = response.errors {
493                    return Err(AdaptiveError::GraphQLErrors(errors));
494                }
495                Err(AdaptiveError::NoGraphQLData)
496            }
497        }
498    }
499
500    pub async fn run_recipe(
501        &self,
502        project: &str,
503        recipe_id: &str,
504        parameters: Map<String, Value>,
505        name: Option<String>,
506        compute_pool: Option<String>,
507        num_gpus: u32,
508        use_experimental_runner: bool,
509    ) -> Result<run_custom_recipe::RunCustomRecipeCreateJob> {
510        let variables = run_custom_recipe::Variables {
511            input: run_custom_recipe::JobInput {
512                recipe: recipe_id.to_string(),
513                project: project.to_string(),
514                args: parameters,
515                name,
516                compute_pool,
517                num_gpus: num_gpus as i64,
518                use_experimental_runner,
519                max_cpu: None,
520                max_ram_gb: None,
521                max_duration_secs: None,
522            },
523        };
524
525        let response_data = self.execute_query(RunCustomRecipe, variables).await?;
526        Ok(response_data.create_job)
527    }
528
529    pub async fn list_jobs(
530        &self,
531        project: Option<String>,
532    ) -> Result<Vec<list_jobs::ListJobsJobsNodes>> {
533        let mut jobs = Vec::new();
534        let mut page = self.list_jobs_page(project.clone(), None).await?;
535        jobs.extend(page.nodes.iter().cloned());
536        while page.page_info.has_next_page {
537            page = self
538                .list_jobs_page(project.clone(), page.page_info.end_cursor)
539                .await?;
540            jobs.extend(page.nodes.iter().cloned());
541        }
542        Ok(jobs)
543    }
544
545    async fn list_jobs_page(
546        &self,
547        project: Option<String>,
548        after: Option<String>,
549    ) -> Result<list_jobs::ListJobsJobs> {
550        let variables = list_jobs::Variables {
551            filter: Some(list_jobs::ListJobsFilterInput {
552                project,
553                kind: Some(vec![list_jobs::JobKind::CUSTOM]),
554                status: Some(vec![
555                    list_jobs::JobStatus::RUNNING,
556                    list_jobs::JobStatus::PENDING,
557                ]),
558                timerange: None,
559                custom_recipes: None,
560                artifacts: None,
561            }),
562            cursor: Some(list_jobs::CursorPageInput {
563                first: Some(PAGE_SIZE as i64),
564                after,
565                before: None,
566                last: None,
567                offset: None,
568            }),
569        };
570
571        let response_data = self.execute_query(ListJobs, variables).await?;
572        Ok(response_data.jobs)
573    }
574
575    pub async fn cancel_job(&self, job_id: Uuid) -> Result<cancel_job::CancelJobCancelJob> {
576        let variables = cancel_job::Variables { job_id };
577
578        let response_data = self.execute_query(CancelJob, variables).await?;
579        Ok(response_data.cancel_job)
580    }
581
582    pub async fn update_job_progress(
583        &self,
584        job_id: Uuid,
585        event: update_job_progress::JobProgressEventInput,
586    ) -> Result<update_job_progress::UpdateJobProgressUpdateJobProgress> {
587        let variables = update_job_progress::Variables { job_id, event };
588
589        let response_data = self.execute_query(UpdateJobProgress, variables).await?;
590        Ok(response_data.update_job_progress)
591    }
592
593    pub async fn list_models(
594        &self,
595        project: String,
596    ) -> Result<Vec<list_models::ListModelsProjectModelServices>> {
597        let variables = list_models::Variables { project };
598
599        let response_data = self.execute_query(ListModels, variables).await?;
600        Ok(response_data
601            .project
602            .map(|project| project.model_services)
603            .unwrap_or(Vec::new()))
604    }
605
606    pub async fn list_all_models(&self) -> Result<Vec<list_all_models::ListAllModelsModels>> {
607        let variables = list_all_models::Variables {};
608
609        let response_data = self.execute_query(ListAllModels, variables).await?;
610        Ok(response_data.models)
611    }
612
613    pub async fn list_projects(&self) -> Result<Vec<list_projects::ListProjectsProjects>> {
614        let variables = list_projects::Variables {};
615
616        let response_data = self.execute_query(ListProjects, variables).await?;
617        Ok(response_data.projects)
618    }
619
620    pub async fn list_pools(
621        &self,
622    ) -> Result<Vec<list_compute_pools::ListComputePoolsComputePools>> {
623        let variables = list_compute_pools::Variables {};
624
625        let response_data = self.execute_query(ListComputePools, variables).await?;
626        Ok(response_data.compute_pools)
627    }
628
629    pub async fn get_recipe(
630        &self,
631        project: String,
632        id_or_key: String,
633    ) -> Result<Option<get_recipe::GetRecipeCustomRecipe>> {
634        let variables = get_recipe::Variables { project, id_or_key };
635
636        let response_data = self.execute_query(GetRecipe, variables).await?;
637        Ok(response_data.custom_recipe)
638    }
639
640    pub async fn get_grader(
641        &self,
642        id_or_key: &str,
643        project: &str,
644    ) -> Result<get_grader::GetGraderGrader> {
645        let variables = get_grader::Variables {
646            id: id_or_key.to_string(),
647            project: project.to_string(),
648        };
649
650        let response_data = self.execute_query(GetGrader, variables).await?;
651        Ok(response_data.grader)
652    }
653
654    pub async fn get_dataset(
655        &self,
656        id_or_key: &str,
657        project: &str,
658    ) -> Result<Option<get_dataset::GetDatasetDataset>> {
659        let variables = get_dataset::Variables {
660            id_or_key: id_or_key.to_string(),
661            project: project.to_string(),
662        };
663
664        let response_data = self.execute_query(GetDataset, variables).await?;
665        Ok(response_data.dataset)
666    }
667
668    pub async fn get_model_config(
669        &self,
670        id_or_key: &str,
671    ) -> Result<Option<get_model_config::GetModelConfigModel>> {
672        let variables = get_model_config::Variables {
673            id_or_key: id_or_key.to_string(),
674        };
675
676        let response_data = self.execute_query(GetModelConfig, variables).await?;
677        Ok(response_data.model)
678    }
679
680    pub fn base_url(&self) -> &Url {
681        &self.rest_base_url
682    }
683
684    /// Upload bytes using the chunked upload API and return the session_id.
685    /// This can be used to link the uploaded file to an artifact.
686    pub async fn upload_bytes(&self, data: &[u8], content_type: &str) -> Result<String> {
687        let file_size = data.len() as u64;
688
689        // Calculate chunk size (same logic as calculate_upload_parts but inline for small files)
690        let chunk_size = if file_size < 5 * 1024 * 1024 {
691            // For files < 5MB, use the whole file as one chunk
692            file_size.max(1)
693        } else if file_size < 500 * 1024 * 1024 {
694            5 * 1024 * 1024
695        } else if file_size < 10 * 1024 * 1024 * 1024 {
696            10 * 1024 * 1024
697        } else {
698            100 * 1024 * 1024
699        };
700
701        let total_parts = file_size.div_ceil(chunk_size).max(1);
702
703        // Initialize upload session
704        let session_id = self
705            .init_chunked_upload_with_content_type(total_parts, content_type)
706            .await?;
707
708        // Upload parts
709        for part_number in 1..=total_parts {
710            let start = ((part_number - 1) * chunk_size) as usize;
711            let end = (part_number * chunk_size).min(file_size) as usize;
712            let chunk = data[start..end].to_vec();
713
714            if let Err(e) = self
715                .upload_part_simple(&session_id, part_number, chunk)
716                .await
717            {
718                let _ = self.abort_chunked_upload(&session_id).await;
719                return Err(e);
720            }
721        }
722
723        Ok(session_id)
724    }
725
726    /// Initialize a chunked upload session.
727    pub async fn init_chunked_upload_with_content_type(
728        &self,
729        total_parts: u64,
730        content_type: &str,
731    ) -> Result<String> {
732        let url = self.rest_base_url.join(INIT_CHUNKED_UPLOAD_ROUTE)?;
733
734        let request = InitChunkedUploadRequest {
735            content_type: content_type.to_string(),
736            metadata: None,
737            total_parts_count: total_parts,
738        };
739
740        let response = self
741            .client
742            .post(url)
743            .bearer_auth(&self.auth_token)
744            .json(&request)
745            .send()
746            .await?;
747
748        if !response.status().is_success() {
749            return Err(AdaptiveError::ChunkedUploadInitFailed {
750                status: response.status().to_string(),
751                body: response.text().await.unwrap_or_default(),
752            });
753        }
754
755        let init_response: InitChunkedUploadResponse = response.json().await?;
756        Ok(init_response.session_id)
757    }
758
759    /// Upload a single part of a chunked upload.
760    pub async fn upload_part_simple(
761        &self,
762        session_id: &str,
763        part_number: u64,
764        data: Vec<u8>,
765    ) -> Result<()> {
766        let url = self.rest_base_url.join(UPLOAD_PART_ROUTE)?;
767
768        let response = self
769            .client
770            .post(url)
771            .bearer_auth(&self.auth_token)
772            .query(&[
773                ("session_id", session_id),
774                ("part_number", &part_number.to_string()),
775            ])
776            .header("Content-Type", "application/octet-stream")
777            .body(data)
778            .send()
779            .await?;
780
781        if !response.status().is_success() {
782            return Err(AdaptiveError::ChunkedUploadPartFailed {
783                part_number,
784                status: response.status().to_string(),
785                body: response.text().await.unwrap_or_default(),
786            });
787        }
788
789        Ok(())
790    }
791
792    async fn init_chunked_upload(&self, total_parts: u64) -> Result<String> {
793        let url = self.rest_base_url.join(INIT_CHUNKED_UPLOAD_ROUTE)?;
794
795        let request = InitChunkedUploadRequest {
796            content_type: "application/jsonl".to_string(),
797            metadata: None,
798            total_parts_count: total_parts,
799        };
800
801        let response = self
802            .client
803            .post(url)
804            .bearer_auth(&self.auth_token)
805            .json(&request)
806            .send()
807            .await?;
808
809        if !response.status().is_success() {
810            return Err(AdaptiveError::ChunkedUploadInitFailed {
811                status: response.status().to_string(),
812                body: response.text().await.unwrap_or_default(),
813            });
814        }
815
816        let init_response: InitChunkedUploadResponse = response.json().await?;
817        Ok(init_response.session_id)
818    }
819
820    async fn upload_part(
821        &self,
822        session_id: &str,
823        part_number: u64,
824        data: Vec<u8>,
825        progress_tx: mpsc::Sender<u64>,
826    ) -> Result<()> {
827        const SUB_CHUNK_SIZE: usize = 64 * 1024;
828
829        let url = self.rest_base_url.join(UPLOAD_PART_ROUTE)?;
830
831        let chunks: Vec<Vec<u8>> = data
832            .chunks(SUB_CHUNK_SIZE)
833            .map(|chunk| chunk.to_vec())
834            .collect();
835
836        let stream = futures::stream::iter(chunks).map(move |chunk| {
837            let len = chunk.len() as u64;
838            let tx = progress_tx.clone();
839            let _ = tx.try_send(len);
840            Ok::<_, std::io::Error>(chunk)
841        });
842
843        let body = reqwest::Body::wrap_stream(stream);
844
845        let response = self
846            .client
847            .post(url)
848            .bearer_auth(&self.auth_token)
849            .query(&[
850                ("session_id", session_id),
851                ("part_number", &part_number.to_string()),
852            ])
853            .header("Content-Type", "application/octet-stream")
854            .body(body)
855            .send()
856            .await?;
857
858        if !response.status().is_success() {
859            return Err(AdaptiveError::ChunkedUploadPartFailed {
860                part_number,
861                status: response.status().to_string(),
862                body: response.text().await.unwrap_or_default(),
863            });
864        }
865
866        Ok(())
867    }
868
869    /// Abort a chunked upload session.
870    pub async fn abort_chunked_upload(&self, session_id: &str) -> Result<()> {
871        let url = self.rest_base_url.join(ABORT_CHUNKED_UPLOAD_ROUTE)?;
872
873        let request = AbortChunkedUploadRequest {
874            session_id: session_id.to_string(),
875        };
876
877        let _ = self
878            .client
879            .delete(url)
880            .bearer_auth(&self.auth_token)
881            .json(&request)
882            .send()
883            .await;
884
885        Ok(())
886    }
887
888    async fn create_dataset_from_multipart(
889        &self,
890        project: &str,
891        name: &str,
892        key: &str,
893        session_id: &str,
894    ) -> Result<
895        create_dataset_from_multipart::CreateDatasetFromMultipartCreateDatasetFromMultipartUpload,
896    > {
897        let variables = create_dataset_from_multipart::Variables {
898            input: create_dataset_from_multipart::DatasetCreateFromMultipartUpload {
899                project: project.to_string(),
900                name: name.to_string(),
901                key: Some(key.to_string()),
902                source: None,
903                upload_session_id: session_id.to_string(),
904            },
905        };
906
907        let response_data = self
908            .execute_query(CreateDatasetFromMultipart, variables)
909            .await?;
910        Ok(response_data.create_dataset_from_multipart_upload)
911    }
912
913    pub fn chunked_upload_dataset<'a, P: AsRef<Path> + Send + 'a>(
914        &'a self,
915        project: &'a str,
916        name: &'a str,
917        key: &'a str,
918        dataset: P,
919    ) -> Result<BoxStream<'a, Result<UploadEvent>>> {
920        let file_size = std::fs::metadata(dataset.as_ref())?.len();
921
922        let (total_parts, chunk_size) = calculate_upload_parts(file_size)?;
923
924        let stream = async_stream::try_stream! {
925            yield UploadEvent::Progress(ChunkedUploadProgress {
926                bytes_uploaded: 0,
927                total_bytes: file_size,
928            });
929
930            let session_id = self.init_chunked_upload(total_parts).await?;
931
932            let mut file = File::open(dataset.as_ref())?;
933            let mut buffer = vec![0u8; chunk_size as usize];
934            let mut bytes_uploaded = 0u64;
935
936            let (progress_tx, mut progress_rx) = mpsc::channel::<u64>(64);
937
938            for part_number in 1..=total_parts {
939                let bytes_read = file.read(&mut buffer)?;
940                let chunk_data = buffer[..bytes_read].to_vec();
941
942                let upload_fut = self.upload_part(&session_id, part_number, chunk_data, progress_tx.clone());
943                tokio::pin!(upload_fut);
944
945                let upload_result: Result<()> = loop {
946                    tokio::select! {
947                        biased;
948                        result = &mut upload_fut => {
949                            break result;
950                        }
951                        Some(bytes) = progress_rx.recv() => {
952                            bytes_uploaded += bytes;
953                            yield UploadEvent::Progress(ChunkedUploadProgress {
954                                bytes_uploaded,
955                                total_bytes: file_size,
956                            });
957                        }
958                    }
959                };
960
961                if let Err(e) = upload_result {
962                    let _ = self.abort_chunked_upload(&session_id).await;
963                    Err(e)?;
964                }
965            }
966
967            let create_result = self
968                .create_dataset_from_multipart(project, name, key, &session_id)
969                .await;
970
971            match create_result {
972                Ok(response) => {
973                    yield UploadEvent::Complete(response);
974                }
975                Err(e) => {
976                    let _ = self.abort_chunked_upload(&session_id).await;
977                    Err(AdaptiveError::DatasetCreationFailed(e.to_string()))?;
978                }
979            }
980        };
981
982        Ok(Box::pin(stream))
983    }
984
985    /// Download a file from the given URL and write it to the specified path.
986    /// The URL can be absolute or relative to the API base URL.
987    pub async fn download_file_to_path(&self, url: &str, dest_path: &Path) -> Result<()> {
988        use tokio::io::AsyncWriteExt;
989
990        let full_url = if url.starts_with("http://") || url.starts_with("https://") {
991            Url::parse(url)?
992        } else {
993            self.rest_base_url.join(url)?
994        };
995
996        let response = self
997            .client
998            .get(full_url)
999            .bearer_auth(&self.auth_token)
1000            .send()
1001            .await?;
1002
1003        if !response.status().is_success() {
1004            return Err(AdaptiveError::HttpError(
1005                response.error_for_status().unwrap_err(),
1006            ));
1007        }
1008
1009        let mut file = tokio::fs::File::create(dest_path).await?;
1010        let mut stream = response.bytes_stream();
1011
1012        while let Some(chunk) = stream.next().await {
1013            let chunk = chunk?;
1014            file.write_all(&chunk).await?;
1015        }
1016
1017        file.flush().await?;
1018        Ok(())
1019    }
1020}