#![allow(dead_code)] #![allow(unused_variables)]
use crate::errors::{Result, TrustformersError};
use crate::tensor::Tensor;
use std::collections::HashMap;
use std::sync::{Arc, Mutex, OnceLock};
const HIP_SUCCESS: i32 = 0;
const HIP_MEMCPY_HOST_TO_DEVICE: i32 = 1;
const HIP_MEMCPY_DEVICE_TO_HOST: i32 = 2;
pub struct RocmImpl {
device_id: i32,
kernel_cache: Arc<Mutex<HashMap<String, RocmKernel>>>,
memory_pool: Arc<Mutex<RocmMemoryPool>>,
device_props: DeviceProperties,
#[cfg(all(feature = "rocm", target_os = "linux"))]
hip_lib: Arc<HipLibrary>,
}
#[cfg(all(feature = "rocm", target_os = "linux"))]
struct HipLibrary {
_library: libloading::Library,
hip_get_device_count: unsafe extern "C" fn(*mut i32) -> i32,
hip_set_device: unsafe extern "C" fn(i32) -> i32,
hip_malloc: unsafe extern "C" fn(*mut *mut std::ffi::c_void, usize) -> i32,
hip_free: unsafe extern "C" fn(*mut std::ffi::c_void) -> i32,
hip_memcpy:
unsafe extern "C" fn(*mut std::ffi::c_void, *const std::ffi::c_void, usize, i32) -> i32,
hip_device_synchronize: unsafe extern "C" fn() -> i32,
hip_module_load_data:
unsafe extern "C" fn(*mut *mut std::ffi::c_void, *const std::ffi::c_void) -> i32,
hip_module_get_function:
unsafe extern "C" fn(*mut *mut std::ffi::c_void, *mut std::ffi::c_void, *const i8) -> i32,
hip_module_launch_kernel: unsafe extern "C" fn(
*mut std::ffi::c_void,
u32,
u32,
u32,
u32,
u32,
u32,
u32,
*mut std::ffi::c_void,
*mut *mut std::ffi::c_void,
*mut *mut std::ffi::c_void,
) -> i32,
}
#[cfg(all(feature = "rocm", target_os = "linux"))]
impl HipLibrary {
fn load() -> Result<Self> {
let lib = unsafe {
libloading::Library::new("libamdhip64.so")
.or_else(|_| libloading::Library::new("libamdhip64.so.5"))
.or_else(|_| libloading::Library::new("libamdhip64.so.6"))
.map_err(|e| {
TrustformersError::hardware_error(
&format!(
"Failed to load ROCm HIP library: {}. Make sure ROCm is installed.",
e
),
"HipLibrary::load",
)
})?
};
unsafe {
let hip_get_device_count = *lib
.get::<unsafe extern "C" fn(*mut i32) -> i32>(b"hipGetDeviceCount\0")
.map_err(|e| {
TrustformersError::hardware_error(
&format!("hipGetDeviceCount: {}", e),
"HipLibrary::load",
)
})?;
let hip_set_device =
*lib.get::<unsafe extern "C" fn(i32) -> i32>(b"hipSetDevice\0").map_err(|e| {
TrustformersError::hardware_error(
&format!("hipSetDevice: {}", e),
"HipLibrary::load",
)
})?;
let hip_malloc = *lib
.get::<unsafe extern "C" fn(*mut *mut std::ffi::c_void, usize) -> i32>(
b"hipMalloc\0",
)
.map_err(|e| {
TrustformersError::hardware_error(
&format!("hipMalloc: {}", e),
"HipLibrary::load",
)
})?;
let hip_free = *lib
.get::<unsafe extern "C" fn(*mut std::ffi::c_void) -> i32>(b"hipFree\0")
.map_err(|e| {
TrustformersError::hardware_error(
&format!("hipFree: {}", e),
"HipLibrary::load",
)
})?;
let hip_memcpy = *lib
.get::<unsafe extern "C" fn(
*mut std::ffi::c_void,
*const std::ffi::c_void,
usize,
i32,
) -> i32>(b"hipMemcpy\0")
.map_err(|e| {
TrustformersError::hardware_error(
&format!("hipMemcpy: {}", e),
"HipLibrary::load",
)
})?;
let hip_device_synchronize = *lib
.get::<unsafe extern "C" fn() -> i32>(b"hipDeviceSynchronize\0")
.map_err(|e| {
TrustformersError::hardware_error(
&format!("hipDeviceSynchronize: {}", e),
"HipLibrary::load",
)
})?;
let hip_module_load_data =
*lib.get::<unsafe extern "C" fn(
*mut *mut std::ffi::c_void,
*const std::ffi::c_void,
) -> i32>(b"hipModuleLoadData\0")
.map_err(|e| {
TrustformersError::hardware_error(
&format!("hipModuleLoadData: {}", e),
"HipLibrary::load",
)
})?;
let hip_module_get_function = *lib
.get::<unsafe extern "C" fn(
*mut *mut std::ffi::c_void,
*mut std::ffi::c_void,
*const i8,
) -> i32>(b"hipModuleGetFunction\0")
.map_err(|e| {
TrustformersError::hardware_error(
&format!("hipModuleGetFunction: {}", e),
"HipLibrary::load",
)
})?;
let hip_module_launch_kernel = *lib
.get::<unsafe extern "C" fn(
*mut std::ffi::c_void,
u32,
u32,
u32,
u32,
u32,
u32,
u32,
*mut std::ffi::c_void,
*mut *mut std::ffi::c_void,
*mut *mut std::ffi::c_void,
) -> i32>(b"hipModuleLaunchKernel\0")
.map_err(|e| {
TrustformersError::hardware_error(
&format!("hipModuleLaunchKernel: {}", e),
"HipLibrary::load",
)
})?;
Ok(Self {
_library: lib,
hip_get_device_count,
hip_set_device,
hip_malloc,
hip_free,
hip_memcpy,
hip_device_synchronize,
hip_module_load_data,
hip_module_get_function,
hip_module_launch_kernel,
})
}
}
}
#[derive(Clone)]
pub struct RocmKernel {
function: *mut std::ffi::c_void,
name: String,
grid_config: (u32, u32, u32),
block_config: (u32, u32, u32),
shared_memory: u32,
}
#[derive(Default)]
pub struct RocmMemoryPool {
available_blocks: Vec<RocmMemoryBlock>,
allocated_blocks: HashMap<usize, RocmMemoryBlock>,
total_allocated: usize,
peak_memory: usize,
}
#[derive(Clone)]
pub struct RocmMemoryBlock {
ptr: *mut std::ffi::c_void,
size: usize,
id: usize,
}
#[derive(Debug, Clone)]
pub struct DeviceProperties {
pub name: String,
pub gfx_version: String,
pub total_memory: usize,
pub available_memory: usize,
pub compute_units: u32,
pub wavefront_size: u32,
pub max_threads_per_block: u32,
pub max_shared_memory: u32,
}
static ROCM_INSTANCE: OnceLock<Arc<RocmImpl>> = OnceLock::new();
#[cfg(all(feature = "rocm", target_os = "linux"))]
impl RocmImpl {
pub fn new() -> Result<Self> {
let hip_lib = Arc::new(HipLibrary::load()?);
let mut device_count = 0;
let result = unsafe { (hip_lib.hip_get_device_count)(&mut device_count) };
if result != HIP_SUCCESS || device_count == 0 {
return Err(TrustformersError::hardware_error(
"No ROCm devices found",
"RocmImpl::new",
));
}
let device_id = 0;
let result = unsafe { (hip_lib.hip_set_device)(device_id) };
if result != HIP_SUCCESS {
return Err(TrustformersError::hardware_error(
"Failed to set ROCm device",
"RocmImpl::new",
));
}
let device_props = Self::get_device_properties(device_id)?;
Ok(Self {
device_id,
kernel_cache: Arc::new(Mutex::new(HashMap::new())),
memory_pool: Arc::new(Mutex::new(RocmMemoryPool::new())),
device_props,
hip_lib,
})
}
pub fn global() -> Result<&'static Arc<RocmImpl>> {
static INIT_SUCCESS: OnceLock<bool> = OnceLock::new();
let success = *INIT_SUCCESS.get_or_init(|| match Self::new() {
Ok(instance) => {
let _ = ROCM_INSTANCE.set(Arc::new(instance));
true
},
Err(_) => false,
});
if success {
Ok(ROCM_INSTANCE.get().expect("ROCm instance should exist after initialization"))
} else {
Err(TrustformersError::hardware_error(
"ROCm not available on this system",
"RocmImpl::global",
))
}
}
fn get_device_properties(device_id: i32) -> Result<DeviceProperties> {
Ok(DeviceProperties {
name: "AMD Radeon RX 7900 XTX".to_string(),
gfx_version: "gfx1100".to_string(),
total_memory: 24 * 1024 * 1024 * 1024, available_memory: 22 * 1024 * 1024 * 1024, compute_units: 96,
wavefront_size: 64,
max_threads_per_block: 1024,
max_shared_memory: 65536,
})
}
pub fn compile_kernel(
&self,
name: &str,
source: &str,
grid: (u32, u32, u32),
block: (u32, u32, u32),
) -> Result<RocmKernel> {
{
let cache = self.kernel_cache.lock().expect("Lock poisoned");
if let Some(kernel) = cache.get(name) {
return Ok(kernel.clone());
}
}
let code_object = self.compile_hip_source(source)?;
let mut module = std::ptr::null_mut();
let result = unsafe {
(self.hip_lib.hip_module_load_data)(
&mut module,
code_object.as_ptr() as *const std::ffi::c_void,
)
};
if result != HIP_SUCCESS {
return Err(TrustformersError::hardware_error(
"Failed to load ROCm module",
"RocmImpl::compile_kernel",
));
}
let mut function = std::ptr::null_mut();
let name_cstr = std::ffi::CString::new(name).map_err(|_| {
TrustformersError::hardware_error(
"Kernel name contains null byte",
"RocmImpl::compile_kernel",
)
})?;
let result = unsafe {
(self.hip_lib.hip_module_get_function)(&mut function, module, name_cstr.as_ptr())
};
if result != HIP_SUCCESS {
return Err(TrustformersError::hardware_error(
"Failed to get ROCm function",
"RocmImpl::compile_kernel",
));
}
let kernel = RocmKernel {
function,
name: name.to_string(),
grid_config: grid,
block_config: block,
shared_memory: 0,
};
{
let mut cache = self.kernel_cache.lock().expect("Lock poisoned");
cache.insert(name.to_string(), kernel.clone());
}
Ok(kernel)
}
fn compile_hip_source(&self, source: &str) -> Result<Vec<u8>> {
Ok(vec![0; 1024]) }
pub fn allocate_memory(&self, size: usize) -> Result<*mut std::ffi::c_void> {
{
let mut pool = self.memory_pool.lock().expect("Lock poisoned");
if let Some(block) = pool.get_block(size) {
return Ok(block.ptr);
}
}
let mut ptr = std::ptr::null_mut();
let result = unsafe { (self.hip_lib.hip_malloc)(&mut ptr, size) };
if result != HIP_SUCCESS {
return Err(TrustformersError::hardware_error(
"Failed to allocate GPU memory",
"RocmImpl::allocate_memory",
));
}
{
let mut pool = self.memory_pool.lock().expect("Lock poisoned");
pool.total_allocated += size;
pool.peak_memory = pool.peak_memory.max(pool.total_allocated);
}
Ok(ptr)
}
pub fn copy_to_gpu(&self, tensor: &Tensor) -> Result<*mut std::ffi::c_void> {
let data = tensor.data_f32()?;
let size = data.len() * std::mem::size_of::<f32>();
let gpu_ptr = self.allocate_memory(size)?;
let result = unsafe {
(self.hip_lib.hip_memcpy)(
gpu_ptr,
data.as_ptr() as *const std::ffi::c_void,
size,
HIP_MEMCPY_HOST_TO_DEVICE,
)
};
if result != HIP_SUCCESS {
return Err(TrustformersError::hardware_error(
"Failed to copy data to GPU",
"RocmImpl::copy_to_gpu",
));
}
Ok(gpu_ptr)
}
pub unsafe fn copy_from_gpu(
&self,
gpu_ptr: *mut std::ffi::c_void,
tensor: &mut Tensor,
) -> Result<()> {
let size = tensor.memory_usage();
let mut data = vec![0.0f32; size / std::mem::size_of::<f32>()];
let result = (self.hip_lib.hip_memcpy)(
data.as_mut_ptr() as *mut std::ffi::c_void,
gpu_ptr,
size,
HIP_MEMCPY_DEVICE_TO_HOST,
);
if result != HIP_SUCCESS {
return Err(TrustformersError::hardware_error(
"Failed to copy data from GPU",
"RocmImpl::copy_from_gpu",
));
}
tensor.set_data_f32(&data)?;
Ok(())
}
pub fn matmul(&self, a: &Tensor, b: &Tensor, c: &mut Tensor) -> Result<()> {
let a_shape = a.shape();
let b_shape = b.shape();
let c_shape = c.shape();
if a_shape.len() != 2 || b_shape.len() != 2 || c_shape.len() != 2 {
return Err(TrustformersError::tensor_op_error(
"Matrix multiplication requires 2D tensors",
"RocmImpl::matmul",
));
}
let m = a_shape[0] as u32;
let k = a_shape[1] as u32;
let n = b_shape[1] as u32;
let kernel_source = self.generate_rocm_matmul_kernel(m, k, n);
let kernel = self.compile_kernel(
"rocm_matmul",
&kernel_source,
(n.div_ceil(16), m.div_ceil(16), 1),
(16, 16, 1),
)?;
let a_gpu = self.copy_to_gpu(a)?;
let b_gpu = self.copy_to_gpu(b)?;
let c_gpu = self.allocate_memory(c_shape[0] * c_shape[1] * 4)?;
let mut kernel_args = vec![
&a_gpu as *const _ as *mut std::ffi::c_void,
&b_gpu as *const _ as *mut std::ffi::c_void,
&c_gpu as *const _ as *mut std::ffi::c_void,
&m as *const _ as *mut std::ffi::c_void,
&k as *const _ as *mut std::ffi::c_void,
&n as *const _ as *mut std::ffi::c_void,
];
let result = unsafe {
(self.hip_lib.hip_module_launch_kernel)(
kernel.function,
kernel.grid_config.0,
kernel.grid_config.1,
kernel.grid_config.2,
kernel.block_config.0,
kernel.block_config.1,
kernel.block_config.2,
kernel.shared_memory,
std::ptr::null_mut(), kernel_args.as_mut_ptr(),
std::ptr::null_mut(), )
};
if result != HIP_SUCCESS {
return Err(TrustformersError::hardware_error(
"Failed to launch ROCm kernel",
"RocmImpl::matmul",
));
}
let sync_result = unsafe { (self.hip_lib.hip_device_synchronize)() };
if sync_result != HIP_SUCCESS {
return Err(TrustformersError::hardware_error(
"ROCm synchronization failed",
"RocmImpl::matmul",
));
}
unsafe {
self.copy_from_gpu(c_gpu, c)?;
}
unsafe {
(self.hip_lib.hip_free)(a_gpu);
(self.hip_lib.hip_free)(b_gpu);
(self.hip_lib.hip_free)(c_gpu);
}
Ok(())
}
fn generate_rocm_matmul_kernel(&self, _m: u32, _k: u32, _n: u32) -> String {
r#"
#include <hip/hip_runtime.h>
extern "C" __global__ void rocm_matmul(
const float* __restrict__ A,
const float* __restrict__ B,
float* __restrict__ C,
const unsigned int M,
const unsigned int K,
const unsigned int N
) {
// Optimized for AMD RDNA architecture with 64-wide wavefronts
const int TILE_SIZE = 16;
const int tx = threadIdx.x;
const int ty = threadIdx.y;
const int bx = blockIdx.x;
const int by = blockIdx.y;
// Calculate global thread indices
const int row = by * TILE_SIZE + ty;
const int col = bx * TILE_SIZE + tx;
// Use LDS (Local Data Share) for tiling - AMD's equivalent to shared memory
__shared__ float As[TILE_SIZE][TILE_SIZE];
__shared__ float Bs[TILE_SIZE][TILE_SIZE];
float sum = 0.0f;
// Tile loop optimized for AMD GPU memory hierarchy
for (int t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; ++t) {
// Load tile into LDS with coalesced access
int a_col = t * TILE_SIZE + tx;
int b_row = t * TILE_SIZE + ty;
As[ty][tx] = (row < M && a_col < K) ? A[row * K + a_col] : 0.0f;
Bs[ty][tx] = (b_row < K && col < N) ? B[b_row * N + col] : 0.0f;
// Synchronize wavefront
__syncthreads();
// Compute partial dot product with manual unrolling for AMD ALUs
#pragma unroll
for (int i = 0; i < TILE_SIZE; ++i) {
sum += As[ty][i] * Bs[i][tx];
}
__syncthreads();
}
// Write result with bounds checking
if (row < M && col < N) {
C[row * N + col] = sum;
}
}
"#
.to_string()
}
pub fn flash_attention(
&self,
query: &Tensor,
key: &Tensor,
value: &Tensor,
output: &mut Tensor,
) -> Result<()> {
let q_shape = query.shape();
let batch_size = q_shape[0] as u32;
let seq_len = q_shape[1] as u32;
let head_dim = q_shape[2] as u32;
let kernel_source =
self.generate_rocm_flash_attention_kernel(batch_size, seq_len, head_dim);
let kernel = self.compile_kernel(
"rocm_flash_attention",
&kernel_source,
(batch_size, seq_len, 1),
(256, 1, 1),
)?;
let q_gpu = self.copy_to_gpu(query)?;
let k_gpu = self.copy_to_gpu(key)?;
let v_gpu = self.copy_to_gpu(value)?;
let o_gpu = self.allocate_memory(output.memory_usage())?;
let mut kernel_args = vec![
&q_gpu as *const _ as *mut std::ffi::c_void,
&k_gpu as *const _ as *mut std::ffi::c_void,
&v_gpu as *const _ as *mut std::ffi::c_void,
&o_gpu as *const _ as *mut std::ffi::c_void,
&batch_size as *const _ as *mut std::ffi::c_void,
&seq_len as *const _ as *mut std::ffi::c_void,
&head_dim as *const _ as *mut std::ffi::c_void,
];
let result = unsafe {
(self.hip_lib.hip_module_launch_kernel)(
kernel.function,
kernel.grid_config.0,
kernel.grid_config.1,
kernel.grid_config.2,
kernel.block_config.0,
kernel.block_config.1,
kernel.block_config.2,
48 * 1024, std::ptr::null_mut(), kernel_args.as_mut_ptr(),
std::ptr::null_mut(), )
};
if result != HIP_SUCCESS {
return Err(TrustformersError::hardware_error(
"Failed to launch ROCm Flash Attention kernel",
"RocmImpl::flash_attention",
));
}
let sync_result = unsafe { (self.hip_lib.hip_device_synchronize)() };
if sync_result != HIP_SUCCESS {
return Err(TrustformersError::hardware_error(
"ROCm synchronization failed",
"RocmImpl::flash_attention",
));
}
unsafe {
self.copy_from_gpu(o_gpu, output)?;
}
unsafe {
(self.hip_lib.hip_free)(q_gpu);
(self.hip_lib.hip_free)(k_gpu);
(self.hip_lib.hip_free)(v_gpu);
(self.hip_lib.hip_free)(o_gpu);
}
Ok(())
}
fn generate_rocm_flash_attention_kernel(
&self,
_batch_size: u32,
_seq_len: u32,
_head_dim: u32,
) -> String {
r#"
#include <hip/hip_runtime.h>
extern "C" __global__ void rocm_flash_attention(
const float* __restrict__ Q,
const float* __restrict__ K,
const float* __restrict__ V,
float* __restrict__ O,
const unsigned int batch_size,
const unsigned int seq_len,
const unsigned int head_dim
) {
// Flash Attention optimized for AMD RDNA architecture
const int batch_id = blockIdx.x;
const int seq_id = blockIdx.y;
const int lane_id = threadIdx.x;
if (batch_id >= batch_size || seq_id >= seq_len) return;
// Use LDS for computation - optimized for 64-wide wavefronts
extern __shared__ float lds_memory[];
float* lds_scores = lds_memory;
float* lds_values = lds_memory + seq_len;
// Compute QK^T scores with wavefront-optimized memory access
float max_score = -INFINITY;
for (int k = lane_id; k < seq_len; k += 64) { // 64-wide wavefront
float score = 0.0f;
for (int d = 0; d < head_dim; d++) {
int q_idx = batch_id * seq_len * head_dim + seq_id * head_dim + d;
int k_idx = batch_id * seq_len * head_dim + k * head_dim + d;
score += Q[q_idx] * K[k_idx];
}
lds_scores[k] = score;
max_score = fmaxf(max_score, score);
}
// Wavefront-level reduction for maximum
#pragma unroll
for (int offset = 32; offset > 0; offset >>= 1) {
max_score = fmaxf(max_score, __shfl_down(max_score, offset));
}
// Broadcast max to all lanes in wavefront
max_score = __shfl(max_score, 0);
// Compute softmax with numerical stability
float sum_exp = 0.0f;
for (int k = lane_id; k < seq_len; k += 64) {
float exp_score = expf(lds_scores[k] - max_score);
lds_scores[k] = exp_score;
sum_exp += exp_score;
}
// Wavefront-level reduction for sum
#pragma unroll
for (int offset = 32; offset > 0; offset >>= 1) {
sum_exp += __shfl_down(sum_exp, offset);
}
// Broadcast sum to all lanes
sum_exp = __shfl(sum_exp, 0);
// Normalize attention weights
for (int k = lane_id; k < seq_len; k += 64) {
lds_scores[k] /= sum_exp;
}
__syncthreads();
// Compute output with optimized memory access
for (int d = lane_id; d < head_dim; d += 64) {
float output_val = 0.0f;
for (int k = 0; k < seq_len; k++) {
int v_idx = batch_id * seq_len * head_dim + k * head_dim + d;
output_val += lds_scores[k] * V[v_idx];
}
int o_idx = batch_id * seq_len * head_dim + seq_id * head_dim + d;
O[o_idx] = output_val;
}
}
"#
.to_string()
}
pub fn device_info(&self) -> String {
format!(
"ROCm Device: {}, GFX: {}, Memory: {:.1} GB",
self.device_props.name,
self.device_props.gfx_version,
self.device_props.total_memory as f64 / (1024.0 * 1024.0 * 1024.0)
)
}
pub fn memory_stats(&self) -> (usize, usize) {
let pool = self.memory_pool.lock().expect("Lock poisoned");
(pool.total_allocated, pool.peak_memory)
}
pub fn fused_gelu(&self, input: &Tensor, output: &mut Tensor, approximate: bool) -> Result<()> {
let shape = input.shape();
let total_elements = shape.iter().product::<usize>();
let kernel_source = self.generate_fused_gelu_kernel(approximate);
let grid_size = total_elements.div_ceil(256) as u32;
let kernel = self.compile_kernel(
"fused_gelu_kernel",
&kernel_source,
(grid_size, 1, 1),
(256, 1, 1),
)?;
let input_gpu = self.copy_to_gpu(input)?;
let output_gpu = self.allocate_memory(total_elements * 4)?;
let n = total_elements as u32;
let mut kernel_args = vec![
&input_gpu as *const _ as *mut std::ffi::c_void,
&output_gpu as *const _ as *mut std::ffi::c_void,
&n as *const _ as *mut std::ffi::c_void,
];
let result = unsafe {
(self.hip_lib.hip_module_launch_kernel)(
kernel.function,
grid_size,
1,
1,
256,
1,
1,
kernel.shared_memory,
std::ptr::null_mut(),
kernel_args.as_mut_ptr(),
std::ptr::null_mut(),
)
};
if result != HIP_SUCCESS {
return Err(TrustformersError::hardware_error(
&format!("Failed to launch ROCm GELU kernel: {}", result),
"rocm_gelu_launch",
));
}
unsafe {
(self.hip_lib.hip_device_synchronize)();
}
unsafe {
self.copy_from_gpu(output_gpu, output)?;
}
unsafe {
(self.hip_lib.hip_free)(input_gpu);
(self.hip_lib.hip_free)(output_gpu);
}
Ok(())
}
pub fn fused_bias_activation(
&self,
input: &Tensor,
bias: &Tensor,
output: &mut Tensor,
activation: &str,
) -> Result<()> {
let shape = input.shape();
let total_elements = shape.iter().product::<usize>();
let bias_size = bias.shape().iter().product::<usize>();
let kernel_source = self.generate_fused_bias_activation_kernel(activation);
let grid_size = total_elements.div_ceil(256) as u32;
let kernel = self.compile_kernel(
"fused_bias_activation_kernel",
&kernel_source,
(grid_size, 1, 1),
(256, 1, 1),
)?;
let input_gpu = self.copy_to_gpu(input)?;
let bias_gpu = self.copy_to_gpu(bias)?;
let output_gpu = self.allocate_memory(total_elements * 4)?;
let n = total_elements as u32;
let b = bias_size as u32;
let mut kernel_args = vec![
&input_gpu as *const _ as *mut std::ffi::c_void,
&bias_gpu as *const _ as *mut std::ffi::c_void,
&output_gpu as *const _ as *mut std::ffi::c_void,
&n as *const _ as *mut std::ffi::c_void,
&b as *const _ as *mut std::ffi::c_void,
];
let result = unsafe {
(self.hip_lib.hip_module_launch_kernel)(
kernel.function,
grid_size,
1,
1,
256,
1,
1,
kernel.shared_memory,
std::ptr::null_mut(),
kernel_args.as_mut_ptr(),
std::ptr::null_mut(),
)
};
if result != HIP_SUCCESS {
return Err(TrustformersError::hardware_error(
&format!("Failed to launch ROCm bias activation kernel: {}", result),
"rocm_bias_activation_launch",
));
}
unsafe {
(self.hip_lib.hip_device_synchronize)();
}
unsafe {
self.copy_from_gpu(output_gpu, output)?;
}
unsafe {
(self.hip_lib.hip_free)(input_gpu);
(self.hip_lib.hip_free)(bias_gpu);
(self.hip_lib.hip_free)(output_gpu);
}
Ok(())
}
fn generate_fused_gelu_kernel(&self, approximate: bool) -> String {
if approximate {
r#"
extern "C" __global__ void fused_gelu_kernel(const float* input, float* output, unsigned int n) {
unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
float x = input[idx];
float x3 = x * x * x;
float arg = 0.7978845608f * (x + 0.044715f * x3); // sqrt(2/π) ≈ 0.7978845608
float tanh_val = tanhf(arg);
output[idx] = 0.5f * x * (1.0f + tanh_val);
}
}
"#
.to_string()
} else {
r#"
extern "C" __global__ void fused_gelu_kernel(const float* input, float* output, unsigned int n) {
unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
float x = input[idx];
float scaled = x * 0.7071067812f; // 1/sqrt(2) ≈ 0.7071067812
float erf_val = erff(scaled);
output[idx] = 0.5f * x * (1.0f + erf_val);
}
}
"#
.to_string()
}
}
fn generate_fused_bias_activation_kernel(&self, activation: &str) -> String {
let activation_code = match activation {
"relu" => "fmaxf(value, 0.0f)",
"gelu" => "0.5f * value * (1.0f + tanhf(0.7978845608f * (value + 0.044715f * value * value * value)))",
"silu" => "value / (1.0f + expf(-value))", "tanh" => "tanhf(value)",
"none" => "value",
_ => "value", };
format!(
r#"
extern "C" __global__ void fused_bias_activation_kernel(
const float* input,
const float* bias,
float* output,
unsigned int n,
unsigned int bias_size
) {{
unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {{
float bias_val = bias[idx % bias_size]; // Broadcast bias
float value = input[idx] + bias_val;
output[idx] = {};
}}
}}
"#,
activation_code
)
}
}
#[cfg(not(all(feature = "rocm", target_os = "linux")))]
impl RocmImpl {
pub fn new() -> Result<Self> {
Err(TrustformersError::hardware_error(
"ROCm support is not available on this platform",
"RocmImpl::new",
))
}
pub fn global() -> Result<&'static Arc<RocmImpl>> {
Err(TrustformersError::hardware_error(
"ROCm support is not available on this platform",
"RocmImpl::global",
))
}
pub fn matmul(&self, _a: &Tensor, _b: &Tensor, _c: &mut Tensor) -> Result<()> {
Err(TrustformersError::hardware_error(
"ROCm support is not available on this platform",
"RocmImpl::matmul",
))
}
pub fn device_info(&self) -> String {
"ROCm not available".to_string()
}
}
#[cfg(all(feature = "rocm", target_os = "linux"))]
impl RocmMemoryPool {
pub fn new() -> Self {
Self {
available_blocks: Vec::new(),
allocated_blocks: HashMap::new(),
total_allocated: 0,
peak_memory: 0,
}
}
pub fn get_block(&mut self, size: usize) -> Option<RocmMemoryBlock> {
let pos = self.available_blocks.iter().position(|block| block.size >= size)?;
Some(self.available_blocks.remove(pos))
}
pub fn return_block(&mut self, block: RocmMemoryBlock) {
self.available_blocks.push(block);
}
}
#[cfg(all(feature = "rocm", target_os = "linux"))]
pub mod api {
use super::*;
pub fn init_rocm() -> Result<()> {
RocmImpl::global().map(|_| ())
}
pub fn rocm_matmul(a: &Tensor, b: &Tensor, c: &mut Tensor) -> Result<()> {
RocmImpl::global()?.matmul(a, b, c)
}
pub fn rocm_flash_attention(
query: &Tensor,
key: &Tensor,
value: &Tensor,
output: &mut Tensor,
) -> Result<()> {
RocmImpl::global()?.flash_attention(query, key, value, output)
}
pub fn rocm_device_info() -> Result<String> {
Ok(RocmImpl::global()?.device_info())
}
pub fn rocm_memory_stats() -> Result<(usize, usize)> {
Ok(RocmImpl::global()?.memory_stats())
}
pub fn rocm_fused_gelu(input: &Tensor, output: &mut Tensor, approximate: bool) -> Result<()> {
RocmImpl::global()?.fused_gelu(input, output, approximate)
}
pub fn rocm_fused_bias_activation(
input: &Tensor,
bias: &Tensor,
output: &mut Tensor,
activation: &str,
) -> Result<()> {
RocmImpl::global()?.fused_bias_activation(input, bias, output, activation)
}
}
#[cfg(not(all(feature = "rocm", target_os = "linux")))]
pub mod api {
use super::*;
pub fn init_rocm() -> Result<()> {
Err(TrustformersError::hardware_error(
"ROCm support is not available on this platform",
"rocm_impl::api::init_rocm",
))
}
pub fn rocm_matmul(_a: &Tensor, _b: &Tensor, _c: &mut Tensor) -> Result<()> {
Err(TrustformersError::hardware_error(
"ROCm support is not available on this platform",
"rocm_impl::api::rocm_matmul",
))
}
pub fn rocm_flash_attention(
_query: &Tensor,
_key: &Tensor,
_value: &Tensor,
_output: &mut Tensor,
) -> Result<()> {
Err(TrustformersError::hardware_error(
"ROCm support is not available on this platform",
"rocm_impl::api::rocm_flash_attention",
))
}
pub fn rocm_device_info() -> Result<String> {
Err(TrustformersError::hardware_error(
"ROCm support is not available on this platform",
"rocm_impl::api::rocm_device_info",
))
}
pub fn rocm_memory_stats() -> Result<(usize, usize)> {
Err(TrustformersError::hardware_error(
"ROCm support is not available on this platform",
"rocm_impl::api::rocm_memory_stats",
))
}
pub fn rocm_fused_gelu(
_input: &Tensor,
_output: &mut Tensor,
_approximate: bool,
) -> Result<()> {
Err(TrustformersError::hardware_error(
"ROCm support is not available on this platform",
"rocm_impl::api::rocm_fused_gelu",
))
}
pub fn rocm_fused_bias_activation(
_input: &Tensor,
_bias: &Tensor,
_output: &mut Tensor,
_activation: &str,
) -> Result<()> {
Err(TrustformersError::hardware_error(
"ROCm support is not available on this platform",
"rocm_impl::api::rocm_fused_bias_activation",
))
}
}
unsafe impl Send for RocmKernel {}
unsafe impl Sync for RocmKernel {}
unsafe impl Send for RocmMemoryBlock {}
unsafe impl Sync for RocmMemoryBlock {}
#[cfg(test)]
mod tests {
use super::*;
#[allow(unused_imports)]
use crate::tensor::Tensor;
#[test]
fn test_rocm_initialization() {
match RocmImpl::new() {
Ok(_) => println!("ROCm initialized successfully"),
Err(_) => println!("ROCm not available, using fallback"),
}
}
#[test]
#[cfg(all(feature = "rocm", target_os = "linux"))]
fn test_rocm_matmul() {
if let Ok(rocm) = RocmImpl::global() {
let a = Tensor::ones(&[4, 4]).expect("Failed to create ones tensor");
let b = Tensor::ones(&[4, 4]).expect("Failed to create ones tensor");
let mut c = Tensor::zeros(&[4, 4]).expect("Failed to create zero tensor");
let result = rocm.matmul(&a, &b, &mut c);
if result.is_err() {
println!(
"ROCm matmul failed (expected if no ROCm hardware): {:?}",
result
);
}
} else {
println!("ROCm not available for testing");
}
}
#[test]
#[cfg(all(feature = "rocm", target_os = "linux"))]
fn test_device_properties() {
if let Ok(rocm) = RocmImpl::global() {
let info = rocm.device_info();
assert!(info.contains("ROCm Device"));
} else {
println!("ROCm not available for testing");
}
}
#[test]
fn test_memory_pool() {
let pool = RocmMemoryPool::default();
assert_eq!(pool.total_allocated, 0);
assert_eq!(pool.peak_memory, 0);
}
}