tenflowers-core 0.1.1

Core tensor operations and execution engine for TenfloweRS
Documentation
// Linear Algebra: Eigenvalue Computation using QR Algorithm
//
// Computes eigenvalues and eigenvectors of a symmetric matrix using the QR algorithm.
// For a symmetric matrix A [n, n], computes eigenvalues λ and eigenvectors V such that:
// A * V = V * Λ (where Λ is diagonal matrix of eigenvalues)
//
// This implementation uses the symmetric QR algorithm with Givens rotations,
// which is suitable for GPU parallelization and provides good numerical stability.

struct LinalgMetadata {
    rows_a: u32,
    cols_a: u32,
    rows_b: u32,
    cols_b: u32,
    batch_size: u32,
    tolerance: f32,
    max_iterations: u32,
    padding: u32,
}

@group(0) @binding(0) var<storage, read> input_matrix: array<f32>;
@group(0) @binding(1) var<storage, read_write> eigenvalues: array<f32>;
@group(0) @binding(2) var<storage, read_write> eigenvectors: array<f32>;
@group(0) @binding(3) var<storage, read_write> working_matrix: array<f32>;
@group(0) @binding(4) var<storage, read_write> q_matrix: array<f32>;
@group(0) @binding(5) var<uniform> metadata: LinalgMetadata;

// Initialize matrices for eigenvalue computation
@compute @workgroup_size(16, 16, 1)
fn initialize_eigen(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let n = metadata.rows_a;
    let row = global_id.y;
    let col = global_id.x;
    
    if (row >= n || col >= n) {
        return;
    }
    
    let idx = row * n + col;
    
    // Copy input to working matrix
    working_matrix[idx] = input_matrix[idx];
    
    // Initialize eigenvectors as identity matrix
    if (row == col) {
        eigenvectors[idx] = 1.0;
        q_matrix[idx] = 1.0;
    } else {
        eigenvectors[idx] = 0.0;
        q_matrix[idx] = 0.0;
    }
    
    // Initialize eigenvalues
    if (row == 0u && col < n) {
        eigenvalues[col] = 0.0;
    }
}

// Check convergence by examining off-diagonal elements
@compute @workgroup_size(256, 1, 1)
fn check_convergence(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let n = metadata.rows_a;
    let thread_id = global_id.x;
    
    if (thread_id >= n * (n - 1u) / 2u) {
        return;
    }
    
    // Convert thread_id to (i,j) with i < j
    var count = 0u;
    for (var i = 0u; i < n - 1u; i = i + 1u) {
        for (var j = i + 1u; j < n; j = j + 1u) {
            if (count == thread_id) {
                let off_diag = abs(working_matrix[i * n + j]);
                // Set a convergence flag (simplified)
                if (off_diag < metadata.tolerance) {
                    // Mark as converged (implementation specific)
                }
                return;
            }
            count = count + 1u;
        }
    }
}

// Apply Givens rotation to eliminate off-diagonal element (i,j)
@compute @workgroup_size(256, 1, 1)
fn apply_givens_eigen(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let n = metadata.rows_a;
    let thread_id = global_id.x;
    
    // Get the target (i,j) pair from metadata
    let i = metadata.rows_b; // Reusing metadata fields for i,j indices
    let j = metadata.cols_b;
    
    if (i >= j || i >= n || j >= n) {
        return;
    }
    
    // Get matrix elements
    let a_ii = working_matrix[i * n + i];
    let a_jj = working_matrix[j * n + j];
    let a_ij = working_matrix[i * n + j];
    
    if (abs(a_ij) < metadata.tolerance) {
        return; // Already converged
    }
    
    // Compute rotation parameters for symmetric case
    let diff = a_jj - a_ii;
    let t = if (abs(diff) < metadata.tolerance) {
        if (a_ij > 0.0) { 1.0 } else { -1.0 }
    } else {
        a_ij / (diff + sign(diff) * sqrt(diff * diff + a_ij * a_ij))
    };
    
    let cos_theta = 1.0 / sqrt(1.0 + t * t);
    let sin_theta = t * cos_theta;
    
    // Apply rotation to working matrix: A = G^T * A * G
    if (thread_id < n) {
        let k = thread_id;
        
        if (k != i && k != j) {
            // Transform row k: A[k,i] and A[k,j]
            let a_ki = working_matrix[k * n + i];
            let a_kj = working_matrix[k * n + j];
            
            working_matrix[k * n + i] = cos_theta * a_ki - sin_theta * a_kj;
            working_matrix[k * n + j] = sin_theta * a_ki + cos_theta * a_kj;
            
            // Transform column k: A[i,k] and A[j,k] (symmetric)
            working_matrix[i * n + k] = working_matrix[k * n + i];
            working_matrix[j * n + k] = working_matrix[k * n + j];
        }
    }
    
    // Update diagonal elements
    if (thread_id == 0u) {
        let a_ii_new = cos_theta * cos_theta * a_ii - 2.0 * sin_theta * cos_theta * a_ij + sin_theta * sin_theta * a_jj;
        let a_jj_new = sin_theta * sin_theta * a_ii + 2.0 * sin_theta * cos_theta * a_ij + cos_theta * cos_theta * a_jj;
        
        working_matrix[i * n + i] = a_ii_new;
        working_matrix[j * n + j] = a_jj_new;
        working_matrix[i * n + j] = 0.0;
        working_matrix[j * n + i] = 0.0;
    }
    
    // Accumulate eigenvectors: V = V * G
    if (thread_id < n) {
        let k = thread_id;
        let v_ki = eigenvectors[k * n + i];
        let v_kj = eigenvectors[k * n + j];
        
        eigenvectors[k * n + i] = cos_theta * v_ki - sin_theta * v_kj;
        eigenvectors[k * n + j] = sin_theta * v_ki + cos_theta * v_kj;
    }
}

// Extract eigenvalues from diagonal
@compute @workgroup_size(256, 1, 1)
fn extract_eigenvalues(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let n = metadata.rows_a;
    let idx = global_id.x;
    
    if (idx >= n) {
        return;
    }
    
    eigenvalues[idx] = working_matrix[idx * n + idx];
}

// Sort eigenvalues and eigenvectors in descending order
@compute @workgroup_size(256, 1, 1)
fn sort_eigenvalues(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let n = metadata.rows_a;
    let thread_id = global_id.x;
    
    // Simple bubble sort (not optimal for GPU, but works for small matrices)
    // For larger matrices, a parallel sorting algorithm would be needed
    
    if (thread_id != 0u) {
        return;
    }
    
    // Perform sorting in a single thread for simplicity
    for (var i = 0u; i < n - 1u; i = i + 1u) {
        for (var j = 0u; j < n - i - 1u; j = j + 1u) {
            if (eigenvalues[j] < eigenvalues[j + 1u]) {
                // Swap eigenvalues
                let temp_val = eigenvalues[j];
                eigenvalues[j] = eigenvalues[j + 1u];
                eigenvalues[j + 1u] = temp_val;
                
                // Swap corresponding eigenvectors
                for (var k = 0u; k < n; k = k + 1u) {
                    let temp_vec = eigenvectors[k * n + j];
                    eigenvectors[k * n + j] = eigenvectors[k * n + (j + 1u)];
                    eigenvectors[k * n + (j + 1u)] = temp_vec;
                }
            }
        }
    }
}

// Normalize eigenvectors to unit length
@compute @workgroup_size(256, 1, 1)
fn normalize_eigenvectors(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let n = metadata.rows_a;
    let col = global_id.x;
    
    if (col >= n) {
        return;
    }
    
    // Compute norm of column 'col'
    var norm_squared = 0.0;
    for (var i = 0u; i < n; i = i + 1u) {
        let val = eigenvectors[i * n + col];
        norm_squared = norm_squared + val * val;
    }
    
    let norm = sqrt(norm_squared);
    
    if (norm > metadata.tolerance) {
        // Normalize the column
        for (var i = 0u; i < n; i = i + 1u) {
            eigenvectors[i * n + col] = eigenvectors[i * n + col] / norm;
        }
    }
}

// Compute residual for convergence checking: ||A*v - λ*v||
@compute @workgroup_size(256, 1, 1)
fn compute_residual(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let n = metadata.rows_a;
    let col = global_id.x;
    
    if (col >= n) {
        return;
    }
    
    let lambda = eigenvalues[col];
    var residual = 0.0;
    
    for (var i = 0u; i < n; i = i + 1u) {
        var av_i = 0.0;
        for (var j = 0u; j < n; j = j + 1u) {
            av_i = av_i + input_matrix[i * n + j] * eigenvectors[j * n + col];
        }
        
        let lv_i = lambda * eigenvectors[i * n + col];
        let diff = av_i - lv_i;
        residual = residual + diff * diff;
    }
    
    // Store residual (could be used for convergence checking)
    // For now, we just compute it
}

// Alternative: Power iteration for largest eigenvalue (simpler, more GPU-friendly)
@compute @workgroup_size(256, 1, 1)
fn power_iteration(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let n = metadata.rows_a;
    let idx = global_id.x;
    
    if (idx >= n) {
        return;
    }
    
    // This is a simplified power iteration step
    // Full implementation would require multiple iterations and proper normalization
    
    var sum = 0.0;
    for (var j = 0u; j < n; j = j + 1u) {
        sum = sum + input_matrix[idx * n + j] * eigenvectors[j * n + 0u];
    }
    
    q_matrix[idx] = sum;
}