use crate::{Result, TensorError};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum StubReductionOp {
Sum,
Mean,
Max,
Min,
}
impl StubReductionOp {
pub fn name(self) -> &'static str {
match self {
Self::Sum => "sum",
Self::Mean => "mean",
Self::Max => "max",
Self::Min => "min",
}
}
}
#[derive(Debug, Clone)]
pub struct StubReductionResult {
pub value: f64,
pub encoder_latency_ns: u64,
pub used_real_gpu: bool,
}
#[derive(Debug, Default)]
pub struct GpuStub {
device_label: String,
}
impl GpuStub {
pub fn new(device_label: impl Into<String>) -> Self {
Self {
device_label: device_label.into(),
}
}
pub fn device_label(&self) -> &str {
&self.device_label
}
pub fn dispatch_reduction(
&self,
op: StubReductionOp,
data: &[f32],
) -> Result<StubReductionResult> {
let start = std::time::Instant::now();
let value: f64 = match op {
StubReductionOp::Sum => data.iter().map(|&x| x as f64).sum(),
StubReductionOp::Mean => {
if data.is_empty() {
0.0
} else {
data.iter().map(|&x| x as f64).sum::<f64>() / data.len() as f64
}
}
StubReductionOp::Max => {
if data.is_empty() {
return Err(TensorError::invalid_argument(
"Max reduction requires at least one element".to_string(),
));
}
data.iter()
.map(|&x| x as f64)
.fold(f64::NEG_INFINITY, f64::max)
}
StubReductionOp::Min => {
if data.is_empty() {
return Err(TensorError::invalid_argument(
"Min reduction requires at least one element".to_string(),
));
}
data.iter().map(|&x| x as f64).fold(f64::INFINITY, f64::min)
}
};
let encoder_latency_ns = start.elapsed().as_nanos() as u64;
Ok(StubReductionResult {
value,
encoder_latency_ns,
used_real_gpu: false,
})
}
}
#[cfg(feature = "gpu")]
pub struct GpuKernelDispatcher {
device: wgpu::Device,
queue: wgpu::Queue,
}
#[cfg(feature = "gpu")]
impl GpuKernelDispatcher {
pub fn new(device: wgpu::Device, queue: wgpu::Queue) -> Self {
Self { device, queue }
}
pub fn dispatch_reduction(
&self,
op: StubReductionOp,
data: &[f32],
) -> Result<StubReductionResult> {
if data.is_empty() && matches!(op, StubReductionOp::Max | StubReductionOp::Min) {
return Err(TensorError::invalid_argument(
"Max/Min reduction requires at least one element".to_string(),
));
}
let start = std::time::Instant::now();
let _encoder = self
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some(&format!("reduction_{}_encoder", op.name())),
});
let cpu_stub = GpuStub::new("gpu_fallback");
let cpu_result = cpu_stub.dispatch_reduction(op, data)?;
let encoder_latency_ns = start.elapsed().as_nanos() as u64;
Ok(StubReductionResult {
value: cpu_result.value,
encoder_latency_ns,
used_real_gpu: true,
})
}
}
#[inline]
pub fn measure_overhead_ns<F: FnMut()>(mut f: F) -> u64 {
let start = std::time::Instant::now();
f();
start.elapsed().as_nanos() as u64
}
pub const MAX_DISPATCH_OVERHEAD_NS: u64 = 10_000;
pub const MAX_GPU_STUB_OVERHEAD_NS: u64 = 500_000;
pub fn validate_overhead(label: &str, measured_ns: u64, threshold_ns: u64) -> Result<()> {
if measured_ns <= threshold_ns {
Ok(())
} else {
Err(TensorError::invalid_argument(format!(
"Overhead validation failed for '{}': measured {}ns exceeds threshold {}ns",
label, measured_ns, threshold_ns
)))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_stub_sum_reduction() {
let stub = GpuStub::new("test_device");
let data = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0];
let result = stub
.dispatch_reduction(StubReductionOp::Sum, &data)
.expect("dispatch_reduction should succeed");
assert!((result.value - 15.0).abs() < 1e-6, "sum should be 15.0");
assert!(!result.used_real_gpu);
}
#[test]
fn test_stub_mean_reduction() {
let stub = GpuStub::new("test_device");
let data = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0];
let result = stub
.dispatch_reduction(StubReductionOp::Mean, &data)
.expect("dispatch_reduction should succeed");
assert!((result.value - 3.0).abs() < 1e-6, "mean should be 3.0");
}
#[test]
fn test_stub_max_reduction() {
let stub = GpuStub::new("test_device");
let data = vec![3.0_f32, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0];
let result = stub
.dispatch_reduction(StubReductionOp::Max, &data)
.expect("dispatch_reduction should succeed");
assert!((result.value - 9.0).abs() < 1e-6, "max should be 9.0");
}
#[test]
fn test_stub_min_reduction() {
let stub = GpuStub::new("test_device");
let data = vec![3.0_f32, 1.0, 4.0, 1.0, 5.0];
let result = stub
.dispatch_reduction(StubReductionOp::Min, &data)
.expect("dispatch_reduction should succeed");
assert!((result.value - 1.0).abs() < 1e-6, "min should be 1.0");
}
#[test]
fn test_stub_empty_max_returns_error() {
let stub = GpuStub::new("test_device");
let result = stub.dispatch_reduction(StubReductionOp::Max, &[]);
assert!(result.is_err(), "Max of empty slice should be an error");
}
#[test]
fn test_stub_empty_min_returns_error() {
let stub = GpuStub::new("test_device");
let result = stub.dispatch_reduction(StubReductionOp::Min, &[]);
assert!(result.is_err(), "Min of empty slice should be an error");
}
#[test]
fn test_stub_empty_sum_returns_zero() {
let stub = GpuStub::new("test_device");
let result = stub
.dispatch_reduction(StubReductionOp::Sum, &[])
.expect("Sum of empty slice is defined (zero)");
assert!((result.value - 0.0).abs() < 1e-12);
}
#[test]
fn test_stub_empty_mean_returns_zero() {
let stub = GpuStub::new("test_device");
let result = stub
.dispatch_reduction(StubReductionOp::Mean, &[])
.expect("Mean of empty slice returns 0");
assert!((result.value - 0.0).abs() < 1e-12);
}
#[test]
fn test_stub_device_label() {
let stub = GpuStub::new("my_test_gpu");
assert_eq!(stub.device_label(), "my_test_gpu");
}
#[test]
fn test_stub_encoder_latency_is_non_zero_or_zero() {
let stub = GpuStub::new("test_device");
let data: Vec<f32> = (0..1024).map(|i| i as f32).collect();
let result = stub
.dispatch_reduction(StubReductionOp::Sum, &data)
.expect("dispatch_reduction should succeed");
let _ = result.encoder_latency_ns;
}
#[test]
fn test_measure_overhead_ns_fast_closure() {
let ns = measure_overhead_ns(|| {
let _ = 1_u64.wrapping_add(1);
});
let _ = ns;
}
#[test]
fn test_measure_overhead_ns_returns_u64() {
let ns: u64 = measure_overhead_ns(|| {
std::hint::black_box(42_u64);
});
let _ = ns;
}
#[test]
fn test_threshold_constants_are_positive() {
const _: () = {
assert!(MAX_DISPATCH_OVERHEAD_NS > 0);
assert!(MAX_GPU_STUB_OVERHEAD_NS > 0);
};
}
#[test]
fn test_threshold_constants_ordering() {
const _: () = {
assert!(MAX_GPU_STUB_OVERHEAD_NS > MAX_DISPATCH_OVERHEAD_NS);
};
}
#[test]
fn test_validate_overhead_passes_when_within_threshold() {
let result = validate_overhead("test_op", 500, 1_000);
assert!(result.is_ok(), "500ns should be within 1000ns threshold");
}
#[test]
fn test_validate_overhead_passes_at_exact_threshold() {
let result = validate_overhead("test_op", 1_000, 1_000);
assert!(result.is_ok(), "Exactly at threshold should pass");
}
#[test]
fn test_validate_overhead_fails_above_threshold() {
let result = validate_overhead("test_op", 1_001, 1_000);
assert!(result.is_err(), "1001ns should exceed 1000ns threshold");
}
#[test]
fn test_validate_overhead_error_message_contains_label() {
let result = validate_overhead("my_operation", 9_999_999, 1);
let err_msg = format!("{:?}", result.expect_err("should be error"));
assert!(
err_msg.contains("my_operation"),
"Error message should name the operation"
);
}
#[test]
fn test_stub_reduction_op_names() {
assert_eq!(StubReductionOp::Sum.name(), "sum");
assert_eq!(StubReductionOp::Mean.name(), "mean");
assert_eq!(StubReductionOp::Max.name(), "max");
assert_eq!(StubReductionOp::Min.name(), "min");
}
}