1use anyhow::Result;
2use burn::prelude::Backend;
3use burn_central_core::artifacts::ArtifactError;
4use burn_central_core::bundle::BundleDecode;
5
6use crate::error::RuntimeError;
7use crate::output::{ExperimentOutput, TrainOutput};
8use crate::param::RoutineParam;
9use crate::routine::{BoxedRoutine, ExecutorRoutineWrapper, IntoRoutine, Routine};
10use burn::tensor::backend::AutodiffBackend;
11use burn_central_core::BurnCentral;
12use burn_central_core::experiment::{
13 ExperimentArgs, ExperimentRun, deserialize_and_merge_with_default,
14};
15use std::collections::HashMap;
16
17pub struct ArtifactLoader<T: BundleDecode> {
21 namespace: String,
22 project_name: String,
23 client: BurnCentral,
24 _artifact: std::marker::PhantomData<T>,
25}
26
27impl<T: BundleDecode> ArtifactLoader<T> {
28 pub fn new(namespace: String, project_name: String, client: BurnCentral) -> Self {
29 Self {
30 namespace,
31 project_name,
32 client,
33 _artifact: std::marker::PhantomData,
34 }
35 }
36
37 pub fn load_with(
39 &self,
40 experiment_num: i32,
41 name: impl AsRef<str>,
42 settings: &T::Settings,
43 ) -> Result<T, ArtifactError> {
44 let scope = self
45 .client
46 .artifacts(&self.namespace, &self.project_name, experiment_num)
47 .map_err(|e| {
48 ArtifactError::Internal(format!("Failed to create artifact scope: {}", e))
49 })?;
50
51 scope.download(name, settings)
52 }
53
54 pub fn load(&self, experiment_num: i32, name: impl AsRef<str>) -> Result<T, ArtifactError> {
56 let scope = self
57 .client
58 .artifacts(&self.namespace, &self.project_name, experiment_num)
59 .map_err(|e| {
60 ArtifactError::Internal(format!("Failed to create artifact scope: {}", e))
61 })?;
62
63 scope.download(name, &Default::default())
64 }
65}
66
67impl<B: Backend, T: BundleDecode> RoutineParam<ExecutionContext<B>> for ArtifactLoader<T> {
68 type Item<'new>
69 = ArtifactLoader<T>
70 where
71 ExecutionContext<B>: 'new;
72
73 fn try_retrieve(ctx: &ExecutionContext<B>) -> Result<Self::Item<'_>> {
74 let client = ctx.client.as_ref().ok_or_else(|| {
75 anyhow::anyhow!("Burn Central client is not configured in the execution context")
76 })?;
77
78 Ok(ArtifactLoader::new(
79 ctx.namespace.clone(),
80 ctx.project.clone(),
81 client.clone(),
82 ))
83 }
84}
85
86type ExecutorRoutine<B> = BoxedRoutine<ExecutionContext<B>, (), ()>;
87
88pub struct ExecutionContext<B: Backend> {
90 client: Option<BurnCentral>,
91 namespace: String,
92 project: String,
93 args_override: Option<serde_json::Value>,
94 devices: Vec<B::Device>,
95 experiment: Option<ExperimentRun>,
96}
97
98impl<B: Backend> ExecutionContext<B> {
99 pub fn use_merged_args<A: ExperimentArgs>(&self) -> A {
100 let args = match &self.args_override {
101 Some(json) => deserialize_and_merge_with_default(json).unwrap_or_default(),
102 None => A::default(),
103 };
104
105 if let Some(experiment) = &self.experiment {
106 experiment.log_args(&args).unwrap_or_else(|e| {
107 log::error!("Failed to log experiment arguments: {}", e);
108 });
109 }
110
111 args
112 }
113
114 pub fn experiment(&self) -> Option<&ExperimentRun> {
115 self.experiment.as_ref()
116 }
117
118 pub fn devices(&self) -> &[B::Device] {
119 &self.devices
120 }
121}
122
123#[derive(Clone, Debug, PartialEq, Eq, Hash, strum::Display, strum::EnumString)]
125#[strum(serialize_all = "snake_case")]
126pub enum ActionKind {
127 Train,
128 }
134
135#[derive(Clone, Debug, PartialEq, Eq, Hash)]
137pub struct TargetId {
138 kind: ActionKind,
139 name: String,
140}
141
142impl std::fmt::Display for TargetId {
143 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
144 write!(f, "{}:{}", self.kind, self.name)
145 }
146}
147
148pub struct ExecutorBuilder<B: AutodiffBackend> {
150 executor: Executor<B>,
151}
152
153impl<B: AutodiffBackend> ExecutorBuilder<B> {
154 fn new() -> Self {
155 Self {
156 executor: Executor {
157 client: None,
158 namespace: None,
159 project: None,
160 handlers: HashMap::new(),
161 },
162 }
163 }
164
165 fn register<M, O: ExperimentOutput<B>>(
166 &mut self,
167 kind: ActionKind,
168 name: impl Into<String>,
169 handler: impl IntoRoutine<ExecutionContext<B>, (), O, M>,
170 ) -> &mut Self {
171 let wrapper = ExecutorRoutineWrapper::new(IntoRoutine::into_routine(handler));
172 let routine = Box::new(wrapper);
173 let routine_name = routine.name();
174
175 let target_id = TargetId {
176 kind,
177 name: name.into(),
178 };
179
180 log::debug!("Registering handler '{routine_name}' for target: {target_id}");
181
182 self.executor.handlers.insert(target_id, routine);
183 self
184 }
185
186 pub fn train<M, O: TrainOutput<B>>(
187 &mut self,
188 name: impl Into<String>,
189 handler: impl IntoRoutine<ExecutionContext<B>, (), O, M>,
190 ) -> &mut Self {
191 self.register(ActionKind::Train, name, handler);
192 self
193 }
194
195 pub fn build(
196 self,
197 client: BurnCentral,
198 namespace: impl Into<String>,
199 project: impl Into<String>,
200 ) -> Executor<B> {
201 let mut executor = self.executor;
202 executor.client = Some(client);
203 executor.namespace = Some(namespace.into());
204 executor.project = Some(project.into());
205 executor
207 }
208}
209
210pub struct Executor<B: Backend> {
212 client: Option<BurnCentral>,
213 namespace: Option<String>,
214 project: Option<String>,
215 handlers: HashMap<TargetId, ExecutorRoutine<B>>,
216}
217
218impl<B: AutodiffBackend> Executor<B> {
219 pub fn builder() -> ExecutorBuilder<B> {
221 ExecutorBuilder::new()
222 }
223
224 pub fn targets(&self) -> Vec<TargetId> {
226 self.handlers.keys().cloned().collect()
227 }
228
229 pub fn run(
231 &self,
232 kind: ActionKind,
233 name: impl AsRef<str>,
234 devices: impl IntoIterator<Item = B::Device>,
235 args_override: Option<String>,
236 ) -> Result<(), RuntimeError> {
237 let routine = name.as_ref();
238
239 let target_id = TargetId {
240 kind,
241 name: routine.to_string(),
242 };
243
244 let handler = self.handlers.get(&target_id).ok_or_else(|| {
245 log::error!("Handler not found for target: {routine}");
246 RuntimeError::HandlerNotFound(routine.to_string())
247 })?;
248
249 log::debug!("Starting Execution for Target: {routine}");
250
251 let args_override = args_override
252 .as_ref()
253 .map(|cfg_str| serde_json::from_str::<serde_json::Value>(cfg_str))
254 .transpose()
255 .map_err(|e| {
256 log::error!("Failed to parse experiment argument overrides: {}", e);
257 RuntimeError::InvalidArgs(e.to_string())
258 })?;
259
260 let mut ctx = ExecutionContext {
261 client: self.client.clone(),
262 namespace: self.namespace.clone().unwrap_or_default(),
263 project: self.project.clone().unwrap_or_default(),
264 args_override,
265 devices: devices.into_iter().collect(),
266 experiment: None,
267 };
268
269 if let Some(client) = &mut ctx.client {
270 let code_version = option_env!("BURN_CENTRAL_CODE_VERSION")
271 .unwrap_or("unknown")
272 .to_string();
273 log::debug!("Using Burn Central client with code version: {code_version}");
274
275 log::info!(
276 "Starting experiment for target: {} in namespace: {}, project: {}",
277 routine,
278 ctx.namespace,
279 ctx.project
280 );
281 let experiment = client.start_experiment(
282 &ctx.namespace,
283 &ctx.project,
284 code_version,
285 routine.to_string(),
286 )?;
287 ctx.experiment = Some(experiment);
288 }
289
290 let result = handler.run((), &mut ctx);
291
292 match result {
293 Ok(_) => {
294 if let Some(experiment) = ctx.experiment {
295 experiment.finish()?;
296 log::info!("Experiment run completed successfully.");
297 }
298 log::debug!("Handler {routine} executed successfully.");
299
300 Ok(())
301 }
302 Err(e) => {
303 log::error!("Error executing handler '{routine}': {e}");
304 if let Some(experiment) = ctx.experiment {
305 experiment.fail(e.to_string())?;
306 log::error!("Experiment run failed: {e}");
307 }
308 Err(e)
309 }
310 }
311 }
312}
313
314#[cfg(test)]
315mod test {
316 use std::convert::Infallible;
317
318 use crate::{Args, Model, MultiDevice};
319
320 use super::*;
321 use burn::backend::{Autodiff, NdArray};
322 use burn::nn::{Linear, LinearConfig};
323 use burn::prelude::*;
324 use burn_central_core::bundle::{BundleEncode, BundleSink};
325 use serde::{Deserialize, Serialize};
326
327 impl<B: AutodiffBackend> ExecutorBuilder<B> {
328 pub fn build_offline(self) -> Executor<B> {
329 self.executor
330 }
331 }
332
333 type TestBackend = Autodiff<NdArray<f32>>;
335 type TestDevice = <NdArray<f32> as Backend>::Device;
336
337 #[derive(Module, Debug)]
338 struct TestModel<B: Backend> {
339 linear: Linear<B>,
340 }
341
342 impl<B: Backend> BundleEncode for TestModel<B> {
343 type Settings = ();
344 type Error = Infallible;
345 fn encode<E: BundleSink>(
346 self,
347 _sink: &mut E,
348 _settings: &Self::Settings,
349 ) -> Result<(), Self::Error> {
350 Ok(())
351 }
352 }
353
354 impl<B: AutodiffBackend> TestModel<B> {
355 fn new(device: &B::Device) -> Self {
356 let linear = LinearConfig::new(10, 5).init(device);
357 TestModel { linear }
358 }
359 }
360
361 #[derive(Serialize, Deserialize, Debug, Default, Clone)]
362 struct TestArgs {
363 lr: f32,
364 epochs: usize,
365 }
366
367 fn simple_train_step<B: AutodiffBackend>() -> Result<Model<TestModel<B>>, String> {
370 let device = B::Device::default();
371 let model = TestModel::new(&device);
372 Ok(model.into())
373 }
374
375 fn train_with_params<B: AutodiffBackend>(
376 args: Args<TestArgs>,
377 devices: MultiDevice<B>,
378 ) -> Model<TestModel<B>> {
379 let model = TestModel::new(&devices[0]);
380 assert_eq!(args.lr, 0.01);
381 assert_eq!(args.epochs, 10);
382 println!("Train step with config and model executed.");
383 model.into()
384 }
385
386 fn failing_routine<B: AutodiffBackend>() -> Result<Model<TestModel<B>>> {
387 anyhow::bail!("Failing routine");
388 }
389
390 #[test]
393 fn should_run_simple_routine_successfully() {
394 let mut builder = Executor::<TestBackend>::builder();
395 builder.train("simple_task", simple_train_step::<TestBackend>);
396 let executor = builder.build_offline();
397
398 let result = executor.run(
399 "train".parse().unwrap(),
400 "simple_task",
401 [TestDevice::default()],
402 None,
403 );
404 assert!(result.is_ok());
405 }
406
407 #[test]
408 fn should_inject_parameters_and_handle_output() {
409 let mut builder = Executor::<TestBackend>::builder();
410 builder.train("complex_task", train_with_params);
411 let executor = builder.build_offline();
412
413 let args_json = r#"{"lr": 0.01, "epochs": 10}"#.to_string();
414
415 let result = executor.run(
416 "train".parse().unwrap(),
417 "complex_task",
418 [TestDevice::default()],
419 Some(args_json),
420 );
421 assert!(result.is_ok());
422 }
423
424 #[test]
425 fn should_return_handler_not_found_error() {
426 let builder = Executor::<TestBackend>::builder();
427 let executor = builder.build_offline();
428
429 let result = executor.run(
430 "train".parse().unwrap(),
431 "non_existent_task",
432 [TestDevice::default()],
433 None,
434 );
435
436 assert!(matches!(result, Err(RuntimeError::HandlerNotFound(_))));
437 }
438
439 #[test]
440 fn should_handle_failing_routine() {
441 let mut builder = Executor::<TestBackend>::builder();
442 builder.train("failing_task", failing_routine::<TestBackend>);
443 let executor = builder.build_offline();
444
445 let result = executor.run(
446 "train".parse().unwrap(),
447 "failing_task",
448 [TestDevice::default()],
449 None,
450 );
451
452 assert!(matches!(result, Err(RuntimeError::HandlerFailed(_))));
453 }
454
455 #[test]
456 fn should_support_named_routines() {
457 let mut builder = Executor::<TestBackend>::builder();
458 builder.train(
459 "task1",
460 simple_train_step::<TestBackend>.with_name("custom_name_1"),
461 );
462 builder.train("task2", ("custom_name_2", simple_train_step::<TestBackend>));
463 let executor = builder.build_offline();
464
465 let res1 = executor.run("train".parse().unwrap(), "task1", [], None);
466 let res2 = executor.run("train".parse().unwrap(), "task2", [], None);
467
468 assert!(res1.is_ok());
469 assert!(res2.is_ok());
470 }
471}