Skip to main content

adaptive_client_rust/
lib.rs

1use std::{fmt::Display, fs::File, io::Read, path::Path, time::SystemTime};
2
3use futures::{StreamExt, stream::BoxStream};
4use thiserror::Error;
5use tokio::sync::mpsc;
6
7use graphql_client::{GraphQLQuery, Response};
8use reqwest::Client;
9use serde::{Deserialize, Serialize, de::DeserializeOwned};
10use serde_json::{Map, Value};
11use url::Url;
12use uuid::Uuid;
13
14mod rest_types;
15mod serde_utils;
16
17use rest_types::{AbortChunkedUploadRequest, InitChunkedUploadRequest, InitChunkedUploadResponse};
18
19const MEGABYTE: u64 = 1024 * 1024; // 1MB
20pub const MIN_CHUNK_SIZE_BYTES: u64 = 5 * MEGABYTE;
21const MAX_CHUNK_SIZE_BYTES: u64 = 100 * MEGABYTE;
22const MAX_PARTS_COUNT: u64 = 10000;
23
24const SIZE_500MB: u64 = 500 * MEGABYTE;
25const SIZE_10GB: u64 = 10 * 1024 * MEGABYTE;
26const SIZE_50GB: u64 = 50 * 1024 * MEGABYTE;
27
28#[derive(Error, Debug)]
29pub enum AdaptiveError {
30    #[error("HTTP error: {0}")]
31    HttpError(#[from] reqwest::Error),
32
33    #[error("JSON error: {0}")]
34    JsonError(#[from] serde_json::Error),
35
36    #[error("IO error: {0}")]
37    IoError(#[from] std::io::Error),
38
39    #[error("URL parse error: {0}")]
40    UrlParseError(#[from] url::ParseError),
41
42    #[error("File too small for chunked upload: {size} bytes (minimum: {min_size} bytes)")]
43    FileTooSmall { size: u64, min_size: u64 },
44
45    #[error("File too large: {size} bytes exceeds maximum {max_size} bytes")]
46    FileTooLarge { size: u64, max_size: u64 },
47
48    #[error("GraphQL errors: {0:?}")]
49    GraphQLErrors(Vec<graphql_client::Error>),
50
51    #[error("No data returned from GraphQL")]
52    NoGraphQLData,
53
54    #[error("Job not found: {0}")]
55    JobNotFound(Uuid),
56
57    #[error("Failed to initialize chunked upload: {status} - {body}")]
58    ChunkedUploadInitFailed { status: String, body: String },
59
60    #[error("Failed to upload part {part_number}: {status} - {body}")]
61    ChunkedUploadPartFailed {
62        part_number: u64,
63        status: String,
64        body: String,
65    },
66
67    #[error("Failed to create dataset: {0}")]
68    DatasetCreationFailed(String),
69
70    #[error("HTTP status error: {status} - {body}")]
71    HttpStatusError { status: String, body: String },
72
73    #[error("Failed to parse JSON response: {error}. Body preview: {body}")]
74    JsonParseError { error: String, body: String },
75}
76
77type Result<T> = std::result::Result<T, AdaptiveError>;
78
79#[derive(Clone, Debug, Default)]
80pub struct ChunkedUploadProgress {
81    pub bytes_uploaded: u64,
82    pub total_bytes: u64,
83}
84
85#[derive(Debug)]
86pub enum UploadEvent {
87    Progress(ChunkedUploadProgress),
88    Complete(
89        create_dataset_from_multipart::CreateDatasetFromMultipartCreateDatasetFromMultipartUpload,
90    ),
91}
92
93pub fn calculate_upload_parts(file_size: u64) -> Result<(u64, u64)> {
94    if file_size < MIN_CHUNK_SIZE_BYTES {
95        return Err(AdaptiveError::FileTooSmall {
96            size: file_size,
97            min_size: MIN_CHUNK_SIZE_BYTES,
98        });
99    }
100
101    let mut chunk_size = if file_size < SIZE_500MB {
102        5 * MEGABYTE
103    } else if file_size < SIZE_10GB {
104        10 * MEGABYTE
105    } else if file_size < SIZE_50GB {
106        50 * MEGABYTE
107    } else {
108        100 * MEGABYTE
109    };
110
111    let mut total_parts = file_size.div_ceil(chunk_size);
112
113    if total_parts > MAX_PARTS_COUNT {
114        chunk_size = file_size.div_ceil(MAX_PARTS_COUNT);
115
116        if chunk_size > MAX_CHUNK_SIZE_BYTES {
117            let max_file_size = MAX_CHUNK_SIZE_BYTES * MAX_PARTS_COUNT;
118            return Err(AdaptiveError::FileTooLarge {
119                size: file_size,
120                max_size: max_file_size,
121            });
122        }
123
124        total_parts = file_size.div_ceil(chunk_size);
125    }
126
127    Ok((total_parts, chunk_size))
128}
129
130type IdOrKey = String;
131#[allow(clippy::upper_case_acronyms)]
132type UUID = Uuid;
133type JsObject = Map<String, Value>;
134type InputDatetime = String;
135#[allow(clippy::upper_case_acronyms)]
136type JSON = Value;
137type KeyInput = String;
138
139#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
140pub struct Timestamp(pub SystemTime);
141
142impl<'de> serde::Deserialize<'de> for Timestamp {
143    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
144    where
145        D: serde::Deserializer<'de>,
146    {
147        let system_time = serde_utils::deserialize_timestamp_millis(deserializer)?;
148        Ok(Timestamp(system_time))
149    }
150}
151
152const PAGE_SIZE: usize = 20;
153
154#[derive(Debug, PartialEq, Serialize, Deserialize)]
155pub struct Upload(usize);
156
157#[derive(GraphQLQuery)]
158#[graphql(
159    schema_path = "schema.gql",
160    query_path = "src/graphql/list.graphql",
161    response_derives = "Debug, Clone"
162)]
163pub struct GetCustomRecipes;
164
165#[derive(GraphQLQuery)]
166#[graphql(
167    schema_path = "schema.gql",
168    query_path = "src/graphql/job.graphql",
169    response_derives = "Debug, Clone"
170)]
171pub struct GetJob;
172
173#[derive(GraphQLQuery)]
174#[graphql(
175    schema_path = "schema.gql",
176    query_path = "src/graphql/jobs.graphql",
177    response_derives = "Debug, Clone"
178)]
179pub struct ListJobs;
180
181#[derive(GraphQLQuery)]
182#[graphql(
183    schema_path = "schema.gql",
184    query_path = "src/graphql/cancel.graphql",
185    response_derives = "Debug, Clone"
186)]
187pub struct CancelJob;
188
189#[derive(GraphQLQuery)]
190#[graphql(
191    schema_path = "schema.gql",
192    query_path = "src/graphql/models.graphql",
193    response_derives = "Debug, Clone"
194)]
195pub struct ListModels;
196
197#[derive(GraphQLQuery)]
198#[graphql(
199    schema_path = "schema.gql",
200    query_path = "src/graphql/all_models.graphql",
201    response_derives = "Debug, Clone"
202)]
203pub struct ListAllModels;
204
205impl Display for get_job::JobStatus {
206    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
207        match self {
208            get_job::JobStatus::PENDING => write!(f, "Pending"),
209            get_job::JobStatus::RUNNING => write!(f, "Running"),
210            get_job::JobStatus::COMPLETED => write!(f, "Completed"),
211            get_job::JobStatus::FAILED => write!(f, "Failed"),
212            get_job::JobStatus::CANCELED => write!(f, "Canceled"),
213            get_job::JobStatus::Other(_) => write!(f, "Unknown"),
214        }
215    }
216}
217
218#[derive(GraphQLQuery)]
219#[graphql(
220    schema_path = "schema.gql",
221    query_path = "src/graphql/publish.graphql",
222    response_derives = "Debug, Clone"
223)]
224pub struct PublishCustomRecipe;
225
226#[derive(GraphQLQuery)]
227#[graphql(
228    schema_path = "schema.gql",
229    query_path = "src/graphql/update_recipe.graphql",
230    response_derives = "Debug, Clone"
231)]
232pub struct UpdateCustomRecipe;
233
234#[derive(GraphQLQuery)]
235#[graphql(
236    schema_path = "schema.gql",
237    query_path = "src/graphql/upload_dataset.graphql",
238    response_derives = "Debug, Clone"
239)]
240pub struct UploadDataset;
241
242#[derive(GraphQLQuery)]
243#[graphql(
244    schema_path = "schema.gql",
245    query_path = "src/graphql/create_dataset_from_multipart.graphql",
246    response_derives = "Debug, Clone"
247)]
248pub struct CreateDatasetFromMultipart;
249
250#[derive(GraphQLQuery)]
251#[graphql(
252    schema_path = "schema.gql",
253    query_path = "src/graphql/run.graphql",
254    response_derives = "Debug, Clone"
255)]
256pub struct RunCustomRecipe;
257
258#[derive(GraphQLQuery)]
259#[graphql(
260    schema_path = "schema.gql",
261    query_path = "src/graphql/projects.graphql",
262    response_derives = "Debug, Clone"
263)]
264pub struct ListProjects;
265
266#[derive(GraphQLQuery)]
267#[graphql(
268    schema_path = "schema.gql",
269    query_path = "src/graphql/pools.graphql",
270    response_derives = "Debug, Clone"
271)]
272pub struct ListComputePools;
273
274#[derive(GraphQLQuery)]
275#[graphql(
276    schema_path = "schema.gql",
277    query_path = "src/graphql/recipe.graphql",
278    response_derives = "Debug, Clone"
279)]
280pub struct GetRecipe;
281
282#[derive(GraphQLQuery)]
283#[graphql(
284    schema_path = "schema.gql",
285    query_path = "src/graphql/grader.graphql",
286    response_derives = "Debug, Clone, Serialize"
287)]
288pub struct GetGrader;
289
290#[derive(GraphQLQuery)]
291#[graphql(
292    schema_path = "schema.gql",
293    query_path = "src/graphql/dataset.graphql",
294    response_derives = "Debug, Clone"
295)]
296pub struct GetDataset;
297
298#[derive(GraphQLQuery)]
299#[graphql(
300    schema_path = "schema.gql",
301    query_path = "src/graphql/model_config.graphql",
302    response_derives = "Debug, Clone, Serialize"
303)]
304pub struct GetModelConfig;
305
306#[derive(GraphQLQuery)]
307#[graphql(
308    schema_path = "schema.gql",
309    query_path = "src/graphql/job_progress.graphql",
310    response_derives = "Debug, Clone"
311)]
312pub struct UpdateJobProgress;
313
314const INIT_CHUNKED_UPLOAD_ROUTE: &str = "v1/upload/init";
315const UPLOAD_PART_ROUTE: &str = "v1/upload/part";
316const ABORT_CHUNKED_UPLOAD_ROUTE: &str = "v1/upload/abort";
317
318#[derive(Clone)]
319pub struct AdaptiveClient {
320    client: Client,
321    graphql_url: Url,
322    rest_base_url: Url,
323    auth_token: String,
324}
325
326impl AdaptiveClient {
327    pub fn new(api_base_url: Url, auth_token: String) -> Self {
328        let graphql_url = api_base_url
329            .join("graphql")
330            .expect("Failed to append graphql to base URL");
331
332        let client = Client::builder()
333            .user_agent(format!(
334                "adaptive-client-rust/{}",
335                env!("CARGO_PKG_VERSION")
336            ))
337            .build()
338            .expect("Failed to build HTTP client");
339
340        Self {
341            client,
342            graphql_url,
343            rest_base_url: api_base_url,
344            auth_token,
345        }
346    }
347
348    async fn execute_query<T>(&self, _query: T, variables: T::Variables) -> Result<T::ResponseData>
349    where
350        T: GraphQLQuery,
351        T::Variables: serde::Serialize,
352        T::ResponseData: DeserializeOwned,
353    {
354        let request_body = T::build_query(variables);
355
356        let response = self
357            .client
358            .post(self.graphql_url.clone())
359            .bearer_auth(&self.auth_token)
360            .json(&request_body)
361            .send()
362            .await?;
363
364        let status = response.status();
365        let response_text = response.text().await?;
366
367        if !status.is_success() {
368            return Err(AdaptiveError::HttpStatusError {
369                status: status.to_string(),
370                body: response_text,
371            });
372        }
373
374        let response_body: Response<T::ResponseData> = serde_json::from_str(&response_text)
375            .map_err(|e| AdaptiveError::JsonParseError {
376                error: e.to_string(),
377                body: response_text.chars().take(500).collect(),
378            })?;
379
380        match response_body.data {
381            Some(data) => Ok(data),
382            None => {
383                if let Some(errors) = response_body.errors {
384                    return Err(AdaptiveError::GraphQLErrors(errors));
385                }
386                Err(AdaptiveError::NoGraphQLData)
387            }
388        }
389    }
390
391    pub async fn list_recipes(
392        &self,
393        project: &str,
394    ) -> Result<Vec<get_custom_recipes::GetCustomRecipesCustomRecipes>> {
395        let variables = get_custom_recipes::Variables {
396            project: IdOrKey::from(project),
397        };
398
399        let response_data = self.execute_query(GetCustomRecipes, variables).await?;
400        Ok(response_data.custom_recipes)
401    }
402
403    pub async fn get_job(&self, job_id: Uuid) -> Result<get_job::GetJobJob> {
404        let variables = get_job::Variables { id: job_id };
405
406        let response_data = self.execute_query(GetJob, variables).await?;
407
408        match response_data.job {
409            Some(job) => Ok(job),
410            None => Err(AdaptiveError::JobNotFound(job_id)),
411        }
412    }
413
414    pub async fn upload_dataset<P: AsRef<Path>>(
415        &self,
416        project: &str,
417        name: &str,
418        dataset: P,
419    ) -> Result<upload_dataset::UploadDatasetCreateDataset> {
420        let variables = upload_dataset::Variables {
421            project: IdOrKey::from(project),
422            file: Upload(0),
423            name: Some(name.to_string()),
424        };
425
426        let operations = UploadDataset::build_query(variables);
427        let operations = serde_json::to_string(&operations)?;
428
429        let file_map = r#"{ "0": ["variables.file"] }"#;
430
431        let dataset_file = reqwest::multipart::Part::file(dataset).await?;
432
433        let form = reqwest::multipart::Form::new()
434            .text("operations", operations)
435            .text("map", file_map)
436            .part("0", dataset_file);
437
438        let response = self
439            .client
440            .post(self.graphql_url.clone())
441            .bearer_auth(&self.auth_token)
442            .multipart(form)
443            .send()
444            .await?;
445
446        let response: Response<<UploadDataset as graphql_client::GraphQLQuery>::ResponseData> =
447            response.json().await?;
448
449        match response.data {
450            Some(data) => Ok(data.create_dataset),
451            None => {
452                if let Some(errors) = response.errors {
453                    return Err(AdaptiveError::GraphQLErrors(errors));
454                }
455                Err(AdaptiveError::NoGraphQLData)
456            }
457        }
458    }
459
460    pub async fn publish_recipe<P: AsRef<Path>>(
461        &self,
462        project: &str,
463        name: &str,
464        key: &str,
465        recipe: P,
466    ) -> Result<publish_custom_recipe::PublishCustomRecipeCreateCustomRecipe> {
467        let variables = publish_custom_recipe::Variables {
468            project: IdOrKey::from(project),
469            file: Upload(0),
470            name: Some(name.to_string()),
471            key: Some(key.to_string()),
472        };
473
474        let operations = PublishCustomRecipe::build_query(variables);
475        let operations = serde_json::to_string(&operations)?;
476
477        let file_map = r#"{ "0": ["variables.file"] }"#;
478
479        let recipe_file = reqwest::multipart::Part::file(recipe).await?;
480
481        let form = reqwest::multipart::Form::new()
482            .text("operations", operations)
483            .text("map", file_map)
484            .part("0", recipe_file);
485
486        let response = self
487            .client
488            .post(self.graphql_url.clone())
489            .bearer_auth(&self.auth_token)
490            .multipart(form)
491            .send()
492            .await?;
493        let response: Response<
494            <PublishCustomRecipe as graphql_client::GraphQLQuery>::ResponseData,
495        > = response.json().await?;
496
497        match response.data {
498            Some(data) => Ok(data.create_custom_recipe),
499            None => {
500                if let Some(errors) = response.errors {
501                    return Err(AdaptiveError::GraphQLErrors(errors));
502                }
503                Err(AdaptiveError::NoGraphQLData)
504            }
505        }
506    }
507
508    pub async fn update_recipe<P: AsRef<Path>>(
509        &self,
510        project: &str,
511        id: &str,
512        name: Option<String>,
513        description: Option<String>,
514        labels: Option<Vec<update_custom_recipe::LabelInput>>,
515        recipe_file: Option<P>,
516    ) -> Result<update_custom_recipe::UpdateCustomRecipeUpdateCustomRecipe> {
517        let input = update_custom_recipe::UpdateRecipeInput {
518            name,
519            description,
520            labels,
521        };
522
523        match recipe_file {
524            Some(file_path) => {
525                let variables = update_custom_recipe::Variables {
526                    project: IdOrKey::from(project),
527                    id: IdOrKey::from(id),
528                    input,
529                    file: Some(Upload(0)),
530                };
531
532                let operations = UpdateCustomRecipe::build_query(variables);
533                let operations = serde_json::to_string(&operations)?;
534
535                let file_map = r#"{ "0": ["variables.file"] }"#;
536
537                let recipe_file = reqwest::multipart::Part::file(file_path).await?;
538
539                let form = reqwest::multipart::Form::new()
540                    .text("operations", operations)
541                    .text("map", file_map)
542                    .part("0", recipe_file);
543
544                let response = self
545                    .client
546                    .post(self.graphql_url.clone())
547                    .bearer_auth(&self.auth_token)
548                    .multipart(form)
549                    .send()
550                    .await?;
551                let response: Response<
552                    <UpdateCustomRecipe as graphql_client::GraphQLQuery>::ResponseData,
553                > = response.json().await?;
554
555                match response.data {
556                    Some(data) => Ok(data.update_custom_recipe),
557                    None => {
558                        if let Some(errors) = response.errors {
559                            return Err(AdaptiveError::GraphQLErrors(errors));
560                        }
561                        Err(AdaptiveError::NoGraphQLData)
562                    }
563                }
564            }
565            None => {
566                let variables = update_custom_recipe::Variables {
567                    project: IdOrKey::from(project),
568                    id: IdOrKey::from(id),
569                    input,
570                    file: None,
571                };
572
573                let response_data = self.execute_query(UpdateCustomRecipe, variables).await?;
574                Ok(response_data.update_custom_recipe)
575            }
576        }
577    }
578
579    pub async fn run_recipe(
580        &self,
581        project: &str,
582        recipe_id: &str,
583        parameters: Map<String, Value>,
584        name: Option<String>,
585        compute_pool: Option<String>,
586        num_gpus: u32,
587        use_experimental_runner: bool,
588    ) -> Result<run_custom_recipe::RunCustomRecipeCreateJob> {
589        let variables = run_custom_recipe::Variables {
590            input: run_custom_recipe::JobInput {
591                recipe: recipe_id.to_string(),
592                project: project.to_string(),
593                args: parameters,
594                name,
595                compute_pool,
596                num_gpus: num_gpus as i64,
597                use_experimental_runner,
598                max_cpu: None,
599                max_ram_gb: None,
600                max_duration_secs: None,
601            },
602        };
603
604        let response_data = self.execute_query(RunCustomRecipe, variables).await?;
605        Ok(response_data.create_job)
606    }
607
608    pub async fn list_jobs(
609        &self,
610        project: Option<String>,
611    ) -> Result<Vec<list_jobs::ListJobsJobsNodes>> {
612        let mut jobs = Vec::new();
613        let mut page = self.list_jobs_page(project.clone(), None).await?;
614        jobs.extend(page.nodes.iter().cloned());
615        while page.page_info.has_next_page {
616            page = self
617                .list_jobs_page(project.clone(), page.page_info.end_cursor)
618                .await?;
619            jobs.extend(page.nodes.iter().cloned());
620        }
621        Ok(jobs)
622    }
623
624    async fn list_jobs_page(
625        &self,
626        project: Option<String>,
627        after: Option<String>,
628    ) -> Result<list_jobs::ListJobsJobs> {
629        let variables = list_jobs::Variables {
630            filter: Some(list_jobs::ListJobsFilterInput {
631                project,
632                kind: Some(vec![list_jobs::JobKind::CUSTOM]),
633                status: Some(vec![
634                    list_jobs::JobStatus::RUNNING,
635                    list_jobs::JobStatus::PENDING,
636                ]),
637                timerange: None,
638                custom_recipes: None,
639                artifacts: None,
640            }),
641            cursor: Some(list_jobs::CursorPageInput {
642                first: Some(PAGE_SIZE as i64),
643                after,
644                before: None,
645                last: None,
646                offset: None,
647            }),
648        };
649
650        let response_data = self.execute_query(ListJobs, variables).await?;
651        Ok(response_data.jobs)
652    }
653
654    pub async fn cancel_job(&self, job_id: Uuid) -> Result<cancel_job::CancelJobCancelJob> {
655        let variables = cancel_job::Variables { job_id };
656
657        let response_data = self.execute_query(CancelJob, variables).await?;
658        Ok(response_data.cancel_job)
659    }
660
661    pub async fn update_job_progress(
662        &self,
663        job_id: Uuid,
664        event: update_job_progress::JobProgressEventInput,
665    ) -> Result<update_job_progress::UpdateJobProgressUpdateJobProgress> {
666        let variables = update_job_progress::Variables { job_id, event };
667
668        let response_data = self.execute_query(UpdateJobProgress, variables).await?;
669        Ok(response_data.update_job_progress)
670    }
671
672    pub async fn list_models(
673        &self,
674        project: String,
675    ) -> Result<Vec<list_models::ListModelsProjectModelServices>> {
676        let variables = list_models::Variables { project };
677
678        let response_data = self.execute_query(ListModels, variables).await?;
679        Ok(response_data
680            .project
681            .map(|project| project.model_services)
682            .unwrap_or(Vec::new()))
683    }
684
685    pub async fn list_all_models(&self) -> Result<Vec<list_all_models::ListAllModelsModels>> {
686        let variables = list_all_models::Variables {};
687
688        let response_data = self.execute_query(ListAllModels, variables).await?;
689        Ok(response_data.models)
690    }
691
692    pub async fn list_projects(&self) -> Result<Vec<list_projects::ListProjectsProjects>> {
693        let variables = list_projects::Variables {};
694
695        let response_data = self.execute_query(ListProjects, variables).await?;
696        Ok(response_data.projects)
697    }
698
699    pub async fn list_pools(
700        &self,
701    ) -> Result<Vec<list_compute_pools::ListComputePoolsComputePools>> {
702        let variables = list_compute_pools::Variables {};
703
704        let response_data = self.execute_query(ListComputePools, variables).await?;
705        Ok(response_data.compute_pools)
706    }
707
708    pub async fn get_recipe(
709        &self,
710        project: String,
711        id_or_key: String,
712    ) -> Result<Option<get_recipe::GetRecipeCustomRecipe>> {
713        let variables = get_recipe::Variables { project, id_or_key };
714
715        let response_data = self.execute_query(GetRecipe, variables).await?;
716        Ok(response_data.custom_recipe)
717    }
718
719    pub async fn get_grader(
720        &self,
721        id_or_key: &str,
722        project: &str,
723    ) -> Result<get_grader::GetGraderGrader> {
724        let variables = get_grader::Variables {
725            id: id_or_key.to_string(),
726            project: project.to_string(),
727        };
728
729        let response_data = self.execute_query(GetGrader, variables).await?;
730        Ok(response_data.grader)
731    }
732
733    pub async fn get_dataset(
734        &self,
735        id_or_key: &str,
736        project: &str,
737    ) -> Result<Option<get_dataset::GetDatasetDataset>> {
738        let variables = get_dataset::Variables {
739            id_or_key: id_or_key.to_string(),
740            project: project.to_string(),
741        };
742
743        let response_data = self.execute_query(GetDataset, variables).await?;
744        Ok(response_data.dataset)
745    }
746
747    pub async fn get_model_config(
748        &self,
749        id_or_key: &str,
750    ) -> Result<Option<get_model_config::GetModelConfigModel>> {
751        let variables = get_model_config::Variables {
752            id_or_key: id_or_key.to_string(),
753        };
754
755        let response_data = self.execute_query(GetModelConfig, variables).await?;
756        Ok(response_data.model)
757    }
758
759    pub fn base_url(&self) -> &Url {
760        &self.rest_base_url
761    }
762
763    /// Upload bytes using the chunked upload API and return the session_id.
764    /// This can be used to link the uploaded file to an artifact.
765    pub async fn upload_bytes(&self, data: &[u8], content_type: &str) -> Result<String> {
766        let file_size = data.len() as u64;
767
768        // Calculate chunk size (same logic as calculate_upload_parts but inline for small files)
769        let chunk_size = if file_size < 5 * 1024 * 1024 {
770            // For files < 5MB, use the whole file as one chunk
771            file_size.max(1)
772        } else if file_size < 500 * 1024 * 1024 {
773            5 * 1024 * 1024
774        } else if file_size < 10 * 1024 * 1024 * 1024 {
775            10 * 1024 * 1024
776        } else {
777            100 * 1024 * 1024
778        };
779
780        let total_parts = file_size.div_ceil(chunk_size).max(1);
781
782        // Initialize upload session
783        let session_id = self
784            .init_chunked_upload_with_content_type(total_parts, content_type)
785            .await?;
786
787        // Upload parts
788        for part_number in 1..=total_parts {
789            let start = ((part_number - 1) * chunk_size) as usize;
790            let end = (part_number * chunk_size).min(file_size) as usize;
791            let chunk = data[start..end].to_vec();
792
793            if let Err(e) = self
794                .upload_part_simple(&session_id, part_number, chunk)
795                .await
796            {
797                let _ = self.abort_chunked_upload(&session_id).await;
798                return Err(e);
799            }
800        }
801
802        Ok(session_id)
803    }
804
805    /// Initialize a chunked upload session.
806    pub async fn init_chunked_upload_with_content_type(
807        &self,
808        total_parts: u64,
809        content_type: &str,
810    ) -> Result<String> {
811        let url = self.rest_base_url.join(INIT_CHUNKED_UPLOAD_ROUTE)?;
812
813        let request = InitChunkedUploadRequest {
814            content_type: content_type.to_string(),
815            metadata: None,
816            total_parts_count: total_parts,
817        };
818
819        let response = self
820            .client
821            .post(url)
822            .bearer_auth(&self.auth_token)
823            .json(&request)
824            .send()
825            .await?;
826
827        if !response.status().is_success() {
828            return Err(AdaptiveError::ChunkedUploadInitFailed {
829                status: response.status().to_string(),
830                body: response.text().await.unwrap_or_default(),
831            });
832        }
833
834        let init_response: InitChunkedUploadResponse = response.json().await?;
835        Ok(init_response.session_id)
836    }
837
838    /// Upload a single part of a chunked upload.
839    pub async fn upload_part_simple(
840        &self,
841        session_id: &str,
842        part_number: u64,
843        data: Vec<u8>,
844    ) -> Result<()> {
845        let url = self.rest_base_url.join(UPLOAD_PART_ROUTE)?;
846
847        let response = self
848            .client
849            .post(url)
850            .bearer_auth(&self.auth_token)
851            .query(&[
852                ("session_id", session_id),
853                ("part_number", &part_number.to_string()),
854            ])
855            .header("Content-Type", "application/octet-stream")
856            .body(data)
857            .send()
858            .await?;
859
860        if !response.status().is_success() {
861            return Err(AdaptiveError::ChunkedUploadPartFailed {
862                part_number,
863                status: response.status().to_string(),
864                body: response.text().await.unwrap_or_default(),
865            });
866        }
867
868        Ok(())
869    }
870
871    async fn init_chunked_upload(&self, total_parts: u64) -> Result<String> {
872        let url = self.rest_base_url.join(INIT_CHUNKED_UPLOAD_ROUTE)?;
873
874        let request = InitChunkedUploadRequest {
875            content_type: "application/jsonl".to_string(),
876            metadata: None,
877            total_parts_count: total_parts,
878        };
879
880        let response = self
881            .client
882            .post(url)
883            .bearer_auth(&self.auth_token)
884            .json(&request)
885            .send()
886            .await?;
887
888        if !response.status().is_success() {
889            return Err(AdaptiveError::ChunkedUploadInitFailed {
890                status: response.status().to_string(),
891                body: response.text().await.unwrap_or_default(),
892            });
893        }
894
895        let init_response: InitChunkedUploadResponse = response.json().await?;
896        Ok(init_response.session_id)
897    }
898
899    async fn upload_part(
900        &self,
901        session_id: &str,
902        part_number: u64,
903        data: Vec<u8>,
904        progress_tx: mpsc::Sender<u64>,
905    ) -> Result<()> {
906        const SUB_CHUNK_SIZE: usize = 64 * 1024;
907
908        let url = self.rest_base_url.join(UPLOAD_PART_ROUTE)?;
909
910        let chunks: Vec<Vec<u8>> = data
911            .chunks(SUB_CHUNK_SIZE)
912            .map(|chunk| chunk.to_vec())
913            .collect();
914
915        let stream = futures::stream::iter(chunks).map(move |chunk| {
916            let len = chunk.len() as u64;
917            let tx = progress_tx.clone();
918            let _ = tx.try_send(len);
919            Ok::<_, std::io::Error>(chunk)
920        });
921
922        let body = reqwest::Body::wrap_stream(stream);
923
924        let response = self
925            .client
926            .post(url)
927            .bearer_auth(&self.auth_token)
928            .query(&[
929                ("session_id", session_id),
930                ("part_number", &part_number.to_string()),
931            ])
932            .header("Content-Type", "application/octet-stream")
933            .body(body)
934            .send()
935            .await?;
936
937        if !response.status().is_success() {
938            return Err(AdaptiveError::ChunkedUploadPartFailed {
939                part_number,
940                status: response.status().to_string(),
941                body: response.text().await.unwrap_or_default(),
942            });
943        }
944
945        Ok(())
946    }
947
948    /// Abort a chunked upload session.
949    pub async fn abort_chunked_upload(&self, session_id: &str) -> Result<()> {
950        let url = self.rest_base_url.join(ABORT_CHUNKED_UPLOAD_ROUTE)?;
951
952        let request = AbortChunkedUploadRequest {
953            session_id: session_id.to_string(),
954        };
955
956        let _ = self
957            .client
958            .delete(url)
959            .bearer_auth(&self.auth_token)
960            .json(&request)
961            .send()
962            .await;
963
964        Ok(())
965    }
966
967    async fn create_dataset_from_multipart(
968        &self,
969        project: &str,
970        name: &str,
971        key: &str,
972        session_id: &str,
973    ) -> Result<
974        create_dataset_from_multipart::CreateDatasetFromMultipartCreateDatasetFromMultipartUpload,
975    > {
976        let variables = create_dataset_from_multipart::Variables {
977            input: create_dataset_from_multipart::DatasetCreateFromMultipartUpload {
978                project: project.to_string(),
979                name: name.to_string(),
980                key: Some(key.to_string()),
981                source: None,
982                upload_session_id: session_id.to_string(),
983            },
984        };
985
986        let response_data = self
987            .execute_query(CreateDatasetFromMultipart, variables)
988            .await?;
989        Ok(response_data.create_dataset_from_multipart_upload)
990    }
991
992    pub fn chunked_upload_dataset<'a, P: AsRef<Path> + Send + 'a>(
993        &'a self,
994        project: &'a str,
995        name: &'a str,
996        key: &'a str,
997        dataset: P,
998    ) -> Result<BoxStream<'a, Result<UploadEvent>>> {
999        let file_size = std::fs::metadata(dataset.as_ref())?.len();
1000
1001        let (total_parts, chunk_size) = calculate_upload_parts(file_size)?;
1002
1003        let stream = async_stream::try_stream! {
1004            yield UploadEvent::Progress(ChunkedUploadProgress {
1005                bytes_uploaded: 0,
1006                total_bytes: file_size,
1007            });
1008
1009            let session_id = self.init_chunked_upload(total_parts).await?;
1010
1011            let mut file = File::open(dataset.as_ref())?;
1012            let mut buffer = vec![0u8; chunk_size as usize];
1013            let mut bytes_uploaded = 0u64;
1014
1015            let (progress_tx, mut progress_rx) = mpsc::channel::<u64>(64);
1016
1017            for part_number in 1..=total_parts {
1018                let bytes_read = file.read(&mut buffer)?;
1019                let chunk_data = buffer[..bytes_read].to_vec();
1020
1021                let upload_fut = self.upload_part(&session_id, part_number, chunk_data, progress_tx.clone());
1022                tokio::pin!(upload_fut);
1023
1024                let upload_result: Result<()> = loop {
1025                    tokio::select! {
1026                        biased;
1027                        result = &mut upload_fut => {
1028                            break result;
1029                        }
1030                        Some(bytes) = progress_rx.recv() => {
1031                            bytes_uploaded += bytes;
1032                            yield UploadEvent::Progress(ChunkedUploadProgress {
1033                                bytes_uploaded,
1034                                total_bytes: file_size,
1035                            });
1036                        }
1037                    }
1038                };
1039
1040                if let Err(e) = upload_result {
1041                    let _ = self.abort_chunked_upload(&session_id).await;
1042                    Err(e)?;
1043                }
1044            }
1045
1046            let create_result = self
1047                .create_dataset_from_multipart(project, name, key, &session_id)
1048                .await;
1049
1050            match create_result {
1051                Ok(response) => {
1052                    yield UploadEvent::Complete(response);
1053                }
1054                Err(e) => {
1055                    let _ = self.abort_chunked_upload(&session_id).await;
1056                    Err(AdaptiveError::DatasetCreationFailed(e.to_string()))?;
1057                }
1058            }
1059        };
1060
1061        Ok(Box::pin(stream))
1062    }
1063
1064    /// Download a file from the given URL and write it to the specified path.
1065    /// The URL can be absolute or relative to the API base URL.
1066    pub async fn download_file_to_path(&self, url: &str, dest_path: &Path) -> Result<()> {
1067        use tokio::io::AsyncWriteExt;
1068
1069        let full_url = if url.starts_with("http://") || url.starts_with("https://") {
1070            Url::parse(url)?
1071        } else {
1072            self.rest_base_url.join(url)?
1073        };
1074
1075        let response = self
1076            .client
1077            .get(full_url)
1078            .bearer_auth(&self.auth_token)
1079            .send()
1080            .await?;
1081
1082        if !response.status().is_success() {
1083            return Err(AdaptiveError::HttpError(
1084                response.error_for_status().unwrap_err(),
1085            ));
1086        }
1087
1088        let mut file = tokio::fs::File::create(dest_path).await?;
1089        let mut stream = response.bytes_stream();
1090
1091        while let Some(chunk) = stream.next().await {
1092            let chunk = chunk?;
1093            file.write_all(&chunk).await?;
1094        }
1095
1096        file.flush().await?;
1097        Ok(())
1098    }
1099}