1use std::{fmt::Display, fs::File, io::Read, path::Path, time::SystemTime};
2
3use futures::{StreamExt, stream::BoxStream};
4use tokio::sync::mpsc;
5
6use anyhow::{Context, Result, anyhow, bail};
7use graphql_client::{GraphQLQuery, Response};
8use reqwest::Client;
9use serde::{Deserialize, Serialize, de::DeserializeOwned};
10use serde_json::{Map, Value};
11use url::Url;
12use uuid::Uuid;
13
14mod rest_types;
15mod serde_utils;
16
17use rest_types::{AbortChunkedUploadRequest, InitChunkedUploadRequest, InitChunkedUploadResponse};
18
19const MEGABYTE: u64 = 1024 * 1024; pub const MIN_CHUNK_SIZE_BYTES: u64 = 5 * MEGABYTE;
21const MAX_CHUNK_SIZE_BYTES: u64 = 100 * MEGABYTE;
22const MAX_PARTS_COUNT: u64 = 10000;
23
24const SIZE_500MB: u64 = 500 * MEGABYTE;
25const SIZE_10GB: u64 = 10 * 1024 * MEGABYTE;
26const SIZE_50GB: u64 = 50 * 1024 * MEGABYTE;
27
28#[derive(Clone, Debug, Default)]
29pub struct ChunkedUploadProgress {
30 pub bytes_uploaded: u64,
31 pub total_bytes: u64,
32}
33
34#[derive(Debug)]
35pub enum UploadEvent {
36 Progress(ChunkedUploadProgress),
37 Complete(
38 create_dataset_from_multipart::CreateDatasetFromMultipartCreateDatasetFromMultipartUpload,
39 ),
40}
41
42pub fn calculate_upload_parts(file_size: u64) -> Result<(u64, u64)> {
43 if file_size < MIN_CHUNK_SIZE_BYTES {
44 bail!(
45 "File size ({} bytes) is too small for chunked upload",
46 file_size
47 );
48 }
49
50 let mut chunk_size = if file_size < SIZE_500MB {
51 5 * MEGABYTE
52 } else if file_size < SIZE_10GB {
53 10 * MEGABYTE
54 } else if file_size < SIZE_50GB {
55 50 * MEGABYTE
56 } else {
57 100 * MEGABYTE
58 };
59
60 let mut total_parts = file_size.div_ceil(chunk_size);
61
62 if total_parts > MAX_PARTS_COUNT {
63 chunk_size = file_size.div_ceil(MAX_PARTS_COUNT);
64
65 if chunk_size > MAX_CHUNK_SIZE_BYTES {
66 let max_file_size = MAX_CHUNK_SIZE_BYTES * MAX_PARTS_COUNT;
67 bail!(
68 "File size ({} bytes) exceeds maximum uploadable size ({} bytes = {} parts * {} bytes)",
69 file_size,
70 max_file_size,
71 MAX_PARTS_COUNT,
72 MAX_CHUNK_SIZE_BYTES
73 );
74 }
75
76 total_parts = file_size.div_ceil(chunk_size);
77 }
78
79 Ok((total_parts, chunk_size))
80}
81
82type IdOrKey = String;
83#[allow(clippy::upper_case_acronyms)]
84type UUID = Uuid;
85type JsObject = Map<String, Value>;
86type InputDatetime = String;
87#[allow(clippy::upper_case_acronyms)]
88type JSON = Value;
89
90#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
91pub struct Timestamp(pub SystemTime);
92
93impl<'de> serde::Deserialize<'de> for Timestamp {
94 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
95 where
96 D: serde::Deserializer<'de>,
97 {
98 let system_time = serde_utils::deserialize_timestamp_millis(deserializer)?;
99 Ok(Timestamp(system_time))
100 }
101}
102
103const PAGE_SIZE: usize = 20;
104
105#[derive(Debug, PartialEq, Serialize, Deserialize)]
106pub struct Upload(usize);
107
108#[derive(GraphQLQuery)]
109#[graphql(
110 schema_path = "schema.gql",
111 query_path = "src/graphql/list.graphql",
112 response_derives = "Debug, Clone"
113)]
114pub struct GetCustomRecipes;
115
116#[derive(GraphQLQuery)]
117#[graphql(
118 schema_path = "schema.gql",
119 query_path = "src/graphql/job.graphql",
120 response_derives = "Debug, Clone"
121)]
122pub struct GetJob;
123
124#[derive(GraphQLQuery)]
125#[graphql(
126 schema_path = "schema.gql",
127 query_path = "src/graphql/jobs.graphql",
128 response_derives = "Debug, Clone"
129)]
130pub struct ListJobs;
131
132#[derive(GraphQLQuery)]
133#[graphql(
134 schema_path = "schema.gql",
135 query_path = "src/graphql/cancel.graphql",
136 response_derives = "Debug, Clone"
137)]
138pub struct CancelJob;
139
140#[derive(GraphQLQuery)]
141#[graphql(
142 schema_path = "schema.gql",
143 query_path = "src/graphql/models.graphql",
144 response_derives = "Debug, Clone"
145)]
146pub struct ListModels;
147
148#[derive(GraphQLQuery)]
149#[graphql(
150 schema_path = "schema.gql",
151 query_path = "src/graphql/all_models.graphql",
152 response_derives = "Debug, Clone"
153)]
154pub struct ListAllModels;
155
156impl Display for get_job::JobStatus {
157 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
158 match self {
159 get_job::JobStatus::PENDING => write!(f, "Pending"),
160 get_job::JobStatus::RUNNING => write!(f, "Running"),
161 get_job::JobStatus::COMPLETED => write!(f, "Completed"),
162 get_job::JobStatus::FAILED => write!(f, "Failed"),
163 get_job::JobStatus::CANCELED => write!(f, "Canceled"),
164 get_job::JobStatus::Other(_) => write!(f, "Unknown"),
165 }
166 }
167}
168
169#[derive(GraphQLQuery)]
170#[graphql(
171 schema_path = "schema.gql",
172 query_path = "src/graphql/publish.graphql",
173 response_derives = "Debug, Clone"
174)]
175pub struct PublishCustomRecipe;
176
177#[derive(GraphQLQuery)]
178#[graphql(
179 schema_path = "schema.gql",
180 query_path = "src/graphql/upload_dataset.graphql",
181 response_derives = "Debug, Clone"
182)]
183pub struct UploadDataset;
184
185#[derive(GraphQLQuery)]
186#[graphql(
187 schema_path = "schema.gql",
188 query_path = "src/graphql/create_dataset_from_multipart.graphql",
189 response_derives = "Debug, Clone"
190)]
191pub struct CreateDatasetFromMultipart;
192
193#[derive(GraphQLQuery)]
194#[graphql(
195 schema_path = "schema.gql",
196 query_path = "src/graphql/run.graphql",
197 response_derives = "Debug, Clone"
198)]
199pub struct RunCustomRecipe;
200
201#[derive(GraphQLQuery)]
202#[graphql(
203 schema_path = "schema.gql",
204 query_path = "src/graphql/usecases.graphql",
205 response_derives = "Debug, Clone"
206)]
207pub struct ListUseCases;
208
209#[derive(GraphQLQuery)]
210#[graphql(
211 schema_path = "schema.gql",
212 query_path = "src/graphql/pools.graphql",
213 response_derives = "Debug, Clone"
214)]
215pub struct ListComputePools;
216
217#[derive(GraphQLQuery)]
218#[graphql(
219 schema_path = "schema.gql",
220 query_path = "src/graphql/recipe.graphql",
221 response_derives = "Debug, Clone"
222)]
223pub struct GetRecipe;
224
225const INIT_CHUNKED_UPLOAD_ROUTE: &str = "v1/upload/init";
226const UPLOAD_PART_ROUTE: &str = "v1/upload/part";
227const ABORT_CHUNKED_UPLOAD_ROUTE: &str = "v1/upload/abort";
228
229pub struct AdaptiveClient {
230 client: Client,
231 graphql_url: Url,
232 rest_base_url: Url,
233 auth_token: String,
234}
235
236impl AdaptiveClient {
237 pub fn new(api_base_url: Url, auth_token: String) -> Self {
238 let graphql_url = api_base_url
239 .join("graphql")
240 .expect("Failed to append graphql to base URL");
241
242 Self {
243 client: Client::new(),
244 graphql_url,
245 rest_base_url: api_base_url,
246 auth_token,
247 }
248 }
249
250 async fn execute_query<T>(&self, _query: T, variables: T::Variables) -> Result<T::ResponseData>
251 where
252 T: GraphQLQuery,
253 T::Variables: serde::Serialize,
254 T::ResponseData: DeserializeOwned,
255 {
256 let request_body = T::build_query(variables);
257
258 let response = self
259 .client
260 .post(self.graphql_url.clone())
261 .bearer_auth(&self.auth_token)
262 .json(&request_body)
263 .send()
264 .await?;
265
266 let response_body: Response<T::ResponseData> = response.json().await?;
267
268 match response_body.data {
269 Some(data) => Ok(data),
270 None => {
271 if let Some(errors) = response_body.errors {
272 bail!("GraphQL errors: {:?}", errors);
273 }
274 Err(anyhow!("No data returned from GraphQL query"))
275 }
276 }
277 }
278
279 pub async fn list_recipes(
280 &self,
281 usecase: &str,
282 ) -> Result<Vec<get_custom_recipes::GetCustomRecipesCustomRecipes>> {
283 let variables = get_custom_recipes::Variables {
284 usecase: IdOrKey::from(usecase),
285 };
286
287 let response_data = self.execute_query(GetCustomRecipes, variables).await?;
288 Ok(response_data.custom_recipes)
289 }
290
291 pub async fn get_job(&self, job_id: Uuid) -> Result<get_job::GetJobJob> {
292 let variables = get_job::Variables { id: job_id };
293
294 let response_data = self.execute_query(GetJob, variables).await?;
295
296 match response_data.job {
297 Some(job) => Ok(job),
298 None => Err(anyhow!("Job with ID '{}' not found", job_id)),
299 }
300 }
301
302 pub async fn upload_dataset<P: AsRef<Path>>(
303 &self,
304 usecase: &str,
305 name: &str,
306 dataset: P,
307 ) -> Result<upload_dataset::UploadDatasetCreateDataset> {
308 let variables = upload_dataset::Variables {
309 usecase: IdOrKey::from(usecase),
310 file: Upload(0),
311 name: Some(name.to_string()),
312 };
313
314 let operations = UploadDataset::build_query(variables);
315 let operations = serde_json::to_string(&operations)?;
316
317 let file_map = r#"{ "0": ["variables.file"] }"#;
318
319 let dataset_file = reqwest::multipart::Part::file(dataset)
320 .await
321 .context("Unable to read dataset")?;
322
323 let form = reqwest::multipart::Form::new()
324 .text("operations", operations)
325 .text("map", file_map)
326 .part("0", dataset_file);
327
328 let response = self
329 .client
330 .post(self.graphql_url.clone())
331 .bearer_auth(&self.auth_token)
332 .multipart(form)
333 .send()
334 .await?;
335
336 let response: Response<<UploadDataset as graphql_client::GraphQLQuery>::ResponseData> =
337 response.json().await?;
338
339 match response.data {
340 Some(data) => Ok(data.create_dataset),
341 None => {
342 if let Some(errors) = response.errors {
343 bail!("GraphQL errors: {:?}", errors);
344 }
345 Err(anyhow!("No data returned from GraphQL mutation"))
346 }
347 }
348 }
349
350 pub async fn publish_recipe<P: AsRef<Path>>(
351 &self,
352 usecase: &str,
353 name: &str,
354 key: &str,
355 recipe: P,
356 ) -> Result<publish_custom_recipe::PublishCustomRecipeCreateCustomRecipe> {
357 let variables = publish_custom_recipe::Variables {
358 usecase: IdOrKey::from(usecase),
359 file: Upload(0),
360 name: Some(name.to_string()),
361 key: Some(key.to_string()),
362 };
363
364 let operations = PublishCustomRecipe::build_query(variables);
365 let operations = serde_json::to_string(&operations)?;
366
367 let file_map = r#"{ "0": ["variables.file"] }"#;
368
369 let recipe_file = reqwest::multipart::Part::file(recipe)
370 .await
371 .context("Unable to read recipe")?;
372
373 let form = reqwest::multipart::Form::new()
374 .text("operations", operations)
375 .text("map", file_map)
376 .part("0", recipe_file);
377
378 let response = self
379 .client
380 .post(self.graphql_url.clone())
381 .bearer_auth(&self.auth_token)
382 .multipart(form)
383 .send()
384 .await?;
385 let response: Response<
386 <PublishCustomRecipe as graphql_client::GraphQLQuery>::ResponseData,
387 > = response.json().await?;
388
389 match response.data {
390 Some(data) => Ok(data.create_custom_recipe),
391 None => {
392 if let Some(errors) = response.errors {
393 bail!("GraphQL errors: {:?}", errors);
394 }
395 Err(anyhow!("No data returned from GraphQL mutation"))
396 }
397 }
398 }
399
400 pub async fn run_recipe(
401 &self,
402 usecase: &str,
403 recipe_id: &str,
404 parameters: Map<String, Value>,
405 name: Option<String>,
406 compute_pool: Option<String>,
407 num_gpus: u32,
408 ) -> Result<run_custom_recipe::RunCustomRecipeCreateJob> {
409 let variables = run_custom_recipe::Variables {
410 input: run_custom_recipe::JobInput {
411 recipe: recipe_id.to_string(),
412 use_case: usecase.to_string(),
413 args: parameters,
414 name,
415 compute_pool,
416 num_gpus: num_gpus as i64,
417 },
418 };
419
420 let response_data = self.execute_query(RunCustomRecipe, variables).await?;
421 Ok(response_data.create_job)
422 }
423
424 pub async fn list_jobs(
425 &self,
426 usecase: Option<String>,
427 ) -> Result<Vec<list_jobs::ListJobsJobsNodes>> {
428 let mut jobs = Vec::new();
429 let mut page = self.list_jobs_page(usecase.clone(), None).await?;
430 jobs.extend(page.nodes.iter().cloned());
431 while page.page_info.has_next_page {
432 page = self
433 .list_jobs_page(usecase.clone(), page.page_info.end_cursor)
434 .await?;
435 jobs.extend(page.nodes.iter().cloned());
436 }
437 Ok(jobs)
438 }
439
440 async fn list_jobs_page(
441 &self,
442 usecase: Option<String>,
443 after: Option<String>,
444 ) -> Result<list_jobs::ListJobsJobs> {
445 let variables = list_jobs::Variables {
446 filter: Some(list_jobs::ListJobsFilterInput {
447 use_case: usecase,
448 kind: Some(vec![list_jobs::JobKind::CUSTOM]),
449 status: Some(vec![
450 list_jobs::JobStatus::RUNNING,
451 list_jobs::JobStatus::PENDING,
452 ]),
453 timerange: None,
454 custom_recipes: None,
455 artifacts: None,
456 }),
457 cursor: Some(list_jobs::CursorPageInput {
458 first: Some(PAGE_SIZE as i64),
459 after,
460 before: None,
461 last: None,
462 offset: None,
463 }),
464 };
465
466 let response_data = self.execute_query(ListJobs, variables).await?;
467 Ok(response_data.jobs)
468 }
469
470 pub async fn cancel_job(&self, job_id: Uuid) -> Result<cancel_job::CancelJobCancelJob> {
471 let variables = cancel_job::Variables { job_id };
472
473 let response_data = self.execute_query(CancelJob, variables).await?;
474 Ok(response_data.cancel_job)
475 }
476
477 pub async fn list_models(
478 &self,
479 usecase: String,
480 ) -> Result<Vec<list_models::ListModelsUseCaseModelServices>> {
481 let variables = list_models::Variables {
482 use_case_id: usecase,
483 };
484
485 let response_data = self.execute_query(ListModels, variables).await?;
486 Ok(response_data
487 .use_case
488 .map(|use_case| use_case.model_services)
489 .unwrap_or(Vec::new()))
490 }
491
492 pub async fn list_all_models(&self) -> Result<Vec<list_all_models::ListAllModelsModels>> {
493 let variables = list_all_models::Variables {};
494
495 let response_data = self.execute_query(ListAllModels, variables).await?;
496 Ok(response_data.models)
497 }
498
499 pub async fn list_usecases(&self) -> Result<Vec<list_use_cases::ListUseCasesUseCases>> {
500 let variables = list_use_cases::Variables {};
501
502 let response_data = self.execute_query(ListUseCases, variables).await?;
503 Ok(response_data.use_cases)
504 }
505
506 pub async fn list_pools(
507 &self,
508 ) -> Result<Vec<list_compute_pools::ListComputePoolsComputePools>> {
509 let variables = list_compute_pools::Variables {};
510
511 let response_data = self.execute_query(ListComputePools, variables).await?;
512 Ok(response_data.compute_pools)
513 }
514
515 pub async fn get_recipe(
516 &self,
517 usecase: String,
518 id_or_key: String,
519 ) -> Result<Option<get_recipe::GetRecipeCustomRecipe>> {
520 let variables = get_recipe::Variables { usecase, id_or_key };
521
522 let response_data = self.execute_query(GetRecipe, variables).await?;
523 Ok(response_data.custom_recipe)
524 }
525
526 async fn init_chunked_upload(&self, total_parts: u64) -> Result<String> {
527 let url = self
528 .rest_base_url
529 .join(INIT_CHUNKED_UPLOAD_ROUTE)
530 .context("Failed to construct init upload URL")?;
531
532 let request = InitChunkedUploadRequest {
533 content_type: "application/jsonl".to_string(),
534 metadata: None,
535 total_parts_count: total_parts,
536 };
537
538 let response = self
539 .client
540 .post(url)
541 .bearer_auth(&self.auth_token)
542 .json(&request)
543 .send()
544 .await?;
545
546 if !response.status().is_success() {
547 bail!(
548 "Failed to initialize chunked upload: {} - {}",
549 response.status(),
550 response.text().await.unwrap_or_default()
551 );
552 }
553
554 let init_response: InitChunkedUploadResponse = response.json().await?;
555 Ok(init_response.session_id)
556 }
557
558 async fn upload_part(
559 &self,
560 session_id: &str,
561 part_number: u64,
562 data: Vec<u8>,
563 progress_tx: mpsc::Sender<u64>,
564 ) -> Result<()> {
565 const SUB_CHUNK_SIZE: usize = 64 * 1024;
566
567 let url = self
568 .rest_base_url
569 .join(UPLOAD_PART_ROUTE)
570 .context("Failed to construct upload part URL")?;
571
572 let chunks: Vec<Vec<u8>> = data
573 .chunks(SUB_CHUNK_SIZE)
574 .map(|chunk| chunk.to_vec())
575 .collect();
576
577 let stream = futures::stream::iter(chunks).map(move |chunk| {
578 let len = chunk.len() as u64;
579 let tx = progress_tx.clone();
580 let _ = tx.try_send(len);
581 Ok::<_, std::io::Error>(chunk)
582 });
583
584 let body = reqwest::Body::wrap_stream(stream);
585
586 let response = self
587 .client
588 .post(url)
589 .bearer_auth(&self.auth_token)
590 .query(&[
591 ("session_id", session_id),
592 ("part_number", &part_number.to_string()),
593 ])
594 .header("Content-Type", "application/octet-stream")
595 .body(body)
596 .send()
597 .await?;
598
599 if !response.status().is_success() {
600 bail!(
601 "Failed to upload part {}: {} - {}",
602 part_number,
603 response.status(),
604 response.text().await.unwrap_or_default()
605 );
606 }
607
608 Ok(())
609 }
610
611 async fn abort_chunked_upload(&self, session_id: &str) -> Result<()> {
612 let url = self
613 .rest_base_url
614 .join(ABORT_CHUNKED_UPLOAD_ROUTE)
615 .context("Failed to construct abort upload URL")?;
616
617 let request = AbortChunkedUploadRequest {
618 session_id: session_id.to_string(),
619 };
620
621 let _ = self
622 .client
623 .delete(url)
624 .bearer_auth(&self.auth_token)
625 .json(&request)
626 .send()
627 .await;
628
629 Ok(())
630 }
631
632 async fn create_dataset_from_multipart(
633 &self,
634 usecase: &str,
635 name: &str,
636 key: &str,
637 session_id: &str,
638 ) -> Result<
639 create_dataset_from_multipart::CreateDatasetFromMultipartCreateDatasetFromMultipartUpload,
640 > {
641 let variables = create_dataset_from_multipart::Variables {
642 input: create_dataset_from_multipart::DatasetCreateFromMultipartUpload {
643 use_case: usecase.to_string(),
644 name: name.to_string(),
645 key: Some(key.to_string()),
646 source: None,
647 upload_session_id: session_id.to_string(),
648 },
649 };
650
651 let response_data = self
652 .execute_query(CreateDatasetFromMultipart, variables)
653 .await?;
654 Ok(response_data.create_dataset_from_multipart_upload)
655 }
656
657 pub fn chunked_upload_dataset<'a, P: AsRef<Path> + Send + 'a>(
658 &'a self,
659 usecase: &'a str,
660 name: &'a str,
661 key: &'a str,
662 dataset: P,
663 ) -> Result<BoxStream<'a, Result<UploadEvent>>> {
664 let file_size = std::fs::metadata(dataset.as_ref())
665 .context("Failed to get file metadata")?
666 .len();
667
668 let (total_parts, chunk_size) = calculate_upload_parts(file_size)?;
669
670 let stream = async_stream::try_stream! {
671 yield UploadEvent::Progress(ChunkedUploadProgress {
672 bytes_uploaded: 0,
673 total_bytes: file_size,
674 });
675
676 let session_id = self.init_chunked_upload(total_parts).await?;
677
678 let mut file =
679 File::open(dataset.as_ref()).context("Failed to open dataset file")?;
680 let mut buffer = vec![0u8; chunk_size as usize];
681 let mut bytes_uploaded = 0u64;
682
683 let (progress_tx, mut progress_rx) = mpsc::channel::<u64>(64);
684
685 for part_number in 1..=total_parts {
686 let bytes_read = file.read(&mut buffer).context("Failed to read chunk")?;
687 let chunk_data = buffer[..bytes_read].to_vec();
688
689 let upload_fut = self.upload_part(&session_id, part_number, chunk_data, progress_tx.clone());
690 tokio::pin!(upload_fut);
691
692 let upload_result: Result<()> = loop {
693 tokio::select! {
694 biased;
695 result = &mut upload_fut => {
696 break result;
697 }
698 Some(bytes) = progress_rx.recv() => {
699 bytes_uploaded += bytes;
700 yield UploadEvent::Progress(ChunkedUploadProgress {
701 bytes_uploaded,
702 total_bytes: file_size,
703 });
704 }
705 }
706 };
707
708 if let Err(e) = upload_result {
709 let _ = self.abort_chunked_upload(&session_id).await;
710 Err(e)?;
711 }
712 }
713
714 let create_result = self
715 .create_dataset_from_multipart(usecase, name, key, &session_id)
716 .await;
717
718 match create_result {
719 Ok(response) => {
720 yield UploadEvent::Complete(response);
721 }
722 Err(e) => {
723 let _ = self.abort_chunked_upload(&session_id).await;
724 Err(anyhow!("Failed to create dataset: {}", e))?;
725 }
726 }
727 };
728
729 Ok(Box::pin(stream))
730 }
731}