tenflowers-core 0.1.1

Core tensor operations and execution engine for TenfloweRS
Documentation
// Linear Algebra: Matrix Determinant Computation
//
// Computes the determinant of a square matrix using Gaussian elimination
// with partial pivoting. This is optimized for GPU execution with proper
// synchronization using multiple kernel dispatches.
//
// Input:  Matrix A [n, n]
// Output: Determinant value (scalar)

struct LinalgMetadata {
    rows_a: u32,
    cols_a: u32,
    rows_b: u32,
    cols_b: u32,
    batch_size: u32,
    tolerance: f32,
    max_iterations: u32,
    padding: u32,
}

@group(0) @binding(0) var<storage, read_write> matrix: array<f32>;
@group(0) @binding(1) var<storage, read_write> determinant: array<f32>;
@group(0) @binding(2) var<storage, read_write> pivot_info: array<u32>;
@group(0) @binding(3) var<uniform> metadata: LinalgMetadata;

// Kernel for copying input matrix to working matrix
@compute @workgroup_size(16, 16, 1)
fn copy_matrix(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let n = metadata.rows_a;
    let row = global_id.y;
    let col = global_id.x;
    
    if (row >= n || col >= n) {
        return;
    }
    
    let idx = row * n + col;
    // Matrix is already in place, just ensure initialization
    // determinant[0] is initialized to 1.0 by the host
}

// Kernel for finding pivot element in column k
@compute @workgroup_size(256, 1, 1)
fn find_pivot(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let n = metadata.rows_a;
    let thread_id = global_id.x;
    let k = pivot_info[0]; // Current column index
    
    if (thread_id >= n || k >= n) {
        return;
    }
    
    // Each thread checks one row below diagonal
    let row = k + thread_id;
    if (row >= n) {
        return;
    }
    
    let idx = row * n + k;
    let abs_val = abs(matrix[idx]);
    
    // Find maximum absolute value and its row index
    // This is a simplified version - in production, would use reduction
    var max_val = abs(matrix[k * n + k]);
    var max_row = k;
    
    for (var i = k + 1u; i < n; i = i + 1u) {
        let val = abs(matrix[i * n + k]);
        if (val > max_val) {
            max_val = val;
            max_row = i;
        }
    }
    
    // Store pivot row index
    if (thread_id == 0u) {
        pivot_info[1] = max_row;
    }
}

// Kernel for swapping rows if needed
@compute @workgroup_size(256, 1, 1)
fn swap_rows(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let n = metadata.rows_a;
    let col = global_id.x;
    let k = pivot_info[0]; // Current column index
    let pivot_row = pivot_info[1]; // Pivot row index
    
    if (col >= n || k >= n) {
        return;
    }
    
    // Swap rows k and pivot_row
    if (k != pivot_row) {
        let idx_k = k * n + col;
        let idx_pivot = pivot_row * n + col;
        
        let temp = matrix[idx_k];
        matrix[idx_k] = matrix[idx_pivot];
        matrix[idx_pivot] = temp;
        
        // Update determinant sign (multiply by -1)
        if (col == 0u) {
            determinant[0] = -determinant[0];
        }
    }
}

// Kernel for Gaussian elimination step
@compute @workgroup_size(16, 16, 1)
fn elimination_step(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let n = metadata.rows_a;
    let row = global_id.y;
    let col = global_id.x;
    let k = pivot_info[0]; // Current column index
    
    if (row >= n || col >= n || k >= n) {
        return;
    }
    
    // Only process elements below and to the right of pivot
    if (row <= k || col < k) {
        return;
    }
    
    let pivot_idx = k * n + k;
    let pivot_val = matrix[pivot_idx];
    
    // Check for singularity
    if (abs(pivot_val) < metadata.tolerance) {
        if (row == k + 1u && col == k) {
            determinant[0] = 0.0;
        }
        return;
    }
    
    // Compute elimination factor
    let factor_idx = row * n + k;
    let factor = matrix[factor_idx] / pivot_val;
    
    // Eliminate element
    let curr_idx = row * n + col;
    let pivot_col_idx = k * n + col;
    matrix[curr_idx] = matrix[curr_idx] - factor * matrix[pivot_col_idx];
    
    // Zero out the column element below diagonal
    if (col == k) {
        matrix[factor_idx] = 0.0;
    }
}

// Kernel for computing final determinant from diagonal elements
@compute @workgroup_size(256, 1, 1)
fn compute_determinant(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let n = metadata.rows_a;
    let thread_id = global_id.x;
    
    if (thread_id != 0u) {
        return;
    }
    
    // Multiply all diagonal elements
    var det = determinant[0]; // This contains the sign factor from row swaps
    
    for (var i = 0u; i < n; i = i + 1u) {
        let diag_val = matrix[i * n + i];
        det = det * diag_val;
    }
    
    determinant[0] = det;
}

// Main kernel that coordinates the entire process
@compute @workgroup_size(1, 1, 1)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let n = metadata.rows_a;
    
    // Initialize determinant to 1.0
    if (global_id.x == 0u) {
        determinant[0] = 1.0;
    }
    
    // The actual elimination process is handled by multiple kernel dispatches
    // from the host code, not in this single kernel due to synchronization requirements
}