tenflowers-core 0.1.1

Core tensor operations and execution engine for TenfloweRS
Documentation
// Linear Algebra: Singular Value Decomposition (SVD)
//
// Performs SVD using Jacobi rotations adapted for GPU execution.
// For a matrix A [m, n], computes: A = U * S * V^T
// Where: U [m, m] (orthogonal), S [min(m,n)] (diagonal), V^T [n, n] (orthogonal)
//
// This implementation uses the Jacobi method which is suitable for GPU parallelization
// and provides good numerical stability for small to medium matrices.

struct LinalgMetadata {
    rows_a: u32,
    cols_a: u32,
    rows_b: u32,
    cols_b: u32,
    batch_size: u32,
    tolerance: f32,
    max_iterations: u32,
    padding: u32,
}

@group(0) @binding(0) var<storage, read> input_matrix: array<f32>;
@group(0) @binding(1) var<storage, read_write> u_matrix: array<f32>;
@group(0) @binding(2) var<storage, read_write> s_values: array<f32>;
@group(0) @binding(3) var<storage, read_write> vt_matrix: array<f32>;
@group(0) @binding(4) var<storage, read_write> working_matrix: array<f32>;
@group(0) @binding(5) var<uniform> metadata: LinalgMetadata;

// Initialize matrices for SVD computation
@compute @workgroup_size(16, 16, 1)
fn initialize_svd(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let m = metadata.rows_a;
    let n = metadata.cols_a;
    let row = global_id.y;
    let col = global_id.x;
    
    // Initialize working matrix with input data
    if (row < m && col < n) {
        let idx = row * n + col;
        working_matrix[idx] = input_matrix[idx];
    }
    
    // Initialize U matrix as identity (m x m)
    if (row < m && col < m) {
        let u_idx = row * m + col;
        if (row == col) {
            u_matrix[u_idx] = 1.0;
        } else {
            u_matrix[u_idx] = 0.0;
        }
    }
    
    // Initialize V^T matrix as identity (n x n)
    if (row < n && col < n) {
        let vt_idx = row * n + col;
        if (row == col) {
            vt_matrix[vt_idx] = 1.0;
        } else {
            vt_matrix[vt_idx] = 0.0;
        }
    }
    
    // Initialize singular values to zero
    if (row == 0u && col < min(m, n)) {
        s_values[col] = 0.0;
    }
}

// Compute Givens rotation parameters
fn compute_givens_rotation(a: f32, b: f32) -> vec2<f32> {
    if (abs(b) < metadata.tolerance) {
        return vec2<f32>(1.0, 0.0); // cos, sin
    }
    
    let tau = a / b;
    let t = 1.0 / (abs(tau) + sqrt(1.0 + tau * tau));
    
    if (tau < 0.0) {
        let cos_val = t / sqrt(1.0 + t * t);
        let sin_val = -t * cos_val;
        return vec2<f32>(cos_val, sin_val);
    } else {
        let cos_val = t / sqrt(1.0 + t * t);
        let sin_val = t * cos_val;
        return vec2<f32>(cos_val, sin_val);
    }
}

// Apply Givens rotation to eliminate off-diagonal element (i,j)
@compute @workgroup_size(256, 1, 1)
fn apply_givens_rotation(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let n = metadata.cols_a;
    let thread_id = global_id.x;
    
    // Get the target (i,j) pair from metadata
    let i = metadata.rows_b; // Reusing metadata fields for i,j indices
    let j = metadata.cols_b;
    
    if (i >= j || i >= n || j >= n) {
        return;
    }
    
    // Compute rotation parameters
    let a_ii = working_matrix[i * n + i];
    let a_jj = working_matrix[j * n + j];
    let a_ij = working_matrix[i * n + j];
    let a_ji = working_matrix[j * n + i];
    
    // Symmetric matrix: a_ij == a_ji
    let off_diag = a_ij;
    
    if (abs(off_diag) < metadata.tolerance) {
        return; // Already converged
    }
    
    // Compute rotation angle
    let diff = a_jj - a_ii;
    let t = if (abs(diff) < metadata.tolerance) {
        if (off_diag > 0.0) { 1.0 } else { -1.0 }
    } else {
        off_diag / (diff + sign(diff) * sqrt(diff * diff + off_diag * off_diag))
    };
    
    let cos_theta = 1.0 / sqrt(1.0 + t * t);
    let sin_theta = t * cos_theta;
    
    // Apply rotation to working matrix (only upper triangle matters)
    if (thread_id < n) {
        let k = thread_id;
        
        if (k != i && k != j) {
            // Update row i and j, column k
            let a_ik = working_matrix[i * n + k];
            let a_jk = working_matrix[j * n + k];
            
            working_matrix[i * n + k] = cos_theta * a_ik - sin_theta * a_jk;
            working_matrix[j * n + k] = sin_theta * a_ik + cos_theta * a_jk;
            
            // Update column i and j, row k
            working_matrix[k * n + i] = working_matrix[i * n + k]; // Symmetric
            working_matrix[k * n + j] = working_matrix[j * n + k]; // Symmetric
        }
    }
    
    // Update diagonal elements
    if (thread_id == 0u) {
        let a_ii_new = cos_theta * cos_theta * a_ii - 2.0 * sin_theta * cos_theta * off_diag + sin_theta * sin_theta * a_jj;
        let a_jj_new = sin_theta * sin_theta * a_ii + 2.0 * sin_theta * cos_theta * off_diag + cos_theta * cos_theta * a_jj;
        
        working_matrix[i * n + i] = a_ii_new;
        working_matrix[j * n + j] = a_jj_new;
        working_matrix[i * n + j] = 0.0;
        working_matrix[j * n + i] = 0.0;
    }
    
    // Apply rotation to V matrix (V = V * G^T)
    if (thread_id < n) {
        let k = thread_id;
        let v_ki = vt_matrix[k * n + i];
        let v_kj = vt_matrix[k * n + j];
        
        vt_matrix[k * n + i] = cos_theta * v_ki - sin_theta * v_kj;
        vt_matrix[k * n + j] = sin_theta * v_ki + cos_theta * v_kj;
    }
}

// Extract singular values from diagonal of working matrix
@compute @workgroup_size(256, 1, 1)
fn extract_singular_values(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let n = min(metadata.rows_a, metadata.cols_a);
    let idx = global_id.x;
    
    if (idx >= n) {
        return;
    }
    
    let diagonal_val = working_matrix[idx * metadata.cols_a + idx];
    s_values[idx] = sqrt(abs(diagonal_val));
}

// For rectangular matrices: reduce to bidiagonal form first
@compute @workgroup_size(16, 16, 1)
fn bidiagonalize(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let m = metadata.rows_a;
    let n = metadata.cols_a;
    let row = global_id.y;
    let col = global_id.x;
    
    if (row >= m || col >= n) {
        return;
    }
    
    // This is a simplified bidiagonalization
    // A full implementation would use Householder reflections
    // For now, we assume the input is already in a suitable form
    // or is square (m == n)
}

// Compute U matrix for rectangular case using back-transformation
@compute @workgroup_size(16, 16, 1)
fn compute_u_matrix(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let m = metadata.rows_a;
    let n = metadata.cols_a;
    let row = global_id.y;
    let col = global_id.x;
    
    if (row >= m || col >= m) {
        return;
    }
    
    // For rectangular matrices, U needs to be computed differently
    // This is a simplified version - a full implementation would
    // require the intermediate transformations from bidiagonalization
    
    if (col < min(m, n)) {
        // Copy from V matrix if square, otherwise compute properly
        if (m == n) {
            u_matrix[row * m + col] = vt_matrix[row * n + col];
        }
    }
}