burn_central_client/
client.rs

1use reqwest::Url;
2use reqwest::header::COOKIE;
3use serde::{Deserialize, Serialize};
4
5use crate::credentials::BurnCentralCredentials;
6use crate::error::{ApiErrorBody, ApiErrorCode, ClientError};
7
8impl From<reqwest::Error> for ClientError {
9    fn from(error: reqwest::Error) -> Self {
10        match error.status() {
11            Some(status) => ClientError::ApiError {
12                status,
13                body: ApiErrorBody {
14                    code: ApiErrorCode::Unknown,
15                    message: error.to_string(),
16                },
17            },
18            None => ClientError::UnknownError(error.to_string()),
19        }
20    }
21}
22
23pub(crate) trait ResponseExt {
24    fn map_to_burn_central_err(self) -> Result<reqwest::blocking::Response, ClientError>;
25}
26
27impl ResponseExt for reqwest::blocking::Response {
28    fn map_to_burn_central_err(self) -> Result<reqwest::blocking::Response, ClientError> {
29        if self.status().is_success() {
30            Ok(self)
31        } else {
32            match self.status() {
33                reqwest::StatusCode::NOT_FOUND => Err(ClientError::NotFound),
34                reqwest::StatusCode::UNAUTHORIZED => Err(ClientError::Unauthorized),
35                reqwest::StatusCode::FORBIDDEN => Err(ClientError::Forbidden),
36                reqwest::StatusCode::INTERNAL_SERVER_ERROR => Err(ClientError::InternalServerError),
37                _ => Err(ClientError::ApiError {
38                    status: self.status(),
39                    body: self
40                        .text()
41                        .map_err(|e| ClientError::UnknownError(e.to_string()))?
42                        .parse::<serde_json::Value>()
43                        .and_then(serde_json::from_value::<ApiErrorBody>)
44                        .unwrap_or_else(|e| ApiErrorBody {
45                            code: ApiErrorCode::Unknown,
46                            message: e.to_string(),
47                        }),
48                }),
49            }
50        }
51    }
52}
53
54/// A client for making HTTP requests to the Burn Central API.
55///
56/// The client can be used to interact with the Burn Central server, such as creating and starting experiments, saving and loading checkpoints, and uploading logs.
57#[derive(Debug, Clone)]
58pub struct Client {
59    pub(crate) http_client: reqwest::blocking::Client,
60    pub(crate) base_url: Url,
61    pub(crate) session_cookie: Option<String>,
62    pub(crate) env: Env,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub enum Env {
67    Production,
68    Staging(u8),
69    Development,
70}
71
72impl Env {
73    pub fn get_url(&self) -> Url {
74        match self {
75            Env::Production => Url::parse("https://central.burn.dev/api/").unwrap(),
76            Env::Staging(version) => {
77                Url::parse(&format!("https://s{}-central.burn.dev/api/", version)).unwrap()
78            }
79            Env::Development => Url::parse("http://localhost:9001/").unwrap(),
80        }
81    }
82}
83
84impl Client {
85    /// Create a new HttpClient with the given base URL and API key.
86    pub fn new(env: Env, credentials: &BurnCentralCredentials) -> Result<Self, ClientError> {
87        let mut client = Client {
88            http_client: reqwest::blocking::Client::new(),
89            base_url: env.get_url(),
90            session_cookie: None,
91            env,
92        };
93
94        let cookie = client.login(credentials)?;
95        client.session_cookie = Some(cookie);
96        Ok(client)
97    }
98
99    #[deprecated]
100    /// Please use environment based constructor
101    pub fn from_url(url: Url, credentials: &BurnCentralCredentials) -> Result<Self, ClientError> {
102        let mut client = Client {
103            http_client: reqwest::blocking::Client::new(),
104            base_url: url,
105            session_cookie: None,
106            env: Env::Production,
107        };
108
109        let cookie = client.login(credentials)?;
110        client.session_cookie = Some(cookie);
111        Ok(client)
112    }
113
114    #[deprecated]
115    /// Please use environment instead of url
116    pub fn get_endpoint(&self) -> &Url {
117        &self.base_url
118    }
119
120    pub fn get_env(&self) -> &Env {
121        &self.env
122    }
123
124    pub(crate) fn get_json<R>(&self, path: impl AsRef<str>) -> Result<R, ClientError>
125    where
126        R: for<'de> serde::Deserialize<'de>,
127    {
128        let response = self.req(reqwest::Method::GET, path, None::<serde_json::Value>)?;
129        let bytes = response.bytes()?;
130        let json = serde_json::from_slice::<R>(&bytes)?;
131        Ok(json)
132    }
133
134    pub(crate) fn post_json<T, R>(
135        &self,
136        path: impl AsRef<str>,
137        body: Option<T>,
138    ) -> Result<R, ClientError>
139    where
140        T: serde::Serialize,
141        R: for<'de> serde::Deserialize<'de>,
142    {
143        let response = self.req(reqwest::Method::POST, path, body)?;
144        let bytes = response.bytes()?;
145        let json = serde_json::from_slice::<R>(&bytes)?;
146        Ok(json)
147    }
148
149    pub(crate) fn post<T>(&self, path: impl AsRef<str>, body: Option<T>) -> Result<(), ClientError>
150    where
151        T: serde::Serialize,
152    {
153        self.req(reqwest::Method::POST, path, body).map(|_| ())
154    }
155
156    pub(crate) fn req<T: serde::Serialize>(
157        &self,
158        method: reqwest::Method,
159        path: impl AsRef<str>,
160        body: Option<T>,
161    ) -> Result<reqwest::blocking::Response, ClientError> {
162        let url = self.join(path.as_ref());
163        let request_builder = self.http_client.request(method, url);
164
165        let mut request_builder = if let Some(body) = body {
166            request_builder
167                .body(serde_json::to_vec(&body)?)
168                .header(reqwest::header::CONTENT_TYPE, "application/json")
169        } else {
170            request_builder
171        };
172
173        if let Some(cookie) = self.session_cookie.as_ref() {
174            request_builder = request_builder.header(COOKIE, cookie);
175        }
176        request_builder = request_builder.header("X-SDK-Version", env!("CARGO_PKG_VERSION"));
177
178        let response = request_builder.send()?.map_to_burn_central_err()?;
179
180        Ok(response)
181    }
182
183    // Todo update to support multiple versions
184    pub(crate) fn join(&self, path: &str) -> Url {
185        self.join_versioned(path, 1)
186    }
187
188    fn join_versioned(&self, path: &str, version: u8) -> Url {
189        self.base_url
190            .join(&format!("v{version}/"))
191            .unwrap()
192            .join(path)
193            .expect("Should be able to join url")
194    }
195
196    /// Generic method to upload bytes to the given URL.
197    pub fn upload_bytes_to_url(&self, url: &str, bytes: Vec<u8>) -> Result<(), ClientError> {
198        self.http_client
199            .put(url)
200            .body(bytes)
201            .send()?
202            .map_to_burn_central_err()?;
203
204        Ok(())
205    }
206
207    /// Generic method to download bytes from the given URL.
208    pub fn download_bytes_from_url(&self, url: &str) -> Result<Vec<u8>, ClientError> {
209        let data = self
210            .http_client
211            .get(url)
212            .send()?
213            .map_to_burn_central_err()?
214            .bytes()?
215            .to_vec();
216
217        Ok(data)
218    }
219}