Skip to main content

adaptive_client_rust/
lib.rs

1use std::{fmt::Display, fs::File, io::Read, path::Path, time::SystemTime};
2
3use futures::{StreamExt, stream::BoxStream};
4use thiserror::Error;
5use tokio::sync::mpsc;
6
7use graphql_client::{GraphQLQuery, Response};
8use reqwest::Client;
9use serde::{Deserialize, Serialize, de::DeserializeOwned};
10use serde_json::{Map, Value};
11use url::Url;
12use uuid::Uuid;
13
14mod rest_types;
15mod serde_utils;
16
17use rest_types::{AbortChunkedUploadRequest, InitChunkedUploadRequest, InitChunkedUploadResponse};
18
19const MEGABYTE: u64 = 1024 * 1024; // 1MB
20pub const MIN_CHUNK_SIZE_BYTES: u64 = 5 * MEGABYTE;
21const MAX_CHUNK_SIZE_BYTES: u64 = 100 * MEGABYTE;
22const MAX_PARTS_COUNT: u64 = 10000;
23
24const SIZE_500MB: u64 = 500 * MEGABYTE;
25const SIZE_10GB: u64 = 10 * 1024 * MEGABYTE;
26const SIZE_50GB: u64 = 50 * 1024 * MEGABYTE;
27
28#[derive(Error, Debug)]
29pub enum AdaptiveError {
30    #[error("HTTP error: {0}")]
31    HttpError(#[from] reqwest::Error),
32
33    #[error("JSON error: {0}")]
34    JsonError(#[from] serde_json::Error),
35
36    #[error("IO error: {0}")]
37    IoError(#[from] std::io::Error),
38
39    #[error("URL parse error: {0}")]
40    UrlParseError(#[from] url::ParseError),
41
42    #[error("File too small for chunked upload: {size} bytes (minimum: {min_size} bytes)")]
43    FileTooSmall { size: u64, min_size: u64 },
44
45    #[error("File too large: {size} bytes exceeds maximum {max_size} bytes")]
46    FileTooLarge { size: u64, max_size: u64 },
47
48    #[error("GraphQL errors: {0:?}")]
49    GraphQLErrors(Vec<graphql_client::Error>),
50
51    #[error("No data returned from GraphQL")]
52    NoGraphQLData,
53
54    #[error("Job not found: {0}")]
55    JobNotFound(Uuid),
56
57    #[error("Failed to initialize chunked upload: {status} - {body}")]
58    ChunkedUploadInitFailed { status: String, body: String },
59
60    #[error("Failed to upload part {part_number}: {status} - {body}")]
61    ChunkedUploadPartFailed {
62        part_number: u64,
63        status: String,
64        body: String,
65    },
66
67    #[error("Failed to create dataset: {0}")]
68    DatasetCreationFailed(String),
69
70    #[error("HTTP status error: {status} - {body}")]
71    HttpStatusError { status: String, body: String },
72
73    #[error("Failed to parse JSON response: {error}. Body preview: {body}")]
74    JsonParseError { error: String, body: String },
75}
76
77type Result<T> = std::result::Result<T, AdaptiveError>;
78
79#[derive(Clone, Debug, Default)]
80pub struct ChunkedUploadProgress {
81    pub bytes_uploaded: u64,
82    pub total_bytes: u64,
83}
84
85#[derive(Debug)]
86pub enum UploadEvent {
87    Progress(ChunkedUploadProgress),
88    Complete(
89        create_dataset_from_multipart::CreateDatasetFromMultipartCreateDatasetFromMultipartUpload,
90    ),
91}
92
93pub fn calculate_upload_parts(file_size: u64) -> Result<(u64, u64)> {
94    if file_size < MIN_CHUNK_SIZE_BYTES {
95        return Err(AdaptiveError::FileTooSmall {
96            size: file_size,
97            min_size: MIN_CHUNK_SIZE_BYTES,
98        });
99    }
100
101    let mut chunk_size = if file_size < SIZE_500MB {
102        5 * MEGABYTE
103    } else if file_size < SIZE_10GB {
104        10 * MEGABYTE
105    } else if file_size < SIZE_50GB {
106        50 * MEGABYTE
107    } else {
108        100 * MEGABYTE
109    };
110
111    let mut total_parts = file_size.div_ceil(chunk_size);
112
113    if total_parts > MAX_PARTS_COUNT {
114        chunk_size = file_size.div_ceil(MAX_PARTS_COUNT);
115
116        if chunk_size > MAX_CHUNK_SIZE_BYTES {
117            let max_file_size = MAX_CHUNK_SIZE_BYTES * MAX_PARTS_COUNT;
118            return Err(AdaptiveError::FileTooLarge {
119                size: file_size,
120                max_size: max_file_size,
121            });
122        }
123
124        total_parts = file_size.div_ceil(chunk_size);
125    }
126
127    Ok((total_parts, chunk_size))
128}
129
130type IdOrKey = String;
131#[allow(clippy::upper_case_acronyms)]
132type UUID = Uuid;
133type JsObject = Map<String, Value>;
134type InputDatetime = String;
135#[allow(clippy::upper_case_acronyms)]
136type JSON = Value;
137#[allow(clippy::upper_case_acronyms)]
138type JSONObject = Map<String, Value>;
139type KeyInput = String;
140
141#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
142pub struct Timestamp(pub SystemTime);
143
144impl<'de> serde::Deserialize<'de> for Timestamp {
145    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
146    where
147        D: serde::Deserializer<'de>,
148    {
149        let system_time = serde_utils::deserialize_timestamp_millis(deserializer)?;
150        Ok(Timestamp(system_time))
151    }
152}
153
154const PAGE_SIZE: usize = 20;
155
156#[derive(Debug, PartialEq, Serialize, Deserialize)]
157pub struct Upload(usize);
158
159#[derive(GraphQLQuery)]
160#[graphql(
161    schema_path = "schema.gql",
162    query_path = "src/graphql/list.graphql",
163    response_derives = "Debug, Clone"
164)]
165pub struct GetCustomRecipes;
166
167#[derive(GraphQLQuery)]
168#[graphql(
169    schema_path = "schema.gql",
170    query_path = "src/graphql/job.graphql",
171    response_derives = "Debug, Clone"
172)]
173pub struct GetJob;
174
175#[derive(GraphQLQuery)]
176#[graphql(
177    schema_path = "schema.gql",
178    query_path = "src/graphql/jobs.graphql",
179    response_derives = "Debug, Clone"
180)]
181pub struct ListJobs;
182
183#[derive(GraphQLQuery)]
184#[graphql(
185    schema_path = "schema.gql",
186    query_path = "src/graphql/cancel.graphql",
187    response_derives = "Debug, Clone"
188)]
189pub struct CancelJob;
190
191#[derive(GraphQLQuery)]
192#[graphql(
193    schema_path = "schema.gql",
194    query_path = "src/graphql/models.graphql",
195    response_derives = "Debug, Clone"
196)]
197pub struct ListModels;
198
199#[derive(GraphQLQuery)]
200#[graphql(
201    schema_path = "schema.gql",
202    query_path = "src/graphql/all_models.graphql",
203    response_derives = "Debug, Clone"
204)]
205pub struct ListAllModels;
206
207impl Display for get_job::JobStatus {
208    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
209        match self {
210            get_job::JobStatus::PENDING => write!(f, "Pending"),
211            get_job::JobStatus::RUNNING => write!(f, "Running"),
212            get_job::JobStatus::COMPLETED => write!(f, "Completed"),
213            get_job::JobStatus::FAILED => write!(f, "Failed"),
214            get_job::JobStatus::CANCELED => write!(f, "Canceled"),
215            get_job::JobStatus::Other(_) => write!(f, "Unknown"),
216        }
217    }
218}
219
220#[derive(GraphQLQuery)]
221#[graphql(
222    schema_path = "schema.gql",
223    query_path = "src/graphql/publish.graphql",
224    response_derives = "Debug, Clone"
225)]
226pub struct PublishCustomRecipe;
227
228#[derive(GraphQLQuery)]
229#[graphql(
230    schema_path = "schema.gql",
231    query_path = "src/graphql/update_recipe.graphql",
232    response_derives = "Debug, Clone"
233)]
234pub struct UpdateCustomRecipe;
235
236#[derive(GraphQLQuery)]
237#[graphql(
238    schema_path = "schema.gql",
239    query_path = "src/graphql/upload_dataset.graphql",
240    response_derives = "Debug, Clone"
241)]
242pub struct UploadDataset;
243
244#[derive(GraphQLQuery)]
245#[graphql(
246    schema_path = "schema.gql",
247    query_path = "src/graphql/create_dataset_from_multipart.graphql",
248    response_derives = "Debug, Clone"
249)]
250pub struct CreateDatasetFromMultipart;
251
252#[derive(GraphQLQuery)]
253#[graphql(
254    schema_path = "schema.gql",
255    query_path = "src/graphql/run.graphql",
256    response_derives = "Debug, Clone"
257)]
258pub struct RunCustomRecipe;
259
260#[derive(GraphQLQuery)]
261#[graphql(
262    schema_path = "schema.gql",
263    query_path = "src/graphql/projects.graphql",
264    response_derives = "Debug, Clone"
265)]
266pub struct ListProjects;
267
268#[derive(GraphQLQuery)]
269#[graphql(
270    schema_path = "schema.gql",
271    query_path = "src/graphql/pools.graphql",
272    response_derives = "Debug, Clone"
273)]
274pub struct ListComputePools;
275
276#[derive(GraphQLQuery)]
277#[graphql(
278    schema_path = "schema.gql",
279    query_path = "src/graphql/recipe.graphql",
280    response_derives = "Debug, Clone"
281)]
282pub struct GetRecipe;
283
284#[derive(GraphQLQuery)]
285#[graphql(
286    schema_path = "schema.gql",
287    query_path = "src/graphql/grader.graphql",
288    response_derives = "Debug, Clone, Serialize"
289)]
290pub struct GetGrader;
291
292#[derive(GraphQLQuery)]
293#[graphql(
294    schema_path = "schema.gql",
295    query_path = "src/graphql/dataset.graphql",
296    response_derives = "Debug, Clone"
297)]
298pub struct GetDataset;
299
300#[derive(GraphQLQuery)]
301#[graphql(
302    schema_path = "schema.gql",
303    query_path = "src/graphql/model_config.graphql",
304    response_derives = "Debug, Clone, Serialize"
305)]
306pub struct GetModelConfig;
307
308#[derive(GraphQLQuery)]
309#[graphql(
310    schema_path = "schema.gql",
311    query_path = "src/graphql/artifact.graphql",
312    response_derives = "Debug, Clone"
313)]
314pub struct GetArtifact;
315
316#[derive(GraphQLQuery)]
317#[graphql(
318    schema_path = "schema.gql",
319    query_path = "src/graphql/job_progress.graphql",
320    response_derives = "Debug, Clone"
321)]
322pub struct UpdateJobProgress;
323
324#[derive(GraphQLQuery)]
325#[graphql(
326    schema_path = "schema.gql",
327    query_path = "src/graphql/roles.graphql",
328    response_derives = "Debug, Clone"
329)]
330pub struct ListRoles;
331
332#[derive(GraphQLQuery)]
333#[graphql(
334    schema_path = "schema.gql",
335    query_path = "src/graphql/create_role.graphql",
336    response_derives = "Debug, Clone"
337)]
338pub struct CreateRole;
339
340#[derive(GraphQLQuery)]
341#[graphql(
342    schema_path = "schema.gql",
343    query_path = "src/graphql/update_role.graphql",
344    response_derives = "Debug, Clone"
345)]
346pub struct UpdateRole;
347
348#[derive(GraphQLQuery)]
349#[graphql(
350    schema_path = "schema.gql",
351    query_path = "src/graphql/teams.graphql",
352    response_derives = "Debug, Clone"
353)]
354pub struct ListTeams;
355
356#[derive(GraphQLQuery)]
357#[graphql(
358    schema_path = "schema.gql",
359    query_path = "src/graphql/create_team.graphql",
360    response_derives = "Debug, Clone"
361)]
362pub struct CreateTeam;
363
364#[derive(GraphQLQuery)]
365#[graphql(
366    schema_path = "schema.gql",
367    query_path = "src/graphql/users.graphql",
368    response_derives = "Debug, Clone"
369)]
370pub struct ListUsers;
371
372#[derive(GraphQLQuery)]
373#[graphql(
374    schema_path = "schema.gql",
375    query_path = "src/graphql/create_user.graphql",
376    response_derives = "Debug, Clone"
377)]
378pub struct CreateUser;
379
380#[derive(GraphQLQuery)]
381#[graphql(
382    schema_path = "schema.gql",
383    query_path = "src/graphql/delete_user.graphql",
384    response_derives = "Debug, Clone"
385)]
386pub struct DeleteUser;
387
388#[derive(GraphQLQuery)]
389#[graphql(
390    schema_path = "schema.gql",
391    query_path = "src/graphql/add_team_member.graphql",
392    response_derives = "Debug, Clone"
393)]
394pub struct AddTeamMember;
395
396#[derive(GraphQLQuery)]
397#[graphql(
398    schema_path = "schema.gql",
399    query_path = "src/graphql/remove_team_member.graphql",
400    response_derives = "Debug, Clone"
401)]
402pub struct RemoveTeamMember;
403
404const INIT_CHUNKED_UPLOAD_ROUTE: &str = "v1/upload/init";
405const UPLOAD_PART_ROUTE: &str = "v1/upload/part";
406const ABORT_CHUNKED_UPLOAD_ROUTE: &str = "v1/upload/abort";
407
408#[derive(Clone)]
409pub struct AdaptiveClient {
410    client: Client,
411    graphql_url: Url,
412    rest_base_url: Url,
413    auth_token: String,
414}
415
416impl AdaptiveClient {
417    pub fn new(api_base_url: Url, auth_token: String) -> Self {
418        let graphql_url = api_base_url
419            .join("graphql")
420            .expect("Failed to append graphql to base URL");
421
422        let client = Client::builder()
423            .user_agent(format!(
424                "adaptive-client-rust/{}",
425                env!("CARGO_PKG_VERSION")
426            ))
427            .build()
428            .expect("Failed to build HTTP client");
429
430        Self {
431            client,
432            graphql_url,
433            rest_base_url: api_base_url,
434            auth_token,
435        }
436    }
437
438    async fn execute_query<T>(&self, _query: T, variables: T::Variables) -> Result<T::ResponseData>
439    where
440        T: GraphQLQuery,
441        T::Variables: serde::Serialize,
442        T::ResponseData: DeserializeOwned,
443    {
444        let request_body = T::build_query(variables);
445
446        let response = self
447            .client
448            .post(self.graphql_url.clone())
449            .bearer_auth(&self.auth_token)
450            .json(&request_body)
451            .send()
452            .await?;
453
454        let status = response.status();
455        let response_text = response.text().await?;
456
457        if !status.is_success() {
458            return Err(AdaptiveError::HttpStatusError {
459                status: status.to_string(),
460                body: response_text,
461            });
462        }
463
464        let response_body: Response<T::ResponseData> = serde_json::from_str(&response_text)
465            .map_err(|e| AdaptiveError::JsonParseError {
466                error: e.to_string(),
467                body: response_text.chars().take(500).collect(),
468            })?;
469
470        match response_body.data {
471            Some(data) => Ok(data),
472            None => {
473                if let Some(errors) = response_body.errors {
474                    return Err(AdaptiveError::GraphQLErrors(errors));
475                }
476                Err(AdaptiveError::NoGraphQLData)
477            }
478        }
479    }
480
481    pub async fn list_recipes(
482        &self,
483        project: &str,
484    ) -> Result<Vec<get_custom_recipes::GetCustomRecipesCustomRecipes>> {
485        let variables = get_custom_recipes::Variables {
486            project: IdOrKey::from(project),
487        };
488
489        let response_data = self.execute_query(GetCustomRecipes, variables).await?;
490        Ok(response_data.custom_recipes)
491    }
492
493    pub async fn get_job(&self, job_id: Uuid) -> Result<get_job::GetJobJob> {
494        let variables = get_job::Variables { id: job_id };
495
496        let response_data = self.execute_query(GetJob, variables).await?;
497
498        match response_data.job {
499            Some(job) => Ok(job),
500            None => Err(AdaptiveError::JobNotFound(job_id)),
501        }
502    }
503
504    pub async fn upload_dataset<P: AsRef<Path>>(
505        &self,
506        project: &str,
507        name: &str,
508        dataset: P,
509    ) -> Result<upload_dataset::UploadDatasetCreateDataset> {
510        let variables = upload_dataset::Variables {
511            project: IdOrKey::from(project),
512            file: Upload(0),
513            name: Some(name.to_string()),
514        };
515
516        let operations = UploadDataset::build_query(variables);
517        let operations = serde_json::to_string(&operations)?;
518
519        let file_map = r#"{ "0": ["variables.file"] }"#;
520
521        let dataset_file = reqwest::multipart::Part::file(dataset).await?;
522
523        let form = reqwest::multipart::Form::new()
524            .text("operations", operations)
525            .text("map", file_map)
526            .part("0", dataset_file);
527
528        let response = self
529            .client
530            .post(self.graphql_url.clone())
531            .bearer_auth(&self.auth_token)
532            .multipart(form)
533            .send()
534            .await?;
535
536        let response: Response<<UploadDataset as graphql_client::GraphQLQuery>::ResponseData> =
537            response.json().await?;
538
539        match response.data {
540            Some(data) => Ok(data.create_dataset),
541            None => {
542                if let Some(errors) = response.errors {
543                    return Err(AdaptiveError::GraphQLErrors(errors));
544                }
545                Err(AdaptiveError::NoGraphQLData)
546            }
547        }
548    }
549
550    pub async fn publish_recipe<P: AsRef<Path>>(
551        &self,
552        project: &str,
553        name: &str,
554        key: &str,
555        recipe: P,
556    ) -> Result<publish_custom_recipe::PublishCustomRecipeCreateCustomRecipe> {
557        let variables = publish_custom_recipe::Variables {
558            project: IdOrKey::from(project),
559            file: Upload(0),
560            name: Some(name.to_string()),
561            key: Some(key.to_string()),
562        };
563
564        let operations = PublishCustomRecipe::build_query(variables);
565        let operations = serde_json::to_string(&operations)?;
566
567        let file_map = r#"{ "0": ["variables.file"] }"#;
568
569        let recipe_file = reqwest::multipart::Part::file(recipe).await?;
570
571        let form = reqwest::multipart::Form::new()
572            .text("operations", operations)
573            .text("map", file_map)
574            .part("0", recipe_file);
575
576        let response = self
577            .client
578            .post(self.graphql_url.clone())
579            .bearer_auth(&self.auth_token)
580            .multipart(form)
581            .send()
582            .await?;
583        let response: Response<
584            <PublishCustomRecipe as graphql_client::GraphQLQuery>::ResponseData,
585        > = response.json().await?;
586
587        match response.data {
588            Some(data) => Ok(data.create_custom_recipe),
589            None => {
590                if let Some(errors) = response.errors {
591                    return Err(AdaptiveError::GraphQLErrors(errors));
592                }
593                Err(AdaptiveError::NoGraphQLData)
594            }
595        }
596    }
597
598    pub async fn update_recipe<P: AsRef<Path>>(
599        &self,
600        project: &str,
601        id: &str,
602        name: Option<String>,
603        description: Option<String>,
604        labels: Option<Vec<update_custom_recipe::LabelInput>>,
605        recipe_file: Option<P>,
606    ) -> Result<update_custom_recipe::UpdateCustomRecipeUpdateCustomRecipe> {
607        let input = update_custom_recipe::UpdateRecipeInput {
608            name,
609            description,
610            labels,
611        };
612
613        match recipe_file {
614            Some(file_path) => {
615                let variables = update_custom_recipe::Variables {
616                    project: IdOrKey::from(project),
617                    id: IdOrKey::from(id),
618                    input,
619                    file: Some(Upload(0)),
620                };
621
622                let operations = UpdateCustomRecipe::build_query(variables);
623                let operations = serde_json::to_string(&operations)?;
624
625                let file_map = r#"{ "0": ["variables.file"] }"#;
626
627                let recipe_file = reqwest::multipart::Part::file(file_path).await?;
628
629                let form = reqwest::multipart::Form::new()
630                    .text("operations", operations)
631                    .text("map", file_map)
632                    .part("0", recipe_file);
633
634                let response = self
635                    .client
636                    .post(self.graphql_url.clone())
637                    .bearer_auth(&self.auth_token)
638                    .multipart(form)
639                    .send()
640                    .await?;
641                let response: Response<
642                    <UpdateCustomRecipe as graphql_client::GraphQLQuery>::ResponseData,
643                > = response.json().await?;
644
645                match response.data {
646                    Some(data) => Ok(data.update_custom_recipe),
647                    None => {
648                        if let Some(errors) = response.errors {
649                            return Err(AdaptiveError::GraphQLErrors(errors));
650                        }
651                        Err(AdaptiveError::NoGraphQLData)
652                    }
653                }
654            }
655            None => {
656                let variables = update_custom_recipe::Variables {
657                    project: IdOrKey::from(project),
658                    id: IdOrKey::from(id),
659                    input,
660                    file: None,
661                };
662
663                let response_data = self.execute_query(UpdateCustomRecipe, variables).await?;
664                Ok(response_data.update_custom_recipe)
665            }
666        }
667    }
668
669    pub async fn run_recipe(
670        &self,
671        project: &str,
672        recipe_id: &str,
673        parameters: Map<String, Value>,
674        name: Option<String>,
675        compute_pool: Option<String>,
676        num_gpus: u32,
677        use_experimental_runner: bool,
678    ) -> Result<run_custom_recipe::RunCustomRecipeCreateJob> {
679        let variables = run_custom_recipe::Variables {
680            input: run_custom_recipe::JobInput {
681                recipe: recipe_id.to_string(),
682                project: project.to_string(),
683                args: parameters,
684                name,
685                compute_pool,
686                num_gpus: num_gpus as i64,
687                use_experimental_runner,
688                image_tag: None,
689                max_cpu: None,
690                max_ram_gb: None,
691                max_duration_secs: None,
692                resume_artifact_id: None,
693            },
694        };
695
696        let response_data = self.execute_query(RunCustomRecipe, variables).await?;
697        Ok(response_data.create_job)
698    }
699
700    pub async fn list_jobs(
701        &self,
702        project: Option<String>,
703    ) -> Result<Vec<list_jobs::ListJobsJobsNodes>> {
704        let mut jobs = Vec::new();
705        let mut page = self.list_jobs_page(project.clone(), None).await?;
706        jobs.extend(page.nodes.iter().cloned());
707        while page.page_info.has_next_page {
708            page = self
709                .list_jobs_page(project.clone(), page.page_info.end_cursor)
710                .await?;
711            jobs.extend(page.nodes.iter().cloned());
712        }
713        Ok(jobs)
714    }
715
716    async fn list_jobs_page(
717        &self,
718        project: Option<String>,
719        after: Option<String>,
720    ) -> Result<list_jobs::ListJobsJobs> {
721        let variables = list_jobs::Variables {
722            filter: Some(list_jobs::ListJobsFilterInput {
723                project,
724                kind: Some(vec![list_jobs::JobKind::CUSTOM]),
725                status: Some(vec![
726                    list_jobs::JobStatus::RUNNING,
727                    list_jobs::JobStatus::PENDING,
728                ]),
729                timerange: None,
730                custom_recipes: None,
731                artifacts: None,
732                created_by: None,
733                name: None,
734                advanced_filter: Box::new(None),
735            }),
736            cursor: Some(list_jobs::CursorPageInput {
737                first: Some(PAGE_SIZE as i64),
738                after,
739                before: None,
740                last: None,
741                offset: None,
742            }),
743        };
744
745        let response_data = self.execute_query(ListJobs, variables).await?;
746        Ok(response_data.jobs)
747    }
748
749    pub async fn cancel_job(&self, job_id: Uuid) -> Result<cancel_job::CancelJobCancelJob> {
750        let variables = cancel_job::Variables { job_id };
751
752        let response_data = self.execute_query(CancelJob, variables).await?;
753        Ok(response_data.cancel_job)
754    }
755
756    pub async fn update_job_progress(
757        &self,
758        job_id: Uuid,
759        event: update_job_progress::JobProgressEventInput,
760    ) -> Result<update_job_progress::UpdateJobProgressUpdateJobProgress> {
761        let variables = update_job_progress::Variables { job_id, event };
762
763        let response_data = self.execute_query(UpdateJobProgress, variables).await?;
764        Ok(response_data.update_job_progress)
765    }
766
767    pub async fn list_models(
768        &self,
769        project: String,
770    ) -> Result<Vec<list_models::ListModelsProjectModelServices>> {
771        let variables = list_models::Variables { project };
772
773        let response_data = self.execute_query(ListModels, variables).await?;
774        Ok(response_data
775            .project
776            .map(|project| project.model_services)
777            .unwrap_or(Vec::new()))
778    }
779
780    pub async fn list_all_models(&self) -> Result<Vec<list_all_models::ListAllModelsModels>> {
781        let variables = list_all_models::Variables {};
782
783        let response_data = self.execute_query(ListAllModels, variables).await?;
784        Ok(response_data.models)
785    }
786
787    pub async fn list_projects(&self) -> Result<Vec<list_projects::ListProjectsProjects>> {
788        let variables = list_projects::Variables {};
789
790        let response_data = self.execute_query(ListProjects, variables).await?;
791        Ok(response_data.projects)
792    }
793
794    pub async fn list_pools(
795        &self,
796    ) -> Result<Vec<list_compute_pools::ListComputePoolsComputePools>> {
797        let variables = list_compute_pools::Variables {};
798
799        let response_data = self.execute_query(ListComputePools, variables).await?;
800        Ok(response_data.compute_pools)
801    }
802
803    pub async fn list_roles(&self) -> Result<Vec<list_roles::ListRolesRoles>> {
804        let variables = list_roles::Variables {};
805
806        let response_data = self.execute_query(ListRoles, variables).await?;
807        Ok(response_data.roles)
808    }
809
810    pub async fn create_role(
811        &self,
812        name: &str,
813        key: Option<&str>,
814        permissions: Vec<String>,
815    ) -> Result<create_role::CreateRoleCreateRole> {
816        let variables = create_role::Variables {
817            input: create_role::RoleCreate {
818                name: name.to_string(),
819                key: key.map(|k| k.to_string()),
820                permissions,
821            },
822        };
823
824        let response_data = self.execute_query(CreateRole, variables).await?;
825        Ok(response_data.create_role)
826    }
827
828    pub async fn update_role(
829        &self,
830        role: &str,
831        name: Option<&str>,
832        permissions: Option<Vec<String>>,
833    ) -> Result<update_role::UpdateRoleUpdateRole> {
834        let variables = update_role::Variables {
835            input: update_role::RoleUpdate {
836                role: role.to_string(),
837                name: name.map(|n| n.to_string()),
838                permissions,
839            },
840        };
841
842        let response_data = self.execute_query(UpdateRole, variables).await?;
843        Ok(response_data.update_role)
844    }
845
846    pub async fn list_teams(&self) -> Result<Vec<list_teams::ListTeamsTeams>> {
847        let variables = list_teams::Variables {};
848
849        let response_data = self.execute_query(ListTeams, variables).await?;
850        Ok(response_data.teams)
851    }
852
853    pub async fn create_team(
854        &self,
855        name: &str,
856        key: Option<&str>,
857    ) -> Result<create_team::CreateTeamCreateTeam> {
858        let variables = create_team::Variables {
859            input: create_team::TeamCreate {
860                name: name.to_string(),
861                key: key.map(|k| k.to_string()),
862            },
863        };
864
865        let response_data = self.execute_query(CreateTeam, variables).await?;
866        Ok(response_data.create_team)
867    }
868
869    pub async fn list_users(&self) -> Result<Vec<list_users::ListUsersUsers>> {
870        let variables = list_users::Variables {};
871
872        let response_data = self.execute_query(ListUsers, variables).await?;
873        Ok(response_data.users)
874    }
875
876    pub async fn create_user(
877        &self,
878        name: &str,
879        email: Option<&str>,
880        teams: Vec<create_user::UserCreateTeamWithRole>,
881        user_type: Option<create_user::UserType>,
882        generate_api_key: Option<bool>,
883    ) -> Result<create_user::CreateUserCreateUser> {
884        let variables = create_user::Variables {
885            input: create_user::UserCreate {
886                name: name.to_string(),
887                email: email.map(|e| e.to_string()),
888                teams,
889                user_type: user_type.unwrap_or(create_user::UserType::HUMAN),
890                generate_api_key,
891            },
892        };
893
894        let response_data = self.execute_query(CreateUser, variables).await?;
895        Ok(response_data.create_user)
896    }
897
898    pub async fn delete_user(&self, user: &str) -> Result<delete_user::DeleteUserDeleteUser> {
899        let variables = delete_user::Variables {
900            user: user.to_string(),
901        };
902
903        let response_data = self.execute_query(DeleteUser, variables).await?;
904        Ok(response_data.delete_user)
905    }
906
907    pub async fn add_team_member(
908        &self,
909        user: &str,
910        team: &str,
911        role: &str,
912    ) -> Result<add_team_member::AddTeamMemberSetTeamMember> {
913        let variables = add_team_member::Variables {
914            input: add_team_member::TeamMemberSet {
915                user: user.to_string(),
916                team: team.to_string(),
917                role: role.to_string(),
918            },
919        };
920
921        let response_data = self.execute_query(AddTeamMember, variables).await?;
922        Ok(response_data.set_team_member)
923    }
924
925    pub async fn remove_team_member(
926        &self,
927        user: &str,
928        team: &str,
929    ) -> Result<remove_team_member::RemoveTeamMemberRemoveTeamMember> {
930        let variables = remove_team_member::Variables {
931            input: remove_team_member::TeamMemberRemove {
932                user: user.to_string(),
933                team: team.to_string(),
934            },
935        };
936
937        let response_data = self.execute_query(RemoveTeamMember, variables).await?;
938        Ok(response_data.remove_team_member)
939    }
940
941    pub async fn get_recipe(
942        &self,
943        project: String,
944        id_or_key: String,
945    ) -> Result<Option<get_recipe::GetRecipeCustomRecipe>> {
946        let variables = get_recipe::Variables { project, id_or_key };
947
948        let response_data = self.execute_query(GetRecipe, variables).await?;
949        Ok(response_data.custom_recipe)
950    }
951
952    pub async fn get_grader(
953        &self,
954        id_or_key: &str,
955        project: &str,
956    ) -> Result<get_grader::GetGraderGrader> {
957        let variables = get_grader::Variables {
958            id: id_or_key.to_string(),
959            project: project.to_string(),
960        };
961
962        let response_data = self.execute_query(GetGrader, variables).await?;
963        Ok(response_data.grader)
964    }
965
966    pub async fn get_dataset(
967        &self,
968        id_or_key: &str,
969        project: &str,
970    ) -> Result<Option<get_dataset::GetDatasetDataset>> {
971        let variables = get_dataset::Variables {
972            id_or_key: id_or_key.to_string(),
973            project: project.to_string(),
974        };
975
976        let response_data = self.execute_query(GetDataset, variables).await?;
977        Ok(response_data.dataset)
978    }
979
980    pub async fn get_model_config(
981        &self,
982        id_or_key: &str,
983    ) -> Result<Option<get_model_config::GetModelConfigModel>> {
984        let variables = get_model_config::Variables {
985            id_or_key: id_or_key.to_string(),
986        };
987
988        let response_data = self.execute_query(GetModelConfig, variables).await?;
989        Ok(response_data.model)
990    }
991
992    pub async fn get_artifact(
993        &self,
994        project: &str,
995        id: Uuid,
996    ) -> Result<Option<get_artifact::GetArtifactArtifact>> {
997        let variables = get_artifact::Variables {
998            project: project.to_string(),
999            id,
1000        };
1001
1002        let response_data = self.execute_query(GetArtifact, variables).await?;
1003        Ok(response_data.artifact)
1004    }
1005
1006    pub fn base_url(&self) -> &Url {
1007        &self.rest_base_url
1008    }
1009
1010    /// Upload bytes using the chunked upload API and return the session_id.
1011    /// This can be used to link the uploaded file to an artifact.
1012    pub async fn upload_bytes(&self, data: &[u8], content_type: &str) -> Result<String> {
1013        let file_size = data.len() as u64;
1014
1015        // Calculate chunk size (same logic as calculate_upload_parts but inline for small files)
1016        let chunk_size = if file_size < 5 * 1024 * 1024 {
1017            // For files < 5MB, use the whole file as one chunk
1018            file_size.max(1)
1019        } else if file_size < 500 * 1024 * 1024 {
1020            5 * 1024 * 1024
1021        } else if file_size < 10 * 1024 * 1024 * 1024 {
1022            10 * 1024 * 1024
1023        } else {
1024            100 * 1024 * 1024
1025        };
1026
1027        let total_parts = file_size.div_ceil(chunk_size).max(1);
1028
1029        // Initialize upload session
1030        let session_id = self
1031            .init_chunked_upload_with_content_type(total_parts, content_type)
1032            .await?;
1033
1034        // Upload parts
1035        for part_number in 1..=total_parts {
1036            let start = ((part_number - 1) * chunk_size) as usize;
1037            let end = (part_number * chunk_size).min(file_size) as usize;
1038            let chunk = data[start..end].to_vec();
1039
1040            if let Err(e) = self
1041                .upload_part_simple(&session_id, part_number, chunk)
1042                .await
1043            {
1044                let _ = self.abort_chunked_upload(&session_id).await;
1045                return Err(e);
1046            }
1047        }
1048
1049        Ok(session_id)
1050    }
1051
1052    /// Initialize a chunked upload session.
1053    pub async fn init_chunked_upload_with_content_type(
1054        &self,
1055        total_parts: u64,
1056        content_type: &str,
1057    ) -> Result<String> {
1058        let url = self.rest_base_url.join(INIT_CHUNKED_UPLOAD_ROUTE)?;
1059
1060        let request = InitChunkedUploadRequest {
1061            content_type: content_type.to_string(),
1062            metadata: None,
1063            total_parts_count: total_parts,
1064        };
1065
1066        let response = self
1067            .client
1068            .post(url)
1069            .bearer_auth(&self.auth_token)
1070            .json(&request)
1071            .send()
1072            .await?;
1073
1074        if !response.status().is_success() {
1075            return Err(AdaptiveError::ChunkedUploadInitFailed {
1076                status: response.status().to_string(),
1077                body: response.text().await.unwrap_or_default(),
1078            });
1079        }
1080
1081        let init_response: InitChunkedUploadResponse = response.json().await?;
1082        Ok(init_response.session_id)
1083    }
1084
1085    /// Upload a single part of a chunked upload.
1086    pub async fn upload_part_simple(
1087        &self,
1088        session_id: &str,
1089        part_number: u64,
1090        data: Vec<u8>,
1091    ) -> Result<()> {
1092        let url = self.rest_base_url.join(UPLOAD_PART_ROUTE)?;
1093
1094        let response = self
1095            .client
1096            .post(url)
1097            .bearer_auth(&self.auth_token)
1098            .query(&[
1099                ("session_id", session_id),
1100                ("part_number", &part_number.to_string()),
1101            ])
1102            .header("Content-Type", "application/octet-stream")
1103            .body(data)
1104            .send()
1105            .await?;
1106
1107        if !response.status().is_success() {
1108            return Err(AdaptiveError::ChunkedUploadPartFailed {
1109                part_number,
1110                status: response.status().to_string(),
1111                body: response.text().await.unwrap_or_default(),
1112            });
1113        }
1114
1115        Ok(())
1116    }
1117
1118    async fn init_chunked_upload(&self, total_parts: u64) -> Result<String> {
1119        let url = self.rest_base_url.join(INIT_CHUNKED_UPLOAD_ROUTE)?;
1120
1121        let request = InitChunkedUploadRequest {
1122            content_type: "application/jsonl".to_string(),
1123            metadata: None,
1124            total_parts_count: total_parts,
1125        };
1126
1127        let response = self
1128            .client
1129            .post(url)
1130            .bearer_auth(&self.auth_token)
1131            .json(&request)
1132            .send()
1133            .await?;
1134
1135        if !response.status().is_success() {
1136            return Err(AdaptiveError::ChunkedUploadInitFailed {
1137                status: response.status().to_string(),
1138                body: response.text().await.unwrap_or_default(),
1139            });
1140        }
1141
1142        let init_response: InitChunkedUploadResponse = response.json().await?;
1143        Ok(init_response.session_id)
1144    }
1145
1146    async fn upload_part(
1147        &self,
1148        session_id: &str,
1149        part_number: u64,
1150        data: Vec<u8>,
1151        progress_tx: mpsc::Sender<u64>,
1152    ) -> Result<()> {
1153        const SUB_CHUNK_SIZE: usize = 64 * 1024;
1154
1155        let url = self.rest_base_url.join(UPLOAD_PART_ROUTE)?;
1156
1157        let chunks: Vec<Vec<u8>> = data
1158            .chunks(SUB_CHUNK_SIZE)
1159            .map(|chunk| chunk.to_vec())
1160            .collect();
1161
1162        let stream = futures::stream::iter(chunks).map(move |chunk| {
1163            let len = chunk.len() as u64;
1164            let tx = progress_tx.clone();
1165            let _ = tx.try_send(len);
1166            Ok::<_, std::io::Error>(chunk)
1167        });
1168
1169        let body = reqwest::Body::wrap_stream(stream);
1170
1171        let response = self
1172            .client
1173            .post(url)
1174            .bearer_auth(&self.auth_token)
1175            .query(&[
1176                ("session_id", session_id),
1177                ("part_number", &part_number.to_string()),
1178            ])
1179            .header("Content-Type", "application/octet-stream")
1180            .body(body)
1181            .send()
1182            .await?;
1183
1184        if !response.status().is_success() {
1185            return Err(AdaptiveError::ChunkedUploadPartFailed {
1186                part_number,
1187                status: response.status().to_string(),
1188                body: response.text().await.unwrap_or_default(),
1189            });
1190        }
1191
1192        Ok(())
1193    }
1194
1195    /// Abort a chunked upload session.
1196    pub async fn abort_chunked_upload(&self, session_id: &str) -> Result<()> {
1197        let url = self.rest_base_url.join(ABORT_CHUNKED_UPLOAD_ROUTE)?;
1198
1199        let request = AbortChunkedUploadRequest {
1200            session_id: session_id.to_string(),
1201        };
1202
1203        let _ = self
1204            .client
1205            .delete(url)
1206            .bearer_auth(&self.auth_token)
1207            .json(&request)
1208            .send()
1209            .await;
1210
1211        Ok(())
1212    }
1213
1214    async fn create_dataset_from_multipart(
1215        &self,
1216        project: &str,
1217        name: &str,
1218        key: &str,
1219        session_id: &str,
1220    ) -> Result<
1221        create_dataset_from_multipart::CreateDatasetFromMultipartCreateDatasetFromMultipartUpload,
1222    > {
1223        let variables = create_dataset_from_multipart::Variables {
1224            input: create_dataset_from_multipart::DatasetCreateFromMultipartUpload {
1225                project: project.to_string(),
1226                name: name.to_string(),
1227                key: Some(key.to_string()),
1228                source: None,
1229                upload_session_id: session_id.to_string(),
1230            },
1231        };
1232
1233        let response_data = self
1234            .execute_query(CreateDatasetFromMultipart, variables)
1235            .await?;
1236        Ok(response_data.create_dataset_from_multipart_upload)
1237    }
1238
1239    pub fn chunked_upload_dataset<'a, P: AsRef<Path> + Send + 'a>(
1240        &'a self,
1241        project: &'a str,
1242        name: &'a str,
1243        key: &'a str,
1244        dataset: P,
1245    ) -> Result<BoxStream<'a, Result<UploadEvent>>> {
1246        let file_size = std::fs::metadata(dataset.as_ref())?.len();
1247
1248        let (total_parts, chunk_size) = calculate_upload_parts(file_size)?;
1249
1250        let stream = async_stream::try_stream! {
1251            yield UploadEvent::Progress(ChunkedUploadProgress {
1252                bytes_uploaded: 0,
1253                total_bytes: file_size,
1254            });
1255
1256            let session_id = self.init_chunked_upload(total_parts).await?;
1257
1258            let mut file = File::open(dataset.as_ref())?;
1259            let mut buffer = vec![0u8; chunk_size as usize];
1260            let mut bytes_uploaded = 0u64;
1261
1262            let (progress_tx, mut progress_rx) = mpsc::channel::<u64>(64);
1263
1264            for part_number in 1..=total_parts {
1265                let bytes_read = file.read(&mut buffer)?;
1266                let chunk_data = buffer[..bytes_read].to_vec();
1267
1268                let upload_fut = self.upload_part(&session_id, part_number, chunk_data, progress_tx.clone());
1269                tokio::pin!(upload_fut);
1270
1271                let upload_result: Result<()> = loop {
1272                    tokio::select! {
1273                        biased;
1274                        result = &mut upload_fut => {
1275                            break result;
1276                        }
1277                        Some(bytes) = progress_rx.recv() => {
1278                            bytes_uploaded += bytes;
1279                            yield UploadEvent::Progress(ChunkedUploadProgress {
1280                                bytes_uploaded,
1281                                total_bytes: file_size,
1282                            });
1283                        }
1284                    }
1285                };
1286
1287                if let Err(e) = upload_result {
1288                    let _ = self.abort_chunked_upload(&session_id).await;
1289                    Err(e)?;
1290                }
1291            }
1292
1293            let create_result = self
1294                .create_dataset_from_multipart(project, name, key, &session_id)
1295                .await;
1296
1297            match create_result {
1298                Ok(response) => {
1299                    yield UploadEvent::Complete(response);
1300                }
1301                Err(e) => {
1302                    let _ = self.abort_chunked_upload(&session_id).await;
1303                    Err(AdaptiveError::DatasetCreationFailed(e.to_string()))?;
1304                }
1305            }
1306        };
1307
1308        Ok(Box::pin(stream))
1309    }
1310
1311    /// Download a file from the given URL and write it to the specified path.
1312    /// The URL can be absolute or relative to the API base URL.
1313    pub async fn download_file_to_path(&self, url: &str, dest_path: &Path) -> Result<()> {
1314        use tokio::io::AsyncWriteExt;
1315
1316        let full_url = if url.starts_with("http://") || url.starts_with("https://") {
1317            Url::parse(url)?
1318        } else {
1319            self.rest_base_url.join(url)?
1320        };
1321
1322        let response = self
1323            .client
1324            .get(full_url)
1325            .bearer_auth(&self.auth_token)
1326            .send()
1327            .await?;
1328
1329        if !response.status().is_success() {
1330            return Err(AdaptiveError::HttpError(
1331                response.error_for_status().unwrap_err(),
1332            ));
1333        }
1334
1335        let mut file = tokio::fs::File::create(dest_path).await?;
1336        let mut stream = response.bytes_stream();
1337
1338        while let Some(chunk) = stream.next().await {
1339            let chunk = chunk?;
1340            file.write_all(&chunk).await?;
1341        }
1342
1343        file.flush().await?;
1344        Ok(())
1345    }
1346}