dbg_cli/session_db/canonicalizer/
cuda.rs1use std::process::{Command, Stdio};
17use std::sync::OnceLock;
18
19use super::{CanonicalSymbol, Canonicalizer};
20
21pub struct CudaCanonicalizer;
22
23impl Canonicalizer for CudaCanonicalizer {
24 fn lang(&self) -> &'static str {
25 "cuda"
26 }
27
28 fn canonicalize(&self, raw: &str) -> CanonicalSymbol {
29 let (demangled_s, used_demangler) = maybe_demangle(raw);
30 let fqn = normalize(&demangled_s);
31
32 CanonicalSymbol {
33 lang: "cuda",
34 fqn,
35 file: None,
36 line: None,
37 demangled: if used_demangler {
38 Some(demangled_s.clone())
39 } else {
40 None
41 },
42 raw: raw.to_string(),
43 is_synthetic: false,
44 }
45 }
46}
47
48fn maybe_demangle(raw: &str) -> (String, bool) {
49 if !raw.starts_with("_Z") {
50 return (raw.to_string(), false);
51 }
52 static AVAILABLE: OnceLock<bool> = OnceLock::new();
53 let available = *AVAILABLE.get_or_init(|| which::which("c++filt").is_ok());
54 if !available {
55 return (raw.to_string(), false);
56 }
57 let out = Command::new("c++filt")
58 .arg(raw)
59 .stdout(Stdio::piped())
60 .stderr(Stdio::null())
61 .output();
62 match out {
63 Ok(o) if o.status.success() => {
64 let s = String::from_utf8_lossy(&o.stdout).trim().to_string();
65 if s.is_empty() || s == raw {
66 (raw.to_string(), false)
67 } else {
68 (s, true)
69 }
70 }
71 _ => (raw.to_string(), false),
72 }
73}
74
75fn normalize(s: &str) -> String {
83 let s = s.trim();
84
85 let body_start = {
90 let (mut depth_angle, mut depth_paren): (i32, i32) = (0, 0);
91 let mut first_space: Option<usize> = None;
92 for (i, ch) in s.char_indices() {
93 match ch {
94 '<' => depth_angle += 1,
95 '>' => depth_angle -= 1,
96 '(' => depth_paren += 1,
97 ')' => depth_paren -= 1,
98 ' ' if depth_angle <= 0 && depth_paren <= 0 => {
99 first_space = Some(i);
100 break;
101 }
102 _ => {}
103 }
104 }
105 match first_space {
109 Some(i) => {
110 let after = &s[i + 1..];
111 if after.contains('<') || after.contains('(') {
112 i + 1
113 } else {
114 0
115 }
116 }
117 None => 0,
118 }
119 };
120 let s = &s[body_start..];
121
122 let mut depth_angle = 0i32;
124 let mut paren_start: Option<usize> = None;
125 for (i, ch) in s.char_indices() {
126 match ch {
127 '<' => depth_angle += 1,
128 '>' => depth_angle -= 1,
129 '(' if depth_angle <= 0 => {
130 paren_start = Some(i);
131 break;
132 }
133 _ => {}
134 }
135 }
136 match paren_start {
137 Some(i) => s[..i].trim().to_string(),
138 None => s.trim().to_string(),
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145
146 fn c() -> CudaCanonicalizer { CudaCanonicalizer }
147
148 #[test]
149 fn simple_kernel_preserved() {
150 let s = c().canonicalize("vector_add");
151 assert_eq!(s.fqn, "vector_add");
152 assert_eq!(s.lang, "cuda");
153 }
154
155 #[test]
156 fn return_type_stripped() {
157 let s = c().canonicalize("void sgemm<float>(float const*, int)");
158 assert_eq!(s.fqn, "sgemm<float>");
159 }
160
161 #[test]
162 fn param_list_stripped() {
163 let s = c().canonicalize("sgemm<float>(float const*, int)");
164 assert_eq!(s.fqn, "sgemm<float>");
165 }
166
167 #[test]
168 fn template_params_preserved_and_distinguishing() {
169 let f = c().canonicalize("void sgemm<float>(float const*, int)");
170 let h = c().canonicalize("void sgemm<half>(half const*, int)");
171 assert_ne!(f.fqn, h.fqn);
172 assert_eq!(f.fqn, "sgemm<float>");
173 assert_eq!(h.fqn, "sgemm<half>");
174 }
175
176 #[test]
177 fn multi_template_params_preserved() {
178 let s = c().canonicalize("void gemm<float, 128, 128, 16>(float const*, int)");
179 assert_eq!(s.fqn, "gemm<float, 128, 128, 16>");
180 }
181
182 #[test]
183 fn qualified_name_preserved() {
184 let s = c().canonicalize("void ns::kernel<int>(int*)");
185 assert_eq!(s.fqn, "ns::kernel<int>");
186 }
187
188 #[test]
189 fn no_parens_no_return_left_alone() {
190 let s = c().canonicalize("sgemm<float>");
191 assert_eq!(s.fqn, "sgemm<float>");
192 }
193
194 #[test]
195 fn key_is_lang_plus_fqn() {
196 let s = c().canonicalize("sgemm<float>");
197 assert_eq!(s.key(), ("cuda", "sgemm<float>"));
198 }
199
200 #[test]
201 fn raw_is_preserved_verbatim() {
202 let input = "void sgemm<float>(float const*, int)";
203 let s = c().canonicalize(input);
204 assert_eq!(s.raw, input);
205 }
206}