tenflowers-core 0.1.1

Core tensor operations and execution engine for TenfloweRS
Documentation
// Linear Algebra: Matrix Transpose
// 
// Efficiently transposes a matrix using shared memory tiles for optimal memory coalescing.
// This implementation follows cuBLAS patterns adapted for WGPU.
//
// Input:  Matrix A [rows_a, cols_a]
// Output: Matrix A^T [cols_a, rows_a]

struct LinalgMetadata {
    rows_a: u32,
    cols_a: u32,
    rows_b: u32,
    cols_b: u32,
    batch_size: u32,
    tolerance: f32,
    max_iterations: u32,
    padding: u32,
}

@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
@group(0) @binding(2) var<uniform> metadata: LinalgMetadata;

// Shared memory tile for efficient transpose
var<workgroup> tile: array<array<f32, 16>, 16>;

@compute @workgroup_size(16, 16, 1)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>,
        @builtin(local_invocation_id) local_id: vec3<u32>,
        @builtin(workgroup_id) workgroup_id: vec3<u32>) {
    
    let rows = metadata.rows_a;
    let cols = metadata.cols_a;
    
    // Calculate global position in input matrix
    let input_row = workgroup_id.y * 16u + local_id.y;
    let input_col = workgroup_id.x * 16u + local_id.x;
    
    // Load data into shared memory tile
    if (input_row < rows && input_col < cols) {
        let input_idx = input_row * cols + input_col;
        tile[local_id.y][local_id.x] = input[input_idx];
    } else {
        tile[local_id.y][local_id.x] = 0.0;
    }
    
    // Synchronize workgroup
    workgroupBarrier();
    
    // Calculate output position (transposed)
    let output_row = workgroup_id.x * 16u + local_id.y;
    let output_col = workgroup_id.y * 16u + local_id.x;
    
    // Write transposed data from shared memory
    if (output_row < cols && output_col < rows) {
        let output_idx = output_row * rows + output_col;
        output[output_idx] = tile[local_id.x][local_id.y];
    }
}