burn_core/module/display.rs
1use alloc::{
2 borrow::ToOwned,
3 format,
4 string::{String, ToString},
5 vec::Vec,
6};
7use core::any;
8use core::fmt::{Display, Write};
9
10/// Default display settings for a module.
11pub trait ModuleDisplayDefault {
12 /// Attributes of the module used for display purposes.
13 ///
14 /// # Arguments
15 ///
16 /// * `_content` - The content object that contains display settings and attributes.
17 ///
18 /// # Returns
19 ///
20 /// An optional content object containing the display attributes.
21 fn content(&self, _content: Content) -> Option<Content>;
22
23 /// Gets the number of the parameters of the module.
24 fn num_params(&self) -> usize {
25 0
26 }
27}
28
29/// Trait to implement custom display settings for a module.
30///
31/// In order to implement custom display settings for a module,
32/// 1. Add #[module(custom_display)] attribute to the module struct after #[derive(Module)]
33/// 2. Implement ModuleDisplay trait for the module
34pub trait ModuleDisplay: ModuleDisplayDefault {
35 /// Formats the module with provided display settings.
36 ///
37 /// # Arguments
38 ///
39 /// * `passed_settings` - Display settings passed to the module.
40 ///
41 /// # Returns
42 ///
43 /// A string representation of the formatted module.
44 fn format(&self, passed_settings: DisplaySettings) -> String {
45 let settings = if let Some(custom_settings) = self.custom_settings() {
46 custom_settings.inherit(passed_settings)
47 } else {
48 passed_settings
49 };
50
51 let indent = " ".repeat(settings.level * settings.indentation_size());
52 let indent_close_braces = " ".repeat((settings.level - 1) * settings.indentation_size());
53
54 let settings = settings.level_up();
55
56 let self_type = extract_type_name::<Self>();
57
58 // Use custom content if it is implemented and show_all_attributes is false,
59 // otherwise use default content
60 let content = if !settings.show_all_attributes() {
61 self.custom_content(Content::new(settings.clone()))
62 .unwrap_or_else(|| {
63 self.content(Content::new(settings.clone()))
64 .unwrap_or_else(|| {
65 panic!("Default content should be implemented for {self_type}.")
66 })
67 })
68 } else {
69 self.content(Content::new(settings.clone()))
70 .unwrap_or_else(|| panic!("Default content should be implemented for {self_type}."))
71 };
72
73 let top_level_type = if let Some(top_level_type) = content.top_level_type {
74 top_level_type.to_owned()
75 } else {
76 self_type.to_owned()
77 };
78
79 // If there is only one item in the content, return it or no attributes
80 if let Some(item) = content.single_item {
81 return item;
82 } else if content.attributes.is_empty() {
83 return top_level_type.to_string();
84 }
85
86 let mut result = String::new();
87
88 // Print the struct name
89 if settings.new_line_after_attribute() {
90 writeln!(result, "{top_level_type} {{").unwrap();
91 } else {
92 write!(result, "{top_level_type} {{").unwrap();
93 }
94
95 for (i, attribute) in content.attributes.iter().enumerate() {
96 if settings.new_line_after_attribute() {
97 writeln!(result, "{indent}{}: {}", attribute.name, attribute.value).unwrap();
98 } else if i == 0 {
99 write!(result, "{}: {}", attribute.name, attribute.value).unwrap();
100 } else {
101 write!(result, ", {}: {}", attribute.name, attribute.value).unwrap();
102 }
103 }
104
105 if settings.show_num_parameters() {
106 let num_params = self.num_params();
107 if num_params > 0 {
108 if settings.new_line_after_attribute() {
109 writeln!(result, "{indent}params: {num_params}").unwrap();
110 } else {
111 write!(result, ", params: {num_params}").unwrap();
112 }
113 }
114 }
115
116 if settings.new_line_after_attribute() {
117 write!(result, "{indent_close_braces}}}").unwrap();
118 } else {
119 write!(result, "}}").unwrap();
120 }
121
122 result
123 }
124
125 /// Custom display settings for the module.
126 ///
127 /// # Returns
128 ///
129 /// An optional display settings object.
130 fn custom_settings(&self) -> Option<DisplaySettings> {
131 None
132 }
133
134 /// Custom attributes for the module.
135 ///
136 /// # Arguments
137 ///
138 /// * `_content` - The content object that contains display settings and attributes.
139 ///
140 /// # Returns
141 ///
142 /// An optional content object containing the custom attributes.
143 fn custom_content(&self, _content: Content) -> Option<Content> {
144 None
145 }
146}
147
148/// Custom module display settings.
149#[derive(Debug, Clone)]
150pub struct DisplaySettings {
151 /// Whether to print the module parameter ids.
152 show_param_id: Option<bool>,
153
154 /// Whether to print the module attributes.
155 show_all_attributes: Option<bool>,
156
157 /// Whether to print the module number of parameters.
158 show_num_parameters: Option<bool>,
159
160 /// Print new line after an attribute.
161 new_line_after_attribute: Option<bool>,
162
163 /// Indentation size.
164 indentation_size: Option<usize>,
165
166 /// Level of indentation.
167 level: usize,
168}
169
170impl Default for DisplaySettings {
171 fn default() -> Self {
172 DisplaySettings {
173 show_param_id: None,
174 show_all_attributes: None,
175 show_num_parameters: None,
176 new_line_after_attribute: None,
177 indentation_size: None,
178 level: 1,
179 }
180 }
181}
182
183impl DisplaySettings {
184 /// Create a new format settings.
185 ///
186 /// # Returns
187 ///
188 /// A new instance of `DisplaySettings`.
189 pub fn new() -> Self {
190 Default::default()
191 }
192
193 /// Sets a flag to show module parameters.
194 ///
195 /// # Arguments
196 ///
197 /// * `flag` - Boolean flag to show module parameters.
198 ///
199 /// # Returns
200 ///
201 /// Updated `DisplaySettings` instance.
202 pub fn with_show_param_id(mut self, flag: bool) -> Self {
203 self.show_param_id = Some(flag);
204 self
205 }
206
207 /// Sets a flag to show module attributes.
208 ///
209 /// # Arguments
210 ///
211 /// * `flag` - Boolean flag to show all module attributes.
212 ///
213 /// # Returns
214 ///
215 /// Updated `DisplaySettings` instance.
216 pub fn with_show_all_attributes(mut self, flag: bool) -> Self {
217 self.show_all_attributes = Some(flag);
218 self
219 }
220
221 /// Sets a flag to show the number of module parameters.
222 ///
223 /// # Arguments
224 ///
225 /// * `flag` - Boolean flag to show the number of module parameters.
226 ///
227 /// # Returns
228 ///
229 /// Updated `DisplaySettings` instance.
230 pub fn with_show_num_parameters(mut self, flag: bool) -> Self {
231 self.show_num_parameters = Some(flag);
232 self
233 }
234
235 /// Sets a flag to print a new line after an attribute.
236 ///
237 /// # Arguments
238 ///
239 /// * `flag` - Boolean flag to print a new line after an attribute.
240 ///
241 /// # Returns
242 ///
243 /// Updated `DisplaySettings` instance.
244 pub fn with_new_line_after_attribute(mut self, flag: bool) -> Self {
245 self.new_line_after_attribute = Some(flag);
246 self
247 }
248
249 /// Sets the indentation size.
250 ///
251 /// # Arguments
252 ///
253 /// * `size` - The size of the indentation.
254 ///
255 /// # Returns
256 ///
257 /// Updated `DisplaySettings` instance.
258 pub fn with_indentation_size(mut self, size: usize) -> Self {
259 self.indentation_size = Some(size);
260 self
261 }
262
263 /// Inherits settings from the provided settings and return a new settings object.
264 ///
265 /// # Arguments
266 ///
267 /// * `top` - The top level `DisplaySettings` to inherit from.
268 ///
269 /// # Returns
270 ///
271 /// Updated `DisplaySettings` instance.
272 pub fn inherit(self, top: Self) -> Self {
273 let mut updated = self.clone();
274
275 if let Some(show_param_id) = top.show_param_id {
276 updated.show_param_id = Some(show_param_id);
277 };
278
279 if let Some(show_all_attributes) = top.show_all_attributes {
280 updated.show_all_attributes = Some(show_all_attributes);
281 }
282
283 if let Some(show_num_parameters) = top.show_num_parameters {
284 updated.show_num_parameters = Some(show_num_parameters);
285 }
286
287 if let Some(new_line_after_attribute) = top.new_line_after_attribute {
288 updated.new_line_after_attribute = Some(new_line_after_attribute);
289 }
290
291 if let Some(indentation_size) = top.indentation_size {
292 updated.indentation_size = Some(indentation_size);
293 }
294
295 updated.level = top.level;
296
297 updated
298 }
299
300 /// A convenience method to wrap the DisplaySettings struct in an option.
301 ///
302 /// # Returns
303 ///
304 /// An optional `DisplaySettings`.
305 pub fn optional(self) -> Option<Self> {
306 Some(self)
307 }
308
309 /// Increases the level of indentation.
310 ///
311 /// # Returns
312 ///
313 /// Updated `DisplaySettings` instance with increased indentation level.
314 pub fn level_up(mut self) -> Self {
315 self.level += 1;
316 self
317 }
318
319 /// Gets `show_param_id` flag, substitutes false if not set.
320 ///
321 /// This flag is used to print the module parameter ids.
322 ///
323 /// # Returns
324 ///
325 /// A boolean value indicating whether to show parameter ids.
326 pub fn show_param_id(&self) -> bool {
327 self.show_param_id.unwrap_or(false)
328 }
329
330 /// Gets `show_all_attributes`, substitutes false if not set.
331 ///
332 /// This flag is used to force to print all module attributes, overriding custom attributes.
333 ///
334 /// # Returns
335 ///
336 /// A boolean value indicating whether to show all attributes.
337 pub fn show_all_attributes(&self) -> bool {
338 self.show_all_attributes.unwrap_or(false)
339 }
340
341 /// Gets `show_num_parameters`, substitutes true if not set.
342 ///
343 /// This flag is used to print the number of module parameters.
344 ///
345 /// # Returns
346 ///
347 /// A boolean value indicating whether to show the number of parameters.
348 pub fn show_num_parameters(&self) -> bool {
349 self.show_num_parameters.unwrap_or(true)
350 }
351
352 /// Gets `new_line_after_attribute`, substitutes true if not set.
353 ///
354 /// This flag is used to print a new line after an attribute.
355 ///
356 /// # Returns
357 ///
358 /// A boolean value indicating whether to print a new line after an attribute.
359 pub fn new_line_after_attribute(&self) -> bool {
360 self.new_line_after_attribute.unwrap_or(true)
361 }
362
363 /// Gets `indentation_size`, substitutes 2 if not set.
364 ///
365 /// This flag is used to set the size of indentation.
366 ///
367 /// # Returns
368 ///
369 /// An integer value indicating the size of indentation.
370 pub fn indentation_size(&self) -> usize {
371 self.indentation_size.unwrap_or(2)
372 }
373}
374
375/// Struct to store the attributes of a module for formatting.
376#[derive(Clone, Debug)]
377pub struct Content {
378 /// List of attributes.
379 pub attributes: Vec<Attribute>,
380
381 /// Single item content.
382 pub single_item: Option<String>,
383
384 /// Display settings.
385 pub display_settings: DisplaySettings,
386
387 /// Top level type name.
388 pub top_level_type: Option<String>,
389}
390
391impl Content {
392 /// Creates a new attributes struct.
393 ///
394 /// # Arguments
395 ///
396 /// * `display_settings` - Display settings for the content.
397 ///
398 /// # Returns
399 ///
400 /// A new instance of `Content`.
401 pub fn new(display_settings: DisplaySettings) -> Self {
402 Content {
403 attributes: Vec::new(),
404 single_item: None,
405 display_settings,
406 top_level_type: None,
407 }
408 }
409
410 /// Adds an attribute to the format settings. The value will be formatted and stored as a string.
411 ///
412 /// # Arguments
413 ///
414 /// * `name` - Name of the attribute.
415 /// * `value` - Value of the attribute.
416 ///
417 /// # Returns
418 ///
419 /// Updated `Content` instance with the new attribute added.
420 pub fn add<T: ModuleDisplay + ?Sized>(mut self, name: &str, value: &T) -> Self {
421 if self.single_item.is_some() {
422 panic!("Cannot add multiple attributes when single item is set.");
423 }
424
425 let attribute = Attribute {
426 name: name.to_owned(),
427 value: value.format(self.display_settings.clone()), // TODO level + 1
428 ty: any::type_name::<T>().to_string(),
429 };
430 self.attributes.push(attribute);
431 self
432 }
433
434 /// Adds a single item.
435 ///
436 /// # Arguments
437 ///
438 /// * `value` - Rendered string of the single item.
439 ///
440 /// # Returns
441 ///
442 /// Updated `Content` instance with the single item added.
443 pub fn add_single<T: ModuleDisplay + ?Sized>(mut self, value: &T) -> Self {
444 if !self.attributes.is_empty() {
445 panic!("Cannot add single item when attributes are set.");
446 }
447
448 self.single_item = Some(value.format(self.display_settings.clone()));
449
450 self
451 }
452
453 /// Adds a single item.
454 ///
455 /// # Arguments
456 ///
457 /// * `value` - Formatted display value.
458 ///
459 /// # Returns
460 ///
461 /// Updated `Content` instance with the formatted single item added.
462 pub fn add_formatted<T: Display>(mut self, value: &T) -> Self {
463 if !self.attributes.is_empty() {
464 panic!("Cannot add single item when attributes are set.");
465 }
466
467 self.single_item = Some(format!("{value}"));
468 self
469 }
470
471 /// A convenience method to wrap the Attributes struct in an option
472 /// because it is often used as an optional field.
473 ///
474 /// # Returns
475 ///
476 /// An optional `Content`.
477 pub fn optional(self) -> Option<Self> {
478 if self.attributes.is_empty() && self.single_item.is_none() && self.top_level_type.is_none()
479 {
480 None
481 } else {
482 Some(self)
483 }
484 }
485
486 /// Sets the top level type name.
487 ///
488 /// # Arguments
489 ///
490 /// * `ty` - The type name to set.
491 ///
492 /// # Returns
493 ///
494 /// Updated `Content` instance with the top level type name set.
495 pub fn set_top_level_type(mut self, ty: &str) -> Self {
496 self.top_level_type = Some(ty.to_owned());
497 self
498 }
499}
500
501/// Attribute to print in the display method.
502#[derive(Clone, Debug)]
503pub struct Attribute {
504 /// Name of the attribute.
505 pub name: String,
506
507 /// Value of the attribute.
508 pub value: String,
509
510 /// Type of the attribute.
511 pub ty: String,
512}
513
514/// Extracts the short name of a type T
515///
516/// # Returns
517///
518/// A string slice representing the short name of the type.
519pub fn extract_type_name<T: ?Sized>() -> &'static str {
520 // Get the full type name of T, including module path and generic parameters
521 let ty = any::type_name::<T>();
522
523 // Find the first occurrence of '<' in the full type name
524 // If not found, use the length of the type name
525 let end = ty.find('<').unwrap_or(ty.len());
526
527 // Slice the type name up to the first '<' or the end
528 let ty = &ty[0..end];
529
530 // Find the last occurrence of "::" in the sliced type name
531 // If found, add 2 to skip the "::" itself
532 // If not found, start from the beginning of the type name
533 let start = ty.rfind("::").map(|i| i + 2).unwrap_or(0);
534
535 // Find the last occurrence of '<' in the sliced type name
536 // If not found, use the length of the type name
537 let end = ty.rfind('<').unwrap_or(ty.len());
538
539 // If the start index is less than the end index,
540 // return the slice of the type name from start to end
541 // Otherwise, return the entire sliced type name
542 if start < end { &ty[start..end] } else { ty }
543}