tenflowers-core 0.1.1

Core tensor operations and execution engine for TenfloweRS
Documentation
// Logical operation compute shaders
// Operates on boolean tensors represented as u32 arrays

@group(0) @binding(0) var<storage, read> input_a: array<u32>; // Boolean inputs stored as u32
@group(0) @binding(1) var<storage, read> input_b: array<u32>;
@group(0) @binding(2) var<storage, read_write> output: array<u32>; // Boolean output

// Broadcasting metadata
@group(0) @binding(3) var<storage, read> shape_metadata: array<u32>;

// Helper functions for broadcasting
fn compute_flat_index_broadcast(indices: ptr<function, array<u32, 8>>, shape_offset: u32, rank: u32) -> u32 {
    var flat_idx = 0u;
    var stride = 1u;
    
    for (var i = 0u; i < rank; i++) {
        let dim_idx = rank - 1u - i;
        flat_idx += (*indices)[dim_idx] * stride;
        stride *= shape_metadata[shape_offset + dim_idx];
    }
    
    return flat_idx;
}

fn compute_multi_index_broadcast(flat_idx: u32, indices: ptr<function, array<u32, 8>>, shape_offset: u32, rank: u32) {
    var remaining = flat_idx;
    
    for (var i = 0u; i < rank; i++) {
        let dim_idx = rank - 1u - i;
        (*indices)[dim_idx] = remaining % shape_metadata[shape_offset + dim_idx];
        remaining = remaining / shape_metadata[shape_offset + dim_idx];
    }
}

fn broadcast_indices(output_indices: ptr<function, array<u32, 8>>, input_indices: ptr<function, array<u32, 8>>, 
                     input_shape_offset: u32, input_rank: u32, output_rank: u32) {
    for (var i = 0u; i < 8u; i++) {
        (*input_indices)[i] = 0u;
    }
    
    let rank_diff = output_rank - input_rank;
    for (var i = 0u; i < input_rank; i++) {
        let input_dim_idx = i;
        let output_dim_idx = i + rank_diff;
        let input_size = shape_metadata[input_shape_offset + input_dim_idx];
        
        if (input_size == 1u) {
            (*input_indices)[input_dim_idx] = 0u;
        } else {
            (*input_indices)[input_dim_idx] = (*output_indices)[output_dim_idx];
        }
    }
}

// Basic logical operations

@compute @workgroup_size(64)
fn and_op(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let index = global_id.x;
    if (index >= arrayLength(&output)) {
        return;
    }
    
    let a_idx = index % arrayLength(&input_a);
    let b_idx = index % arrayLength(&input_b);
    
    // Convert to boolean, perform AND, convert back to u32
    let a_bool = input_a[a_idx] != 0u;
    let b_bool = input_b[b_idx] != 0u;
    output[index] = select(0u, 1u, a_bool && b_bool);
}

@compute @workgroup_size(64)
fn or_op(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let index = global_id.x;
    if (index >= arrayLength(&output)) {
        return;
    }
    
    let a_idx = index % arrayLength(&input_a);
    let b_idx = index % arrayLength(&input_b);
    
    let a_bool = input_a[a_idx] != 0u;
    let b_bool = input_b[b_idx] != 0u;
    output[index] = select(0u, 1u, a_bool || b_bool);
}

@compute @workgroup_size(64)
fn xor_op(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let index = global_id.x;
    if (index >= arrayLength(&output)) {
        return;
    }
    
    let a_idx = index % arrayLength(&input_a);
    let b_idx = index % arrayLength(&input_b);
    
    let a_bool = input_a[a_idx] != 0u;
    let b_bool = input_b[b_idx] != 0u;
    output[index] = select(0u, 1u, a_bool != b_bool);
}

// Unary logical operation (NOT)
@compute @workgroup_size(64)
fn not_op(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let index = global_id.x;
    if (index >= arrayLength(&output)) {
        return;
    }
    
    let a_idx = index % arrayLength(&input_a);
    let a_bool = input_a[a_idx] != 0u;
    output[index] = select(1u, 0u, a_bool); // NOT operation
}

// Broadcasting versions

@compute @workgroup_size(64)
fn and_broadcast(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let output_idx = global_id.x;
    if (output_idx >= arrayLength(&output)) {
        return;
    }
    
    let a_rank = shape_metadata[0];
    let b_rank = shape_metadata[1];
    let output_rank = shape_metadata[2];
    
    let a_shape_offset = 3u;
    let b_shape_offset = 3u + a_rank;
    let output_shape_offset = 3u + a_rank + b_rank;
    
    var output_indices: array<u32, 8>;
    compute_multi_index_broadcast(output_idx, &output_indices, output_shape_offset, output_rank);
    
    var a_indices: array<u32, 8>;
    var b_indices: array<u32, 8>;
    broadcast_indices(&output_indices, &a_indices, a_shape_offset, a_rank, output_rank);
    broadcast_indices(&output_indices, &b_indices, b_shape_offset, b_rank, output_rank);
    
    let a_flat_idx = compute_flat_index_broadcast(&a_indices, a_shape_offset, a_rank);
    let b_flat_idx = compute_flat_index_broadcast(&b_indices, b_shape_offset, b_rank);
    
    let a_bool = input_a[a_flat_idx] != 0u;
    let b_bool = input_b[b_flat_idx] != 0u;
    output[output_idx] = select(0u, 1u, a_bool && b_bool);
}

@compute @workgroup_size(64)
fn or_broadcast(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let output_idx = global_id.x;
    if (output_idx >= arrayLength(&output)) {
        return;
    }
    
    let a_rank = shape_metadata[0];
    let b_rank = shape_metadata[1];
    let output_rank = shape_metadata[2];
    
    let a_shape_offset = 3u;
    let b_shape_offset = 3u + a_rank;
    let output_shape_offset = 3u + a_rank + b_rank;
    
    var output_indices: array<u32, 8>;
    compute_multi_index_broadcast(output_idx, &output_indices, output_shape_offset, output_rank);
    
    var a_indices: array<u32, 8>;
    var b_indices: array<u32, 8>;
    broadcast_indices(&output_indices, &a_indices, a_shape_offset, a_rank, output_rank);
    broadcast_indices(&output_indices, &b_indices, b_shape_offset, b_rank, output_rank);
    
    let a_flat_idx = compute_flat_index_broadcast(&a_indices, a_shape_offset, a_rank);
    let b_flat_idx = compute_flat_index_broadcast(&b_indices, b_shape_offset, b_rank);
    
    let a_bool = input_a[a_flat_idx] != 0u;
    let b_bool = input_b[b_flat_idx] != 0u;
    output[output_idx] = select(0u, 1u, a_bool || b_bool);
}

@compute @workgroup_size(64)
fn xor_broadcast(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let output_idx = global_id.x;
    if (output_idx >= arrayLength(&output)) {
        return;
    }
    
    let a_rank = shape_metadata[0];
    let b_rank = shape_metadata[1];
    let output_rank = shape_metadata[2];
    
    let a_shape_offset = 3u;
    let b_shape_offset = 3u + a_rank;
    let output_shape_offset = 3u + a_rank + b_rank;
    
    var output_indices: array<u32, 8>;
    compute_multi_index_broadcast(output_idx, &output_indices, output_shape_offset, output_rank);
    
    var a_indices: array<u32, 8>;
    var b_indices: array<u32, 8>;
    broadcast_indices(&output_indices, &a_indices, a_shape_offset, a_rank, output_rank);
    broadcast_indices(&output_indices, &b_indices, b_shape_offset, b_rank, output_rank);
    
    let a_flat_idx = compute_flat_index_broadcast(&a_indices, a_shape_offset, a_rank);
    let b_flat_idx = compute_flat_index_broadcast(&b_indices, b_shape_offset, b_rank);
    
    let a_bool = input_a[a_flat_idx] != 0u;
    let b_bool = input_b[b_flat_idx] != 0u;
    output[output_idx] = select(0u, 1u, a_bool != b_bool);
}

// Short-circuit optimized operations
// These operations can potentially skip computation when one operand guarantees the result

@compute @workgroup_size(64)
fn and_short_circuit(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let index = global_id.x;
    if (index >= arrayLength(&output)) {
        return;
    }
    
    let a_idx = index % arrayLength(&input_a);
    let b_idx = index % arrayLength(&input_b);
    
    // Early exit if first operand is false
    if (input_a[a_idx] == 0u) {
        output[index] = 0u;
        return;
    }
    
    // Check second operand
    output[index] = select(0u, 1u, input_b[b_idx] != 0u);
}

@compute @workgroup_size(64)
fn or_short_circuit(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let index = global_id.x;
    if (index >= arrayLength(&output)) {
        return;
    }
    
    let a_idx = index % arrayLength(&input_a);
    let b_idx = index % arrayLength(&input_b);
    
    // Early exit if first operand is true
    if (input_a[a_idx] != 0u) {
        output[index] = 1u;
        return;
    }
    
    // Check second operand
    output[index] = select(0u, 1u, input_b[b_idx] != 0u);
}