burn_central_runtime/params/default.rs
1use crate::{executor::ExecutionContext, params::RoutineParam};
2use burn::prelude::Backend;
3use derive_more::{Deref, From};
4
5/// Wrapper around multiple devices.
6///
7/// Since Burn Central CLI support selecting different backend on the fly. We handle the device
8/// selection in the generated crate. This structure is simply a marker for us to know where to
9/// inject the devices selected by the CLI.
10///
11/// We are planning to support multi device training in the future, however we currently only
12/// support one so this vector will always contains one device for now.
13#[derive(Clone, Debug, Deref, From)]
14pub struct MultiDevice<B: Backend>(pub Vec<B::Device>);
15
16/// Wrapper around the model returned by a routine.
17///
18/// This is used to differentiate the model from other return types.
19/// Right now the macro force you to return a Model as we expect to be able to log it as a model
20/// artifact.
21#[derive(Clone, From, Deref)]
22pub struct Model<M>(pub M);
23
24impl<B: Backend> RoutineParam<ExecutionContext<B>> for MultiDevice<B> {
25 type Item<'new> = MultiDevice<B>;
26
27 fn try_retrieve(ctx: &ExecutionContext<B>) -> anyhow::Result<Self::Item<'_>> {
28 Ok(MultiDevice(ctx.devices().into()))
29 }
30}