Skip to main content

kernel_abi_check/torch_stable_abi/
mod.rs

1use std::collections::{BTreeSet, HashMap};
2use std::str::FromStr;
3
4use cpp_demangle::Symbol as CppSymbol;
5use eyre::Result;
6use object::{BinaryFormat, ObjectSymbol, Symbol};
7use once_cell::sync::Lazy;
8
9use crate::version::Version;
10
11// https://raw.githubusercontent.com/pytorch/pytorch/refs/heads/main/torch/csrc/stable/c/shim_function_versions.txt
12static SHIM_FUNCTION_VERSIONS_RAW: &str = include_str!("shim_function_versions.txt");
13
14/// Maps shim function names to the minimum Torch version that introduced them.
15/// Functions absent from this map were available before 2.10.0.
16pub static TORCH_SHIM_VERSIONS: Lazy<HashMap<String, Version>> = Lazy::new(|| {
17    let mut map = HashMap::new();
18    for line in SHIM_FUNCTION_VERSIONS_RAW.lines() {
19        // Skip blank lines and comments.
20        let line = line.trim();
21        if line.is_empty() || line.starts_with('#') {
22            continue;
23        }
24        if let Some((name, version_token)) = line.split_once(':') {
25            let name = name.trim().to_owned();
26            // TORCH_VERSION_2_10_0 -> "2.10.0"
27            let version_str = version_token
28                .trim()
29                .strip_prefix("TORCH_VERSION_")
30                .expect("unexpected version token format")
31                .replace('_', ".");
32            let version = Version::from_str(&version_str)
33                .expect("invalid version in shim_function_versions.txt");
34            map.insert(name, version);
35        }
36    }
37    map
38});
39
40/// Torch stable ABI violation.
41#[derive(Debug, Clone, Eq, Ord, PartialEq, PartialOrd)]
42pub enum TorchStableAbiViolation {
43    /// Symbol is newer than the specified Torch Stable ABI version.
44    IncompatibleStableAbiSymbol { name: String, added: Version },
45
46    /// Symbol is not part of ABI3.
47    NonStableAbiSymbol { name: String },
48}
49
50/// Check for violations of the Python ABI policy.
51pub fn check_torch_stable_abi<'a>(
52    torch_stable_abi: &Version,
53    binary_format: BinaryFormat,
54    symbols: impl IntoIterator<Item = Symbol<'a, 'a>>,
55) -> Result<BTreeSet<TorchStableAbiViolation>> {
56    let mut violations = BTreeSet::new();
57
58    for symbol in symbols {
59        if !symbol.is_undefined() {
60            continue;
61        }
62
63        let mut symbol_name = symbol.name()?;
64        if matches!(binary_format, BinaryFormat::MachO) {
65            // Mach-O C symbol mangling adds an underscore.
66            symbol_name = symbol_name.strip_prefix("_").unwrap_or(symbol_name);
67        }
68
69        // If this is a C shim symbol, check if it is valid for this version.
70        if let Some(symbol_version) = TORCH_SHIM_VERSIONS.get(symbol_name) {
71            if symbol_version > torch_stable_abi {
72                violations.insert(TorchStableAbiViolation::IncompatibleStableAbiSymbol {
73                    name: symbol_name.to_owned(),
74                    added: symbol_version.clone(),
75                });
76            }
77            continue;
78        }
79
80        // Try to demangle the symbol as a C++ symbol. If that fails, it's probably an
81        // unrelated C symbol.
82        let cpp_symbol = match CppSymbol::new(symbol_name) {
83            Ok(cpp_symbol) => cpp_symbol,
84            Err(_) => {
85                continue;
86            }
87        };
88        let demangled = cpp_symbol.demangle()?;
89
90        // Check if Torch symbols are from the stable ABI.
91        if demangled.starts_with("torch::stable::") {
92            // This branch fulfills to purposes: (1) avoid that stable ABI
93            // C++ symbols get reported by the filter below. (2) Once a
94            // versioned list of symbols is available, check versions.
95        } else if demangled.starts_with("c10::")
96            || demangled.starts_with("at::")
97            || demangled.starts_with("torch::")
98        {
99            violations.insert(TorchStableAbiViolation::NonStableAbiSymbol { name: demangled });
100        }
101    }
102
103    Ok(violations)
104}