1use super::socket::ExperimentSocket;
2use crate::artifacts::{ArtifactKind, ExperimentArtifactClient};
3use crate::bundle::{BundleDecode, BundleEncode, InMemoryBundleReader};
4use crate::experiment::CancelToken;
5use crate::experiment::error::ExperimentTrackerError;
6use crate::experiment::log_store::TempLogStore;
7use crate::experiment::socket::ThreadError;
8use crate::schemas::ExperimentPath;
9use burn_central_client::Client;
10pub use burn_central_client::websocket::MetricLog;
11use burn_central_client::websocket::{ExperimentCompletion, ExperimentMessage, InputUsed};
12use crossbeam::channel::Sender;
13use serde::Serialize;
14use std::ops::Deref;
15use std::sync::{Arc, Weak};
16
17pub enum EndExperimentStatus {
18 Success,
19 Fail(String),
20}
21
22#[derive(Clone, Debug)]
24pub struct ExperimentRunHandle {
25 recorder: Weak<ExperimentRunInner>,
26}
27
28impl ExperimentRunHandle {
29 fn try_upgrade(&self) -> Result<Arc<ExperimentRunInner>, ExperimentTrackerError> {
30 self.recorder
31 .upgrade()
32 .ok_or(ExperimentTrackerError::InactiveExperiment)
33 }
34
35 pub fn log_args<A: Serialize>(&self, args: &A) -> Result<(), ExperimentTrackerError> {
37 self.try_upgrade()?.log_args(args)
38 }
39
40 pub fn log_artifact<E: BundleEncode>(
42 &self,
43 name: impl Into<String>,
44 kind: ArtifactKind,
45 sources: E,
46 settings: &E::Settings,
47 ) -> Result<(), ExperimentTrackerError> {
48 self.try_upgrade()?
49 .log_artifact(name, kind, sources, settings)
50 }
51
52 pub fn load_artifact<D: BundleDecode>(
54 &self,
55 name: impl AsRef<str>,
56 settings: &D::Settings,
57 ) -> Result<D, ExperimentTrackerError> {
58 self.try_upgrade()?.load_artifact(name, settings)
59 }
60
61 pub fn load_artifact_raw(
63 &self,
64 name: impl AsRef<str>,
65 ) -> Result<InMemoryBundleReader, ExperimentTrackerError> {
66 self.try_upgrade()?.load_artifact_raw(name)
67 }
68
69 pub fn log_metric(
71 &self,
72 epoch: usize,
73 split: impl Into<String>,
74 iteration: usize,
75 items: Vec<MetricLog>,
76 ) {
77 self.try_log_metric(epoch, split, iteration, items)
78 .expect("Failed to log metric, experiment may have been closed or inactive");
79 }
80
81 pub fn try_log_metric(
83 &self,
84 epoch: usize,
85 split: impl Into<String>,
86 iteration: usize,
87 items: Vec<MetricLog>,
88 ) -> Result<(), ExperimentTrackerError> {
89 self.try_upgrade()?
90 .log_metric(epoch, split, iteration, items)
91 }
92
93 pub fn log_metric_definition(
94 &self,
95 name: impl Into<String>,
96 description: Option<String>,
97 unit: Option<String>,
98 higher_is_better: bool,
99 ) -> Result<(), ExperimentTrackerError> {
100 self.try_upgrade()?
101 .log_metric_definition(name, description, unit, higher_is_better)
102 }
103
104 pub fn log_epoch_summary(
105 &self,
106 epoch: usize,
107 split: String,
108 best_metric_values: Vec<MetricLog>,
109 ) -> Result<(), ExperimentTrackerError> {
110 self.try_upgrade()?
111 .log_epoch_summary(epoch, split, best_metric_values)
112 }
113
114 pub fn log_info(&self, message: impl Into<String>) {
116 self.try_log_info(message)
117 .expect("Failed to log info, experiment may have been closed or inactive");
118 }
119
120 pub fn try_log_info(&self, message: impl Into<String>) -> Result<(), ExperimentTrackerError> {
122 self.try_upgrade()?.log_info(message)
123 }
124
125 pub fn log_error(&self, error: impl Into<String>) {
127 self.try_log_error(error)
128 .expect("Failed to log error, experiment may have been closed or inactive");
129 }
130
131 pub fn try_log_error(&self, error: impl Into<String>) -> Result<(), ExperimentTrackerError> {
133 self.try_upgrade()?.log_error(error)
134 }
135
136 pub fn log_config<C: Serialize>(
137 &self,
138 name: impl Into<String>,
139 config: &C,
140 ) -> Result<(), ExperimentTrackerError> {
141 self.try_upgrade()?.log_config(name.into(), config)
142 }
143
144 pub fn is_cancelled(&self) -> Result<bool, ExperimentTrackerError> {
147 Ok(self.try_upgrade()?.is_cancelled())
148 }
149
150 pub fn cancel_token(&self) -> Result<CancelToken, ExperimentTrackerError> {
152 Ok(self.try_upgrade()?.cancel_token())
153 }
154}
155
156struct ExperimentRunInner {
159 id: ExperimentPath,
160 http_client: Client,
161 sender: Sender<ExperimentMessage>,
162 cancel_token: CancelToken,
163}
164
165impl ExperimentRunInner {
166 fn send(&self, message: ExperimentMessage) -> Result<(), ExperimentTrackerError> {
167 self.sender
168 .send(message)
169 .map_err(|_| ExperimentTrackerError::SocketClosed)
170 }
171
172 pub fn log_args<A: Serialize>(&self, args: &A) -> Result<(), ExperimentTrackerError> {
173 let message = ExperimentMessage::Arguments(serde_json::to_value(args).map_err(|e| {
174 ExperimentTrackerError::InternalError(format!("Failed to serialize arguments: {}", e))
175 })?);
176 self.send(message)
177 }
178
179 pub fn log_config<C: Serialize>(
180 &self,
181 name: String,
182 config: &C,
183 ) -> Result<(), ExperimentTrackerError> {
184 let message = ExperimentMessage::Config {
185 value: serde_json::to_value(config).map_err(|e| {
186 ExperimentTrackerError::InternalError(format!("Failed to serialize config: {}", e))
187 })?,
188 name,
189 };
190 self.send(message)
191 }
192
193 pub fn log_artifact<E: BundleEncode>(
194 &self,
195 name: impl Into<String>,
196 kind: ArtifactKind,
197 artifact: E,
198 settings: &E::Settings,
199 ) -> Result<(), ExperimentTrackerError> {
200 ExperimentArtifactClient::new(self.http_client.clone(), self.id.clone())
201 .upload(name, kind, artifact, settings)
202 .map_err(Into::into)
203 .map(|_| ())
204 }
205
206 pub fn load_artifact_raw(
207 &self,
208 name: impl AsRef<str>,
209 ) -> Result<InMemoryBundleReader, ExperimentTrackerError> {
210 let scope = ExperimentArtifactClient::new(self.http_client.clone(), self.id.clone());
211 let artifact = scope.fetch(&name)?;
212 self.send(ExperimentMessage::InputUsed(InputUsed::Artifact {
213 artifact_id: artifact.id.to_string(),
214 }))?;
215 scope.download_raw(name).map_err(Into::into)
216 }
217
218 pub fn load_artifact<D: BundleDecode>(
219 &self,
220 name: impl AsRef<str>,
221 settings: &D::Settings,
222 ) -> Result<D, ExperimentTrackerError> {
223 let scope = ExperimentArtifactClient::new(self.http_client.clone(), self.id.clone());
224 let artifact = scope.fetch(&name)?;
225 self.send(ExperimentMessage::InputUsed(InputUsed::Artifact {
226 artifact_id: artifact.id.to_string(),
227 }))?;
228 scope.download(name, settings).map_err(Into::into)
229 }
230
231 pub fn log_metric(
232 &self,
233 epoch: usize,
234 split: impl Into<String>,
235 iteration: usize,
236 items: Vec<MetricLog>,
237 ) -> Result<(), ExperimentTrackerError> {
238 let message = ExperimentMessage::MetricsLog {
239 epoch,
240 split: split.into(),
241 iteration,
242 items,
243 };
244 self.send(message)
245 }
246
247 pub fn log_metric_definition(
248 &self,
249 name: impl Into<String>,
250 description: Option<String>,
251 unit: Option<String>,
252 higher_is_better: bool,
253 ) -> Result<(), ExperimentTrackerError> {
254 let message = ExperimentMessage::MetricDefinitionLog {
255 name: name.into(),
256 description,
257 unit,
258 higher_is_better,
259 };
260 self.send(message)
261 }
262
263 pub fn log_epoch_summary(
264 &self,
265 epoch: usize,
266 split: String,
267 best_metric_values: Vec<MetricLog>,
268 ) -> Result<(), ExperimentTrackerError> {
269 let message = ExperimentMessage::EpochSummaryLog {
270 epoch,
271 split,
272 best_metric_values,
273 };
274 self.send(message)
275 }
276
277 pub fn log_info(&self, message: impl Into<String>) -> Result<(), ExperimentTrackerError> {
278 self.send(ExperimentMessage::Log(message.into()))
279 }
280
281 pub fn log_error(&self, error: impl Into<String>) -> Result<(), ExperimentTrackerError> {
282 self.send(ExperimentMessage::Error(error.into()))
283 }
284
285 pub fn is_cancelled(&self) -> bool {
286 self.cancel_token.is_cancelled()
287 }
288
289 pub fn cancel_token(&self) -> CancelToken {
290 self.cancel_token.clone()
291 }
292}
293
294pub struct ExperimentRun {
296 inner: Option<Arc<ExperimentRunInner>>,
297 socket: Option<ExperimentSocket>,
298 _handle: ExperimentRunHandle,
300}
301
302impl ExperimentRun {
303 pub fn new(
304 burn_client: Client,
305 namespace: &str,
306 project_name: &str,
307 experiment_num: i32,
308 ) -> Result<Self, ExperimentTrackerError> {
309 let cancel_token = CancelToken::new();
310
311 let ws_client = burn_client
312 .create_experiment_run_websocket(namespace, project_name, experiment_num)
313 .map_err(|e| {
314 ExperimentTrackerError::ConnectionFailed(format!(
315 "Failed to create WebSocket client: {}",
316 e
317 ))
318 })?;
319
320 let experiment_path = ExperimentPath::new(namespace, project_name, experiment_num);
321
322 let log_store = TempLogStore::new(burn_client.clone(), experiment_path.clone());
323 let (sender, receiver) = crossbeam::channel::unbounded();
324 let socket = ExperimentSocket::new(ws_client, log_store, receiver, cancel_token.clone());
325
326 let inner = Arc::new(ExperimentRunInner {
327 id: experiment_path.clone(),
328 http_client: burn_client.clone(),
329 sender,
330 cancel_token: cancel_token.clone(),
331 });
332
333 let _handle = ExperimentRunHandle {
334 recorder: Arc::downgrade(&inner),
335 };
336
337 Ok(ExperimentRun {
338 inner: Some(inner),
339 socket: Some(socket),
340 _handle,
341 })
342 }
343
344 pub fn handle(&self) -> ExperimentRunHandle {
346 ExperimentRunHandle {
347 recorder: Arc::downgrade(self.inner.as_ref().expect("Experiment already finished")),
348 }
349 }
350
351 pub fn id(&self) -> &ExperimentPath {
352 &self.inner.as_ref().expect("Experiment already finished").id
353 }
354
355 fn finish_internal(
356 &mut self,
357 end_status: EndExperimentStatus,
358 ) -> Result<(), ExperimentTrackerError> {
359 let socket = self
360 .socket
361 .take()
362 .ok_or(ExperimentTrackerError::AlreadyFinished)?;
363
364 let inner = self
365 .inner
366 .take()
367 .ok_or(ExperimentTrackerError::AlreadyFinished)?;
368
369 let completion = match end_status {
370 EndExperimentStatus::Success => ExperimentCompletion::Success,
371 EndExperimentStatus::Fail(reason) => ExperimentCompletion::Fail { reason },
372 };
373
374 inner
375 .sender
376 .send(ExperimentMessage::ExperimentComplete(completion))
377 .map_err(|_| ExperimentTrackerError::SocketClosed)?;
378
379 drop(inner);
380
381 let thread_result = socket.join();
382
383 match thread_result {
384 Ok(_thread) => {}
385 Err(ThreadError::WebSocket(msg)) => {
386 eprintln!("Warning: WebSocket failure during experiment finish: {msg}");
387 }
388 Err(ThreadError::LogFlushError(msg)) => {
389 eprintln!("Warning: Log artifact creation failed: {msg}");
390 }
391 Err(ThreadError::Panic) => {
392 eprintln!("Warning: Experiment thread panicked");
393 return Err(ExperimentTrackerError::InternalError(
394 "Experiment thread panicked".into(),
395 ));
396 }
397 }
398
399 Ok(())
400 }
401
402 pub fn finish(mut self) -> Result<(), ExperimentTrackerError> {
403 self.finish_internal(EndExperimentStatus::Success)
404 }
405
406 pub fn fail(mut self, reason: impl Into<String>) -> Result<(), ExperimentTrackerError> {
407 self.finish_internal(EndExperimentStatus::Fail(reason.into()))
408 }
409}
410
411impl Drop for ExperimentRun {
412 fn drop(&mut self) {
413 if self.socket.is_some() {
414 let _ = self.finish_internal(EndExperimentStatus::Fail(
415 "Experiment dropped without finishing".to_string(),
416 ));
417 }
418 }
419}
420
421impl Deref for ExperimentRun {
424 type Target = ExperimentRunHandle;
425
426 fn deref(&self) -> &Self::Target {
427 &self._handle
428 }
429}