adaptive_client_rust/
lib.rs

1use std::{fmt::Display, fs::File, io::Read, path::Path, time::SystemTime};
2
3use futures::{StreamExt, stream::BoxStream};
4use thiserror::Error;
5use tokio::sync::mpsc;
6
7use graphql_client::{GraphQLQuery, Response};
8use reqwest::Client;
9use serde::{Deserialize, Serialize, de::DeserializeOwned};
10use serde_json::{Map, Value};
11use url::Url;
12use uuid::Uuid;
13
14mod rest_types;
15mod serde_utils;
16
17use rest_types::{AbortChunkedUploadRequest, InitChunkedUploadRequest, InitChunkedUploadResponse};
18
19const MEGABYTE: u64 = 1024 * 1024; // 1MB
20pub const MIN_CHUNK_SIZE_BYTES: u64 = 5 * MEGABYTE;
21const MAX_CHUNK_SIZE_BYTES: u64 = 100 * MEGABYTE;
22const MAX_PARTS_COUNT: u64 = 10000;
23
24const SIZE_500MB: u64 = 500 * MEGABYTE;
25const SIZE_10GB: u64 = 10 * 1024 * MEGABYTE;
26const SIZE_50GB: u64 = 50 * 1024 * MEGABYTE;
27
28#[derive(Error, Debug)]
29pub enum AdaptiveError {
30    #[error("HTTP error: {0}")]
31    HttpError(#[from] reqwest::Error),
32
33    #[error("JSON error: {0}")]
34    JsonError(#[from] serde_json::Error),
35
36    #[error("IO error: {0}")]
37    IoError(#[from] std::io::Error),
38
39    #[error("URL parse error: {0}")]
40    UrlParseError(#[from] url::ParseError),
41
42    #[error("File too small for chunked upload: {size} bytes (minimum: {min_size} bytes)")]
43    FileTooSmall { size: u64, min_size: u64 },
44
45    #[error("File too large: {size} bytes exceeds maximum {max_size} bytes")]
46    FileTooLarge { size: u64, max_size: u64 },
47
48    #[error("GraphQL errors: {0:?}")]
49    GraphQLErrors(Vec<graphql_client::Error>),
50
51    #[error("No data returned from GraphQL")]
52    NoGraphQLData,
53
54    #[error("Job not found: {0}")]
55    JobNotFound(Uuid),
56
57    #[error("Failed to initialize chunked upload: {status} - {body}")]
58    ChunkedUploadInitFailed { status: String, body: String },
59
60    #[error("Failed to upload part {part_number}: {status} - {body}")]
61    ChunkedUploadPartFailed {
62        part_number: u64,
63        status: String,
64        body: String,
65    },
66
67    #[error("Failed to create dataset: {0}")]
68    DatasetCreationFailed(String),
69}
70
71type Result<T> = std::result::Result<T, AdaptiveError>;
72
73#[derive(Clone, Debug, Default)]
74pub struct ChunkedUploadProgress {
75    pub bytes_uploaded: u64,
76    pub total_bytes: u64,
77}
78
79#[derive(Debug)]
80pub enum UploadEvent {
81    Progress(ChunkedUploadProgress),
82    Complete(
83        create_dataset_from_multipart::CreateDatasetFromMultipartCreateDatasetFromMultipartUpload,
84    ),
85}
86
87pub fn calculate_upload_parts(file_size: u64) -> Result<(u64, u64)> {
88    if file_size < MIN_CHUNK_SIZE_BYTES {
89        return Err(AdaptiveError::FileTooSmall {
90            size: file_size,
91            min_size: MIN_CHUNK_SIZE_BYTES,
92        });
93    }
94
95    let mut chunk_size = if file_size < SIZE_500MB {
96        5 * MEGABYTE
97    } else if file_size < SIZE_10GB {
98        10 * MEGABYTE
99    } else if file_size < SIZE_50GB {
100        50 * MEGABYTE
101    } else {
102        100 * MEGABYTE
103    };
104
105    let mut total_parts = file_size.div_ceil(chunk_size);
106
107    if total_parts > MAX_PARTS_COUNT {
108        chunk_size = file_size.div_ceil(MAX_PARTS_COUNT);
109
110        if chunk_size > MAX_CHUNK_SIZE_BYTES {
111            let max_file_size = MAX_CHUNK_SIZE_BYTES * MAX_PARTS_COUNT;
112            return Err(AdaptiveError::FileTooLarge {
113                size: file_size,
114                max_size: max_file_size,
115            });
116        }
117
118        total_parts = file_size.div_ceil(chunk_size);
119    }
120
121    Ok((total_parts, chunk_size))
122}
123
124type IdOrKey = String;
125#[allow(clippy::upper_case_acronyms)]
126type UUID = Uuid;
127type JsObject = Map<String, Value>;
128type InputDatetime = String;
129#[allow(clippy::upper_case_acronyms)]
130type JSON = Value;
131
132#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
133pub struct Timestamp(pub SystemTime);
134
135impl<'de> serde::Deserialize<'de> for Timestamp {
136    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
137    where
138        D: serde::Deserializer<'de>,
139    {
140        let system_time = serde_utils::deserialize_timestamp_millis(deserializer)?;
141        Ok(Timestamp(system_time))
142    }
143}
144
145const PAGE_SIZE: usize = 20;
146
147#[derive(Debug, PartialEq, Serialize, Deserialize)]
148pub struct Upload(usize);
149
150#[derive(GraphQLQuery)]
151#[graphql(
152    schema_path = "schema.gql",
153    query_path = "src/graphql/list.graphql",
154    response_derives = "Debug, Clone"
155)]
156pub struct GetCustomRecipes;
157
158#[derive(GraphQLQuery)]
159#[graphql(
160    schema_path = "schema.gql",
161    query_path = "src/graphql/job.graphql",
162    response_derives = "Debug, Clone"
163)]
164pub struct GetJob;
165
166#[derive(GraphQLQuery)]
167#[graphql(
168    schema_path = "schema.gql",
169    query_path = "src/graphql/jobs.graphql",
170    response_derives = "Debug, Clone"
171)]
172pub struct ListJobs;
173
174#[derive(GraphQLQuery)]
175#[graphql(
176    schema_path = "schema.gql",
177    query_path = "src/graphql/cancel.graphql",
178    response_derives = "Debug, Clone"
179)]
180pub struct CancelJob;
181
182#[derive(GraphQLQuery)]
183#[graphql(
184    schema_path = "schema.gql",
185    query_path = "src/graphql/models.graphql",
186    response_derives = "Debug, Clone"
187)]
188pub struct ListModels;
189
190#[derive(GraphQLQuery)]
191#[graphql(
192    schema_path = "schema.gql",
193    query_path = "src/graphql/all_models.graphql",
194    response_derives = "Debug, Clone"
195)]
196pub struct ListAllModels;
197
198impl Display for get_job::JobStatus {
199    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
200        match self {
201            get_job::JobStatus::PENDING => write!(f, "Pending"),
202            get_job::JobStatus::RUNNING => write!(f, "Running"),
203            get_job::JobStatus::COMPLETED => write!(f, "Completed"),
204            get_job::JobStatus::FAILED => write!(f, "Failed"),
205            get_job::JobStatus::CANCELED => write!(f, "Canceled"),
206            get_job::JobStatus::Other(_) => write!(f, "Unknown"),
207        }
208    }
209}
210
211#[derive(GraphQLQuery)]
212#[graphql(
213    schema_path = "schema.gql",
214    query_path = "src/graphql/publish.graphql",
215    response_derives = "Debug, Clone"
216)]
217pub struct PublishCustomRecipe;
218
219#[derive(GraphQLQuery)]
220#[graphql(
221    schema_path = "schema.gql",
222    query_path = "src/graphql/upload_dataset.graphql",
223    response_derives = "Debug, Clone"
224)]
225pub struct UploadDataset;
226
227#[derive(GraphQLQuery)]
228#[graphql(
229    schema_path = "schema.gql",
230    query_path = "src/graphql/create_dataset_from_multipart.graphql",
231    response_derives = "Debug, Clone"
232)]
233pub struct CreateDatasetFromMultipart;
234
235#[derive(GraphQLQuery)]
236#[graphql(
237    schema_path = "schema.gql",
238    query_path = "src/graphql/run.graphql",
239    response_derives = "Debug, Clone"
240)]
241pub struct RunCustomRecipe;
242
243#[derive(GraphQLQuery)]
244#[graphql(
245    schema_path = "schema.gql",
246    query_path = "src/graphql/usecases.graphql",
247    response_derives = "Debug, Clone"
248)]
249pub struct ListUseCases;
250
251#[derive(GraphQLQuery)]
252#[graphql(
253    schema_path = "schema.gql",
254    query_path = "src/graphql/pools.graphql",
255    response_derives = "Debug, Clone"
256)]
257pub struct ListComputePools;
258
259#[derive(GraphQLQuery)]
260#[graphql(
261    schema_path = "schema.gql",
262    query_path = "src/graphql/recipe.graphql",
263    response_derives = "Debug, Clone"
264)]
265pub struct GetRecipe;
266
267const INIT_CHUNKED_UPLOAD_ROUTE: &str = "v1/upload/init";
268const UPLOAD_PART_ROUTE: &str = "v1/upload/part";
269const ABORT_CHUNKED_UPLOAD_ROUTE: &str = "v1/upload/abort";
270
271pub struct AdaptiveClient {
272    client: Client,
273    graphql_url: Url,
274    rest_base_url: Url,
275    auth_token: String,
276}
277
278impl AdaptiveClient {
279    pub fn new(api_base_url: Url, auth_token: String) -> Self {
280        let graphql_url = api_base_url
281            .join("graphql")
282            .expect("Failed to append graphql to base URL");
283
284        Self {
285            client: Client::new(),
286            graphql_url,
287            rest_base_url: api_base_url,
288            auth_token,
289        }
290    }
291
292    async fn execute_query<T>(&self, _query: T, variables: T::Variables) -> Result<T::ResponseData>
293    where
294        T: GraphQLQuery,
295        T::Variables: serde::Serialize,
296        T::ResponseData: DeserializeOwned,
297    {
298        let request_body = T::build_query(variables);
299
300        let response = self
301            .client
302            .post(self.graphql_url.clone())
303            .bearer_auth(&self.auth_token)
304            .json(&request_body)
305            .send()
306            .await?;
307
308        let response_body: Response<T::ResponseData> = response.json().await?;
309
310        match response_body.data {
311            Some(data) => Ok(data),
312            None => {
313                if let Some(errors) = response_body.errors {
314                    return Err(AdaptiveError::GraphQLErrors(errors));
315                }
316                Err(AdaptiveError::NoGraphQLData)
317            }
318        }
319    }
320
321    pub async fn list_recipes(
322        &self,
323        usecase: &str,
324    ) -> Result<Vec<get_custom_recipes::GetCustomRecipesCustomRecipes>> {
325        let variables = get_custom_recipes::Variables {
326            usecase: IdOrKey::from(usecase),
327        };
328
329        let response_data = self.execute_query(GetCustomRecipes, variables).await?;
330        Ok(response_data.custom_recipes)
331    }
332
333    pub async fn get_job(&self, job_id: Uuid) -> Result<get_job::GetJobJob> {
334        let variables = get_job::Variables { id: job_id };
335
336        let response_data = self.execute_query(GetJob, variables).await?;
337
338        match response_data.job {
339            Some(job) => Ok(job),
340            None => Err(AdaptiveError::JobNotFound(job_id)),
341        }
342    }
343
344    pub async fn upload_dataset<P: AsRef<Path>>(
345        &self,
346        usecase: &str,
347        name: &str,
348        dataset: P,
349    ) -> Result<upload_dataset::UploadDatasetCreateDataset> {
350        let variables = upload_dataset::Variables {
351            usecase: IdOrKey::from(usecase),
352            file: Upload(0),
353            name: Some(name.to_string()),
354        };
355
356        let operations = UploadDataset::build_query(variables);
357        let operations = serde_json::to_string(&operations)?;
358
359        let file_map = r#"{ "0": ["variables.file"] }"#;
360
361        let dataset_file = reqwest::multipart::Part::file(dataset).await?;
362
363        let form = reqwest::multipart::Form::new()
364            .text("operations", operations)
365            .text("map", file_map)
366            .part("0", dataset_file);
367
368        let response = self
369            .client
370            .post(self.graphql_url.clone())
371            .bearer_auth(&self.auth_token)
372            .multipart(form)
373            .send()
374            .await?;
375
376        let response: Response<<UploadDataset as graphql_client::GraphQLQuery>::ResponseData> =
377            response.json().await?;
378
379        match response.data {
380            Some(data) => Ok(data.create_dataset),
381            None => {
382                if let Some(errors) = response.errors {
383                    return Err(AdaptiveError::GraphQLErrors(errors));
384                }
385                Err(AdaptiveError::NoGraphQLData)
386            }
387        }
388    }
389
390    pub async fn publish_recipe<P: AsRef<Path>>(
391        &self,
392        usecase: &str,
393        name: &str,
394        key: &str,
395        recipe: P,
396    ) -> Result<publish_custom_recipe::PublishCustomRecipeCreateCustomRecipe> {
397        let variables = publish_custom_recipe::Variables {
398            usecase: IdOrKey::from(usecase),
399            file: Upload(0),
400            name: Some(name.to_string()),
401            key: Some(key.to_string()),
402        };
403
404        let operations = PublishCustomRecipe::build_query(variables);
405        let operations = serde_json::to_string(&operations)?;
406
407        let file_map = r#"{ "0": ["variables.file"] }"#;
408
409        let recipe_file = reqwest::multipart::Part::file(recipe).await?;
410
411        let form = reqwest::multipart::Form::new()
412            .text("operations", operations)
413            .text("map", file_map)
414            .part("0", recipe_file);
415
416        let response = self
417            .client
418            .post(self.graphql_url.clone())
419            .bearer_auth(&self.auth_token)
420            .multipart(form)
421            .send()
422            .await?;
423        let response: Response<
424            <PublishCustomRecipe as graphql_client::GraphQLQuery>::ResponseData,
425        > = response.json().await?;
426
427        match response.data {
428            Some(data) => Ok(data.create_custom_recipe),
429            None => {
430                if let Some(errors) = response.errors {
431                    return Err(AdaptiveError::GraphQLErrors(errors));
432                }
433                Err(AdaptiveError::NoGraphQLData)
434            }
435        }
436    }
437
438    pub async fn run_recipe(
439        &self,
440        usecase: &str,
441        recipe_id: &str,
442        parameters: Map<String, Value>,
443        name: Option<String>,
444        compute_pool: Option<String>,
445        num_gpus: u32,
446    ) -> Result<run_custom_recipe::RunCustomRecipeCreateJob> {
447        let variables = run_custom_recipe::Variables {
448            input: run_custom_recipe::JobInput {
449                recipe: recipe_id.to_string(),
450                use_case: usecase.to_string(),
451                args: parameters,
452                name,
453                compute_pool,
454                num_gpus: num_gpus as i64,
455            },
456        };
457
458        let response_data = self.execute_query(RunCustomRecipe, variables).await?;
459        Ok(response_data.create_job)
460    }
461
462    pub async fn list_jobs(
463        &self,
464        usecase: Option<String>,
465    ) -> Result<Vec<list_jobs::ListJobsJobsNodes>> {
466        let mut jobs = Vec::new();
467        let mut page = self.list_jobs_page(usecase.clone(), None).await?;
468        jobs.extend(page.nodes.iter().cloned());
469        while page.page_info.has_next_page {
470            page = self
471                .list_jobs_page(usecase.clone(), page.page_info.end_cursor)
472                .await?;
473            jobs.extend(page.nodes.iter().cloned());
474        }
475        Ok(jobs)
476    }
477
478    async fn list_jobs_page(
479        &self,
480        usecase: Option<String>,
481        after: Option<String>,
482    ) -> Result<list_jobs::ListJobsJobs> {
483        let variables = list_jobs::Variables {
484            filter: Some(list_jobs::ListJobsFilterInput {
485                use_case: usecase,
486                kind: Some(vec![list_jobs::JobKind::CUSTOM]),
487                status: Some(vec![
488                    list_jobs::JobStatus::RUNNING,
489                    list_jobs::JobStatus::PENDING,
490                ]),
491                timerange: None,
492                custom_recipes: None,
493                artifacts: None,
494            }),
495            cursor: Some(list_jobs::CursorPageInput {
496                first: Some(PAGE_SIZE as i64),
497                after,
498                before: None,
499                last: None,
500                offset: None,
501            }),
502        };
503
504        let response_data = self.execute_query(ListJobs, variables).await?;
505        Ok(response_data.jobs)
506    }
507
508    pub async fn cancel_job(&self, job_id: Uuid) -> Result<cancel_job::CancelJobCancelJob> {
509        let variables = cancel_job::Variables { job_id };
510
511        let response_data = self.execute_query(CancelJob, variables).await?;
512        Ok(response_data.cancel_job)
513    }
514
515    pub async fn list_models(
516        &self,
517        usecase: String,
518    ) -> Result<Vec<list_models::ListModelsUseCaseModelServices>> {
519        let variables = list_models::Variables {
520            use_case_id: usecase,
521        };
522
523        let response_data = self.execute_query(ListModels, variables).await?;
524        Ok(response_data
525            .use_case
526            .map(|use_case| use_case.model_services)
527            .unwrap_or(Vec::new()))
528    }
529
530    pub async fn list_all_models(&self) -> Result<Vec<list_all_models::ListAllModelsModels>> {
531        let variables = list_all_models::Variables {};
532
533        let response_data = self.execute_query(ListAllModels, variables).await?;
534        Ok(response_data.models)
535    }
536
537    pub async fn list_usecases(&self) -> Result<Vec<list_use_cases::ListUseCasesUseCases>> {
538        let variables = list_use_cases::Variables {};
539
540        let response_data = self.execute_query(ListUseCases, variables).await?;
541        Ok(response_data.use_cases)
542    }
543
544    pub async fn list_pools(
545        &self,
546    ) -> Result<Vec<list_compute_pools::ListComputePoolsComputePools>> {
547        let variables = list_compute_pools::Variables {};
548
549        let response_data = self.execute_query(ListComputePools, variables).await?;
550        Ok(response_data.compute_pools)
551    }
552
553    pub async fn get_recipe(
554        &self,
555        usecase: String,
556        id_or_key: String,
557    ) -> Result<Option<get_recipe::GetRecipeCustomRecipe>> {
558        let variables = get_recipe::Variables { usecase, id_or_key };
559
560        let response_data = self.execute_query(GetRecipe, variables).await?;
561        Ok(response_data.custom_recipe)
562    }
563
564    async fn init_chunked_upload(&self, total_parts: u64) -> Result<String> {
565        let url = self.rest_base_url.join(INIT_CHUNKED_UPLOAD_ROUTE)?;
566
567        let request = InitChunkedUploadRequest {
568            content_type: "application/jsonl".to_string(),
569            metadata: None,
570            total_parts_count: total_parts,
571        };
572
573        let response = self
574            .client
575            .post(url)
576            .bearer_auth(&self.auth_token)
577            .json(&request)
578            .send()
579            .await?;
580
581        if !response.status().is_success() {
582            return Err(AdaptiveError::ChunkedUploadInitFailed {
583                status: response.status().to_string(),
584                body: response.text().await.unwrap_or_default(),
585            });
586        }
587
588        let init_response: InitChunkedUploadResponse = response.json().await?;
589        Ok(init_response.session_id)
590    }
591
592    async fn upload_part(
593        &self,
594        session_id: &str,
595        part_number: u64,
596        data: Vec<u8>,
597        progress_tx: mpsc::Sender<u64>,
598    ) -> Result<()> {
599        const SUB_CHUNK_SIZE: usize = 64 * 1024;
600
601        let url = self.rest_base_url.join(UPLOAD_PART_ROUTE)?;
602
603        let chunks: Vec<Vec<u8>> = data
604            .chunks(SUB_CHUNK_SIZE)
605            .map(|chunk| chunk.to_vec())
606            .collect();
607
608        let stream = futures::stream::iter(chunks).map(move |chunk| {
609            let len = chunk.len() as u64;
610            let tx = progress_tx.clone();
611            let _ = tx.try_send(len);
612            Ok::<_, std::io::Error>(chunk)
613        });
614
615        let body = reqwest::Body::wrap_stream(stream);
616
617        let response = self
618            .client
619            .post(url)
620            .bearer_auth(&self.auth_token)
621            .query(&[
622                ("session_id", session_id),
623                ("part_number", &part_number.to_string()),
624            ])
625            .header("Content-Type", "application/octet-stream")
626            .body(body)
627            .send()
628            .await?;
629
630        if !response.status().is_success() {
631            return Err(AdaptiveError::ChunkedUploadPartFailed {
632                part_number,
633                status: response.status().to_string(),
634                body: response.text().await.unwrap_or_default(),
635            });
636        }
637
638        Ok(())
639    }
640
641    async fn abort_chunked_upload(&self, session_id: &str) -> Result<()> {
642        let url = self.rest_base_url.join(ABORT_CHUNKED_UPLOAD_ROUTE)?;
643
644        let request = AbortChunkedUploadRequest {
645            session_id: session_id.to_string(),
646        };
647
648        let _ = self
649            .client
650            .delete(url)
651            .bearer_auth(&self.auth_token)
652            .json(&request)
653            .send()
654            .await;
655
656        Ok(())
657    }
658
659    async fn create_dataset_from_multipart(
660        &self,
661        usecase: &str,
662        name: &str,
663        key: &str,
664        session_id: &str,
665    ) -> Result<
666        create_dataset_from_multipart::CreateDatasetFromMultipartCreateDatasetFromMultipartUpload,
667    > {
668        let variables = create_dataset_from_multipart::Variables {
669            input: create_dataset_from_multipart::DatasetCreateFromMultipartUpload {
670                use_case: usecase.to_string(),
671                name: name.to_string(),
672                key: Some(key.to_string()),
673                source: None,
674                upload_session_id: session_id.to_string(),
675            },
676        };
677
678        let response_data = self
679            .execute_query(CreateDatasetFromMultipart, variables)
680            .await?;
681        Ok(response_data.create_dataset_from_multipart_upload)
682    }
683
684    pub fn chunked_upload_dataset<'a, P: AsRef<Path> + Send + 'a>(
685        &'a self,
686        usecase: &'a str,
687        name: &'a str,
688        key: &'a str,
689        dataset: P,
690    ) -> Result<BoxStream<'a, Result<UploadEvent>>> {
691        let file_size = std::fs::metadata(dataset.as_ref())?.len();
692
693        let (total_parts, chunk_size) = calculate_upload_parts(file_size)?;
694
695        let stream = async_stream::try_stream! {
696            yield UploadEvent::Progress(ChunkedUploadProgress {
697                bytes_uploaded: 0,
698                total_bytes: file_size,
699            });
700
701            let session_id = self.init_chunked_upload(total_parts).await?;
702
703            let mut file = File::open(dataset.as_ref())?;
704            let mut buffer = vec![0u8; chunk_size as usize];
705            let mut bytes_uploaded = 0u64;
706
707            let (progress_tx, mut progress_rx) = mpsc::channel::<u64>(64);
708
709            for part_number in 1..=total_parts {
710                let bytes_read = file.read(&mut buffer)?;
711                let chunk_data = buffer[..bytes_read].to_vec();
712
713                let upload_fut = self.upload_part(&session_id, part_number, chunk_data, progress_tx.clone());
714                tokio::pin!(upload_fut);
715
716                let upload_result: Result<()> = loop {
717                    tokio::select! {
718                        biased;
719                        result = &mut upload_fut => {
720                            break result;
721                        }
722                        Some(bytes) = progress_rx.recv() => {
723                            bytes_uploaded += bytes;
724                            yield UploadEvent::Progress(ChunkedUploadProgress {
725                                bytes_uploaded,
726                                total_bytes: file_size,
727                            });
728                        }
729                    }
730                };
731
732                if let Err(e) = upload_result {
733                    let _ = self.abort_chunked_upload(&session_id).await;
734                    Err(e)?;
735                }
736            }
737
738            let create_result = self
739                .create_dataset_from_multipart(usecase, name, key, &session_id)
740                .await;
741
742            match create_result {
743                Ok(response) => {
744                    yield UploadEvent::Complete(response);
745                }
746                Err(e) => {
747                    let _ = self.abort_chunked_upload(&session_id).await;
748                    Err(AdaptiveError::DatasetCreationFailed(e.to_string()))?;
749                }
750            }
751        };
752
753        Ok(Box::pin(stream))
754    }
755}