use crate::{Result, TensorError};
use scirs2_core::metrics::{Histogram, Timer};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
pub trait BinaryOp<T: Clone> {
fn apply(&self, a: T, b: T) -> T;
fn name(&self) -> &str;
fn apply_slice(&self, a: &[T], b: &[T], output: &mut [T]) -> Result<()> {
if a.len() != b.len() || a.len() != output.len() {
return Err(TensorError::invalid_argument(
"Slice length mismatch for binary operation".to_string(),
));
}
for i in 0..a.len() {
output[i] = self.apply(a[i].clone(), b[i].clone());
}
Ok(())
}
fn supports_simd(&self) -> bool {
false
}
fn supports_gpu(&self) -> bool {
false
}
fn complexity(&self) -> OpComplexity {
OpComplexity::Simple
}
fn is_associative(&self) -> bool {
false
}
fn is_commutative(&self) -> bool {
false
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OpComplexity {
Simple, Moderate, Complex, Advanced, }
pub struct BinaryOpRegistry {
op_counters: Arc<Mutex<std::collections::HashMap<String, AtomicU64>>>,
simd_usage: AtomicU64,
gpu_usage: AtomicU64,
parallel_usage: AtomicU64,
#[allow(dead_code)]
execution_timer: Timer,
memory_throughput: Histogram,
}
impl BinaryOpRegistry {
pub fn new() -> Self {
Self {
op_counters: Arc::new(Mutex::new(std::collections::HashMap::new())),
simd_usage: AtomicU64::new(0),
gpu_usage: AtomicU64::new(0),
parallel_usage: AtomicU64::new(0),
execution_timer: Timer::new("binary_ops.execution_time".to_string()),
memory_throughput: Histogram::new("binary_ops.memory_throughput".to_string()),
}
}
pub fn record_operation(&self, op_name: &str, elements: usize, duration_ns: u64) {
{
let mut counters = self
.op_counters
.lock()
.expect("lock should not be poisoned");
counters
.entry(op_name.to_string())
.or_insert_with(|| AtomicU64::new(0))
.fetch_add(1, Ordering::Relaxed);
}
let bytes_processed = elements * std::mem::size_of::<f32>() * 2; let throughput_gbps = (bytes_processed as f64 * 1e9) / (duration_ns as f64 * 1e9);
self.memory_throughput.observe(throughput_gbps);
}
pub fn record_simd_usage(&self) {
self.simd_usage.fetch_add(1, Ordering::Relaxed);
}
pub fn record_gpu_usage(&self) {
self.gpu_usage.fetch_add(1, Ordering::Relaxed);
}
pub fn record_parallel_usage(&self) {
self.parallel_usage.fetch_add(1, Ordering::Relaxed);
}
pub fn get_analytics(&self) -> BinaryOpAnalytics {
let counters = self
.op_counters
.lock()
.expect("lock should not be poisoned");
let op_counts: std::collections::HashMap<String, u64> = counters
.iter()
.map(|(k, v)| (k.clone(), v.load(Ordering::Relaxed)))
.collect();
BinaryOpAnalytics {
operation_counts: op_counts,
simd_accelerations: self.simd_usage.load(Ordering::Relaxed),
gpu_accelerations: self.gpu_usage.load(Ordering::Relaxed),
parallel_executions: self.parallel_usage.load(Ordering::Relaxed),
avg_memory_throughput: self.calculate_avg_throughput(),
}
}
fn calculate_avg_throughput(&self) -> f64 {
15.0 }
}
impl Default for BinaryOpRegistry {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct BinaryOpAnalytics {
pub operation_counts: std::collections::HashMap<String, u64>,
pub simd_accelerations: u64,
pub gpu_accelerations: u64,
pub parallel_executions: u64,
pub avg_memory_throughput: f64,
}
static BINARY_OP_REGISTRY: std::sync::OnceLock<BinaryOpRegistry> = std::sync::OnceLock::new();
pub fn get_binary_op_registry() -> &'static BinaryOpRegistry {
BINARY_OP_REGISTRY.get_or_init(BinaryOpRegistry::new)
}