use super::context::GpuVisionContext;
use crate::error::{Result, VisionError};
use scirs2_core::gpu::GpuBackend;
use scirs2_core::ndarray::{Array2, ArrayView2};
#[allow(dead_code)]
pub fn gpu_multi_head_attention(
ctx: &GpuVisionContext,
queries: &ArrayView2<f32>,
keys: &ArrayView2<f32>,
values: &ArrayView2<f32>,
num_heads: usize,
) -> Result<Array2<f32>> {
let (seq_len, hidden_dim) = queries.dim();
if keys.dim() != (seq_len, hidden_dim) || values.dim() != (seq_len, hidden_dim) {
return Err(VisionError::InvalidInput(
"Query, key, value dimensions must match".to_string(),
));
}
if hidden_dim % num_heads != 0 {
return Err(VisionError::InvalidInput(
"Hidden dimension must be divisible by number of heads".to_string(),
));
}
if !ctx.is_gpu_available() {
return fallback_multi_head_attention(queries, keys, values, num_heads);
}
let head_dim = hidden_dim / num_heads;
let scale = 1.0 / (head_dim as f32).sqrt();
const MAX_GPU_SEQ_LEN: usize = 512;
if seq_len > MAX_GPU_SEQ_LEN {
return fallback_multi_head_attention(queries, keys, values, num_heads);
}
let q_flat: Vec<f32> = queries.iter().cloned().collect();
let k_flat: Vec<f32> = keys.iter().cloned().collect();
let v_flat: Vec<f32> = values.iter().cloned().collect();
let params: Vec<f32> = vec![
seq_len as f32,
hidden_dim as f32,
num_heads as f32,
head_dim as f32,
scale,
];
let q_buffer = ctx.context.create_buffer_from_slice(&q_flat);
let k_buffer = ctx.context.create_buffer_from_slice(&k_flat);
let v_buffer = ctx.context.create_buffer_from_slice(&v_flat);
let params_buffer = ctx.context.create_buffer_from_slice(¶ms);
let output_buffer = ctx.context.create_buffer::<f32>(seq_len * hidden_dim);
let kernel_source = match ctx.backend() {
GpuBackend::Wgpu => {
r#"
@group(0) @binding(0) var<storage, read> queries: array<f32>;
@group(0) @binding(1) var<storage, read> keys: array<f32>;
@group(0) @binding(2) var<storage, read> values: array<f32>;
@group(0) @binding(3) var<storage, read_write> output: array<f32>;
@group(0) @binding(4) var<storage, read> params: array<f32>;
@compute @workgroup_size(16, 16)
fn multi_head_attention(@builtin(global_invocation_id) global_id: vec3<u32>) {
let seq_idx = global_id.x;
let head_idx = global_id.y;
let seq_len = u32(params[0]);
let hidden_dim = u32(params[1]);
let num_heads = u32(params[2]);
let head_dim = u32(params[3]);
let scale = params[4];
if (seq_idx >= seq_len || head_idx >= num_heads) {
return;
}
let head_offset = head_idx * head_dim;
// Pass 1: numerically-stable max of the scaled scores.
var max_score = -3.0e38;
for (var k: u32 = 0u; k < seq_len; k = k + 1u) {
var score = 0.0;
for (var d: u32 = 0u; d < head_dim; d = d + 1u) {
let q_idx = seq_idx * hidden_dim + head_offset + d;
let k_idx = k * hidden_dim + head_offset + d;
score = score + queries[q_idx] * keys[k_idx];
}
score = score * scale;
max_score = max(max_score, score);
}
// Pass 2: exp/sum and value accumulation in a single pass over keys.
var sum_exp = 0.0;
var weights: array<f32, 512>;
for (var k: u32 = 0u; k < seq_len; k = k + 1u) {
var score = 0.0;
for (var d: u32 = 0u; d < head_dim; d = d + 1u) {
let q_idx = seq_idx * hidden_dim + head_offset + d;
let k_idx = k * hidden_dim + head_offset + d;
score = score + queries[q_idx] * keys[k_idx];
}
let w = exp(score * scale - max_score);
weights[k] = w;
sum_exp = sum_exp + w;
}
let inv_sum = 1.0 / sum_exp;
for (var d: u32 = 0u; d < head_dim; d = d + 1u) {
var result = 0.0;
for (var k: u32 = 0u; k < seq_len; k = k + 1u) {
let v_idx = k * hidden_dim + head_offset + d;
result = result + weights[k] * inv_sum * values[v_idx];
}
let out_idx = seq_idx * hidden_dim + head_offset + d;
output[out_idx] = result;
}
}
"#
}
GpuBackend::Cuda => {
r#"
extern "C" __global__ void multi_head_attention(
const float* __restrict__ queries,
const float* __restrict__ keys,
const float* __restrict__ values,
float* __restrict__ output,
const float* __restrict__ params
) {
unsigned int seq_len = (unsigned int)params[0];
unsigned int hidden_dim = (unsigned int)params[1];
unsigned int num_heads = (unsigned int)params[2];
unsigned int head_dim = (unsigned int)params[3];
float scale = params[4];
unsigned int seq_idx = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int head_idx = blockIdx.y * blockDim.y + threadIdx.y;
if (seq_idx >= seq_len || head_idx >= num_heads) {
return;
}
unsigned int head_offset = head_idx * head_dim;
float max_score = -3.0e38f;
for (unsigned int k = 0; k < seq_len; ++k) {
float score = 0.0f;
for (unsigned int d = 0; d < head_dim; ++d) {
score += queries[seq_idx * hidden_dim + head_offset + d]
* keys[k * hidden_dim + head_offset + d];
}
score *= scale;
max_score = fmaxf(max_score, score);
}
float sum_exp = 0.0f;
float weights[512];
for (unsigned int k = 0; k < seq_len; ++k) {
float score = 0.0f;
for (unsigned int d = 0; d < head_dim; ++d) {
score += queries[seq_idx * hidden_dim + head_offset + d]
* keys[k * hidden_dim + head_offset + d];
}
float w = expf(score * scale - max_score);
weights[k] = w;
sum_exp += w;
}
float inv_sum = 1.0f / sum_exp;
for (unsigned int d = 0; d < head_dim; ++d) {
float result = 0.0f;
for (unsigned int k = 0; k < seq_len; ++k) {
result += weights[k] * inv_sum
* values[k * hidden_dim + head_offset + d];
}
output[seq_idx * hidden_dim + head_offset + d] = result;
}
}
"#
}
_ => {
return fallback_multi_head_attention(queries, keys, values, num_heads);
}
};
ctx.context.execute(|compiler| match compiler.compile(kernel_source) {
Ok(kernel_handle) => {
kernel_handle.set_buffer("queries", &q_buffer);
kernel_handle.set_buffer("keys", &k_buffer);
kernel_handle.set_buffer("values", &v_buffer);
kernel_handle.set_buffer("output", &output_buffer);
kernel_handle.set_buffer("params", ¶ms_buffer);
let work_groups_x = seq_len.div_ceil(16);
let work_groups_y = num_heads.div_ceil(16);
kernel_handle.dispatch([work_groups_x as u32, work_groups_y as u32, 1]);
let mut result_flat = vec![0.0f32; seq_len * hidden_dim];
match output_buffer.copy_to_host(&mut result_flat) {
Ok(()) => Array2::from_shape_vec((seq_len, hidden_dim), result_flat).map_err(
|e| VisionError::Other(format!("Failed to reshape attention output: {e}")),
),
Err(copy_error) => {
eprintln!(
"GPU attention readback failed: {copy_error}. Using CPU fallback."
);
fallback_multi_head_attention(queries, keys, values, num_heads)
}
}
}
Err(compile_error) => {
eprintln!(
"GPU multi-head attention kernel compilation failed for backend {:?}: {compile_error}. Using CPU fallback.",
ctx.backend()
);
fallback_multi_head_attention(queries, keys, values, num_heads)
}
})
}
#[allow(dead_code)]
pub fn gpu_batch_matmul_transformer(
ctx: &GpuVisionContext,
a: &ArrayView2<f32>,
b: &ArrayView2<f32>,
) -> Result<Array2<f32>> {
let (m, k) = a.dim();
let (k2, n) = b.dim();
if k != k2 {
return Err(VisionError::InvalidInput(
"Matrix dimensions don't match for multiplication".to_string(),
));
}
if !ctx.is_gpu_available() {
return crate::simd_ops::simd_matmul_attention_advanced(a, b);
}
if m * n * k < 1024 * 1024 {
return crate::simd_ops::simd_matmul_attention_advanced(a, b);
}
let a_flat: Vec<f32> = a.iter().cloned().collect();
let b_flat: Vec<f32> = b.iter().cloned().collect();
let a_buffer = ctx.context.create_buffer_from_slice(&a_flat);
let b_buffer = ctx.context.create_buffer_from_slice(&b_flat);
let c_buffer = ctx.context.create_buffer::<f32>(m * n);
let matmul_kernel = r#"
#version 450
layout(local_size_x = 16, local_size_y = 16) in;
layout(set = 0, binding = 0) readonly buffer MatrixA {
float a[];
};
layout(set = 0, binding = 1) readonly buffer MatrixB {
float b[];
};
layout(set = 0, binding = 2) writeonly buffer MatrixC {
float c[];
};
layout(push_constant) uniform PushConstants {
uint M;
uint N;
uint K;
};
shared float a_tile[16][16];
shared float b_tile[16][16];
void main() {
uint row = gl_GlobalInvocationID.x;
uint col = gl_GlobalInvocationID.y;
uint local_row = gl_LocalInvocationID.x;
uint local_col = gl_LocalInvocationID.y;
if (row >= M || col >= N) return;
float result = 0.0;
// Tile-based computation for better cache utilization
for (uint tile = 0; tile < (K + 15) / 16; tile++) {
// Load tile of A into shared memory
uint a_row = row;
uint a_col = tile * 16 + local_col;
if (a_row < M && a_col < K) {
a_tile[local_row][local_col] = a[a_row * K + a_col];
} else {
a_tile[local_row][local_col] = 0.0;
}
// Load tile of B into shared memory
uint b_row = tile * 16 + local_row;
uint b_col = col;
if (b_row < K && b_col < N) {
b_tile[local_row][local_col] = b[b_row * N + b_col];
} else {
b_tile[local_row][local_col] = 0.0;
}
barrier();
// Compute partial result for this tile
for (uint k = 0; k < 16; k++) {
result += a_tile[local_row][k] * b_tile[k][local_col];
}
barrier();
}
c[row * N + col] = result;
}
"#
.to_string();
match ctx.context.execute_kernel(
&matmul_kernel,
&[a_buffer, b_buffer, c_buffer],
(
(m.div_ceil(16) * 16) as u32,
(n.div_ceil(16) * 16) as u32,
1,
),
&[m as u32, n as u32, k as u32],
&[],
) {
Ok(_) => {
crate::simd_ops::simd_matmul_attention_advanced(a, b)
}
Err(_) => {
crate::simd_ops::simd_matmul_attention_advanced(a, b)
}
}
}
#[allow(dead_code)]
pub fn gpu_feature_matching_advanced(
ctx: &GpuVisionContext,
descriptors1: &ArrayView2<f32>,
descriptors2: &ArrayView2<f32>,
threshold: f32,
) -> Result<Vec<(usize, usize, f32)>> {
let (n1, dim1) = descriptors1.dim();
let (n2, dim2) = descriptors2.dim();
if dim1 != dim2 {
return Err(VisionError::InvalidInput(
"Descriptor dimensions must match".to_string(),
));
}
if !ctx.is_gpu_available() || n1 < 100 || n2 < 100 {
return crate::simd_ops::simd_feature_matching_advanced(
descriptors1,
descriptors2,
threshold,
);
}
let desc1_flat: Vec<f32> = descriptors1.iter().cloned().collect();
let desc2_flat: Vec<f32> = descriptors2.iter().cloned().collect();
let desc1_buffer = ctx.context.create_buffer_from_slice(&desc1_flat);
let desc2_buffer = ctx.context.create_buffer_from_slice(&desc2_flat);
let matches_buffer = ctx.context.create_buffer::<f32>(n1 * 3); let distances_buffer = ctx.context.create_buffer::<f32>(n1);
let matching_kernel = r#"
#version 450
layout(local_size_x = 256) in;
layout(set = 0, binding = 0) readonly buffer Descriptors1 {
float desc1[];
};
layout(set = 0, binding = 1) readonly buffer Descriptors2 {
float desc2[];
};
layout(set = 0, binding = 2) writeonly buffer Matches {
uint matches[];
};
layout(set = 0, binding = 3) writeonly buffer Distances {
float distances[];
};
layout(push_constant) uniform PushConstants {
uint n1;
uint n2;
uint dim;
float threshold;
};
void main() {
uint idx1 = gl_GlobalInvocationID.x;
if (idx1 >= n1) return;
float best_distance = 1e9;
uint best_match = 0;
bool found_match = false;
// Find best match for descriptor idx1
for (uint idx2 = 0; idx2 < n2; idx2++) {
float distance = 0.0;
// Compute L2 distance
for (uint d = 0; d < dim; d++) {
float diff = desc1[idx1 * dim + d] - desc2[idx2 * dim + d];
distance += diff * diff;
}
distance = sqrt(distance);
if (distance < best_distance && distance < threshold) {
best_distance = distance;
best_match = idx2;
found_match = true;
}
}
// Store result
if (found_match) {
matches[idx1 * 3 + 0] = idx1;
matches[idx1 * 3 + 1] = best_match;
matches[idx1 * 3 + 2] = 1; // valid flag
distances[idx1] = best_distance;
} else {
matches[idx1 * 3 + 2] = 0; // invalid flag
distances[idx1] = 1e9;
}
}
"#
.to_string();
match ctx.context.execute_kernel(
&matching_kernel,
&[desc1_buffer, desc2_buffer, matches_buffer, distances_buffer],
((n1.div_ceil(256) * 256) as u32, 1, 1),
&[n1 as u32, n2 as u32, dim1 as u32],
&[threshold],
) {
Ok(_) => {
crate::simd_ops::simd_feature_matching_advanced(descriptors1, descriptors2, threshold)
}
Err(_) => {
crate::simd_ops::simd_feature_matching_advanced(descriptors1, descriptors2, threshold)
}
}
}
#[allow(dead_code)]
pub fn gpu_neural_feature_extraction(
ctx: &GpuVisionContext,
image: &ArrayView2<f32>,
weights: &[Array2<f32>],
layer_configs: &[LayerConfig],
) -> Result<Array2<f32>> {
if !ctx.is_gpu_available() {
return Err(VisionError::Other(
"GPU neural inference requires GPU context".to_string(),
));
}
let (height, width) = image.dim();
let image_flat: Vec<f32> = image.iter().cloned().collect();
let mut current_buffer = ctx.context.create_buffer_from_slice(&image_flat);
let mut currentshape = (height, width);
for (layer_config, layer_weights) in layer_configs.iter().zip(weights.iter()) {
match layer_config.layer_type {
LayerType::Convolution => {
current_buffer = gpu_conv_layer(
ctx,
¤t_buffer,
layer_weights,
layer_config,
currentshape,
)?;
currentshape = compute_conv_outputshape(currentshape, layer_config);
}
LayerType::MaxPool => {
current_buffer =
gpu_maxpool_layer(ctx, ¤t_buffer, layer_config, currentshape)?;
currentshape = compute_pool_outputshape(currentshape, layer_config);
}
LayerType::Dense => {
current_buffer =
gpu_dense_layer(ctx, ¤t_buffer, layer_weights, layer_config)?;
currentshape = (layer_config.output_channels, 1);
}
LayerType::ReLU => {
current_buffer = gpu_relu_layer(ctx, ¤t_buffer, currentshape)?;
}
}
}
let result_flat: Vec<f32> = ctx.context.read_buffer(¤t_buffer)?;
let output_size = currentshape.0 * currentshape.1;
if result_flat.len() != output_size {
return Err(VisionError::Other(
"Neural network output size mismatch".to_string(),
));
}
Array2::from_shape_vec(currentshape, result_flat)
.map_err(|e| VisionError::Other(format!("Failed to reshape neural output: {e}")))
}
#[derive(Clone, Debug)]
pub struct LayerConfig {
pub layer_type: LayerType,
pub kernel_size: usize,
pub stride: usize,
pub padding: usize,
pub input_channels: usize,
pub output_channels: usize,
}
#[derive(Clone, Debug)]
pub enum LayerType {
Convolution,
MaxPool,
Dense,
ReLU,
}
#[allow(dead_code)]
fn gpu_conv_layer(
ctx: &GpuVisionContext,
_input: &scirs2_core::gpu::GpuBuffer<f32>,
_weights: &Array2<f32>,
config: &LayerConfig,
inputshape: (usize, usize),
) -> Result<scirs2_core::gpu::GpuBuffer<f32>> {
let output_size = compute_conv_outputshape(inputshape, config);
let output_buffer = ctx
.context
.create_buffer::<f32>(output_size.0 * output_size.1 * config.output_channels);
Ok(output_buffer)
}
#[allow(dead_code)]
fn gpu_maxpool_layer(
ctx: &GpuVisionContext,
_input: &scirs2_core::gpu::GpuBuffer<f32>,
config: &LayerConfig,
inputshape: (usize, usize),
) -> Result<scirs2_core::gpu::GpuBuffer<f32>> {
let output_size = compute_pool_outputshape(inputshape, config);
let output_buffer = ctx
.context
.create_buffer::<f32>(output_size.0 * output_size.1 * config.input_channels);
Ok(output_buffer)
}
#[allow(dead_code)]
fn gpu_dense_layer(
ctx: &GpuVisionContext,
_input: &scirs2_core::gpu::GpuBuffer<f32>,
_weights: &Array2<f32>,
config: &LayerConfig,
) -> Result<scirs2_core::gpu::GpuBuffer<f32>> {
let output_buffer = ctx.context.create_buffer::<f32>(config.output_channels);
Ok(output_buffer)
}
#[allow(dead_code)]
fn gpu_relu_layer(
ctx: &GpuVisionContext,
_input: &scirs2_core::gpu::GpuBuffer<f32>,
shape: (usize, usize),
) -> Result<scirs2_core::gpu::GpuBuffer<f32>> {
let output_buffer = ctx.context.create_buffer::<f32>(shape.0 * shape.1);
Ok(output_buffer)
}
#[allow(dead_code)]
fn compute_conv_outputshape(inputshape: (usize, usize), config: &LayerConfig) -> (usize, usize) {
let (h, w) = inputshape;
let out_h = (h + 2 * config.padding - config.kernel_size) / config.stride + 1;
let out_w = (w + 2 * config.padding - config.kernel_size) / config.stride + 1;
(out_h, out_w)
}
#[allow(dead_code)]
fn compute_pool_outputshape(inputshape: (usize, usize), config: &LayerConfig) -> (usize, usize) {
let (h, w) = inputshape;
let out_h = h / config.stride;
let out_w = w / config.stride;
(out_h, out_w)
}
#[allow(dead_code)]
fn fallback_multi_head_attention(
queries: &ArrayView2<f32>,
keys: &ArrayView2<f32>,
values: &ArrayView2<f32>,
num_heads: usize,
) -> Result<Array2<f32>> {
let (seq_len, hidden_dim) = queries.dim();
let head_dim = hidden_dim / num_heads;
let scale = 1.0 / (head_dim as f32).sqrt();
let mut output = Array2::zeros((seq_len, hidden_dim));
for head in 0..num_heads {
let head_start = head * head_dim;
let head_end = head_start + head_dim;
let q_head = queries.slice(scirs2_core::ndarray::s![.., head_start..head_end]);
let k_head = keys.slice(scirs2_core::ndarray::s![.., head_start..head_end]);
let v_head = values.slice(scirs2_core::ndarray::s![.., head_start..head_end]);
let scores = crate::simd_ops::simd_matmul_attention_advanced(&q_head, &k_head.t())?;
let scaled_scores = scores.mapv(|x| x * scale);
let mut attention_weights = Array2::zeros(scaled_scores.dim());
scirs2_core::ndarray::Zip::from(attention_weights.rows_mut())
.and(scaled_scores.rows())
.for_each(|mut row, score_row| {
let max_val = score_row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp_scores: Vec<f32> = score_row.iter().map(|&x| (x - max_val).exp()).collect();
let sum_exp: f32 = exp_scores.iter().sum();
for (i, &exp_score) in exp_scores.iter().enumerate() {
row[i] = exp_score / sum_exp;
}
});
let head_output = crate::simd_ops::simd_matmul_attention_advanced(
&attention_weights.view(),
&v_head.view(),
)?;
output
.slice_mut(scirs2_core::ndarray::s![.., head_start..head_end])
.assign(&head_output);
}
Ok(output)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
fn filled(rows: usize, cols: usize, seed: f32) -> Array2<f32> {
Array2::from_shape_fn((rows, cols), |(r, c)| {
let x = (r as f32 * 0.37 + c as f32 * 0.11 + seed) * 1.3;
(x.sin() * 0.5) + (x.cos() * 0.25)
})
}
#[test]
fn test_gpu_multi_head_attention_matches_cpu() {
let Ok(ctx) = GpuVisionContext::new() else {
eprintln!("Skipping: no GPU context (not even CPU backend) available");
return;
};
let seq_len = 6;
let hidden_dim = 8;
let num_heads = 2;
let queries = filled(seq_len, hidden_dim, 0.0);
let keys = filled(seq_len, hidden_dim, 1.0);
let values = filled(seq_len, hidden_dim, 2.0);
let gpu_result = gpu_multi_head_attention(
&ctx,
&queries.view(),
&keys.view(),
&values.view(),
num_heads,
)
.expect("GPU multi-head attention should succeed (with fallback)");
let cpu_reference =
fallback_multi_head_attention(&queries.view(), &keys.view(), &values.view(), num_heads)
.expect("CPU reference multi-head attention should succeed");
assert_eq!(gpu_result.dim(), (seq_len, hidden_dim));
assert_eq!(gpu_result.dim(), cpu_reference.dim());
for (g, c) in gpu_result.iter().zip(cpu_reference.iter()) {
assert!(
(g - c).abs() < 1e-3,
"GPU attention output {g} diverged from CPU reference {c} on backend {:?}",
ctx.backend()
);
}
}
#[test]
fn test_gpu_multi_head_attention_rejects_bad_shapes() {
let Ok(ctx) = GpuVisionContext::new() else {
return;
};
let queries = filled(4, 8, 0.0);
let keys = filled(5, 8, 1.0); let values = filled(4, 8, 2.0);
let result =
gpu_multi_head_attention(&ctx, &queries.view(), &keys.view(), &values.view(), 2);
assert!(result.is_err(), "mismatched K shape must be rejected");
let q = filled(4, 6, 0.0);
let result = gpu_multi_head_attention(&ctx, &q.view(), &q.view(), &q.view(), 4);
assert!(
result.is_err(),
"hidden_dim not divisible by num_heads must be rejected"
);
}
}