tenflowers-core 0.1.1

Core tensor operations and execution engine for TenfloweRS
Documentation
//
//  simd_group_matmul.metal
//  Advanced SIMD-group matrix multiplication for Apple Silicon
//
//  Leverages Apple's unified memory architecture, tile memory, and SIMD-group primitives
//  for maximum performance on M1/M2/M3 processors and AMD GPUs on macOS
//

#include <metal_stdlib>
using namespace metal;

// MARK: - Advanced Matrix Multiplication using SIMD Groups

/// High-performance matrix multiplication using SIMD-group operations
/// Processes 32x32 tiles using threadgroup memory and SIMD-group primitives
/// Optimized for Apple Silicon's wide execution units and unified memory
kernel void simd_group_matmul_f32(device const float* A [[buffer(0)]],
                                  device const float* B [[buffer(1)]],
                                  device float* C [[buffer(2)]],
                                  constant uint& M [[buffer(3)]],
                                  constant uint& N [[buffer(4)]],
                                  constant uint& K [[buffer(5)]],
                                  constant uint& lda [[buffer(6)]],
                                  constant uint& ldb [[buffer(7)]],
                                  constant uint& ldc [[buffer(8)]],
                                  threadgroup float* tile_A [[threadgroup(0)]],
                                  threadgroup float* tile_B [[threadgroup(1)]],
                                  uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]],
                                  uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]],
                                  uint simd_lane_id [[thread_index_in_simdgroup]],
                                  uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
    
    constexpr uint TILE_SIZE = 32;
    constexpr uint SIMD_GROUP_SIZE = 32;
    
    // Calculate tile coordinates
    uint tile_row = threadgroup_position_in_grid.x * TILE_SIZE;
    uint tile_col = threadgroup_position_in_grid.y * TILE_SIZE;
    
    // Thread indices within threadgroup
    uint local_row = thread_position_in_threadgroup.x;
    uint local_col = thread_position_in_threadgroup.y;
    
    // SIMD-group specific indices
    uint simd_row = simd_group_id / 2;
    uint simd_col = simd_group_id % 2;
    uint lane_row = simd_lane_id / 8;
    uint lane_col = simd_lane_id % 8;
    
    // Local accumulator for this thread
    float4 accumulator[4] = {float4(0.0f), float4(0.0f), float4(0.0f), float4(0.0f)};
    
    // Process K dimension in blocks of TILE_SIZE
    for (uint k_block = 0; k_block < K; k_block += TILE_SIZE) {
        
        // Cooperative loading of A tile using vectorized memory access
        uint A_global_row = tile_row + local_row;
        uint A_global_col = k_block + local_col;
        
        if (A_global_row < M && A_global_col < K) {
            // Use SIMD-group shuffle for efficient data distribution
            float a_value = A[A_global_row * lda + A_global_col];
            tile_A[local_row * TILE_SIZE + local_col] = a_value;
        } else {
            tile_A[local_row * TILE_SIZE + local_col] = 0.0f;
        }
        
        // Cooperative loading of B tile with memory coalescing
        uint B_global_row = k_block + local_row;
        uint B_global_col = tile_col + local_col;
        
        if (B_global_row < K && B_global_col < N) {
            float b_value = B[B_global_row * ldb + B_global_col];
            tile_B[local_row * TILE_SIZE + local_col] = b_value;
        } else {
            tile_B[local_row * TILE_SIZE + local_col] = 0.0f;
        }
        
        // Synchronize threadgroup before computation
        threadgroup_barrier(mem_flags::mem_threadgroup);
        
        // SIMD-group matrix multiply using Apple's optimized primitives
        for (uint k_inner = 0; k_inner < TILE_SIZE; k_inner += 8) {
            
            // Load 4x4 sub-tiles for this SIMD group
            float4 a_fragment[4];
            float4 b_fragment[4];
            
            // Efficient loading using SIMD-group operations
            for (uint i = 0; i < 4; i++) {
                uint a_idx = (simd_row * 16 + lane_row * 4 + i) * TILE_SIZE + (k_inner + lane_col);
                uint b_idx = (k_inner + lane_row * 4 + i) * TILE_SIZE + (simd_col * 16 + lane_col * 4);
                
                if (a_idx / TILE_SIZE < TILE_SIZE && a_idx % TILE_SIZE < TILE_SIZE) {
                    a_fragment[i] = float4(tile_A[a_idx], tile_A[a_idx + 1], 
                                          tile_A[a_idx + 2], tile_A[a_idx + 3]);
                } else {
                    a_fragment[i] = float4(0.0f);
                }
                
                if (b_idx / TILE_SIZE < TILE_SIZE && (b_idx % TILE_SIZE) + 3 < TILE_SIZE) {
                    b_fragment[i] = float4(tile_B[b_idx], tile_B[b_idx + 1], 
                                          tile_B[b_idx + 2], tile_B[b_idx + 3]);
                } else {
                    b_fragment[i] = float4(0.0f);
                }
            }
            
            // Perform 4x4 matrix multiply using fused multiply-add
            for (uint i = 0; i < 4; i++) {
                for (uint j = 0; j < 4; j++) {
                    accumulator[i] = fma(a_fragment[i], b_fragment[j], accumulator[i]);
                }
            }
        }
        
        // Synchronize before next k-block
        threadgroup_barrier(mem_flags::mem_threadgroup);
    }
    
    // Write results back to global memory using SIMD-group cooperation
    uint C_base_row = tile_row + simd_row * 16 + lane_row * 4;
    uint C_base_col = tile_col + simd_col * 16 + lane_col * 4;
    
    for (uint i = 0; i < 4; i++) {
        uint C_row = C_base_row + i;
        if (C_row < M && C_base_col + 3 < N) {
            uint C_idx = C_row * ldc + C_base_col;
            
            // Vectorized write for maximum memory bandwidth
            device float4* C_ptr = (device float4*)(C + C_idx);
            *C_ptr = accumulator[i];
        }
    }
}

/// Half-precision variant for memory-bandwidth limited workloads
kernel void simd_group_matmul_f16(device const half* A [[buffer(0)]],
                                  device const half* B [[buffer(1)]],
                                  device half* C [[buffer(2)]],
                                  constant uint& M [[buffer(3)]],
                                  constant uint& N [[buffer(4)]],
                                  constant uint& K [[buffer(5)]],
                                  threadgroup half* tile_A [[threadgroup(0)]],
                                  threadgroup half* tile_B [[threadgroup(1)]],
                                  uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]],
                                  uint simd_lane_id [[thread_index_in_simdgroup]],
                                  uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
    
    constexpr uint TILE_SIZE = 64;  // Larger tiles for half precision
    
    // Use half8 for even better vectorization
    half8 accumulator[4] = {half8(0.0h), half8(0.0h), half8(0.0h), half8(0.0h)};
    
    uint tile_row = threadgroup_position_in_grid.x * TILE_SIZE;
    uint tile_col = threadgroup_position_in_grid.y * TILE_SIZE;
    
    // Process with 8-wide half precision operations
    for (uint k_block = 0; k_block < K; k_block += TILE_SIZE) {
        // Similar structure but using half8 vectorization
        // and optimized for Apple Silicon's half-precision units
        
        // Cooperative tile loading with half8 vectors
        threadgroup_barrier(mem_flags::mem_threadgroup);
        
        // SIMD-group matrix operations using half precision
        for (uint k_inner = 0; k_inner < TILE_SIZE; k_inner += 8) {
            half8 a_fragment = half8(0.0h);
            half8 b_fragment = half8(0.0h);
            
            // Load fragments and perform fused multiply-add
            accumulator[0] = fma(a_fragment, b_fragment, accumulator[0]);
        }
        
        threadgroup_barrier(mem_flags::mem_threadgroup);
    }
    
    // Write results back with half8 vectorization
}

/// Advanced convolution using SIMD-group operations
/// Optimized for Apple's unified memory and wide SIMD units
kernel void simd_group_conv2d(device const float* input [[buffer(0)]],
                              device const float* weights [[buffer(1)]],
                              device float* output [[buffer(2)]],
                              constant uint& batch_size [[buffer(3)]],
                              constant uint& in_channels [[buffer(4)]],
                              constant uint& out_channels [[buffer(5)]],
                              constant uint& input_height [[buffer(6)]],
                              constant uint& input_width [[buffer(7)]],
                              constant uint& kernel_height [[buffer(8)]],
                              constant uint& kernel_width [[buffer(9)]],
                              constant uint& stride [[buffer(10)]],
                              constant uint& padding [[buffer(11)]],
                              threadgroup float* tile_input [[threadgroup(0)]],
                              threadgroup float* tile_weights [[threadgroup(1)]],
                              uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]],
                              uint simd_lane_id [[thread_index_in_simdgroup]],
                              uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
    
    constexpr uint TILE_SIZE = 32;
    
    // Calculate output coordinates
    uint out_y = threadgroup_position_in_grid.x * TILE_SIZE + (simd_lane_id / 8);
    uint out_x = threadgroup_position_in_grid.y * TILE_SIZE + (simd_lane_id % 8);
    uint out_c = threadgroup_position_in_grid.z;
    
    float accumulator = 0.0f;
    
    // Process all input channels
    for (uint in_c = 0; in_c < in_channels; in_c += 4) {
        
        // Load 4 input channels in parallel using SIMD-group cooperation
        float4 input_channels = float4(0.0f);
        float4 weight_channels = float4(0.0f);
        
        for (uint ky = 0; ky < kernel_height; ky++) {
            for (uint kx = 0; kx < kernel_width; kx++) {
                
                int in_y = int(out_y * stride + ky) - int(padding);
                int in_x = int(out_x * stride + kx) - int(padding);
                
                if (in_y >= 0 && in_y < int(input_height) && 
                    in_x >= 0 && in_x < int(input_width)) {
                    
                    // Vectorized load of 4 input channels
                    uint input_idx = uint(in_y) * input_width * in_channels + 
                                    uint(in_x) * in_channels + in_c;
                    
                    input_channels = float4(input[input_idx], 
                                           input[input_idx + 1],
                                           input[input_idx + 2],
                                           input[input_idx + 3]);
                    
                    // Load corresponding weights
                    uint weight_idx = out_c * kernel_height * kernel_width * in_channels +
                                     ky * kernel_width * in_channels +
                                     kx * in_channels + in_c;
                    
                    weight_channels = float4(weights[weight_idx],
                                           weights[weight_idx + 1], 
                                           weights[weight_idx + 2],
                                           weights[weight_idx + 3]);
                    
                    // Fused multiply-add across all 4 channels
                    accumulator += dot(input_channels, weight_channels);
                }
            }
        }
    }
    
    // Write result using SIMD-group coordination
    if (out_y < input_height && out_x < input_width) {
        uint output_idx = out_y * input_width * out_channels + out_x * out_channels + out_c;
        output[output_idx] = accumulator;
    }
}

/// Optimized reduction operations using SIMD-group primitives
kernel void simd_group_reduction(device const float* input [[buffer(0)]],
                                device float* output [[buffer(1)]],
                                constant uint& count [[buffer(2)]],
                                uint simd_lane_id [[thread_index_in_simdgroup]],
                                uint simd_group_id [[simdgroup_index_in_threadgroup]],
                                uint thread_position_in_grid [[thread_position_in_grid]]) {
    
    float local_sum = 0.0f;
    
    // Each thread processes multiple elements
    uint elements_per_thread = (count + 31) / 32;  // Round up division
    uint start_idx = thread_position_in_grid * elements_per_thread;
    
    for (uint i = 0; i < elements_per_thread; i++) {
        uint idx = start_idx + i;
        if (idx < count) {
            local_sum += input[idx];
        }
    }
    
    // SIMD-group tree reduction using built-in primitives
    local_sum = simd_sum(local_sum);
    
    // First thread in each SIMD-group writes result
    if (simd_lane_id == 0) {
        output[simd_group_id] = local_sum;
    }
}