use crate::gpu_profiler::global_profiler;
#[cfg(feature = "gpu")]
use crate::{buffer::TensorBuffer, Device, Result, TensorError};
use scirs2_core::ndarray::ArrayD;
use std::sync::Arc;
use std::time::Instant;
macro_rules! include_shader {
("activation_ops") => {
include_str!("../shaders/activation_ops.wgsl")
};
("manipulation_ops") => {
include_str!("../shaders/manipulation_ops.wgsl")
};
("comparison_ops") => {
include_str!("../shaders/comparison_ops.wgsl")
};
("logical_ops") => {
include_str!("../shaders/logical_ops.wgsl")
};
("random_ops") => {
include_str!("../shaders/random_ops.wgsl")
};
("reduction_ops") => {
include_str!("../shaders/reduction_ops.wgsl")
};
("einsum_ops") => {
include_str!("../shaders/einsum_ops.wgsl")
};
("binary_ops") => {
include_str!("../shaders/binary_ops.wgsl")
};
("conv_ops") => {
include_str!("../shaders/conv_ops.wgsl")
};
("matmul_ops") => {
include_str!("../shaders/matmul_ops.wgsl")
};
("attention_ops") => {
include_str!("../shaders/attention_ops.wgsl")
};
("embedding_ops") => {
include_str!("../shaders/embedding_ops.wgsl")
};
("normalization_ops") => {
include_str!("../shaders/normalization_ops.wgsl")
};
("pooling_ops") => {
include_str!("../shaders/pooling_ops.wgsl")
};
("scan_ops") => {
include_str!("../shaders/scan_ops.wgsl")
};
("segmented_ops") => {
include_str!("../shaders/segmented_ops.wgsl")
};
("strided_ops") => {
include_str!("../shaders/strided_ops.wgsl")
};
("unary_ops") => {
include_str!("../shaders/unary_ops.wgsl")
};
("unary_ops_f64") => {
include_str!("../shaders/unary_ops_f64.wgsl")
};
("unary_ops_i32") => {
include_str!("../shaders/unary_ops_i32.wgsl")
};
("unary_ops_i64") => {
include_str!("../shaders/unary_ops_i64.wgsl")
};
("unary_ops_u32") => {
include_str!("../shaders/unary_ops_u32.wgsl")
};
("unary_ops_u64") => {
include_str!("../shaders/unary_ops_u64.wgsl")
};
("binary_ops_f64") => {
include_str!("../shaders/binary_ops_f64.wgsl")
};
("binary_ops_i32") => {
include_str!("../shaders/binary_ops_i32.wgsl")
};
("binary_ops_i64") => {
include_str!("../shaders/binary_ops_i64.wgsl")
};
("topk_ops") => {
include_str!("../shaders/topk_ops.wgsl")
};
("manipulation_ops2") => {
include_str!("../shaders/manipulation_ops2.wgsl")
};
("fused_ops") => {
include_str!("../shaders/fused_ops.wgsl")
};
("fft_ops") => {
include_str!("../shaders/fft_ops.wgsl")
};
}
pub struct GpuContext {
pub device: Arc<wgpu::Device>,
pub queue: Arc<wgpu::Queue>,
}
pub enum BinaryScalarOp {
Add,
Sub,
Mul,
Div,
Pow,
}
impl GpuContext {
pub fn new() -> Result<Self> {
pollster::block_on(async {
let instance =
wgpu::Instance::new(wgpu::InstanceDescriptor::new_without_display_handle());
let adapter = instance
.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::HighPerformance,
compatible_surface: None,
force_fallback_adapter: false,
})
.await
.map_err(|_e| {
TensorError::gpu_error(
"GpuContext::new",
"Failed to find suitable GPU adapter",
None,
false,
)
})?;
let (device, queue) = adapter
.request_device(&wgpu::DeviceDescriptor {
required_features: wgpu::Features::empty(),
required_limits: if cfg!(target_arch = "wasm32") {
wgpu::Limits::downlevel_webgl2_defaults()
} else {
wgpu::Limits::default()
},
label: Some("TenfloweRS GPU Device"),
memory_hints: Default::default(),
experimental_features: wgpu::ExperimentalFeatures::default(),
trace: wgpu::Trace::default(),
})
.await
.map_err(|e| {
TensorError::gpu_error(
"GpuContext::new",
&format!("Failed to create GPU device: {}", e),
None,
false,
)
})?;
Ok(Self {
device: Arc::new(device),
queue: Arc::new(queue),
})
})
}
pub fn global() -> Result<&'static Self> {
use std::sync::OnceLock;
static GLOBAL_CONTEXT: OnceLock<Result<GpuContext>> = OnceLock::new();
GLOBAL_CONTEXT
.get_or_init(|| GpuContext::new())
.as_ref()
.map_err(|e| e.clone())
}
}
#[macro_export]
macro_rules! gpu_include_shader {
("binary_ops") => {
include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/gpu/shaders/binary_ops.wgsl"
))
};
("binary_ops_f64") => {
include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/gpu/shaders/binary_ops_f64.wgsl"
))
};
("binary_ops_i32") => {
include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/gpu/shaders/binary_ops_i32.wgsl"
))
};
("binary_ops_i64") => {
include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/gpu/shaders/binary_ops_i64.wgsl"
))
};
("binary_ops_u32") => {
include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/gpu/shaders/binary_ops_u32.wgsl"
))
};
("binary_ops_u64") => {
include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/gpu/shaders/binary_ops_u64.wgsl"
))
};
("unary_ops") => {
include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/gpu/shaders/unary_ops.wgsl"
))
};
("unary_ops_f64") => {
include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/gpu/shaders/unary_ops_f64.wgsl"
))
};
("unary_ops_i32") => {
include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/gpu/shaders/unary_ops_i32.wgsl"
))
};
("unary_ops_i64") => {
include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/gpu/shaders/unary_ops_i64.wgsl"
))
};
("unary_ops_u32") => {
include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/gpu/shaders/unary_ops_u32.wgsl"
))
};
("unary_ops_u64") => {
include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/gpu/shaders/unary_ops_u64.wgsl"
))
};
("fft_ops") => {
include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/gpu/shaders/fft_ops.wgsl"
))
};
("einsum_ops") => {
include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/gpu/shaders/einsum_ops.wgsl"
))
};
("reduction_ops") => {
include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/gpu/shaders/reduction_ops.wgsl"
))
};
("matmul_ops") => {
include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/gpu/shaders/matmul_ops.wgsl"
))
};
("conv_ops") => {
include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/gpu/shaders/conv_ops.wgsl"
))
};
("attention_ops") => {
include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/gpu/shaders/attention_ops.wgsl"
))
};
("pooling_ops") => {
include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/gpu/shaders/pooling_ops.wgsl"
))
};
("activation_ops") => {
include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/gpu/shaders/activation_ops.wgsl"
))
};
("comparison_ops") => {
include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/gpu/shaders/comparison_ops.wgsl"
))
};
("logical_ops") => {
include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/gpu/shaders/logical_ops.wgsl"
))
};
("manipulation_ops") => {
include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/gpu/shaders/manipulation_ops.wgsl"
))
};
("manipulation_ops2") => {
include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/gpu/shaders/manipulation_ops2.wgsl"
))
};
("normalization_ops") => {
include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/gpu/shaders/normalization_ops.wgsl"
))
};
("embedding_ops") => {
include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/gpu/shaders/embedding_ops.wgsl"
))
};
("random_ops") => {
include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/gpu/shaders/random_ops.wgsl"
))
};
("scan_ops") => {
include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/gpu/shaders/scan_ops.wgsl"
))
};
("segmented_ops") => {
include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/gpu/shaders/segmented_ops.wgsl"
))
};
("strided_ops") => {
include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/gpu/shaders/strided_ops.wgsl"
))
};
("topk_ops") => {
include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/gpu/shaders/topk_ops.wgsl"
))
};
("fused_ops") => {
include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/gpu/shaders/fused_ops.wgsl"
))
};
}
pub use gpu_include_shader;
pub mod binary_ops;
pub mod buffer;
pub mod logical_ops;
pub mod random_ops;
pub mod unary_ops;
#[cfg(feature = "gpu")]
pub mod async_kernel;
#[cfg(feature = "gpu")]
pub mod linalg;
#[cfg(feature = "gpu")]
pub mod memory_coalescing;
#[cfg(feature = "gpu")]
pub mod multi_stream_executor;
#[cfg(feature = "gpu")]
pub mod rnn_ops;
#[cfg(feature = "gpu")]
pub mod attention_ops;
#[cfg(feature = "gpu")]
pub mod kernel_fusion;
#[cfg(feature = "gpu")]
pub mod ultra_fusion_integration;
#[cfg(feature = "gpu")]
pub mod memory_pool;
#[cfg(feature = "gpu")]
pub mod memory_tracing;
#[cfg(feature = "gpu")]
pub mod memory_diagnostics;
#[cfg(feature = "gpu")]
pub mod reduction_kernels;
#[cfg(feature = "gpu")]
pub mod performance_optimizer;
#[cfg(feature = "gpu")]
pub mod advanced_kernel_manager;
#[cfg(feature = "cudnn")]
pub mod cudnn;
#[cfg(all(target_os = "macos", feature = "metal"))]
pub mod metal_kernels;
#[cfg(feature = "rocm")]
pub mod rocm_kernels;
#[cfg(feature = "cuda")]
pub mod cuda_kernels;
#[cfg(feature = "nccl")]
pub mod nccl_integration;
pub mod ops;
pub use binary_ops::{gpu_binary_op, BinaryOpKernel};
pub use buffer::{BufferManager, GpuBuffer, GpuBufferOps};
pub use unary_ops::{gpu_unary_op, UnaryOpKernel};
#[cfg(feature = "gpu")]
pub use attention_ops::*;
#[cfg(feature = "gpu")]
pub use kernel_fusion::*;
#[cfg(feature = "gpu")]
pub use linalg::*;
#[cfg(feature = "gpu")]
pub use ultra_fusion_integration::*;
pub use ops::ReductionOp;
#[cfg(feature = "gpu")]
pub use memory_tracing::{
current_gpu_memory_usage, generate_gpu_memory_report, peak_gpu_memory_usage,
print_gpu_memory_report, MemoryReport, MemoryStats,
};
pub trait GpuOps {
fn gpu_add(&self, other: &Self) -> crate::Result<Self>
where
Self: Sized;
fn gpu_mul(&self, other: &Self) -> crate::Result<Self>
where
Self: Sized;
fn gpu_sub(&self, other: &Self) -> crate::Result<Self>
where
Self: Sized;
fn gpu_div(&self, other: &Self) -> crate::Result<Self>
where
Self: Sized;
}
fn cast_to_f32<T>(value: T) -> f32
where
T: bytemuck::Pod + bytemuck::Zeroable + Clone + Send + Sync + 'static,
{
42.0 }
pub fn gpu_comparison_op_dispatch<T>(
input_a: &GpuBuffer<T>,
input_b: &GpuBuffer<T>,
operation: self::ops::ComparisonOp,
) -> Result<GpuBuffer<u8>>
where
T: bytemuck::Pod + bytemuck::Zeroable + Clone + Send + Sync + 'static,
{
let device_id = match input_a.device_enum() {
Device::Gpu(id) => id,
_ => {
return Err(TensorError::DeviceMismatch {
operation: "comparison".to_string(),
device1: format!("{:?}", input_a.device_enum()),
device2: "GPU".to_string(),
context: None,
})
}
};
let result_data = vec![1u8; input_a.len()];
GpuBuffer::from_slice(&result_data, &Device::Gpu(device_id))
}
pub fn execute_embedding_lookup<T>(
indices: &GpuBuffer<T>,
weights: &GpuBuffer<T>,
num_embeddings: usize,
embedding_dim: usize,
total_indices: usize,
) -> Result<GpuBuffer<T>>
where
T: bytemuck::Pod + bytemuck::Zeroable + Clone + Send + Sync + 'static + Default,
{
let output_size = total_indices * embedding_dim;
let device_id = match indices.device_enum() {
Device::Gpu(id) => id,
_ => {
return Err(TensorError::DeviceMismatch {
operation: "embedding_lookup".to_string(),
device1: format!("{:?}", indices.device_enum()),
device2: "GPU".to_string(),
context: None,
})
}
};
let result_data = vec![T::default(); output_size];
GpuBuffer::from_slice(&result_data, &Device::Gpu(device_id))
}