use std::collections::HashMap;
use std::time::Instant;
#[derive(Debug, Clone, Default)]
pub struct OpStats {
pub min_us: f64,
pub max_us: f64,
pub avg_us: f64,
pub total_us: f64,
pub count: usize,
pub per_layer: Vec<f64>,
}
impl OpStats {
fn new(time_us: f64) -> Self {
Self {
min_us: time_us,
max_us: time_us,
avg_us: time_us,
total_us: time_us,
count: 1,
per_layer: vec![time_us],
}
}
fn add(&mut self, time_us: f64) {
self.min_us = self.min_us.min(time_us);
self.max_us = self.max_us.max(time_us);
self.total_us += time_us;
self.count += 1;
self.avg_us = self.total_us / self.count as f64;
self.per_layer.push(time_us);
}
}
#[derive(Debug, Clone)]
pub struct ProfileReport {
pub operations: HashMap<String, OpStats>,
pub total_inference_us: f64,
pub tokens_processed: usize,
pub num_layers: usize,
pub throughput_tok_s: f64,
pub is_real_data: bool,
}
impl ProfileReport {
pub fn hottest(&self) -> Option<(&str, &OpStats)> {
self.operations
.iter()
.max_by(|a, b| {
a.1.total_us
.partial_cmp(&b.1.total_us)
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(name, stats)| (name.as_str(), stats))
}
pub fn sorted_by_time(&self) -> Vec<(&str, &OpStats)> {
let mut sorted: Vec<_> = self
.operations
.iter()
.map(|(k, v)| (k.as_str(), v))
.collect();
sorted.sort_by(|a, b| {
b.1.total_us
.partial_cmp(&a.1.total_us)
.unwrap_or(std::cmp::Ordering::Equal)
});
sorted
}
pub fn percentage_breakdown(&self) -> HashMap<String, f64> {
let total = self.total_inference_us;
if total <= 0.0 {
return HashMap::new();
}
self.operations
.iter()
.map(|(name, stats)| (name.clone(), (stats.total_us / total) * 100.0))
.collect()
}
}
include!("profiler_contracts.rs");
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ContractSeverity {
Error,
Warning,
}
#[derive(Debug)]
struct ActiveTimer {
start: Instant,
}
#[derive(Debug)]
pub struct BrickProfiler {
stats: HashMap<String, OpStats>,
active: HashMap<String, ActiveTimer>,
inference_start: Option<Instant>,
inference_end: Option<Instant>,
tokens_count: usize,
num_layers: usize,
current_layer: usize,
enabled: bool,
}
impl Default for BrickProfiler {
fn default() -> Self {
Self::new()
}
}
impl BrickProfiler {
pub fn new() -> Self {
Self {
stats: HashMap::new(),
active: HashMap::new(),
inference_start: None,
inference_end: None,
tokens_count: 0,
num_layers: 0,
current_layer: 0,
enabled: true,
}
}
pub fn disabled() -> Self {
let mut p = Self::new();
p.enabled = false;
p
}
pub fn is_enabled(&self) -> bool {
self.enabled
}
pub fn set_tokens(&mut self, count: usize) {
self.tokens_count = count;
}
pub fn set_num_layers(&mut self, num_layers: usize) {
self.num_layers = num_layers;
}
pub fn set_current_layer(&mut self, layer: usize) {
self.current_layer = layer;
}
pub fn start(&mut self, operation: &str) {
if !self.enabled {
return;
}
self.active.insert(
operation.to_string(),
ActiveTimer {
start: Instant::now(),
},
);
}
pub fn start_inference(&mut self) {
if !self.enabled {
return;
}
self.inference_start = Some(Instant::now());
}
pub fn stop_inference(&mut self) {
if !self.enabled {
return;
}
self.inference_end = Some(Instant::now());
}
pub fn stop(&mut self, operation: &str) {
if !self.enabled {
return;
}
if let Some(timer) = self.active.remove(operation) {
let elapsed_us = timer.start.elapsed().as_secs_f64() * 1_000_000.0;
if let Some(stats) = self.stats.get_mut(operation) {
stats.add(elapsed_us);
} else {
self.stats
.insert(operation.to_string(), OpStats::new(elapsed_us));
}
}
}
pub fn record(&mut self, operation: &str, time_us: f64) {
if !self.enabled {
return;
}
if let Some(stats) = self.stats.get_mut(operation) {
stats.add(time_us);
} else {
self.stats
.insert(operation.to_string(), OpStats::new(time_us));
}
}
pub fn measure<F, T>(&mut self, operation: &str, f: F) -> T
where
F: FnOnce() -> T,
{
if !self.enabled {
return f();
}
self.start(operation);
let result = f();
self.stop(operation);
result
}
pub fn clear(&mut self) {
self.stats.clear();
self.active.clear();
self.inference_start = None;
self.inference_end = None;
self.tokens_count = 0;
self.current_layer = 0;
}
pub fn report(&self) -> ProfileReport {
let total_inference_us = match (self.inference_start, self.inference_end) {
(Some(start), Some(end)) => end.duration_since(start).as_secs_f64() * 1_000_000.0,
_ => self.stats.values().map(|s| s.total_us).sum(),
};
let throughput_tok_s = if total_inference_us > 0.0 && self.tokens_count > 0 {
(self.tokens_count as f64 / total_inference_us) * 1_000_000.0
} else {
0.0
};
let report = ProfileReport {
operations: self.stats.clone(),
total_inference_us,
tokens_processed: self.tokens_count,
num_layers: self.num_layers,
throughput_tok_s,
is_real_data: self.enabled && !self.stats.is_empty(),
};
let violations = report.validate_contracts();
for (severity, msg) in &violations {
match severity {
ContractSeverity::Error => eprintln!("[CONTRACT ERROR] {}", msg),
ContractSeverity::Warning => eprintln!("[CONTRACT WARN] {}", msg),
}
}
report
}
pub fn stats(&self) -> &HashMap<String, OpStats> {
&self.stats
}
}
thread_local! {
pub static PROFILER: std::cell::RefCell<BrickProfiler> = std::cell::RefCell::new(BrickProfiler::new());
}
#[macro_export]
macro_rules! profile_start {
($op:expr) => {
$crate::brick::profiler::PROFILER.with(|p| {
p.borrow_mut().start($op);
});
};
}
#[macro_export]
macro_rules! profile_stop {
($op:expr) => {
$crate::brick::profiler::PROFILER.with(|p| {
p.borrow_mut().stop($op);
});
};
}
#[macro_export]
macro_rules! profile_report {
() => {
$crate::brick::profiler::PROFILER.with(|p| p.borrow().report())
};
}
#[macro_export]
macro_rules! profile_clear {
() => {
$crate::brick::profiler::PROFILER.with(|p| p.borrow_mut().clear());
};
}
include!("profiler_basic_profiling.rs");