use derive_more::Display;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use std::{
borrow::Cow,
cell::RefCell,
collections::{HashMap, HashSet},
panic::Location,
path::{Path, PathBuf},
sync::OnceLock,
};
#[derive(Debug, Clone, PartialEq, Eq, Hash, Display)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[display("{file}:{line}:{column}")]
pub struct SourceLocation<'i> {
pub file: Cow<'i, str>,
pub line: u32,
pub column: u32,
}
impl SourceLocation<'static> {
#[cfg(test)]
pub(crate) fn new<F: Into<String>>(file: F, line: u32, column: u32) -> Self {
Self { file: Cow::Owned(file.into()), line, column }
}
pub fn from_caller(loc: &'static Location<'static>) -> Self {
Self { file: Cow::Borrowed(get_relative_location(loc)), line: loc.line(), column: loc.column() }
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct OnnxNodeInfo {
pub name: Option<String>,
pub op_type: String,
pub domain: String,
pub version: i64,
}
impl std::fmt::Display for OnnxNodeInfo {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "ONNX:{}/{} v{}", self.domain, self.op_type, self.version)?;
if let Some(ref name) = self.name {
write!(f, " ({})", name)?;
}
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Display)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum PassName {
#[display("substitute")]
Substitute,
#[display("shift_to")]
ShiftTo,
#[display("convert_loop_to_global")]
ConvertLoopToGlobal,
#[display("convert_outer_to_loop")]
ConvertOuterToLoop,
#[display("rewrite_pattern")]
RewritePattern,
}
#[derive(Debug, Clone, PartialEq, Display)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum ProvenanceEvent {
#[display("Created at {location}")]
Created { location: SourceLocation<'static> },
#[display("From {node}")]
FromOnnx { node: OnnxNodeInfo },
#[display("Transformed from UOp {from_id} by {pass_name}")]
Transformed {
from_id: u64, pass_name: PassName, },
}
pub type ProvenanceChain = Vec<ProvenanceEvent>;
#[derive(Default)]
pub struct ProvenanceTracker {
events: HashMap<u64, Vec<ProvenanceEvent>>,
}
impl ProvenanceTracker {
pub fn capture(&mut self, uop_id: u64, location: &'static Location<'static>) {
let event = ProvenanceEvent::Created { location: SourceLocation::from_caller(location) };
self.events.entry(uop_id).or_default().push(event);
}
pub fn record_transform(&mut self, new_id: u64, old_id: u64, pass_name: PassName) {
self.events.entry(new_id).or_default().push(ProvenanceEvent::Transformed { from_id: old_id, pass_name });
}
pub fn attach_onnx_node(&mut self, uop_id: u64, node: OnnxNodeInfo) {
self.events.entry(uop_id).or_default().push(ProvenanceEvent::FromOnnx { node });
}
pub fn get_events(&self, uop_id: u64) -> Option<&[ProvenanceEvent]> {
self.events.get(&uop_id).map(|v| v.as_slice())
}
pub fn get_chain(&self, uop_id: u64) -> ProvenanceChain {
let mut chain = Vec::new();
self.collect_chain_recursive(uop_id, &mut chain, &mut HashSet::new());
chain
}
fn collect_chain_recursive(&self, uop_id: u64, chain: &mut ProvenanceChain, visited: &mut HashSet<u64>) {
if visited.contains(&uop_id) {
return;
}
visited.insert(uop_id);
if let Some(events) = self.events.get(&uop_id) {
for event in events {
if let ProvenanceEvent::Transformed { from_id, .. } = event {
self.collect_chain_recursive(*from_id, chain, visited);
}
}
chain.extend(events.iter().cloned());
}
}
pub fn cleanup_with_live_set(&mut self, live_uops: &HashSet<u64>) {
self.events.retain(|&id, _| live_uops.contains(&id));
}
pub fn clear(&mut self) {
self.events.clear();
}
pub fn len(&self) -> usize {
self.events.len()
}
pub fn is_empty(&self) -> bool {
self.events.is_empty()
}
}
thread_local! {
pub static PROVENANCE_TRACKER: RefCell<ProvenanceTracker> = RefCell::default();
}
fn workspace_root() -> &'static Path {
static ROOT: OnceLock<PathBuf> = OnceLock::new();
ROOT.get_or_init(|| {
let manifest_dir = env!("CARGO_MANIFEST_DIR");
let path = Path::new(manifest_dir);
path.parent().map(|p| p.to_path_buf()).unwrap_or_else(|| PathBuf::from(manifest_dir))
})
.as_path()
}
pub(crate) fn get_relative_location(loc: &'static Location<'static>) -> &'static str {
let file = loc.file();
let root = workspace_root().to_str().expect("workspace root must be valid UTF-8");
if let Some(stripped) = file.strip_prefix(root) {
stripped.strip_prefix('/').or_else(|| stripped.strip_prefix('\\')).unwrap_or(stripped)
} else {
file
}
}
pub fn format_chain(chain: &ProvenanceChain) -> String {
let mut output = String::new();
for (i, event) in chain.iter().enumerate() {
output.push_str(&format!("\n [{}] {}", i, event));
}
output
}