burn_central_runtime/params/artifact_loader.rs
1use crate::executor::ExecutionContext;
2use crate::params::RoutineParam;
3use burn::prelude::Backend;
4use burn_central_core::BurnCentral;
5use burn_central_core::artifacts::ArtifactError;
6use burn_central_core::bundle::BundleDecode;
7
8/// Artifact loader for loading artifacts from Burn Central. It allow to fecth for instance other
9/// experiment endpoint to be able to restart from a certain point your experiment.
10///
11/// You can build it yourself by using the [ArtifactLoader::new] function with your namespace (in
12/// slug format (e.g. "my-team")), project name and a [burn_central_core::BurnCentral]. However, it
13/// is also possible to request it directly in your routine by using declaring the param like so:
14///
15/// ```ignore
16/// # use burn_central_runtime::ArtifactLoader;
17/// # use burn_central_core::bundle::BundleDecode;
18/// # use burn_central::register;
19/// # use burn_central_runtime::Model;
20/// # use burn_central_runtime::MultiDevice;
21/// # use serde::*;
22/// #[derive(Deserialize, Serialize, Default)]
23/// pub struct LaunchArgs {
24/// pub experiment_num: Option<i32>,
25/// }
26///
27/// #[register(training, name = "mnist")]
28/// pub fn training<B: AutodiffBackend>(
29/// config: Args<LaunchArgs>,
30/// MultiDevice(devices): MultiDevice<B>,
31/// loader: ArtifactLoader<ModelArtifact<B>>,
32/// ) -> Result<Model<ModelArtifact<B::InnerBackend>>, String> {
33/// // Load a pretrained model if an experiment number is provided.
34/// if let Some(experiment_num) = config.experiment_num {
35/// let pretrained_model = loader
36/// .load(experiment_num, "train_artifacts")
37/// .expect("To be able to fetch artifacts");
38/// }
39/// }
40/// ```
41///
42/// As you can see in the example above, you can use the loader to dynamically request experiment
43/// artifacts when requested through your routine configuration.
44pub struct ArtifactLoader<T: BundleDecode> {
45 namespace: String,
46 project_name: String,
47 client: BurnCentral,
48 _artifact: std::marker::PhantomData<T>,
49}
50
51impl<T: BundleDecode> ArtifactLoader<T> {
52 pub fn new(namespace: String, project_name: String, client: BurnCentral) -> Self {
53 Self {
54 namespace,
55 project_name,
56 client,
57 _artifact: std::marker::PhantomData,
58 }
59 }
60
61 /// Load an artifact by name with specific settings.
62 pub fn load_with(
63 &self,
64 experiment_num: i32,
65 name: impl AsRef<str>,
66 settings: &T::Settings,
67 ) -> Result<T, ArtifactError> {
68 let scope = self
69 .client
70 .artifacts(&self.namespace, &self.project_name, experiment_num)
71 .map_err(|e| {
72 ArtifactError::Internal(format!("Failed to create artifact scope: {}", e))
73 })?;
74
75 scope.download(name, settings)
76 }
77
78 /// Load an artifact by name with default settings.
79 pub fn load(&self, experiment_num: i32, name: impl AsRef<str>) -> Result<T, ArtifactError> {
80 let scope = self
81 .client
82 .artifacts(&self.namespace, &self.project_name, experiment_num)
83 .map_err(|e| {
84 ArtifactError::Internal(format!("Failed to create artifact scope: {}", e))
85 })?;
86
87 scope.download(name, &Default::default())
88 }
89}
90
91impl<B: Backend, T: BundleDecode> RoutineParam<ExecutionContext<B>> for ArtifactLoader<T> {
92 type Item<'new>
93 = ArtifactLoader<T>
94 where
95 ExecutionContext<B>: 'new;
96
97 fn try_retrieve(ctx: &ExecutionContext<B>) -> anyhow::Result<Self::Item<'_>> {
98 let client = ctx.client().ok_or_else(|| {
99 anyhow::anyhow!("Burn Central client is not configured in the execution context")
100 })?;
101
102 Ok(ArtifactLoader::new(
103 ctx.namespace().to_string(),
104 ctx.project().to_string(),
105 client.clone(),
106 ))
107 }
108}