burn_central_core/
client.rs1use crate::artifacts::ExperimentArtifactClient;
4use crate::experiment::{ExperimentRun, ExperimentTrackerError};
5use crate::models::ModelRegistry;
6use crate::schemas::{ExperimentPath, User};
7use burn_central_client::{BurnCentralCredentials, Client, ClientError};
8use reqwest::Url;
9
10#[derive(Debug, thiserror::Error)]
12pub enum InitError {
13 #[error("Client error: {0}")]
15 Client(#[from] ClientError),
16 #[error("Failed to parse endpoint URL: {0}")]
18 InvalidEndpointUrl(String),
19 #[error("Environment variable not set: {0}")]
21 EnvNotSet(String),
22}
23
24#[derive(Debug, thiserror::Error)]
25pub enum BurnCentralError {
26 #[error("Invalid experiment path: {0}")]
28 InvalidExperimentPath(String),
29 #[error("Invalid project path: {0}")]
30 InvalidProjectPath(String),
31 #[error("Invalid experiment number: {0}")]
32 InvalidExperimentNumber(String),
33 #[error("Invalid model path: {0}")]
34 InvalidModelPath(String),
35
36 #[error("Client error: {context}\nSource: {source}")]
45 Client {
46 context: String,
47 source: ClientError,
48 },
49 #[error("Experiment error: {0}")]
51 ExperimentTracker(#[from] ExperimentTrackerError),
52
53 #[error("The user is not authenticated.")]
55 Unauthenticated,
56
57 #[error(transparent)]
59 Io(#[from] std::io::Error),
60
61 #[error("Internal error: {0}")]
63 Internal(String),
64}
65
66pub struct BurnCentralBuilder {
68 endpoint: Option<String>,
69 credentials: BurnCentralCredentials,
70}
71
72impl BurnCentralBuilder {
73 pub fn new(credentials: impl Into<BurnCentralCredentials>) -> Self {
75 BurnCentralBuilder {
76 endpoint: None,
77 credentials: credentials.into(),
78 }
79 }
80
81 pub fn with_endpoint(mut self, endpoint: impl Into<String>) -> Self {
83 self.endpoint = Some(endpoint.into());
84 self
85 }
86
87 pub fn build(self) -> Result<BurnCentral, InitError> {
89 let url = match self.endpoint {
90 Some(s) => s
91 .parse::<Url>()
92 .map_err(|e| InitError::InvalidEndpointUrl(e.to_string()))?,
93 None => {
94 Url::parse("https://central.burn.dev/api/").expect("Default URL should be valid")
95 }
96 };
97 let client = Client::new(url, &self.credentials)?;
98 Ok(BurnCentral::new(client))
99 }
100}
101
102#[derive(Clone)]
104pub struct BurnCentral {
105 client: Client,
106}
107
108impl BurnCentral {
109 pub fn login(credentials: impl Into<BurnCentralCredentials>) -> Result<Self, InitError> {
111 let credentials = credentials.into();
112 BurnCentralBuilder::new(credentials).build()
113 }
114
115 pub fn builder(credentials: impl Into<BurnCentralCredentials>) -> BurnCentralBuilder {
117 BurnCentralBuilder::new(credentials)
118 }
119
120 pub fn from_env() -> Result<Self, InitError> {
125 let endpoint = std::env::var("BURN_CENTRAL_ENDPOINT")
126 .unwrap_or_else(|_| "https://central.burn.dev/api/".to_string())
127 .parse::<Url>()
128 .map_err(|_| InitError::InvalidEndpointUrl("BURN_CENTRAL_ENDPOINT".to_string()))?;
129 let credentials = BurnCentralCredentials::from_env()
130 .map_err(|_| InitError::EnvNotSet("BURN_CENTRAL_API_KEY".to_string()))?;
131
132 BurnCentralBuilder::new(credentials)
133 .with_endpoint(endpoint.as_str())
134 .build()
135 }
136
137 fn new(client: Client) -> Self {
139 BurnCentral { client }
140 }
141
142 pub fn me(&self) -> Result<User, BurnCentralError> {
144 let user = self.client.get_current_user().map_err(|e| {
145 if matches!(e, ClientError::Unauthorized) {
146 BurnCentralError::Unauthenticated
147 } else {
148 BurnCentralError::Client {
149 context: "Failed to get current user".to_string(),
150 source: e,
151 }
152 }
153 })?;
154
155 Ok(User {
156 username: user.username,
157 email: user.email,
158 namespace: user.namespace,
159 })
160 }
161
162 pub fn start_experiment(
164 &self,
165 namespace: &str,
166 project_name: &str,
167 digest: String,
168 routine: String,
169 ) -> Result<ExperimentRun, BurnCentralError> {
170 let experiment = self
171 .client
172 .create_experiment(namespace, project_name, None, digest, routine)
173 .map_err(|e| BurnCentralError::Client {
174 context: format!("Failed to create experiment for {namespace}/{project_name}"),
175 source: e,
176 })?;
177 let experiment_path = ExperimentPath::try_from(format!(
178 "{}/{}/{}",
179 namespace, project_name, experiment.experiment_num
180 ))?;
181
182 println!("Experiment num: {}", experiment.experiment_num);
183
184 ExperimentRun::new(self.client.clone(), experiment_path)
185 .map_err(BurnCentralError::ExperimentTracker)
186 }
187
188 pub fn artifacts(
189 &self,
190 owner: &str,
191 project: &str,
192 exp_num: i32,
193 ) -> Result<ExperimentArtifactClient, BurnCentralError> {
194 let exp_path = ExperimentPath::try_from(format!("{}/{}/{}", owner, project, exp_num))?;
195 Ok(ExperimentArtifactClient::new(self.client.clone(), exp_path))
196 }
197
198 pub fn models(&self) -> ModelRegistry {
201 ModelRegistry::new(self.client.clone())
202 }
203}