adaptive_client_rust/
lib.rs

1use std::{fmt::Display, fs::File, io::Read, path::Path, time::SystemTime};
2
3use futures::{StreamExt, stream::BoxStream};
4use tokio::sync::mpsc;
5
6use anyhow::{Context, Result, anyhow, bail};
7use graphql_client::{GraphQLQuery, Response};
8use reqwest::Client;
9use serde::{Deserialize, Serialize, de::DeserializeOwned};
10use serde_json::{Map, Value};
11use url::Url;
12use uuid::Uuid;
13
14mod rest_types;
15mod serde_utils;
16
17use rest_types::{AbortChunkedUploadRequest, InitChunkedUploadRequest, InitChunkedUploadResponse};
18
19const MEGABYTE: u64 = 1024 * 1024; // 1MB
20pub const MIN_CHUNK_SIZE_BYTES: u64 = 5 * MEGABYTE;
21const MAX_CHUNK_SIZE_BYTES: u64 = 100 * MEGABYTE;
22const MAX_PARTS_COUNT: u64 = 10000;
23
24const SIZE_500MB: u64 = 500 * MEGABYTE;
25const SIZE_10GB: u64 = 10 * 1024 * MEGABYTE;
26const SIZE_50GB: u64 = 50 * 1024 * MEGABYTE;
27
28#[derive(Clone, Debug, Default)]
29pub struct ChunkedUploadProgress {
30    pub bytes_uploaded: u64,
31    pub total_bytes: u64,
32}
33
34#[derive(Debug)]
35pub enum UploadEvent {
36    Progress(ChunkedUploadProgress),
37    Complete(
38        create_dataset_from_multipart::CreateDatasetFromMultipartCreateDatasetFromMultipartUpload,
39    ),
40}
41
42pub fn calculate_upload_parts(file_size: u64) -> Result<(u64, u64)> {
43    if file_size < MIN_CHUNK_SIZE_BYTES {
44        bail!(
45            "File size ({} bytes) is too small for chunked upload",
46            file_size
47        );
48    }
49
50    let mut chunk_size = if file_size < SIZE_500MB {
51        5 * MEGABYTE
52    } else if file_size < SIZE_10GB {
53        10 * MEGABYTE
54    } else if file_size < SIZE_50GB {
55        50 * MEGABYTE
56    } else {
57        100 * MEGABYTE
58    };
59
60    let mut total_parts = file_size.div_ceil(chunk_size);
61
62    if total_parts > MAX_PARTS_COUNT {
63        chunk_size = file_size.div_ceil(MAX_PARTS_COUNT);
64
65        if chunk_size > MAX_CHUNK_SIZE_BYTES {
66            let max_file_size = MAX_CHUNK_SIZE_BYTES * MAX_PARTS_COUNT;
67            bail!(
68                "File size ({} bytes) exceeds maximum uploadable size ({} bytes = {} parts * {} bytes)",
69                file_size,
70                max_file_size,
71                MAX_PARTS_COUNT,
72                MAX_CHUNK_SIZE_BYTES
73            );
74        }
75
76        total_parts = file_size.div_ceil(chunk_size);
77    }
78
79    Ok((total_parts, chunk_size))
80}
81
82type IdOrKey = String;
83#[allow(clippy::upper_case_acronyms)]
84type UUID = Uuid;
85type JsObject = Map<String, Value>;
86type InputDatetime = String;
87#[allow(clippy::upper_case_acronyms)]
88type JSON = Value;
89
90#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
91pub struct Timestamp(pub SystemTime);
92
93impl<'de> serde::Deserialize<'de> for Timestamp {
94    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
95    where
96        D: serde::Deserializer<'de>,
97    {
98        let system_time = serde_utils::deserialize_timestamp_millis(deserializer)?;
99        Ok(Timestamp(system_time))
100    }
101}
102
103const PAGE_SIZE: usize = 20;
104
105#[derive(Debug, PartialEq, Serialize, Deserialize)]
106pub struct Upload(usize);
107
108#[derive(GraphQLQuery)]
109#[graphql(
110    schema_path = "schema.gql",
111    query_path = "src/graphql/list.graphql",
112    response_derives = "Debug, Clone"
113)]
114pub struct GetCustomRecipes;
115
116#[derive(GraphQLQuery)]
117#[graphql(
118    schema_path = "schema.gql",
119    query_path = "src/graphql/job.graphql",
120    response_derives = "Debug, Clone"
121)]
122pub struct GetJob;
123
124#[derive(GraphQLQuery)]
125#[graphql(
126    schema_path = "schema.gql",
127    query_path = "src/graphql/jobs.graphql",
128    response_derives = "Debug, Clone"
129)]
130pub struct ListJobs;
131
132#[derive(GraphQLQuery)]
133#[graphql(
134    schema_path = "schema.gql",
135    query_path = "src/graphql/cancel.graphql",
136    response_derives = "Debug, Clone"
137)]
138pub struct CancelJob;
139
140#[derive(GraphQLQuery)]
141#[graphql(
142    schema_path = "schema.gql",
143    query_path = "src/graphql/models.graphql",
144    response_derives = "Debug, Clone"
145)]
146pub struct ListModels;
147
148#[derive(GraphQLQuery)]
149#[graphql(
150    schema_path = "schema.gql",
151    query_path = "src/graphql/all_models.graphql",
152    response_derives = "Debug, Clone"
153)]
154pub struct ListAllModels;
155
156impl Display for get_job::JobStatus {
157    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
158        match self {
159            get_job::JobStatus::PENDING => write!(f, "Pending"),
160            get_job::JobStatus::RUNNING => write!(f, "Running"),
161            get_job::JobStatus::COMPLETED => write!(f, "Completed"),
162            get_job::JobStatus::FAILED => write!(f, "Failed"),
163            get_job::JobStatus::CANCELED => write!(f, "Canceled"),
164            get_job::JobStatus::Other(_) => write!(f, "Unknown"),
165        }
166    }
167}
168
169#[derive(GraphQLQuery)]
170#[graphql(
171    schema_path = "schema.gql",
172    query_path = "src/graphql/publish.graphql",
173    response_derives = "Debug, Clone"
174)]
175pub struct PublishCustomRecipe;
176
177#[derive(GraphQLQuery)]
178#[graphql(
179    schema_path = "schema.gql",
180    query_path = "src/graphql/upload_dataset.graphql",
181    response_derives = "Debug, Clone"
182)]
183pub struct UploadDataset;
184
185#[derive(GraphQLQuery)]
186#[graphql(
187    schema_path = "schema.gql",
188    query_path = "src/graphql/create_dataset_from_multipart.graphql",
189    response_derives = "Debug, Clone"
190)]
191pub struct CreateDatasetFromMultipart;
192
193#[derive(GraphQLQuery)]
194#[graphql(
195    schema_path = "schema.gql",
196    query_path = "src/graphql/run.graphql",
197    response_derives = "Debug, Clone"
198)]
199pub struct RunCustomRecipe;
200
201#[derive(GraphQLQuery)]
202#[graphql(
203    schema_path = "schema.gql",
204    query_path = "src/graphql/usecases.graphql",
205    response_derives = "Debug, Clone"
206)]
207pub struct ListUseCases;
208
209#[derive(GraphQLQuery)]
210#[graphql(
211    schema_path = "schema.gql",
212    query_path = "src/graphql/pools.graphql",
213    response_derives = "Debug, Clone"
214)]
215pub struct ListComputePools;
216
217#[derive(GraphQLQuery)]
218#[graphql(
219    schema_path = "schema.gql",
220    query_path = "src/graphql/recipe.graphql",
221    response_derives = "Debug, Clone"
222)]
223pub struct GetRecipe;
224
225const INIT_CHUNKED_UPLOAD_ROUTE: &str = "v1/upload/init";
226const UPLOAD_PART_ROUTE: &str = "v1/upload/part";
227const ABORT_CHUNKED_UPLOAD_ROUTE: &str = "v1/upload/abort";
228
229pub struct AdaptiveClient {
230    client: Client,
231    graphql_url: Url,
232    rest_base_url: Url,
233    auth_token: String,
234}
235
236impl AdaptiveClient {
237    pub fn new(api_base_url: Url, auth_token: String) -> Self {
238        let graphql_url = api_base_url
239            .join("graphql")
240            .expect("Failed to append graphql to base URL");
241
242        Self {
243            client: Client::new(),
244            graphql_url,
245            rest_base_url: api_base_url,
246            auth_token,
247        }
248    }
249
250    async fn execute_query<T>(&self, _query: T, variables: T::Variables) -> Result<T::ResponseData>
251    where
252        T: GraphQLQuery,
253        T::Variables: serde::Serialize,
254        T::ResponseData: DeserializeOwned,
255    {
256        let request_body = T::build_query(variables);
257
258        let response = self
259            .client
260            .post(self.graphql_url.clone())
261            .bearer_auth(&self.auth_token)
262            .json(&request_body)
263            .send()
264            .await?;
265
266        let response_body: Response<T::ResponseData> = response.json().await?;
267
268        match response_body.data {
269            Some(data) => Ok(data),
270            None => {
271                if let Some(errors) = response_body.errors {
272                    bail!("GraphQL errors: {:?}", errors);
273                }
274                Err(anyhow!("No data returned from GraphQL query"))
275            }
276        }
277    }
278
279    pub async fn list_recipes(
280        &self,
281        usecase: &str,
282    ) -> Result<Vec<get_custom_recipes::GetCustomRecipesCustomRecipes>> {
283        let variables = get_custom_recipes::Variables {
284            usecase: IdOrKey::from(usecase),
285        };
286
287        let response_data = self.execute_query(GetCustomRecipes, variables).await?;
288        Ok(response_data.custom_recipes)
289    }
290
291    pub async fn get_job(&self, job_id: Uuid) -> Result<get_job::GetJobJob> {
292        let variables = get_job::Variables { id: job_id };
293
294        let response_data = self.execute_query(GetJob, variables).await?;
295
296        match response_data.job {
297            Some(job) => Ok(job),
298            None => Err(anyhow!("Job with ID '{}' not found", job_id)),
299        }
300    }
301
302    pub async fn upload_dataset<P: AsRef<Path>>(
303        &self,
304        usecase: &str,
305        name: &str,
306        dataset: P,
307    ) -> Result<upload_dataset::UploadDatasetCreateDataset> {
308        let variables = upload_dataset::Variables {
309            usecase: IdOrKey::from(usecase),
310            file: Upload(0),
311            name: Some(name.to_string()),
312        };
313
314        let operations = UploadDataset::build_query(variables);
315        let operations = serde_json::to_string(&operations)?;
316
317        let file_map = r#"{ "0": ["variables.file"] }"#;
318
319        let dataset_file = reqwest::multipart::Part::file(dataset)
320            .await
321            .context("Unable to read dataset")?;
322
323        let form = reqwest::multipart::Form::new()
324            .text("operations", operations)
325            .text("map", file_map)
326            .part("0", dataset_file);
327
328        let response = self
329            .client
330            .post(self.graphql_url.clone())
331            .bearer_auth(&self.auth_token)
332            .multipart(form)
333            .send()
334            .await?;
335
336        let response: Response<<UploadDataset as graphql_client::GraphQLQuery>::ResponseData> =
337            response.json().await?;
338
339        match response.data {
340            Some(data) => Ok(data.create_dataset),
341            None => {
342                if let Some(errors) = response.errors {
343                    bail!("GraphQL errors: {:?}", errors);
344                }
345                Err(anyhow!("No data returned from GraphQL mutation"))
346            }
347        }
348    }
349
350    pub async fn publish_recipe<P: AsRef<Path>>(
351        &self,
352        usecase: &str,
353        name: &str,
354        key: &str,
355        recipe: P,
356    ) -> Result<publish_custom_recipe::PublishCustomRecipeCreateCustomRecipe> {
357        let variables = publish_custom_recipe::Variables {
358            usecase: IdOrKey::from(usecase),
359            file: Upload(0),
360            name: Some(name.to_string()),
361            key: Some(key.to_string()),
362        };
363
364        let operations = PublishCustomRecipe::build_query(variables);
365        let operations = serde_json::to_string(&operations)?;
366
367        let file_map = r#"{ "0": ["variables.file"] }"#;
368
369        let recipe_file = reqwest::multipart::Part::file(recipe)
370            .await
371            .context("Unable to read recipe")?;
372
373        let form = reqwest::multipart::Form::new()
374            .text("operations", operations)
375            .text("map", file_map)
376            .part("0", recipe_file);
377
378        let response = self
379            .client
380            .post(self.graphql_url.clone())
381            .bearer_auth(&self.auth_token)
382            .multipart(form)
383            .send()
384            .await?;
385        let response: Response<
386            <PublishCustomRecipe as graphql_client::GraphQLQuery>::ResponseData,
387        > = response.json().await?;
388
389        match response.data {
390            Some(data) => Ok(data.create_custom_recipe),
391            None => {
392                if let Some(errors) = response.errors {
393                    bail!("GraphQL errors: {:?}", errors);
394                }
395                Err(anyhow!("No data returned from GraphQL mutation"))
396            }
397        }
398    }
399
400    pub async fn run_recipe(
401        &self,
402        usecase: &str,
403        recipe_id: &str,
404        parameters: Map<String, Value>,
405        name: Option<String>,
406        compute_pool: Option<String>,
407        num_gpus: u32,
408    ) -> Result<run_custom_recipe::RunCustomRecipeCreateJob> {
409        let variables = run_custom_recipe::Variables {
410            input: run_custom_recipe::JobInput {
411                recipe: recipe_id.to_string(),
412                use_case: usecase.to_string(),
413                args: parameters,
414                name,
415                compute_pool,
416                num_gpus: num_gpus as i64,
417            },
418        };
419
420        let response_data = self.execute_query(RunCustomRecipe, variables).await?;
421        Ok(response_data.create_job)
422    }
423
424    pub async fn list_jobs(
425        &self,
426        usecase: Option<String>,
427    ) -> Result<Vec<list_jobs::ListJobsJobsNodes>> {
428        let mut jobs = Vec::new();
429        let mut page = self.list_jobs_page(usecase.clone(), None).await?;
430        jobs.extend(page.nodes.iter().cloned());
431        while page.page_info.has_next_page {
432            page = self
433                .list_jobs_page(usecase.clone(), page.page_info.end_cursor)
434                .await?;
435            jobs.extend(page.nodes.iter().cloned());
436        }
437        Ok(jobs)
438    }
439
440    async fn list_jobs_page(
441        &self,
442        usecase: Option<String>,
443        after: Option<String>,
444    ) -> Result<list_jobs::ListJobsJobs> {
445        let variables = list_jobs::Variables {
446            filter: Some(list_jobs::ListJobsFilterInput {
447                use_case: usecase,
448                kind: Some(vec![list_jobs::JobKind::CUSTOM]),
449                status: Some(vec![
450                    list_jobs::JobStatus::RUNNING,
451                    list_jobs::JobStatus::PENDING,
452                ]),
453                timerange: None,
454                custom_recipes: None,
455                artifacts: None,
456            }),
457            cursor: Some(list_jobs::CursorPageInput {
458                first: Some(PAGE_SIZE as i64),
459                after,
460                before: None,
461                last: None,
462                offset: None,
463            }),
464        };
465
466        let response_data = self.execute_query(ListJobs, variables).await?;
467        Ok(response_data.jobs)
468    }
469
470    pub async fn cancel_job(&self, job_id: Uuid) -> Result<cancel_job::CancelJobCancelJob> {
471        let variables = cancel_job::Variables { job_id };
472
473        let response_data = self.execute_query(CancelJob, variables).await?;
474        Ok(response_data.cancel_job)
475    }
476
477    pub async fn list_models(
478        &self,
479        usecase: String,
480    ) -> Result<Vec<list_models::ListModelsUseCaseModelServices>> {
481        let variables = list_models::Variables {
482            use_case_id: usecase,
483        };
484
485        let response_data = self.execute_query(ListModels, variables).await?;
486        Ok(response_data
487            .use_case
488            .map(|use_case| use_case.model_services)
489            .unwrap_or(Vec::new()))
490    }
491
492    pub async fn list_all_models(&self) -> Result<Vec<list_all_models::ListAllModelsModels>> {
493        let variables = list_all_models::Variables {};
494
495        let response_data = self.execute_query(ListAllModels, variables).await?;
496        Ok(response_data.models)
497    }
498
499    pub async fn list_usecases(&self) -> Result<Vec<list_use_cases::ListUseCasesUseCases>> {
500        let variables = list_use_cases::Variables {};
501
502        let response_data = self.execute_query(ListUseCases, variables).await?;
503        Ok(response_data.use_cases)
504    }
505
506    pub async fn list_pools(
507        &self,
508    ) -> Result<Vec<list_compute_pools::ListComputePoolsComputePools>> {
509        let variables = list_compute_pools::Variables {};
510
511        let response_data = self.execute_query(ListComputePools, variables).await?;
512        Ok(response_data.compute_pools)
513    }
514
515    pub async fn get_recipe(
516        &self,
517        usecase: String,
518        id_or_key: String,
519    ) -> Result<Option<get_recipe::GetRecipeCustomRecipe>> {
520        let variables = get_recipe::Variables { usecase, id_or_key };
521
522        let response_data = self.execute_query(GetRecipe, variables).await?;
523        Ok(response_data.custom_recipe)
524    }
525
526    async fn init_chunked_upload(&self, total_parts: u64) -> Result<String> {
527        let url = self
528            .rest_base_url
529            .join(INIT_CHUNKED_UPLOAD_ROUTE)
530            .context("Failed to construct init upload URL")?;
531
532        let request = InitChunkedUploadRequest {
533            content_type: "application/jsonl".to_string(),
534            metadata: None,
535            total_parts_count: total_parts,
536        };
537
538        let response = self
539            .client
540            .post(url)
541            .bearer_auth(&self.auth_token)
542            .json(&request)
543            .send()
544            .await?;
545
546        if !response.status().is_success() {
547            bail!(
548                "Failed to initialize chunked upload: {} - {}",
549                response.status(),
550                response.text().await.unwrap_or_default()
551            );
552        }
553
554        let init_response: InitChunkedUploadResponse = response.json().await?;
555        Ok(init_response.session_id)
556    }
557
558    async fn upload_part(
559        &self,
560        session_id: &str,
561        part_number: u64,
562        data: Vec<u8>,
563        progress_tx: mpsc::Sender<u64>,
564    ) -> Result<()> {
565        const SUB_CHUNK_SIZE: usize = 64 * 1024;
566
567        let url = self
568            .rest_base_url
569            .join(UPLOAD_PART_ROUTE)
570            .context("Failed to construct upload part URL")?;
571
572        let chunks: Vec<Vec<u8>> = data
573            .chunks(SUB_CHUNK_SIZE)
574            .map(|chunk| chunk.to_vec())
575            .collect();
576
577        let stream = futures::stream::iter(chunks).map(move |chunk| {
578            let len = chunk.len() as u64;
579            let tx = progress_tx.clone();
580            let _ = tx.try_send(len);
581            Ok::<_, std::io::Error>(chunk)
582        });
583
584        let body = reqwest::Body::wrap_stream(stream);
585
586        let response = self
587            .client
588            .post(url)
589            .bearer_auth(&self.auth_token)
590            .query(&[
591                ("session_id", session_id),
592                ("part_number", &part_number.to_string()),
593            ])
594            .header("Content-Type", "application/octet-stream")
595            .body(body)
596            .send()
597            .await?;
598
599        if !response.status().is_success() {
600            bail!(
601                "Failed to upload part {}: {} - {}",
602                part_number,
603                response.status(),
604                response.text().await.unwrap_or_default()
605            );
606        }
607
608        Ok(())
609    }
610
611    async fn abort_chunked_upload(&self, session_id: &str) -> Result<()> {
612        let url = self
613            .rest_base_url
614            .join(ABORT_CHUNKED_UPLOAD_ROUTE)
615            .context("Failed to construct abort upload URL")?;
616
617        let request = AbortChunkedUploadRequest {
618            session_id: session_id.to_string(),
619        };
620
621        let _ = self
622            .client
623            .delete(url)
624            .bearer_auth(&self.auth_token)
625            .json(&request)
626            .send()
627            .await;
628
629        Ok(())
630    }
631
632    async fn create_dataset_from_multipart(
633        &self,
634        usecase: &str,
635        name: &str,
636        key: &str,
637        session_id: &str,
638    ) -> Result<
639        create_dataset_from_multipart::CreateDatasetFromMultipartCreateDatasetFromMultipartUpload,
640    > {
641        let variables = create_dataset_from_multipart::Variables {
642            input: create_dataset_from_multipart::DatasetCreateFromMultipartUpload {
643                use_case: usecase.to_string(),
644                name: name.to_string(),
645                key: Some(key.to_string()),
646                source: None,
647                upload_session_id: session_id.to_string(),
648            },
649        };
650
651        let response_data = self
652            .execute_query(CreateDatasetFromMultipart, variables)
653            .await?;
654        Ok(response_data.create_dataset_from_multipart_upload)
655    }
656
657    pub fn chunked_upload_dataset<'a, P: AsRef<Path> + Send + 'a>(
658        &'a self,
659        usecase: &'a str,
660        name: &'a str,
661        key: &'a str,
662        dataset: P,
663    ) -> Result<BoxStream<'a, Result<UploadEvent>>> {
664        let file_size = std::fs::metadata(dataset.as_ref())
665            .context("Failed to get file metadata")?
666            .len();
667
668        let (total_parts, chunk_size) = calculate_upload_parts(file_size)?;
669
670        let stream = async_stream::try_stream! {
671            yield UploadEvent::Progress(ChunkedUploadProgress {
672                bytes_uploaded: 0,
673                total_bytes: file_size,
674            });
675
676            let session_id = self.init_chunked_upload(total_parts).await?;
677
678            let mut file =
679                File::open(dataset.as_ref()).context("Failed to open dataset file")?;
680            let mut buffer = vec![0u8; chunk_size as usize];
681            let mut bytes_uploaded = 0u64;
682
683            let (progress_tx, mut progress_rx) = mpsc::channel::<u64>(64);
684
685            for part_number in 1..=total_parts {
686                let bytes_read = file.read(&mut buffer).context("Failed to read chunk")?;
687                let chunk_data = buffer[..bytes_read].to_vec();
688
689                let upload_fut = self.upload_part(&session_id, part_number, chunk_data, progress_tx.clone());
690                tokio::pin!(upload_fut);
691
692                let upload_result: Result<()> = loop {
693                    tokio::select! {
694                        biased;
695                        result = &mut upload_fut => {
696                            break result;
697                        }
698                        Some(bytes) = progress_rx.recv() => {
699                            bytes_uploaded += bytes;
700                            yield UploadEvent::Progress(ChunkedUploadProgress {
701                                bytes_uploaded,
702                                total_bytes: file_size,
703                            });
704                        }
705                    }
706                };
707
708                if let Err(e) = upload_result {
709                    let _ = self.abort_chunked_upload(&session_id).await;
710                    Err(e)?;
711                }
712            }
713
714            let create_result = self
715                .create_dataset_from_multipart(usecase, name, key, &session_id)
716                .await;
717
718            match create_result {
719                Ok(response) => {
720                    yield UploadEvent::Complete(response);
721                }
722                Err(e) => {
723                    let _ = self.abort_chunked_upload(&session_id).await;
724                    Err(anyhow!("Failed to create dataset: {}", e))?;
725                }
726            }
727        };
728
729        Ok(Box::pin(stream))
730    }
731}