Skip to main content

burn_central_core/experiment/
base.rs

1use super::socket::ExperimentSocket;
2use crate::artifacts::{ArtifactKind, ExperimentArtifactClient};
3use crate::bundle::{BundleDecode, BundleEncode, InMemoryBundleReader};
4use crate::experiment::CancelToken;
5use crate::experiment::error::ExperimentTrackerError;
6use crate::experiment::log_store::TempLogStore;
7use crate::experiment::socket::ThreadError;
8use crate::schemas::ExperimentPath;
9use burn_central_client::Client;
10pub use burn_central_client::websocket::MetricLog;
11use burn_central_client::websocket::{ExperimentCompletion, ExperimentMessage, InputUsed};
12use crossbeam::channel::Sender;
13use serde::Serialize;
14use std::ops::Deref;
15use std::sync::{Arc, Weak};
16
17pub enum EndExperimentStatus {
18    Success,
19    Fail(String),
20}
21
22/// Represents a handle to an experiment, allowing logging of artifacts, metrics, and messages.
23#[derive(Clone, Debug)]
24pub struct ExperimentRunHandle {
25    recorder: Weak<ExperimentRunInner>,
26}
27
28impl ExperimentRunHandle {
29    fn try_upgrade(&self) -> Result<Arc<ExperimentRunInner>, ExperimentTrackerError> {
30        self.recorder
31            .upgrade()
32            .ok_or(ExperimentTrackerError::InactiveExperiment)
33    }
34
35    /// Log arguments used to launch this experiment
36    pub fn log_args<A: Serialize>(&self, args: &A) -> Result<(), ExperimentTrackerError> {
37        self.try_upgrade()?.log_args(args)
38    }
39
40    /// Log an artifact with the given name, kind and settings.
41    pub fn log_artifact<E: BundleEncode>(
42        &self,
43        name: impl Into<String>,
44        kind: ArtifactKind,
45        sources: E,
46        settings: &E::Settings,
47    ) -> Result<(), ExperimentTrackerError> {
48        self.try_upgrade()?
49            .log_artifact(name, kind, sources, settings)
50    }
51
52    /// Loads an artifact with the given name and settings.
53    pub fn load_artifact<D: BundleDecode>(
54        &self,
55        name: impl AsRef<str>,
56        settings: &D::Settings,
57    ) -> Result<D, ExperimentTrackerError> {
58        self.try_upgrade()?.load_artifact(name, settings)
59    }
60
61    /// Loads a raw artifact with the given name.
62    pub fn load_artifact_raw(
63        &self,
64        name: impl AsRef<str>,
65    ) -> Result<InMemoryBundleReader, ExperimentTrackerError> {
66        self.try_upgrade()?.load_artifact_raw(name)
67    }
68
69    /// Logs a metric with the given name, epoch, iteration, value, and group.
70    pub fn log_metric(
71        &self,
72        epoch: usize,
73        split: impl Into<String>,
74        iteration: usize,
75        items: Vec<MetricLog>,
76    ) {
77        self.try_log_metric(epoch, split, iteration, items)
78            .expect("Failed to log metric, experiment may have been closed or inactive");
79    }
80
81    /// Attempts to log a metric with the given name, epoch, iteration, value, and group.
82    pub fn try_log_metric(
83        &self,
84        epoch: usize,
85        split: impl Into<String>,
86        iteration: usize,
87        items: Vec<MetricLog>,
88    ) -> Result<(), ExperimentTrackerError> {
89        self.try_upgrade()?
90            .log_metric(epoch, split, iteration, items)
91    }
92
93    pub fn log_metric_definition(
94        &self,
95        name: impl Into<String>,
96        description: Option<String>,
97        unit: Option<String>,
98        higher_is_better: bool,
99    ) -> Result<(), ExperimentTrackerError> {
100        self.try_upgrade()?
101            .log_metric_definition(name, description, unit, higher_is_better)
102    }
103
104    pub fn log_epoch_summary(
105        &self,
106        epoch: usize,
107        split: String,
108        best_metric_values: Vec<MetricLog>,
109    ) -> Result<(), ExperimentTrackerError> {
110        self.try_upgrade()?
111            .log_epoch_summary(epoch, split, best_metric_values)
112    }
113
114    /// Logs an info message.
115    pub fn log_info(&self, message: impl Into<String>) {
116        self.try_log_info(message)
117            .expect("Failed to log info, experiment may have been closed or inactive");
118    }
119
120    /// Attempts to log an info message.
121    pub fn try_log_info(&self, message: impl Into<String>) -> Result<(), ExperimentTrackerError> {
122        self.try_upgrade()?.log_info(message)
123    }
124
125    /// Logs an error message.
126    pub fn log_error(&self, error: impl Into<String>) {
127        self.try_log_error(error)
128            .expect("Failed to log error, experiment may have been closed or inactive");
129    }
130
131    /// Attempts to log an error message.
132    pub fn try_log_error(&self, error: impl Into<String>) -> Result<(), ExperimentTrackerError> {
133        self.try_upgrade()?.log_error(error)
134    }
135
136    pub fn log_config<C: Serialize>(
137        &self,
138        name: impl Into<String>,
139        config: &C,
140    ) -> Result<(), ExperimentTrackerError> {
141        self.try_upgrade()?.log_config(name.into(), config)
142    }
143
144    /// Check whether the experiment has been cancelled (either locally or via server request).
145    /// Returns an error if the experiment has already become inactive.
146    pub fn is_cancelled(&self) -> Result<bool, ExperimentTrackerError> {
147        Ok(self.try_upgrade()?.is_cancelled())
148    }
149
150    /// Returns the experiment cancel token.
151    pub fn cancel_token(&self) -> Result<CancelToken, ExperimentTrackerError> {
152        Ok(self.try_upgrade()?.cancel_token())
153    }
154}
155
156/// Represents a recorder for an experiment, allowing logging of artifacts, metrics, and messages.
157/// It is used internally by the [Experiment](ExperimentRun) struct to handle logging operations.
158struct ExperimentRunInner {
159    id: ExperimentPath,
160    http_client: Client,
161    sender: Sender<ExperimentMessage>,
162    cancel_token: CancelToken,
163}
164
165impl ExperimentRunInner {
166    fn send(&self, message: ExperimentMessage) -> Result<(), ExperimentTrackerError> {
167        self.sender
168            .send(message)
169            .map_err(|_| ExperimentTrackerError::SocketClosed)
170    }
171
172    pub fn log_args<A: Serialize>(&self, args: &A) -> Result<(), ExperimentTrackerError> {
173        let message = ExperimentMessage::Arguments(serde_json::to_value(args).map_err(|e| {
174            ExperimentTrackerError::InternalError(format!("Failed to serialize arguments: {}", e))
175        })?);
176        self.send(message)
177    }
178
179    pub fn log_config<C: Serialize>(
180        &self,
181        name: String,
182        config: &C,
183    ) -> Result<(), ExperimentTrackerError> {
184        let message = ExperimentMessage::Config {
185            value: serde_json::to_value(config).map_err(|e| {
186                ExperimentTrackerError::InternalError(format!("Failed to serialize config: {}", e))
187            })?,
188            name,
189        };
190        self.send(message)
191    }
192
193    pub fn log_artifact<E: BundleEncode>(
194        &self,
195        name: impl Into<String>,
196        kind: ArtifactKind,
197        artifact: E,
198        settings: &E::Settings,
199    ) -> Result<(), ExperimentTrackerError> {
200        ExperimentArtifactClient::new(self.http_client.clone(), self.id.clone())
201            .upload(name, kind, artifact, settings)
202            .map_err(Into::into)
203            .map(|_| ())
204    }
205
206    pub fn load_artifact_raw(
207        &self,
208        name: impl AsRef<str>,
209    ) -> Result<InMemoryBundleReader, ExperimentTrackerError> {
210        let scope = ExperimentArtifactClient::new(self.http_client.clone(), self.id.clone());
211        let artifact = scope.fetch(&name)?;
212        self.send(ExperimentMessage::InputUsed(InputUsed::Artifact {
213            artifact_id: artifact.id.to_string(),
214        }))?;
215        scope.download_raw(name).map_err(Into::into)
216    }
217
218    pub fn load_artifact<D: BundleDecode>(
219        &self,
220        name: impl AsRef<str>,
221        settings: &D::Settings,
222    ) -> Result<D, ExperimentTrackerError> {
223        let scope = ExperimentArtifactClient::new(self.http_client.clone(), self.id.clone());
224        let artifact = scope.fetch(&name)?;
225        self.send(ExperimentMessage::InputUsed(InputUsed::Artifact {
226            artifact_id: artifact.id.to_string(),
227        }))?;
228        scope.download(name, settings).map_err(Into::into)
229    }
230
231    pub fn log_metric(
232        &self,
233        epoch: usize,
234        split: impl Into<String>,
235        iteration: usize,
236        items: Vec<MetricLog>,
237    ) -> Result<(), ExperimentTrackerError> {
238        let message = ExperimentMessage::MetricsLog {
239            epoch,
240            split: split.into(),
241            iteration,
242            items,
243        };
244        self.send(message)
245    }
246
247    pub fn log_metric_definition(
248        &self,
249        name: impl Into<String>,
250        description: Option<String>,
251        unit: Option<String>,
252        higher_is_better: bool,
253    ) -> Result<(), ExperimentTrackerError> {
254        let message = ExperimentMessage::MetricDefinitionLog {
255            name: name.into(),
256            description,
257            unit,
258            higher_is_better,
259        };
260        self.send(message)
261    }
262
263    pub fn log_epoch_summary(
264        &self,
265        epoch: usize,
266        split: String,
267        best_metric_values: Vec<MetricLog>,
268    ) -> Result<(), ExperimentTrackerError> {
269        let message = ExperimentMessage::EpochSummaryLog {
270            epoch,
271            split,
272            best_metric_values,
273        };
274        self.send(message)
275    }
276
277    pub fn log_info(&self, message: impl Into<String>) -> Result<(), ExperimentTrackerError> {
278        self.send(ExperimentMessage::Log(message.into()))
279    }
280
281    pub fn log_error(&self, error: impl Into<String>) -> Result<(), ExperimentTrackerError> {
282        self.send(ExperimentMessage::Error(error.into()))
283    }
284
285    pub fn is_cancelled(&self) -> bool {
286        self.cancel_token.is_cancelled()
287    }
288
289    pub fn cancel_token(&self) -> CancelToken {
290        self.cancel_token.clone()
291    }
292}
293
294/// Represents an experiment in Burn Central, which is a run of a machine learning model or process.
295pub struct ExperimentRun {
296    inner: Option<Arc<ExperimentRunInner>>,
297    socket: Option<ExperimentSocket>,
298    // temporary field to allow dereferencing to handle
299    _handle: ExperimentRunHandle,
300}
301
302impl ExperimentRun {
303    pub fn new(
304        burn_client: Client,
305        namespace: &str,
306        project_name: &str,
307        experiment_num: i32,
308    ) -> Result<Self, ExperimentTrackerError> {
309        let cancel_token = CancelToken::new();
310
311        let ws_client = burn_client
312            .create_experiment_run_websocket(namespace, project_name, experiment_num)
313            .map_err(|e| {
314                ExperimentTrackerError::ConnectionFailed(format!(
315                    "Failed to create WebSocket client: {}",
316                    e
317                ))
318            })?;
319
320        let experiment_path = ExperimentPath::new(namespace, project_name, experiment_num);
321
322        let log_store = TempLogStore::new(burn_client.clone(), experiment_path.clone());
323        let (sender, receiver) = crossbeam::channel::unbounded();
324        let socket = ExperimentSocket::new(ws_client, log_store, receiver, cancel_token.clone());
325
326        let inner = Arc::new(ExperimentRunInner {
327            id: experiment_path.clone(),
328            http_client: burn_client.clone(),
329            sender,
330            cancel_token: cancel_token.clone(),
331        });
332
333        let _handle = ExperimentRunHandle {
334            recorder: Arc::downgrade(&inner),
335        };
336
337        Ok(ExperimentRun {
338            inner: Some(inner),
339            socket: Some(socket),
340            _handle,
341        })
342    }
343
344    /// Returns a handle to the experiment, allowing logging of artifacts, metrics, and messages.
345    pub fn handle(&self) -> ExperimentRunHandle {
346        ExperimentRunHandle {
347            recorder: Arc::downgrade(self.inner.as_ref().expect("Experiment already finished")),
348        }
349    }
350
351    pub fn id(&self) -> &ExperimentPath {
352        &self.inner.as_ref().expect("Experiment already finished").id
353    }
354
355    fn finish_internal(
356        &mut self,
357        end_status: EndExperimentStatus,
358    ) -> Result<(), ExperimentTrackerError> {
359        let socket = self
360            .socket
361            .take()
362            .ok_or(ExperimentTrackerError::AlreadyFinished)?;
363
364        let inner = self
365            .inner
366            .take()
367            .ok_or(ExperimentTrackerError::AlreadyFinished)?;
368
369        let completion = match end_status {
370            EndExperimentStatus::Success => ExperimentCompletion::Success,
371            EndExperimentStatus::Fail(reason) => ExperimentCompletion::Fail { reason },
372        };
373
374        inner
375            .sender
376            .send(ExperimentMessage::ExperimentComplete(completion))
377            .map_err(|_| ExperimentTrackerError::SocketClosed)?;
378
379        drop(inner);
380
381        let thread_result = socket.join();
382
383        match thread_result {
384            Ok(_thread) => {}
385            Err(ThreadError::WebSocket(msg)) => {
386                eprintln!("Warning: WebSocket failure during experiment finish: {msg}");
387            }
388            Err(ThreadError::LogFlushError(msg)) => {
389                eprintln!("Warning: Log artifact creation failed: {msg}");
390            }
391            Err(ThreadError::Panic) => {
392                eprintln!("Warning: Experiment thread panicked");
393                return Err(ExperimentTrackerError::InternalError(
394                    "Experiment thread panicked".into(),
395                ));
396            }
397        }
398
399        Ok(())
400    }
401
402    pub fn finish(mut self) -> Result<(), ExperimentTrackerError> {
403        self.finish_internal(EndExperimentStatus::Success)
404    }
405
406    pub fn fail(mut self, reason: impl Into<String>) -> Result<(), ExperimentTrackerError> {
407        self.finish_internal(EndExperimentStatus::Fail(reason.into()))
408    }
409}
410
411impl Drop for ExperimentRun {
412    fn drop(&mut self) {
413        if self.socket.is_some() {
414            let _ = self.finish_internal(EndExperimentStatus::Fail(
415                "Experiment dropped without finishing".to_string(),
416            ));
417        }
418    }
419}
420
421/// Temporary implementation to allow dereferencing the Experiment to its recorder
422/// This will be removed once the experiment logging api is completed
423impl Deref for ExperimentRun {
424    type Target = ExperimentRunHandle;
425
426    fn deref(&self) -> &Self::Target {
427        &self._handle
428    }
429}