Skip to main content

burn_central_runtime/params/
artifact_loader.rs

1use crate::executor::ExecutionContext;
2use crate::params::RoutineParam;
3use burn::prelude::Backend;
4use burn_central_core::BurnCentral;
5use burn_central_core::artifacts::ArtifactError;
6use burn_central_core::bundle::BundleDecode;
7
8/// Artifact loader for loading artifacts from Burn Central. It allow to fecth for instance other
9/// experiment endpoint to be able to restart from a certain point your experiment.
10///
11/// You can build it yourself by using the [ArtifactLoader::new] function with your namespace (in
12/// slug format (e.g. "my-team")), project name and a [burn_central_core::BurnCentral]. However, it
13/// is also possible to request it directly in your routine by using declaring the param like so:
14///
15/// ```ignore
16/// # use burn_central_runtime::ArtifactLoader;
17/// # use burn_central_core::bundle::BundleDecode;
18/// # use burn_central::register;
19/// # use burn_central_runtime::Model;
20/// # use burn_central_runtime::MultiDevice;
21/// # use serde::*;
22/// #[derive(Deserialize, Serialize, Default)]
23/// pub struct LaunchArgs {
24///     pub experiment_num: Option<i32>,
25/// }
26///
27/// #[register(training, name = "mnist")]
28/// pub fn training<B: AutodiffBackend>(
29///     config: Args<LaunchArgs>,
30///     MultiDevice(devices): MultiDevice<B>,
31///     loader: ArtifactLoader<ModelArtifact<B>>,
32/// ) -> Result<Model<ModelArtifact<B::InnerBackend>>, String> {
33///     // Load a pretrained model if an experiment number is provided.
34///     if let Some(experiment_num) = config.experiment_num {
35///         let pretrained_model = loader
36///             .load(experiment_num, "train_artifacts")
37///             .expect("To be able to fetch artifacts");
38///     }
39/// }
40/// ```
41///
42/// As you can see in the example above, you can use the loader to dynamically request experiment
43/// artifacts when requested through your routine configuration.
44pub struct ArtifactLoader<T: BundleDecode> {
45    namespace: String,
46    project_name: String,
47    client: BurnCentral,
48    _artifact: std::marker::PhantomData<T>,
49}
50
51impl<T: BundleDecode> ArtifactLoader<T> {
52    pub fn new(namespace: String, project_name: String, client: BurnCentral) -> Self {
53        Self {
54            namespace,
55            project_name,
56            client,
57            _artifact: std::marker::PhantomData,
58        }
59    }
60
61    /// Load an artifact by name with specific settings.
62    pub fn load_with(
63        &self,
64        experiment_num: i32,
65        name: impl AsRef<str>,
66        settings: &T::Settings,
67    ) -> Result<T, ArtifactError> {
68        let scope = self
69            .client
70            .artifacts(&self.namespace, &self.project_name, experiment_num)
71            .map_err(|e| {
72                ArtifactError::Internal(format!("Failed to create artifact scope: {}", e))
73            })?;
74
75        scope.download(name, settings)
76    }
77
78    /// Load an artifact by name with default settings.
79    pub fn load(&self, experiment_num: i32, name: impl AsRef<str>) -> Result<T, ArtifactError> {
80        let scope = self
81            .client
82            .artifacts(&self.namespace, &self.project_name, experiment_num)
83            .map_err(|e| {
84                ArtifactError::Internal(format!("Failed to create artifact scope: {}", e))
85            })?;
86
87        scope.download(name, &Default::default())
88    }
89}
90
91impl<B: Backend, T: BundleDecode> RoutineParam<ExecutionContext<B>> for ArtifactLoader<T> {
92    type Item<'new>
93        = ArtifactLoader<T>
94    where
95        ExecutionContext<B>: 'new;
96
97    fn try_retrieve(ctx: &ExecutionContext<B>) -> anyhow::Result<Self::Item<'_>> {
98        let client = ctx.client().ok_or_else(|| {
99            anyhow::anyhow!("Burn Central client is not configured in the execution context")
100        })?;
101
102        Ok(ArtifactLoader::new(
103            ctx.namespace().to_string(),
104            ctx.project().to_string(),
105            client.clone(),
106        ))
107    }
108}