Skip to main content

bunsen_cache/
prefabs.rs

1//! # Config Prefabs for Well-Known Model Configurations
2
3use alloc::{
4    collections::BTreeMap,
5    format,
6    string::{
7        String,
8        ToString,
9    },
10    sync::Arc,
11};
12use core::fmt::Debug;
13
14use anyhow::bail;
15use burn::config::Config;
16
17use crate::{
18    PretrainedWeightsDescriptor,
19    PretrainedWeightsMap,
20    StaticPretrainedWeightsMap,
21};
22
23/// Static builder for a [`PreFabConfig`]
24pub struct StaticPreFabConfig<C>
25where
26    C: 'static + Config + Debug + Clone,
27{
28    /// Name of the model config pre-fab.
29    pub name: &'static str,
30
31    /// Description of the model config pre-fab.
32    pub description: &'static str,
33
34    /// Builder function for the config.
35    pub builder: fn() -> C,
36
37    /// Pretrained weights map.
38    pub weights: Option<&'static StaticPretrainedWeightsMap<'static>>,
39}
40
41impl<C> StaticPreFabConfig<C>
42where
43    C: 'static + Config + Debug + Clone,
44{
45    /// Convert to a [`PreFabConfig<C>`].
46    pub fn to_prefab(&self) -> PreFabConfig<C> {
47        let builder = self.builder;
48        PreFabConfig {
49            name: self.name.to_string(),
50            description: self.description.to_string(),
51            builder: Arc::new(builder),
52            weights: self.weights.map(|w| w.to_directory()),
53        }
54    }
55
56    /// Build a new config.
57    pub fn to_config(&self) -> C {
58        (self.builder)()
59    }
60}
61
62impl<C> From<&StaticPreFabConfig<C>> for PreFabConfig<C>
63where
64    C: 'static + Config + Debug + Clone,
65{
66    fn from(config: &StaticPreFabConfig<C>) -> Self {
67        config.to_prefab()
68    }
69}
70
71impl<C> Debug for StaticPreFabConfig<C>
72where
73    C: 'static + Config + Debug + Clone,
74{
75    fn fmt(
76        &self,
77        f: &mut core::fmt::Formatter<'_>,
78    ) -> core::fmt::Result {
79        self.to_prefab().fmt(f)
80    }
81}
82
83/// A [`Config`] Well-Known Pre-Fab.
84#[derive(Clone)]
85pub struct PreFabConfig<C>
86where
87    C: 'static + Config + Debug + Clone,
88{
89    /// Name of the model config pre-fab.
90    pub name: String,
91
92    /// Description of the model config pre-fab.
93    pub description: String,
94
95    /// Builder function for the config.
96    pub builder: Arc<dyn Fn() -> C + Send + Sync>,
97
98    /// Pretrained weights map.
99    pub weights: Option<PretrainedWeightsMap>,
100}
101
102impl<C> Debug for PreFabConfig<C>
103where
104    C: 'static + Config + Debug + Clone,
105{
106    fn fmt(
107        &self,
108        f: &mut core::fmt::Formatter<'_>,
109    ) -> core::fmt::Result {
110        let pretty = f.alternate();
111
112        let type_name = core::any::type_name::<C>();
113        let mut handle = f.debug_struct(&format!("PreFabConfig<{}>", type_name));
114
115        handle
116            .field("name", &self.name)
117            .field("description", &self.description);
118
119        if pretty {
120            handle.field("config", &self.to_config());
121        }
122
123        handle.finish()
124    }
125}
126
127impl<C> PreFabConfig<C>
128where
129    C: 'static + Config + Debug + Clone,
130{
131    /// Build a new config.
132    pub fn to_config(&self) -> C {
133        (self.builder)()
134    }
135
136    /// Lookup a descriptor.
137    pub fn lookup_pretrained_weights(
138        &self,
139        name: &str,
140    ) -> Option<PretrainedWeightsDescriptor> {
141        match &self.weights {
142            None => None,
143            Some(m) => m.lookup_by_name(name),
144        }
145    }
146
147    /// Lookup a descriptor.
148    pub fn try_lookup_pretrained_weights(
149        &self,
150        name: &str,
151    ) -> anyhow::Result<PretrainedWeightsDescriptor> {
152        match self.lookup_pretrained_weights(name) {
153            Some(d) => Ok(d),
154            None => bail!("Descriptor not found: {}", name),
155        }
156    }
157
158    /// Lookup a descriptor.
159    pub fn expect_lookup_pretrained_weights(
160        &self,
161        name: &str,
162    ) -> PretrainedWeightsDescriptor {
163        match self.try_lookup_pretrained_weights(name) {
164            Ok(p) => p,
165            Err(e) => panic!("{}", e),
166        }
167    }
168}
169
170/// Static builder for a [`PreFabMap`].
171#[derive(Debug)]
172pub struct StaticPreFabMap<C>
173where
174    C: 'static + Config + Debug + Clone,
175{
176    /// Name of the prefab map.
177    pub name: &'static str,
178
179    /// Description of the prefab map.
180    pub description: &'static str,
181
182    /// List of prefabs.
183    pub items: &'static [&'static StaticPreFabConfig<C>],
184}
185
186impl<C> StaticPreFabMap<C>
187where
188    C: 'static + Config + Debug + Clone,
189{
190    /// Convert to a [`PreFabMap`].
191    pub fn to_prefab_map(&self) -> PreFabMap<C> {
192        PreFabMap {
193            name: self.name.to_string(),
194            description: self.description.to_string(),
195            items: self
196                .items
197                .iter()
198                .map(|c| (c.name.to_string(), c.to_prefab()))
199                .collect(),
200        }
201    }
202
203    /// Lookup a prefab.
204    pub fn lookup_prefab(
205        &self,
206        name: &str,
207    ) -> Option<PreFabConfig<C>> {
208        self.items
209            .iter()
210            .find(|c| c.name == name)
211            .map(|c| c.to_prefab())
212    }
213
214    /// Lookup a prefab.
215    pub fn try_lookup_prefab(
216        &self,
217        name: &str,
218    ) -> anyhow::Result<PreFabConfig<C>> {
219        match self.lookup_prefab(name) {
220            Some(d) => Ok(d),
221            None => bail!("PreFab not found: {}", name),
222        }
223    }
224
225    /// Lookup a prefab.
226    pub fn expect_lookup_prefab(
227        &self,
228        name: &str,
229    ) -> PreFabConfig<C> {
230        match self.try_lookup_prefab(name) {
231            Ok(p) => p,
232            Err(e) => panic!("{}", e),
233        }
234    }
235}
236
237/// A map of [`PreFabConfig`]s.
238#[derive(Debug, Clone)]
239pub struct PreFabMap<C>
240where
241    C: 'static + Config + Debug + Clone,
242{
243    /// Name of the prefab map.
244    pub name: String,
245
246    /// Description of the prefab map.
247    pub description: String,
248
249    /// Map of prefabs.
250    pub items: BTreeMap<String, PreFabConfig<C>>,
251}
252
253impl<C> PreFabMap<C>
254where
255    C: 'static + Config + Debug + Clone,
256{
257    /// Lookup a prefab.
258    pub fn lookup_prefab(
259        &self,
260        name: &str,
261    ) -> Option<PreFabConfig<C>> {
262        self.items.get(name).cloned()
263    }
264
265    /// Lookup a prefab.
266    pub fn try_lookup_prefab(
267        &self,
268        name: &str,
269    ) -> anyhow::Result<PreFabConfig<C>> {
270        match self.lookup_prefab(name) {
271            Some(d) => Ok(d),
272            None => bail!("PreFab not found: {}", name),
273        }
274    }
275
276    /// Lookup a prefab.
277    pub fn expect_lookup_prefab(
278        &self,
279        name: &str,
280    ) -> PreFabConfig<C> {
281        match self.try_lookup_prefab(name) {
282            Ok(p) => p,
283            Err(e) => panic!("{}", e),
284        }
285    }
286}