1use std::{fmt::Display, fs::File, io::Read, path::Path, time::SystemTime};
2
3use futures::{StreamExt, stream::BoxStream};
4use thiserror::Error;
5use tokio::sync::mpsc;
6
7use graphql_client::{GraphQLQuery, Response};
8use reqwest::Client;
9use serde::{Deserialize, Serialize, de::DeserializeOwned};
10use serde_json::{Map, Value};
11use url::Url;
12use uuid::Uuid;
13
14mod rest_types;
15mod serde_utils;
16
17use rest_types::{AbortChunkedUploadRequest, InitChunkedUploadRequest, InitChunkedUploadResponse};
18
19const MEGABYTE: u64 = 1024 * 1024; pub const MIN_CHUNK_SIZE_BYTES: u64 = 5 * MEGABYTE;
21const MAX_CHUNK_SIZE_BYTES: u64 = 100 * MEGABYTE;
22const MAX_PARTS_COUNT: u64 = 10000;
23
24const SIZE_500MB: u64 = 500 * MEGABYTE;
25const SIZE_10GB: u64 = 10 * 1024 * MEGABYTE;
26const SIZE_50GB: u64 = 50 * 1024 * MEGABYTE;
27
28#[derive(Error, Debug)]
29pub enum AdaptiveError {
30 #[error("HTTP error: {0}")]
31 HttpError(#[from] reqwest::Error),
32
33 #[error("JSON error: {0}")]
34 JsonError(#[from] serde_json::Error),
35
36 #[error("IO error: {0}")]
37 IoError(#[from] std::io::Error),
38
39 #[error("URL parse error: {0}")]
40 UrlParseError(#[from] url::ParseError),
41
42 #[error("File too small for chunked upload: {size} bytes (minimum: {min_size} bytes)")]
43 FileTooSmall { size: u64, min_size: u64 },
44
45 #[error("File too large: {size} bytes exceeds maximum {max_size} bytes")]
46 FileTooLarge { size: u64, max_size: u64 },
47
48 #[error("GraphQL errors: {0:?}")]
49 GraphQLErrors(Vec<graphql_client::Error>),
50
51 #[error("No data returned from GraphQL")]
52 NoGraphQLData,
53
54 #[error("Job not found: {0}")]
55 JobNotFound(Uuid),
56
57 #[error("Failed to initialize chunked upload: {status} - {body}")]
58 ChunkedUploadInitFailed { status: String, body: String },
59
60 #[error("Failed to upload part {part_number}: {status} - {body}")]
61 ChunkedUploadPartFailed {
62 part_number: u64,
63 status: String,
64 body: String,
65 },
66
67 #[error("Failed to create dataset: {0}")]
68 DatasetCreationFailed(String),
69
70 #[error("HTTP status error: {status} - {body}")]
71 HttpStatusError { status: String, body: String },
72
73 #[error("Failed to parse JSON response: {error}. Body preview: {body}")]
74 JsonParseError { error: String, body: String },
75}
76
77type Result<T> = std::result::Result<T, AdaptiveError>;
78
79#[derive(Clone, Debug, Default)]
80pub struct ChunkedUploadProgress {
81 pub bytes_uploaded: u64,
82 pub total_bytes: u64,
83}
84
85#[derive(Debug)]
86pub enum UploadEvent {
87 Progress(ChunkedUploadProgress),
88 Complete(
89 create_dataset_from_multipart::CreateDatasetFromMultipartCreateDatasetFromMultipartUpload,
90 ),
91}
92
93pub fn calculate_upload_parts(file_size: u64) -> Result<(u64, u64)> {
94 if file_size < MIN_CHUNK_SIZE_BYTES {
95 return Err(AdaptiveError::FileTooSmall {
96 size: file_size,
97 min_size: MIN_CHUNK_SIZE_BYTES,
98 });
99 }
100
101 let mut chunk_size = if file_size < SIZE_500MB {
102 5 * MEGABYTE
103 } else if file_size < SIZE_10GB {
104 10 * MEGABYTE
105 } else if file_size < SIZE_50GB {
106 50 * MEGABYTE
107 } else {
108 100 * MEGABYTE
109 };
110
111 let mut total_parts = file_size.div_ceil(chunk_size);
112
113 if total_parts > MAX_PARTS_COUNT {
114 chunk_size = file_size.div_ceil(MAX_PARTS_COUNT);
115
116 if chunk_size > MAX_CHUNK_SIZE_BYTES {
117 let max_file_size = MAX_CHUNK_SIZE_BYTES * MAX_PARTS_COUNT;
118 return Err(AdaptiveError::FileTooLarge {
119 size: file_size,
120 max_size: max_file_size,
121 });
122 }
123
124 total_parts = file_size.div_ceil(chunk_size);
125 }
126
127 Ok((total_parts, chunk_size))
128}
129
130type IdOrKey = String;
131#[allow(clippy::upper_case_acronyms)]
132type UUID = Uuid;
133type JsObject = Map<String, Value>;
134type InputDatetime = String;
135#[allow(clippy::upper_case_acronyms)]
136type JSON = Value;
137type KeyInput = String;
138
139#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
140pub struct Timestamp(pub SystemTime);
141
142impl<'de> serde::Deserialize<'de> for Timestamp {
143 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
144 where
145 D: serde::Deserializer<'de>,
146 {
147 let system_time = serde_utils::deserialize_timestamp_millis(deserializer)?;
148 Ok(Timestamp(system_time))
149 }
150}
151
152const PAGE_SIZE: usize = 20;
153
154#[derive(Debug, PartialEq, Serialize, Deserialize)]
155pub struct Upload(usize);
156
157#[derive(GraphQLQuery)]
158#[graphql(
159 schema_path = "schema.gql",
160 query_path = "src/graphql/list.graphql",
161 response_derives = "Debug, Clone"
162)]
163pub struct GetCustomRecipes;
164
165#[derive(GraphQLQuery)]
166#[graphql(
167 schema_path = "schema.gql",
168 query_path = "src/graphql/job.graphql",
169 response_derives = "Debug, Clone"
170)]
171pub struct GetJob;
172
173#[derive(GraphQLQuery)]
174#[graphql(
175 schema_path = "schema.gql",
176 query_path = "src/graphql/jobs.graphql",
177 response_derives = "Debug, Clone"
178)]
179pub struct ListJobs;
180
181#[derive(GraphQLQuery)]
182#[graphql(
183 schema_path = "schema.gql",
184 query_path = "src/graphql/cancel.graphql",
185 response_derives = "Debug, Clone"
186)]
187pub struct CancelJob;
188
189#[derive(GraphQLQuery)]
190#[graphql(
191 schema_path = "schema.gql",
192 query_path = "src/graphql/models.graphql",
193 response_derives = "Debug, Clone"
194)]
195pub struct ListModels;
196
197#[derive(GraphQLQuery)]
198#[graphql(
199 schema_path = "schema.gql",
200 query_path = "src/graphql/all_models.graphql",
201 response_derives = "Debug, Clone"
202)]
203pub struct ListAllModels;
204
205impl Display for get_job::JobStatus {
206 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
207 match self {
208 get_job::JobStatus::PENDING => write!(f, "Pending"),
209 get_job::JobStatus::RUNNING => write!(f, "Running"),
210 get_job::JobStatus::COMPLETED => write!(f, "Completed"),
211 get_job::JobStatus::FAILED => write!(f, "Failed"),
212 get_job::JobStatus::CANCELED => write!(f, "Canceled"),
213 get_job::JobStatus::Other(_) => write!(f, "Unknown"),
214 }
215 }
216}
217
218#[derive(GraphQLQuery)]
219#[graphql(
220 schema_path = "schema.gql",
221 query_path = "src/graphql/publish.graphql",
222 response_derives = "Debug, Clone"
223)]
224pub struct PublishCustomRecipe;
225
226#[derive(GraphQLQuery)]
227#[graphql(
228 schema_path = "schema.gql",
229 query_path = "src/graphql/update_recipe.graphql",
230 response_derives = "Debug, Clone"
231)]
232pub struct UpdateCustomRecipe;
233
234#[derive(GraphQLQuery)]
235#[graphql(
236 schema_path = "schema.gql",
237 query_path = "src/graphql/upload_dataset.graphql",
238 response_derives = "Debug, Clone"
239)]
240pub struct UploadDataset;
241
242#[derive(GraphQLQuery)]
243#[graphql(
244 schema_path = "schema.gql",
245 query_path = "src/graphql/create_dataset_from_multipart.graphql",
246 response_derives = "Debug, Clone"
247)]
248pub struct CreateDatasetFromMultipart;
249
250#[derive(GraphQLQuery)]
251#[graphql(
252 schema_path = "schema.gql",
253 query_path = "src/graphql/run.graphql",
254 response_derives = "Debug, Clone"
255)]
256pub struct RunCustomRecipe;
257
258#[derive(GraphQLQuery)]
259#[graphql(
260 schema_path = "schema.gql",
261 query_path = "src/graphql/projects.graphql",
262 response_derives = "Debug, Clone"
263)]
264pub struct ListProjects;
265
266#[derive(GraphQLQuery)]
267#[graphql(
268 schema_path = "schema.gql",
269 query_path = "src/graphql/pools.graphql",
270 response_derives = "Debug, Clone"
271)]
272pub struct ListComputePools;
273
274#[derive(GraphQLQuery)]
275#[graphql(
276 schema_path = "schema.gql",
277 query_path = "src/graphql/recipe.graphql",
278 response_derives = "Debug, Clone"
279)]
280pub struct GetRecipe;
281
282#[derive(GraphQLQuery)]
283#[graphql(
284 schema_path = "schema.gql",
285 query_path = "src/graphql/grader.graphql",
286 response_derives = "Debug, Clone, Serialize"
287)]
288pub struct GetGrader;
289
290#[derive(GraphQLQuery)]
291#[graphql(
292 schema_path = "schema.gql",
293 query_path = "src/graphql/dataset.graphql",
294 response_derives = "Debug, Clone"
295)]
296pub struct GetDataset;
297
298#[derive(GraphQLQuery)]
299#[graphql(
300 schema_path = "schema.gql",
301 query_path = "src/graphql/model_config.graphql",
302 response_derives = "Debug, Clone, Serialize"
303)]
304pub struct GetModelConfig;
305
306#[derive(GraphQLQuery)]
307#[graphql(
308 schema_path = "schema.gql",
309 query_path = "src/graphql/job_progress.graphql",
310 response_derives = "Debug, Clone"
311)]
312pub struct UpdateJobProgress;
313
314#[derive(GraphQLQuery)]
315#[graphql(
316 schema_path = "schema.gql",
317 query_path = "src/graphql/roles.graphql",
318 response_derives = "Debug, Clone"
319)]
320pub struct ListRoles;
321
322#[derive(GraphQLQuery)]
323#[graphql(
324 schema_path = "schema.gql",
325 query_path = "src/graphql/create_role.graphql",
326 response_derives = "Debug, Clone"
327)]
328pub struct CreateRole;
329
330#[derive(GraphQLQuery)]
331#[graphql(
332 schema_path = "schema.gql",
333 query_path = "src/graphql/teams.graphql",
334 response_derives = "Debug, Clone"
335)]
336pub struct ListTeams;
337
338#[derive(GraphQLQuery)]
339#[graphql(
340 schema_path = "schema.gql",
341 query_path = "src/graphql/create_team.graphql",
342 response_derives = "Debug, Clone"
343)]
344pub struct CreateTeam;
345
346#[derive(GraphQLQuery)]
347#[graphql(
348 schema_path = "schema.gql",
349 query_path = "src/graphql/users.graphql",
350 response_derives = "Debug, Clone"
351)]
352pub struct ListUsers;
353
354#[derive(GraphQLQuery)]
355#[graphql(
356 schema_path = "schema.gql",
357 query_path = "src/graphql/create_user.graphql",
358 response_derives = "Debug, Clone"
359)]
360pub struct CreateUser;
361
362#[derive(GraphQLQuery)]
363#[graphql(
364 schema_path = "schema.gql",
365 query_path = "src/graphql/delete_user.graphql",
366 response_derives = "Debug, Clone"
367)]
368pub struct DeleteUser;
369
370#[derive(GraphQLQuery)]
371#[graphql(
372 schema_path = "schema.gql",
373 query_path = "src/graphql/add_team_member.graphql",
374 response_derives = "Debug, Clone"
375)]
376pub struct AddTeamMember;
377
378#[derive(GraphQLQuery)]
379#[graphql(
380 schema_path = "schema.gql",
381 query_path = "src/graphql/remove_team_member.graphql",
382 response_derives = "Debug, Clone"
383)]
384pub struct RemoveTeamMember;
385
386const INIT_CHUNKED_UPLOAD_ROUTE: &str = "v1/upload/init";
387const UPLOAD_PART_ROUTE: &str = "v1/upload/part";
388const ABORT_CHUNKED_UPLOAD_ROUTE: &str = "v1/upload/abort";
389
390#[derive(Clone)]
391pub struct AdaptiveClient {
392 client: Client,
393 graphql_url: Url,
394 rest_base_url: Url,
395 auth_token: String,
396}
397
398impl AdaptiveClient {
399 pub fn new(api_base_url: Url, auth_token: String) -> Self {
400 let graphql_url = api_base_url
401 .join("graphql")
402 .expect("Failed to append graphql to base URL");
403
404 let client = Client::builder()
405 .user_agent(format!(
406 "adaptive-client-rust/{}",
407 env!("CARGO_PKG_VERSION")
408 ))
409 .build()
410 .expect("Failed to build HTTP client");
411
412 Self {
413 client,
414 graphql_url,
415 rest_base_url: api_base_url,
416 auth_token,
417 }
418 }
419
420 async fn execute_query<T>(&self, _query: T, variables: T::Variables) -> Result<T::ResponseData>
421 where
422 T: GraphQLQuery,
423 T::Variables: serde::Serialize,
424 T::ResponseData: DeserializeOwned,
425 {
426 let request_body = T::build_query(variables);
427
428 let response = self
429 .client
430 .post(self.graphql_url.clone())
431 .bearer_auth(&self.auth_token)
432 .json(&request_body)
433 .send()
434 .await?;
435
436 let status = response.status();
437 let response_text = response.text().await?;
438
439 if !status.is_success() {
440 return Err(AdaptiveError::HttpStatusError {
441 status: status.to_string(),
442 body: response_text,
443 });
444 }
445
446 let response_body: Response<T::ResponseData> = serde_json::from_str(&response_text)
447 .map_err(|e| AdaptiveError::JsonParseError {
448 error: e.to_string(),
449 body: response_text.chars().take(500).collect(),
450 })?;
451
452 match response_body.data {
453 Some(data) => Ok(data),
454 None => {
455 if let Some(errors) = response_body.errors {
456 return Err(AdaptiveError::GraphQLErrors(errors));
457 }
458 Err(AdaptiveError::NoGraphQLData)
459 }
460 }
461 }
462
463 pub async fn list_recipes(
464 &self,
465 project: &str,
466 ) -> Result<Vec<get_custom_recipes::GetCustomRecipesCustomRecipes>> {
467 let variables = get_custom_recipes::Variables {
468 project: IdOrKey::from(project),
469 };
470
471 let response_data = self.execute_query(GetCustomRecipes, variables).await?;
472 Ok(response_data.custom_recipes)
473 }
474
475 pub async fn get_job(&self, job_id: Uuid) -> Result<get_job::GetJobJob> {
476 let variables = get_job::Variables { id: job_id };
477
478 let response_data = self.execute_query(GetJob, variables).await?;
479
480 match response_data.job {
481 Some(job) => Ok(job),
482 None => Err(AdaptiveError::JobNotFound(job_id)),
483 }
484 }
485
486 pub async fn upload_dataset<P: AsRef<Path>>(
487 &self,
488 project: &str,
489 name: &str,
490 dataset: P,
491 ) -> Result<upload_dataset::UploadDatasetCreateDataset> {
492 let variables = upload_dataset::Variables {
493 project: IdOrKey::from(project),
494 file: Upload(0),
495 name: Some(name.to_string()),
496 };
497
498 let operations = UploadDataset::build_query(variables);
499 let operations = serde_json::to_string(&operations)?;
500
501 let file_map = r#"{ "0": ["variables.file"] }"#;
502
503 let dataset_file = reqwest::multipart::Part::file(dataset).await?;
504
505 let form = reqwest::multipart::Form::new()
506 .text("operations", operations)
507 .text("map", file_map)
508 .part("0", dataset_file);
509
510 let response = self
511 .client
512 .post(self.graphql_url.clone())
513 .bearer_auth(&self.auth_token)
514 .multipart(form)
515 .send()
516 .await?;
517
518 let response: Response<<UploadDataset as graphql_client::GraphQLQuery>::ResponseData> =
519 response.json().await?;
520
521 match response.data {
522 Some(data) => Ok(data.create_dataset),
523 None => {
524 if let Some(errors) = response.errors {
525 return Err(AdaptiveError::GraphQLErrors(errors));
526 }
527 Err(AdaptiveError::NoGraphQLData)
528 }
529 }
530 }
531
532 pub async fn publish_recipe<P: AsRef<Path>>(
533 &self,
534 project: &str,
535 name: &str,
536 key: &str,
537 recipe: P,
538 ) -> Result<publish_custom_recipe::PublishCustomRecipeCreateCustomRecipe> {
539 let variables = publish_custom_recipe::Variables {
540 project: IdOrKey::from(project),
541 file: Upload(0),
542 name: Some(name.to_string()),
543 key: Some(key.to_string()),
544 };
545
546 let operations = PublishCustomRecipe::build_query(variables);
547 let operations = serde_json::to_string(&operations)?;
548
549 let file_map = r#"{ "0": ["variables.file"] }"#;
550
551 let recipe_file = reqwest::multipart::Part::file(recipe).await?;
552
553 let form = reqwest::multipart::Form::new()
554 .text("operations", operations)
555 .text("map", file_map)
556 .part("0", recipe_file);
557
558 let response = self
559 .client
560 .post(self.graphql_url.clone())
561 .bearer_auth(&self.auth_token)
562 .multipart(form)
563 .send()
564 .await?;
565 let response: Response<
566 <PublishCustomRecipe as graphql_client::GraphQLQuery>::ResponseData,
567 > = response.json().await?;
568
569 match response.data {
570 Some(data) => Ok(data.create_custom_recipe),
571 None => {
572 if let Some(errors) = response.errors {
573 return Err(AdaptiveError::GraphQLErrors(errors));
574 }
575 Err(AdaptiveError::NoGraphQLData)
576 }
577 }
578 }
579
580 pub async fn update_recipe<P: AsRef<Path>>(
581 &self,
582 project: &str,
583 id: &str,
584 name: Option<String>,
585 description: Option<String>,
586 labels: Option<Vec<update_custom_recipe::LabelInput>>,
587 recipe_file: Option<P>,
588 ) -> Result<update_custom_recipe::UpdateCustomRecipeUpdateCustomRecipe> {
589 let input = update_custom_recipe::UpdateRecipeInput {
590 name,
591 description,
592 labels,
593 };
594
595 match recipe_file {
596 Some(file_path) => {
597 let variables = update_custom_recipe::Variables {
598 project: IdOrKey::from(project),
599 id: IdOrKey::from(id),
600 input,
601 file: Some(Upload(0)),
602 };
603
604 let operations = UpdateCustomRecipe::build_query(variables);
605 let operations = serde_json::to_string(&operations)?;
606
607 let file_map = r#"{ "0": ["variables.file"] }"#;
608
609 let recipe_file = reqwest::multipart::Part::file(file_path).await?;
610
611 let form = reqwest::multipart::Form::new()
612 .text("operations", operations)
613 .text("map", file_map)
614 .part("0", recipe_file);
615
616 let response = self
617 .client
618 .post(self.graphql_url.clone())
619 .bearer_auth(&self.auth_token)
620 .multipart(form)
621 .send()
622 .await?;
623 let response: Response<
624 <UpdateCustomRecipe as graphql_client::GraphQLQuery>::ResponseData,
625 > = response.json().await?;
626
627 match response.data {
628 Some(data) => Ok(data.update_custom_recipe),
629 None => {
630 if let Some(errors) = response.errors {
631 return Err(AdaptiveError::GraphQLErrors(errors));
632 }
633 Err(AdaptiveError::NoGraphQLData)
634 }
635 }
636 }
637 None => {
638 let variables = update_custom_recipe::Variables {
639 project: IdOrKey::from(project),
640 id: IdOrKey::from(id),
641 input,
642 file: None,
643 };
644
645 let response_data = self.execute_query(UpdateCustomRecipe, variables).await?;
646 Ok(response_data.update_custom_recipe)
647 }
648 }
649 }
650
651 pub async fn run_recipe(
652 &self,
653 project: &str,
654 recipe_id: &str,
655 parameters: Map<String, Value>,
656 name: Option<String>,
657 compute_pool: Option<String>,
658 num_gpus: u32,
659 use_experimental_runner: bool,
660 ) -> Result<run_custom_recipe::RunCustomRecipeCreateJob> {
661 let variables = run_custom_recipe::Variables {
662 input: run_custom_recipe::JobInput {
663 recipe: recipe_id.to_string(),
664 project: project.to_string(),
665 args: parameters,
666 name,
667 compute_pool,
668 num_gpus: num_gpus as i64,
669 use_experimental_runner,
670 max_cpu: None,
671 max_ram_gb: None,
672 max_duration_secs: None,
673 },
674 };
675
676 let response_data = self.execute_query(RunCustomRecipe, variables).await?;
677 Ok(response_data.create_job)
678 }
679
680 pub async fn list_jobs(
681 &self,
682 project: Option<String>,
683 ) -> Result<Vec<list_jobs::ListJobsJobsNodes>> {
684 let mut jobs = Vec::new();
685 let mut page = self.list_jobs_page(project.clone(), None).await?;
686 jobs.extend(page.nodes.iter().cloned());
687 while page.page_info.has_next_page {
688 page = self
689 .list_jobs_page(project.clone(), page.page_info.end_cursor)
690 .await?;
691 jobs.extend(page.nodes.iter().cloned());
692 }
693 Ok(jobs)
694 }
695
696 async fn list_jobs_page(
697 &self,
698 project: Option<String>,
699 after: Option<String>,
700 ) -> Result<list_jobs::ListJobsJobs> {
701 let variables = list_jobs::Variables {
702 filter: Some(list_jobs::ListJobsFilterInput {
703 project,
704 kind: Some(vec![list_jobs::JobKind::CUSTOM]),
705 status: Some(vec![
706 list_jobs::JobStatus::RUNNING,
707 list_jobs::JobStatus::PENDING,
708 ]),
709 timerange: None,
710 custom_recipes: None,
711 artifacts: None,
712 }),
713 cursor: Some(list_jobs::CursorPageInput {
714 first: Some(PAGE_SIZE as i64),
715 after,
716 before: None,
717 last: None,
718 offset: None,
719 }),
720 };
721
722 let response_data = self.execute_query(ListJobs, variables).await?;
723 Ok(response_data.jobs)
724 }
725
726 pub async fn cancel_job(&self, job_id: Uuid) -> Result<cancel_job::CancelJobCancelJob> {
727 let variables = cancel_job::Variables { job_id };
728
729 let response_data = self.execute_query(CancelJob, variables).await?;
730 Ok(response_data.cancel_job)
731 }
732
733 pub async fn update_job_progress(
734 &self,
735 job_id: Uuid,
736 event: update_job_progress::JobProgressEventInput,
737 ) -> Result<update_job_progress::UpdateJobProgressUpdateJobProgress> {
738 let variables = update_job_progress::Variables { job_id, event };
739
740 let response_data = self.execute_query(UpdateJobProgress, variables).await?;
741 Ok(response_data.update_job_progress)
742 }
743
744 pub async fn list_models(
745 &self,
746 project: String,
747 ) -> Result<Vec<list_models::ListModelsProjectModelServices>> {
748 let variables = list_models::Variables { project };
749
750 let response_data = self.execute_query(ListModels, variables).await?;
751 Ok(response_data
752 .project
753 .map(|project| project.model_services)
754 .unwrap_or(Vec::new()))
755 }
756
757 pub async fn list_all_models(&self) -> Result<Vec<list_all_models::ListAllModelsModels>> {
758 let variables = list_all_models::Variables {};
759
760 let response_data = self.execute_query(ListAllModels, variables).await?;
761 Ok(response_data.models)
762 }
763
764 pub async fn list_projects(&self) -> Result<Vec<list_projects::ListProjectsProjects>> {
765 let variables = list_projects::Variables {};
766
767 let response_data = self.execute_query(ListProjects, variables).await?;
768 Ok(response_data.projects)
769 }
770
771 pub async fn list_pools(
772 &self,
773 ) -> Result<Vec<list_compute_pools::ListComputePoolsComputePools>> {
774 let variables = list_compute_pools::Variables {};
775
776 let response_data = self.execute_query(ListComputePools, variables).await?;
777 Ok(response_data.compute_pools)
778 }
779
780 pub async fn list_roles(&self) -> Result<Vec<list_roles::ListRolesRoles>> {
781 let variables = list_roles::Variables {};
782
783 let response_data = self.execute_query(ListRoles, variables).await?;
784 Ok(response_data.roles)
785 }
786
787 pub async fn create_role(
788 &self,
789 name: &str,
790 key: Option<&str>,
791 permissions: Vec<String>,
792 ) -> Result<create_role::CreateRoleCreateRole> {
793 let variables = create_role::Variables {
794 input: create_role::RoleCreate {
795 name: name.to_string(),
796 key: key.map(|k| k.to_string()),
797 permissions,
798 },
799 };
800
801 let response_data = self.execute_query(CreateRole, variables).await?;
802 Ok(response_data.create_role)
803 }
804
805 pub async fn list_teams(&self) -> Result<Vec<list_teams::ListTeamsTeams>> {
806 let variables = list_teams::Variables {};
807
808 let response_data = self.execute_query(ListTeams, variables).await?;
809 Ok(response_data.teams)
810 }
811
812 pub async fn create_team(
813 &self,
814 name: &str,
815 key: Option<&str>,
816 ) -> Result<create_team::CreateTeamCreateTeam> {
817 let variables = create_team::Variables {
818 input: create_team::TeamCreate {
819 name: name.to_string(),
820 key: key.map(|k| k.to_string()),
821 },
822 };
823
824 let response_data = self.execute_query(CreateTeam, variables).await?;
825 Ok(response_data.create_team)
826 }
827
828 pub async fn list_users(&self) -> Result<Vec<list_users::ListUsersUsers>> {
829 let variables = list_users::Variables {};
830
831 let response_data = self.execute_query(ListUsers, variables).await?;
832 Ok(response_data.users)
833 }
834
835 pub async fn create_user(
836 &self,
837 name: &str,
838 email: Option<&str>,
839 teams: Vec<create_user::UserCreateTeamWithRole>,
840 user_type: Option<create_user::UserType>,
841 generate_api_key: Option<bool>,
842 ) -> Result<create_user::CreateUserCreateUser> {
843 let variables = create_user::Variables {
844 input: create_user::UserCreate {
845 name: name.to_string(),
846 email: email.map(|e| e.to_string()),
847 teams,
848 user_type: user_type.unwrap_or(create_user::UserType::HUMAN),
849 generate_api_key,
850 },
851 };
852
853 let response_data = self.execute_query(CreateUser, variables).await?;
854 Ok(response_data.create_user)
855 }
856
857 pub async fn delete_user(&self, user: &str) -> Result<delete_user::DeleteUserDeleteUser> {
858 let variables = delete_user::Variables {
859 user: user.to_string(),
860 };
861
862 let response_data = self.execute_query(DeleteUser, variables).await?;
863 Ok(response_data.delete_user)
864 }
865
866 pub async fn add_team_member(
867 &self,
868 user: &str,
869 team: &str,
870 role: &str,
871 ) -> Result<add_team_member::AddTeamMemberSetTeamMember> {
872 let variables = add_team_member::Variables {
873 input: add_team_member::TeamMemberSet {
874 user: user.to_string(),
875 team: team.to_string(),
876 role: role.to_string(),
877 },
878 };
879
880 let response_data = self.execute_query(AddTeamMember, variables).await?;
881 Ok(response_data.set_team_member)
882 }
883
884 pub async fn remove_team_member(
885 &self,
886 user: &str,
887 team: &str,
888 ) -> Result<remove_team_member::RemoveTeamMemberRemoveTeamMember> {
889 let variables = remove_team_member::Variables {
890 input: remove_team_member::TeamMemberRemove {
891 user: user.to_string(),
892 team: team.to_string(),
893 },
894 };
895
896 let response_data = self.execute_query(RemoveTeamMember, variables).await?;
897 Ok(response_data.remove_team_member)
898 }
899
900 pub async fn get_recipe(
901 &self,
902 project: String,
903 id_or_key: String,
904 ) -> Result<Option<get_recipe::GetRecipeCustomRecipe>> {
905 let variables = get_recipe::Variables { project, id_or_key };
906
907 let response_data = self.execute_query(GetRecipe, variables).await?;
908 Ok(response_data.custom_recipe)
909 }
910
911 pub async fn get_grader(
912 &self,
913 id_or_key: &str,
914 project: &str,
915 ) -> Result<get_grader::GetGraderGrader> {
916 let variables = get_grader::Variables {
917 id: id_or_key.to_string(),
918 project: project.to_string(),
919 };
920
921 let response_data = self.execute_query(GetGrader, variables).await?;
922 Ok(response_data.grader)
923 }
924
925 pub async fn get_dataset(
926 &self,
927 id_or_key: &str,
928 project: &str,
929 ) -> Result<Option<get_dataset::GetDatasetDataset>> {
930 let variables = get_dataset::Variables {
931 id_or_key: id_or_key.to_string(),
932 project: project.to_string(),
933 };
934
935 let response_data = self.execute_query(GetDataset, variables).await?;
936 Ok(response_data.dataset)
937 }
938
939 pub async fn get_model_config(
940 &self,
941 id_or_key: &str,
942 ) -> Result<Option<get_model_config::GetModelConfigModel>> {
943 let variables = get_model_config::Variables {
944 id_or_key: id_or_key.to_string(),
945 };
946
947 let response_data = self.execute_query(GetModelConfig, variables).await?;
948 Ok(response_data.model)
949 }
950
951 pub fn base_url(&self) -> &Url {
952 &self.rest_base_url
953 }
954
955 pub async fn upload_bytes(&self, data: &[u8], content_type: &str) -> Result<String> {
958 let file_size = data.len() as u64;
959
960 let chunk_size = if file_size < 5 * 1024 * 1024 {
962 file_size.max(1)
964 } else if file_size < 500 * 1024 * 1024 {
965 5 * 1024 * 1024
966 } else if file_size < 10 * 1024 * 1024 * 1024 {
967 10 * 1024 * 1024
968 } else {
969 100 * 1024 * 1024
970 };
971
972 let total_parts = file_size.div_ceil(chunk_size).max(1);
973
974 let session_id = self
976 .init_chunked_upload_with_content_type(total_parts, content_type)
977 .await?;
978
979 for part_number in 1..=total_parts {
981 let start = ((part_number - 1) * chunk_size) as usize;
982 let end = (part_number * chunk_size).min(file_size) as usize;
983 let chunk = data[start..end].to_vec();
984
985 if let Err(e) = self
986 .upload_part_simple(&session_id, part_number, chunk)
987 .await
988 {
989 let _ = self.abort_chunked_upload(&session_id).await;
990 return Err(e);
991 }
992 }
993
994 Ok(session_id)
995 }
996
997 pub async fn init_chunked_upload_with_content_type(
999 &self,
1000 total_parts: u64,
1001 content_type: &str,
1002 ) -> Result<String> {
1003 let url = self.rest_base_url.join(INIT_CHUNKED_UPLOAD_ROUTE)?;
1004
1005 let request = InitChunkedUploadRequest {
1006 content_type: content_type.to_string(),
1007 metadata: None,
1008 total_parts_count: total_parts,
1009 };
1010
1011 let response = self
1012 .client
1013 .post(url)
1014 .bearer_auth(&self.auth_token)
1015 .json(&request)
1016 .send()
1017 .await?;
1018
1019 if !response.status().is_success() {
1020 return Err(AdaptiveError::ChunkedUploadInitFailed {
1021 status: response.status().to_string(),
1022 body: response.text().await.unwrap_or_default(),
1023 });
1024 }
1025
1026 let init_response: InitChunkedUploadResponse = response.json().await?;
1027 Ok(init_response.session_id)
1028 }
1029
1030 pub async fn upload_part_simple(
1032 &self,
1033 session_id: &str,
1034 part_number: u64,
1035 data: Vec<u8>,
1036 ) -> Result<()> {
1037 let url = self.rest_base_url.join(UPLOAD_PART_ROUTE)?;
1038
1039 let response = self
1040 .client
1041 .post(url)
1042 .bearer_auth(&self.auth_token)
1043 .query(&[
1044 ("session_id", session_id),
1045 ("part_number", &part_number.to_string()),
1046 ])
1047 .header("Content-Type", "application/octet-stream")
1048 .body(data)
1049 .send()
1050 .await?;
1051
1052 if !response.status().is_success() {
1053 return Err(AdaptiveError::ChunkedUploadPartFailed {
1054 part_number,
1055 status: response.status().to_string(),
1056 body: response.text().await.unwrap_or_default(),
1057 });
1058 }
1059
1060 Ok(())
1061 }
1062
1063 async fn init_chunked_upload(&self, total_parts: u64) -> Result<String> {
1064 let url = self.rest_base_url.join(INIT_CHUNKED_UPLOAD_ROUTE)?;
1065
1066 let request = InitChunkedUploadRequest {
1067 content_type: "application/jsonl".to_string(),
1068 metadata: None,
1069 total_parts_count: total_parts,
1070 };
1071
1072 let response = self
1073 .client
1074 .post(url)
1075 .bearer_auth(&self.auth_token)
1076 .json(&request)
1077 .send()
1078 .await?;
1079
1080 if !response.status().is_success() {
1081 return Err(AdaptiveError::ChunkedUploadInitFailed {
1082 status: response.status().to_string(),
1083 body: response.text().await.unwrap_or_default(),
1084 });
1085 }
1086
1087 let init_response: InitChunkedUploadResponse = response.json().await?;
1088 Ok(init_response.session_id)
1089 }
1090
1091 async fn upload_part(
1092 &self,
1093 session_id: &str,
1094 part_number: u64,
1095 data: Vec<u8>,
1096 progress_tx: mpsc::Sender<u64>,
1097 ) -> Result<()> {
1098 const SUB_CHUNK_SIZE: usize = 64 * 1024;
1099
1100 let url = self.rest_base_url.join(UPLOAD_PART_ROUTE)?;
1101
1102 let chunks: Vec<Vec<u8>> = data
1103 .chunks(SUB_CHUNK_SIZE)
1104 .map(|chunk| chunk.to_vec())
1105 .collect();
1106
1107 let stream = futures::stream::iter(chunks).map(move |chunk| {
1108 let len = chunk.len() as u64;
1109 let tx = progress_tx.clone();
1110 let _ = tx.try_send(len);
1111 Ok::<_, std::io::Error>(chunk)
1112 });
1113
1114 let body = reqwest::Body::wrap_stream(stream);
1115
1116 let response = self
1117 .client
1118 .post(url)
1119 .bearer_auth(&self.auth_token)
1120 .query(&[
1121 ("session_id", session_id),
1122 ("part_number", &part_number.to_string()),
1123 ])
1124 .header("Content-Type", "application/octet-stream")
1125 .body(body)
1126 .send()
1127 .await?;
1128
1129 if !response.status().is_success() {
1130 return Err(AdaptiveError::ChunkedUploadPartFailed {
1131 part_number,
1132 status: response.status().to_string(),
1133 body: response.text().await.unwrap_or_default(),
1134 });
1135 }
1136
1137 Ok(())
1138 }
1139
1140 pub async fn abort_chunked_upload(&self, session_id: &str) -> Result<()> {
1142 let url = self.rest_base_url.join(ABORT_CHUNKED_UPLOAD_ROUTE)?;
1143
1144 let request = AbortChunkedUploadRequest {
1145 session_id: session_id.to_string(),
1146 };
1147
1148 let _ = self
1149 .client
1150 .delete(url)
1151 .bearer_auth(&self.auth_token)
1152 .json(&request)
1153 .send()
1154 .await;
1155
1156 Ok(())
1157 }
1158
1159 async fn create_dataset_from_multipart(
1160 &self,
1161 project: &str,
1162 name: &str,
1163 key: &str,
1164 session_id: &str,
1165 ) -> Result<
1166 create_dataset_from_multipart::CreateDatasetFromMultipartCreateDatasetFromMultipartUpload,
1167 > {
1168 let variables = create_dataset_from_multipart::Variables {
1169 input: create_dataset_from_multipart::DatasetCreateFromMultipartUpload {
1170 project: project.to_string(),
1171 name: name.to_string(),
1172 key: Some(key.to_string()),
1173 source: None,
1174 upload_session_id: session_id.to_string(),
1175 },
1176 };
1177
1178 let response_data = self
1179 .execute_query(CreateDatasetFromMultipart, variables)
1180 .await?;
1181 Ok(response_data.create_dataset_from_multipart_upload)
1182 }
1183
1184 pub fn chunked_upload_dataset<'a, P: AsRef<Path> + Send + 'a>(
1185 &'a self,
1186 project: &'a str,
1187 name: &'a str,
1188 key: &'a str,
1189 dataset: P,
1190 ) -> Result<BoxStream<'a, Result<UploadEvent>>> {
1191 let file_size = std::fs::metadata(dataset.as_ref())?.len();
1192
1193 let (total_parts, chunk_size) = calculate_upload_parts(file_size)?;
1194
1195 let stream = async_stream::try_stream! {
1196 yield UploadEvent::Progress(ChunkedUploadProgress {
1197 bytes_uploaded: 0,
1198 total_bytes: file_size,
1199 });
1200
1201 let session_id = self.init_chunked_upload(total_parts).await?;
1202
1203 let mut file = File::open(dataset.as_ref())?;
1204 let mut buffer = vec![0u8; chunk_size as usize];
1205 let mut bytes_uploaded = 0u64;
1206
1207 let (progress_tx, mut progress_rx) = mpsc::channel::<u64>(64);
1208
1209 for part_number in 1..=total_parts {
1210 let bytes_read = file.read(&mut buffer)?;
1211 let chunk_data = buffer[..bytes_read].to_vec();
1212
1213 let upload_fut = self.upload_part(&session_id, part_number, chunk_data, progress_tx.clone());
1214 tokio::pin!(upload_fut);
1215
1216 let upload_result: Result<()> = loop {
1217 tokio::select! {
1218 biased;
1219 result = &mut upload_fut => {
1220 break result;
1221 }
1222 Some(bytes) = progress_rx.recv() => {
1223 bytes_uploaded += bytes;
1224 yield UploadEvent::Progress(ChunkedUploadProgress {
1225 bytes_uploaded,
1226 total_bytes: file_size,
1227 });
1228 }
1229 }
1230 };
1231
1232 if let Err(e) = upload_result {
1233 let _ = self.abort_chunked_upload(&session_id).await;
1234 Err(e)?;
1235 }
1236 }
1237
1238 let create_result = self
1239 .create_dataset_from_multipart(project, name, key, &session_id)
1240 .await;
1241
1242 match create_result {
1243 Ok(response) => {
1244 yield UploadEvent::Complete(response);
1245 }
1246 Err(e) => {
1247 let _ = self.abort_chunked_upload(&session_id).await;
1248 Err(AdaptiveError::DatasetCreationFailed(e.to_string()))?;
1249 }
1250 }
1251 };
1252
1253 Ok(Box::pin(stream))
1254 }
1255
1256 pub async fn download_file_to_path(&self, url: &str, dest_path: &Path) -> Result<()> {
1259 use tokio::io::AsyncWriteExt;
1260
1261 let full_url = if url.starts_with("http://") || url.starts_with("https://") {
1262 Url::parse(url)?
1263 } else {
1264 self.rest_base_url.join(url)?
1265 };
1266
1267 let response = self
1268 .client
1269 .get(full_url)
1270 .bearer_auth(&self.auth_token)
1271 .send()
1272 .await?;
1273
1274 if !response.status().is_success() {
1275 return Err(AdaptiveError::HttpError(
1276 response.error_for_status().unwrap_err(),
1277 ));
1278 }
1279
1280 let mut file = tokio::fs::File::create(dest_path).await?;
1281 let mut stream = response.bytes_stream();
1282
1283 while let Some(chunk) = stream.next().await {
1284 let chunk = chunk?;
1285 file.write_all(&chunk).await?;
1286 }
1287
1288 file.flush().await?;
1289 Ok(())
1290 }
1291}