use super::BrickProfiler;
use crate::brick::profiler::checksum::{fnv1a_f32_checksum, DivergenceInfo, KernelChecksum};
impl BrickProfiler {
pub fn record_checksum(&mut self, name: &str, layer_idx: usize, position: u32, output: &[f32]) {
if !self.enabled {
return;
}
let checksum = fnv1a_f32_checksum(output);
let trace = KernelChecksum { name: name.to_string(), layer_idx, position, checksum };
self.kernel_checksums.push(trace);
}
#[must_use]
pub fn get_checksums(&self) -> &[KernelChecksum] {
&self.kernel_checksums
}
#[must_use]
pub fn find_divergence(&self, reference: &BrickProfiler) -> Option<DivergenceInfo> {
use std::collections::HashMap;
let ref_index: HashMap<(&str, usize, u32), u64> = reference
.kernel_checksums
.iter()
.map(|t| ((t.name.as_str(), t.layer_idx, t.position), t.checksum))
.collect();
for trace in &self.kernel_checksums {
let key = (trace.name.as_str(), trace.layer_idx, trace.position);
if let Some(&expected) = ref_index.get(&key) {
if trace.checksum != expected {
return Some(DivergenceInfo {
kernel_name: trace.name.clone(),
layer_idx: trace.layer_idx,
position: trace.position,
expected_checksum: expected,
actual_checksum: trace.checksum,
});
}
}
}
None
}
pub fn reset_checksums(&mut self) {
self.kernel_checksums.clear();
}
}