1use std::{fmt::Display, fs::File, io::Read, path::Path, time::SystemTime};
2
3use futures::{StreamExt, stream::BoxStream};
4use thiserror::Error;
5use tokio::sync::mpsc;
6
7use graphql_client::{GraphQLQuery, Response};
8use reqwest::Client;
9use serde::{Deserialize, Serialize, de::DeserializeOwned};
10use serde_json::{Map, Value};
11use url::Url;
12use uuid::Uuid;
13
14mod rest_types;
15mod serde_utils;
16
17use rest_types::{AbortChunkedUploadRequest, InitChunkedUploadRequest, InitChunkedUploadResponse};
18
19const MEGABYTE: u64 = 1024 * 1024; pub const MIN_CHUNK_SIZE_BYTES: u64 = 5 * MEGABYTE;
21const MAX_CHUNK_SIZE_BYTES: u64 = 100 * MEGABYTE;
22const MAX_PARTS_COUNT: u64 = 10000;
23
24const SIZE_500MB: u64 = 500 * MEGABYTE;
25const SIZE_10GB: u64 = 10 * 1024 * MEGABYTE;
26const SIZE_50GB: u64 = 50 * 1024 * MEGABYTE;
27
28#[derive(Error, Debug)]
29pub enum AdaptiveError {
30 #[error("HTTP error: {0}")]
31 HttpError(#[from] reqwest::Error),
32
33 #[error("JSON error: {0}")]
34 JsonError(#[from] serde_json::Error),
35
36 #[error("IO error: {0}")]
37 IoError(#[from] std::io::Error),
38
39 #[error("URL parse error: {0}")]
40 UrlParseError(#[from] url::ParseError),
41
42 #[error("File too small for chunked upload: {size} bytes (minimum: {min_size} bytes)")]
43 FileTooSmall { size: u64, min_size: u64 },
44
45 #[error("File too large: {size} bytes exceeds maximum {max_size} bytes")]
46 FileTooLarge { size: u64, max_size: u64 },
47
48 #[error("GraphQL errors: {0:?}")]
49 GraphQLErrors(Vec<graphql_client::Error>),
50
51 #[error("No data returned from GraphQL")]
52 NoGraphQLData,
53
54 #[error("Job not found: {0}")]
55 JobNotFound(Uuid),
56
57 #[error("Failed to initialize chunked upload: {status} - {body}")]
58 ChunkedUploadInitFailed { status: String, body: String },
59
60 #[error("Failed to upload part {part_number}: {status} - {body}")]
61 ChunkedUploadPartFailed {
62 part_number: u64,
63 status: String,
64 body: String,
65 },
66
67 #[error("Failed to create dataset: {0}")]
68 DatasetCreationFailed(String),
69
70 #[error("HTTP status error: {status} - {body}")]
71 HttpStatusError { status: String, body: String },
72
73 #[error("Failed to parse JSON response: {error}. Body preview: {body}")]
74 JsonParseError { error: String, body: String },
75}
76
77type Result<T> = std::result::Result<T, AdaptiveError>;
78
79#[derive(Clone, Debug, Default)]
80pub struct ChunkedUploadProgress {
81 pub bytes_uploaded: u64,
82 pub total_bytes: u64,
83}
84
85#[derive(Debug)]
86pub enum UploadEvent {
87 Progress(ChunkedUploadProgress),
88 Complete(
89 create_dataset_from_multipart::CreateDatasetFromMultipartCreateDatasetFromMultipartUpload,
90 ),
91}
92
93pub fn calculate_upload_parts(file_size: u64) -> Result<(u64, u64)> {
94 if file_size < MIN_CHUNK_SIZE_BYTES {
95 return Err(AdaptiveError::FileTooSmall {
96 size: file_size,
97 min_size: MIN_CHUNK_SIZE_BYTES,
98 });
99 }
100
101 let mut chunk_size = if file_size < SIZE_500MB {
102 5 * MEGABYTE
103 } else if file_size < SIZE_10GB {
104 10 * MEGABYTE
105 } else if file_size < SIZE_50GB {
106 50 * MEGABYTE
107 } else {
108 100 * MEGABYTE
109 };
110
111 let mut total_parts = file_size.div_ceil(chunk_size);
112
113 if total_parts > MAX_PARTS_COUNT {
114 chunk_size = file_size.div_ceil(MAX_PARTS_COUNT);
115
116 if chunk_size > MAX_CHUNK_SIZE_BYTES {
117 let max_file_size = MAX_CHUNK_SIZE_BYTES * MAX_PARTS_COUNT;
118 return Err(AdaptiveError::FileTooLarge {
119 size: file_size,
120 max_size: max_file_size,
121 });
122 }
123
124 total_parts = file_size.div_ceil(chunk_size);
125 }
126
127 Ok((total_parts, chunk_size))
128}
129
130type IdOrKey = String;
131#[allow(clippy::upper_case_acronyms)]
132type UUID = Uuid;
133type JsObject = Map<String, Value>;
134type InputDatetime = String;
135#[allow(clippy::upper_case_acronyms)]
136type JSON = Value;
137#[allow(clippy::upper_case_acronyms)]
138type JSONObject = Map<String, Value>;
139type KeyInput = String;
140
141#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
142pub struct Timestamp(pub SystemTime);
143
144impl<'de> serde::Deserialize<'de> for Timestamp {
145 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
146 where
147 D: serde::Deserializer<'de>,
148 {
149 let system_time = serde_utils::deserialize_timestamp_millis(deserializer)?;
150 Ok(Timestamp(system_time))
151 }
152}
153
154const PAGE_SIZE: usize = 20;
155
156#[derive(Debug, PartialEq, Serialize, Deserialize)]
157pub struct Upload(usize);
158
159#[derive(GraphQLQuery)]
160#[graphql(
161 schema_path = "schema.gql",
162 query_path = "src/graphql/list.graphql",
163 response_derives = "Debug, Clone"
164)]
165pub struct GetCustomRecipes;
166
167#[derive(GraphQLQuery)]
168#[graphql(
169 schema_path = "schema.gql",
170 query_path = "src/graphql/job.graphql",
171 response_derives = "Debug, Clone"
172)]
173pub struct GetJob;
174
175#[derive(GraphQLQuery)]
176#[graphql(
177 schema_path = "schema.gql",
178 query_path = "src/graphql/jobs.graphql",
179 response_derives = "Debug, Clone"
180)]
181pub struct ListJobs;
182
183#[derive(GraphQLQuery)]
184#[graphql(
185 schema_path = "schema.gql",
186 query_path = "src/graphql/cancel.graphql",
187 response_derives = "Debug, Clone"
188)]
189pub struct CancelJob;
190
191#[derive(GraphQLQuery)]
192#[graphql(
193 schema_path = "schema.gql",
194 query_path = "src/graphql/models.graphql",
195 response_derives = "Debug, Clone"
196)]
197pub struct ListModels;
198
199#[derive(GraphQLQuery)]
200#[graphql(
201 schema_path = "schema.gql",
202 query_path = "src/graphql/all_models.graphql",
203 response_derives = "Debug, Clone"
204)]
205pub struct ListAllModels;
206
207impl Display for get_job::JobStatus {
208 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
209 match self {
210 get_job::JobStatus::PENDING => write!(f, "Pending"),
211 get_job::JobStatus::RUNNING => write!(f, "Running"),
212 get_job::JobStatus::COMPLETED => write!(f, "Completed"),
213 get_job::JobStatus::FAILED => write!(f, "Failed"),
214 get_job::JobStatus::CANCELED => write!(f, "Canceled"),
215 get_job::JobStatus::Other(_) => write!(f, "Unknown"),
216 }
217 }
218}
219
220#[derive(GraphQLQuery)]
221#[graphql(
222 schema_path = "schema.gql",
223 query_path = "src/graphql/publish.graphql",
224 response_derives = "Debug, Clone"
225)]
226pub struct PublishCustomRecipe;
227
228#[derive(GraphQLQuery)]
229#[graphql(
230 schema_path = "schema.gql",
231 query_path = "src/graphql/update_recipe.graphql",
232 response_derives = "Debug, Clone"
233)]
234pub struct UpdateCustomRecipe;
235
236#[derive(GraphQLQuery)]
237#[graphql(
238 schema_path = "schema.gql",
239 query_path = "src/graphql/upload_dataset.graphql",
240 response_derives = "Debug, Clone"
241)]
242pub struct UploadDataset;
243
244#[derive(GraphQLQuery)]
245#[graphql(
246 schema_path = "schema.gql",
247 query_path = "src/graphql/create_dataset_from_multipart.graphql",
248 response_derives = "Debug, Clone"
249)]
250pub struct CreateDatasetFromMultipart;
251
252#[derive(GraphQLQuery)]
253#[graphql(
254 schema_path = "schema.gql",
255 query_path = "src/graphql/run.graphql",
256 response_derives = "Debug, Clone"
257)]
258pub struct RunCustomRecipe;
259
260#[derive(GraphQLQuery)]
261#[graphql(
262 schema_path = "schema.gql",
263 query_path = "src/graphql/projects.graphql",
264 response_derives = "Debug, Clone"
265)]
266pub struct ListProjects;
267
268#[derive(GraphQLQuery)]
269#[graphql(
270 schema_path = "schema.gql",
271 query_path = "src/graphql/pools.graphql",
272 response_derives = "Debug, Clone"
273)]
274pub struct ListComputePools;
275
276#[derive(GraphQLQuery)]
277#[graphql(
278 schema_path = "schema.gql",
279 query_path = "src/graphql/recipe.graphql",
280 response_derives = "Debug, Clone"
281)]
282pub struct GetRecipe;
283
284#[derive(GraphQLQuery)]
285#[graphql(
286 schema_path = "schema.gql",
287 query_path = "src/graphql/grader.graphql",
288 response_derives = "Debug, Clone, Serialize"
289)]
290pub struct GetGrader;
291
292#[derive(GraphQLQuery)]
293#[graphql(
294 schema_path = "schema.gql",
295 query_path = "src/graphql/dataset.graphql",
296 response_derives = "Debug, Clone"
297)]
298pub struct GetDataset;
299
300#[derive(GraphQLQuery)]
301#[graphql(
302 schema_path = "schema.gql",
303 query_path = "src/graphql/model_config.graphql",
304 response_derives = "Debug, Clone, Serialize"
305)]
306pub struct GetModelConfig;
307
308#[derive(GraphQLQuery)]
309#[graphql(
310 schema_path = "schema.gql",
311 query_path = "src/graphql/artifact.graphql",
312 response_derives = "Debug, Clone"
313)]
314pub struct GetArtifact;
315
316#[derive(GraphQLQuery)]
317#[graphql(
318 schema_path = "schema.gql",
319 query_path = "src/graphql/job_progress.graphql",
320 response_derives = "Debug, Clone"
321)]
322pub struct UpdateJobProgress;
323
324const INIT_CHUNKED_UPLOAD_ROUTE: &str = "v1/upload/init";
325const UPLOAD_PART_ROUTE: &str = "v1/upload/part";
326const ABORT_CHUNKED_UPLOAD_ROUTE: &str = "v1/upload/abort";
327
328#[derive(Clone)]
329pub struct AdaptiveClient {
330 client: Client,
331 graphql_url: Url,
332 rest_base_url: Url,
333 auth_token: String,
334}
335
336impl AdaptiveClient {
337 pub fn new(api_base_url: Url, auth_token: String) -> Self {
338 let graphql_url = api_base_url
339 .join("graphql")
340 .expect("Failed to append graphql to base URL");
341
342 let client = Client::builder()
343 .user_agent(format!(
344 "adaptive-client-rust/{}",
345 env!("CARGO_PKG_VERSION")
346 ))
347 .build()
348 .expect("Failed to build HTTP client");
349
350 Self {
351 client,
352 graphql_url,
353 rest_base_url: api_base_url,
354 auth_token,
355 }
356 }
357
358 async fn execute_query<T>(&self, _query: T, variables: T::Variables) -> Result<T::ResponseData>
359 where
360 T: GraphQLQuery,
361 T::Variables: serde::Serialize,
362 T::ResponseData: DeserializeOwned,
363 {
364 let request_body = T::build_query(variables);
365
366 let response = self
367 .client
368 .post(self.graphql_url.clone())
369 .bearer_auth(&self.auth_token)
370 .json(&request_body)
371 .send()
372 .await?;
373
374 let status = response.status();
375 let response_text = response.text().await?;
376
377 if !status.is_success() {
378 return Err(AdaptiveError::HttpStatusError {
379 status: status.to_string(),
380 body: response_text,
381 });
382 }
383
384 let response_body: Response<T::ResponseData> = serde_json::from_str(&response_text)
385 .map_err(|e| AdaptiveError::JsonParseError {
386 error: e.to_string(),
387 body: response_text.chars().take(500).collect(),
388 })?;
389
390 match response_body.data {
391 Some(data) => Ok(data),
392 None => {
393 if let Some(errors) = response_body.errors {
394 return Err(AdaptiveError::GraphQLErrors(errors));
395 }
396 Err(AdaptiveError::NoGraphQLData)
397 }
398 }
399 }
400
401 pub async fn list_recipes(
402 &self,
403 project: &str,
404 ) -> Result<Vec<get_custom_recipes::GetCustomRecipesCustomRecipes>> {
405 let variables = get_custom_recipes::Variables {
406 project: IdOrKey::from(project),
407 };
408
409 let response_data = self.execute_query(GetCustomRecipes, variables).await?;
410 Ok(response_data.custom_recipes)
411 }
412
413 pub async fn get_job(&self, job_id: Uuid) -> Result<get_job::GetJobJob> {
414 let variables = get_job::Variables { id: job_id };
415
416 let response_data = self.execute_query(GetJob, variables).await?;
417
418 match response_data.job {
419 Some(job) => Ok(job),
420 None => Err(AdaptiveError::JobNotFound(job_id)),
421 }
422 }
423
424 pub async fn upload_dataset<P: AsRef<Path>>(
425 &self,
426 project: &str,
427 name: &str,
428 dataset: P,
429 ) -> Result<upload_dataset::UploadDatasetCreateDataset> {
430 let variables = upload_dataset::Variables {
431 project: IdOrKey::from(project),
432 file: Upload(0),
433 name: Some(name.to_string()),
434 };
435
436 let operations = UploadDataset::build_query(variables);
437 let operations = serde_json::to_string(&operations)?;
438
439 let file_map = r#"{ "0": ["variables.file"] }"#;
440
441 let dataset_file = reqwest::multipart::Part::file(dataset).await?;
442
443 let form = reqwest::multipart::Form::new()
444 .text("operations", operations)
445 .text("map", file_map)
446 .part("0", dataset_file);
447
448 let response = self
449 .client
450 .post(self.graphql_url.clone())
451 .bearer_auth(&self.auth_token)
452 .multipart(form)
453 .send()
454 .await?;
455
456 let response: Response<<UploadDataset as graphql_client::GraphQLQuery>::ResponseData> =
457 response.json().await?;
458
459 match response.data {
460 Some(data) => Ok(data.create_dataset),
461 None => {
462 if let Some(errors) = response.errors {
463 return Err(AdaptiveError::GraphQLErrors(errors));
464 }
465 Err(AdaptiveError::NoGraphQLData)
466 }
467 }
468 }
469
470 pub async fn publish_recipe<P: AsRef<Path>>(
471 &self,
472 project: &str,
473 name: &str,
474 key: &str,
475 recipe: P,
476 ) -> Result<publish_custom_recipe::PublishCustomRecipeCreateCustomRecipe> {
477 let variables = publish_custom_recipe::Variables {
478 project: IdOrKey::from(project),
479 file: Upload(0),
480 name: Some(name.to_string()),
481 key: Some(key.to_string()),
482 };
483
484 let operations = PublishCustomRecipe::build_query(variables);
485 let operations = serde_json::to_string(&operations)?;
486
487 let file_map = r#"{ "0": ["variables.file"] }"#;
488
489 let recipe_file = reqwest::multipart::Part::file(recipe).await?;
490
491 let form = reqwest::multipart::Form::new()
492 .text("operations", operations)
493 .text("map", file_map)
494 .part("0", recipe_file);
495
496 let response = self
497 .client
498 .post(self.graphql_url.clone())
499 .bearer_auth(&self.auth_token)
500 .multipart(form)
501 .send()
502 .await?;
503 let response: Response<
504 <PublishCustomRecipe as graphql_client::GraphQLQuery>::ResponseData,
505 > = response.json().await?;
506
507 match response.data {
508 Some(data) => Ok(data.create_custom_recipe),
509 None => {
510 if let Some(errors) = response.errors {
511 return Err(AdaptiveError::GraphQLErrors(errors));
512 }
513 Err(AdaptiveError::NoGraphQLData)
514 }
515 }
516 }
517
518 pub async fn update_recipe<P: AsRef<Path>>(
519 &self,
520 project: &str,
521 id: &str,
522 name: Option<String>,
523 description: Option<String>,
524 labels: Option<Vec<update_custom_recipe::LabelInput>>,
525 recipe_file: Option<P>,
526 ) -> Result<update_custom_recipe::UpdateCustomRecipeUpdateCustomRecipe> {
527 let input = update_custom_recipe::UpdateRecipeInput {
528 name,
529 description,
530 labels,
531 };
532
533 match recipe_file {
534 Some(file_path) => {
535 let variables = update_custom_recipe::Variables {
536 project: IdOrKey::from(project),
537 id: IdOrKey::from(id),
538 input,
539 file: Some(Upload(0)),
540 };
541
542 let operations = UpdateCustomRecipe::build_query(variables);
543 let operations = serde_json::to_string(&operations)?;
544
545 let file_map = r#"{ "0": ["variables.file"] }"#;
546
547 let recipe_file = reqwest::multipart::Part::file(file_path).await?;
548
549 let form = reqwest::multipart::Form::new()
550 .text("operations", operations)
551 .text("map", file_map)
552 .part("0", recipe_file);
553
554 let response = self
555 .client
556 .post(self.graphql_url.clone())
557 .bearer_auth(&self.auth_token)
558 .multipart(form)
559 .send()
560 .await?;
561 let response: Response<
562 <UpdateCustomRecipe as graphql_client::GraphQLQuery>::ResponseData,
563 > = response.json().await?;
564
565 match response.data {
566 Some(data) => Ok(data.update_custom_recipe),
567 None => {
568 if let Some(errors) = response.errors {
569 return Err(AdaptiveError::GraphQLErrors(errors));
570 }
571 Err(AdaptiveError::NoGraphQLData)
572 }
573 }
574 }
575 None => {
576 let variables = update_custom_recipe::Variables {
577 project: IdOrKey::from(project),
578 id: IdOrKey::from(id),
579 input,
580 file: None,
581 };
582
583 let response_data = self.execute_query(UpdateCustomRecipe, variables).await?;
584 Ok(response_data.update_custom_recipe)
585 }
586 }
587 }
588
589 pub async fn run_recipe(
590 &self,
591 project: &str,
592 recipe_id: &str,
593 parameters: Map<String, Value>,
594 name: Option<String>,
595 compute_pool: Option<String>,
596 num_gpus: u32,
597 use_experimental_runner: bool,
598 ) -> Result<run_custom_recipe::RunCustomRecipeCreateJob> {
599 let variables = run_custom_recipe::Variables {
600 input: run_custom_recipe::JobInput {
601 recipe: recipe_id.to_string(),
602 project: project.to_string(),
603 args: parameters,
604 name,
605 compute_pool,
606 num_gpus: num_gpus as i64,
607 use_experimental_runner,
608 max_cpu: None,
609 max_ram_gb: None,
610 max_duration_secs: None,
611 resume_artifact_id: None,
612 },
613 };
614
615 let response_data = self.execute_query(RunCustomRecipe, variables).await?;
616 Ok(response_data.create_job)
617 }
618
619 pub async fn list_jobs(
620 &self,
621 project: Option<String>,
622 ) -> Result<Vec<list_jobs::ListJobsJobsNodes>> {
623 let mut jobs = Vec::new();
624 let mut page = self.list_jobs_page(project.clone(), None).await?;
625 jobs.extend(page.nodes.iter().cloned());
626 while page.page_info.has_next_page {
627 page = self
628 .list_jobs_page(project.clone(), page.page_info.end_cursor)
629 .await?;
630 jobs.extend(page.nodes.iter().cloned());
631 }
632 Ok(jobs)
633 }
634
635 async fn list_jobs_page(
636 &self,
637 project: Option<String>,
638 after: Option<String>,
639 ) -> Result<list_jobs::ListJobsJobs> {
640 let variables = list_jobs::Variables {
641 filter: Some(list_jobs::ListJobsFilterInput {
642 project,
643 kind: Some(vec![list_jobs::JobKind::CUSTOM]),
644 status: Some(vec![
645 list_jobs::JobStatus::RUNNING,
646 list_jobs::JobStatus::PENDING,
647 ]),
648 timerange: None,
649 custom_recipes: None,
650 artifacts: None,
651 }),
652 cursor: Some(list_jobs::CursorPageInput {
653 first: Some(PAGE_SIZE as i64),
654 after,
655 before: None,
656 last: None,
657 offset: None,
658 }),
659 };
660
661 let response_data = self.execute_query(ListJobs, variables).await?;
662 Ok(response_data.jobs)
663 }
664
665 pub async fn cancel_job(&self, job_id: Uuid) -> Result<cancel_job::CancelJobCancelJob> {
666 let variables = cancel_job::Variables { job_id };
667
668 let response_data = self.execute_query(CancelJob, variables).await?;
669 Ok(response_data.cancel_job)
670 }
671
672 pub async fn update_job_progress(
673 &self,
674 job_id: Uuid,
675 event: update_job_progress::JobProgressEventInput,
676 ) -> Result<update_job_progress::UpdateJobProgressUpdateJobProgress> {
677 let variables = update_job_progress::Variables { job_id, event };
678
679 let response_data = self.execute_query(UpdateJobProgress, variables).await?;
680 Ok(response_data.update_job_progress)
681 }
682
683 pub async fn list_models(
684 &self,
685 project: String,
686 ) -> Result<Vec<list_models::ListModelsProjectModelServices>> {
687 let variables = list_models::Variables { project };
688
689 let response_data = self.execute_query(ListModels, variables).await?;
690 Ok(response_data
691 .project
692 .map(|project| project.model_services)
693 .unwrap_or(Vec::new()))
694 }
695
696 pub async fn list_all_models(&self) -> Result<Vec<list_all_models::ListAllModelsModels>> {
697 let variables = list_all_models::Variables {};
698
699 let response_data = self.execute_query(ListAllModels, variables).await?;
700 Ok(response_data.models)
701 }
702
703 pub async fn list_projects(&self) -> Result<Vec<list_projects::ListProjectsProjects>> {
704 let variables = list_projects::Variables {};
705
706 let response_data = self.execute_query(ListProjects, variables).await?;
707 Ok(response_data.projects)
708 }
709
710 pub async fn list_pools(
711 &self,
712 ) -> Result<Vec<list_compute_pools::ListComputePoolsComputePools>> {
713 let variables = list_compute_pools::Variables {};
714
715 let response_data = self.execute_query(ListComputePools, variables).await?;
716 Ok(response_data.compute_pools)
717 }
718
719 pub async fn get_recipe(
720 &self,
721 project: String,
722 id_or_key: String,
723 ) -> Result<Option<get_recipe::GetRecipeCustomRecipe>> {
724 let variables = get_recipe::Variables { project, id_or_key };
725
726 let response_data = self.execute_query(GetRecipe, variables).await?;
727 Ok(response_data.custom_recipe)
728 }
729
730 pub async fn get_grader(
731 &self,
732 id_or_key: &str,
733 project: &str,
734 ) -> Result<get_grader::GetGraderGrader> {
735 let variables = get_grader::Variables {
736 id: id_or_key.to_string(),
737 project: project.to_string(),
738 };
739
740 let response_data = self.execute_query(GetGrader, variables).await?;
741 Ok(response_data.grader)
742 }
743
744 pub async fn get_dataset(
745 &self,
746 id_or_key: &str,
747 project: &str,
748 ) -> Result<Option<get_dataset::GetDatasetDataset>> {
749 let variables = get_dataset::Variables {
750 id_or_key: id_or_key.to_string(),
751 project: project.to_string(),
752 };
753
754 let response_data = self.execute_query(GetDataset, variables).await?;
755 Ok(response_data.dataset)
756 }
757
758 pub async fn get_model_config(
759 &self,
760 id_or_key: &str,
761 ) -> Result<Option<get_model_config::GetModelConfigModel>> {
762 let variables = get_model_config::Variables {
763 id_or_key: id_or_key.to_string(),
764 };
765
766 let response_data = self.execute_query(GetModelConfig, variables).await?;
767 Ok(response_data.model)
768 }
769
770 pub async fn get_artifact(
771 &self,
772 project: &str,
773 id: Uuid,
774 ) -> Result<Option<get_artifact::GetArtifactArtifact>> {
775 let variables = get_artifact::Variables {
776 project: project.to_string(),
777 id,
778 };
779
780 let response_data = self.execute_query(GetArtifact, variables).await?;
781 Ok(response_data.artifact)
782 }
783
784 pub fn base_url(&self) -> &Url {
785 &self.rest_base_url
786 }
787
788 pub async fn upload_bytes(&self, data: &[u8], content_type: &str) -> Result<String> {
791 let file_size = data.len() as u64;
792
793 let chunk_size = if file_size < 5 * 1024 * 1024 {
795 file_size.max(1)
797 } else if file_size < 500 * 1024 * 1024 {
798 5 * 1024 * 1024
799 } else if file_size < 10 * 1024 * 1024 * 1024 {
800 10 * 1024 * 1024
801 } else {
802 100 * 1024 * 1024
803 };
804
805 let total_parts = file_size.div_ceil(chunk_size).max(1);
806
807 let session_id = self
809 .init_chunked_upload_with_content_type(total_parts, content_type)
810 .await?;
811
812 for part_number in 1..=total_parts {
814 let start = ((part_number - 1) * chunk_size) as usize;
815 let end = (part_number * chunk_size).min(file_size) as usize;
816 let chunk = data[start..end].to_vec();
817
818 if let Err(e) = self
819 .upload_part_simple(&session_id, part_number, chunk)
820 .await
821 {
822 let _ = self.abort_chunked_upload(&session_id).await;
823 return Err(e);
824 }
825 }
826
827 Ok(session_id)
828 }
829
830 pub async fn init_chunked_upload_with_content_type(
832 &self,
833 total_parts: u64,
834 content_type: &str,
835 ) -> Result<String> {
836 let url = self.rest_base_url.join(INIT_CHUNKED_UPLOAD_ROUTE)?;
837
838 let request = InitChunkedUploadRequest {
839 content_type: content_type.to_string(),
840 metadata: None,
841 total_parts_count: total_parts,
842 };
843
844 let response = self
845 .client
846 .post(url)
847 .bearer_auth(&self.auth_token)
848 .json(&request)
849 .send()
850 .await?;
851
852 if !response.status().is_success() {
853 return Err(AdaptiveError::ChunkedUploadInitFailed {
854 status: response.status().to_string(),
855 body: response.text().await.unwrap_or_default(),
856 });
857 }
858
859 let init_response: InitChunkedUploadResponse = response.json().await?;
860 Ok(init_response.session_id)
861 }
862
863 pub async fn upload_part_simple(
865 &self,
866 session_id: &str,
867 part_number: u64,
868 data: Vec<u8>,
869 ) -> Result<()> {
870 let url = self.rest_base_url.join(UPLOAD_PART_ROUTE)?;
871
872 let response = self
873 .client
874 .post(url)
875 .bearer_auth(&self.auth_token)
876 .query(&[
877 ("session_id", session_id),
878 ("part_number", &part_number.to_string()),
879 ])
880 .header("Content-Type", "application/octet-stream")
881 .body(data)
882 .send()
883 .await?;
884
885 if !response.status().is_success() {
886 return Err(AdaptiveError::ChunkedUploadPartFailed {
887 part_number,
888 status: response.status().to_string(),
889 body: response.text().await.unwrap_or_default(),
890 });
891 }
892
893 Ok(())
894 }
895
896 async fn init_chunked_upload(&self, total_parts: u64) -> Result<String> {
897 let url = self.rest_base_url.join(INIT_CHUNKED_UPLOAD_ROUTE)?;
898
899 let request = InitChunkedUploadRequest {
900 content_type: "application/jsonl".to_string(),
901 metadata: None,
902 total_parts_count: total_parts,
903 };
904
905 let response = self
906 .client
907 .post(url)
908 .bearer_auth(&self.auth_token)
909 .json(&request)
910 .send()
911 .await?;
912
913 if !response.status().is_success() {
914 return Err(AdaptiveError::ChunkedUploadInitFailed {
915 status: response.status().to_string(),
916 body: response.text().await.unwrap_or_default(),
917 });
918 }
919
920 let init_response: InitChunkedUploadResponse = response.json().await?;
921 Ok(init_response.session_id)
922 }
923
924 async fn upload_part(
925 &self,
926 session_id: &str,
927 part_number: u64,
928 data: Vec<u8>,
929 progress_tx: mpsc::Sender<u64>,
930 ) -> Result<()> {
931 const SUB_CHUNK_SIZE: usize = 64 * 1024;
932
933 let url = self.rest_base_url.join(UPLOAD_PART_ROUTE)?;
934
935 let chunks: Vec<Vec<u8>> = data
936 .chunks(SUB_CHUNK_SIZE)
937 .map(|chunk| chunk.to_vec())
938 .collect();
939
940 let stream = futures::stream::iter(chunks).map(move |chunk| {
941 let len = chunk.len() as u64;
942 let tx = progress_tx.clone();
943 let _ = tx.try_send(len);
944 Ok::<_, std::io::Error>(chunk)
945 });
946
947 let body = reqwest::Body::wrap_stream(stream);
948
949 let response = self
950 .client
951 .post(url)
952 .bearer_auth(&self.auth_token)
953 .query(&[
954 ("session_id", session_id),
955 ("part_number", &part_number.to_string()),
956 ])
957 .header("Content-Type", "application/octet-stream")
958 .body(body)
959 .send()
960 .await?;
961
962 if !response.status().is_success() {
963 return Err(AdaptiveError::ChunkedUploadPartFailed {
964 part_number,
965 status: response.status().to_string(),
966 body: response.text().await.unwrap_or_default(),
967 });
968 }
969
970 Ok(())
971 }
972
973 pub async fn abort_chunked_upload(&self, session_id: &str) -> Result<()> {
975 let url = self.rest_base_url.join(ABORT_CHUNKED_UPLOAD_ROUTE)?;
976
977 let request = AbortChunkedUploadRequest {
978 session_id: session_id.to_string(),
979 };
980
981 let _ = self
982 .client
983 .delete(url)
984 .bearer_auth(&self.auth_token)
985 .json(&request)
986 .send()
987 .await;
988
989 Ok(())
990 }
991
992 async fn create_dataset_from_multipart(
993 &self,
994 project: &str,
995 name: &str,
996 key: &str,
997 session_id: &str,
998 ) -> Result<
999 create_dataset_from_multipart::CreateDatasetFromMultipartCreateDatasetFromMultipartUpload,
1000 > {
1001 let variables = create_dataset_from_multipart::Variables {
1002 input: create_dataset_from_multipart::DatasetCreateFromMultipartUpload {
1003 project: project.to_string(),
1004 name: name.to_string(),
1005 key: Some(key.to_string()),
1006 source: None,
1007 upload_session_id: session_id.to_string(),
1008 },
1009 };
1010
1011 let response_data = self
1012 .execute_query(CreateDatasetFromMultipart, variables)
1013 .await?;
1014 Ok(response_data.create_dataset_from_multipart_upload)
1015 }
1016
1017 pub fn chunked_upload_dataset<'a, P: AsRef<Path> + Send + 'a>(
1018 &'a self,
1019 project: &'a str,
1020 name: &'a str,
1021 key: &'a str,
1022 dataset: P,
1023 ) -> Result<BoxStream<'a, Result<UploadEvent>>> {
1024 let file_size = std::fs::metadata(dataset.as_ref())?.len();
1025
1026 let (total_parts, chunk_size) = calculate_upload_parts(file_size)?;
1027
1028 let stream = async_stream::try_stream! {
1029 yield UploadEvent::Progress(ChunkedUploadProgress {
1030 bytes_uploaded: 0,
1031 total_bytes: file_size,
1032 });
1033
1034 let session_id = self.init_chunked_upload(total_parts).await?;
1035
1036 let mut file = File::open(dataset.as_ref())?;
1037 let mut buffer = vec![0u8; chunk_size as usize];
1038 let mut bytes_uploaded = 0u64;
1039
1040 let (progress_tx, mut progress_rx) = mpsc::channel::<u64>(64);
1041
1042 for part_number in 1..=total_parts {
1043 let bytes_read = file.read(&mut buffer)?;
1044 let chunk_data = buffer[..bytes_read].to_vec();
1045
1046 let upload_fut = self.upload_part(&session_id, part_number, chunk_data, progress_tx.clone());
1047 tokio::pin!(upload_fut);
1048
1049 let upload_result: Result<()> = loop {
1050 tokio::select! {
1051 biased;
1052 result = &mut upload_fut => {
1053 break result;
1054 }
1055 Some(bytes) = progress_rx.recv() => {
1056 bytes_uploaded += bytes;
1057 yield UploadEvent::Progress(ChunkedUploadProgress {
1058 bytes_uploaded,
1059 total_bytes: file_size,
1060 });
1061 }
1062 }
1063 };
1064
1065 if let Err(e) = upload_result {
1066 let _ = self.abort_chunked_upload(&session_id).await;
1067 Err(e)?;
1068 }
1069 }
1070
1071 let create_result = self
1072 .create_dataset_from_multipart(project, name, key, &session_id)
1073 .await;
1074
1075 match create_result {
1076 Ok(response) => {
1077 yield UploadEvent::Complete(response);
1078 }
1079 Err(e) => {
1080 let _ = self.abort_chunked_upload(&session_id).await;
1081 Err(AdaptiveError::DatasetCreationFailed(e.to_string()))?;
1082 }
1083 }
1084 };
1085
1086 Ok(Box::pin(stream))
1087 }
1088
1089 pub async fn download_file_to_path(&self, url: &str, dest_path: &Path) -> Result<()> {
1092 use tokio::io::AsyncWriteExt;
1093
1094 let full_url = if url.starts_with("http://") || url.starts_with("https://") {
1095 Url::parse(url)?
1096 } else {
1097 self.rest_base_url.join(url)?
1098 };
1099
1100 let response = self
1101 .client
1102 .get(full_url)
1103 .bearer_auth(&self.auth_token)
1104 .send()
1105 .await?;
1106
1107 if !response.status().is_success() {
1108 return Err(AdaptiveError::HttpError(
1109 response.error_for_status().unwrap_err(),
1110 ));
1111 }
1112
1113 let mut file = tokio::fs::File::create(dest_path).await?;
1114 let mut stream = response.bytes_stream();
1115
1116 while let Some(chunk) = stream.next().await {
1117 let chunk = chunk?;
1118 file.write_all(&chunk).await?;
1119 }
1120
1121 file.flush().await?;
1122 Ok(())
1123 }
1124}