Skip to main content

cubecl_common/
format.rs

1use alloc::format;
2use alloc::string::String;
3
4/// Print string without quotes
5pub struct DebugRaw<'a>(pub &'a str);
6
7impl<'a> core::fmt::Debug for DebugRaw<'a> {
8    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
9        write!(f, "{}", self.0)
10    }
11}
12
13/// Format strings for use in identifiers and types.
14pub fn format_str(string: &str, markers: &[(char, char)], include_space: bool) -> String {
15    let mut result = String::new();
16    let mut depth = 0;
17    let indentation = 4;
18
19    let mut prev = ' ';
20    let mut in_string = false;
21
22    for c in string.chars() {
23        if c == ' ' {
24            if in_string {
25                result.push(c);
26            }
27
28            continue;
29        }
30        if c == '"' {
31            in_string = !in_string;
32        }
33
34        let mut found_marker = false;
35
36        for (start, end) in markers {
37            let (start, end) = (*start, *end);
38
39            if c == start {
40                depth += 1;
41                if prev != ' ' && include_space {
42                    result.push(' ');
43                }
44                result.push(start);
45                result.push('\n');
46                result.push_str(&" ".repeat(indentation * depth));
47                found_marker = true;
48            } else if c == end {
49                depth -= 1;
50                if prev != start {
51                    if prev == ' ' {
52                        result.pop();
53                    }
54                    result.push_str(",\n");
55                    result.push_str(&" ".repeat(indentation * depth));
56                    result.push(end);
57                } else {
58                    for _ in 0..(&" ".repeat(indentation * depth).len()) + 1 + indentation {
59                        result.pop();
60                    }
61                    result.push(end);
62                }
63                found_marker = true;
64            }
65        }
66
67        if found_marker {
68            prev = c;
69            continue;
70        }
71
72        if c == ',' && depth > 0 {
73            if prev == ' ' {
74                result.pop();
75            }
76
77            result.push_str(",\n");
78            result.push_str(&" ".repeat(indentation * depth));
79            continue;
80        }
81
82        if c == ':' && include_space {
83            result.push(c);
84            result.push(' ');
85            prev = ' ';
86        } else {
87            result.push(c);
88            prev = c;
89        }
90    }
91
92    result
93}
94
95/// Format a debug type.
96pub fn format_debug<F: core::fmt::Debug>(string: &F) -> String {
97    let string = format!("{string:?}");
98    format_str(&string, &[('(', ')'), ('[', ']'), ('{', '}')], true)
99}
100
101/// Provide a short, sanitized type name for use in function or file names
102pub fn type_name_short_sanitized<T>() -> String {
103    let name = tynm::type_name::<T>();
104    name.replace(|c: char| !c.is_alphanumeric() && c != '_', "_")
105}
106
107#[cfg(test)]
108#[cfg(feature = "std")]
109mod tests {
110    use alloc::string::ToString;
111    use hashbrown::HashMap;
112
113    use super::*;
114
115    #[derive(Debug)]
116    #[allow(unused)]
117    struct Test {
118        map: HashMap<String, u32>,
119    }
120
121    #[test_log::test]
122    fn test_format_debug() {
123        let test = Test {
124            map: HashMap::from_iter([("Hey with space".to_string(), 8)]),
125        };
126
127        let formatted = format_debug(&test);
128        let expected = r#"Test {
129    map: {
130        "Hey with space": 8,
131    },
132}"#;
133
134        assert_eq!(expected, formatted);
135    }
136}