1use std::{fmt::Display, fs::File, io::Read, path::Path, time::SystemTime};
2
3use futures::{StreamExt, stream::BoxStream};
4use thiserror::Error;
5use tokio::sync::mpsc;
6
7use graphql_client::{GraphQLQuery, Response};
8use reqwest::Client;
9use serde::{Deserialize, Serialize, de::DeserializeOwned};
10use serde_json::{Map, Value};
11use url::Url;
12use uuid::Uuid;
13
14mod rest_types;
15mod serde_utils;
16
17use rest_types::{AbortChunkedUploadRequest, InitChunkedUploadRequest, InitChunkedUploadResponse};
18
19const MEGABYTE: u64 = 1024 * 1024; pub const MIN_CHUNK_SIZE_BYTES: u64 = 5 * MEGABYTE;
21const MAX_CHUNK_SIZE_BYTES: u64 = 100 * MEGABYTE;
22const MAX_PARTS_COUNT: u64 = 10000;
23
24const SIZE_500MB: u64 = 500 * MEGABYTE;
25const SIZE_10GB: u64 = 10 * 1024 * MEGABYTE;
26const SIZE_50GB: u64 = 50 * 1024 * MEGABYTE;
27
28#[derive(Error, Debug)]
29pub enum AdaptiveError {
30 #[error("HTTP error: {0}")]
31 HttpError(#[from] reqwest::Error),
32
33 #[error("JSON error: {0}")]
34 JsonError(#[from] serde_json::Error),
35
36 #[error("IO error: {0}")]
37 IoError(#[from] std::io::Error),
38
39 #[error("URL parse error: {0}")]
40 UrlParseError(#[from] url::ParseError),
41
42 #[error("File too small for chunked upload: {size} bytes (minimum: {min_size} bytes)")]
43 FileTooSmall { size: u64, min_size: u64 },
44
45 #[error("File too large: {size} bytes exceeds maximum {max_size} bytes")]
46 FileTooLarge { size: u64, max_size: u64 },
47
48 #[error("GraphQL errors: {0:?}")]
49 GraphQLErrors(Vec<graphql_client::Error>),
50
51 #[error("No data returned from GraphQL")]
52 NoGraphQLData,
53
54 #[error("Job not found: {0}")]
55 JobNotFound(Uuid),
56
57 #[error("Failed to initialize chunked upload: {status} - {body}")]
58 ChunkedUploadInitFailed { status: String, body: String },
59
60 #[error("Failed to upload part {part_number}: {status} - {body}")]
61 ChunkedUploadPartFailed {
62 part_number: u64,
63 status: String,
64 body: String,
65 },
66
67 #[error("Failed to create dataset: {0}")]
68 DatasetCreationFailed(String),
69
70 #[error("HTTP status error: {status} - {body}")]
71 HttpStatusError { status: String, body: String },
72
73 #[error("Failed to parse JSON response: {error}. Body preview: {body}")]
74 JsonParseError { error: String, body: String },
75}
76
77type Result<T> = std::result::Result<T, AdaptiveError>;
78
79#[derive(Clone, Debug, Default)]
80pub struct ChunkedUploadProgress {
81 pub bytes_uploaded: u64,
82 pub total_bytes: u64,
83}
84
85#[derive(Debug)]
86pub enum UploadEvent {
87 Progress(ChunkedUploadProgress),
88 Complete(
89 create_dataset_from_multipart::CreateDatasetFromMultipartCreateDatasetFromMultipartUpload,
90 ),
91}
92
93pub fn calculate_upload_parts(file_size: u64) -> Result<(u64, u64)> {
94 if file_size < MIN_CHUNK_SIZE_BYTES {
95 return Err(AdaptiveError::FileTooSmall {
96 size: file_size,
97 min_size: MIN_CHUNK_SIZE_BYTES,
98 });
99 }
100
101 let mut chunk_size = if file_size < SIZE_500MB {
102 5 * MEGABYTE
103 } else if file_size < SIZE_10GB {
104 10 * MEGABYTE
105 } else if file_size < SIZE_50GB {
106 50 * MEGABYTE
107 } else {
108 100 * MEGABYTE
109 };
110
111 let mut total_parts = file_size.div_ceil(chunk_size);
112
113 if total_parts > MAX_PARTS_COUNT {
114 chunk_size = file_size.div_ceil(MAX_PARTS_COUNT);
115
116 if chunk_size > MAX_CHUNK_SIZE_BYTES {
117 let max_file_size = MAX_CHUNK_SIZE_BYTES * MAX_PARTS_COUNT;
118 return Err(AdaptiveError::FileTooLarge {
119 size: file_size,
120 max_size: max_file_size,
121 });
122 }
123
124 total_parts = file_size.div_ceil(chunk_size);
125 }
126
127 Ok((total_parts, chunk_size))
128}
129
130type IdOrKey = String;
131#[allow(clippy::upper_case_acronyms)]
132type UUID = Uuid;
133type JsObject = Map<String, Value>;
134type InputDatetime = String;
135#[allow(clippy::upper_case_acronyms)]
136type JSON = Value;
137type KeyInput = String;
138
139#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
140pub struct Timestamp(pub SystemTime);
141
142impl<'de> serde::Deserialize<'de> for Timestamp {
143 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
144 where
145 D: serde::Deserializer<'de>,
146 {
147 let system_time = serde_utils::deserialize_timestamp_millis(deserializer)?;
148 Ok(Timestamp(system_time))
149 }
150}
151
152const PAGE_SIZE: usize = 20;
153
154#[derive(Debug, PartialEq, Serialize, Deserialize)]
155pub struct Upload(usize);
156
157#[derive(GraphQLQuery)]
158#[graphql(
159 schema_path = "schema.gql",
160 query_path = "src/graphql/list.graphql",
161 response_derives = "Debug, Clone"
162)]
163pub struct GetCustomRecipes;
164
165#[derive(GraphQLQuery)]
166#[graphql(
167 schema_path = "schema.gql",
168 query_path = "src/graphql/job.graphql",
169 response_derives = "Debug, Clone"
170)]
171pub struct GetJob;
172
173#[derive(GraphQLQuery)]
174#[graphql(
175 schema_path = "schema.gql",
176 query_path = "src/graphql/jobs.graphql",
177 response_derives = "Debug, Clone"
178)]
179pub struct ListJobs;
180
181#[derive(GraphQLQuery)]
182#[graphql(
183 schema_path = "schema.gql",
184 query_path = "src/graphql/cancel.graphql",
185 response_derives = "Debug, Clone"
186)]
187pub struct CancelJob;
188
189#[derive(GraphQLQuery)]
190#[graphql(
191 schema_path = "schema.gql",
192 query_path = "src/graphql/models.graphql",
193 response_derives = "Debug, Clone"
194)]
195pub struct ListModels;
196
197#[derive(GraphQLQuery)]
198#[graphql(
199 schema_path = "schema.gql",
200 query_path = "src/graphql/all_models.graphql",
201 response_derives = "Debug, Clone"
202)]
203pub struct ListAllModels;
204
205impl Display for get_job::JobStatus {
206 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
207 match self {
208 get_job::JobStatus::PENDING => write!(f, "Pending"),
209 get_job::JobStatus::RUNNING => write!(f, "Running"),
210 get_job::JobStatus::COMPLETED => write!(f, "Completed"),
211 get_job::JobStatus::FAILED => write!(f, "Failed"),
212 get_job::JobStatus::CANCELED => write!(f, "Canceled"),
213 get_job::JobStatus::Other(_) => write!(f, "Unknown"),
214 }
215 }
216}
217
218#[derive(GraphQLQuery)]
219#[graphql(
220 schema_path = "schema.gql",
221 query_path = "src/graphql/publish.graphql",
222 response_derives = "Debug, Clone"
223)]
224pub struct PublishCustomRecipe;
225
226#[derive(GraphQLQuery)]
227#[graphql(
228 schema_path = "schema.gql",
229 query_path = "src/graphql/upload_dataset.graphql",
230 response_derives = "Debug, Clone"
231)]
232pub struct UploadDataset;
233
234#[derive(GraphQLQuery)]
235#[graphql(
236 schema_path = "schema.gql",
237 query_path = "src/graphql/create_dataset_from_multipart.graphql",
238 response_derives = "Debug, Clone"
239)]
240pub struct CreateDatasetFromMultipart;
241
242#[derive(GraphQLQuery)]
243#[graphql(
244 schema_path = "schema.gql",
245 query_path = "src/graphql/run.graphql",
246 response_derives = "Debug, Clone"
247)]
248pub struct RunCustomRecipe;
249
250#[derive(GraphQLQuery)]
251#[graphql(
252 schema_path = "schema.gql",
253 query_path = "src/graphql/projects.graphql",
254 response_derives = "Debug, Clone"
255)]
256pub struct ListProjects;
257
258#[derive(GraphQLQuery)]
259#[graphql(
260 schema_path = "schema.gql",
261 query_path = "src/graphql/pools.graphql",
262 response_derives = "Debug, Clone"
263)]
264pub struct ListComputePools;
265
266#[derive(GraphQLQuery)]
267#[graphql(
268 schema_path = "schema.gql",
269 query_path = "src/graphql/recipe.graphql",
270 response_derives = "Debug, Clone"
271)]
272pub struct GetRecipe;
273
274#[derive(GraphQLQuery)]
275#[graphql(
276 schema_path = "schema.gql",
277 query_path = "src/graphql/grader.graphql",
278 response_derives = "Debug, Clone, Serialize"
279)]
280pub struct GetGrader;
281
282#[derive(GraphQLQuery)]
283#[graphql(
284 schema_path = "schema.gql",
285 query_path = "src/graphql/dataset.graphql",
286 response_derives = "Debug, Clone"
287)]
288pub struct GetDataset;
289
290#[derive(GraphQLQuery)]
291#[graphql(
292 schema_path = "schema.gql",
293 query_path = "src/graphql/model_config.graphql",
294 response_derives = "Debug, Clone, Serialize"
295)]
296pub struct GetModelConfig;
297
298#[derive(GraphQLQuery)]
299#[graphql(
300 schema_path = "schema.gql",
301 query_path = "src/graphql/job_progress.graphql",
302 response_derives = "Debug, Clone"
303)]
304pub struct UpdateJobProgress;
305
306const INIT_CHUNKED_UPLOAD_ROUTE: &str = "v1/upload/init";
307const UPLOAD_PART_ROUTE: &str = "v1/upload/part";
308const ABORT_CHUNKED_UPLOAD_ROUTE: &str = "v1/upload/abort";
309
310#[derive(Clone)]
311pub struct AdaptiveClient {
312 client: Client,
313 graphql_url: Url,
314 rest_base_url: Url,
315 auth_token: String,
316}
317
318impl AdaptiveClient {
319 pub fn new(api_base_url: Url, auth_token: String) -> Self {
320 let graphql_url = api_base_url
321 .join("graphql")
322 .expect("Failed to append graphql to base URL");
323
324 let client = Client::builder()
325 .user_agent(format!(
326 "adaptive-client-rust/{}",
327 env!("CARGO_PKG_VERSION")
328 ))
329 .build()
330 .expect("Failed to build HTTP client");
331
332 Self {
333 client,
334 graphql_url,
335 rest_base_url: api_base_url,
336 auth_token,
337 }
338 }
339
340 async fn execute_query<T>(&self, _query: T, variables: T::Variables) -> Result<T::ResponseData>
341 where
342 T: GraphQLQuery,
343 T::Variables: serde::Serialize,
344 T::ResponseData: DeserializeOwned,
345 {
346 let request_body = T::build_query(variables);
347
348 let response = self
349 .client
350 .post(self.graphql_url.clone())
351 .bearer_auth(&self.auth_token)
352 .json(&request_body)
353 .send()
354 .await?;
355
356 let status = response.status();
357 let response_text = response.text().await?;
358
359 if !status.is_success() {
360 return Err(AdaptiveError::HttpStatusError {
361 status: status.to_string(),
362 body: response_text,
363 });
364 }
365
366 let response_body: Response<T::ResponseData> = serde_json::from_str(&response_text)
367 .map_err(|e| AdaptiveError::JsonParseError {
368 error: e.to_string(),
369 body: response_text.chars().take(500).collect(),
370 })?;
371
372 match response_body.data {
373 Some(data) => Ok(data),
374 None => {
375 if let Some(errors) = response_body.errors {
376 return Err(AdaptiveError::GraphQLErrors(errors));
377 }
378 Err(AdaptiveError::NoGraphQLData)
379 }
380 }
381 }
382
383 pub async fn list_recipes(
384 &self,
385 project: &str,
386 ) -> Result<Vec<get_custom_recipes::GetCustomRecipesCustomRecipes>> {
387 let variables = get_custom_recipes::Variables {
388 project: IdOrKey::from(project),
389 };
390
391 let response_data = self.execute_query(GetCustomRecipes, variables).await?;
392 Ok(response_data.custom_recipes)
393 }
394
395 pub async fn get_job(&self, job_id: Uuid) -> Result<get_job::GetJobJob> {
396 let variables = get_job::Variables { id: job_id };
397
398 let response_data = self.execute_query(GetJob, variables).await?;
399
400 match response_data.job {
401 Some(job) => Ok(job),
402 None => Err(AdaptiveError::JobNotFound(job_id)),
403 }
404 }
405
406 pub async fn upload_dataset<P: AsRef<Path>>(
407 &self,
408 project: &str,
409 name: &str,
410 dataset: P,
411 ) -> Result<upload_dataset::UploadDatasetCreateDataset> {
412 let variables = upload_dataset::Variables {
413 project: IdOrKey::from(project),
414 file: Upload(0),
415 name: Some(name.to_string()),
416 };
417
418 let operations = UploadDataset::build_query(variables);
419 let operations = serde_json::to_string(&operations)?;
420
421 let file_map = r#"{ "0": ["variables.file"] }"#;
422
423 let dataset_file = reqwest::multipart::Part::file(dataset).await?;
424
425 let form = reqwest::multipart::Form::new()
426 .text("operations", operations)
427 .text("map", file_map)
428 .part("0", dataset_file);
429
430 let response = self
431 .client
432 .post(self.graphql_url.clone())
433 .bearer_auth(&self.auth_token)
434 .multipart(form)
435 .send()
436 .await?;
437
438 let response: Response<<UploadDataset as graphql_client::GraphQLQuery>::ResponseData> =
439 response.json().await?;
440
441 match response.data {
442 Some(data) => Ok(data.create_dataset),
443 None => {
444 if let Some(errors) = response.errors {
445 return Err(AdaptiveError::GraphQLErrors(errors));
446 }
447 Err(AdaptiveError::NoGraphQLData)
448 }
449 }
450 }
451
452 pub async fn publish_recipe<P: AsRef<Path>>(
453 &self,
454 project: &str,
455 name: &str,
456 key: &str,
457 recipe: P,
458 ) -> Result<publish_custom_recipe::PublishCustomRecipeCreateCustomRecipe> {
459 let variables = publish_custom_recipe::Variables {
460 project: IdOrKey::from(project),
461 file: Upload(0),
462 name: Some(name.to_string()),
463 key: Some(key.to_string()),
464 };
465
466 let operations = PublishCustomRecipe::build_query(variables);
467 let operations = serde_json::to_string(&operations)?;
468
469 let file_map = r#"{ "0": ["variables.file"] }"#;
470
471 let recipe_file = reqwest::multipart::Part::file(recipe).await?;
472
473 let form = reqwest::multipart::Form::new()
474 .text("operations", operations)
475 .text("map", file_map)
476 .part("0", recipe_file);
477
478 let response = self
479 .client
480 .post(self.graphql_url.clone())
481 .bearer_auth(&self.auth_token)
482 .multipart(form)
483 .send()
484 .await?;
485 let response: Response<
486 <PublishCustomRecipe as graphql_client::GraphQLQuery>::ResponseData,
487 > = response.json().await?;
488
489 match response.data {
490 Some(data) => Ok(data.create_custom_recipe),
491 None => {
492 if let Some(errors) = response.errors {
493 return Err(AdaptiveError::GraphQLErrors(errors));
494 }
495 Err(AdaptiveError::NoGraphQLData)
496 }
497 }
498 }
499
500 pub async fn run_recipe(
501 &self,
502 project: &str,
503 recipe_id: &str,
504 parameters: Map<String, Value>,
505 name: Option<String>,
506 compute_pool: Option<String>,
507 num_gpus: u32,
508 use_experimental_runner: bool,
509 ) -> Result<run_custom_recipe::RunCustomRecipeCreateJob> {
510 let variables = run_custom_recipe::Variables {
511 input: run_custom_recipe::JobInput {
512 recipe: recipe_id.to_string(),
513 project: project.to_string(),
514 args: parameters,
515 name,
516 compute_pool,
517 num_gpus: num_gpus as i64,
518 use_experimental_runner,
519 max_cpu: None,
520 max_ram_gb: None,
521 max_duration_secs: None,
522 },
523 };
524
525 let response_data = self.execute_query(RunCustomRecipe, variables).await?;
526 Ok(response_data.create_job)
527 }
528
529 pub async fn list_jobs(
530 &self,
531 project: Option<String>,
532 ) -> Result<Vec<list_jobs::ListJobsJobsNodes>> {
533 let mut jobs = Vec::new();
534 let mut page = self.list_jobs_page(project.clone(), None).await?;
535 jobs.extend(page.nodes.iter().cloned());
536 while page.page_info.has_next_page {
537 page = self
538 .list_jobs_page(project.clone(), page.page_info.end_cursor)
539 .await?;
540 jobs.extend(page.nodes.iter().cloned());
541 }
542 Ok(jobs)
543 }
544
545 async fn list_jobs_page(
546 &self,
547 project: Option<String>,
548 after: Option<String>,
549 ) -> Result<list_jobs::ListJobsJobs> {
550 let variables = list_jobs::Variables {
551 filter: Some(list_jobs::ListJobsFilterInput {
552 project,
553 kind: Some(vec![list_jobs::JobKind::CUSTOM]),
554 status: Some(vec![
555 list_jobs::JobStatus::RUNNING,
556 list_jobs::JobStatus::PENDING,
557 ]),
558 timerange: None,
559 custom_recipes: None,
560 artifacts: None,
561 }),
562 cursor: Some(list_jobs::CursorPageInput {
563 first: Some(PAGE_SIZE as i64),
564 after,
565 before: None,
566 last: None,
567 offset: None,
568 }),
569 };
570
571 let response_data = self.execute_query(ListJobs, variables).await?;
572 Ok(response_data.jobs)
573 }
574
575 pub async fn cancel_job(&self, job_id: Uuid) -> Result<cancel_job::CancelJobCancelJob> {
576 let variables = cancel_job::Variables { job_id };
577
578 let response_data = self.execute_query(CancelJob, variables).await?;
579 Ok(response_data.cancel_job)
580 }
581
582 pub async fn update_job_progress(
583 &self,
584 job_id: Uuid,
585 event: update_job_progress::JobProgressEventInput,
586 ) -> Result<update_job_progress::UpdateJobProgressUpdateJobProgress> {
587 let variables = update_job_progress::Variables { job_id, event };
588
589 let response_data = self.execute_query(UpdateJobProgress, variables).await?;
590 Ok(response_data.update_job_progress)
591 }
592
593 pub async fn list_models(
594 &self,
595 project: String,
596 ) -> Result<Vec<list_models::ListModelsProjectModelServices>> {
597 let variables = list_models::Variables { project };
598
599 let response_data = self.execute_query(ListModels, variables).await?;
600 Ok(response_data
601 .project
602 .map(|project| project.model_services)
603 .unwrap_or(Vec::new()))
604 }
605
606 pub async fn list_all_models(&self) -> Result<Vec<list_all_models::ListAllModelsModels>> {
607 let variables = list_all_models::Variables {};
608
609 let response_data = self.execute_query(ListAllModels, variables).await?;
610 Ok(response_data.models)
611 }
612
613 pub async fn list_projects(&self) -> Result<Vec<list_projects::ListProjectsProjects>> {
614 let variables = list_projects::Variables {};
615
616 let response_data = self.execute_query(ListProjects, variables).await?;
617 Ok(response_data.projects)
618 }
619
620 pub async fn list_pools(
621 &self,
622 ) -> Result<Vec<list_compute_pools::ListComputePoolsComputePools>> {
623 let variables = list_compute_pools::Variables {};
624
625 let response_data = self.execute_query(ListComputePools, variables).await?;
626 Ok(response_data.compute_pools)
627 }
628
629 pub async fn get_recipe(
630 &self,
631 project: String,
632 id_or_key: String,
633 ) -> Result<Option<get_recipe::GetRecipeCustomRecipe>> {
634 let variables = get_recipe::Variables { project, id_or_key };
635
636 let response_data = self.execute_query(GetRecipe, variables).await?;
637 Ok(response_data.custom_recipe)
638 }
639
640 pub async fn get_grader(
641 &self,
642 id_or_key: &str,
643 project: &str,
644 ) -> Result<get_grader::GetGraderGrader> {
645 let variables = get_grader::Variables {
646 id: id_or_key.to_string(),
647 project: project.to_string(),
648 };
649
650 let response_data = self.execute_query(GetGrader, variables).await?;
651 Ok(response_data.grader)
652 }
653
654 pub async fn get_dataset(
655 &self,
656 id_or_key: &str,
657 project: &str,
658 ) -> Result<Option<get_dataset::GetDatasetDataset>> {
659 let variables = get_dataset::Variables {
660 id_or_key: id_or_key.to_string(),
661 project: project.to_string(),
662 };
663
664 let response_data = self.execute_query(GetDataset, variables).await?;
665 Ok(response_data.dataset)
666 }
667
668 pub async fn get_model_config(
669 &self,
670 id_or_key: &str,
671 ) -> Result<Option<get_model_config::GetModelConfigModel>> {
672 let variables = get_model_config::Variables {
673 id_or_key: id_or_key.to_string(),
674 };
675
676 let response_data = self.execute_query(GetModelConfig, variables).await?;
677 Ok(response_data.model)
678 }
679
680 pub fn base_url(&self) -> &Url {
681 &self.rest_base_url
682 }
683
684 pub async fn upload_bytes(&self, data: &[u8], content_type: &str) -> Result<String> {
687 let file_size = data.len() as u64;
688
689 let chunk_size = if file_size < 5 * 1024 * 1024 {
691 file_size.max(1)
693 } else if file_size < 500 * 1024 * 1024 {
694 5 * 1024 * 1024
695 } else if file_size < 10 * 1024 * 1024 * 1024 {
696 10 * 1024 * 1024
697 } else {
698 100 * 1024 * 1024
699 };
700
701 let total_parts = file_size.div_ceil(chunk_size).max(1);
702
703 let session_id = self
705 .init_chunked_upload_with_content_type(total_parts, content_type)
706 .await?;
707
708 for part_number in 1..=total_parts {
710 let start = ((part_number - 1) * chunk_size) as usize;
711 let end = (part_number * chunk_size).min(file_size) as usize;
712 let chunk = data[start..end].to_vec();
713
714 if let Err(e) = self
715 .upload_part_simple(&session_id, part_number, chunk)
716 .await
717 {
718 let _ = self.abort_chunked_upload(&session_id).await;
719 return Err(e);
720 }
721 }
722
723 Ok(session_id)
724 }
725
726 pub async fn init_chunked_upload_with_content_type(
728 &self,
729 total_parts: u64,
730 content_type: &str,
731 ) -> Result<String> {
732 let url = self.rest_base_url.join(INIT_CHUNKED_UPLOAD_ROUTE)?;
733
734 let request = InitChunkedUploadRequest {
735 content_type: content_type.to_string(),
736 metadata: None,
737 total_parts_count: total_parts,
738 };
739
740 let response = self
741 .client
742 .post(url)
743 .bearer_auth(&self.auth_token)
744 .json(&request)
745 .send()
746 .await?;
747
748 if !response.status().is_success() {
749 return Err(AdaptiveError::ChunkedUploadInitFailed {
750 status: response.status().to_string(),
751 body: response.text().await.unwrap_or_default(),
752 });
753 }
754
755 let init_response: InitChunkedUploadResponse = response.json().await?;
756 Ok(init_response.session_id)
757 }
758
759 pub async fn upload_part_simple(
761 &self,
762 session_id: &str,
763 part_number: u64,
764 data: Vec<u8>,
765 ) -> Result<()> {
766 let url = self.rest_base_url.join(UPLOAD_PART_ROUTE)?;
767
768 let response = self
769 .client
770 .post(url)
771 .bearer_auth(&self.auth_token)
772 .query(&[
773 ("session_id", session_id),
774 ("part_number", &part_number.to_string()),
775 ])
776 .header("Content-Type", "application/octet-stream")
777 .body(data)
778 .send()
779 .await?;
780
781 if !response.status().is_success() {
782 return Err(AdaptiveError::ChunkedUploadPartFailed {
783 part_number,
784 status: response.status().to_string(),
785 body: response.text().await.unwrap_or_default(),
786 });
787 }
788
789 Ok(())
790 }
791
792 async fn init_chunked_upload(&self, total_parts: u64) -> Result<String> {
793 let url = self.rest_base_url.join(INIT_CHUNKED_UPLOAD_ROUTE)?;
794
795 let request = InitChunkedUploadRequest {
796 content_type: "application/jsonl".to_string(),
797 metadata: None,
798 total_parts_count: total_parts,
799 };
800
801 let response = self
802 .client
803 .post(url)
804 .bearer_auth(&self.auth_token)
805 .json(&request)
806 .send()
807 .await?;
808
809 if !response.status().is_success() {
810 return Err(AdaptiveError::ChunkedUploadInitFailed {
811 status: response.status().to_string(),
812 body: response.text().await.unwrap_or_default(),
813 });
814 }
815
816 let init_response: InitChunkedUploadResponse = response.json().await?;
817 Ok(init_response.session_id)
818 }
819
820 async fn upload_part(
821 &self,
822 session_id: &str,
823 part_number: u64,
824 data: Vec<u8>,
825 progress_tx: mpsc::Sender<u64>,
826 ) -> Result<()> {
827 const SUB_CHUNK_SIZE: usize = 64 * 1024;
828
829 let url = self.rest_base_url.join(UPLOAD_PART_ROUTE)?;
830
831 let chunks: Vec<Vec<u8>> = data
832 .chunks(SUB_CHUNK_SIZE)
833 .map(|chunk| chunk.to_vec())
834 .collect();
835
836 let stream = futures::stream::iter(chunks).map(move |chunk| {
837 let len = chunk.len() as u64;
838 let tx = progress_tx.clone();
839 let _ = tx.try_send(len);
840 Ok::<_, std::io::Error>(chunk)
841 });
842
843 let body = reqwest::Body::wrap_stream(stream);
844
845 let response = self
846 .client
847 .post(url)
848 .bearer_auth(&self.auth_token)
849 .query(&[
850 ("session_id", session_id),
851 ("part_number", &part_number.to_string()),
852 ])
853 .header("Content-Type", "application/octet-stream")
854 .body(body)
855 .send()
856 .await?;
857
858 if !response.status().is_success() {
859 return Err(AdaptiveError::ChunkedUploadPartFailed {
860 part_number,
861 status: response.status().to_string(),
862 body: response.text().await.unwrap_or_default(),
863 });
864 }
865
866 Ok(())
867 }
868
869 pub async fn abort_chunked_upload(&self, session_id: &str) -> Result<()> {
871 let url = self.rest_base_url.join(ABORT_CHUNKED_UPLOAD_ROUTE)?;
872
873 let request = AbortChunkedUploadRequest {
874 session_id: session_id.to_string(),
875 };
876
877 let _ = self
878 .client
879 .delete(url)
880 .bearer_auth(&self.auth_token)
881 .json(&request)
882 .send()
883 .await;
884
885 Ok(())
886 }
887
888 async fn create_dataset_from_multipart(
889 &self,
890 project: &str,
891 name: &str,
892 key: &str,
893 session_id: &str,
894 ) -> Result<
895 create_dataset_from_multipart::CreateDatasetFromMultipartCreateDatasetFromMultipartUpload,
896 > {
897 let variables = create_dataset_from_multipart::Variables {
898 input: create_dataset_from_multipart::DatasetCreateFromMultipartUpload {
899 project: project.to_string(),
900 name: name.to_string(),
901 key: Some(key.to_string()),
902 source: None,
903 upload_session_id: session_id.to_string(),
904 },
905 };
906
907 let response_data = self
908 .execute_query(CreateDatasetFromMultipart, variables)
909 .await?;
910 Ok(response_data.create_dataset_from_multipart_upload)
911 }
912
913 pub fn chunked_upload_dataset<'a, P: AsRef<Path> + Send + 'a>(
914 &'a self,
915 project: &'a str,
916 name: &'a str,
917 key: &'a str,
918 dataset: P,
919 ) -> Result<BoxStream<'a, Result<UploadEvent>>> {
920 let file_size = std::fs::metadata(dataset.as_ref())?.len();
921
922 let (total_parts, chunk_size) = calculate_upload_parts(file_size)?;
923
924 let stream = async_stream::try_stream! {
925 yield UploadEvent::Progress(ChunkedUploadProgress {
926 bytes_uploaded: 0,
927 total_bytes: file_size,
928 });
929
930 let session_id = self.init_chunked_upload(total_parts).await?;
931
932 let mut file = File::open(dataset.as_ref())?;
933 let mut buffer = vec![0u8; chunk_size as usize];
934 let mut bytes_uploaded = 0u64;
935
936 let (progress_tx, mut progress_rx) = mpsc::channel::<u64>(64);
937
938 for part_number in 1..=total_parts {
939 let bytes_read = file.read(&mut buffer)?;
940 let chunk_data = buffer[..bytes_read].to_vec();
941
942 let upload_fut = self.upload_part(&session_id, part_number, chunk_data, progress_tx.clone());
943 tokio::pin!(upload_fut);
944
945 let upload_result: Result<()> = loop {
946 tokio::select! {
947 biased;
948 result = &mut upload_fut => {
949 break result;
950 }
951 Some(bytes) = progress_rx.recv() => {
952 bytes_uploaded += bytes;
953 yield UploadEvent::Progress(ChunkedUploadProgress {
954 bytes_uploaded,
955 total_bytes: file_size,
956 });
957 }
958 }
959 };
960
961 if let Err(e) = upload_result {
962 let _ = self.abort_chunked_upload(&session_id).await;
963 Err(e)?;
964 }
965 }
966
967 let create_result = self
968 .create_dataset_from_multipart(project, name, key, &session_id)
969 .await;
970
971 match create_result {
972 Ok(response) => {
973 yield UploadEvent::Complete(response);
974 }
975 Err(e) => {
976 let _ = self.abort_chunked_upload(&session_id).await;
977 Err(AdaptiveError::DatasetCreationFailed(e.to_string()))?;
978 }
979 }
980 };
981
982 Ok(Box::pin(stream))
983 }
984
985 pub async fn download_file_to_path(&self, url: &str, dest_path: &Path) -> Result<()> {
988 use tokio::io::AsyncWriteExt;
989
990 let full_url = if url.starts_with("http://") || url.starts_with("https://") {
991 Url::parse(url)?
992 } else {
993 self.rest_base_url.join(url)?
994 };
995
996 let response = self
997 .client
998 .get(full_url)
999 .bearer_auth(&self.auth_token)
1000 .send()
1001 .await?;
1002
1003 if !response.status().is_success() {
1004 return Err(AdaptiveError::HttpError(
1005 response.error_for_status().unwrap_err(),
1006 ));
1007 }
1008
1009 let mut file = tokio::fs::File::create(dest_path).await?;
1010 let mut stream = response.bytes_stream();
1011
1012 while let Some(chunk) = stream.next().await {
1013 let chunk = chunk?;
1014 file.write_all(&chunk).await?;
1015 }
1016
1017 file.flush().await?;
1018 Ok(())
1019 }
1020}