1use std::{fmt::Display, fs::File, io::Read, path::Path, time::SystemTime};
2
3use futures::{StreamExt, stream::BoxStream};
4use thiserror::Error;
5use tokio::sync::mpsc;
6
7use graphql_client::{GraphQLQuery, Response};
8use reqwest::Client;
9use serde::{Deserialize, Serialize, de::DeserializeOwned};
10use serde_json::{Map, Value};
11use url::Url;
12use uuid::Uuid;
13
14mod rest_types;
15mod serde_utils;
16
17use rest_types::{AbortChunkedUploadRequest, InitChunkedUploadRequest, InitChunkedUploadResponse};
18
19const MEGABYTE: u64 = 1024 * 1024; pub const MIN_CHUNK_SIZE_BYTES: u64 = 5 * MEGABYTE;
21const MAX_CHUNK_SIZE_BYTES: u64 = 100 * MEGABYTE;
22const MAX_PARTS_COUNT: u64 = 10000;
23
24const SIZE_500MB: u64 = 500 * MEGABYTE;
25const SIZE_10GB: u64 = 10 * 1024 * MEGABYTE;
26const SIZE_50GB: u64 = 50 * 1024 * MEGABYTE;
27
28#[derive(Error, Debug)]
29pub enum AdaptiveError {
30 #[error("HTTP error: {0}")]
31 HttpError(#[from] reqwest::Error),
32
33 #[error("JSON error: {0}")]
34 JsonError(#[from] serde_json::Error),
35
36 #[error("IO error: {0}")]
37 IoError(#[from] std::io::Error),
38
39 #[error("URL parse error: {0}")]
40 UrlParseError(#[from] url::ParseError),
41
42 #[error("File too small for chunked upload: {size} bytes (minimum: {min_size} bytes)")]
43 FileTooSmall { size: u64, min_size: u64 },
44
45 #[error("File too large: {size} bytes exceeds maximum {max_size} bytes")]
46 FileTooLarge { size: u64, max_size: u64 },
47
48 #[error("GraphQL errors: {0:?}")]
49 GraphQLErrors(Vec<graphql_client::Error>),
50
51 #[error("No data returned from GraphQL")]
52 NoGraphQLData,
53
54 #[error("Job not found: {0}")]
55 JobNotFound(Uuid),
56
57 #[error("Failed to initialize chunked upload: {status} - {body}")]
58 ChunkedUploadInitFailed { status: String, body: String },
59
60 #[error("Failed to upload part {part_number}: {status} - {body}")]
61 ChunkedUploadPartFailed {
62 part_number: u64,
63 status: String,
64 body: String,
65 },
66
67 #[error("Failed to create dataset: {0}")]
68 DatasetCreationFailed(String),
69
70 #[error("HTTP status error: {status} - {body}")]
71 HttpStatusError { status: String, body: String },
72
73 #[error("Failed to parse JSON response: {error}. Body preview: {body}")]
74 JsonParseError { error: String, body: String },
75}
76
77type Result<T> = std::result::Result<T, AdaptiveError>;
78
79#[derive(Clone, Debug, Default)]
80pub struct ChunkedUploadProgress {
81 pub bytes_uploaded: u64,
82 pub total_bytes: u64,
83}
84
85#[derive(Debug)]
86pub enum UploadEvent {
87 Progress(ChunkedUploadProgress),
88 Complete(
89 create_dataset_from_multipart::CreateDatasetFromMultipartCreateDatasetFromMultipartUpload,
90 ),
91}
92
93pub fn calculate_upload_parts(file_size: u64) -> Result<(u64, u64)> {
94 if file_size < MIN_CHUNK_SIZE_BYTES {
95 return Err(AdaptiveError::FileTooSmall {
96 size: file_size,
97 min_size: MIN_CHUNK_SIZE_BYTES,
98 });
99 }
100
101 let mut chunk_size = if file_size < SIZE_500MB {
102 5 * MEGABYTE
103 } else if file_size < SIZE_10GB {
104 10 * MEGABYTE
105 } else if file_size < SIZE_50GB {
106 50 * MEGABYTE
107 } else {
108 100 * MEGABYTE
109 };
110
111 let mut total_parts = file_size.div_ceil(chunk_size);
112
113 if total_parts > MAX_PARTS_COUNT {
114 chunk_size = file_size.div_ceil(MAX_PARTS_COUNT);
115
116 if chunk_size > MAX_CHUNK_SIZE_BYTES {
117 let max_file_size = MAX_CHUNK_SIZE_BYTES * MAX_PARTS_COUNT;
118 return Err(AdaptiveError::FileTooLarge {
119 size: file_size,
120 max_size: max_file_size,
121 });
122 }
123
124 total_parts = file_size.div_ceil(chunk_size);
125 }
126
127 Ok((total_parts, chunk_size))
128}
129
130type IdOrKey = String;
131#[allow(clippy::upper_case_acronyms)]
132type UUID = Uuid;
133type JsObject = Map<String, Value>;
134type InputDatetime = String;
135#[allow(clippy::upper_case_acronyms)]
136type JSON = Value;
137#[allow(clippy::upper_case_acronyms)]
138type JSONObject = Map<String, Value>;
139type KeyInput = String;
140
141#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
142pub struct Timestamp(pub SystemTime);
143
144impl<'de> serde::Deserialize<'de> for Timestamp {
145 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
146 where
147 D: serde::Deserializer<'de>,
148 {
149 let system_time = serde_utils::deserialize_timestamp_millis(deserializer)?;
150 Ok(Timestamp(system_time))
151 }
152}
153
154const PAGE_SIZE: usize = 20;
155
156#[derive(Debug, PartialEq, Serialize, Deserialize)]
157pub struct Upload(usize);
158
159#[derive(GraphQLQuery)]
160#[graphql(
161 schema_path = "schema.gql",
162 query_path = "src/graphql/list.graphql",
163 response_derives = "Debug, Clone"
164)]
165pub struct GetCustomRecipes;
166
167#[derive(GraphQLQuery)]
168#[graphql(
169 schema_path = "schema.gql",
170 query_path = "src/graphql/job.graphql",
171 response_derives = "Debug, Clone"
172)]
173pub struct GetJob;
174
175#[derive(GraphQLQuery)]
176#[graphql(
177 schema_path = "schema.gql",
178 query_path = "src/graphql/jobs.graphql",
179 response_derives = "Debug, Clone"
180)]
181pub struct ListJobs;
182
183#[derive(GraphQLQuery)]
184#[graphql(
185 schema_path = "schema.gql",
186 query_path = "src/graphql/cancel.graphql",
187 response_derives = "Debug, Clone"
188)]
189pub struct CancelJob;
190
191#[derive(GraphQLQuery)]
192#[graphql(
193 schema_path = "schema.gql",
194 query_path = "src/graphql/models.graphql",
195 response_derives = "Debug, Clone"
196)]
197pub struct ListModels;
198
199#[derive(GraphQLQuery)]
200#[graphql(
201 schema_path = "schema.gql",
202 query_path = "src/graphql/all_models.graphql",
203 response_derives = "Debug, Clone"
204)]
205pub struct ListAllModels;
206
207impl Display for get_job::JobStatus {
208 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
209 match self {
210 get_job::JobStatus::PENDING => write!(f, "Pending"),
211 get_job::JobStatus::RUNNING => write!(f, "Running"),
212 get_job::JobStatus::COMPLETED => write!(f, "Completed"),
213 get_job::JobStatus::FAILED => write!(f, "Failed"),
214 get_job::JobStatus::CANCELED => write!(f, "Canceled"),
215 get_job::JobStatus::Other(_) => write!(f, "Unknown"),
216 }
217 }
218}
219
220#[derive(GraphQLQuery)]
221#[graphql(
222 schema_path = "schema.gql",
223 query_path = "src/graphql/publish.graphql",
224 response_derives = "Debug, Clone"
225)]
226pub struct PublishCustomRecipe;
227
228#[derive(GraphQLQuery)]
229#[graphql(
230 schema_path = "schema.gql",
231 query_path = "src/graphql/update_recipe.graphql",
232 response_derives = "Debug, Clone"
233)]
234pub struct UpdateCustomRecipe;
235
236#[derive(GraphQLQuery)]
237#[graphql(
238 schema_path = "schema.gql",
239 query_path = "src/graphql/upload_dataset.graphql",
240 response_derives = "Debug, Clone"
241)]
242pub struct UploadDataset;
243
244#[derive(GraphQLQuery)]
245#[graphql(
246 schema_path = "schema.gql",
247 query_path = "src/graphql/create_dataset_from_multipart.graphql",
248 response_derives = "Debug, Clone"
249)]
250pub struct CreateDatasetFromMultipart;
251
252#[derive(GraphQLQuery)]
253#[graphql(
254 schema_path = "schema.gql",
255 query_path = "src/graphql/run.graphql",
256 response_derives = "Debug, Clone"
257)]
258pub struct RunCustomRecipe;
259
260#[derive(GraphQLQuery)]
261#[graphql(
262 schema_path = "schema.gql",
263 query_path = "src/graphql/projects.graphql",
264 response_derives = "Debug, Clone"
265)]
266pub struct ListProjects;
267
268#[derive(GraphQLQuery)]
269#[graphql(
270 schema_path = "schema.gql",
271 query_path = "src/graphql/pools.graphql",
272 response_derives = "Debug, Clone"
273)]
274pub struct ListComputePools;
275
276#[derive(GraphQLQuery)]
277#[graphql(
278 schema_path = "schema.gql",
279 query_path = "src/graphql/recipe.graphql",
280 response_derives = "Debug, Clone"
281)]
282pub struct GetRecipe;
283
284#[derive(GraphQLQuery)]
285#[graphql(
286 schema_path = "schema.gql",
287 query_path = "src/graphql/grader.graphql",
288 response_derives = "Debug, Clone, Serialize"
289)]
290pub struct GetGrader;
291
292#[derive(GraphQLQuery)]
293#[graphql(
294 schema_path = "schema.gql",
295 query_path = "src/graphql/dataset.graphql",
296 response_derives = "Debug, Clone"
297)]
298pub struct GetDataset;
299
300#[derive(GraphQLQuery)]
301#[graphql(
302 schema_path = "schema.gql",
303 query_path = "src/graphql/model_config.graphql",
304 response_derives = "Debug, Clone, Serialize"
305)]
306pub struct GetModelConfig;
307
308#[derive(GraphQLQuery)]
309#[graphql(
310 schema_path = "schema.gql",
311 query_path = "src/graphql/artifact.graphql",
312 response_derives = "Debug, Clone"
313)]
314pub struct GetArtifact;
315
316#[derive(GraphQLQuery)]
317#[graphql(
318 schema_path = "schema.gql",
319 query_path = "src/graphql/job_progress.graphql",
320 response_derives = "Debug, Clone"
321)]
322pub struct UpdateJobProgress;
323
324#[derive(GraphQLQuery)]
325#[graphql(
326 schema_path = "schema.gql",
327 query_path = "src/graphql/roles.graphql",
328 response_derives = "Debug, Clone"
329)]
330pub struct ListRoles;
331
332#[derive(GraphQLQuery)]
333#[graphql(
334 schema_path = "schema.gql",
335 query_path = "src/graphql/create_role.graphql",
336 response_derives = "Debug, Clone"
337)]
338pub struct CreateRole;
339
340#[derive(GraphQLQuery)]
341#[graphql(
342 schema_path = "schema.gql",
343 query_path = "src/graphql/update_role.graphql",
344 response_derives = "Debug, Clone"
345)]
346pub struct UpdateRole;
347
348#[derive(GraphQLQuery)]
349#[graphql(
350 schema_path = "schema.gql",
351 query_path = "src/graphql/teams.graphql",
352 response_derives = "Debug, Clone"
353)]
354pub struct ListTeams;
355
356#[derive(GraphQLQuery)]
357#[graphql(
358 schema_path = "schema.gql",
359 query_path = "src/graphql/create_team.graphql",
360 response_derives = "Debug, Clone"
361)]
362pub struct CreateTeam;
363
364#[derive(GraphQLQuery)]
365#[graphql(
366 schema_path = "schema.gql",
367 query_path = "src/graphql/users.graphql",
368 response_derives = "Debug, Clone"
369)]
370pub struct ListUsers;
371
372#[derive(GraphQLQuery)]
373#[graphql(
374 schema_path = "schema.gql",
375 query_path = "src/graphql/create_user.graphql",
376 response_derives = "Debug, Clone"
377)]
378pub struct CreateUser;
379
380#[derive(GraphQLQuery)]
381#[graphql(
382 schema_path = "schema.gql",
383 query_path = "src/graphql/delete_user.graphql",
384 response_derives = "Debug, Clone"
385)]
386pub struct DeleteUser;
387
388#[derive(GraphQLQuery)]
389#[graphql(
390 schema_path = "schema.gql",
391 query_path = "src/graphql/add_team_member.graphql",
392 response_derives = "Debug, Clone"
393)]
394pub struct AddTeamMember;
395
396#[derive(GraphQLQuery)]
397#[graphql(
398 schema_path = "schema.gql",
399 query_path = "src/graphql/remove_team_member.graphql",
400 response_derives = "Debug, Clone"
401)]
402pub struct RemoveTeamMember;
403
404const INIT_CHUNKED_UPLOAD_ROUTE: &str = "v1/upload/init";
405const UPLOAD_PART_ROUTE: &str = "v1/upload/part";
406const ABORT_CHUNKED_UPLOAD_ROUTE: &str = "v1/upload/abort";
407
408#[derive(Clone)]
409pub struct AdaptiveClient {
410 client: Client,
411 graphql_url: Url,
412 rest_base_url: Url,
413 auth_token: String,
414}
415
416impl AdaptiveClient {
417 pub fn new(api_base_url: Url, auth_token: String) -> Self {
418 let graphql_url = api_base_url
419 .join("graphql")
420 .expect("Failed to append graphql to base URL");
421
422 let client = Client::builder()
423 .user_agent(format!(
424 "adaptive-client-rust/{}",
425 env!("CARGO_PKG_VERSION")
426 ))
427 .build()
428 .expect("Failed to build HTTP client");
429
430 Self {
431 client,
432 graphql_url,
433 rest_base_url: api_base_url,
434 auth_token,
435 }
436 }
437
438 async fn execute_query<T>(&self, _query: T, variables: T::Variables) -> Result<T::ResponseData>
439 where
440 T: GraphQLQuery,
441 T::Variables: serde::Serialize,
442 T::ResponseData: DeserializeOwned,
443 {
444 let request_body = T::build_query(variables);
445
446 let response = self
447 .client
448 .post(self.graphql_url.clone())
449 .bearer_auth(&self.auth_token)
450 .json(&request_body)
451 .send()
452 .await?;
453
454 let status = response.status();
455 let response_text = response.text().await?;
456
457 if !status.is_success() {
458 return Err(AdaptiveError::HttpStatusError {
459 status: status.to_string(),
460 body: response_text,
461 });
462 }
463
464 let response_body: Response<T::ResponseData> = serde_json::from_str(&response_text)
465 .map_err(|e| AdaptiveError::JsonParseError {
466 error: e.to_string(),
467 body: response_text.chars().take(500).collect(),
468 })?;
469
470 match response_body.data {
471 Some(data) => Ok(data),
472 None => {
473 if let Some(errors) = response_body.errors {
474 return Err(AdaptiveError::GraphQLErrors(errors));
475 }
476 Err(AdaptiveError::NoGraphQLData)
477 }
478 }
479 }
480
481 pub async fn list_recipes(
482 &self,
483 project: &str,
484 ) -> Result<Vec<get_custom_recipes::GetCustomRecipesCustomRecipes>> {
485 let variables = get_custom_recipes::Variables {
486 project: IdOrKey::from(project),
487 };
488
489 let response_data = self.execute_query(GetCustomRecipes, variables).await?;
490 Ok(response_data.custom_recipes)
491 }
492
493 pub async fn get_job(&self, job_id: Uuid) -> Result<get_job::GetJobJob> {
494 let variables = get_job::Variables { id: job_id };
495
496 let response_data = self.execute_query(GetJob, variables).await?;
497
498 match response_data.job {
499 Some(job) => Ok(job),
500 None => Err(AdaptiveError::JobNotFound(job_id)),
501 }
502 }
503
504 pub async fn upload_dataset<P: AsRef<Path>>(
505 &self,
506 project: &str,
507 name: &str,
508 dataset: P,
509 ) -> Result<upload_dataset::UploadDatasetCreateDataset> {
510 let variables = upload_dataset::Variables {
511 project: IdOrKey::from(project),
512 file: Upload(0),
513 name: Some(name.to_string()),
514 };
515
516 let operations = UploadDataset::build_query(variables);
517 let operations = serde_json::to_string(&operations)?;
518
519 let file_map = r#"{ "0": ["variables.file"] }"#;
520
521 let dataset_file = reqwest::multipart::Part::file(dataset).await?;
522
523 let form = reqwest::multipart::Form::new()
524 .text("operations", operations)
525 .text("map", file_map)
526 .part("0", dataset_file);
527
528 let response = self
529 .client
530 .post(self.graphql_url.clone())
531 .bearer_auth(&self.auth_token)
532 .multipart(form)
533 .send()
534 .await?;
535
536 let response: Response<<UploadDataset as graphql_client::GraphQLQuery>::ResponseData> =
537 response.json().await?;
538
539 match response.data {
540 Some(data) => Ok(data.create_dataset),
541 None => {
542 if let Some(errors) = response.errors {
543 return Err(AdaptiveError::GraphQLErrors(errors));
544 }
545 Err(AdaptiveError::NoGraphQLData)
546 }
547 }
548 }
549
550 pub async fn publish_recipe<P: AsRef<Path>>(
551 &self,
552 project: &str,
553 name: &str,
554 key: &str,
555 recipe: P,
556 ) -> Result<publish_custom_recipe::PublishCustomRecipeCreateCustomRecipe> {
557 let variables = publish_custom_recipe::Variables {
558 project: IdOrKey::from(project),
559 file: Upload(0),
560 name: Some(name.to_string()),
561 key: Some(key.to_string()),
562 };
563
564 let operations = PublishCustomRecipe::build_query(variables);
565 let operations = serde_json::to_string(&operations)?;
566
567 let file_map = r#"{ "0": ["variables.file"] }"#;
568
569 let recipe_file = reqwest::multipart::Part::file(recipe).await?;
570
571 let form = reqwest::multipart::Form::new()
572 .text("operations", operations)
573 .text("map", file_map)
574 .part("0", recipe_file);
575
576 let response = self
577 .client
578 .post(self.graphql_url.clone())
579 .bearer_auth(&self.auth_token)
580 .multipart(form)
581 .send()
582 .await?;
583 let response: Response<
584 <PublishCustomRecipe as graphql_client::GraphQLQuery>::ResponseData,
585 > = response.json().await?;
586
587 match response.data {
588 Some(data) => Ok(data.create_custom_recipe),
589 None => {
590 if let Some(errors) = response.errors {
591 return Err(AdaptiveError::GraphQLErrors(errors));
592 }
593 Err(AdaptiveError::NoGraphQLData)
594 }
595 }
596 }
597
598 pub async fn update_recipe<P: AsRef<Path>>(
599 &self,
600 project: &str,
601 id: &str,
602 name: Option<String>,
603 description: Option<String>,
604 labels: Option<Vec<update_custom_recipe::LabelInput>>,
605 recipe_file: Option<P>,
606 ) -> Result<update_custom_recipe::UpdateCustomRecipeUpdateCustomRecipe> {
607 let input = update_custom_recipe::UpdateRecipeInput {
608 name,
609 description,
610 labels,
611 };
612
613 match recipe_file {
614 Some(file_path) => {
615 let variables = update_custom_recipe::Variables {
616 project: IdOrKey::from(project),
617 id: IdOrKey::from(id),
618 input,
619 file: Some(Upload(0)),
620 };
621
622 let operations = UpdateCustomRecipe::build_query(variables);
623 let operations = serde_json::to_string(&operations)?;
624
625 let file_map = r#"{ "0": ["variables.file"] }"#;
626
627 let recipe_file = reqwest::multipart::Part::file(file_path).await?;
628
629 let form = reqwest::multipart::Form::new()
630 .text("operations", operations)
631 .text("map", file_map)
632 .part("0", recipe_file);
633
634 let response = self
635 .client
636 .post(self.graphql_url.clone())
637 .bearer_auth(&self.auth_token)
638 .multipart(form)
639 .send()
640 .await?;
641 let response: Response<
642 <UpdateCustomRecipe as graphql_client::GraphQLQuery>::ResponseData,
643 > = response.json().await?;
644
645 match response.data {
646 Some(data) => Ok(data.update_custom_recipe),
647 None => {
648 if let Some(errors) = response.errors {
649 return Err(AdaptiveError::GraphQLErrors(errors));
650 }
651 Err(AdaptiveError::NoGraphQLData)
652 }
653 }
654 }
655 None => {
656 let variables = update_custom_recipe::Variables {
657 project: IdOrKey::from(project),
658 id: IdOrKey::from(id),
659 input,
660 file: None,
661 };
662
663 let response_data = self.execute_query(UpdateCustomRecipe, variables).await?;
664 Ok(response_data.update_custom_recipe)
665 }
666 }
667 }
668
669 pub async fn run_recipe(
670 &self,
671 project: &str,
672 recipe_id: &str,
673 parameters: Map<String, Value>,
674 name: Option<String>,
675 compute_pool: Option<String>,
676 num_gpus: u32,
677 use_experimental_runner: bool,
678 ) -> Result<run_custom_recipe::RunCustomRecipeCreateJob> {
679 let variables = run_custom_recipe::Variables {
680 input: run_custom_recipe::JobInput {
681 recipe: recipe_id.to_string(),
682 project: project.to_string(),
683 args: parameters,
684 name,
685 compute_pool,
686 num_gpus: num_gpus as i64,
687 use_experimental_runner,
688 image_tag: None,
689 max_cpu: None,
690 max_ram_gb: None,
691 max_duration_secs: None,
692 resume_artifact_id: None,
693 },
694 };
695
696 let response_data = self.execute_query(RunCustomRecipe, variables).await?;
697 Ok(response_data.create_job)
698 }
699
700 pub async fn list_jobs(
701 &self,
702 project: Option<String>,
703 ) -> Result<Vec<list_jobs::ListJobsJobsNodes>> {
704 let mut jobs = Vec::new();
705 let mut page = self.list_jobs_page(project.clone(), None).await?;
706 jobs.extend(page.nodes.iter().cloned());
707 while page.page_info.has_next_page {
708 page = self
709 .list_jobs_page(project.clone(), page.page_info.end_cursor)
710 .await?;
711 jobs.extend(page.nodes.iter().cloned());
712 }
713 Ok(jobs)
714 }
715
716 async fn list_jobs_page(
717 &self,
718 project: Option<String>,
719 after: Option<String>,
720 ) -> Result<list_jobs::ListJobsJobs> {
721 let variables = list_jobs::Variables {
722 filter: Some(list_jobs::ListJobsFilterInput {
723 project,
724 kind: Some(vec![list_jobs::JobKind::CUSTOM]),
725 status: Some(vec![
726 list_jobs::JobStatus::RUNNING,
727 list_jobs::JobStatus::PENDING,
728 ]),
729 timerange: None,
730 custom_recipes: None,
731 artifacts: None,
732 created_by: None,
733 name: None,
734 advanced_filter: Box::new(None),
735 }),
736 cursor: Some(list_jobs::CursorPageInput {
737 first: Some(PAGE_SIZE as i64),
738 after,
739 before: None,
740 last: None,
741 offset: None,
742 }),
743 };
744
745 let response_data = self.execute_query(ListJobs, variables).await?;
746 Ok(response_data.jobs)
747 }
748
749 pub async fn cancel_job(&self, job_id: Uuid) -> Result<cancel_job::CancelJobCancelJob> {
750 let variables = cancel_job::Variables { job_id };
751
752 let response_data = self.execute_query(CancelJob, variables).await?;
753 Ok(response_data.cancel_job)
754 }
755
756 pub async fn update_job_progress(
757 &self,
758 job_id: Uuid,
759 event: update_job_progress::JobProgressEventInput,
760 ) -> Result<update_job_progress::UpdateJobProgressUpdateJobProgress> {
761 let variables = update_job_progress::Variables { job_id, event };
762
763 let response_data = self.execute_query(UpdateJobProgress, variables).await?;
764 Ok(response_data.update_job_progress)
765 }
766
767 pub async fn list_models(
768 &self,
769 project: String,
770 ) -> Result<Vec<list_models::ListModelsProjectModelServices>> {
771 let variables = list_models::Variables { project };
772
773 let response_data = self.execute_query(ListModels, variables).await?;
774 Ok(response_data
775 .project
776 .map(|project| project.model_services)
777 .unwrap_or(Vec::new()))
778 }
779
780 pub async fn list_all_models(&self) -> Result<Vec<list_all_models::ListAllModelsModels>> {
781 let variables = list_all_models::Variables {};
782
783 let response_data = self.execute_query(ListAllModels, variables).await?;
784 Ok(response_data.models)
785 }
786
787 pub async fn list_projects(&self) -> Result<Vec<list_projects::ListProjectsProjects>> {
788 let variables = list_projects::Variables {};
789
790 let response_data = self.execute_query(ListProjects, variables).await?;
791 Ok(response_data.projects)
792 }
793
794 pub async fn list_pools(
795 &self,
796 ) -> Result<Vec<list_compute_pools::ListComputePoolsComputePools>> {
797 let variables = list_compute_pools::Variables {};
798
799 let response_data = self.execute_query(ListComputePools, variables).await?;
800 Ok(response_data.compute_pools)
801 }
802
803 pub async fn list_roles(&self) -> Result<Vec<list_roles::ListRolesRoles>> {
804 let variables = list_roles::Variables {};
805
806 let response_data = self.execute_query(ListRoles, variables).await?;
807 Ok(response_data.roles)
808 }
809
810 pub async fn create_role(
811 &self,
812 name: &str,
813 key: Option<&str>,
814 permissions: Vec<String>,
815 ) -> Result<create_role::CreateRoleCreateRole> {
816 let variables = create_role::Variables {
817 input: create_role::RoleCreate {
818 name: name.to_string(),
819 key: key.map(|k| k.to_string()),
820 permissions,
821 },
822 };
823
824 let response_data = self.execute_query(CreateRole, variables).await?;
825 Ok(response_data.create_role)
826 }
827
828 pub async fn update_role(
829 &self,
830 role: &str,
831 name: Option<&str>,
832 permissions: Option<Vec<String>>,
833 ) -> Result<update_role::UpdateRoleUpdateRole> {
834 let variables = update_role::Variables {
835 input: update_role::RoleUpdate {
836 role: role.to_string(),
837 name: name.map(|n| n.to_string()),
838 permissions,
839 },
840 };
841
842 let response_data = self.execute_query(UpdateRole, variables).await?;
843 Ok(response_data.update_role)
844 }
845
846 pub async fn list_teams(&self) -> Result<Vec<list_teams::ListTeamsTeams>> {
847 let variables = list_teams::Variables {};
848
849 let response_data = self.execute_query(ListTeams, variables).await?;
850 Ok(response_data.teams)
851 }
852
853 pub async fn create_team(
854 &self,
855 name: &str,
856 key: Option<&str>,
857 ) -> Result<create_team::CreateTeamCreateTeam> {
858 let variables = create_team::Variables {
859 input: create_team::TeamCreate {
860 name: name.to_string(),
861 key: key.map(|k| k.to_string()),
862 },
863 };
864
865 let response_data = self.execute_query(CreateTeam, variables).await?;
866 Ok(response_data.create_team)
867 }
868
869 pub async fn list_users(&self) -> Result<Vec<list_users::ListUsersUsers>> {
870 let variables = list_users::Variables {};
871
872 let response_data = self.execute_query(ListUsers, variables).await?;
873 Ok(response_data.users)
874 }
875
876 pub async fn create_user(
877 &self,
878 name: &str,
879 email: Option<&str>,
880 teams: Vec<create_user::UserCreateTeamWithRole>,
881 user_type: Option<create_user::UserType>,
882 generate_api_key: Option<bool>,
883 ) -> Result<create_user::CreateUserCreateUser> {
884 let variables = create_user::Variables {
885 input: create_user::UserCreate {
886 name: name.to_string(),
887 email: email.map(|e| e.to_string()),
888 teams,
889 user_type: user_type.unwrap_or(create_user::UserType::HUMAN),
890 generate_api_key,
891 },
892 };
893
894 let response_data = self.execute_query(CreateUser, variables).await?;
895 Ok(response_data.create_user)
896 }
897
898 pub async fn delete_user(&self, user: &str) -> Result<delete_user::DeleteUserDeleteUser> {
899 let variables = delete_user::Variables {
900 user: user.to_string(),
901 };
902
903 let response_data = self.execute_query(DeleteUser, variables).await?;
904 Ok(response_data.delete_user)
905 }
906
907 pub async fn add_team_member(
908 &self,
909 user: &str,
910 team: &str,
911 role: &str,
912 ) -> Result<add_team_member::AddTeamMemberSetTeamMember> {
913 let variables = add_team_member::Variables {
914 input: add_team_member::TeamMemberSet {
915 user: user.to_string(),
916 team: team.to_string(),
917 role: role.to_string(),
918 },
919 };
920
921 let response_data = self.execute_query(AddTeamMember, variables).await?;
922 Ok(response_data.set_team_member)
923 }
924
925 pub async fn remove_team_member(
926 &self,
927 user: &str,
928 team: &str,
929 ) -> Result<remove_team_member::RemoveTeamMemberRemoveTeamMember> {
930 let variables = remove_team_member::Variables {
931 input: remove_team_member::TeamMemberRemove {
932 user: user.to_string(),
933 team: team.to_string(),
934 },
935 };
936
937 let response_data = self.execute_query(RemoveTeamMember, variables).await?;
938 Ok(response_data.remove_team_member)
939 }
940
941 pub async fn get_recipe(
942 &self,
943 project: String,
944 id_or_key: String,
945 ) -> Result<Option<get_recipe::GetRecipeCustomRecipe>> {
946 let variables = get_recipe::Variables { project, id_or_key };
947
948 let response_data = self.execute_query(GetRecipe, variables).await?;
949 Ok(response_data.custom_recipe)
950 }
951
952 pub async fn get_grader(
953 &self,
954 id_or_key: &str,
955 project: &str,
956 ) -> Result<get_grader::GetGraderGrader> {
957 let variables = get_grader::Variables {
958 id: id_or_key.to_string(),
959 project: project.to_string(),
960 };
961
962 let response_data = self.execute_query(GetGrader, variables).await?;
963 Ok(response_data.grader)
964 }
965
966 pub async fn get_dataset(
967 &self,
968 id_or_key: &str,
969 project: &str,
970 ) -> Result<Option<get_dataset::GetDatasetDataset>> {
971 let variables = get_dataset::Variables {
972 id_or_key: id_or_key.to_string(),
973 project: project.to_string(),
974 };
975
976 let response_data = self.execute_query(GetDataset, variables).await?;
977 Ok(response_data.dataset)
978 }
979
980 pub async fn get_model_config(
981 &self,
982 id_or_key: &str,
983 ) -> Result<Option<get_model_config::GetModelConfigModel>> {
984 let variables = get_model_config::Variables {
985 id_or_key: id_or_key.to_string(),
986 };
987
988 let response_data = self.execute_query(GetModelConfig, variables).await?;
989 Ok(response_data.model)
990 }
991
992 pub async fn get_artifact(
993 &self,
994 project: &str,
995 id: Uuid,
996 ) -> Result<Option<get_artifact::GetArtifactArtifact>> {
997 let variables = get_artifact::Variables {
998 project: project.to_string(),
999 id,
1000 };
1001
1002 let response_data = self.execute_query(GetArtifact, variables).await?;
1003 Ok(response_data.artifact)
1004 }
1005
1006 pub fn base_url(&self) -> &Url {
1007 &self.rest_base_url
1008 }
1009
1010 pub async fn upload_bytes(&self, data: &[u8], content_type: &str) -> Result<String> {
1013 let file_size = data.len() as u64;
1014
1015 let chunk_size = if file_size < 5 * 1024 * 1024 {
1017 file_size.max(1)
1019 } else if file_size < 500 * 1024 * 1024 {
1020 5 * 1024 * 1024
1021 } else if file_size < 10 * 1024 * 1024 * 1024 {
1022 10 * 1024 * 1024
1023 } else {
1024 100 * 1024 * 1024
1025 };
1026
1027 let total_parts = file_size.div_ceil(chunk_size).max(1);
1028
1029 let session_id = self
1031 .init_chunked_upload_with_content_type(total_parts, content_type)
1032 .await?;
1033
1034 for part_number in 1..=total_parts {
1036 let start = ((part_number - 1) * chunk_size) as usize;
1037 let end = (part_number * chunk_size).min(file_size) as usize;
1038 let chunk = data[start..end].to_vec();
1039
1040 if let Err(e) = self
1041 .upload_part_simple(&session_id, part_number, chunk)
1042 .await
1043 {
1044 let _ = self.abort_chunked_upload(&session_id).await;
1045 return Err(e);
1046 }
1047 }
1048
1049 Ok(session_id)
1050 }
1051
1052 pub async fn init_chunked_upload_with_content_type(
1054 &self,
1055 total_parts: u64,
1056 content_type: &str,
1057 ) -> Result<String> {
1058 let url = self.rest_base_url.join(INIT_CHUNKED_UPLOAD_ROUTE)?;
1059
1060 let request = InitChunkedUploadRequest {
1061 content_type: content_type.to_string(),
1062 metadata: None,
1063 total_parts_count: total_parts,
1064 };
1065
1066 let response = self
1067 .client
1068 .post(url)
1069 .bearer_auth(&self.auth_token)
1070 .json(&request)
1071 .send()
1072 .await?;
1073
1074 if !response.status().is_success() {
1075 return Err(AdaptiveError::ChunkedUploadInitFailed {
1076 status: response.status().to_string(),
1077 body: response.text().await.unwrap_or_default(),
1078 });
1079 }
1080
1081 let init_response: InitChunkedUploadResponse = response.json().await?;
1082 Ok(init_response.session_id)
1083 }
1084
1085 pub async fn upload_part_simple(
1087 &self,
1088 session_id: &str,
1089 part_number: u64,
1090 data: Vec<u8>,
1091 ) -> Result<()> {
1092 let url = self.rest_base_url.join(UPLOAD_PART_ROUTE)?;
1093
1094 let response = self
1095 .client
1096 .post(url)
1097 .bearer_auth(&self.auth_token)
1098 .query(&[
1099 ("session_id", session_id),
1100 ("part_number", &part_number.to_string()),
1101 ])
1102 .header("Content-Type", "application/octet-stream")
1103 .body(data)
1104 .send()
1105 .await?;
1106
1107 if !response.status().is_success() {
1108 return Err(AdaptiveError::ChunkedUploadPartFailed {
1109 part_number,
1110 status: response.status().to_string(),
1111 body: response.text().await.unwrap_or_default(),
1112 });
1113 }
1114
1115 Ok(())
1116 }
1117
1118 async fn init_chunked_upload(&self, total_parts: u64) -> Result<String> {
1119 let url = self.rest_base_url.join(INIT_CHUNKED_UPLOAD_ROUTE)?;
1120
1121 let request = InitChunkedUploadRequest {
1122 content_type: "application/jsonl".to_string(),
1123 metadata: None,
1124 total_parts_count: total_parts,
1125 };
1126
1127 let response = self
1128 .client
1129 .post(url)
1130 .bearer_auth(&self.auth_token)
1131 .json(&request)
1132 .send()
1133 .await?;
1134
1135 if !response.status().is_success() {
1136 return Err(AdaptiveError::ChunkedUploadInitFailed {
1137 status: response.status().to_string(),
1138 body: response.text().await.unwrap_or_default(),
1139 });
1140 }
1141
1142 let init_response: InitChunkedUploadResponse = response.json().await?;
1143 Ok(init_response.session_id)
1144 }
1145
1146 async fn upload_part(
1147 &self,
1148 session_id: &str,
1149 part_number: u64,
1150 data: Vec<u8>,
1151 progress_tx: mpsc::Sender<u64>,
1152 ) -> Result<()> {
1153 const SUB_CHUNK_SIZE: usize = 64 * 1024;
1154
1155 let url = self.rest_base_url.join(UPLOAD_PART_ROUTE)?;
1156
1157 let chunks: Vec<Vec<u8>> = data
1158 .chunks(SUB_CHUNK_SIZE)
1159 .map(|chunk| chunk.to_vec())
1160 .collect();
1161
1162 let stream = futures::stream::iter(chunks).map(move |chunk| {
1163 let len = chunk.len() as u64;
1164 let tx = progress_tx.clone();
1165 let _ = tx.try_send(len);
1166 Ok::<_, std::io::Error>(chunk)
1167 });
1168
1169 let body = reqwest::Body::wrap_stream(stream);
1170
1171 let response = self
1172 .client
1173 .post(url)
1174 .bearer_auth(&self.auth_token)
1175 .query(&[
1176 ("session_id", session_id),
1177 ("part_number", &part_number.to_string()),
1178 ])
1179 .header("Content-Type", "application/octet-stream")
1180 .body(body)
1181 .send()
1182 .await?;
1183
1184 if !response.status().is_success() {
1185 return Err(AdaptiveError::ChunkedUploadPartFailed {
1186 part_number,
1187 status: response.status().to_string(),
1188 body: response.text().await.unwrap_or_default(),
1189 });
1190 }
1191
1192 Ok(())
1193 }
1194
1195 pub async fn abort_chunked_upload(&self, session_id: &str) -> Result<()> {
1197 let url = self.rest_base_url.join(ABORT_CHUNKED_UPLOAD_ROUTE)?;
1198
1199 let request = AbortChunkedUploadRequest {
1200 session_id: session_id.to_string(),
1201 };
1202
1203 let _ = self
1204 .client
1205 .delete(url)
1206 .bearer_auth(&self.auth_token)
1207 .json(&request)
1208 .send()
1209 .await;
1210
1211 Ok(())
1212 }
1213
1214 async fn create_dataset_from_multipart(
1215 &self,
1216 project: &str,
1217 name: &str,
1218 key: &str,
1219 session_id: &str,
1220 ) -> Result<
1221 create_dataset_from_multipart::CreateDatasetFromMultipartCreateDatasetFromMultipartUpload,
1222 > {
1223 let variables = create_dataset_from_multipart::Variables {
1224 input: create_dataset_from_multipart::DatasetCreateFromMultipartUpload {
1225 project: project.to_string(),
1226 name: name.to_string(),
1227 key: Some(key.to_string()),
1228 source: None,
1229 upload_session_id: session_id.to_string(),
1230 },
1231 };
1232
1233 let response_data = self
1234 .execute_query(CreateDatasetFromMultipart, variables)
1235 .await?;
1236 Ok(response_data.create_dataset_from_multipart_upload)
1237 }
1238
1239 pub fn chunked_upload_dataset<'a, P: AsRef<Path> + Send + 'a>(
1240 &'a self,
1241 project: &'a str,
1242 name: &'a str,
1243 key: &'a str,
1244 dataset: P,
1245 ) -> Result<BoxStream<'a, Result<UploadEvent>>> {
1246 let file_size = std::fs::metadata(dataset.as_ref())?.len();
1247
1248 let (total_parts, chunk_size) = calculate_upload_parts(file_size)?;
1249
1250 let stream = async_stream::try_stream! {
1251 yield UploadEvent::Progress(ChunkedUploadProgress {
1252 bytes_uploaded: 0,
1253 total_bytes: file_size,
1254 });
1255
1256 let session_id = self.init_chunked_upload(total_parts).await?;
1257
1258 let mut file = File::open(dataset.as_ref())?;
1259 let mut buffer = vec![0u8; chunk_size as usize];
1260 let mut bytes_uploaded = 0u64;
1261
1262 let (progress_tx, mut progress_rx) = mpsc::channel::<u64>(64);
1263
1264 for part_number in 1..=total_parts {
1265 let bytes_read = file.read(&mut buffer)?;
1266 let chunk_data = buffer[..bytes_read].to_vec();
1267
1268 let upload_fut = self.upload_part(&session_id, part_number, chunk_data, progress_tx.clone());
1269 tokio::pin!(upload_fut);
1270
1271 let upload_result: Result<()> = loop {
1272 tokio::select! {
1273 biased;
1274 result = &mut upload_fut => {
1275 break result;
1276 }
1277 Some(bytes) = progress_rx.recv() => {
1278 bytes_uploaded += bytes;
1279 yield UploadEvent::Progress(ChunkedUploadProgress {
1280 bytes_uploaded,
1281 total_bytes: file_size,
1282 });
1283 }
1284 }
1285 };
1286
1287 if let Err(e) = upload_result {
1288 let _ = self.abort_chunked_upload(&session_id).await;
1289 Err(e)?;
1290 }
1291 }
1292
1293 let create_result = self
1294 .create_dataset_from_multipart(project, name, key, &session_id)
1295 .await;
1296
1297 match create_result {
1298 Ok(response) => {
1299 yield UploadEvent::Complete(response);
1300 }
1301 Err(e) => {
1302 let _ = self.abort_chunked_upload(&session_id).await;
1303 Err(AdaptiveError::DatasetCreationFailed(e.to_string()))?;
1304 }
1305 }
1306 };
1307
1308 Ok(Box::pin(stream))
1309 }
1310
1311 pub async fn download_file_to_path(&self, url: &str, dest_path: &Path) -> Result<()> {
1314 use tokio::io::AsyncWriteExt;
1315
1316 let full_url = if url.starts_with("http://") || url.starts_with("https://") {
1317 Url::parse(url)?
1318 } else {
1319 self.rest_base_url.join(url)?
1320 };
1321
1322 let response = self
1323 .client
1324 .get(full_url)
1325 .bearer_auth(&self.auth_token)
1326 .send()
1327 .await?;
1328
1329 if !response.status().is_success() {
1330 return Err(AdaptiveError::HttpError(
1331 response.error_for_status().unwrap_err(),
1332 ));
1333 }
1334
1335 let mut file = tokio::fs::File::create(dest_path).await?;
1336 let mut stream = response.bytes_stream();
1337
1338 while let Some(chunk) = stream.next().await {
1339 let chunk = chunk?;
1340 file.write_all(&chunk).await?;
1341 }
1342
1343 file.flush().await?;
1344 Ok(())
1345 }
1346}