Skip to main content

dbg_cli/session_db/canonicalizer/
cuda.rs

1//! CUDA kernel symbol canonicalization.
2//!
3//! Kernels come from nsys / ncu in two shapes:
4//!   * Demangled C++:
5//!     `void sgemm_128x128<float>(float const*, float const*, float*, int, int, int)`
6//!     →  canonical `sgemm_128x128<float>`
7//!     (drop leading return type, drop parenthesized parameter list,
8//!      KEEP template parameters — they distinguish instantiations).
9//!   * Raw mangled (`_Z...`) — delegate to `c++filt` like cxx does, then
10//!     apply the same normalization.
11//!
12//! Template parameters are MANDATORY to preserve: agents correlate
13//! `sgemm<float>` and `sgemm<half>` as different rows and a canonical
14//! form that drops them would merge and lose the distinction.
15
16use std::process::{Command, Stdio};
17use std::sync::OnceLock;
18
19use super::{CanonicalSymbol, Canonicalizer};
20
21pub struct CudaCanonicalizer;
22
23impl Canonicalizer for CudaCanonicalizer {
24    fn lang(&self) -> &'static str {
25        "cuda"
26    }
27
28    fn canonicalize(&self, raw: &str) -> CanonicalSymbol {
29        let (demangled_s, used_demangler) = maybe_demangle(raw);
30        let fqn = normalize(&demangled_s);
31
32        CanonicalSymbol {
33            lang: "cuda",
34            fqn,
35            file: None,
36            line: None,
37            demangled: if used_demangler {
38                Some(demangled_s.clone())
39            } else {
40                None
41            },
42            raw: raw.to_string(),
43            is_synthetic: false,
44        }
45    }
46}
47
48fn maybe_demangle(raw: &str) -> (String, bool) {
49    if !raw.starts_with("_Z") {
50        return (raw.to_string(), false);
51    }
52    static AVAILABLE: OnceLock<bool> = OnceLock::new();
53    let available = *AVAILABLE.get_or_init(|| which::which("c++filt").is_ok());
54    if !available {
55        return (raw.to_string(), false);
56    }
57    let out = Command::new("c++filt")
58        .arg(raw)
59        .stdout(Stdio::piped())
60        .stderr(Stdio::null())
61        .output();
62    match out {
63        Ok(o) if o.status.success() => {
64            let s = String::from_utf8_lossy(&o.stdout).trim().to_string();
65            if s.is_empty() || s == raw {
66                (raw.to_string(), false)
67            } else {
68                (s, true)
69            }
70        }
71        _ => (raw.to_string(), false),
72    }
73}
74
75/// Drop the leading return type and the trailing parenthesized parameter
76/// list while preserving template parameters.
77///
78/// Input shapes we handle:
79///   * `T fn<Args>(Params)`          — typed demangled
80///   * `fn<Args>(Params)`            — no return type
81///   * `fn<Args>`                    — neither
82fn normalize(s: &str) -> String {
83    let s = s.trim();
84
85    // 1. Drop a leading return-type token if one is present.
86    //    Heuristic: the first angle bracket or paren determines where
87    //    the symbol body begins; if there's a space before that which
88    //    isn't inside brackets, everything before it was the return type.
89    let body_start = {
90        let (mut depth_angle, mut depth_paren): (i32, i32) = (0, 0);
91        let mut first_space: Option<usize> = None;
92        for (i, ch) in s.char_indices() {
93            match ch {
94                '<' => depth_angle += 1,
95                '>' => depth_angle -= 1,
96                '(' => depth_paren += 1,
97                ')' => depth_paren -= 1,
98                ' ' if depth_angle <= 0 && depth_paren <= 0 => {
99                    first_space = Some(i);
100                    break;
101                }
102                _ => {}
103            }
104        }
105        // We only treat the prefix as a return type if a `<` or `(`
106        // appears AFTER that space — otherwise the space was part of an
107        // unusual symbol name and we leave it alone.
108        match first_space {
109            Some(i) => {
110                let after = &s[i + 1..];
111                if after.contains('<') || after.contains('(') {
112                    i + 1
113                } else {
114                    0
115                }
116            }
117            None => 0,
118        }
119    };
120    let s = &s[body_start..];
121
122    // 2. Drop a trailing parenthesized parameter list at top level.
123    let mut depth_angle = 0i32;
124    let mut paren_start: Option<usize> = None;
125    for (i, ch) in s.char_indices() {
126        match ch {
127            '<' => depth_angle += 1,
128            '>' => depth_angle -= 1,
129            '(' if depth_angle <= 0 => {
130                paren_start = Some(i);
131                break;
132            }
133            _ => {}
134        }
135    }
136    match paren_start {
137        Some(i) => s[..i].trim().to_string(),
138        None => s.trim().to_string(),
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145
146    fn c() -> CudaCanonicalizer { CudaCanonicalizer }
147
148    #[test]
149    fn simple_kernel_preserved() {
150        let s = c().canonicalize("vector_add");
151        assert_eq!(s.fqn, "vector_add");
152        assert_eq!(s.lang, "cuda");
153    }
154
155    #[test]
156    fn return_type_stripped() {
157        let s = c().canonicalize("void sgemm<float>(float const*, int)");
158        assert_eq!(s.fqn, "sgemm<float>");
159    }
160
161    #[test]
162    fn param_list_stripped() {
163        let s = c().canonicalize("sgemm<float>(float const*, int)");
164        assert_eq!(s.fqn, "sgemm<float>");
165    }
166
167    #[test]
168    fn template_params_preserved_and_distinguishing() {
169        let f = c().canonicalize("void sgemm<float>(float const*, int)");
170        let h = c().canonicalize("void sgemm<half>(half const*, int)");
171        assert_ne!(f.fqn, h.fqn);
172        assert_eq!(f.fqn, "sgemm<float>");
173        assert_eq!(h.fqn, "sgemm<half>");
174    }
175
176    #[test]
177    fn multi_template_params_preserved() {
178        let s = c().canonicalize("void gemm<float, 128, 128, 16>(float const*, int)");
179        assert_eq!(s.fqn, "gemm<float, 128, 128, 16>");
180    }
181
182    #[test]
183    fn qualified_name_preserved() {
184        let s = c().canonicalize("void ns::kernel<int>(int*)");
185        assert_eq!(s.fqn, "ns::kernel<int>");
186    }
187
188    #[test]
189    fn no_parens_no_return_left_alone() {
190        let s = c().canonicalize("sgemm<float>");
191        assert_eq!(s.fqn, "sgemm<float>");
192    }
193
194    #[test]
195    fn key_is_lang_plus_fqn() {
196        let s = c().canonicalize("sgemm<float>");
197        assert_eq!(s.key(), ("cuda", "sgemm<float>"));
198    }
199
200    #[test]
201    fn raw_is_preserved_verbatim() {
202        let input = "void sgemm<float>(float const*, int)";
203        let s = c().canonicalize(input);
204        assert_eq!(s.raw, input);
205    }
206}