Skip to main content

adaptive_client_rust/
lib.rs

1use std::{fmt::Display, fs::File, io::Read, path::Path, time::SystemTime};
2
3use futures::{StreamExt, stream::BoxStream};
4use thiserror::Error;
5use tokio::sync::mpsc;
6
7use graphql_client::{GraphQLQuery, Response};
8use reqwest::Client;
9use serde::{Deserialize, Serialize, de::DeserializeOwned};
10use serde_json::{Map, Value};
11use url::Url;
12use uuid::Uuid;
13
14mod rest_types;
15mod serde_utils;
16
17use rest_types::{AbortChunkedUploadRequest, InitChunkedUploadRequest, InitChunkedUploadResponse};
18
19const MEGABYTE: u64 = 1024 * 1024; // 1MB
20pub const MIN_CHUNK_SIZE_BYTES: u64 = 5 * MEGABYTE;
21const MAX_CHUNK_SIZE_BYTES: u64 = 100 * MEGABYTE;
22const MAX_PARTS_COUNT: u64 = 10000;
23
24const SIZE_500MB: u64 = 500 * MEGABYTE;
25const SIZE_10GB: u64 = 10 * 1024 * MEGABYTE;
26const SIZE_50GB: u64 = 50 * 1024 * MEGABYTE;
27
28#[derive(Error, Debug)]
29pub enum AdaptiveError {
30    #[error("HTTP error: {0}")]
31    HttpError(#[from] reqwest::Error),
32
33    #[error("JSON error: {0}")]
34    JsonError(#[from] serde_json::Error),
35
36    #[error("IO error: {0}")]
37    IoError(#[from] std::io::Error),
38
39    #[error("URL parse error: {0}")]
40    UrlParseError(#[from] url::ParseError),
41
42    #[error("File too small for chunked upload: {size} bytes (minimum: {min_size} bytes)")]
43    FileTooSmall { size: u64, min_size: u64 },
44
45    #[error("File too large: {size} bytes exceeds maximum {max_size} bytes")]
46    FileTooLarge { size: u64, max_size: u64 },
47
48    #[error("GraphQL errors: {0:?}")]
49    GraphQLErrors(Vec<graphql_client::Error>),
50
51    #[error("No data returned from GraphQL")]
52    NoGraphQLData,
53
54    #[error("Job not found: {0}")]
55    JobNotFound(Uuid),
56
57    #[error("Failed to initialize chunked upload: {status} - {body}")]
58    ChunkedUploadInitFailed { status: String, body: String },
59
60    #[error("Failed to upload part {part_number}: {status} - {body}")]
61    ChunkedUploadPartFailed {
62        part_number: u64,
63        status: String,
64        body: String,
65    },
66
67    #[error("Failed to create dataset: {0}")]
68    DatasetCreationFailed(String),
69
70    #[error("HTTP status error: {status} - {body}")]
71    HttpStatusError { status: String, body: String },
72
73    #[error("Failed to parse JSON response: {error}. Body preview: {body}")]
74    JsonParseError { error: String, body: String },
75}
76
77type Result<T> = std::result::Result<T, AdaptiveError>;
78
79#[derive(Clone, Debug, Default)]
80pub struct ChunkedUploadProgress {
81    pub bytes_uploaded: u64,
82    pub total_bytes: u64,
83}
84
85#[derive(Debug)]
86pub enum UploadEvent {
87    Progress(ChunkedUploadProgress),
88    Complete(
89        create_dataset_from_multipart::CreateDatasetFromMultipartCreateDatasetFromMultipartUpload,
90    ),
91}
92
93pub fn calculate_upload_parts(file_size: u64) -> Result<(u64, u64)> {
94    if file_size < MIN_CHUNK_SIZE_BYTES {
95        return Err(AdaptiveError::FileTooSmall {
96            size: file_size,
97            min_size: MIN_CHUNK_SIZE_BYTES,
98        });
99    }
100
101    let mut chunk_size = if file_size < SIZE_500MB {
102        5 * MEGABYTE
103    } else if file_size < SIZE_10GB {
104        10 * MEGABYTE
105    } else if file_size < SIZE_50GB {
106        50 * MEGABYTE
107    } else {
108        100 * MEGABYTE
109    };
110
111    let mut total_parts = file_size.div_ceil(chunk_size);
112
113    if total_parts > MAX_PARTS_COUNT {
114        chunk_size = file_size.div_ceil(MAX_PARTS_COUNT);
115
116        if chunk_size > MAX_CHUNK_SIZE_BYTES {
117            let max_file_size = MAX_CHUNK_SIZE_BYTES * MAX_PARTS_COUNT;
118            return Err(AdaptiveError::FileTooLarge {
119                size: file_size,
120                max_size: max_file_size,
121            });
122        }
123
124        total_parts = file_size.div_ceil(chunk_size);
125    }
126
127    Ok((total_parts, chunk_size))
128}
129
130type IdOrKey = String;
131#[allow(clippy::upper_case_acronyms)]
132type UUID = Uuid;
133type JsObject = Map<String, Value>;
134type InputDatetime = String;
135#[allow(clippy::upper_case_acronyms)]
136type JSON = Value;
137type KeyInput = String;
138
139#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
140pub struct Timestamp(pub SystemTime);
141
142impl<'de> serde::Deserialize<'de> for Timestamp {
143    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
144    where
145        D: serde::Deserializer<'de>,
146    {
147        let system_time = serde_utils::deserialize_timestamp_millis(deserializer)?;
148        Ok(Timestamp(system_time))
149    }
150}
151
152const PAGE_SIZE: usize = 20;
153
154#[derive(Debug, PartialEq, Serialize, Deserialize)]
155pub struct Upload(usize);
156
157#[derive(GraphQLQuery)]
158#[graphql(
159    schema_path = "schema.gql",
160    query_path = "src/graphql/list.graphql",
161    response_derives = "Debug, Clone"
162)]
163pub struct GetCustomRecipes;
164
165#[derive(GraphQLQuery)]
166#[graphql(
167    schema_path = "schema.gql",
168    query_path = "src/graphql/job.graphql",
169    response_derives = "Debug, Clone"
170)]
171pub struct GetJob;
172
173#[derive(GraphQLQuery)]
174#[graphql(
175    schema_path = "schema.gql",
176    query_path = "src/graphql/jobs.graphql",
177    response_derives = "Debug, Clone"
178)]
179pub struct ListJobs;
180
181#[derive(GraphQLQuery)]
182#[graphql(
183    schema_path = "schema.gql",
184    query_path = "src/graphql/cancel.graphql",
185    response_derives = "Debug, Clone"
186)]
187pub struct CancelJob;
188
189#[derive(GraphQLQuery)]
190#[graphql(
191    schema_path = "schema.gql",
192    query_path = "src/graphql/models.graphql",
193    response_derives = "Debug, Clone"
194)]
195pub struct ListModels;
196
197#[derive(GraphQLQuery)]
198#[graphql(
199    schema_path = "schema.gql",
200    query_path = "src/graphql/all_models.graphql",
201    response_derives = "Debug, Clone"
202)]
203pub struct ListAllModels;
204
205impl Display for get_job::JobStatus {
206    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
207        match self {
208            get_job::JobStatus::PENDING => write!(f, "Pending"),
209            get_job::JobStatus::RUNNING => write!(f, "Running"),
210            get_job::JobStatus::COMPLETED => write!(f, "Completed"),
211            get_job::JobStatus::FAILED => write!(f, "Failed"),
212            get_job::JobStatus::CANCELED => write!(f, "Canceled"),
213            get_job::JobStatus::Other(_) => write!(f, "Unknown"),
214        }
215    }
216}
217
218#[derive(GraphQLQuery)]
219#[graphql(
220    schema_path = "schema.gql",
221    query_path = "src/graphql/publish.graphql",
222    response_derives = "Debug, Clone"
223)]
224pub struct PublishCustomRecipe;
225
226#[derive(GraphQLQuery)]
227#[graphql(
228    schema_path = "schema.gql",
229    query_path = "src/graphql/update_recipe.graphql",
230    response_derives = "Debug, Clone"
231)]
232pub struct UpdateCustomRecipe;
233
234#[derive(GraphQLQuery)]
235#[graphql(
236    schema_path = "schema.gql",
237    query_path = "src/graphql/upload_dataset.graphql",
238    response_derives = "Debug, Clone"
239)]
240pub struct UploadDataset;
241
242#[derive(GraphQLQuery)]
243#[graphql(
244    schema_path = "schema.gql",
245    query_path = "src/graphql/create_dataset_from_multipart.graphql",
246    response_derives = "Debug, Clone"
247)]
248pub struct CreateDatasetFromMultipart;
249
250#[derive(GraphQLQuery)]
251#[graphql(
252    schema_path = "schema.gql",
253    query_path = "src/graphql/run.graphql",
254    response_derives = "Debug, Clone"
255)]
256pub struct RunCustomRecipe;
257
258#[derive(GraphQLQuery)]
259#[graphql(
260    schema_path = "schema.gql",
261    query_path = "src/graphql/projects.graphql",
262    response_derives = "Debug, Clone"
263)]
264pub struct ListProjects;
265
266#[derive(GraphQLQuery)]
267#[graphql(
268    schema_path = "schema.gql",
269    query_path = "src/graphql/pools.graphql",
270    response_derives = "Debug, Clone"
271)]
272pub struct ListComputePools;
273
274#[derive(GraphQLQuery)]
275#[graphql(
276    schema_path = "schema.gql",
277    query_path = "src/graphql/recipe.graphql",
278    response_derives = "Debug, Clone"
279)]
280pub struct GetRecipe;
281
282#[derive(GraphQLQuery)]
283#[graphql(
284    schema_path = "schema.gql",
285    query_path = "src/graphql/grader.graphql",
286    response_derives = "Debug, Clone, Serialize"
287)]
288pub struct GetGrader;
289
290#[derive(GraphQLQuery)]
291#[graphql(
292    schema_path = "schema.gql",
293    query_path = "src/graphql/dataset.graphql",
294    response_derives = "Debug, Clone"
295)]
296pub struct GetDataset;
297
298#[derive(GraphQLQuery)]
299#[graphql(
300    schema_path = "schema.gql",
301    query_path = "src/graphql/model_config.graphql",
302    response_derives = "Debug, Clone, Serialize"
303)]
304pub struct GetModelConfig;
305
306#[derive(GraphQLQuery)]
307#[graphql(
308    schema_path = "schema.gql",
309    query_path = "src/graphql/job_progress.graphql",
310    response_derives = "Debug, Clone"
311)]
312pub struct UpdateJobProgress;
313
314#[derive(GraphQLQuery)]
315#[graphql(
316    schema_path = "schema.gql",
317    query_path = "src/graphql/roles.graphql",
318    response_derives = "Debug, Clone"
319)]
320pub struct ListRoles;
321
322#[derive(GraphQLQuery)]
323#[graphql(
324    schema_path = "schema.gql",
325    query_path = "src/graphql/create_role.graphql",
326    response_derives = "Debug, Clone"
327)]
328pub struct CreateRole;
329
330#[derive(GraphQLQuery)]
331#[graphql(
332    schema_path = "schema.gql",
333    query_path = "src/graphql/teams.graphql",
334    response_derives = "Debug, Clone"
335)]
336pub struct ListTeams;
337
338#[derive(GraphQLQuery)]
339#[graphql(
340    schema_path = "schema.gql",
341    query_path = "src/graphql/create_team.graphql",
342    response_derives = "Debug, Clone"
343)]
344pub struct CreateTeam;
345
346#[derive(GraphQLQuery)]
347#[graphql(
348    schema_path = "schema.gql",
349    query_path = "src/graphql/users.graphql",
350    response_derives = "Debug, Clone"
351)]
352pub struct ListUsers;
353
354#[derive(GraphQLQuery)]
355#[graphql(
356    schema_path = "schema.gql",
357    query_path = "src/graphql/create_user.graphql",
358    response_derives = "Debug, Clone"
359)]
360pub struct CreateUser;
361
362#[derive(GraphQLQuery)]
363#[graphql(
364    schema_path = "schema.gql",
365    query_path = "src/graphql/delete_user.graphql",
366    response_derives = "Debug, Clone"
367)]
368pub struct DeleteUser;
369
370#[derive(GraphQLQuery)]
371#[graphql(
372    schema_path = "schema.gql",
373    query_path = "src/graphql/add_team_member.graphql",
374    response_derives = "Debug, Clone"
375)]
376pub struct AddTeamMember;
377
378#[derive(GraphQLQuery)]
379#[graphql(
380    schema_path = "schema.gql",
381    query_path = "src/graphql/remove_team_member.graphql",
382    response_derives = "Debug, Clone"
383)]
384pub struct RemoveTeamMember;
385
386const INIT_CHUNKED_UPLOAD_ROUTE: &str = "v1/upload/init";
387const UPLOAD_PART_ROUTE: &str = "v1/upload/part";
388const ABORT_CHUNKED_UPLOAD_ROUTE: &str = "v1/upload/abort";
389
390#[derive(Clone)]
391pub struct AdaptiveClient {
392    client: Client,
393    graphql_url: Url,
394    rest_base_url: Url,
395    auth_token: String,
396}
397
398impl AdaptiveClient {
399    pub fn new(api_base_url: Url, auth_token: String) -> Self {
400        let graphql_url = api_base_url
401            .join("graphql")
402            .expect("Failed to append graphql to base URL");
403
404        let client = Client::builder()
405            .user_agent(format!(
406                "adaptive-client-rust/{}",
407                env!("CARGO_PKG_VERSION")
408            ))
409            .build()
410            .expect("Failed to build HTTP client");
411
412        Self {
413            client,
414            graphql_url,
415            rest_base_url: api_base_url,
416            auth_token,
417        }
418    }
419
420    async fn execute_query<T>(&self, _query: T, variables: T::Variables) -> Result<T::ResponseData>
421    where
422        T: GraphQLQuery,
423        T::Variables: serde::Serialize,
424        T::ResponseData: DeserializeOwned,
425    {
426        let request_body = T::build_query(variables);
427
428        let response = self
429            .client
430            .post(self.graphql_url.clone())
431            .bearer_auth(&self.auth_token)
432            .json(&request_body)
433            .send()
434            .await?;
435
436        let status = response.status();
437        let response_text = response.text().await?;
438
439        if !status.is_success() {
440            return Err(AdaptiveError::HttpStatusError {
441                status: status.to_string(),
442                body: response_text,
443            });
444        }
445
446        let response_body: Response<T::ResponseData> = serde_json::from_str(&response_text)
447            .map_err(|e| AdaptiveError::JsonParseError {
448                error: e.to_string(),
449                body: response_text.chars().take(500).collect(),
450            })?;
451
452        match response_body.data {
453            Some(data) => Ok(data),
454            None => {
455                if let Some(errors) = response_body.errors {
456                    return Err(AdaptiveError::GraphQLErrors(errors));
457                }
458                Err(AdaptiveError::NoGraphQLData)
459            }
460        }
461    }
462
463    pub async fn list_recipes(
464        &self,
465        project: &str,
466    ) -> Result<Vec<get_custom_recipes::GetCustomRecipesCustomRecipes>> {
467        let variables = get_custom_recipes::Variables {
468            project: IdOrKey::from(project),
469        };
470
471        let response_data = self.execute_query(GetCustomRecipes, variables).await?;
472        Ok(response_data.custom_recipes)
473    }
474
475    pub async fn get_job(&self, job_id: Uuid) -> Result<get_job::GetJobJob> {
476        let variables = get_job::Variables { id: job_id };
477
478        let response_data = self.execute_query(GetJob, variables).await?;
479
480        match response_data.job {
481            Some(job) => Ok(job),
482            None => Err(AdaptiveError::JobNotFound(job_id)),
483        }
484    }
485
486    pub async fn upload_dataset<P: AsRef<Path>>(
487        &self,
488        project: &str,
489        name: &str,
490        dataset: P,
491    ) -> Result<upload_dataset::UploadDatasetCreateDataset> {
492        let variables = upload_dataset::Variables {
493            project: IdOrKey::from(project),
494            file: Upload(0),
495            name: Some(name.to_string()),
496        };
497
498        let operations = UploadDataset::build_query(variables);
499        let operations = serde_json::to_string(&operations)?;
500
501        let file_map = r#"{ "0": ["variables.file"] }"#;
502
503        let dataset_file = reqwest::multipart::Part::file(dataset).await?;
504
505        let form = reqwest::multipart::Form::new()
506            .text("operations", operations)
507            .text("map", file_map)
508            .part("0", dataset_file);
509
510        let response = self
511            .client
512            .post(self.graphql_url.clone())
513            .bearer_auth(&self.auth_token)
514            .multipart(form)
515            .send()
516            .await?;
517
518        let response: Response<<UploadDataset as graphql_client::GraphQLQuery>::ResponseData> =
519            response.json().await?;
520
521        match response.data {
522            Some(data) => Ok(data.create_dataset),
523            None => {
524                if let Some(errors) = response.errors {
525                    return Err(AdaptiveError::GraphQLErrors(errors));
526                }
527                Err(AdaptiveError::NoGraphQLData)
528            }
529        }
530    }
531
532    pub async fn publish_recipe<P: AsRef<Path>>(
533        &self,
534        project: &str,
535        name: &str,
536        key: &str,
537        recipe: P,
538    ) -> Result<publish_custom_recipe::PublishCustomRecipeCreateCustomRecipe> {
539        let variables = publish_custom_recipe::Variables {
540            project: IdOrKey::from(project),
541            file: Upload(0),
542            name: Some(name.to_string()),
543            key: Some(key.to_string()),
544        };
545
546        let operations = PublishCustomRecipe::build_query(variables);
547        let operations = serde_json::to_string(&operations)?;
548
549        let file_map = r#"{ "0": ["variables.file"] }"#;
550
551        let recipe_file = reqwest::multipart::Part::file(recipe).await?;
552
553        let form = reqwest::multipart::Form::new()
554            .text("operations", operations)
555            .text("map", file_map)
556            .part("0", recipe_file);
557
558        let response = self
559            .client
560            .post(self.graphql_url.clone())
561            .bearer_auth(&self.auth_token)
562            .multipart(form)
563            .send()
564            .await?;
565        let response: Response<
566            <PublishCustomRecipe as graphql_client::GraphQLQuery>::ResponseData,
567        > = response.json().await?;
568
569        match response.data {
570            Some(data) => Ok(data.create_custom_recipe),
571            None => {
572                if let Some(errors) = response.errors {
573                    return Err(AdaptiveError::GraphQLErrors(errors));
574                }
575                Err(AdaptiveError::NoGraphQLData)
576            }
577        }
578    }
579
580    pub async fn update_recipe<P: AsRef<Path>>(
581        &self,
582        project: &str,
583        id: &str,
584        name: Option<String>,
585        description: Option<String>,
586        labels: Option<Vec<update_custom_recipe::LabelInput>>,
587        recipe_file: Option<P>,
588    ) -> Result<update_custom_recipe::UpdateCustomRecipeUpdateCustomRecipe> {
589        let input = update_custom_recipe::UpdateRecipeInput {
590            name,
591            description,
592            labels,
593        };
594
595        match recipe_file {
596            Some(file_path) => {
597                let variables = update_custom_recipe::Variables {
598                    project: IdOrKey::from(project),
599                    id: IdOrKey::from(id),
600                    input,
601                    file: Some(Upload(0)),
602                };
603
604                let operations = UpdateCustomRecipe::build_query(variables);
605                let operations = serde_json::to_string(&operations)?;
606
607                let file_map = r#"{ "0": ["variables.file"] }"#;
608
609                let recipe_file = reqwest::multipart::Part::file(file_path).await?;
610
611                let form = reqwest::multipart::Form::new()
612                    .text("operations", operations)
613                    .text("map", file_map)
614                    .part("0", recipe_file);
615
616                let response = self
617                    .client
618                    .post(self.graphql_url.clone())
619                    .bearer_auth(&self.auth_token)
620                    .multipart(form)
621                    .send()
622                    .await?;
623                let response: Response<
624                    <UpdateCustomRecipe as graphql_client::GraphQLQuery>::ResponseData,
625                > = response.json().await?;
626
627                match response.data {
628                    Some(data) => Ok(data.update_custom_recipe),
629                    None => {
630                        if let Some(errors) = response.errors {
631                            return Err(AdaptiveError::GraphQLErrors(errors));
632                        }
633                        Err(AdaptiveError::NoGraphQLData)
634                    }
635                }
636            }
637            None => {
638                let variables = update_custom_recipe::Variables {
639                    project: IdOrKey::from(project),
640                    id: IdOrKey::from(id),
641                    input,
642                    file: None,
643                };
644
645                let response_data = self.execute_query(UpdateCustomRecipe, variables).await?;
646                Ok(response_data.update_custom_recipe)
647            }
648        }
649    }
650
651    pub async fn run_recipe(
652        &self,
653        project: &str,
654        recipe_id: &str,
655        parameters: Map<String, Value>,
656        name: Option<String>,
657        compute_pool: Option<String>,
658        num_gpus: u32,
659        use_experimental_runner: bool,
660    ) -> Result<run_custom_recipe::RunCustomRecipeCreateJob> {
661        let variables = run_custom_recipe::Variables {
662            input: run_custom_recipe::JobInput {
663                recipe: recipe_id.to_string(),
664                project: project.to_string(),
665                args: parameters,
666                name,
667                compute_pool,
668                num_gpus: num_gpus as i64,
669                use_experimental_runner,
670                max_cpu: None,
671                max_ram_gb: None,
672                max_duration_secs: None,
673            },
674        };
675
676        let response_data = self.execute_query(RunCustomRecipe, variables).await?;
677        Ok(response_data.create_job)
678    }
679
680    pub async fn list_jobs(
681        &self,
682        project: Option<String>,
683    ) -> Result<Vec<list_jobs::ListJobsJobsNodes>> {
684        let mut jobs = Vec::new();
685        let mut page = self.list_jobs_page(project.clone(), None).await?;
686        jobs.extend(page.nodes.iter().cloned());
687        while page.page_info.has_next_page {
688            page = self
689                .list_jobs_page(project.clone(), page.page_info.end_cursor)
690                .await?;
691            jobs.extend(page.nodes.iter().cloned());
692        }
693        Ok(jobs)
694    }
695
696    async fn list_jobs_page(
697        &self,
698        project: Option<String>,
699        after: Option<String>,
700    ) -> Result<list_jobs::ListJobsJobs> {
701        let variables = list_jobs::Variables {
702            filter: Some(list_jobs::ListJobsFilterInput {
703                project,
704                kind: Some(vec![list_jobs::JobKind::CUSTOM]),
705                status: Some(vec![
706                    list_jobs::JobStatus::RUNNING,
707                    list_jobs::JobStatus::PENDING,
708                ]),
709                timerange: None,
710                custom_recipes: None,
711                artifacts: None,
712            }),
713            cursor: Some(list_jobs::CursorPageInput {
714                first: Some(PAGE_SIZE as i64),
715                after,
716                before: None,
717                last: None,
718                offset: None,
719            }),
720        };
721
722        let response_data = self.execute_query(ListJobs, variables).await?;
723        Ok(response_data.jobs)
724    }
725
726    pub async fn cancel_job(&self, job_id: Uuid) -> Result<cancel_job::CancelJobCancelJob> {
727        let variables = cancel_job::Variables { job_id };
728
729        let response_data = self.execute_query(CancelJob, variables).await?;
730        Ok(response_data.cancel_job)
731    }
732
733    pub async fn update_job_progress(
734        &self,
735        job_id: Uuid,
736        event: update_job_progress::JobProgressEventInput,
737    ) -> Result<update_job_progress::UpdateJobProgressUpdateJobProgress> {
738        let variables = update_job_progress::Variables { job_id, event };
739
740        let response_data = self.execute_query(UpdateJobProgress, variables).await?;
741        Ok(response_data.update_job_progress)
742    }
743
744    pub async fn list_models(
745        &self,
746        project: String,
747    ) -> Result<Vec<list_models::ListModelsProjectModelServices>> {
748        let variables = list_models::Variables { project };
749
750        let response_data = self.execute_query(ListModels, variables).await?;
751        Ok(response_data
752            .project
753            .map(|project| project.model_services)
754            .unwrap_or(Vec::new()))
755    }
756
757    pub async fn list_all_models(&self) -> Result<Vec<list_all_models::ListAllModelsModels>> {
758        let variables = list_all_models::Variables {};
759
760        let response_data = self.execute_query(ListAllModels, variables).await?;
761        Ok(response_data.models)
762    }
763
764    pub async fn list_projects(&self) -> Result<Vec<list_projects::ListProjectsProjects>> {
765        let variables = list_projects::Variables {};
766
767        let response_data = self.execute_query(ListProjects, variables).await?;
768        Ok(response_data.projects)
769    }
770
771    pub async fn list_pools(
772        &self,
773    ) -> Result<Vec<list_compute_pools::ListComputePoolsComputePools>> {
774        let variables = list_compute_pools::Variables {};
775
776        let response_data = self.execute_query(ListComputePools, variables).await?;
777        Ok(response_data.compute_pools)
778    }
779
780    pub async fn list_roles(&self) -> Result<Vec<list_roles::ListRolesRoles>> {
781        let variables = list_roles::Variables {};
782
783        let response_data = self.execute_query(ListRoles, variables).await?;
784        Ok(response_data.roles)
785    }
786
787    pub async fn create_role(
788        &self,
789        name: &str,
790        key: Option<&str>,
791        permissions: Vec<String>,
792    ) -> Result<create_role::CreateRoleCreateRole> {
793        let variables = create_role::Variables {
794            input: create_role::RoleCreate {
795                name: name.to_string(),
796                key: key.map(|k| k.to_string()),
797                permissions,
798            },
799        };
800
801        let response_data = self.execute_query(CreateRole, variables).await?;
802        Ok(response_data.create_role)
803    }
804
805    pub async fn list_teams(&self) -> Result<Vec<list_teams::ListTeamsTeams>> {
806        let variables = list_teams::Variables {};
807
808        let response_data = self.execute_query(ListTeams, variables).await?;
809        Ok(response_data.teams)
810    }
811
812    pub async fn create_team(
813        &self,
814        name: &str,
815        key: Option<&str>,
816    ) -> Result<create_team::CreateTeamCreateTeam> {
817        let variables = create_team::Variables {
818            input: create_team::TeamCreate {
819                name: name.to_string(),
820                key: key.map(|k| k.to_string()),
821            },
822        };
823
824        let response_data = self.execute_query(CreateTeam, variables).await?;
825        Ok(response_data.create_team)
826    }
827
828    pub async fn list_users(&self) -> Result<Vec<list_users::ListUsersUsers>> {
829        let variables = list_users::Variables {};
830
831        let response_data = self.execute_query(ListUsers, variables).await?;
832        Ok(response_data.users)
833    }
834
835    pub async fn create_user(
836        &self,
837        name: &str,
838        email: Option<&str>,
839        teams: Vec<create_user::UserCreateTeamWithRole>,
840        user_type: Option<create_user::UserType>,
841        generate_api_key: Option<bool>,
842    ) -> Result<create_user::CreateUserCreateUser> {
843        let variables = create_user::Variables {
844            input: create_user::UserCreate {
845                name: name.to_string(),
846                email: email.map(|e| e.to_string()),
847                teams,
848                user_type: user_type.unwrap_or(create_user::UserType::HUMAN),
849                generate_api_key,
850            },
851        };
852
853        let response_data = self.execute_query(CreateUser, variables).await?;
854        Ok(response_data.create_user)
855    }
856
857    pub async fn delete_user(&self, user: &str) -> Result<delete_user::DeleteUserDeleteUser> {
858        let variables = delete_user::Variables {
859            user: user.to_string(),
860        };
861
862        let response_data = self.execute_query(DeleteUser, variables).await?;
863        Ok(response_data.delete_user)
864    }
865
866    pub async fn add_team_member(
867        &self,
868        user: &str,
869        team: &str,
870        role: &str,
871    ) -> Result<add_team_member::AddTeamMemberSetTeamMember> {
872        let variables = add_team_member::Variables {
873            input: add_team_member::TeamMemberSet {
874                user: user.to_string(),
875                team: team.to_string(),
876                role: role.to_string(),
877            },
878        };
879
880        let response_data = self.execute_query(AddTeamMember, variables).await?;
881        Ok(response_data.set_team_member)
882    }
883
884    pub async fn remove_team_member(
885        &self,
886        user: &str,
887        team: &str,
888    ) -> Result<remove_team_member::RemoveTeamMemberRemoveTeamMember> {
889        let variables = remove_team_member::Variables {
890            input: remove_team_member::TeamMemberRemove {
891                user: user.to_string(),
892                team: team.to_string(),
893            },
894        };
895
896        let response_data = self.execute_query(RemoveTeamMember, variables).await?;
897        Ok(response_data.remove_team_member)
898    }
899
900    pub async fn get_recipe(
901        &self,
902        project: String,
903        id_or_key: String,
904    ) -> Result<Option<get_recipe::GetRecipeCustomRecipe>> {
905        let variables = get_recipe::Variables { project, id_or_key };
906
907        let response_data = self.execute_query(GetRecipe, variables).await?;
908        Ok(response_data.custom_recipe)
909    }
910
911    pub async fn get_grader(
912        &self,
913        id_or_key: &str,
914        project: &str,
915    ) -> Result<get_grader::GetGraderGrader> {
916        let variables = get_grader::Variables {
917            id: id_or_key.to_string(),
918            project: project.to_string(),
919        };
920
921        let response_data = self.execute_query(GetGrader, variables).await?;
922        Ok(response_data.grader)
923    }
924
925    pub async fn get_dataset(
926        &self,
927        id_or_key: &str,
928        project: &str,
929    ) -> Result<Option<get_dataset::GetDatasetDataset>> {
930        let variables = get_dataset::Variables {
931            id_or_key: id_or_key.to_string(),
932            project: project.to_string(),
933        };
934
935        let response_data = self.execute_query(GetDataset, variables).await?;
936        Ok(response_data.dataset)
937    }
938
939    pub async fn get_model_config(
940        &self,
941        id_or_key: &str,
942    ) -> Result<Option<get_model_config::GetModelConfigModel>> {
943        let variables = get_model_config::Variables {
944            id_or_key: id_or_key.to_string(),
945        };
946
947        let response_data = self.execute_query(GetModelConfig, variables).await?;
948        Ok(response_data.model)
949    }
950
951    pub fn base_url(&self) -> &Url {
952        &self.rest_base_url
953    }
954
955    /// Upload bytes using the chunked upload API and return the session_id.
956    /// This can be used to link the uploaded file to an artifact.
957    pub async fn upload_bytes(&self, data: &[u8], content_type: &str) -> Result<String> {
958        let file_size = data.len() as u64;
959
960        // Calculate chunk size (same logic as calculate_upload_parts but inline for small files)
961        let chunk_size = if file_size < 5 * 1024 * 1024 {
962            // For files < 5MB, use the whole file as one chunk
963            file_size.max(1)
964        } else if file_size < 500 * 1024 * 1024 {
965            5 * 1024 * 1024
966        } else if file_size < 10 * 1024 * 1024 * 1024 {
967            10 * 1024 * 1024
968        } else {
969            100 * 1024 * 1024
970        };
971
972        let total_parts = file_size.div_ceil(chunk_size).max(1);
973
974        // Initialize upload session
975        let session_id = self
976            .init_chunked_upload_with_content_type(total_parts, content_type)
977            .await?;
978
979        // Upload parts
980        for part_number in 1..=total_parts {
981            let start = ((part_number - 1) * chunk_size) as usize;
982            let end = (part_number * chunk_size).min(file_size) as usize;
983            let chunk = data[start..end].to_vec();
984
985            if let Err(e) = self
986                .upload_part_simple(&session_id, part_number, chunk)
987                .await
988            {
989                let _ = self.abort_chunked_upload(&session_id).await;
990                return Err(e);
991            }
992        }
993
994        Ok(session_id)
995    }
996
997    /// Initialize a chunked upload session.
998    pub async fn init_chunked_upload_with_content_type(
999        &self,
1000        total_parts: u64,
1001        content_type: &str,
1002    ) -> Result<String> {
1003        let url = self.rest_base_url.join(INIT_CHUNKED_UPLOAD_ROUTE)?;
1004
1005        let request = InitChunkedUploadRequest {
1006            content_type: content_type.to_string(),
1007            metadata: None,
1008            total_parts_count: total_parts,
1009        };
1010
1011        let response = self
1012            .client
1013            .post(url)
1014            .bearer_auth(&self.auth_token)
1015            .json(&request)
1016            .send()
1017            .await?;
1018
1019        if !response.status().is_success() {
1020            return Err(AdaptiveError::ChunkedUploadInitFailed {
1021                status: response.status().to_string(),
1022                body: response.text().await.unwrap_or_default(),
1023            });
1024        }
1025
1026        let init_response: InitChunkedUploadResponse = response.json().await?;
1027        Ok(init_response.session_id)
1028    }
1029
1030    /// Upload a single part of a chunked upload.
1031    pub async fn upload_part_simple(
1032        &self,
1033        session_id: &str,
1034        part_number: u64,
1035        data: Vec<u8>,
1036    ) -> Result<()> {
1037        let url = self.rest_base_url.join(UPLOAD_PART_ROUTE)?;
1038
1039        let response = self
1040            .client
1041            .post(url)
1042            .bearer_auth(&self.auth_token)
1043            .query(&[
1044                ("session_id", session_id),
1045                ("part_number", &part_number.to_string()),
1046            ])
1047            .header("Content-Type", "application/octet-stream")
1048            .body(data)
1049            .send()
1050            .await?;
1051
1052        if !response.status().is_success() {
1053            return Err(AdaptiveError::ChunkedUploadPartFailed {
1054                part_number,
1055                status: response.status().to_string(),
1056                body: response.text().await.unwrap_or_default(),
1057            });
1058        }
1059
1060        Ok(())
1061    }
1062
1063    async fn init_chunked_upload(&self, total_parts: u64) -> Result<String> {
1064        let url = self.rest_base_url.join(INIT_CHUNKED_UPLOAD_ROUTE)?;
1065
1066        let request = InitChunkedUploadRequest {
1067            content_type: "application/jsonl".to_string(),
1068            metadata: None,
1069            total_parts_count: total_parts,
1070        };
1071
1072        let response = self
1073            .client
1074            .post(url)
1075            .bearer_auth(&self.auth_token)
1076            .json(&request)
1077            .send()
1078            .await?;
1079
1080        if !response.status().is_success() {
1081            return Err(AdaptiveError::ChunkedUploadInitFailed {
1082                status: response.status().to_string(),
1083                body: response.text().await.unwrap_or_default(),
1084            });
1085        }
1086
1087        let init_response: InitChunkedUploadResponse = response.json().await?;
1088        Ok(init_response.session_id)
1089    }
1090
1091    async fn upload_part(
1092        &self,
1093        session_id: &str,
1094        part_number: u64,
1095        data: Vec<u8>,
1096        progress_tx: mpsc::Sender<u64>,
1097    ) -> Result<()> {
1098        const SUB_CHUNK_SIZE: usize = 64 * 1024;
1099
1100        let url = self.rest_base_url.join(UPLOAD_PART_ROUTE)?;
1101
1102        let chunks: Vec<Vec<u8>> = data
1103            .chunks(SUB_CHUNK_SIZE)
1104            .map(|chunk| chunk.to_vec())
1105            .collect();
1106
1107        let stream = futures::stream::iter(chunks).map(move |chunk| {
1108            let len = chunk.len() as u64;
1109            let tx = progress_tx.clone();
1110            let _ = tx.try_send(len);
1111            Ok::<_, std::io::Error>(chunk)
1112        });
1113
1114        let body = reqwest::Body::wrap_stream(stream);
1115
1116        let response = self
1117            .client
1118            .post(url)
1119            .bearer_auth(&self.auth_token)
1120            .query(&[
1121                ("session_id", session_id),
1122                ("part_number", &part_number.to_string()),
1123            ])
1124            .header("Content-Type", "application/octet-stream")
1125            .body(body)
1126            .send()
1127            .await?;
1128
1129        if !response.status().is_success() {
1130            return Err(AdaptiveError::ChunkedUploadPartFailed {
1131                part_number,
1132                status: response.status().to_string(),
1133                body: response.text().await.unwrap_or_default(),
1134            });
1135        }
1136
1137        Ok(())
1138    }
1139
1140    /// Abort a chunked upload session.
1141    pub async fn abort_chunked_upload(&self, session_id: &str) -> Result<()> {
1142        let url = self.rest_base_url.join(ABORT_CHUNKED_UPLOAD_ROUTE)?;
1143
1144        let request = AbortChunkedUploadRequest {
1145            session_id: session_id.to_string(),
1146        };
1147
1148        let _ = self
1149            .client
1150            .delete(url)
1151            .bearer_auth(&self.auth_token)
1152            .json(&request)
1153            .send()
1154            .await;
1155
1156        Ok(())
1157    }
1158
1159    async fn create_dataset_from_multipart(
1160        &self,
1161        project: &str,
1162        name: &str,
1163        key: &str,
1164        session_id: &str,
1165    ) -> Result<
1166        create_dataset_from_multipart::CreateDatasetFromMultipartCreateDatasetFromMultipartUpload,
1167    > {
1168        let variables = create_dataset_from_multipart::Variables {
1169            input: create_dataset_from_multipart::DatasetCreateFromMultipartUpload {
1170                project: project.to_string(),
1171                name: name.to_string(),
1172                key: Some(key.to_string()),
1173                source: None,
1174                upload_session_id: session_id.to_string(),
1175            },
1176        };
1177
1178        let response_data = self
1179            .execute_query(CreateDatasetFromMultipart, variables)
1180            .await?;
1181        Ok(response_data.create_dataset_from_multipart_upload)
1182    }
1183
1184    pub fn chunked_upload_dataset<'a, P: AsRef<Path> + Send + 'a>(
1185        &'a self,
1186        project: &'a str,
1187        name: &'a str,
1188        key: &'a str,
1189        dataset: P,
1190    ) -> Result<BoxStream<'a, Result<UploadEvent>>> {
1191        let file_size = std::fs::metadata(dataset.as_ref())?.len();
1192
1193        let (total_parts, chunk_size) = calculate_upload_parts(file_size)?;
1194
1195        let stream = async_stream::try_stream! {
1196            yield UploadEvent::Progress(ChunkedUploadProgress {
1197                bytes_uploaded: 0,
1198                total_bytes: file_size,
1199            });
1200
1201            let session_id = self.init_chunked_upload(total_parts).await?;
1202
1203            let mut file = File::open(dataset.as_ref())?;
1204            let mut buffer = vec![0u8; chunk_size as usize];
1205            let mut bytes_uploaded = 0u64;
1206
1207            let (progress_tx, mut progress_rx) = mpsc::channel::<u64>(64);
1208
1209            for part_number in 1..=total_parts {
1210                let bytes_read = file.read(&mut buffer)?;
1211                let chunk_data = buffer[..bytes_read].to_vec();
1212
1213                let upload_fut = self.upload_part(&session_id, part_number, chunk_data, progress_tx.clone());
1214                tokio::pin!(upload_fut);
1215
1216                let upload_result: Result<()> = loop {
1217                    tokio::select! {
1218                        biased;
1219                        result = &mut upload_fut => {
1220                            break result;
1221                        }
1222                        Some(bytes) = progress_rx.recv() => {
1223                            bytes_uploaded += bytes;
1224                            yield UploadEvent::Progress(ChunkedUploadProgress {
1225                                bytes_uploaded,
1226                                total_bytes: file_size,
1227                            });
1228                        }
1229                    }
1230                };
1231
1232                if let Err(e) = upload_result {
1233                    let _ = self.abort_chunked_upload(&session_id).await;
1234                    Err(e)?;
1235                }
1236            }
1237
1238            let create_result = self
1239                .create_dataset_from_multipart(project, name, key, &session_id)
1240                .await;
1241
1242            match create_result {
1243                Ok(response) => {
1244                    yield UploadEvent::Complete(response);
1245                }
1246                Err(e) => {
1247                    let _ = self.abort_chunked_upload(&session_id).await;
1248                    Err(AdaptiveError::DatasetCreationFailed(e.to_string()))?;
1249                }
1250            }
1251        };
1252
1253        Ok(Box::pin(stream))
1254    }
1255
1256    /// Download a file from the given URL and write it to the specified path.
1257    /// The URL can be absolute or relative to the API base URL.
1258    pub async fn download_file_to_path(&self, url: &str, dest_path: &Path) -> Result<()> {
1259        use tokio::io::AsyncWriteExt;
1260
1261        let full_url = if url.starts_with("http://") || url.starts_with("https://") {
1262            Url::parse(url)?
1263        } else {
1264            self.rest_base_url.join(url)?
1265        };
1266
1267        let response = self
1268            .client
1269            .get(full_url)
1270            .bearer_auth(&self.auth_token)
1271            .send()
1272            .await?;
1273
1274        if !response.status().is_success() {
1275            return Err(AdaptiveError::HttpError(
1276                response.error_for_status().unwrap_err(),
1277            ));
1278        }
1279
1280        let mut file = tokio::fs::File::create(dest_path).await?;
1281        let mut stream = response.bytes_stream();
1282
1283        while let Some(chunk) = stream.next().await {
1284            let chunk = chunk?;
1285            file.write_all(&chunk).await?;
1286        }
1287
1288        file.flush().await?;
1289        Ok(())
1290    }
1291}