burn_central_core/
client.rs

1//! This module provides the [BurnCentral] struct, which is used to interact with the Burn Central service.
2
3use crate::artifacts::ExperimentArtifactClient;
4use crate::experiment::{ExperimentRun, ExperimentTrackerError};
5use crate::models::ModelRegistry;
6use crate::schemas::{ExperimentPath, User};
7use burn_central_client::{BurnCentralCredentials, Client, ClientError};
8use reqwest::Url;
9
10/// Errors that can occur during the initialization of the [BurnCentral] client.
11#[derive(Debug, thiserror::Error)]
12pub enum InitError {
13    /// Represents an error related to the client.
14    #[error("Client error: {0}")]
15    Client(#[from] ClientError),
16    /// Represents an error when the endpoint URL is invalid.
17    #[error("Failed to parse endpoint URL: {0}")]
18    InvalidEndpointUrl(String),
19    /// Represents an error when an environment variable is not set.
20    #[error("Environment variable not set: {0}")]
21    EnvNotSet(String),
22}
23
24#[derive(Debug, thiserror::Error)]
25pub enum BurnCentralError {
26    // Input validation errors
27    #[error("Invalid experiment path: {0}")]
28    InvalidExperimentPath(String),
29    #[error("Invalid project path: {0}")]
30    InvalidProjectPath(String),
31    #[error("Invalid experiment number: {0}")]
32    InvalidExperimentNumber(String),
33    #[error("Invalid model path: {0}")]
34    InvalidModelPath(String),
35
36    /// Represents an error related to client operations.
37    ///
38    /// This error variant is used to encapsulate client-specific errors along with additional context
39    /// and the underlying source error for more detailed debugging.
40    ///
41    /// # Fields
42    /// - `context` (String): A description or additional information about the client error context.
43    /// - `source` (ClientError): The underlying source of the client error, providing more details about the cause.
44    #[error("Client error: {context}\nSource: {source}")]
45    Client {
46        context: String,
47        source: ClientError,
48    },
49    /// Represents an error related to the experiment tracker.
50    #[error("Experiment error: {0}")]
51    ExperimentTracker(#[from] ExperimentTrackerError),
52
53    /// Error that should be used when the user is not logged in but tries to perform an operation that requires authentication.
54    #[error("The user is not authenticated.")]
55    Unauthenticated,
56
57    /// Error that should be used when the client performs operations that can fail due to IO issues.
58    #[error(transparent)]
59    Io(#[from] std::io::Error),
60
61    /// Error that should be used when the client encounters an error that is not specifically handled.
62    #[error("Internal error: {0}")]
63    Internal(String),
64}
65
66/// This builder struct is used to create a [BurnCentral] client.
67pub struct BurnCentralBuilder {
68    endpoint: Option<String>,
69    credentials: BurnCentralCredentials,
70}
71
72impl BurnCentralBuilder {
73    /// Creates a new [BurnCentralBuilder] with the given credentials.
74    pub fn new(credentials: impl Into<BurnCentralCredentials>) -> Self {
75        BurnCentralBuilder {
76            endpoint: None,
77            credentials: credentials.into(),
78        }
79    }
80
81    /// Sets the endpoint for the [BurnCentral] client.
82    pub fn with_endpoint(mut self, endpoint: impl Into<String>) -> Self {
83        self.endpoint = Some(endpoint.into());
84        self
85    }
86
87    /// Builds the [BurnCentral] client.
88    pub fn build(self) -> Result<BurnCentral, InitError> {
89        let url = match self.endpoint {
90            Some(s) => s
91                .parse::<Url>()
92                .map_err(|e| InitError::InvalidEndpointUrl(e.to_string()))?,
93            None => {
94                Url::parse("https://central.burn.dev/api/").expect("Default URL should be valid")
95            }
96        };
97        #[allow(deprecated)]
98        let client = Client::from_url(url, &self.credentials)?;
99        Ok(BurnCentral::new(client))
100    }
101}
102
103/// This struct provides the main interface to interact with Burn Central.
104#[derive(Clone)]
105pub struct BurnCentral {
106    client: Client,
107}
108
109impl BurnCentral {
110    /// Creates a new [BurnCentral] instance with the given credentials.
111    pub fn login(credentials: impl Into<BurnCentralCredentials>) -> Result<Self, InitError> {
112        let credentials = credentials.into();
113        BurnCentralBuilder::new(credentials).build()
114    }
115
116    /// Creates a new [BurnCentralBuilder] to configure the client.
117    pub fn builder(credentials: impl Into<BurnCentralCredentials>) -> BurnCentralBuilder {
118        BurnCentralBuilder::new(credentials)
119    }
120
121    /// Creates a new instance of [BurnCentral] with the given [Client].
122    fn new(client: Client) -> Self {
123        BurnCentral { client }
124    }
125
126    /// Returns the current user information.
127    pub fn me(&self) -> Result<User, BurnCentralError> {
128        let user = self.client.get_current_user().map_err(|e| {
129            if matches!(e, ClientError::Unauthorized) {
130                BurnCentralError::Unauthenticated
131            } else {
132                BurnCentralError::Client {
133                    context: "Failed to get current user".to_string(),
134                    source: e,
135                }
136            }
137        })?;
138
139        Ok(User {
140            username: user.username,
141            email: user.email,
142            namespace: user.namespace,
143        })
144    }
145
146    /// Start a new experiment. This will create a new experiment on the Burn Central backend and start it.
147    pub fn start_experiment(
148        &self,
149        namespace: &str,
150        project_name: &str,
151        digest: String,
152        routine: String,
153    ) -> Result<ExperimentRun, BurnCentralError> {
154        let experiment = self
155            .client
156            .create_experiment(namespace, project_name, None, digest, routine)
157            .map_err(|e| BurnCentralError::Client {
158                context: format!("Failed to create experiment for {namespace}/{project_name}"),
159                source: e,
160            })?;
161        let experiment_path = ExperimentPath::try_from(format!(
162            "{}/{}/{}",
163            namespace, project_name, experiment.experiment_num
164        ))?;
165
166        println!("Experiment num: {}", experiment.experiment_num);
167
168        ExperimentRun::new(self.client.clone(), experiment_path)
169            .map_err(BurnCentralError::ExperimentTracker)
170    }
171
172    pub fn artifacts(
173        &self,
174        owner: &str,
175        project: &str,
176        exp_num: i32,
177    ) -> Result<ExperimentArtifactClient, BurnCentralError> {
178        let exp_path = ExperimentPath::try_from(format!("{}/{}/{}", owner, project, exp_num))?;
179        Ok(ExperimentArtifactClient::new(self.client.clone(), exp_path))
180    }
181
182    /// Create a model registry for downloading models from Burn Central.
183    /// Models are project-scoped and identified by namespace/project/model_name.
184    pub fn models(&self) -> ModelRegistry {
185        ModelRegistry::new(self.client.clone())
186    }
187}