1#![deny(missing_docs)]
16#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
17
18use self::responses::{
142 CreateSessRes, ImgUploadRes, ProofReq, Quotas, ReceiptDownload, SessionStatusRes, SnarkReq,
143 SnarkStatusRes, UploadRes, VersionInfo,
144};
145use duplicate::duplicate_item;
146use reqwest::header;
147use serde::{Deserialize, Serialize};
148use std::path::Path;
149use std::time::Duration;
150use thiserror::Error;
151
152pub const API_KEY_HEADER: &str = "x-api-key";
154pub const VERSION_HEADER: &str = "x-risc0-version";
156pub const API_URL_ENVVAR: &str = "BONSAI_API_URL";
158pub const API_KEY_ENVVAR: &str = "BONSAI_API_KEY";
160pub const TIMEOUT_ENVVAR: &str = "BONSAI_TIMEOUT_MS";
162const DEFAULT_TIMEOUT: u64 = 30000;
164
165#[derive(Debug, Error)]
167pub enum SdkErr {
168 #[error("server error `{0}`")]
170 InternalServerErr(String),
171 #[error("HTTP error from reqwest")]
173 HttpErr(#[from] reqwest::Error),
174 #[error("HTTP header failed to construct")]
176 HttpHeaderErr(#[from] header::InvalidHeaderValue),
177 #[error("missing BONSAI_API_KEY env var")]
179 MissingApiKey,
180 #[error("missing BONSAI_API_URL env var")]
182 MissingApiUrl,
183 #[error("failed to find file on disk: {0:?}")]
185 FileNotFound(#[from] std::io::Error),
186 #[error("Receipt not found")]
188 ReceiptNotFound,
189}
190
191enum ImageExistsOpt {
192 Exists,
193 New(ImgUploadRes),
194}
195
196pub mod responses {
198 use serde::{Deserialize, Serialize};
199
200 #[derive(Deserialize, Serialize)]
202 pub struct UploadRes {
203 pub url: String,
205 pub uuid: String,
207 }
208
209 #[derive(Deserialize, Serialize)]
211 pub struct ImgUploadRes {
212 pub url: String,
214 }
215
216 #[derive(Deserialize, Serialize)]
218 pub struct CreateSessRes {
219 pub uuid: String,
221 }
222
223 #[derive(Deserialize, Serialize)]
225 pub struct ProofReq {
226 pub img: String,
228 pub input: String,
230 pub assumptions: Vec<String>,
232 pub execute_only: bool,
234 pub exec_cycle_limit: Option<u64>,
236 }
237
238 #[derive(Serialize, Deserialize)]
240 pub struct SessionStats {
241 pub segments: usize,
243 pub total_cycles: u64,
245 pub cycles: u64,
247 }
248
249 #[derive(Deserialize, Serialize)]
251 pub struct SessionStatusRes {
252 pub status: String,
256 pub receipt_url: Option<String>,
260 pub error_msg: Option<String>,
265 pub state: Option<String>,
280 pub elapsed_time: Option<f64>,
284 pub stats: Option<SessionStats>,
290 }
291
292 #[derive(Deserialize, Serialize)]
294 pub struct ReceiptDownload {
295 pub url: String,
297 }
298
299 #[derive(Deserialize, Serialize)]
301 pub struct SnarkReq {
302 pub session_id: String,
304 }
305
306 #[derive(Deserialize, Serialize)]
308 pub struct SnarkStatusRes {
309 pub status: String,
313 pub output: Option<String>,
317 pub error_msg: Option<String>,
322 }
323
324 #[derive(Deserialize, Serialize)]
326 pub struct VersionInfo {
327 pub risc0_zkvm: Vec<String>,
329 }
330
331 #[derive(Deserialize, Serialize)]
333 pub struct Quotas {
334 pub exec_cycle_limit: i64,
336 pub concurrent_proofs: i64,
338 pub cycle_budget: i64,
340 pub cycle_usage: i64,
342 pub dedicated_executor: i32,
344 pub dedicated_gpu: i32,
346 }
347}
348
349#[cfg_attr(feature = "non_blocking",
350duplicate_item(
351 [
352 module_type [non_blocking]
353 maybe_async_attr [maybe_async::must_be_async]
354 File [tokio::fs::File]
355 HttpBody [reqwest::Body]
356 HttpClient [reqwest::Client]
357 ]
358 [
359 module_type [blocking]
360 maybe_async_attr [maybe_async::must_be_sync]
361 File [std::fs::File]
362 HttpBody [reqwest::blocking::Body]
363 HttpClient [reqwest::blocking::Client]
364 ]
365))]
366#[cfg_attr(not(feature = "non_blocking"),
367duplicate_item(
368 [
369 module_type [blocking]
370 maybe_async_attr [maybe_async::must_be_sync]
371 File [std::fs::File]
372 HttpBody [reqwest::blocking::Body]
373 HttpClient [reqwest::blocking::Client]
374 ]
375))]
376pub mod module_type {
378 use super::*;
379
380 #[derive(Clone)]
382 pub struct Client {
383 pub(crate) url: String,
384 pub(crate) client: HttpClient,
385 }
386
387 #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
389 pub struct SessionId {
390 pub uuid: String,
392 }
393
394 impl SessionId {
395 pub fn new(uuid: String) -> Self {
397 Self { uuid }
398 }
399
400 #[maybe_async_attr]
402 pub async fn status(&self, client: &Client) -> Result<SessionStatusRes, SdkErr> {
403 let url = format!("{}/sessions/status/{}", client.url, self.uuid);
404 let res = client.client.get(url).send().await?;
405
406 if !res.status().is_success() {
407 let body = res.text().await?;
408 return Err(SdkErr::InternalServerErr(body));
409 }
410 Ok(res.json::<SessionStatusRes>().await?)
411 }
412
413 #[maybe_async_attr]
422 pub async fn logs(&self, client: &Client) -> Result<String, SdkErr> {
423 let url = format!("{}/sessions/logs/{}", client.url, self.uuid);
424 let res = client.client.get(url).send().await?;
425
426 if !res.status().is_success() {
427 let body = res.text().await?;
428 return Err(SdkErr::InternalServerErr(body));
429 }
430 Ok(res.text().await?)
431 }
432
433 #[maybe_async_attr]
435 pub async fn stop(&self, client: &Client) -> Result<(), SdkErr> {
436 let url = format!("{}/sessions/stop/{}", client.url, self.uuid);
437 let res = client.client.get(url).send().await?;
438 if !res.status().is_success() {
439 let body = res.text().await?;
440 return Err(SdkErr::InternalServerErr(body));
441 }
442 Ok(())
443 }
444
445 #[maybe_async_attr]
450 pub async fn exec_only_journal(&self, client: &Client) -> Result<Vec<u8>, SdkErr> {
451 let url = format!("{}/sessions/exec_only_journal/{}", client.url, self.uuid);
452 let res = client.client.get(url).send().await?;
453
454 if !res.status().is_success() {
455 let body = res.text().await?;
456 return Err(SdkErr::InternalServerErr(body));
457 }
458 Ok(res.bytes().await?.to_vec())
459 }
460 }
461
462 #[derive(Debug, Clone, PartialEq)]
464 pub struct SnarkId {
465 pub uuid: String,
467 }
468
469 impl SnarkId {
470 pub fn new(uuid: String) -> Self {
472 Self { uuid }
473 }
474
475 #[maybe_async_attr]
477 pub async fn status(&self, client: &Client) -> Result<SnarkStatusRes, SdkErr> {
478 let url = format!("{}/snark/status/{}", client.url, self.uuid);
479 let res = client.client.get(url).send().await?;
480
481 if !res.status().is_success() {
482 let body = res.text().await?;
483 return Err(SdkErr::InternalServerErr(body));
484 }
485 Ok(res.json::<SnarkStatusRes>().await?)
486 }
487 }
488
489 fn construct_req_client(api_key: &str, version: &str) -> Result<HttpClient, SdkErr> {
491 let mut headers = header::HeaderMap::new();
492 headers.insert(API_KEY_HEADER, header::HeaderValue::from_str(api_key)?);
493 headers.insert(VERSION_HEADER, header::HeaderValue::from_str(version)?);
494
495 let timeout = match std::env::var(TIMEOUT_ENVVAR).as_deref() {
496 Ok("none") => None,
497 Ok(val) => Some(Duration::from_millis(
498 val.parse().unwrap_or(DEFAULT_TIMEOUT),
499 )),
500 Err(_) => Some(Duration::from_millis(DEFAULT_TIMEOUT)),
501 };
502 #[cfg(feature = "non_blocking")]
503 {
504 Ok(HttpClient::builder()
505 .default_headers(headers)
506 .pool_max_idle_per_host(0)
507 .timeout(timeout.unwrap_or(Duration::from_millis(DEFAULT_TIMEOUT)))
508 .build()?)
509 }
510 #[cfg(not(feature = "non_blocking"))]
511 {
512 Ok(HttpClient::builder()
513 .default_headers(headers)
514 .pool_max_idle_per_host(0)
515 .timeout(timeout)
516 .build()?)
517 }
518 }
519
520 impl Client {
521 #[cfg_attr(
533 feature = "non_blocking",
534 doc = r##"
535# Example (non-blocking):
536
537```no_run
538use bonsai_sdk;
539let url = "http://api.bonsai.xyz".to_string();
540let api_key = "my_secret_key".to_string();
541bonsai_sdk::non_blocking::Client::from_parts(url, api_key, risc0_zkvm::VERSION)
542 .expect("Failed to construct sdk client");
543```
544"##
545 )]
546 pub fn from_parts(url: String, key: String, risc0_version: &str) -> Result<Self, SdkErr> {
547 let client = construct_req_client(&key, risc0_version)?;
548 let url = url.strip_suffix('/').unwrap_or(&url).to_string();
549 Ok(Self { url, client })
550 }
551
552 #[cfg_attr(
568 feature = "non_blocking",
569 doc = r##"
570# Example (non-blocking):
571
572```no_run
573use bonsai_sdk;
574bonsai_sdk::non_blocking::Client::from_env(risc0_zkvm::VERSION)
575 .expect("Failed to construct sdk client");
576```
577"##
578 )]
579 pub fn from_env(risc0_version: &str) -> Result<Self, SdkErr> {
580 let api_url = std::env::var(API_URL_ENVVAR).map_err(|_| SdkErr::MissingApiUrl)?;
581 let api_url = api_url.strip_suffix('/').unwrap_or(&api_url);
582 let api_key = std::env::var(API_KEY_ENVVAR).map_err(|_| SdkErr::MissingApiKey)?;
583
584 let client = construct_req_client(&api_key, risc0_version)?;
585
586 Ok(Self {
587 url: api_url.to_string(),
588 client,
589 })
590 }
591
592 #[maybe_async_attr]
593 async fn get_image_upload_url(&self, image_id: &str) -> Result<ImageExistsOpt, SdkErr> {
594 let res = self
595 .client
596 .get(format!("{}/images/upload/{}", self.url, image_id))
597 .send()
598 .await?;
599
600 if res.status() == 204 {
601 return Ok(ImageExistsOpt::Exists);
602 }
603
604 if !res.status().is_success() {
605 let body = res.text().await?;
606 return Err(SdkErr::InternalServerErr(body));
607 }
608
609 Ok(ImageExistsOpt::New(res.json::<ImgUploadRes>().await?))
610 }
611
612 #[maybe_async_attr]
614 async fn put_data<T: Into<HttpBody>>(&self, url: &str, body: T) -> Result<(), SdkErr> {
615 let res = self.client.put(url).body(body).send().await?;
616 if !res.status().is_success() {
617 let body = res.text().await?;
618 return Err(SdkErr::InternalServerErr(body));
619 }
620
621 Ok(())
622 }
623
624 #[maybe_async_attr]
628 pub async fn has_img(&self, image_id: &str) -> Result<bool, SdkErr> {
629 let res_or_exists = self.get_image_upload_url(image_id).await?;
630 match res_or_exists {
631 ImageExistsOpt::Exists => Ok(true),
632 ImageExistsOpt::New(_) => Ok(false),
633 }
634 }
635
636 #[maybe_async_attr]
644 pub async fn upload_img(&self, image_id: &str, buf: Vec<u8>) -> Result<bool, SdkErr> {
645 let res_or_exists = self.get_image_upload_url(image_id).await?;
646 match res_or_exists {
647 ImageExistsOpt::Exists => Ok(true),
648 ImageExistsOpt::New(upload_res) => {
649 self.put_data(&upload_res.url, buf).await?;
650 Ok(false)
651 }
652 }
653 }
654
655 #[maybe_async_attr]
663 pub async fn upload_img_file(&self, image_id: &str, path: &Path) -> Result<bool, SdkErr> {
664 let res_or_exists = self.get_image_upload_url(image_id).await?;
665 match res_or_exists {
666 ImageExistsOpt::Exists => Ok(true),
667 ImageExistsOpt::New(upload_res) => {
668 let fd = File::open(path).await?;
669 self.put_data(&upload_res.url, fd).await?;
670 Ok(false)
671 }
672 }
673 }
674
675 #[maybe_async_attr]
677 async fn get_upload_url(&self, route: &str) -> Result<UploadRes, SdkErr> {
678 let res = self
679 .client
680 .get(format!("{}/{}/upload", self.url, route))
681 .send()
682 .await?;
683
684 if !res.status().is_success() {
685 let body = res.text().await?;
686 return Err(SdkErr::InternalServerErr(body));
687 }
688
689 Ok(res.json::<UploadRes>().await?)
690 }
691
692 #[maybe_async_attr]
694 pub async fn upload_input(&self, buf: Vec<u8>) -> Result<String, SdkErr> {
695 let upload_data = self.get_upload_url("inputs").await?;
696 self.put_data(&upload_data.url, buf).await?;
697 Ok(upload_data.uuid)
698 }
699
700 #[maybe_async_attr]
702 pub async fn upload_input_file(&self, path: &Path) -> Result<String, SdkErr> {
703 let upload_data = self.get_upload_url("inputs").await?;
704
705 let fd = File::open(path).await?;
706 self.put_data(&upload_data.url, fd).await?;
707
708 Ok(upload_data.uuid)
709 }
710
711 #[maybe_async_attr]
713 pub async fn upload_receipt(&self, buf: Vec<u8>) -> Result<String, SdkErr> {
714 let upload_data = self.get_upload_url("receipts").await?;
715 self.put_data(&upload_data.url, buf).await?;
716 Ok(upload_data.uuid)
717 }
718
719 #[maybe_async_attr]
721 pub async fn upload_receipt_file(&self, path: &Path) -> Result<String, SdkErr> {
722 let upload_data = self.get_upload_url("receipts").await?;
723
724 let fd = File::open(path).await?;
725 self.put_data(&upload_data.url, fd).await?;
726
727 Ok(upload_data.uuid)
728 }
729
730 #[maybe_async_attr]
734 pub async fn receipt_download(&self, session_id: &SessionId) -> Result<Vec<u8>, SdkErr> {
735 let res = self
736 .client
737 .get(format!("{}/receipts/{}", self.url, session_id.uuid))
738 .send()
739 .await?;
740
741 if !res.status().is_success() {
742 if res.status() == reqwest::StatusCode::NOT_FOUND {
743 return Err(SdkErr::ReceiptNotFound);
744 }
745 let body = res.text().await?;
746 return Err(SdkErr::InternalServerErr(body));
747 }
748
749 let res: ReceiptDownload = res.json().await?;
750 self.download(&res.url).await
751 }
752
753 #[maybe_async_attr]
757 pub async fn image_delete(&self, image_id: &str) -> Result<(), SdkErr> {
758 let res = self
759 .client
760 .delete(format!("{}/images/{}", self.url, image_id))
761 .send()
762 .await?;
763
764 if !res.status().is_success() {
765 let body = res.text().await?;
766 return Err(SdkErr::InternalServerErr(body));
767 }
768
769 Ok(())
770 }
771
772 #[maybe_async_attr]
776 pub async fn input_delete(&self, input_uuid: &str) -> Result<(), SdkErr> {
777 let res = self
778 .client
779 .delete(format!("{}/inputs/{}", self.url, input_uuid))
780 .send()
781 .await?;
782
783 if !res.status().is_success() {
784 let body = res.text().await?;
785 return Err(SdkErr::InternalServerErr(body));
786 }
787
788 Ok(())
789 }
790
791 #[maybe_async_attr]
798 pub async fn create_session_with_limit(
799 &self,
800 img_id: String,
801 input_id: String,
802 assumptions: Vec<String>,
803 execute_only: bool,
804 exec_cycle_limit: Option<u64>,
805 ) -> Result<SessionId, SdkErr> {
806 let url = format!("{}/sessions/create", self.url);
807
808 let req = ProofReq {
809 img: img_id,
810 input: input_id,
811 assumptions,
812 execute_only,
813 exec_cycle_limit,
814 };
815
816 let res = self.client.post(url).json(&req).send().await?;
817
818 if !res.status().is_success() {
819 let body = res.text().await?;
820 return Err(SdkErr::InternalServerErr(body));
821 }
822
823 let res: CreateSessRes = res.json().await?;
824
825 Ok(SessionId::new(res.uuid))
826 }
827
828 #[maybe_async_attr]
833 pub async fn create_session(
834 &self,
835 img_id: String,
836 input_id: String,
837 assumptions: Vec<String>,
838 execute_only: bool,
839 ) -> Result<SessionId, SdkErr> {
840 self.create_session_with_limit(img_id, input_id, assumptions, execute_only, None)
841 .await
842 }
843
844 #[maybe_async_attr]
850 pub async fn download(&self, url: &str) -> Result<Vec<u8>, SdkErr> {
851 let data = self.client.get(url).send().await?.bytes().await?;
852
853 Ok(data.into())
854 }
855
856 #[maybe_async_attr]
863 pub async fn create_snark(&self, session_id: String) -> Result<SnarkId, SdkErr> {
864 let url = format!("{}/snark/create", self.url);
865
866 let snark_req = SnarkReq { session_id };
867
868 let res = self.client.post(url).json(&snark_req).send().await?;
869
870 if !res.status().is_success() {
871 let body = res.text().await?;
872 return Err(SdkErr::InternalServerErr(body));
873 }
874
875 let res: CreateSessRes = res.json().await?;
877
878 Ok(SnarkId::new(res.uuid))
879 }
880
881 #[maybe_async_attr]
888 pub async fn version(&self) -> Result<VersionInfo, SdkErr> {
889 Ok(self
890 .client
891 .get(format!("{}/version", self.url))
892 .send()
893 .await?
894 .json::<VersionInfo>()
895 .await?)
896 }
897
898 #[maybe_async_attr]
904 pub async fn quotas(&self) -> Result<Quotas, SdkErr> {
905 Ok(self
906 .client
907 .get(format!("{}/user/quotas", self.url))
908 .send()
909 .await?
910 .json::<Quotas>()
911 .await?)
912 }
913 }
914}
915
916#[cfg(test)]
917mod tests {
918 use httpmock::prelude::*;
919 use uuid::Uuid;
920
921 use super::*;
922 use blocking::{Client, SessionId, SnarkId};
923
924 const TEST_KEY: &str = "TESTKEY";
925 const TEST_ID: &str = "0x5891b5b522d5df086d0ff0b110fbd9d21bb4fc7163af34d08286a2e846f6be03";
926 const TEST_VERSION: &str = "0.1.0";
927
928 #[test]
929 fn client_from_parts() {
930 let url = "http://127.0.0.1/stage".to_string();
931 let apikey = TEST_KEY.to_string();
932 let client = Client::from_parts(url.clone(), apikey, TEST_VERSION).unwrap();
933
934 assert_eq!(client.url, url);
935 }
936
937 #[test]
938 fn client_from_env() {
939 let url = "http://127.0.0.1/stage".to_string();
940 let apikey = TEST_KEY.to_string();
941 temp_env::with_vars(
942 vec![
943 (API_URL_ENVVAR, Some(url.clone())),
944 (API_KEY_ENVVAR, Some(apikey)),
945 ],
946 || {
947 let client = Client::from_env(TEST_VERSION).unwrap();
948 assert_eq!(client.url, url);
949 },
950 );
951 }
952
953 #[test]
954 fn client_test_slash_strip() {
955 let url = "http://127.0.0.1/".to_string();
956 let apikey = TEST_KEY.to_string();
957 temp_env::with_vars(
958 vec![(API_URL_ENVVAR, Some(url)), (API_KEY_ENVVAR, Some(apikey))],
959 || {
960 let client = Client::from_env(TEST_VERSION).unwrap();
961 assert_eq!(client.url, "http://127.0.0.1");
962 },
963 );
964 }
965
966 #[test]
967 fn image_upload() {
968 let data = vec![];
969
970 let server = MockServer::start();
971
972 let put_url = format!("http://{}/upload/{TEST_ID}", server.address());
973 let response = ImgUploadRes { url: put_url };
974
975 let get_mock = server.mock(|when, then| {
976 when.method(GET)
977 .path(format!("/images/upload/{TEST_ID}"))
978 .header(API_KEY_HEADER, TEST_KEY)
979 .header(VERSION_HEADER, TEST_VERSION);
980 then.status(200)
981 .header("content-type", "application/json")
982 .json_body_obj(&response);
983 });
984
985 let put_mock = server.mock(|when, then| {
986 when.method(PUT).path(format!("/upload/{TEST_ID}"));
987 then.status(200);
988 });
989
990 let server_url = format!("http://{}", server.address());
991 let client = Client::from_parts(server_url, TEST_KEY.to_string(), TEST_VERSION)
992 .expect("Failed to construct client");
993 let exists = client
994 .upload_img(TEST_ID, data)
995 .expect("Failed to upload input");
996 assert!(!exists);
997 get_mock.assert();
998 put_mock.assert();
999 }
1000
1001 #[cfg(feature = "non_blocking")]
1002 #[tokio::test]
1003 async fn image_upload_async() {
1004 let data = vec![];
1005
1006 let server = MockServer::start();
1007
1008 let put_url = format!("http://{}/upload/{TEST_ID}", server.address());
1009 let response = ImgUploadRes { url: put_url };
1010
1011 let get_mock = server.mock(|when, then| {
1012 when.method(GET)
1013 .path(format!("/images/upload/{TEST_ID}"))
1014 .header(API_KEY_HEADER, TEST_KEY)
1015 .header(VERSION_HEADER, TEST_VERSION);
1016 then.status(200)
1017 .header("content-type", "application/json")
1018 .json_body_obj(&response);
1019 });
1020
1021 let put_mock = server.mock(|when, then| {
1022 when.method(PUT).path(format!("/upload/{TEST_ID}"));
1023 then.status(200);
1024 });
1025
1026 let server_url = format!("http://{}", server.address());
1027 let client =
1028 super::non_blocking::Client::from_parts(server_url, TEST_KEY.to_string(), TEST_VERSION)
1029 .expect("Failed to construct client");
1030 let exists = client
1031 .upload_img(TEST_ID, data)
1032 .await
1033 .expect("Failed to upload input");
1034 assert!(!exists);
1035 get_mock.assert();
1036 put_mock.assert();
1037 }
1038
1039 #[test]
1040 fn image_upload_dup() {
1041 let data = vec![0x41];
1042
1043 let server = MockServer::start();
1044
1045 let put_url = format!("http://{}/upload/{TEST_ID}", server.address());
1046 let response = ImgUploadRes { url: put_url };
1047
1048 server.mock(|when, then| {
1049 when.method(GET)
1050 .path(format!("/images/upload/{TEST_ID}"))
1051 .header(API_KEY_HEADER, TEST_KEY)
1052 .header(VERSION_HEADER, TEST_VERSION);
1053 then.status(204).json_body_obj(&response);
1054 });
1055
1056 server.mock(|when, then| {
1057 when.method(PUT).path(format!("/upload/{TEST_ID}"));
1058 then.status(200);
1059 });
1060
1061 let server_url = format!("http://{}", server.address());
1062 let client = Client::from_parts(server_url, TEST_KEY.to_string(), TEST_VERSION)
1063 .expect("Failed to construct client");
1064 let exists = client.upload_img(TEST_ID, data).unwrap();
1065 assert!(exists);
1066 }
1067
1068 #[test]
1069 fn image_delete() {
1070 let server = MockServer::start();
1071
1072 let del_mock = server.mock(|when, then| {
1073 when.method(DELETE)
1074 .path(format!("/images/{TEST_ID}"))
1075 .header(API_KEY_HEADER, TEST_KEY)
1076 .header(VERSION_HEADER, TEST_VERSION);
1077 then.status(200);
1078 });
1079
1080 let server_url = format!("http://{}", server.address());
1081 let client = Client::from_parts(server_url, TEST_KEY.to_string(), TEST_VERSION)
1082 .expect("Failed to construct client");
1083 client.image_delete(TEST_ID).unwrap();
1084 del_mock.assert();
1085 }
1086
1087 #[test]
1088 fn input_upload() {
1089 let data = vec![];
1090
1091 let server = MockServer::start();
1092
1093 let input_uuid = Uuid::new_v4();
1094 let put_url = format!("http://{}/upload/{}", server.address(), input_uuid);
1095 let response = UploadRes {
1096 url: put_url,
1097 uuid: input_uuid.to_string(),
1098 };
1099
1100 let get_mock = server.mock(|when, then| {
1101 when.method(GET)
1102 .path("/inputs/upload")
1103 .header(API_KEY_HEADER, TEST_KEY)
1104 .header(VERSION_HEADER, TEST_VERSION);
1105 then.status(200)
1106 .header("content-type", "application/json")
1107 .json_body_obj(&response);
1108 });
1109
1110 let put_mock = server.mock(|when, then| {
1111 when.method(PUT).path(format!("/upload/{input_uuid}"));
1112 then.status(200);
1113 });
1114
1115 let server_url = format!("http://{}", server.address());
1116 let client = Client::from_parts(server_url, TEST_KEY.to_string(), TEST_VERSION)
1117 .expect("Failed to construct client");
1118 let res = client.upload_input(data).expect("Failed to upload input");
1119
1120 assert_eq!(res, response.uuid);
1121
1122 get_mock.assert();
1123 put_mock.assert();
1124 }
1125
1126 #[test]
1127 fn input_delete() {
1128 let server = MockServer::start();
1129
1130 let del_mock = server.mock(|when, then| {
1131 when.method(DELETE)
1132 .path(format!("/inputs/{TEST_ID}"))
1133 .header(API_KEY_HEADER, TEST_KEY)
1134 .header(VERSION_HEADER, TEST_VERSION);
1135 then.status(200);
1136 });
1137
1138 let server_url = format!("http://{}", server.address());
1139 let client = Client::from_parts(server_url, TEST_KEY.to_string(), TEST_VERSION)
1140 .expect("Failed to construct client");
1141 client.input_delete(TEST_ID).unwrap();
1142 del_mock.assert();
1143 }
1144
1145 #[test]
1146 fn receipt_upload() {
1147 let data = vec![];
1148
1149 let server = MockServer::start();
1150
1151 let receipt_uuid = Uuid::new_v4();
1152 let put_url = format!("http://{}/upload/{}", server.address(), receipt_uuid);
1153 let response = UploadRes {
1154 url: put_url,
1155 uuid: receipt_uuid.to_string(),
1156 };
1157
1158 let get_mock = server.mock(|when, then| {
1159 when.method(GET)
1160 .path("/receipts/upload")
1161 .header(API_KEY_HEADER, TEST_KEY)
1162 .header(VERSION_HEADER, TEST_VERSION);
1163 then.status(200)
1164 .header("content-type", "application/json")
1165 .json_body_obj(&response);
1166 });
1167
1168 let put_mock = server.mock(|when, then| {
1169 when.method(PUT).path(format!("/upload/{receipt_uuid}"));
1170 then.status(200);
1171 });
1172
1173 let server_url = format!("http://{}", server.address());
1174 let client = Client::from_parts(server_url, TEST_KEY.to_string(), TEST_VERSION)
1175 .expect("Failed to construct client");
1176 let res = client
1177 .upload_receipt(data)
1178 .expect("Failed to upload receipt");
1179
1180 assert_eq!(res, response.uuid);
1181
1182 get_mock.assert();
1183 put_mock.assert();
1184 }
1185
1186 #[test]
1187 fn receipt_download() {
1188 let server = MockServer::start();
1189 let receipt_uuid = Uuid::new_v4();
1190
1191 let download_method = "download_path";
1192 let download_url = format!("http://{}/{download_method}", server.address());
1193 let response = ReceiptDownload { url: download_url };
1194
1195 let get_mock = server.mock(|when, then| {
1196 when.method(GET)
1197 .path(format!("/receipts/{receipt_uuid}"))
1198 .header(API_KEY_HEADER, TEST_KEY)
1199 .header(VERSION_HEADER, TEST_VERSION);
1200 then.status(200)
1201 .header("content-type", "application/json")
1202 .json_body_obj(&response);
1203 });
1204
1205 let receipt_data: Vec<u8> = vec![0x41, 0x41, 0x42, 0x42];
1206 let download_mock = server.mock(|when, then| {
1207 when.method(GET)
1208 .path(format!("/{download_method}"))
1209 .header(API_KEY_HEADER, TEST_KEY)
1210 .header(VERSION_HEADER, TEST_VERSION);
1211
1212 then.body(&receipt_data).status(200);
1213 });
1214
1215 let server_url = format!("http://{}", server.address());
1216 let client = Client::from_parts(server_url, TEST_KEY.to_string(), TEST_VERSION)
1217 .expect("Failed to construct client");
1218 let res = client
1219 .receipt_download(&SessionId {
1220 uuid: receipt_uuid.to_string(),
1221 })
1222 .expect("Failed to upload receipt");
1223
1224 println!("{}", std::str::from_utf8(&res).unwrap());
1225 assert_eq!(res, receipt_data);
1226
1227 get_mock.assert();
1228 download_mock.assert();
1229 }
1230
1231 #[test]
1232 fn session_create() {
1233 let server = MockServer::start();
1234
1235 let request = ProofReq {
1236 img: TEST_ID.to_string(),
1237 input: Uuid::new_v4().to_string(),
1238 assumptions: vec![],
1239 execute_only: false,
1240 exec_cycle_limit: None,
1241 };
1242 let response = CreateSessRes {
1243 uuid: Uuid::new_v4().to_string(),
1244 };
1245
1246 let create_mock = server.mock(|when, then| {
1247 when.method(POST)
1248 .path("/sessions/create")
1249 .header("content-type", "application/json")
1250 .header(API_KEY_HEADER, TEST_KEY)
1251 .header(VERSION_HEADER, TEST_VERSION)
1252 .json_body_obj(&request);
1253 then.status(200)
1254 .header("content-type", "application/json")
1255 .json_body_obj(&response);
1256 });
1257
1258 let server_url = format!("http://{}", server.address());
1259 let client = Client::from_parts(server_url, TEST_KEY.to_string(), TEST_VERSION).unwrap();
1260
1261 let res = client
1262 .create_session_with_limit(
1263 request.img,
1264 request.input,
1265 request.assumptions,
1266 request.execute_only,
1267 request.exec_cycle_limit,
1268 )
1269 .unwrap();
1270 assert_eq!(res.uuid, response.uuid);
1271
1272 create_mock.assert();
1273 }
1274
1275 #[test]
1276 fn session_status() {
1277 let server = MockServer::start();
1278
1279 let uuid = Uuid::new_v4().to_string();
1280 let session_id = SessionId::new(uuid);
1281 let response = SessionStatusRes {
1282 status: "RUNNING".to_string(),
1283 receipt_url: None,
1284 error_msg: None,
1285 state: None,
1286 elapsed_time: None,
1287 stats: None,
1288 };
1289
1290 let create_mock = server.mock(|when, then| {
1291 when.method(GET)
1292 .path(format!("/sessions/status/{}", session_id.uuid))
1293 .header(API_KEY_HEADER, TEST_KEY)
1294 .header(VERSION_HEADER, TEST_VERSION);
1295 then.status(200)
1296 .header("content-type", "application/json")
1297 .json_body_obj(&response);
1298 });
1299
1300 let server_url = format!("http://{}", server.address());
1301 let client = Client::from_parts(server_url, TEST_KEY.to_string(), TEST_VERSION).unwrap();
1302
1303 let status = session_id.status(&client).unwrap();
1304 assert_eq!(status.status, response.status);
1305 assert_eq!(status.receipt_url, None);
1306
1307 create_mock.assert();
1308 }
1309
1310 #[test]
1311 fn session_logs() {
1312 let server = MockServer::start();
1313
1314 let uuid = Uuid::new_v4().to_string();
1315 let session_id = SessionId::new(uuid);
1316 let response = "Hello\nWorld";
1317
1318 let create_mock = server.mock(|when, then| {
1319 when.method(GET)
1320 .path(format!("/sessions/logs/{}", session_id.uuid))
1321 .header(API_KEY_HEADER, TEST_KEY)
1322 .header(VERSION_HEADER, TEST_VERSION);
1323 then.status(200)
1324 .header("content-type", "text/plain")
1325 .json_body_obj(&response);
1326 });
1327
1328 let server_url = format!("http://{}", server.address());
1329 let client = Client::from_parts(server_url, TEST_KEY.to_string(), TEST_VERSION).unwrap();
1330
1331 let logs = session_id.logs(&client).unwrap();
1332
1333 assert_eq!(logs, "\"Hello\\nWorld\"");
1334
1335 create_mock.assert();
1336 }
1337
1338 #[test]
1339 fn session_exec_only_journal() {
1340 let server = MockServer::start();
1341
1342 let uuid = Uuid::new_v4().to_string();
1343 let session_id = SessionId::new(uuid);
1344 let response = vec![0x41, 0x41, 0x41, 0x41];
1345
1346 let create_mock = server.mock(|when, then| {
1347 when.method(GET)
1348 .path(format!("/sessions/exec_only_journal/{}", session_id.uuid))
1349 .header(API_KEY_HEADER, TEST_KEY)
1350 .header(VERSION_HEADER, TEST_VERSION);
1351 then.status(200)
1352 .header("content-type", "text/plain")
1353 .body(&response);
1354 });
1355
1356 let server_url = format!("http://{}", server.address());
1357 let client = Client::from_parts(server_url, TEST_KEY.to_string(), TEST_VERSION).unwrap();
1358
1359 let journal = session_id.exec_only_journal(&client).unwrap();
1360
1361 assert_eq!(journal, response);
1362
1363 create_mock.assert();
1364 }
1365
1366 #[test]
1367 fn session_stop() {
1368 let server = MockServer::start();
1369
1370 let uuid = Uuid::new_v4().to_string();
1371 let session_id = SessionId::new(uuid);
1372
1373 let create_mock = server.mock(|when, then| {
1374 when.method(GET)
1375 .path(format!("/sessions/stop/{}", session_id.uuid))
1376 .header(API_KEY_HEADER, TEST_KEY)
1377 .header(VERSION_HEADER, TEST_VERSION);
1378 then.status(200).header("content-type", "text/plain");
1379 });
1380
1381 let server_url = format!("http://{}", server.address());
1382 let client = Client::from_parts(server_url, TEST_KEY.to_string(), TEST_VERSION).unwrap();
1383
1384 session_id.stop(&client).unwrap();
1385 create_mock.assert();
1386 }
1387
1388 #[test]
1389 fn snark_create() {
1390 let server = MockServer::start();
1391
1392 let request = SnarkReq {
1393 session_id: Uuid::new_v4().to_string(),
1394 };
1395 let response = CreateSessRes {
1396 uuid: Uuid::new_v4().to_string(),
1397 };
1398
1399 let create_mock = server.mock(|when, then| {
1400 when.method(POST)
1401 .path("/snark/create")
1402 .header("content-type", "application/json")
1403 .header(API_KEY_HEADER, TEST_KEY)
1404 .header(VERSION_HEADER, TEST_VERSION)
1405 .json_body_obj(&request);
1406 then.status(200)
1407 .header("content-type", "application/json")
1408 .json_body_obj(&response);
1409 });
1410
1411 let server_url = format!("http://{}", server.address());
1412 let client = Client::from_parts(server_url, TEST_KEY.to_string(), TEST_VERSION).unwrap();
1413
1414 let res = client.create_snark(request.session_id).unwrap();
1415 assert_eq!(res.uuid, response.uuid);
1416
1417 create_mock.assert();
1418 }
1419
1420 #[test]
1421 fn snark_status() {
1422 let server = MockServer::start();
1423
1424 let uuid = Uuid::new_v4().to_string();
1425 let snark_id = SnarkId::new(uuid);
1426 let response = SnarkStatusRes {
1427 status: "RUNNING".to_string(),
1428 output: None,
1429 error_msg: None,
1430 };
1431
1432 let create_mock = server.mock(|when, then| {
1433 when.method(GET)
1434 .path(format!("/snark/status/{}", snark_id.uuid))
1435 .header(API_KEY_HEADER, TEST_KEY)
1436 .header(VERSION_HEADER, TEST_VERSION);
1437 then.status(200)
1438 .header("content-type", "application/json")
1439 .json_body_obj(&response);
1440 });
1441
1442 let server_url = format!("http://{}", server.address());
1443 let client = Client::from_parts(server_url, TEST_KEY.to_string(), TEST_VERSION).unwrap();
1444
1445 let status = snark_id.status(&client).unwrap();
1446 assert_eq!(status.status, response.status);
1447 assert_eq!(status.output, None);
1448
1449 create_mock.assert();
1450 }
1451
1452 #[test]
1453 fn version() {
1454 let server = MockServer::start();
1455
1456 let response = VersionInfo {
1457 risc0_zkvm: vec![TEST_VERSION.into()],
1458 };
1459
1460 let get_mock = server.mock(|when, then| {
1461 when.method(GET)
1462 .path("/version")
1463 .header(API_KEY_HEADER, TEST_KEY)
1464 .header(VERSION_HEADER, TEST_VERSION);
1465 then.status(200)
1466 .header("content-type", "application/json")
1467 .json_body_obj(&response);
1468 });
1469
1470 let server_url = format!("http://{}", server.address());
1471 let client = Client::from_parts(server_url, TEST_KEY.to_string(), TEST_VERSION)
1472 .expect("Failed to construct client");
1473 let info = client.version().expect("Failed to fetch version route");
1474 assert_eq!(&info.risc0_zkvm[0], TEST_VERSION);
1475 get_mock.assert();
1476 }
1477
1478 #[test]
1479 fn quotas() {
1480 let server = MockServer::start();
1481
1482 let response = Quotas {
1483 concurrent_proofs: 10,
1484 cycle_budget: 100000,
1485 cycle_usage: 1000000,
1486 exec_cycle_limit: 500,
1487 dedicated_executor: 0,
1488 dedicated_gpu: 0,
1489 };
1490
1491 let get_mock = server.mock(|when, then| {
1492 when.method(GET)
1493 .path("/user/quotas")
1494 .header(API_KEY_HEADER, TEST_KEY)
1495 .header(VERSION_HEADER, TEST_VERSION);
1496 then.status(200)
1497 .header("content-type", "application/json")
1498 .json_body_obj(&response);
1499 });
1500
1501 let server_url = format!("http://{}", server.address());
1502 let client = Client::from_parts(server_url, TEST_KEY.to_string(), TEST_VERSION)
1503 .expect("Failed to construct client");
1504 let quota = client.quotas().expect("Failed to fetch version route");
1505 assert_eq!(quota.concurrent_proofs, response.concurrent_proofs);
1506 assert_eq!(quota.cycle_budget, response.cycle_budget);
1507 assert_eq!(quota.cycle_usage, response.cycle_usage);
1508 assert_eq!(quota.exec_cycle_limit, response.exec_cycle_limit);
1509 assert_eq!(quota.dedicated_executor, response.dedicated_executor);
1510 assert_eq!(quota.dedicated_gpu, response.dedicated_gpu);
1511
1512 get_mock.assert();
1513 }
1514}