1use alloc::{
4 collections::BTreeMap,
5 format,
6 string::{
7 String,
8 ToString,
9 },
10 sync::Arc,
11};
12use core::fmt::Debug;
13
14use anyhow::bail;
15use burn::config::Config;
16
17use crate::{
18 PretrainedWeightsDescriptor,
19 PretrainedWeightsMap,
20 StaticPretrainedWeightsMap,
21};
22
23pub struct StaticPreFabConfig<C>
25where
26 C: 'static + Config + Debug + Clone,
27{
28 pub name: &'static str,
30
31 pub description: &'static str,
33
34 pub builder: fn() -> C,
36
37 pub weights: Option<&'static StaticPretrainedWeightsMap<'static>>,
39}
40
41impl<C> StaticPreFabConfig<C>
42where
43 C: 'static + Config + Debug + Clone,
44{
45 pub fn to_prefab(&self) -> PreFabConfig<C> {
47 let builder = self.builder;
48 PreFabConfig {
49 name: self.name.to_string(),
50 description: self.description.to_string(),
51 builder: Arc::new(builder),
52 weights: self.weights.map(|w| w.to_directory()),
53 }
54 }
55
56 pub fn to_config(&self) -> C {
58 (self.builder)()
59 }
60}
61
62impl<C> From<&StaticPreFabConfig<C>> for PreFabConfig<C>
63where
64 C: 'static + Config + Debug + Clone,
65{
66 fn from(config: &StaticPreFabConfig<C>) -> Self {
67 config.to_prefab()
68 }
69}
70
71impl<C> Debug for StaticPreFabConfig<C>
72where
73 C: 'static + Config + Debug + Clone,
74{
75 fn fmt(
76 &self,
77 f: &mut core::fmt::Formatter<'_>,
78 ) -> core::fmt::Result {
79 self.to_prefab().fmt(f)
80 }
81}
82
83#[derive(Clone)]
85pub struct PreFabConfig<C>
86where
87 C: 'static + Config + Debug + Clone,
88{
89 pub name: String,
91
92 pub description: String,
94
95 pub builder: Arc<dyn Fn() -> C + Send + Sync>,
97
98 pub weights: Option<PretrainedWeightsMap>,
100}
101
102impl<C> Debug for PreFabConfig<C>
103where
104 C: 'static + Config + Debug + Clone,
105{
106 fn fmt(
107 &self,
108 f: &mut core::fmt::Formatter<'_>,
109 ) -> core::fmt::Result {
110 let pretty = f.alternate();
111
112 let type_name = core::any::type_name::<C>();
113 let mut handle = f.debug_struct(&format!("PreFabConfig<{}>", type_name));
114
115 handle
116 .field("name", &self.name)
117 .field("description", &self.description);
118
119 if pretty {
120 handle.field("config", &self.to_config());
121 }
122
123 handle.finish()
124 }
125}
126
127impl<C> PreFabConfig<C>
128where
129 C: 'static + Config + Debug + Clone,
130{
131 pub fn to_config(&self) -> C {
133 (self.builder)()
134 }
135
136 pub fn lookup_pretrained_weights(
138 &self,
139 name: &str,
140 ) -> Option<PretrainedWeightsDescriptor> {
141 match &self.weights {
142 None => None,
143 Some(m) => m.lookup_by_name(name),
144 }
145 }
146
147 pub fn try_lookup_pretrained_weights(
149 &self,
150 name: &str,
151 ) -> anyhow::Result<PretrainedWeightsDescriptor> {
152 match self.lookup_pretrained_weights(name) {
153 Some(d) => Ok(d),
154 None => bail!("Descriptor not found: {}", name),
155 }
156 }
157
158 pub fn expect_lookup_pretrained_weights(
160 &self,
161 name: &str,
162 ) -> PretrainedWeightsDescriptor {
163 match self.try_lookup_pretrained_weights(name) {
164 Ok(p) => p,
165 Err(e) => panic!("{}", e),
166 }
167 }
168}
169
170#[derive(Debug)]
172pub struct StaticPreFabMap<C>
173where
174 C: 'static + Config + Debug + Clone,
175{
176 pub name: &'static str,
178
179 pub description: &'static str,
181
182 pub items: &'static [&'static StaticPreFabConfig<C>],
184}
185
186impl<C> StaticPreFabMap<C>
187where
188 C: 'static + Config + Debug + Clone,
189{
190 pub fn to_prefab_map(&self) -> PreFabMap<C> {
192 PreFabMap {
193 name: self.name.to_string(),
194 description: self.description.to_string(),
195 items: self
196 .items
197 .iter()
198 .map(|c| (c.name.to_string(), c.to_prefab()))
199 .collect(),
200 }
201 }
202
203 pub fn lookup_prefab(
205 &self,
206 name: &str,
207 ) -> Option<PreFabConfig<C>> {
208 self.items
209 .iter()
210 .find(|c| c.name == name)
211 .map(|c| c.to_prefab())
212 }
213
214 pub fn try_lookup_prefab(
216 &self,
217 name: &str,
218 ) -> anyhow::Result<PreFabConfig<C>> {
219 match self.lookup_prefab(name) {
220 Some(d) => Ok(d),
221 None => bail!("PreFab not found: {}", name),
222 }
223 }
224
225 pub fn expect_lookup_prefab(
227 &self,
228 name: &str,
229 ) -> PreFabConfig<C> {
230 match self.try_lookup_prefab(name) {
231 Ok(p) => p,
232 Err(e) => panic!("{}", e),
233 }
234 }
235}
236
237#[derive(Debug, Clone)]
239pub struct PreFabMap<C>
240where
241 C: 'static + Config + Debug + Clone,
242{
243 pub name: String,
245
246 pub description: String,
248
249 pub items: BTreeMap<String, PreFabConfig<C>>,
251}
252
253impl<C> PreFabMap<C>
254where
255 C: 'static + Config + Debug + Clone,
256{
257 pub fn lookup_prefab(
259 &self,
260 name: &str,
261 ) -> Option<PreFabConfig<C>> {
262 self.items.get(name).cloned()
263 }
264
265 pub fn try_lookup_prefab(
267 &self,
268 name: &str,
269 ) -> anyhow::Result<PreFabConfig<C>> {
270 match self.lookup_prefab(name) {
271 Some(d) => Ok(d),
272 None => bail!("PreFab not found: {}", name),
273 }
274 }
275
276 pub fn expect_lookup_prefab(
278 &self,
279 name: &str,
280 ) -> PreFabConfig<C> {
281 match self.try_lookup_prefab(name) {
282 Ok(p) => p,
283 Err(e) => panic!("{}", e),
284 }
285 }
286}