kernel_abi_check/torch_stable_abi/
mod.rs1use std::collections::{BTreeSet, HashMap};
2use std::str::FromStr;
3
4use cpp_demangle::Symbol as CppSymbol;
5use eyre::Result;
6use object::{BinaryFormat, ObjectSymbol, Symbol};
7use once_cell::sync::Lazy;
8
9use crate::version::Version;
10
11static SHIM_FUNCTION_VERSIONS_RAW: &str = include_str!("shim_function_versions.txt");
13
14pub static TORCH_SHIM_VERSIONS: Lazy<HashMap<String, Version>> = Lazy::new(|| {
17 let mut map = HashMap::new();
18 for line in SHIM_FUNCTION_VERSIONS_RAW.lines() {
19 let line = line.trim();
21 if line.is_empty() || line.starts_with('#') {
22 continue;
23 }
24 if let Some((name, version_token)) = line.split_once(':') {
25 let name = name.trim().to_owned();
26 let version_str = version_token
28 .trim()
29 .strip_prefix("TORCH_VERSION_")
30 .expect("unexpected version token format")
31 .replace('_', ".");
32 let version = Version::from_str(&version_str)
33 .expect("invalid version in shim_function_versions.txt");
34 map.insert(name, version);
35 }
36 }
37 map
38});
39
40#[derive(Debug, Clone, Eq, Ord, PartialEq, PartialOrd)]
42pub enum TorchStableAbiViolation {
43 IncompatibleStableAbiSymbol { name: String, added: Version },
45
46 NonStableAbiSymbol { name: String },
48}
49
50pub fn check_torch_stable_abi<'a>(
52 torch_stable_abi: &Version,
53 binary_format: BinaryFormat,
54 symbols: impl IntoIterator<Item = Symbol<'a, 'a>>,
55) -> Result<BTreeSet<TorchStableAbiViolation>> {
56 let mut violations = BTreeSet::new();
57
58 for symbol in symbols {
59 if !symbol.is_undefined() {
60 continue;
61 }
62
63 let mut symbol_name = symbol.name()?;
64 if matches!(binary_format, BinaryFormat::MachO) {
65 symbol_name = symbol_name.strip_prefix("_").unwrap_or(symbol_name);
67 }
68
69 if let Some(symbol_version) = TORCH_SHIM_VERSIONS.get(symbol_name) {
71 if symbol_version > torch_stable_abi {
72 violations.insert(TorchStableAbiViolation::IncompatibleStableAbiSymbol {
73 name: symbol_name.to_owned(),
74 added: symbol_version.clone(),
75 });
76 }
77 continue;
78 }
79
80 let cpp_symbol = match CppSymbol::new(symbol_name) {
83 Ok(cpp_symbol) => cpp_symbol,
84 Err(_) => {
85 continue;
86 }
87 };
88 let demangled = cpp_symbol.demangle()?;
89
90 if demangled.starts_with("torch::stable::") {
92 } else if demangled.starts_with("c10::")
96 || demangled.starts_with("at::")
97 || demangled.starts_with("torch::")
98 {
99 violations.insert(TorchStableAbiViolation::NonStableAbiSymbol { name: demangled });
100 }
101 }
102
103 Ok(violations)
104}