Skip to main content

adaptive_client_rust/
lib.rs

1use std::{fmt::Display, fs::File, io::Read, path::Path, time::SystemTime};
2
3use futures::{StreamExt, stream::BoxStream};
4use thiserror::Error;
5use tokio::sync::mpsc;
6
7use graphql_client::{GraphQLQuery, Response};
8use reqwest::Client;
9use serde::{Deserialize, Serialize, de::DeserializeOwned};
10use serde_json::{Map, Value};
11use url::Url;
12use uuid::Uuid;
13
14mod rest_types;
15mod serde_utils;
16
17use rest_types::{AbortChunkedUploadRequest, InitChunkedUploadRequest, InitChunkedUploadResponse};
18
19const MEGABYTE: u64 = 1024 * 1024; // 1MB
20pub const MIN_CHUNK_SIZE_BYTES: u64 = 5 * MEGABYTE;
21const MAX_CHUNK_SIZE_BYTES: u64 = 100 * MEGABYTE;
22const MAX_PARTS_COUNT: u64 = 10000;
23
24const SIZE_500MB: u64 = 500 * MEGABYTE;
25const SIZE_10GB: u64 = 10 * 1024 * MEGABYTE;
26const SIZE_50GB: u64 = 50 * 1024 * MEGABYTE;
27
28#[derive(Error, Debug)]
29pub enum AdaptiveError {
30    #[error("HTTP error: {0}")]
31    HttpError(#[from] reqwest::Error),
32
33    #[error("JSON error: {0}")]
34    JsonError(#[from] serde_json::Error),
35
36    #[error("IO error: {0}")]
37    IoError(#[from] std::io::Error),
38
39    #[error("URL parse error: {0}")]
40    UrlParseError(#[from] url::ParseError),
41
42    #[error("File too small for chunked upload: {size} bytes (minimum: {min_size} bytes)")]
43    FileTooSmall { size: u64, min_size: u64 },
44
45    #[error("File too large: {size} bytes exceeds maximum {max_size} bytes")]
46    FileTooLarge { size: u64, max_size: u64 },
47
48    #[error("GraphQL errors: {0:?}")]
49    GraphQLErrors(Vec<graphql_client::Error>),
50
51    #[error("No data returned from GraphQL")]
52    NoGraphQLData,
53
54    #[error("Job not found: {0}")]
55    JobNotFound(Uuid),
56
57    #[error("Failed to initialize chunked upload: {status} - {body}")]
58    ChunkedUploadInitFailed { status: String, body: String },
59
60    #[error("Failed to upload part {part_number}: {status} - {body}")]
61    ChunkedUploadPartFailed {
62        part_number: u64,
63        status: String,
64        body: String,
65    },
66
67    #[error("Failed to create dataset: {0}")]
68    DatasetCreationFailed(String),
69
70    #[error("HTTP status error: {status} - {body}")]
71    HttpStatusError { status: String, body: String },
72
73    #[error("Failed to parse JSON response: {error}. Body preview: {body}")]
74    JsonParseError { error: String, body: String },
75}
76
77type Result<T> = std::result::Result<T, AdaptiveError>;
78
79#[derive(Clone, Debug, Default)]
80pub struct ChunkedUploadProgress {
81    pub bytes_uploaded: u64,
82    pub total_bytes: u64,
83}
84
85#[derive(Debug)]
86pub enum UploadEvent {
87    Progress(ChunkedUploadProgress),
88    Complete(
89        create_dataset_from_multipart::CreateDatasetFromMultipartCreateDatasetFromMultipartUpload,
90    ),
91}
92
93pub fn calculate_upload_parts(file_size: u64) -> Result<(u64, u64)> {
94    if file_size < MIN_CHUNK_SIZE_BYTES {
95        return Err(AdaptiveError::FileTooSmall {
96            size: file_size,
97            min_size: MIN_CHUNK_SIZE_BYTES,
98        });
99    }
100
101    let mut chunk_size = if file_size < SIZE_500MB {
102        5 * MEGABYTE
103    } else if file_size < SIZE_10GB {
104        10 * MEGABYTE
105    } else if file_size < SIZE_50GB {
106        50 * MEGABYTE
107    } else {
108        100 * MEGABYTE
109    };
110
111    let mut total_parts = file_size.div_ceil(chunk_size);
112
113    if total_parts > MAX_PARTS_COUNT {
114        chunk_size = file_size.div_ceil(MAX_PARTS_COUNT);
115
116        if chunk_size > MAX_CHUNK_SIZE_BYTES {
117            let max_file_size = MAX_CHUNK_SIZE_BYTES * MAX_PARTS_COUNT;
118            return Err(AdaptiveError::FileTooLarge {
119                size: file_size,
120                max_size: max_file_size,
121            });
122        }
123
124        total_parts = file_size.div_ceil(chunk_size);
125    }
126
127    Ok((total_parts, chunk_size))
128}
129
130type IdOrKey = String;
131#[allow(clippy::upper_case_acronyms)]
132type UUID = Uuid;
133type JsObject = Map<String, Value>;
134type InputDatetime = String;
135#[allow(clippy::upper_case_acronyms)]
136type JSON = Value;
137#[allow(clippy::upper_case_acronyms)]
138type JSONObject = Map<String, Value>;
139type KeyInput = String;
140
141#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
142pub struct Timestamp(pub SystemTime);
143
144impl<'de> serde::Deserialize<'de> for Timestamp {
145    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
146    where
147        D: serde::Deserializer<'de>,
148    {
149        let system_time = serde_utils::deserialize_timestamp_millis(deserializer)?;
150        Ok(Timestamp(system_time))
151    }
152}
153
154const PAGE_SIZE: usize = 20;
155
156#[derive(Debug, PartialEq, Serialize, Deserialize)]
157pub struct Upload(usize);
158
159#[derive(GraphQLQuery)]
160#[graphql(
161    schema_path = "schema.gql",
162    query_path = "src/graphql/list.graphql",
163    response_derives = "Debug, Clone"
164)]
165pub struct GetCustomRecipes;
166
167#[derive(GraphQLQuery)]
168#[graphql(
169    schema_path = "schema.gql",
170    query_path = "src/graphql/job.graphql",
171    response_derives = "Debug, Clone"
172)]
173pub struct GetJob;
174
175#[derive(GraphQLQuery)]
176#[graphql(
177    schema_path = "schema.gql",
178    query_path = "src/graphql/jobs.graphql",
179    response_derives = "Debug, Clone"
180)]
181pub struct ListJobs;
182
183#[derive(GraphQLQuery)]
184#[graphql(
185    schema_path = "schema.gql",
186    query_path = "src/graphql/cancel.graphql",
187    response_derives = "Debug, Clone"
188)]
189pub struct CancelJob;
190
191#[derive(GraphQLQuery)]
192#[graphql(
193    schema_path = "schema.gql",
194    query_path = "src/graphql/models.graphql",
195    response_derives = "Debug, Clone"
196)]
197pub struct ListModels;
198
199#[derive(GraphQLQuery)]
200#[graphql(
201    schema_path = "schema.gql",
202    query_path = "src/graphql/all_models.graphql",
203    response_derives = "Debug, Clone"
204)]
205pub struct ListAllModels;
206
207impl Display for get_job::JobStatus {
208    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
209        match self {
210            get_job::JobStatus::PENDING => write!(f, "Pending"),
211            get_job::JobStatus::RUNNING => write!(f, "Running"),
212            get_job::JobStatus::COMPLETED => write!(f, "Completed"),
213            get_job::JobStatus::FAILED => write!(f, "Failed"),
214            get_job::JobStatus::CANCELED => write!(f, "Canceled"),
215            get_job::JobStatus::Other(_) => write!(f, "Unknown"),
216        }
217    }
218}
219
220#[derive(GraphQLQuery)]
221#[graphql(
222    schema_path = "schema.gql",
223    query_path = "src/graphql/publish.graphql",
224    response_derives = "Debug, Clone"
225)]
226pub struct PublishCustomRecipe;
227
228#[derive(GraphQLQuery)]
229#[graphql(
230    schema_path = "schema.gql",
231    query_path = "src/graphql/update_recipe.graphql",
232    response_derives = "Debug, Clone"
233)]
234pub struct UpdateCustomRecipe;
235
236#[derive(GraphQLQuery)]
237#[graphql(
238    schema_path = "schema.gql",
239    query_path = "src/graphql/upload_dataset.graphql",
240    response_derives = "Debug, Clone"
241)]
242pub struct UploadDataset;
243
244#[derive(GraphQLQuery)]
245#[graphql(
246    schema_path = "schema.gql",
247    query_path = "src/graphql/create_dataset_from_multipart.graphql",
248    response_derives = "Debug, Clone"
249)]
250pub struct CreateDatasetFromMultipart;
251
252#[derive(GraphQLQuery)]
253#[graphql(
254    schema_path = "schema.gql",
255    query_path = "src/graphql/run.graphql",
256    response_derives = "Debug, Clone"
257)]
258pub struct RunCustomRecipe;
259
260#[derive(GraphQLQuery)]
261#[graphql(
262    schema_path = "schema.gql",
263    query_path = "src/graphql/projects.graphql",
264    response_derives = "Debug, Clone"
265)]
266pub struct ListProjects;
267
268#[derive(GraphQLQuery)]
269#[graphql(
270    schema_path = "schema.gql",
271    query_path = "src/graphql/pools.graphql",
272    response_derives = "Debug, Clone"
273)]
274pub struct ListComputePools;
275
276#[derive(GraphQLQuery)]
277#[graphql(
278    schema_path = "schema.gql",
279    query_path = "src/graphql/recipe.graphql",
280    response_derives = "Debug, Clone"
281)]
282pub struct GetRecipe;
283
284#[derive(GraphQLQuery)]
285#[graphql(
286    schema_path = "schema.gql",
287    query_path = "src/graphql/grader.graphql",
288    response_derives = "Debug, Clone, Serialize"
289)]
290pub struct GetGrader;
291
292#[derive(GraphQLQuery)]
293#[graphql(
294    schema_path = "schema.gql",
295    query_path = "src/graphql/dataset.graphql",
296    response_derives = "Debug, Clone"
297)]
298pub struct GetDataset;
299
300#[derive(GraphQLQuery)]
301#[graphql(
302    schema_path = "schema.gql",
303    query_path = "src/graphql/model_config.graphql",
304    response_derives = "Debug, Clone, Serialize"
305)]
306pub struct GetModelConfig;
307
308#[derive(GraphQLQuery)]
309#[graphql(
310    schema_path = "schema.gql",
311    query_path = "src/graphql/artifact.graphql",
312    response_derives = "Debug, Clone"
313)]
314pub struct GetArtifact;
315
316#[derive(GraphQLQuery)]
317#[graphql(
318    schema_path = "schema.gql",
319    query_path = "src/graphql/job_progress.graphql",
320    response_derives = "Debug, Clone"
321)]
322pub struct UpdateJobProgress;
323
324const INIT_CHUNKED_UPLOAD_ROUTE: &str = "v1/upload/init";
325const UPLOAD_PART_ROUTE: &str = "v1/upload/part";
326const ABORT_CHUNKED_UPLOAD_ROUTE: &str = "v1/upload/abort";
327
328#[derive(Clone)]
329pub struct AdaptiveClient {
330    client: Client,
331    graphql_url: Url,
332    rest_base_url: Url,
333    auth_token: String,
334}
335
336impl AdaptiveClient {
337    pub fn new(api_base_url: Url, auth_token: String) -> Self {
338        let graphql_url = api_base_url
339            .join("graphql")
340            .expect("Failed to append graphql to base URL");
341
342        let client = Client::builder()
343            .user_agent(format!(
344                "adaptive-client-rust/{}",
345                env!("CARGO_PKG_VERSION")
346            ))
347            .build()
348            .expect("Failed to build HTTP client");
349
350        Self {
351            client,
352            graphql_url,
353            rest_base_url: api_base_url,
354            auth_token,
355        }
356    }
357
358    async fn execute_query<T>(&self, _query: T, variables: T::Variables) -> Result<T::ResponseData>
359    where
360        T: GraphQLQuery,
361        T::Variables: serde::Serialize,
362        T::ResponseData: DeserializeOwned,
363    {
364        let request_body = T::build_query(variables);
365
366        let response = self
367            .client
368            .post(self.graphql_url.clone())
369            .bearer_auth(&self.auth_token)
370            .json(&request_body)
371            .send()
372            .await?;
373
374        let status = response.status();
375        let response_text = response.text().await?;
376
377        if !status.is_success() {
378            return Err(AdaptiveError::HttpStatusError {
379                status: status.to_string(),
380                body: response_text,
381            });
382        }
383
384        let response_body: Response<T::ResponseData> = serde_json::from_str(&response_text)
385            .map_err(|e| AdaptiveError::JsonParseError {
386                error: e.to_string(),
387                body: response_text.chars().take(500).collect(),
388            })?;
389
390        match response_body.data {
391            Some(data) => Ok(data),
392            None => {
393                if let Some(errors) = response_body.errors {
394                    return Err(AdaptiveError::GraphQLErrors(errors));
395                }
396                Err(AdaptiveError::NoGraphQLData)
397            }
398        }
399    }
400
401    pub async fn list_recipes(
402        &self,
403        project: &str,
404    ) -> Result<Vec<get_custom_recipes::GetCustomRecipesCustomRecipes>> {
405        let variables = get_custom_recipes::Variables {
406            project: IdOrKey::from(project),
407        };
408
409        let response_data = self.execute_query(GetCustomRecipes, variables).await?;
410        Ok(response_data.custom_recipes)
411    }
412
413    pub async fn get_job(&self, job_id: Uuid) -> Result<get_job::GetJobJob> {
414        let variables = get_job::Variables { id: job_id };
415
416        let response_data = self.execute_query(GetJob, variables).await?;
417
418        match response_data.job {
419            Some(job) => Ok(job),
420            None => Err(AdaptiveError::JobNotFound(job_id)),
421        }
422    }
423
424    pub async fn upload_dataset<P: AsRef<Path>>(
425        &self,
426        project: &str,
427        name: &str,
428        dataset: P,
429    ) -> Result<upload_dataset::UploadDatasetCreateDataset> {
430        let variables = upload_dataset::Variables {
431            project: IdOrKey::from(project),
432            file: Upload(0),
433            name: Some(name.to_string()),
434        };
435
436        let operations = UploadDataset::build_query(variables);
437        let operations = serde_json::to_string(&operations)?;
438
439        let file_map = r#"{ "0": ["variables.file"] }"#;
440
441        let dataset_file = reqwest::multipart::Part::file(dataset).await?;
442
443        let form = reqwest::multipart::Form::new()
444            .text("operations", operations)
445            .text("map", file_map)
446            .part("0", dataset_file);
447
448        let response = self
449            .client
450            .post(self.graphql_url.clone())
451            .bearer_auth(&self.auth_token)
452            .multipart(form)
453            .send()
454            .await?;
455
456        let response: Response<<UploadDataset as graphql_client::GraphQLQuery>::ResponseData> =
457            response.json().await?;
458
459        match response.data {
460            Some(data) => Ok(data.create_dataset),
461            None => {
462                if let Some(errors) = response.errors {
463                    return Err(AdaptiveError::GraphQLErrors(errors));
464                }
465                Err(AdaptiveError::NoGraphQLData)
466            }
467        }
468    }
469
470    pub async fn publish_recipe<P: AsRef<Path>>(
471        &self,
472        project: &str,
473        name: &str,
474        key: &str,
475        recipe: P,
476    ) -> Result<publish_custom_recipe::PublishCustomRecipeCreateCustomRecipe> {
477        let variables = publish_custom_recipe::Variables {
478            project: IdOrKey::from(project),
479            file: Upload(0),
480            name: Some(name.to_string()),
481            key: Some(key.to_string()),
482        };
483
484        let operations = PublishCustomRecipe::build_query(variables);
485        let operations = serde_json::to_string(&operations)?;
486
487        let file_map = r#"{ "0": ["variables.file"] }"#;
488
489        let recipe_file = reqwest::multipart::Part::file(recipe).await?;
490
491        let form = reqwest::multipart::Form::new()
492            .text("operations", operations)
493            .text("map", file_map)
494            .part("0", recipe_file);
495
496        let response = self
497            .client
498            .post(self.graphql_url.clone())
499            .bearer_auth(&self.auth_token)
500            .multipart(form)
501            .send()
502            .await?;
503        let response: Response<
504            <PublishCustomRecipe as graphql_client::GraphQLQuery>::ResponseData,
505        > = response.json().await?;
506
507        match response.data {
508            Some(data) => Ok(data.create_custom_recipe),
509            None => {
510                if let Some(errors) = response.errors {
511                    return Err(AdaptiveError::GraphQLErrors(errors));
512                }
513                Err(AdaptiveError::NoGraphQLData)
514            }
515        }
516    }
517
518    pub async fn update_recipe<P: AsRef<Path>>(
519        &self,
520        project: &str,
521        id: &str,
522        name: Option<String>,
523        description: Option<String>,
524        labels: Option<Vec<update_custom_recipe::LabelInput>>,
525        recipe_file: Option<P>,
526    ) -> Result<update_custom_recipe::UpdateCustomRecipeUpdateCustomRecipe> {
527        let input = update_custom_recipe::UpdateRecipeInput {
528            name,
529            description,
530            labels,
531        };
532
533        match recipe_file {
534            Some(file_path) => {
535                let variables = update_custom_recipe::Variables {
536                    project: IdOrKey::from(project),
537                    id: IdOrKey::from(id),
538                    input,
539                    file: Some(Upload(0)),
540                };
541
542                let operations = UpdateCustomRecipe::build_query(variables);
543                let operations = serde_json::to_string(&operations)?;
544
545                let file_map = r#"{ "0": ["variables.file"] }"#;
546
547                let recipe_file = reqwest::multipart::Part::file(file_path).await?;
548
549                let form = reqwest::multipart::Form::new()
550                    .text("operations", operations)
551                    .text("map", file_map)
552                    .part("0", recipe_file);
553
554                let response = self
555                    .client
556                    .post(self.graphql_url.clone())
557                    .bearer_auth(&self.auth_token)
558                    .multipart(form)
559                    .send()
560                    .await?;
561                let response: Response<
562                    <UpdateCustomRecipe as graphql_client::GraphQLQuery>::ResponseData,
563                > = response.json().await?;
564
565                match response.data {
566                    Some(data) => Ok(data.update_custom_recipe),
567                    None => {
568                        if let Some(errors) = response.errors {
569                            return Err(AdaptiveError::GraphQLErrors(errors));
570                        }
571                        Err(AdaptiveError::NoGraphQLData)
572                    }
573                }
574            }
575            None => {
576                let variables = update_custom_recipe::Variables {
577                    project: IdOrKey::from(project),
578                    id: IdOrKey::from(id),
579                    input,
580                    file: None,
581                };
582
583                let response_data = self.execute_query(UpdateCustomRecipe, variables).await?;
584                Ok(response_data.update_custom_recipe)
585            }
586        }
587    }
588
589    pub async fn run_recipe(
590        &self,
591        project: &str,
592        recipe_id: &str,
593        parameters: Map<String, Value>,
594        name: Option<String>,
595        compute_pool: Option<String>,
596        num_gpus: u32,
597        use_experimental_runner: bool,
598    ) -> Result<run_custom_recipe::RunCustomRecipeCreateJob> {
599        let variables = run_custom_recipe::Variables {
600            input: run_custom_recipe::JobInput {
601                recipe: recipe_id.to_string(),
602                project: project.to_string(),
603                args: parameters,
604                name,
605                compute_pool,
606                num_gpus: num_gpus as i64,
607                use_experimental_runner,
608                max_cpu: None,
609                max_ram_gb: None,
610                max_duration_secs: None,
611                resume_artifact_id: None,
612            },
613        };
614
615        let response_data = self.execute_query(RunCustomRecipe, variables).await?;
616        Ok(response_data.create_job)
617    }
618
619    pub async fn list_jobs(
620        &self,
621        project: Option<String>,
622    ) -> Result<Vec<list_jobs::ListJobsJobsNodes>> {
623        let mut jobs = Vec::new();
624        let mut page = self.list_jobs_page(project.clone(), None).await?;
625        jobs.extend(page.nodes.iter().cloned());
626        while page.page_info.has_next_page {
627            page = self
628                .list_jobs_page(project.clone(), page.page_info.end_cursor)
629                .await?;
630            jobs.extend(page.nodes.iter().cloned());
631        }
632        Ok(jobs)
633    }
634
635    async fn list_jobs_page(
636        &self,
637        project: Option<String>,
638        after: Option<String>,
639    ) -> Result<list_jobs::ListJobsJobs> {
640        let variables = list_jobs::Variables {
641            filter: Some(list_jobs::ListJobsFilterInput {
642                project,
643                kind: Some(vec![list_jobs::JobKind::CUSTOM]),
644                status: Some(vec![
645                    list_jobs::JobStatus::RUNNING,
646                    list_jobs::JobStatus::PENDING,
647                ]),
648                timerange: None,
649                custom_recipes: None,
650                artifacts: None,
651            }),
652            cursor: Some(list_jobs::CursorPageInput {
653                first: Some(PAGE_SIZE as i64),
654                after,
655                before: None,
656                last: None,
657                offset: None,
658            }),
659        };
660
661        let response_data = self.execute_query(ListJobs, variables).await?;
662        Ok(response_data.jobs)
663    }
664
665    pub async fn cancel_job(&self, job_id: Uuid) -> Result<cancel_job::CancelJobCancelJob> {
666        let variables = cancel_job::Variables { job_id };
667
668        let response_data = self.execute_query(CancelJob, variables).await?;
669        Ok(response_data.cancel_job)
670    }
671
672    pub async fn update_job_progress(
673        &self,
674        job_id: Uuid,
675        event: update_job_progress::JobProgressEventInput,
676    ) -> Result<update_job_progress::UpdateJobProgressUpdateJobProgress> {
677        let variables = update_job_progress::Variables { job_id, event };
678
679        let response_data = self.execute_query(UpdateJobProgress, variables).await?;
680        Ok(response_data.update_job_progress)
681    }
682
683    pub async fn list_models(
684        &self,
685        project: String,
686    ) -> Result<Vec<list_models::ListModelsProjectModelServices>> {
687        let variables = list_models::Variables { project };
688
689        let response_data = self.execute_query(ListModels, variables).await?;
690        Ok(response_data
691            .project
692            .map(|project| project.model_services)
693            .unwrap_or(Vec::new()))
694    }
695
696    pub async fn list_all_models(&self) -> Result<Vec<list_all_models::ListAllModelsModels>> {
697        let variables = list_all_models::Variables {};
698
699        let response_data = self.execute_query(ListAllModels, variables).await?;
700        Ok(response_data.models)
701    }
702
703    pub async fn list_projects(&self) -> Result<Vec<list_projects::ListProjectsProjects>> {
704        let variables = list_projects::Variables {};
705
706        let response_data = self.execute_query(ListProjects, variables).await?;
707        Ok(response_data.projects)
708    }
709
710    pub async fn list_pools(
711        &self,
712    ) -> Result<Vec<list_compute_pools::ListComputePoolsComputePools>> {
713        let variables = list_compute_pools::Variables {};
714
715        let response_data = self.execute_query(ListComputePools, variables).await?;
716        Ok(response_data.compute_pools)
717    }
718
719    pub async fn get_recipe(
720        &self,
721        project: String,
722        id_or_key: String,
723    ) -> Result<Option<get_recipe::GetRecipeCustomRecipe>> {
724        let variables = get_recipe::Variables { project, id_or_key };
725
726        let response_data = self.execute_query(GetRecipe, variables).await?;
727        Ok(response_data.custom_recipe)
728    }
729
730    pub async fn get_grader(
731        &self,
732        id_or_key: &str,
733        project: &str,
734    ) -> Result<get_grader::GetGraderGrader> {
735        let variables = get_grader::Variables {
736            id: id_or_key.to_string(),
737            project: project.to_string(),
738        };
739
740        let response_data = self.execute_query(GetGrader, variables).await?;
741        Ok(response_data.grader)
742    }
743
744    pub async fn get_dataset(
745        &self,
746        id_or_key: &str,
747        project: &str,
748    ) -> Result<Option<get_dataset::GetDatasetDataset>> {
749        let variables = get_dataset::Variables {
750            id_or_key: id_or_key.to_string(),
751            project: project.to_string(),
752        };
753
754        let response_data = self.execute_query(GetDataset, variables).await?;
755        Ok(response_data.dataset)
756    }
757
758    pub async fn get_model_config(
759        &self,
760        id_or_key: &str,
761    ) -> Result<Option<get_model_config::GetModelConfigModel>> {
762        let variables = get_model_config::Variables {
763            id_or_key: id_or_key.to_string(),
764        };
765
766        let response_data = self.execute_query(GetModelConfig, variables).await?;
767        Ok(response_data.model)
768    }
769
770    pub async fn get_artifact(
771        &self,
772        project: &str,
773        id: Uuid,
774    ) -> Result<Option<get_artifact::GetArtifactArtifact>> {
775        let variables = get_artifact::Variables {
776            project: project.to_string(),
777            id,
778        };
779
780        let response_data = self.execute_query(GetArtifact, variables).await?;
781        Ok(response_data.artifact)
782    }
783
784    pub fn base_url(&self) -> &Url {
785        &self.rest_base_url
786    }
787
788    /// Upload bytes using the chunked upload API and return the session_id.
789    /// This can be used to link the uploaded file to an artifact.
790    pub async fn upload_bytes(&self, data: &[u8], content_type: &str) -> Result<String> {
791        let file_size = data.len() as u64;
792
793        // Calculate chunk size (same logic as calculate_upload_parts but inline for small files)
794        let chunk_size = if file_size < 5 * 1024 * 1024 {
795            // For files < 5MB, use the whole file as one chunk
796            file_size.max(1)
797        } else if file_size < 500 * 1024 * 1024 {
798            5 * 1024 * 1024
799        } else if file_size < 10 * 1024 * 1024 * 1024 {
800            10 * 1024 * 1024
801        } else {
802            100 * 1024 * 1024
803        };
804
805        let total_parts = file_size.div_ceil(chunk_size).max(1);
806
807        // Initialize upload session
808        let session_id = self
809            .init_chunked_upload_with_content_type(total_parts, content_type)
810            .await?;
811
812        // Upload parts
813        for part_number in 1..=total_parts {
814            let start = ((part_number - 1) * chunk_size) as usize;
815            let end = (part_number * chunk_size).min(file_size) as usize;
816            let chunk = data[start..end].to_vec();
817
818            if let Err(e) = self
819                .upload_part_simple(&session_id, part_number, chunk)
820                .await
821            {
822                let _ = self.abort_chunked_upload(&session_id).await;
823                return Err(e);
824            }
825        }
826
827        Ok(session_id)
828    }
829
830    /// Initialize a chunked upload session.
831    pub async fn init_chunked_upload_with_content_type(
832        &self,
833        total_parts: u64,
834        content_type: &str,
835    ) -> Result<String> {
836        let url = self.rest_base_url.join(INIT_CHUNKED_UPLOAD_ROUTE)?;
837
838        let request = InitChunkedUploadRequest {
839            content_type: content_type.to_string(),
840            metadata: None,
841            total_parts_count: total_parts,
842        };
843
844        let response = self
845            .client
846            .post(url)
847            .bearer_auth(&self.auth_token)
848            .json(&request)
849            .send()
850            .await?;
851
852        if !response.status().is_success() {
853            return Err(AdaptiveError::ChunkedUploadInitFailed {
854                status: response.status().to_string(),
855                body: response.text().await.unwrap_or_default(),
856            });
857        }
858
859        let init_response: InitChunkedUploadResponse = response.json().await?;
860        Ok(init_response.session_id)
861    }
862
863    /// Upload a single part of a chunked upload.
864    pub async fn upload_part_simple(
865        &self,
866        session_id: &str,
867        part_number: u64,
868        data: Vec<u8>,
869    ) -> Result<()> {
870        let url = self.rest_base_url.join(UPLOAD_PART_ROUTE)?;
871
872        let response = self
873            .client
874            .post(url)
875            .bearer_auth(&self.auth_token)
876            .query(&[
877                ("session_id", session_id),
878                ("part_number", &part_number.to_string()),
879            ])
880            .header("Content-Type", "application/octet-stream")
881            .body(data)
882            .send()
883            .await?;
884
885        if !response.status().is_success() {
886            return Err(AdaptiveError::ChunkedUploadPartFailed {
887                part_number,
888                status: response.status().to_string(),
889                body: response.text().await.unwrap_or_default(),
890            });
891        }
892
893        Ok(())
894    }
895
896    async fn init_chunked_upload(&self, total_parts: u64) -> Result<String> {
897        let url = self.rest_base_url.join(INIT_CHUNKED_UPLOAD_ROUTE)?;
898
899        let request = InitChunkedUploadRequest {
900            content_type: "application/jsonl".to_string(),
901            metadata: None,
902            total_parts_count: total_parts,
903        };
904
905        let response = self
906            .client
907            .post(url)
908            .bearer_auth(&self.auth_token)
909            .json(&request)
910            .send()
911            .await?;
912
913        if !response.status().is_success() {
914            return Err(AdaptiveError::ChunkedUploadInitFailed {
915                status: response.status().to_string(),
916                body: response.text().await.unwrap_or_default(),
917            });
918        }
919
920        let init_response: InitChunkedUploadResponse = response.json().await?;
921        Ok(init_response.session_id)
922    }
923
924    async fn upload_part(
925        &self,
926        session_id: &str,
927        part_number: u64,
928        data: Vec<u8>,
929        progress_tx: mpsc::Sender<u64>,
930    ) -> Result<()> {
931        const SUB_CHUNK_SIZE: usize = 64 * 1024;
932
933        let url = self.rest_base_url.join(UPLOAD_PART_ROUTE)?;
934
935        let chunks: Vec<Vec<u8>> = data
936            .chunks(SUB_CHUNK_SIZE)
937            .map(|chunk| chunk.to_vec())
938            .collect();
939
940        let stream = futures::stream::iter(chunks).map(move |chunk| {
941            let len = chunk.len() as u64;
942            let tx = progress_tx.clone();
943            let _ = tx.try_send(len);
944            Ok::<_, std::io::Error>(chunk)
945        });
946
947        let body = reqwest::Body::wrap_stream(stream);
948
949        let response = self
950            .client
951            .post(url)
952            .bearer_auth(&self.auth_token)
953            .query(&[
954                ("session_id", session_id),
955                ("part_number", &part_number.to_string()),
956            ])
957            .header("Content-Type", "application/octet-stream")
958            .body(body)
959            .send()
960            .await?;
961
962        if !response.status().is_success() {
963            return Err(AdaptiveError::ChunkedUploadPartFailed {
964                part_number,
965                status: response.status().to_string(),
966                body: response.text().await.unwrap_or_default(),
967            });
968        }
969
970        Ok(())
971    }
972
973    /// Abort a chunked upload session.
974    pub async fn abort_chunked_upload(&self, session_id: &str) -> Result<()> {
975        let url = self.rest_base_url.join(ABORT_CHUNKED_UPLOAD_ROUTE)?;
976
977        let request = AbortChunkedUploadRequest {
978            session_id: session_id.to_string(),
979        };
980
981        let _ = self
982            .client
983            .delete(url)
984            .bearer_auth(&self.auth_token)
985            .json(&request)
986            .send()
987            .await;
988
989        Ok(())
990    }
991
992    async fn create_dataset_from_multipart(
993        &self,
994        project: &str,
995        name: &str,
996        key: &str,
997        session_id: &str,
998    ) -> Result<
999        create_dataset_from_multipart::CreateDatasetFromMultipartCreateDatasetFromMultipartUpload,
1000    > {
1001        let variables = create_dataset_from_multipart::Variables {
1002            input: create_dataset_from_multipart::DatasetCreateFromMultipartUpload {
1003                project: project.to_string(),
1004                name: name.to_string(),
1005                key: Some(key.to_string()),
1006                source: None,
1007                upload_session_id: session_id.to_string(),
1008            },
1009        };
1010
1011        let response_data = self
1012            .execute_query(CreateDatasetFromMultipart, variables)
1013            .await?;
1014        Ok(response_data.create_dataset_from_multipart_upload)
1015    }
1016
1017    pub fn chunked_upload_dataset<'a, P: AsRef<Path> + Send + 'a>(
1018        &'a self,
1019        project: &'a str,
1020        name: &'a str,
1021        key: &'a str,
1022        dataset: P,
1023    ) -> Result<BoxStream<'a, Result<UploadEvent>>> {
1024        let file_size = std::fs::metadata(dataset.as_ref())?.len();
1025
1026        let (total_parts, chunk_size) = calculate_upload_parts(file_size)?;
1027
1028        let stream = async_stream::try_stream! {
1029            yield UploadEvent::Progress(ChunkedUploadProgress {
1030                bytes_uploaded: 0,
1031                total_bytes: file_size,
1032            });
1033
1034            let session_id = self.init_chunked_upload(total_parts).await?;
1035
1036            let mut file = File::open(dataset.as_ref())?;
1037            let mut buffer = vec![0u8; chunk_size as usize];
1038            let mut bytes_uploaded = 0u64;
1039
1040            let (progress_tx, mut progress_rx) = mpsc::channel::<u64>(64);
1041
1042            for part_number in 1..=total_parts {
1043                let bytes_read = file.read(&mut buffer)?;
1044                let chunk_data = buffer[..bytes_read].to_vec();
1045
1046                let upload_fut = self.upload_part(&session_id, part_number, chunk_data, progress_tx.clone());
1047                tokio::pin!(upload_fut);
1048
1049                let upload_result: Result<()> = loop {
1050                    tokio::select! {
1051                        biased;
1052                        result = &mut upload_fut => {
1053                            break result;
1054                        }
1055                        Some(bytes) = progress_rx.recv() => {
1056                            bytes_uploaded += bytes;
1057                            yield UploadEvent::Progress(ChunkedUploadProgress {
1058                                bytes_uploaded,
1059                                total_bytes: file_size,
1060                            });
1061                        }
1062                    }
1063                };
1064
1065                if let Err(e) = upload_result {
1066                    let _ = self.abort_chunked_upload(&session_id).await;
1067                    Err(e)?;
1068                }
1069            }
1070
1071            let create_result = self
1072                .create_dataset_from_multipart(project, name, key, &session_id)
1073                .await;
1074
1075            match create_result {
1076                Ok(response) => {
1077                    yield UploadEvent::Complete(response);
1078                }
1079                Err(e) => {
1080                    let _ = self.abort_chunked_upload(&session_id).await;
1081                    Err(AdaptiveError::DatasetCreationFailed(e.to_string()))?;
1082                }
1083            }
1084        };
1085
1086        Ok(Box::pin(stream))
1087    }
1088
1089    /// Download a file from the given URL and write it to the specified path.
1090    /// The URL can be absolute or relative to the API base URL.
1091    pub async fn download_file_to_path(&self, url: &str, dest_path: &Path) -> Result<()> {
1092        use tokio::io::AsyncWriteExt;
1093
1094        let full_url = if url.starts_with("http://") || url.starts_with("https://") {
1095            Url::parse(url)?
1096        } else {
1097            self.rest_base_url.join(url)?
1098        };
1099
1100        let response = self
1101            .client
1102            .get(full_url)
1103            .bearer_auth(&self.auth_token)
1104            .send()
1105            .await?;
1106
1107        if !response.status().is_success() {
1108            return Err(AdaptiveError::HttpError(
1109                response.error_for_status().unwrap_err(),
1110            ));
1111        }
1112
1113        let mut file = tokio::fs::File::create(dest_path).await?;
1114        let mut stream = response.bytes_stream();
1115
1116        while let Some(chunk) = stream.next().await {
1117            let chunk = chunk?;
1118            file.write_all(&chunk).await?;
1119        }
1120
1121        file.flush().await?;
1122        Ok(())
1123    }
1124}