1use std::{fmt::Display, fs::File, io::Read, path::Path, time::SystemTime};
2
3use futures::{StreamExt, stream::BoxStream};
4use thiserror::Error;
5use tokio::sync::mpsc;
6
7use graphql_client::{GraphQLQuery, Response};
8use reqwest::Client;
9use serde::{Deserialize, Serialize, de::DeserializeOwned};
10use serde_json::{Map, Value};
11use url::Url;
12use uuid::Uuid;
13
14mod rest_types;
15mod serde_utils;
16
17use rest_types::{AbortChunkedUploadRequest, InitChunkedUploadRequest, InitChunkedUploadResponse};
18
19const MEGABYTE: u64 = 1024 * 1024; pub const MIN_CHUNK_SIZE_BYTES: u64 = 5 * MEGABYTE;
21const MAX_CHUNK_SIZE_BYTES: u64 = 100 * MEGABYTE;
22const MAX_PARTS_COUNT: u64 = 10000;
23
24const SIZE_500MB: u64 = 500 * MEGABYTE;
25const SIZE_10GB: u64 = 10 * 1024 * MEGABYTE;
26const SIZE_50GB: u64 = 50 * 1024 * MEGABYTE;
27
28#[derive(Error, Debug)]
29pub enum AdaptiveError {
30 #[error("HTTP error: {0}")]
31 HttpError(#[from] reqwest::Error),
32
33 #[error("JSON error: {0}")]
34 JsonError(#[from] serde_json::Error),
35
36 #[error("IO error: {0}")]
37 IoError(#[from] std::io::Error),
38
39 #[error("URL parse error: {0}")]
40 UrlParseError(#[from] url::ParseError),
41
42 #[error("File too small for chunked upload: {size} bytes (minimum: {min_size} bytes)")]
43 FileTooSmall { size: u64, min_size: u64 },
44
45 #[error("File too large: {size} bytes exceeds maximum {max_size} bytes")]
46 FileTooLarge { size: u64, max_size: u64 },
47
48 #[error("GraphQL errors: {0:?}")]
49 GraphQLErrors(Vec<graphql_client::Error>),
50
51 #[error("No data returned from GraphQL")]
52 NoGraphQLData,
53
54 #[error("Job not found: {0}")]
55 JobNotFound(Uuid),
56
57 #[error("Failed to initialize chunked upload: {status} - {body}")]
58 ChunkedUploadInitFailed { status: String, body: String },
59
60 #[error("Failed to upload part {part_number}: {status} - {body}")]
61 ChunkedUploadPartFailed {
62 part_number: u64,
63 status: String,
64 body: String,
65 },
66
67 #[error("Failed to create dataset: {0}")]
68 DatasetCreationFailed(String),
69}
70
71type Result<T> = std::result::Result<T, AdaptiveError>;
72
73#[derive(Clone, Debug, Default)]
74pub struct ChunkedUploadProgress {
75 pub bytes_uploaded: u64,
76 pub total_bytes: u64,
77}
78
79#[derive(Debug)]
80pub enum UploadEvent {
81 Progress(ChunkedUploadProgress),
82 Complete(
83 create_dataset_from_multipart::CreateDatasetFromMultipartCreateDatasetFromMultipartUpload,
84 ),
85}
86
87pub fn calculate_upload_parts(file_size: u64) -> Result<(u64, u64)> {
88 if file_size < MIN_CHUNK_SIZE_BYTES {
89 return Err(AdaptiveError::FileTooSmall {
90 size: file_size,
91 min_size: MIN_CHUNK_SIZE_BYTES,
92 });
93 }
94
95 let mut chunk_size = if file_size < SIZE_500MB {
96 5 * MEGABYTE
97 } else if file_size < SIZE_10GB {
98 10 * MEGABYTE
99 } else if file_size < SIZE_50GB {
100 50 * MEGABYTE
101 } else {
102 100 * MEGABYTE
103 };
104
105 let mut total_parts = file_size.div_ceil(chunk_size);
106
107 if total_parts > MAX_PARTS_COUNT {
108 chunk_size = file_size.div_ceil(MAX_PARTS_COUNT);
109
110 if chunk_size > MAX_CHUNK_SIZE_BYTES {
111 let max_file_size = MAX_CHUNK_SIZE_BYTES * MAX_PARTS_COUNT;
112 return Err(AdaptiveError::FileTooLarge {
113 size: file_size,
114 max_size: max_file_size,
115 });
116 }
117
118 total_parts = file_size.div_ceil(chunk_size);
119 }
120
121 Ok((total_parts, chunk_size))
122}
123
124type IdOrKey = String;
125#[allow(clippy::upper_case_acronyms)]
126type UUID = Uuid;
127type JsObject = Map<String, Value>;
128type InputDatetime = String;
129#[allow(clippy::upper_case_acronyms)]
130type JSON = Value;
131
132#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
133pub struct Timestamp(pub SystemTime);
134
135impl<'de> serde::Deserialize<'de> for Timestamp {
136 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
137 where
138 D: serde::Deserializer<'de>,
139 {
140 let system_time = serde_utils::deserialize_timestamp_millis(deserializer)?;
141 Ok(Timestamp(system_time))
142 }
143}
144
145const PAGE_SIZE: usize = 20;
146
147#[derive(Debug, PartialEq, Serialize, Deserialize)]
148pub struct Upload(usize);
149
150#[derive(GraphQLQuery)]
151#[graphql(
152 schema_path = "schema.gql",
153 query_path = "src/graphql/list.graphql",
154 response_derives = "Debug, Clone"
155)]
156pub struct GetCustomRecipes;
157
158#[derive(GraphQLQuery)]
159#[graphql(
160 schema_path = "schema.gql",
161 query_path = "src/graphql/job.graphql",
162 response_derives = "Debug, Clone"
163)]
164pub struct GetJob;
165
166#[derive(GraphQLQuery)]
167#[graphql(
168 schema_path = "schema.gql",
169 query_path = "src/graphql/jobs.graphql",
170 response_derives = "Debug, Clone"
171)]
172pub struct ListJobs;
173
174#[derive(GraphQLQuery)]
175#[graphql(
176 schema_path = "schema.gql",
177 query_path = "src/graphql/cancel.graphql",
178 response_derives = "Debug, Clone"
179)]
180pub struct CancelJob;
181
182#[derive(GraphQLQuery)]
183#[graphql(
184 schema_path = "schema.gql",
185 query_path = "src/graphql/models.graphql",
186 response_derives = "Debug, Clone"
187)]
188pub struct ListModels;
189
190#[derive(GraphQLQuery)]
191#[graphql(
192 schema_path = "schema.gql",
193 query_path = "src/graphql/all_models.graphql",
194 response_derives = "Debug, Clone"
195)]
196pub struct ListAllModels;
197
198impl Display for get_job::JobStatus {
199 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
200 match self {
201 get_job::JobStatus::PENDING => write!(f, "Pending"),
202 get_job::JobStatus::RUNNING => write!(f, "Running"),
203 get_job::JobStatus::COMPLETED => write!(f, "Completed"),
204 get_job::JobStatus::FAILED => write!(f, "Failed"),
205 get_job::JobStatus::CANCELED => write!(f, "Canceled"),
206 get_job::JobStatus::Other(_) => write!(f, "Unknown"),
207 }
208 }
209}
210
211#[derive(GraphQLQuery)]
212#[graphql(
213 schema_path = "schema.gql",
214 query_path = "src/graphql/publish.graphql",
215 response_derives = "Debug, Clone"
216)]
217pub struct PublishCustomRecipe;
218
219#[derive(GraphQLQuery)]
220#[graphql(
221 schema_path = "schema.gql",
222 query_path = "src/graphql/upload_dataset.graphql",
223 response_derives = "Debug, Clone"
224)]
225pub struct UploadDataset;
226
227#[derive(GraphQLQuery)]
228#[graphql(
229 schema_path = "schema.gql",
230 query_path = "src/graphql/create_dataset_from_multipart.graphql",
231 response_derives = "Debug, Clone"
232)]
233pub struct CreateDatasetFromMultipart;
234
235#[derive(GraphQLQuery)]
236#[graphql(
237 schema_path = "schema.gql",
238 query_path = "src/graphql/run.graphql",
239 response_derives = "Debug, Clone"
240)]
241pub struct RunCustomRecipe;
242
243#[derive(GraphQLQuery)]
244#[graphql(
245 schema_path = "schema.gql",
246 query_path = "src/graphql/usecases.graphql",
247 response_derives = "Debug, Clone"
248)]
249pub struct ListUseCases;
250
251#[derive(GraphQLQuery)]
252#[graphql(
253 schema_path = "schema.gql",
254 query_path = "src/graphql/pools.graphql",
255 response_derives = "Debug, Clone"
256)]
257pub struct ListComputePools;
258
259#[derive(GraphQLQuery)]
260#[graphql(
261 schema_path = "schema.gql",
262 query_path = "src/graphql/recipe.graphql",
263 response_derives = "Debug, Clone"
264)]
265pub struct GetRecipe;
266
267const INIT_CHUNKED_UPLOAD_ROUTE: &str = "v1/upload/init";
268const UPLOAD_PART_ROUTE: &str = "v1/upload/part";
269const ABORT_CHUNKED_UPLOAD_ROUTE: &str = "v1/upload/abort";
270
271pub struct AdaptiveClient {
272 client: Client,
273 graphql_url: Url,
274 rest_base_url: Url,
275 auth_token: String,
276}
277
278impl AdaptiveClient {
279 pub fn new(api_base_url: Url, auth_token: String) -> Self {
280 let graphql_url = api_base_url
281 .join("graphql")
282 .expect("Failed to append graphql to base URL");
283
284 Self {
285 client: Client::new(),
286 graphql_url,
287 rest_base_url: api_base_url,
288 auth_token,
289 }
290 }
291
292 async fn execute_query<T>(&self, _query: T, variables: T::Variables) -> Result<T::ResponseData>
293 where
294 T: GraphQLQuery,
295 T::Variables: serde::Serialize,
296 T::ResponseData: DeserializeOwned,
297 {
298 let request_body = T::build_query(variables);
299
300 let response = self
301 .client
302 .post(self.graphql_url.clone())
303 .bearer_auth(&self.auth_token)
304 .json(&request_body)
305 .send()
306 .await?;
307
308 let response_body: Response<T::ResponseData> = response.json().await?;
309
310 match response_body.data {
311 Some(data) => Ok(data),
312 None => {
313 if let Some(errors) = response_body.errors {
314 return Err(AdaptiveError::GraphQLErrors(errors));
315 }
316 Err(AdaptiveError::NoGraphQLData)
317 }
318 }
319 }
320
321 pub async fn list_recipes(
322 &self,
323 usecase: &str,
324 ) -> Result<Vec<get_custom_recipes::GetCustomRecipesCustomRecipes>> {
325 let variables = get_custom_recipes::Variables {
326 usecase: IdOrKey::from(usecase),
327 };
328
329 let response_data = self.execute_query(GetCustomRecipes, variables).await?;
330 Ok(response_data.custom_recipes)
331 }
332
333 pub async fn get_job(&self, job_id: Uuid) -> Result<get_job::GetJobJob> {
334 let variables = get_job::Variables { id: job_id };
335
336 let response_data = self.execute_query(GetJob, variables).await?;
337
338 match response_data.job {
339 Some(job) => Ok(job),
340 None => Err(AdaptiveError::JobNotFound(job_id)),
341 }
342 }
343
344 pub async fn upload_dataset<P: AsRef<Path>>(
345 &self,
346 usecase: &str,
347 name: &str,
348 dataset: P,
349 ) -> Result<upload_dataset::UploadDatasetCreateDataset> {
350 let variables = upload_dataset::Variables {
351 usecase: IdOrKey::from(usecase),
352 file: Upload(0),
353 name: Some(name.to_string()),
354 };
355
356 let operations = UploadDataset::build_query(variables);
357 let operations = serde_json::to_string(&operations)?;
358
359 let file_map = r#"{ "0": ["variables.file"] }"#;
360
361 let dataset_file = reqwest::multipart::Part::file(dataset).await?;
362
363 let form = reqwest::multipart::Form::new()
364 .text("operations", operations)
365 .text("map", file_map)
366 .part("0", dataset_file);
367
368 let response = self
369 .client
370 .post(self.graphql_url.clone())
371 .bearer_auth(&self.auth_token)
372 .multipart(form)
373 .send()
374 .await?;
375
376 let response: Response<<UploadDataset as graphql_client::GraphQLQuery>::ResponseData> =
377 response.json().await?;
378
379 match response.data {
380 Some(data) => Ok(data.create_dataset),
381 None => {
382 if let Some(errors) = response.errors {
383 return Err(AdaptiveError::GraphQLErrors(errors));
384 }
385 Err(AdaptiveError::NoGraphQLData)
386 }
387 }
388 }
389
390 pub async fn publish_recipe<P: AsRef<Path>>(
391 &self,
392 usecase: &str,
393 name: &str,
394 key: &str,
395 recipe: P,
396 ) -> Result<publish_custom_recipe::PublishCustomRecipeCreateCustomRecipe> {
397 let variables = publish_custom_recipe::Variables {
398 usecase: IdOrKey::from(usecase),
399 file: Upload(0),
400 name: Some(name.to_string()),
401 key: Some(key.to_string()),
402 };
403
404 let operations = PublishCustomRecipe::build_query(variables);
405 let operations = serde_json::to_string(&operations)?;
406
407 let file_map = r#"{ "0": ["variables.file"] }"#;
408
409 let recipe_file = reqwest::multipart::Part::file(recipe).await?;
410
411 let form = reqwest::multipart::Form::new()
412 .text("operations", operations)
413 .text("map", file_map)
414 .part("0", recipe_file);
415
416 let response = self
417 .client
418 .post(self.graphql_url.clone())
419 .bearer_auth(&self.auth_token)
420 .multipart(form)
421 .send()
422 .await?;
423 let response: Response<
424 <PublishCustomRecipe as graphql_client::GraphQLQuery>::ResponseData,
425 > = response.json().await?;
426
427 match response.data {
428 Some(data) => Ok(data.create_custom_recipe),
429 None => {
430 if let Some(errors) = response.errors {
431 return Err(AdaptiveError::GraphQLErrors(errors));
432 }
433 Err(AdaptiveError::NoGraphQLData)
434 }
435 }
436 }
437
438 pub async fn run_recipe(
439 &self,
440 usecase: &str,
441 recipe_id: &str,
442 parameters: Map<String, Value>,
443 name: Option<String>,
444 compute_pool: Option<String>,
445 num_gpus: u32,
446 ) -> Result<run_custom_recipe::RunCustomRecipeCreateJob> {
447 let variables = run_custom_recipe::Variables {
448 input: run_custom_recipe::JobInput {
449 recipe: recipe_id.to_string(),
450 use_case: usecase.to_string(),
451 args: parameters,
452 name,
453 compute_pool,
454 num_gpus: num_gpus as i64,
455 },
456 };
457
458 let response_data = self.execute_query(RunCustomRecipe, variables).await?;
459 Ok(response_data.create_job)
460 }
461
462 pub async fn list_jobs(
463 &self,
464 usecase: Option<String>,
465 ) -> Result<Vec<list_jobs::ListJobsJobsNodes>> {
466 let mut jobs = Vec::new();
467 let mut page = self.list_jobs_page(usecase.clone(), None).await?;
468 jobs.extend(page.nodes.iter().cloned());
469 while page.page_info.has_next_page {
470 page = self
471 .list_jobs_page(usecase.clone(), page.page_info.end_cursor)
472 .await?;
473 jobs.extend(page.nodes.iter().cloned());
474 }
475 Ok(jobs)
476 }
477
478 async fn list_jobs_page(
479 &self,
480 usecase: Option<String>,
481 after: Option<String>,
482 ) -> Result<list_jobs::ListJobsJobs> {
483 let variables = list_jobs::Variables {
484 filter: Some(list_jobs::ListJobsFilterInput {
485 use_case: usecase,
486 kind: Some(vec![list_jobs::JobKind::CUSTOM]),
487 status: Some(vec![
488 list_jobs::JobStatus::RUNNING,
489 list_jobs::JobStatus::PENDING,
490 ]),
491 timerange: None,
492 custom_recipes: None,
493 artifacts: None,
494 }),
495 cursor: Some(list_jobs::CursorPageInput {
496 first: Some(PAGE_SIZE as i64),
497 after,
498 before: None,
499 last: None,
500 offset: None,
501 }),
502 };
503
504 let response_data = self.execute_query(ListJobs, variables).await?;
505 Ok(response_data.jobs)
506 }
507
508 pub async fn cancel_job(&self, job_id: Uuid) -> Result<cancel_job::CancelJobCancelJob> {
509 let variables = cancel_job::Variables { job_id };
510
511 let response_data = self.execute_query(CancelJob, variables).await?;
512 Ok(response_data.cancel_job)
513 }
514
515 pub async fn list_models(
516 &self,
517 usecase: String,
518 ) -> Result<Vec<list_models::ListModelsUseCaseModelServices>> {
519 let variables = list_models::Variables {
520 use_case_id: usecase,
521 };
522
523 let response_data = self.execute_query(ListModels, variables).await?;
524 Ok(response_data
525 .use_case
526 .map(|use_case| use_case.model_services)
527 .unwrap_or(Vec::new()))
528 }
529
530 pub async fn list_all_models(&self) -> Result<Vec<list_all_models::ListAllModelsModels>> {
531 let variables = list_all_models::Variables {};
532
533 let response_data = self.execute_query(ListAllModels, variables).await?;
534 Ok(response_data.models)
535 }
536
537 pub async fn list_usecases(&self) -> Result<Vec<list_use_cases::ListUseCasesUseCases>> {
538 let variables = list_use_cases::Variables {};
539
540 let response_data = self.execute_query(ListUseCases, variables).await?;
541 Ok(response_data.use_cases)
542 }
543
544 pub async fn list_pools(
545 &self,
546 ) -> Result<Vec<list_compute_pools::ListComputePoolsComputePools>> {
547 let variables = list_compute_pools::Variables {};
548
549 let response_data = self.execute_query(ListComputePools, variables).await?;
550 Ok(response_data.compute_pools)
551 }
552
553 pub async fn get_recipe(
554 &self,
555 usecase: String,
556 id_or_key: String,
557 ) -> Result<Option<get_recipe::GetRecipeCustomRecipe>> {
558 let variables = get_recipe::Variables { usecase, id_or_key };
559
560 let response_data = self.execute_query(GetRecipe, variables).await?;
561 Ok(response_data.custom_recipe)
562 }
563
564 async fn init_chunked_upload(&self, total_parts: u64) -> Result<String> {
565 let url = self.rest_base_url.join(INIT_CHUNKED_UPLOAD_ROUTE)?;
566
567 let request = InitChunkedUploadRequest {
568 content_type: "application/jsonl".to_string(),
569 metadata: None,
570 total_parts_count: total_parts,
571 };
572
573 let response = self
574 .client
575 .post(url)
576 .bearer_auth(&self.auth_token)
577 .json(&request)
578 .send()
579 .await?;
580
581 if !response.status().is_success() {
582 return Err(AdaptiveError::ChunkedUploadInitFailed {
583 status: response.status().to_string(),
584 body: response.text().await.unwrap_or_default(),
585 });
586 }
587
588 let init_response: InitChunkedUploadResponse = response.json().await?;
589 Ok(init_response.session_id)
590 }
591
592 async fn upload_part(
593 &self,
594 session_id: &str,
595 part_number: u64,
596 data: Vec<u8>,
597 progress_tx: mpsc::Sender<u64>,
598 ) -> Result<()> {
599 const SUB_CHUNK_SIZE: usize = 64 * 1024;
600
601 let url = self.rest_base_url.join(UPLOAD_PART_ROUTE)?;
602
603 let chunks: Vec<Vec<u8>> = data
604 .chunks(SUB_CHUNK_SIZE)
605 .map(|chunk| chunk.to_vec())
606 .collect();
607
608 let stream = futures::stream::iter(chunks).map(move |chunk| {
609 let len = chunk.len() as u64;
610 let tx = progress_tx.clone();
611 let _ = tx.try_send(len);
612 Ok::<_, std::io::Error>(chunk)
613 });
614
615 let body = reqwest::Body::wrap_stream(stream);
616
617 let response = self
618 .client
619 .post(url)
620 .bearer_auth(&self.auth_token)
621 .query(&[
622 ("session_id", session_id),
623 ("part_number", &part_number.to_string()),
624 ])
625 .header("Content-Type", "application/octet-stream")
626 .body(body)
627 .send()
628 .await?;
629
630 if !response.status().is_success() {
631 return Err(AdaptiveError::ChunkedUploadPartFailed {
632 part_number,
633 status: response.status().to_string(),
634 body: response.text().await.unwrap_or_default(),
635 });
636 }
637
638 Ok(())
639 }
640
641 async fn abort_chunked_upload(&self, session_id: &str) -> Result<()> {
642 let url = self.rest_base_url.join(ABORT_CHUNKED_UPLOAD_ROUTE)?;
643
644 let request = AbortChunkedUploadRequest {
645 session_id: session_id.to_string(),
646 };
647
648 let _ = self
649 .client
650 .delete(url)
651 .bearer_auth(&self.auth_token)
652 .json(&request)
653 .send()
654 .await;
655
656 Ok(())
657 }
658
659 async fn create_dataset_from_multipart(
660 &self,
661 usecase: &str,
662 name: &str,
663 key: &str,
664 session_id: &str,
665 ) -> Result<
666 create_dataset_from_multipart::CreateDatasetFromMultipartCreateDatasetFromMultipartUpload,
667 > {
668 let variables = create_dataset_from_multipart::Variables {
669 input: create_dataset_from_multipart::DatasetCreateFromMultipartUpload {
670 use_case: usecase.to_string(),
671 name: name.to_string(),
672 key: Some(key.to_string()),
673 source: None,
674 upload_session_id: session_id.to_string(),
675 },
676 };
677
678 let response_data = self
679 .execute_query(CreateDatasetFromMultipart, variables)
680 .await?;
681 Ok(response_data.create_dataset_from_multipart_upload)
682 }
683
684 pub fn chunked_upload_dataset<'a, P: AsRef<Path> + Send + 'a>(
685 &'a self,
686 usecase: &'a str,
687 name: &'a str,
688 key: &'a str,
689 dataset: P,
690 ) -> Result<BoxStream<'a, Result<UploadEvent>>> {
691 let file_size = std::fs::metadata(dataset.as_ref())?.len();
692
693 let (total_parts, chunk_size) = calculate_upload_parts(file_size)?;
694
695 let stream = async_stream::try_stream! {
696 yield UploadEvent::Progress(ChunkedUploadProgress {
697 bytes_uploaded: 0,
698 total_bytes: file_size,
699 });
700
701 let session_id = self.init_chunked_upload(total_parts).await?;
702
703 let mut file = File::open(dataset.as_ref())?;
704 let mut buffer = vec![0u8; chunk_size as usize];
705 let mut bytes_uploaded = 0u64;
706
707 let (progress_tx, mut progress_rx) = mpsc::channel::<u64>(64);
708
709 for part_number in 1..=total_parts {
710 let bytes_read = file.read(&mut buffer)?;
711 let chunk_data = buffer[..bytes_read].to_vec();
712
713 let upload_fut = self.upload_part(&session_id, part_number, chunk_data, progress_tx.clone());
714 tokio::pin!(upload_fut);
715
716 let upload_result: Result<()> = loop {
717 tokio::select! {
718 biased;
719 result = &mut upload_fut => {
720 break result;
721 }
722 Some(bytes) = progress_rx.recv() => {
723 bytes_uploaded += bytes;
724 yield UploadEvent::Progress(ChunkedUploadProgress {
725 bytes_uploaded,
726 total_bytes: file_size,
727 });
728 }
729 }
730 };
731
732 if let Err(e) = upload_result {
733 let _ = self.abort_chunked_upload(&session_id).await;
734 Err(e)?;
735 }
736 }
737
738 let create_result = self
739 .create_dataset_from_multipart(usecase, name, key, &session_id)
740 .await;
741
742 match create_result {
743 Ok(response) => {
744 yield UploadEvent::Complete(response);
745 }
746 Err(e) => {
747 let _ = self.abort_chunked_upload(&session_id).await;
748 Err(AdaptiveError::DatasetCreationFailed(e.to_string()))?;
749 }
750 }
751 };
752
753 Ok(Box::pin(stream))
754 }
755}