burn_central_core/
client.rs

1//! This module provides the [BurnCentral] struct, which is used to interact with the Burn Central service.
2
3use crate::artifacts::ExperimentArtifactClient;
4use crate::experiment::{ExperimentRun, ExperimentTrackerError};
5use crate::models::ModelRegistry;
6use crate::schemas::{ExperimentPath, User};
7use burn_central_client::{BurnCentralCredentials, Client, ClientError};
8use reqwest::Url;
9
10/// Errors that can occur during the initialization of the [BurnCentral] client.
11#[derive(Debug, thiserror::Error)]
12pub enum InitError {
13    /// Represents an error related to the client.
14    #[error("Client error: {0}")]
15    Client(#[from] ClientError),
16    /// Represents an error when the endpoint URL is invalid.
17    #[error("Failed to parse endpoint URL: {0}")]
18    InvalidEndpointUrl(String),
19    /// Represents an error when an environment variable is not set.
20    #[error("Environment variable not set: {0}")]
21    EnvNotSet(String),
22}
23
24#[derive(Debug, thiserror::Error)]
25pub enum BurnCentralError {
26    // Input validation errors
27    #[error("Invalid experiment path: {0}")]
28    InvalidExperimentPath(String),
29    #[error("Invalid project path: {0}")]
30    InvalidProjectPath(String),
31    #[error("Invalid experiment number: {0}")]
32    InvalidExperimentNumber(String),
33    #[error("Invalid model path: {0}")]
34    InvalidModelPath(String),
35
36    /// Represents an error related to client operations.
37    ///
38    /// This error variant is used to encapsulate client-specific errors along with additional context
39    /// and the underlying source error for more detailed debugging.
40    ///
41    /// # Fields
42    /// - `context` (String): A description or additional information about the client error context.
43    /// - `source` (ClientError): The underlying source of the client error, providing more details about the cause.
44    #[error("Client error: {context}\nSource: {source}")]
45    Client {
46        context: String,
47        source: ClientError,
48    },
49    /// Represents an error related to the experiment tracker.
50    #[error("Experiment error: {0}")]
51    ExperimentTracker(#[from] ExperimentTrackerError),
52
53    /// Error that should be used when the user is not logged in but tries to perform an operation that requires authentication.
54    #[error("The user is not authenticated.")]
55    Unauthenticated,
56
57    /// Error that should be used when the client performs operations that can fail due to IO issues.
58    #[error(transparent)]
59    Io(#[from] std::io::Error),
60
61    /// Error that should be used when the client encounters an error that is not specifically handled.
62    #[error("Internal error: {0}")]
63    Internal(String),
64}
65
66/// This builder struct is used to create a [BurnCentral] client.
67pub struct BurnCentralBuilder {
68    endpoint: Option<String>,
69    credentials: BurnCentralCredentials,
70}
71
72impl BurnCentralBuilder {
73    /// Creates a new [BurnCentralBuilder] with the given credentials.
74    pub fn new(credentials: impl Into<BurnCentralCredentials>) -> Self {
75        BurnCentralBuilder {
76            endpoint: None,
77            credentials: credentials.into(),
78        }
79    }
80
81    /// Sets the endpoint for the [BurnCentral] client.
82    pub fn with_endpoint(mut self, endpoint: impl Into<String>) -> Self {
83        self.endpoint = Some(endpoint.into());
84        self
85    }
86
87    /// Builds the [BurnCentral] client.
88    pub fn build(self) -> Result<BurnCentral, InitError> {
89        let url = match self.endpoint {
90            Some(s) => s
91                .parse::<Url>()
92                .map_err(|e| InitError::InvalidEndpointUrl(e.to_string()))?,
93            None => {
94                Url::parse("https://central.burn.dev/api/").expect("Default URL should be valid")
95            }
96        };
97        let client = Client::new(url, &self.credentials)?;
98        Ok(BurnCentral::new(client))
99    }
100}
101
102/// This struct provides the main interface to interact with Burn Central.
103#[derive(Clone)]
104pub struct BurnCentral {
105    client: Client,
106}
107
108impl BurnCentral {
109    /// Creates a new [BurnCentral] instance with the given credentials.
110    pub fn login(credentials: impl Into<BurnCentralCredentials>) -> Result<Self, InitError> {
111        let credentials = credentials.into();
112        BurnCentralBuilder::new(credentials).build()
113    }
114
115    /// Creates a new [BurnCentralBuilder] to configure the client.
116    pub fn builder(credentials: impl Into<BurnCentralCredentials>) -> BurnCentralBuilder {
117        BurnCentralBuilder::new(credentials)
118    }
119
120    /// Creates a new [BurnCentral] instance from environment variables.
121    ///
122    /// This function reads the `BURN_CENTRAL_ENDPOINT` and `BURN_CENTRAL_API_KEY` environment variables.
123    /// If the `BURN_CENTRAL_ENDPOINT` is not set, it defaults to `https://central.burn.dev/api/`.
124    pub fn from_env() -> Result<Self, InitError> {
125        let endpoint = std::env::var("BURN_CENTRAL_ENDPOINT")
126            .unwrap_or_else(|_| "https://central.burn.dev/api/".to_string())
127            .parse::<Url>()
128            .map_err(|_| InitError::InvalidEndpointUrl("BURN_CENTRAL_ENDPOINT".to_string()))?;
129        let credentials = BurnCentralCredentials::from_env()
130            .map_err(|_| InitError::EnvNotSet("BURN_CENTRAL_API_KEY".to_string()))?;
131
132        BurnCentralBuilder::new(credentials)
133            .with_endpoint(endpoint.as_str())
134            .build()
135    }
136
137    /// Creates a new instance of [BurnCentral] with the given [Client].
138    fn new(client: Client) -> Self {
139        BurnCentral { client }
140    }
141
142    /// Returns the current user information.
143    pub fn me(&self) -> Result<User, BurnCentralError> {
144        let user = self.client.get_current_user().map_err(|e| {
145            if matches!(e, ClientError::Unauthorized) {
146                BurnCentralError::Unauthenticated
147            } else {
148                BurnCentralError::Client {
149                    context: "Failed to get current user".to_string(),
150                    source: e,
151                }
152            }
153        })?;
154
155        Ok(User {
156            username: user.username,
157            email: user.email,
158            namespace: user.namespace,
159        })
160    }
161
162    /// Start a new experiment. This will create a new experiment on the Burn Central backend and start it.
163    pub fn start_experiment(
164        &self,
165        namespace: &str,
166        project_name: &str,
167        digest: String,
168        routine: String,
169    ) -> Result<ExperimentRun, BurnCentralError> {
170        let experiment = self
171            .client
172            .create_experiment(namespace, project_name, None, digest, routine)
173            .map_err(|e| BurnCentralError::Client {
174                context: format!("Failed to create experiment for {namespace}/{project_name}"),
175                source: e,
176            })?;
177        let experiment_path = ExperimentPath::try_from(format!(
178            "{}/{}/{}",
179            namespace, project_name, experiment.experiment_num
180        ))?;
181
182        println!("Experiment num: {}", experiment.experiment_num);
183
184        ExperimentRun::new(self.client.clone(), experiment_path)
185            .map_err(BurnCentralError::ExperimentTracker)
186    }
187
188    pub fn artifacts(
189        &self,
190        owner: &str,
191        project: &str,
192        exp_num: i32,
193    ) -> Result<ExperimentArtifactClient, BurnCentralError> {
194        let exp_path = ExperimentPath::try_from(format!("{}/{}/{}", owner, project, exp_num))?;
195        Ok(ExperimentArtifactClient::new(self.client.clone(), exp_path))
196    }
197
198    /// Create a model registry for downloading models from Burn Central.
199    /// Models are project-scoped and identified by namespace/project/model_name.
200    pub fn models(&self) -> ModelRegistry {
201        ModelRegistry::new(self.client.clone())
202    }
203}