1use std::{fmt::Display, fs::File, io::Read, path::Path, time::SystemTime};
2
3use futures::{StreamExt, stream::BoxStream};
4use thiserror::Error;
5use tokio::sync::mpsc;
6
7use graphql_client::{GraphQLQuery, Response};
8use reqwest::Client;
9use serde::{Deserialize, Serialize, de::DeserializeOwned};
10use serde_json::{Map, Value};
11use url::Url;
12use uuid::Uuid;
13
14mod rest_types;
15mod serde_utils;
16
17use rest_types::{AbortChunkedUploadRequest, InitChunkedUploadRequest, InitChunkedUploadResponse};
18
19const MEGABYTE: u64 = 1024 * 1024; pub const MIN_CHUNK_SIZE_BYTES: u64 = 5 * MEGABYTE;
21const MAX_CHUNK_SIZE_BYTES: u64 = 100 * MEGABYTE;
22const MAX_PARTS_COUNT: u64 = 10000;
23
24const SIZE_500MB: u64 = 500 * MEGABYTE;
25const SIZE_10GB: u64 = 10 * 1024 * MEGABYTE;
26const SIZE_50GB: u64 = 50 * 1024 * MEGABYTE;
27
28#[derive(Error, Debug)]
29pub enum AdaptiveError {
30 #[error("HTTP error: {0}")]
31 HttpError(#[from] reqwest::Error),
32
33 #[error("JSON error: {0}")]
34 JsonError(#[from] serde_json::Error),
35
36 #[error("IO error: {0}")]
37 IoError(#[from] std::io::Error),
38
39 #[error("URL parse error: {0}")]
40 UrlParseError(#[from] url::ParseError),
41
42 #[error("File too small for chunked upload: {size} bytes (minimum: {min_size} bytes)")]
43 FileTooSmall { size: u64, min_size: u64 },
44
45 #[error("File too large: {size} bytes exceeds maximum {max_size} bytes")]
46 FileTooLarge { size: u64, max_size: u64 },
47
48 #[error("GraphQL errors: {0:?}")]
49 GraphQLErrors(Vec<graphql_client::Error>),
50
51 #[error("No data returned from GraphQL")]
52 NoGraphQLData,
53
54 #[error("Job not found: {0}")]
55 JobNotFound(Uuid),
56
57 #[error("Failed to initialize chunked upload: {status} - {body}")]
58 ChunkedUploadInitFailed { status: String, body: String },
59
60 #[error("Failed to upload part {part_number}: {status} - {body}")]
61 ChunkedUploadPartFailed {
62 part_number: u64,
63 status: String,
64 body: String,
65 },
66
67 #[error("Failed to create dataset: {0}")]
68 DatasetCreationFailed(String),
69
70 #[error("HTTP status error: {status} - {body}")]
71 HttpStatusError { status: String, body: String },
72
73 #[error("Failed to parse JSON response: {error}. Body preview: {body}")]
74 JsonParseError { error: String, body: String },
75}
76
77type Result<T> = std::result::Result<T, AdaptiveError>;
78
79#[derive(Clone, Debug, Default)]
80pub struct ChunkedUploadProgress {
81 pub bytes_uploaded: u64,
82 pub total_bytes: u64,
83}
84
85#[derive(Debug)]
86pub enum UploadEvent {
87 Progress(ChunkedUploadProgress),
88 Complete(
89 create_dataset_from_multipart::CreateDatasetFromMultipartCreateDatasetFromMultipartUpload,
90 ),
91}
92
93pub fn calculate_upload_parts(file_size: u64) -> Result<(u64, u64)> {
94 if file_size < MIN_CHUNK_SIZE_BYTES {
95 return Err(AdaptiveError::FileTooSmall {
96 size: file_size,
97 min_size: MIN_CHUNK_SIZE_BYTES,
98 });
99 }
100
101 let mut chunk_size = if file_size < SIZE_500MB {
102 5 * MEGABYTE
103 } else if file_size < SIZE_10GB {
104 10 * MEGABYTE
105 } else if file_size < SIZE_50GB {
106 50 * MEGABYTE
107 } else {
108 100 * MEGABYTE
109 };
110
111 let mut total_parts = file_size.div_ceil(chunk_size);
112
113 if total_parts > MAX_PARTS_COUNT {
114 chunk_size = file_size.div_ceil(MAX_PARTS_COUNT);
115
116 if chunk_size > MAX_CHUNK_SIZE_BYTES {
117 let max_file_size = MAX_CHUNK_SIZE_BYTES * MAX_PARTS_COUNT;
118 return Err(AdaptiveError::FileTooLarge {
119 size: file_size,
120 max_size: max_file_size,
121 });
122 }
123
124 total_parts = file_size.div_ceil(chunk_size);
125 }
126
127 Ok((total_parts, chunk_size))
128}
129
130type IdOrKey = String;
131#[allow(clippy::upper_case_acronyms)]
132type UUID = Uuid;
133type JsObject = Map<String, Value>;
134type InputDatetime = String;
135#[allow(clippy::upper_case_acronyms)]
136type JSON = Value;
137type KeyInput = String;
138
139#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
140pub struct Timestamp(pub SystemTime);
141
142impl<'de> serde::Deserialize<'de> for Timestamp {
143 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
144 where
145 D: serde::Deserializer<'de>,
146 {
147 let system_time = serde_utils::deserialize_timestamp_millis(deserializer)?;
148 Ok(Timestamp(system_time))
149 }
150}
151
152const PAGE_SIZE: usize = 20;
153
154#[derive(Debug, PartialEq, Serialize, Deserialize)]
155pub struct Upload(usize);
156
157#[derive(GraphQLQuery)]
158#[graphql(
159 schema_path = "schema.gql",
160 query_path = "src/graphql/list.graphql",
161 response_derives = "Debug, Clone"
162)]
163pub struct GetCustomRecipes;
164
165#[derive(GraphQLQuery)]
166#[graphql(
167 schema_path = "schema.gql",
168 query_path = "src/graphql/job.graphql",
169 response_derives = "Debug, Clone"
170)]
171pub struct GetJob;
172
173#[derive(GraphQLQuery)]
174#[graphql(
175 schema_path = "schema.gql",
176 query_path = "src/graphql/jobs.graphql",
177 response_derives = "Debug, Clone"
178)]
179pub struct ListJobs;
180
181#[derive(GraphQLQuery)]
182#[graphql(
183 schema_path = "schema.gql",
184 query_path = "src/graphql/cancel.graphql",
185 response_derives = "Debug, Clone"
186)]
187pub struct CancelJob;
188
189#[derive(GraphQLQuery)]
190#[graphql(
191 schema_path = "schema.gql",
192 query_path = "src/graphql/models.graphql",
193 response_derives = "Debug, Clone"
194)]
195pub struct ListModels;
196
197#[derive(GraphQLQuery)]
198#[graphql(
199 schema_path = "schema.gql",
200 query_path = "src/graphql/all_models.graphql",
201 response_derives = "Debug, Clone"
202)]
203pub struct ListAllModels;
204
205impl Display for get_job::JobStatus {
206 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
207 match self {
208 get_job::JobStatus::PENDING => write!(f, "Pending"),
209 get_job::JobStatus::RUNNING => write!(f, "Running"),
210 get_job::JobStatus::COMPLETED => write!(f, "Completed"),
211 get_job::JobStatus::FAILED => write!(f, "Failed"),
212 get_job::JobStatus::CANCELED => write!(f, "Canceled"),
213 get_job::JobStatus::Other(_) => write!(f, "Unknown"),
214 }
215 }
216}
217
218#[derive(GraphQLQuery)]
219#[graphql(
220 schema_path = "schema.gql",
221 query_path = "src/graphql/publish.graphql",
222 response_derives = "Debug, Clone"
223)]
224pub struct PublishCustomRecipe;
225
226#[derive(GraphQLQuery)]
227#[graphql(
228 schema_path = "schema.gql",
229 query_path = "src/graphql/update_recipe.graphql",
230 response_derives = "Debug, Clone"
231)]
232pub struct UpdateCustomRecipe;
233
234#[derive(GraphQLQuery)]
235#[graphql(
236 schema_path = "schema.gql",
237 query_path = "src/graphql/upload_dataset.graphql",
238 response_derives = "Debug, Clone"
239)]
240pub struct UploadDataset;
241
242#[derive(GraphQLQuery)]
243#[graphql(
244 schema_path = "schema.gql",
245 query_path = "src/graphql/create_dataset_from_multipart.graphql",
246 response_derives = "Debug, Clone"
247)]
248pub struct CreateDatasetFromMultipart;
249
250#[derive(GraphQLQuery)]
251#[graphql(
252 schema_path = "schema.gql",
253 query_path = "src/graphql/run.graphql",
254 response_derives = "Debug, Clone"
255)]
256pub struct RunCustomRecipe;
257
258#[derive(GraphQLQuery)]
259#[graphql(
260 schema_path = "schema.gql",
261 query_path = "src/graphql/projects.graphql",
262 response_derives = "Debug, Clone"
263)]
264pub struct ListProjects;
265
266#[derive(GraphQLQuery)]
267#[graphql(
268 schema_path = "schema.gql",
269 query_path = "src/graphql/pools.graphql",
270 response_derives = "Debug, Clone"
271)]
272pub struct ListComputePools;
273
274#[derive(GraphQLQuery)]
275#[graphql(
276 schema_path = "schema.gql",
277 query_path = "src/graphql/recipe.graphql",
278 response_derives = "Debug, Clone"
279)]
280pub struct GetRecipe;
281
282#[derive(GraphQLQuery)]
283#[graphql(
284 schema_path = "schema.gql",
285 query_path = "src/graphql/grader.graphql",
286 response_derives = "Debug, Clone, Serialize"
287)]
288pub struct GetGrader;
289
290#[derive(GraphQLQuery)]
291#[graphql(
292 schema_path = "schema.gql",
293 query_path = "src/graphql/dataset.graphql",
294 response_derives = "Debug, Clone"
295)]
296pub struct GetDataset;
297
298#[derive(GraphQLQuery)]
299#[graphql(
300 schema_path = "schema.gql",
301 query_path = "src/graphql/model_config.graphql",
302 response_derives = "Debug, Clone, Serialize"
303)]
304pub struct GetModelConfig;
305
306#[derive(GraphQLQuery)]
307#[graphql(
308 schema_path = "schema.gql",
309 query_path = "src/graphql/job_progress.graphql",
310 response_derives = "Debug, Clone"
311)]
312pub struct UpdateJobProgress;
313
314const INIT_CHUNKED_UPLOAD_ROUTE: &str = "v1/upload/init";
315const UPLOAD_PART_ROUTE: &str = "v1/upload/part";
316const ABORT_CHUNKED_UPLOAD_ROUTE: &str = "v1/upload/abort";
317
318#[derive(Clone)]
319pub struct AdaptiveClient {
320 client: Client,
321 graphql_url: Url,
322 rest_base_url: Url,
323 auth_token: String,
324}
325
326impl AdaptiveClient {
327 pub fn new(api_base_url: Url, auth_token: String) -> Self {
328 let graphql_url = api_base_url
329 .join("graphql")
330 .expect("Failed to append graphql to base URL");
331
332 let client = Client::builder()
333 .user_agent(format!(
334 "adaptive-client-rust/{}",
335 env!("CARGO_PKG_VERSION")
336 ))
337 .build()
338 .expect("Failed to build HTTP client");
339
340 Self {
341 client,
342 graphql_url,
343 rest_base_url: api_base_url,
344 auth_token,
345 }
346 }
347
348 async fn execute_query<T>(&self, _query: T, variables: T::Variables) -> Result<T::ResponseData>
349 where
350 T: GraphQLQuery,
351 T::Variables: serde::Serialize,
352 T::ResponseData: DeserializeOwned,
353 {
354 let request_body = T::build_query(variables);
355
356 let response = self
357 .client
358 .post(self.graphql_url.clone())
359 .bearer_auth(&self.auth_token)
360 .json(&request_body)
361 .send()
362 .await?;
363
364 let status = response.status();
365 let response_text = response.text().await?;
366
367 if !status.is_success() {
368 return Err(AdaptiveError::HttpStatusError {
369 status: status.to_string(),
370 body: response_text,
371 });
372 }
373
374 let response_body: Response<T::ResponseData> = serde_json::from_str(&response_text)
375 .map_err(|e| AdaptiveError::JsonParseError {
376 error: e.to_string(),
377 body: response_text.chars().take(500).collect(),
378 })?;
379
380 match response_body.data {
381 Some(data) => Ok(data),
382 None => {
383 if let Some(errors) = response_body.errors {
384 return Err(AdaptiveError::GraphQLErrors(errors));
385 }
386 Err(AdaptiveError::NoGraphQLData)
387 }
388 }
389 }
390
391 pub async fn list_recipes(
392 &self,
393 project: &str,
394 ) -> Result<Vec<get_custom_recipes::GetCustomRecipesCustomRecipes>> {
395 let variables = get_custom_recipes::Variables {
396 project: IdOrKey::from(project),
397 };
398
399 let response_data = self.execute_query(GetCustomRecipes, variables).await?;
400 Ok(response_data.custom_recipes)
401 }
402
403 pub async fn get_job(&self, job_id: Uuid) -> Result<get_job::GetJobJob> {
404 let variables = get_job::Variables { id: job_id };
405
406 let response_data = self.execute_query(GetJob, variables).await?;
407
408 match response_data.job {
409 Some(job) => Ok(job),
410 None => Err(AdaptiveError::JobNotFound(job_id)),
411 }
412 }
413
414 pub async fn upload_dataset<P: AsRef<Path>>(
415 &self,
416 project: &str,
417 name: &str,
418 dataset: P,
419 ) -> Result<upload_dataset::UploadDatasetCreateDataset> {
420 let variables = upload_dataset::Variables {
421 project: IdOrKey::from(project),
422 file: Upload(0),
423 name: Some(name.to_string()),
424 };
425
426 let operations = UploadDataset::build_query(variables);
427 let operations = serde_json::to_string(&operations)?;
428
429 let file_map = r#"{ "0": ["variables.file"] }"#;
430
431 let dataset_file = reqwest::multipart::Part::file(dataset).await?;
432
433 let form = reqwest::multipart::Form::new()
434 .text("operations", operations)
435 .text("map", file_map)
436 .part("0", dataset_file);
437
438 let response = self
439 .client
440 .post(self.graphql_url.clone())
441 .bearer_auth(&self.auth_token)
442 .multipart(form)
443 .send()
444 .await?;
445
446 let response: Response<<UploadDataset as graphql_client::GraphQLQuery>::ResponseData> =
447 response.json().await?;
448
449 match response.data {
450 Some(data) => Ok(data.create_dataset),
451 None => {
452 if let Some(errors) = response.errors {
453 return Err(AdaptiveError::GraphQLErrors(errors));
454 }
455 Err(AdaptiveError::NoGraphQLData)
456 }
457 }
458 }
459
460 pub async fn publish_recipe<P: AsRef<Path>>(
461 &self,
462 project: &str,
463 name: &str,
464 key: &str,
465 recipe: P,
466 ) -> Result<publish_custom_recipe::PublishCustomRecipeCreateCustomRecipe> {
467 let variables = publish_custom_recipe::Variables {
468 project: IdOrKey::from(project),
469 file: Upload(0),
470 name: Some(name.to_string()),
471 key: Some(key.to_string()),
472 };
473
474 let operations = PublishCustomRecipe::build_query(variables);
475 let operations = serde_json::to_string(&operations)?;
476
477 let file_map = r#"{ "0": ["variables.file"] }"#;
478
479 let recipe_file = reqwest::multipart::Part::file(recipe).await?;
480
481 let form = reqwest::multipart::Form::new()
482 .text("operations", operations)
483 .text("map", file_map)
484 .part("0", recipe_file);
485
486 let response = self
487 .client
488 .post(self.graphql_url.clone())
489 .bearer_auth(&self.auth_token)
490 .multipart(form)
491 .send()
492 .await?;
493 let response: Response<
494 <PublishCustomRecipe as graphql_client::GraphQLQuery>::ResponseData,
495 > = response.json().await?;
496
497 match response.data {
498 Some(data) => Ok(data.create_custom_recipe),
499 None => {
500 if let Some(errors) = response.errors {
501 return Err(AdaptiveError::GraphQLErrors(errors));
502 }
503 Err(AdaptiveError::NoGraphQLData)
504 }
505 }
506 }
507
508 pub async fn update_recipe<P: AsRef<Path>>(
509 &self,
510 project: &str,
511 id: &str,
512 name: Option<String>,
513 description: Option<String>,
514 labels: Option<Vec<update_custom_recipe::LabelInput>>,
515 recipe_file: Option<P>,
516 ) -> Result<update_custom_recipe::UpdateCustomRecipeUpdateCustomRecipe> {
517 let input = update_custom_recipe::UpdateRecipeInput {
518 name,
519 description,
520 labels,
521 };
522
523 match recipe_file {
524 Some(file_path) => {
525 let variables = update_custom_recipe::Variables {
526 project: IdOrKey::from(project),
527 id: IdOrKey::from(id),
528 input,
529 file: Some(Upload(0)),
530 };
531
532 let operations = UpdateCustomRecipe::build_query(variables);
533 let operations = serde_json::to_string(&operations)?;
534
535 let file_map = r#"{ "0": ["variables.file"] }"#;
536
537 let recipe_file = reqwest::multipart::Part::file(file_path).await?;
538
539 let form = reqwest::multipart::Form::new()
540 .text("operations", operations)
541 .text("map", file_map)
542 .part("0", recipe_file);
543
544 let response = self
545 .client
546 .post(self.graphql_url.clone())
547 .bearer_auth(&self.auth_token)
548 .multipart(form)
549 .send()
550 .await?;
551 let response: Response<
552 <UpdateCustomRecipe as graphql_client::GraphQLQuery>::ResponseData,
553 > = response.json().await?;
554
555 match response.data {
556 Some(data) => Ok(data.update_custom_recipe),
557 None => {
558 if let Some(errors) = response.errors {
559 return Err(AdaptiveError::GraphQLErrors(errors));
560 }
561 Err(AdaptiveError::NoGraphQLData)
562 }
563 }
564 }
565 None => {
566 let variables = update_custom_recipe::Variables {
567 project: IdOrKey::from(project),
568 id: IdOrKey::from(id),
569 input,
570 file: None,
571 };
572
573 let response_data = self.execute_query(UpdateCustomRecipe, variables).await?;
574 Ok(response_data.update_custom_recipe)
575 }
576 }
577 }
578
579 pub async fn run_recipe(
580 &self,
581 project: &str,
582 recipe_id: &str,
583 parameters: Map<String, Value>,
584 name: Option<String>,
585 compute_pool: Option<String>,
586 num_gpus: u32,
587 use_experimental_runner: bool,
588 ) -> Result<run_custom_recipe::RunCustomRecipeCreateJob> {
589 let variables = run_custom_recipe::Variables {
590 input: run_custom_recipe::JobInput {
591 recipe: recipe_id.to_string(),
592 project: project.to_string(),
593 args: parameters,
594 name,
595 compute_pool,
596 num_gpus: num_gpus as i64,
597 use_experimental_runner,
598 max_cpu: None,
599 max_ram_gb: None,
600 max_duration_secs: None,
601 },
602 };
603
604 let response_data = self.execute_query(RunCustomRecipe, variables).await?;
605 Ok(response_data.create_job)
606 }
607
608 pub async fn list_jobs(
609 &self,
610 project: Option<String>,
611 ) -> Result<Vec<list_jobs::ListJobsJobsNodes>> {
612 let mut jobs = Vec::new();
613 let mut page = self.list_jobs_page(project.clone(), None).await?;
614 jobs.extend(page.nodes.iter().cloned());
615 while page.page_info.has_next_page {
616 page = self
617 .list_jobs_page(project.clone(), page.page_info.end_cursor)
618 .await?;
619 jobs.extend(page.nodes.iter().cloned());
620 }
621 Ok(jobs)
622 }
623
624 async fn list_jobs_page(
625 &self,
626 project: Option<String>,
627 after: Option<String>,
628 ) -> Result<list_jobs::ListJobsJobs> {
629 let variables = list_jobs::Variables {
630 filter: Some(list_jobs::ListJobsFilterInput {
631 project,
632 kind: Some(vec![list_jobs::JobKind::CUSTOM]),
633 status: Some(vec![
634 list_jobs::JobStatus::RUNNING,
635 list_jobs::JobStatus::PENDING,
636 ]),
637 timerange: None,
638 custom_recipes: None,
639 artifacts: None,
640 }),
641 cursor: Some(list_jobs::CursorPageInput {
642 first: Some(PAGE_SIZE as i64),
643 after,
644 before: None,
645 last: None,
646 offset: None,
647 }),
648 };
649
650 let response_data = self.execute_query(ListJobs, variables).await?;
651 Ok(response_data.jobs)
652 }
653
654 pub async fn cancel_job(&self, job_id: Uuid) -> Result<cancel_job::CancelJobCancelJob> {
655 let variables = cancel_job::Variables { job_id };
656
657 let response_data = self.execute_query(CancelJob, variables).await?;
658 Ok(response_data.cancel_job)
659 }
660
661 pub async fn update_job_progress(
662 &self,
663 job_id: Uuid,
664 event: update_job_progress::JobProgressEventInput,
665 ) -> Result<update_job_progress::UpdateJobProgressUpdateJobProgress> {
666 let variables = update_job_progress::Variables { job_id, event };
667
668 let response_data = self.execute_query(UpdateJobProgress, variables).await?;
669 Ok(response_data.update_job_progress)
670 }
671
672 pub async fn list_models(
673 &self,
674 project: String,
675 ) -> Result<Vec<list_models::ListModelsProjectModelServices>> {
676 let variables = list_models::Variables { project };
677
678 let response_data = self.execute_query(ListModels, variables).await?;
679 Ok(response_data
680 .project
681 .map(|project| project.model_services)
682 .unwrap_or(Vec::new()))
683 }
684
685 pub async fn list_all_models(&self) -> Result<Vec<list_all_models::ListAllModelsModels>> {
686 let variables = list_all_models::Variables {};
687
688 let response_data = self.execute_query(ListAllModels, variables).await?;
689 Ok(response_data.models)
690 }
691
692 pub async fn list_projects(&self) -> Result<Vec<list_projects::ListProjectsProjects>> {
693 let variables = list_projects::Variables {};
694
695 let response_data = self.execute_query(ListProjects, variables).await?;
696 Ok(response_data.projects)
697 }
698
699 pub async fn list_pools(
700 &self,
701 ) -> Result<Vec<list_compute_pools::ListComputePoolsComputePools>> {
702 let variables = list_compute_pools::Variables {};
703
704 let response_data = self.execute_query(ListComputePools, variables).await?;
705 Ok(response_data.compute_pools)
706 }
707
708 pub async fn get_recipe(
709 &self,
710 project: String,
711 id_or_key: String,
712 ) -> Result<Option<get_recipe::GetRecipeCustomRecipe>> {
713 let variables = get_recipe::Variables { project, id_or_key };
714
715 let response_data = self.execute_query(GetRecipe, variables).await?;
716 Ok(response_data.custom_recipe)
717 }
718
719 pub async fn get_grader(
720 &self,
721 id_or_key: &str,
722 project: &str,
723 ) -> Result<get_grader::GetGraderGrader> {
724 let variables = get_grader::Variables {
725 id: id_or_key.to_string(),
726 project: project.to_string(),
727 };
728
729 let response_data = self.execute_query(GetGrader, variables).await?;
730 Ok(response_data.grader)
731 }
732
733 pub async fn get_dataset(
734 &self,
735 id_or_key: &str,
736 project: &str,
737 ) -> Result<Option<get_dataset::GetDatasetDataset>> {
738 let variables = get_dataset::Variables {
739 id_or_key: id_or_key.to_string(),
740 project: project.to_string(),
741 };
742
743 let response_data = self.execute_query(GetDataset, variables).await?;
744 Ok(response_data.dataset)
745 }
746
747 pub async fn get_model_config(
748 &self,
749 id_or_key: &str,
750 ) -> Result<Option<get_model_config::GetModelConfigModel>> {
751 let variables = get_model_config::Variables {
752 id_or_key: id_or_key.to_string(),
753 };
754
755 let response_data = self.execute_query(GetModelConfig, variables).await?;
756 Ok(response_data.model)
757 }
758
759 pub fn base_url(&self) -> &Url {
760 &self.rest_base_url
761 }
762
763 pub async fn upload_bytes(&self, data: &[u8], content_type: &str) -> Result<String> {
766 let file_size = data.len() as u64;
767
768 let chunk_size = if file_size < 5 * 1024 * 1024 {
770 file_size.max(1)
772 } else if file_size < 500 * 1024 * 1024 {
773 5 * 1024 * 1024
774 } else if file_size < 10 * 1024 * 1024 * 1024 {
775 10 * 1024 * 1024
776 } else {
777 100 * 1024 * 1024
778 };
779
780 let total_parts = file_size.div_ceil(chunk_size).max(1);
781
782 let session_id = self
784 .init_chunked_upload_with_content_type(total_parts, content_type)
785 .await?;
786
787 for part_number in 1..=total_parts {
789 let start = ((part_number - 1) * chunk_size) as usize;
790 let end = (part_number * chunk_size).min(file_size) as usize;
791 let chunk = data[start..end].to_vec();
792
793 if let Err(e) = self
794 .upload_part_simple(&session_id, part_number, chunk)
795 .await
796 {
797 let _ = self.abort_chunked_upload(&session_id).await;
798 return Err(e);
799 }
800 }
801
802 Ok(session_id)
803 }
804
805 pub async fn init_chunked_upload_with_content_type(
807 &self,
808 total_parts: u64,
809 content_type: &str,
810 ) -> Result<String> {
811 let url = self.rest_base_url.join(INIT_CHUNKED_UPLOAD_ROUTE)?;
812
813 let request = InitChunkedUploadRequest {
814 content_type: content_type.to_string(),
815 metadata: None,
816 total_parts_count: total_parts,
817 };
818
819 let response = self
820 .client
821 .post(url)
822 .bearer_auth(&self.auth_token)
823 .json(&request)
824 .send()
825 .await?;
826
827 if !response.status().is_success() {
828 return Err(AdaptiveError::ChunkedUploadInitFailed {
829 status: response.status().to_string(),
830 body: response.text().await.unwrap_or_default(),
831 });
832 }
833
834 let init_response: InitChunkedUploadResponse = response.json().await?;
835 Ok(init_response.session_id)
836 }
837
838 pub async fn upload_part_simple(
840 &self,
841 session_id: &str,
842 part_number: u64,
843 data: Vec<u8>,
844 ) -> Result<()> {
845 let url = self.rest_base_url.join(UPLOAD_PART_ROUTE)?;
846
847 let response = self
848 .client
849 .post(url)
850 .bearer_auth(&self.auth_token)
851 .query(&[
852 ("session_id", session_id),
853 ("part_number", &part_number.to_string()),
854 ])
855 .header("Content-Type", "application/octet-stream")
856 .body(data)
857 .send()
858 .await?;
859
860 if !response.status().is_success() {
861 return Err(AdaptiveError::ChunkedUploadPartFailed {
862 part_number,
863 status: response.status().to_string(),
864 body: response.text().await.unwrap_or_default(),
865 });
866 }
867
868 Ok(())
869 }
870
871 async fn init_chunked_upload(&self, total_parts: u64) -> Result<String> {
872 let url = self.rest_base_url.join(INIT_CHUNKED_UPLOAD_ROUTE)?;
873
874 let request = InitChunkedUploadRequest {
875 content_type: "application/jsonl".to_string(),
876 metadata: None,
877 total_parts_count: total_parts,
878 };
879
880 let response = self
881 .client
882 .post(url)
883 .bearer_auth(&self.auth_token)
884 .json(&request)
885 .send()
886 .await?;
887
888 if !response.status().is_success() {
889 return Err(AdaptiveError::ChunkedUploadInitFailed {
890 status: response.status().to_string(),
891 body: response.text().await.unwrap_or_default(),
892 });
893 }
894
895 let init_response: InitChunkedUploadResponse = response.json().await?;
896 Ok(init_response.session_id)
897 }
898
899 async fn upload_part(
900 &self,
901 session_id: &str,
902 part_number: u64,
903 data: Vec<u8>,
904 progress_tx: mpsc::Sender<u64>,
905 ) -> Result<()> {
906 const SUB_CHUNK_SIZE: usize = 64 * 1024;
907
908 let url = self.rest_base_url.join(UPLOAD_PART_ROUTE)?;
909
910 let chunks: Vec<Vec<u8>> = data
911 .chunks(SUB_CHUNK_SIZE)
912 .map(|chunk| chunk.to_vec())
913 .collect();
914
915 let stream = futures::stream::iter(chunks).map(move |chunk| {
916 let len = chunk.len() as u64;
917 let tx = progress_tx.clone();
918 let _ = tx.try_send(len);
919 Ok::<_, std::io::Error>(chunk)
920 });
921
922 let body = reqwest::Body::wrap_stream(stream);
923
924 let response = self
925 .client
926 .post(url)
927 .bearer_auth(&self.auth_token)
928 .query(&[
929 ("session_id", session_id),
930 ("part_number", &part_number.to_string()),
931 ])
932 .header("Content-Type", "application/octet-stream")
933 .body(body)
934 .send()
935 .await?;
936
937 if !response.status().is_success() {
938 return Err(AdaptiveError::ChunkedUploadPartFailed {
939 part_number,
940 status: response.status().to_string(),
941 body: response.text().await.unwrap_or_default(),
942 });
943 }
944
945 Ok(())
946 }
947
948 pub async fn abort_chunked_upload(&self, session_id: &str) -> Result<()> {
950 let url = self.rest_base_url.join(ABORT_CHUNKED_UPLOAD_ROUTE)?;
951
952 let request = AbortChunkedUploadRequest {
953 session_id: session_id.to_string(),
954 };
955
956 let _ = self
957 .client
958 .delete(url)
959 .bearer_auth(&self.auth_token)
960 .json(&request)
961 .send()
962 .await;
963
964 Ok(())
965 }
966
967 async fn create_dataset_from_multipart(
968 &self,
969 project: &str,
970 name: &str,
971 key: &str,
972 session_id: &str,
973 ) -> Result<
974 create_dataset_from_multipart::CreateDatasetFromMultipartCreateDatasetFromMultipartUpload,
975 > {
976 let variables = create_dataset_from_multipart::Variables {
977 input: create_dataset_from_multipart::DatasetCreateFromMultipartUpload {
978 project: project.to_string(),
979 name: name.to_string(),
980 key: Some(key.to_string()),
981 source: None,
982 upload_session_id: session_id.to_string(),
983 },
984 };
985
986 let response_data = self
987 .execute_query(CreateDatasetFromMultipart, variables)
988 .await?;
989 Ok(response_data.create_dataset_from_multipart_upload)
990 }
991
992 pub fn chunked_upload_dataset<'a, P: AsRef<Path> + Send + 'a>(
993 &'a self,
994 project: &'a str,
995 name: &'a str,
996 key: &'a str,
997 dataset: P,
998 ) -> Result<BoxStream<'a, Result<UploadEvent>>> {
999 let file_size = std::fs::metadata(dataset.as_ref())?.len();
1000
1001 let (total_parts, chunk_size) = calculate_upload_parts(file_size)?;
1002
1003 let stream = async_stream::try_stream! {
1004 yield UploadEvent::Progress(ChunkedUploadProgress {
1005 bytes_uploaded: 0,
1006 total_bytes: file_size,
1007 });
1008
1009 let session_id = self.init_chunked_upload(total_parts).await?;
1010
1011 let mut file = File::open(dataset.as_ref())?;
1012 let mut buffer = vec![0u8; chunk_size as usize];
1013 let mut bytes_uploaded = 0u64;
1014
1015 let (progress_tx, mut progress_rx) = mpsc::channel::<u64>(64);
1016
1017 for part_number in 1..=total_parts {
1018 let bytes_read = file.read(&mut buffer)?;
1019 let chunk_data = buffer[..bytes_read].to_vec();
1020
1021 let upload_fut = self.upload_part(&session_id, part_number, chunk_data, progress_tx.clone());
1022 tokio::pin!(upload_fut);
1023
1024 let upload_result: Result<()> = loop {
1025 tokio::select! {
1026 biased;
1027 result = &mut upload_fut => {
1028 break result;
1029 }
1030 Some(bytes) = progress_rx.recv() => {
1031 bytes_uploaded += bytes;
1032 yield UploadEvent::Progress(ChunkedUploadProgress {
1033 bytes_uploaded,
1034 total_bytes: file_size,
1035 });
1036 }
1037 }
1038 };
1039
1040 if let Err(e) = upload_result {
1041 let _ = self.abort_chunked_upload(&session_id).await;
1042 Err(e)?;
1043 }
1044 }
1045
1046 let create_result = self
1047 .create_dataset_from_multipart(project, name, key, &session_id)
1048 .await;
1049
1050 match create_result {
1051 Ok(response) => {
1052 yield UploadEvent::Complete(response);
1053 }
1054 Err(e) => {
1055 let _ = self.abort_chunked_upload(&session_id).await;
1056 Err(AdaptiveError::DatasetCreationFailed(e.to_string()))?;
1057 }
1058 }
1059 };
1060
1061 Ok(Box::pin(stream))
1062 }
1063
1064 pub async fn download_file_to_path(&self, url: &str, dest_path: &Path) -> Result<()> {
1067 use tokio::io::AsyncWriteExt;
1068
1069 let full_url = if url.starts_with("http://") || url.starts_with("https://") {
1070 Url::parse(url)?
1071 } else {
1072 self.rest_base_url.join(url)?
1073 };
1074
1075 let response = self
1076 .client
1077 .get(full_url)
1078 .bearer_auth(&self.auth_token)
1079 .send()
1080 .await?;
1081
1082 if !response.status().is_success() {
1083 return Err(AdaptiveError::HttpError(
1084 response.error_for_status().unwrap_err(),
1085 ));
1086 }
1087
1088 let mut file = tokio::fs::File::create(dest_path).await?;
1089 let mut stream = response.bytes_stream();
1090
1091 while let Some(chunk) = stream.next().await {
1092 let chunk = chunk?;
1093 file.write_all(&chunk).await?;
1094 }
1095
1096 file.flush().await?;
1097 Ok(())
1098 }
1099}