burn_central_runtime/
executor.rs

1use anyhow::Result;
2use burn::prelude::Backend;
3use burn_central_core::artifacts::ArtifactError;
4use burn_central_core::bundle::BundleDecode;
5
6use crate::error::RuntimeError;
7use crate::output::{ExperimentOutput, TrainOutput};
8use crate::param::RoutineParam;
9use crate::routine::{BoxedRoutine, ExecutorRoutineWrapper, IntoRoutine, Routine};
10use burn::tensor::backend::AutodiffBackend;
11use burn_central_core::BurnCentral;
12use burn_central_core::experiment::{
13    ExperimentArgs, ExperimentRun, deserialize_and_merge_with_default,
14};
15use std::collections::HashMap;
16
17/// A loader for artifacts associated with a specific experiment in Burn Central.
18///
19/// It can be used as a parameter in experiment routines to load artifacts like models or checkpoints.
20pub struct ArtifactLoader<T: BundleDecode> {
21    namespace: String,
22    project_name: String,
23    client: BurnCentral,
24    _artifact: std::marker::PhantomData<T>,
25}
26
27impl<T: BundleDecode> ArtifactLoader<T> {
28    pub fn new(namespace: String, project_name: String, client: BurnCentral) -> Self {
29        Self {
30            namespace,
31            project_name,
32            client,
33            _artifact: std::marker::PhantomData,
34        }
35    }
36
37    /// Load an artifact by name with specific settings.
38    pub fn load_with(
39        &self,
40        experiment_num: i32,
41        name: impl AsRef<str>,
42        settings: &T::Settings,
43    ) -> Result<T, ArtifactError> {
44        let scope = self
45            .client
46            .artifacts(&self.namespace, &self.project_name, experiment_num)
47            .map_err(|e| {
48                ArtifactError::Internal(format!("Failed to create artifact scope: {}", e))
49            })?;
50
51        scope.download(name, settings)
52    }
53
54    /// Load an artifact by name with default settings.
55    pub fn load(&self, experiment_num: i32, name: impl AsRef<str>) -> Result<T, ArtifactError> {
56        let scope = self
57            .client
58            .artifacts(&self.namespace, &self.project_name, experiment_num)
59            .map_err(|e| {
60                ArtifactError::Internal(format!("Failed to create artifact scope: {}", e))
61            })?;
62
63        scope.download(name, &Default::default())
64    }
65}
66
67impl<B: Backend, T: BundleDecode> RoutineParam<ExecutionContext<B>> for ArtifactLoader<T> {
68    type Item<'new>
69        = ArtifactLoader<T>
70    where
71        ExecutionContext<B>: 'new;
72
73    fn try_retrieve(ctx: &ExecutionContext<B>) -> Result<Self::Item<'_>> {
74        let client = ctx.client.as_ref().ok_or_else(|| {
75            anyhow::anyhow!("Burn Central client is not configured in the execution context")
76        })?;
77
78        Ok(ArtifactLoader::new(
79            ctx.namespace.clone(),
80            ctx.project.clone(),
81            client.clone(),
82        ))
83    }
84}
85
86type ExecutorRoutine<B> = BoxedRoutine<ExecutionContext<B>, (), ()>;
87
88/// The execution context for a routine, containing the necessary information to run it.
89pub struct ExecutionContext<B: Backend> {
90    client: Option<BurnCentral>,
91    namespace: String,
92    project: String,
93    args_override: Option<serde_json::Value>,
94    devices: Vec<B::Device>,
95    experiment: Option<ExperimentRun>,
96}
97
98impl<B: Backend> ExecutionContext<B> {
99    pub fn use_merged_args<A: ExperimentArgs>(&self) -> A {
100        let args = match &self.args_override {
101            Some(json) => deserialize_and_merge_with_default(json).unwrap_or_default(),
102            None => A::default(),
103        };
104
105        if let Some(experiment) = &self.experiment {
106            experiment.log_args(&args).unwrap_or_else(|e| {
107                log::error!("Failed to log experiment arguments: {}", e);
108            });
109        }
110
111        args
112    }
113
114    pub fn experiment(&self) -> Option<&ExperimentRun> {
115        self.experiment.as_ref()
116    }
117
118    pub fn devices(&self) -> &[B::Device] {
119        &self.devices
120    }
121}
122
123/// The kind of action that can be executed by the executor.
124#[derive(Clone, Debug, PartialEq, Eq, Hash, strum::Display, strum::EnumString)]
125#[strum(serialize_all = "snake_case")]
126pub enum ActionKind {
127    Train,
128    // Infer,
129    // Eval,
130    // Test,
131    // #[strum(serialize = "custom({0})")]
132    // Custom(String),
133}
134
135/// The identifier for a target, which consists of an action kind and a name.
136#[derive(Clone, Debug, PartialEq, Eq, Hash)]
137pub struct TargetId {
138    kind: ActionKind,
139    name: String,
140}
141
142impl std::fmt::Display for TargetId {
143    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
144        write!(f, "{}:{}", self.kind, self.name)
145    }
146}
147
148/// A builder for creating an `Executor` instance with registered routines.
149pub struct ExecutorBuilder<B: AutodiffBackend> {
150    executor: Executor<B>,
151}
152
153impl<B: AutodiffBackend> ExecutorBuilder<B> {
154    fn new() -> Self {
155        Self {
156            executor: Executor {
157                client: None,
158                namespace: None,
159                project: None,
160                handlers: HashMap::new(),
161            },
162        }
163    }
164
165    fn register<M, O: ExperimentOutput<B>>(
166        &mut self,
167        kind: ActionKind,
168        name: impl Into<String>,
169        handler: impl IntoRoutine<ExecutionContext<B>, (), O, M>,
170    ) -> &mut Self {
171        let wrapper = ExecutorRoutineWrapper::new(IntoRoutine::into_routine(handler));
172        let routine = Box::new(wrapper);
173        let routine_name = routine.name();
174
175        let target_id = TargetId {
176            kind,
177            name: name.into(),
178        };
179
180        log::debug!("Registering handler '{routine_name}' for target: {target_id}");
181
182        self.executor.handlers.insert(target_id, routine);
183        self
184    }
185
186    pub fn train<M, O: TrainOutput<B>>(
187        &mut self,
188        name: impl Into<String>,
189        handler: impl IntoRoutine<ExecutionContext<B>, (), O, M>,
190    ) -> &mut Self {
191        self.register(ActionKind::Train, name, handler);
192        self
193    }
194
195    pub fn build(
196        self,
197        client: BurnCentral,
198        namespace: impl Into<String>,
199        project: impl Into<String>,
200    ) -> Executor<B> {
201        let mut executor = self.executor;
202        executor.client = Some(client);
203        executor.namespace = Some(namespace.into());
204        executor.project = Some(project.into());
205        // Possibly do some validation or final setup here
206        executor
207    }
208}
209
210/// An executor that manages the execution of routines for different targets.
211pub struct Executor<B: Backend> {
212    client: Option<BurnCentral>,
213    namespace: Option<String>,
214    project: Option<String>,
215    handlers: HashMap<TargetId, ExecutorRoutine<B>>,
216}
217
218impl<B: AutodiffBackend> Executor<B> {
219    /// Creates a new `ExecutorBuilder` to configure and build an `Executor`.
220    pub fn builder() -> ExecutorBuilder<B> {
221        ExecutorBuilder::new()
222    }
223
224    /// Lists all registered targets in the executor.
225    pub fn targets(&self) -> Vec<TargetId> {
226        self.handlers.keys().cloned().collect()
227    }
228
229    /// Runs a routine for the specified target with the given devices and arguments override.
230    pub fn run(
231        &self,
232        kind: ActionKind,
233        name: impl AsRef<str>,
234        devices: impl IntoIterator<Item = B::Device>,
235        args_override: Option<String>,
236    ) -> Result<(), RuntimeError> {
237        let routine = name.as_ref();
238
239        let target_id = TargetId {
240            kind,
241            name: routine.to_string(),
242        };
243
244        let handler = self.handlers.get(&target_id).ok_or_else(|| {
245            log::error!("Handler not found for target: {routine}");
246            RuntimeError::HandlerNotFound(routine.to_string())
247        })?;
248
249        log::debug!("Starting Execution for Target: {routine}");
250
251        let args_override = args_override
252            .as_ref()
253            .map(|cfg_str| serde_json::from_str::<serde_json::Value>(cfg_str))
254            .transpose()
255            .map_err(|e| {
256                log::error!("Failed to parse experiment argument overrides: {}", e);
257                RuntimeError::InvalidArgs(e.to_string())
258            })?;
259
260        let mut ctx = ExecutionContext {
261            client: self.client.clone(),
262            namespace: self.namespace.clone().unwrap_or_default(),
263            project: self.project.clone().unwrap_or_default(),
264            args_override,
265            devices: devices.into_iter().collect(),
266            experiment: None,
267        };
268
269        if let Some(client) = &mut ctx.client {
270            let code_version = option_env!("BURN_CENTRAL_CODE_VERSION")
271                .unwrap_or("unknown")
272                .to_string();
273            log::debug!("Using Burn Central client with code version: {code_version}");
274
275            log::info!(
276                "Starting experiment for target: {} in namespace: {}, project: {}",
277                routine,
278                ctx.namespace,
279                ctx.project
280            );
281            let experiment = client.start_experiment(
282                &ctx.namespace,
283                &ctx.project,
284                code_version,
285                routine.to_string(),
286            )?;
287            ctx.experiment = Some(experiment);
288        }
289
290        let result = handler.run((), &mut ctx);
291
292        match result {
293            Ok(_) => {
294                if let Some(experiment) = ctx.experiment {
295                    experiment.finish()?;
296                    log::info!("Experiment run completed successfully.");
297                }
298                log::debug!("Handler {routine} executed successfully.");
299
300                Ok(())
301            }
302            Err(e) => {
303                log::error!("Error executing handler '{routine}': {e}");
304                if let Some(experiment) = ctx.experiment {
305                    experiment.fail(e.to_string())?;
306                    log::error!("Experiment run failed: {e}");
307                }
308                Err(e)
309            }
310        }
311    }
312}
313
314#[cfg(test)]
315mod test {
316    use std::convert::Infallible;
317
318    use crate::{Args, Model, MultiDevice};
319
320    use super::*;
321    use burn::backend::{Autodiff, NdArray};
322    use burn::nn::{Linear, LinearConfig};
323    use burn::prelude::*;
324    use burn_central_core::bundle::{BundleEncode, BundleSink};
325    use serde::{Deserialize, Serialize};
326
327    impl<B: AutodiffBackend> ExecutorBuilder<B> {
328        pub fn build_offline(self) -> Executor<B> {
329            self.executor
330        }
331    }
332
333    // A backend stub for testing purposes.
334    type TestBackend = Autodiff<NdArray<f32>>;
335    type TestDevice = <NdArray<f32> as Backend>::Device;
336
337    #[derive(Module, Debug)]
338    struct TestModel<B: Backend> {
339        linear: Linear<B>,
340    }
341
342    impl<B: Backend> BundleEncode for TestModel<B> {
343        type Settings = ();
344        type Error = Infallible;
345        fn encode<E: BundleSink>(
346            self,
347            _sink: &mut E,
348            _settings: &Self::Settings,
349        ) -> Result<(), Self::Error> {
350            Ok(())
351        }
352    }
353
354    impl<B: AutodiffBackend> TestModel<B> {
355        fn new(device: &B::Device) -> Self {
356            let linear = LinearConfig::new(10, 5).init(device);
357            TestModel { linear }
358        }
359    }
360
361    #[derive(Serialize, Deserialize, Debug, Default, Clone)]
362    struct TestArgs {
363        lr: f32,
364        epochs: usize,
365    }
366
367    // --- Test Routines ---
368
369    fn simple_train_step<B: AutodiffBackend>() -> Result<Model<TestModel<B>>, String> {
370        let device = B::Device::default();
371        let model = TestModel::new(&device);
372        Ok(model.into())
373    }
374
375    fn train_with_params<B: AutodiffBackend>(
376        args: Args<TestArgs>,
377        devices: MultiDevice<B>,
378    ) -> Model<TestModel<B>> {
379        let model = TestModel::new(&devices[0]);
380        assert_eq!(args.lr, 0.01);
381        assert_eq!(args.epochs, 10);
382        println!("Train step with config and model executed.");
383        model.into()
384    }
385
386    fn failing_routine<B: AutodiffBackend>() -> Result<Model<TestModel<B>>> {
387        anyhow::bail!("Failing routine");
388    }
389
390    // --- Tests ---
391
392    #[test]
393    fn should_run_simple_routine_successfully() {
394        let mut builder = Executor::<TestBackend>::builder();
395        builder.train("simple_task", simple_train_step::<TestBackend>);
396        let executor = builder.build_offline();
397
398        let result = executor.run(
399            "train".parse().unwrap(),
400            "simple_task",
401            [TestDevice::default()],
402            None,
403        );
404        assert!(result.is_ok());
405    }
406
407    #[test]
408    fn should_inject_parameters_and_handle_output() {
409        let mut builder = Executor::<TestBackend>::builder();
410        builder.train("complex_task", train_with_params);
411        let executor = builder.build_offline();
412
413        let args_json = r#"{"lr": 0.01, "epochs": 10}"#.to_string();
414
415        let result = executor.run(
416            "train".parse().unwrap(),
417            "complex_task",
418            [TestDevice::default()],
419            Some(args_json),
420        );
421        assert!(result.is_ok());
422    }
423
424    #[test]
425    fn should_return_handler_not_found_error() {
426        let builder = Executor::<TestBackend>::builder();
427        let executor = builder.build_offline();
428
429        let result = executor.run(
430            "train".parse().unwrap(),
431            "non_existent_task",
432            [TestDevice::default()],
433            None,
434        );
435
436        assert!(matches!(result, Err(RuntimeError::HandlerNotFound(_))));
437    }
438
439    #[test]
440    fn should_handle_failing_routine() {
441        let mut builder = Executor::<TestBackend>::builder();
442        builder.train("failing_task", failing_routine::<TestBackend>);
443        let executor = builder.build_offline();
444
445        let result = executor.run(
446            "train".parse().unwrap(),
447            "failing_task",
448            [TestDevice::default()],
449            None,
450        );
451
452        assert!(matches!(result, Err(RuntimeError::HandlerFailed(_))));
453    }
454
455    #[test]
456    fn should_support_named_routines() {
457        let mut builder = Executor::<TestBackend>::builder();
458        builder.train(
459            "task1",
460            simple_train_step::<TestBackend>.with_name("custom_name_1"),
461        );
462        builder.train("task2", ("custom_name_2", simple_train_step::<TestBackend>));
463        let executor = builder.build_offline();
464
465        let res1 = executor.run("train".parse().unwrap(), "task1", [], None);
466        let res2 = executor.run("train".parse().unwrap(), "task2", [], None);
467
468        assert!(res1.is_ok());
469        assert!(res2.is_ok());
470    }
471}