Skip to main content

burn_central_client/
client.rs

1use reqwest::Url;
2use reqwest::header::COOKIE;
3use serde::{Deserialize, Serialize};
4
5use crate::credentials::BurnCentralCredentials;
6use crate::error::{ApiErrorBody, ApiErrorCode, ClientError};
7
8impl From<reqwest::Error> for ClientError {
9    fn from(error: reqwest::Error) -> Self {
10        match error.status() {
11            Some(status) => ClientError::ApiError {
12                status,
13                body: ApiErrorBody {
14                    code: ApiErrorCode::Unknown,
15                    message: error.to_string(),
16                },
17            },
18            None => ClientError::UnknownError(error.to_string()),
19        }
20    }
21}
22
23pub(crate) trait ResponseExt {
24    fn map_to_burn_central_err(self) -> Result<reqwest::blocking::Response, ClientError>;
25}
26
27impl ResponseExt for reqwest::blocking::Response {
28    fn map_to_burn_central_err(self) -> Result<reqwest::blocking::Response, ClientError> {
29        if self.status().is_success() {
30            Ok(self)
31        } else {
32            match self.status() {
33                reqwest::StatusCode::NOT_FOUND => Err(ClientError::NotFound),
34                reqwest::StatusCode::UNAUTHORIZED => Err(ClientError::Unauthorized),
35                reqwest::StatusCode::INTERNAL_SERVER_ERROR => Err(ClientError::InternalServerError),
36                _ => Err(ClientError::ApiError {
37                    status: self.status(),
38                    body: self
39                        .text()
40                        .map_err(|e| ClientError::UnknownError(e.to_string()))?
41                        .parse::<serde_json::Value>()
42                        .and_then(serde_json::from_value::<ApiErrorBody>)
43                        .unwrap_or_else(|e| ApiErrorBody {
44                            code: ApiErrorCode::Unknown,
45                            message: e.to_string(),
46                        }),
47                }),
48            }
49        }
50    }
51}
52
53/// A client for making HTTP requests to the Burn Central API.
54///
55/// The client can be used to interact with the Burn Central server, such as creating and starting experiments, saving and loading checkpoints, and uploading logs.
56#[derive(Debug, Clone)]
57pub struct Client {
58    pub(crate) http_client: reqwest::blocking::Client,
59    pub(crate) base_url: Url,
60    pub(crate) session_cookie: Option<String>,
61    pub(crate) env: Env,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub enum Env {
66    Production,
67    Staging(u8),
68    Development,
69}
70
71impl Env {
72    pub fn get_url(&self) -> Url {
73        match self {
74            Env::Production => Url::parse("https://central.burn.dev/api/").unwrap(),
75            Env::Staging(version) => {
76                Url::parse(&format!("https://s{}-central.burn.dev/api/", version)).unwrap()
77            }
78            Env::Development => Url::parse("http://localhost:9001/").unwrap(),
79        }
80    }
81}
82
83impl Client {
84    /// Create a new HttpClient with the given base URL and API key.
85    pub fn new(env: Env, credentials: &BurnCentralCredentials) -> Result<Self, ClientError> {
86        let mut client = Client {
87            http_client: reqwest::blocking::Client::new(),
88            base_url: env.get_url(),
89            session_cookie: None,
90            env,
91        };
92
93        let cookie = client.login(credentials)?;
94        client.session_cookie = Some(cookie);
95        Ok(client)
96    }
97
98    #[deprecated]
99    /// Please use environment based constructor
100    pub fn from_url(url: Url, credentials: &BurnCentralCredentials) -> Result<Self, ClientError> {
101        let mut client = Client {
102            http_client: reqwest::blocking::Client::new(),
103            base_url: url,
104            session_cookie: None,
105            env: Env::Production,
106        };
107
108        let cookie = client.login(credentials)?;
109        client.session_cookie = Some(cookie);
110        Ok(client)
111    }
112
113    #[deprecated]
114    /// Please use environment instead of url
115    pub fn get_endpoint(&self) -> &Url {
116        &self.base_url
117    }
118
119    pub fn get_env(&self) -> &Env {
120        &self.env
121    }
122
123    pub(crate) fn get_json<R>(&self, path: impl AsRef<str>) -> Result<R, ClientError>
124    where
125        R: for<'de> serde::Deserialize<'de>,
126    {
127        let response = self.req(reqwest::Method::GET, path, None::<serde_json::Value>)?;
128        let bytes = response.bytes()?;
129        let json = serde_json::from_slice::<R>(&bytes)?;
130        Ok(json)
131    }
132
133    pub(crate) fn post_json<T, R>(
134        &self,
135        path: impl AsRef<str>,
136        body: Option<T>,
137    ) -> Result<R, ClientError>
138    where
139        T: serde::Serialize,
140        R: for<'de> serde::Deserialize<'de>,
141    {
142        let response = self.req(reqwest::Method::POST, path, body)?;
143        let bytes = response.bytes()?;
144        let json = serde_json::from_slice::<R>(&bytes)?;
145        Ok(json)
146    }
147
148    pub(crate) fn post<T>(&self, path: impl AsRef<str>, body: Option<T>) -> Result<(), ClientError>
149    where
150        T: serde::Serialize,
151    {
152        self.req(reqwest::Method::POST, path, body).map(|_| ())
153    }
154
155    pub(crate) fn req<T: serde::Serialize>(
156        &self,
157        method: reqwest::Method,
158        path: impl AsRef<str>,
159        body: Option<T>,
160    ) -> Result<reqwest::blocking::Response, ClientError> {
161        let url = self.join(path.as_ref());
162        let request_builder = self.http_client.request(method, url);
163
164        let mut request_builder = if let Some(body) = body {
165            request_builder
166                .body(serde_json::to_vec(&body)?)
167                .header(reqwest::header::CONTENT_TYPE, "application/json")
168        } else {
169            request_builder
170        };
171
172        if let Some(cookie) = self.session_cookie.as_ref() {
173            request_builder = request_builder.header(COOKIE, cookie);
174        }
175        request_builder = request_builder.header("X-SDK-Version", env!("CARGO_PKG_VERSION"));
176
177        tracing::debug!("Sending request to Burn Central: {:?}", request_builder);
178
179        let response = request_builder.send()?.map_to_burn_central_err()?;
180
181        tracing::debug!("Received response from Burn Central: {:?}", response);
182
183        Ok(response)
184    }
185
186    // Todo update to support multiple versions
187    pub(crate) fn join(&self, path: &str) -> Url {
188        self.join_versioned(path, 1)
189    }
190
191    fn join_versioned(&self, path: &str, version: u8) -> Url {
192        self.base_url
193            .join(&format!("v{version}/"))
194            .unwrap()
195            .join(path)
196            .expect("Should be able to join url")
197    }
198
199    /// Generic method to upload bytes to the given URL.
200    pub fn upload_bytes_to_url(&self, url: &str, bytes: Vec<u8>) -> Result<(), ClientError> {
201        self.http_client
202            .put(url)
203            .body(bytes)
204            .send()?
205            .map_to_burn_central_err()?;
206
207        Ok(())
208    }
209
210    /// Generic method to download bytes from the given URL.
211    pub fn download_bytes_from_url(&self, url: &str) -> Result<Vec<u8>, ClientError> {
212        let data = self
213            .http_client
214            .get(url)
215            .send()?
216            .map_to_burn_central_err()?
217            .bytes()?
218            .to_vec();
219
220        Ok(data)
221    }
222}