bevy_march 0.1.0

SDF ray marching for bevy
Documentation
#import bevy_march::{get_initial_settings, settings, NearestSdf, get_ray_dir_invz, nodes, get_node_min, instances, get_aabb_hit, EPSILON_PER_DIST, EPSILON_MIN, EPSILON_MAX};
#import bevy_prototype_sdf::sdf;

@group(1) @binding(0) var cone_texture: texture_storage_2d<r32float, write>;
@group(1) @binding(1) var<storage, read> uv_scale: vec2<f32>;

struct Material {
    base_color: vec3<f32>,
    emissive: vec3<f32>,
    reflective: f32,
}

@compute @workgroup_size(8, 8, 1)
fn march(@builtin(global_invocation_id) invocation_id: vec3<u32>) {
    var size = textureDimensions(cone_texture);
    let pixel_factor = uv_scale / vec2<f32>(size);
    let position = invocation_id.xy;

    let cluster_start = vec2<f32>(position) * pixel_factor;
    let cluster_end = vec2<f32>(position + 1) * pixel_factor;
    let cluster_center = (cluster_start + cluster_end) * 0.5;
    var march = get_initial_settings(cluster_center, 0.);

    let tl = get_ray_dir_invz(cluster_start);
    let tr = get_ray_dir_invz(vec2<f32>(cluster_end.x, cluster_start.y));
    let bl = get_ray_dir_invz(vec2<f32>(cluster_start.x, cluster_end.y));
    let br = get_ray_dir_invz(cluster_end);
    let radius_per_unit = sqrt(max(max(max(
        len_sq(tl.xyz-march.direction),
        len_sq(tr.xyz-march.direction)),
        len_sq(bl.xyz-march.direction)),
        len_sq(br.xyz-march.direction),
    ));

    let frustum = get_cone_frustum(tl, tr, bl, br, march.direction);

    let epsilon_per_dist = march.scale * EPSILON_PER_DIST;
    let min_epsilon = march.scale * EPSILON_MIN;
    let max_epsilon = march.scale * EPSILON_MAX;

    let dir_recip = 1. / march.direction;
    let ray_positive = sign(march.direction) == vec3<f32>(1.);

    var cluster_size: f32;
    var traveled = 1e9;

    var stack: array<u32, 16>;
    stack[0] = 0u;
    var stack_location = 1u;

    while true {
        if stack_location == 0 {
            break;
        }
        stack_location -= 1u;
        let node = nodes[stack[stack_location]];

        if !check_aabb_frustum(frustum, node.min-max_epsilon, node.max+max_epsilon) {
            continue;
        }

        if node.count == 0 {
            let a = nodes[node.index];
            let hit_a = project_node(a.min-max_epsilon, a.max+max_epsilon, march.origin, march.direction, ray_positive);
            let b = nodes[node.index+1];
            let hit_b = project_node(b.min-max_epsilon, b.max+max_epsilon, march.origin, march.direction, ray_positive);

            if hit_a.x > hit_b.x {
                stack[stack_location] = node.index;
                stack[stack_location+1] = node.index+1;
                stack_location += 2u;
            } else {
                stack[stack_location] = node.index+1;
                stack[stack_location+1] = node.index;
                stack_location += 2u;
            }

            continue;
        }

        // If the node is a leaf it should have one node to traverse

        let instance_id = node.index;
        let instance = instances[instance_id];

        let min = node.min - max_epsilon;
        let max = node.max + max_epsilon;
        var hit = project_node(min, max, march.origin, march.direction, ray_positive);
        hit = vec2<f32>(max(hit.x, march.start), min(hit.y, traveled));
        if hit.x > hit.y {
            continue;
        }

        let start = max(hit.x, 0.);
        let end = (min(hit.y, march.limit) - start) / instance.scale;

        let start_pos = march.origin + march.direction * start;

        let relative_pos = instance.matrix * (start_pos - instance.translation);
        let relative_dir = instance.matrix * march.direction * instance.scale;

        var dist = 0.;
        var local_traveled = 0.;

        let start_epsilon = clamp(epsilon_per_dist * hit.x, min_epsilon, max_epsilon) / instance.scale;
        var epsilon = start_epsilon;
        let max_epsilon = max_epsilon / instance.scale;

        let start_radius = start * radius_per_unit / instance.scale;
        let radius_per_scaled_unit = radius_per_unit * instance.scale;

        for (var i = 0u; i < 512u; i++) {
            let pos = relative_pos + relative_dir * local_traveled;
            dist = sdf(pos, instance.order_start, instance.data_start);

            epsilon = min(start_epsilon + local_traveled * epsilon_per_dist, max_epsilon);
            cluster_size = start_radius + radius_per_scaled_unit * local_traveled;

            if local_traveled > end || dist < cluster_size + epsilon {
                break;
            }

            local_traveled += max(dist - cluster_size, epsilon);
        }

        let cur_travel = start + local_traveled * instance.scale;
        if dist < cluster_size + epsilon && cur_travel < traveled {
            traveled = cur_travel;
        }
    }

    textureStore(cone_texture, position, vec4<f32>(traveled, 0., 0., 0.));
}

fn project_node(aabb_min: vec3<f32>, aabb_max: vec3<f32>, origin: vec3<f32>, direction: vec3<f32>, ray_sign: vec3<bool>) -> vec2<f32> {
    let min = aabb_min - origin;
    let max = aabb_max - origin;
    let axis_min = select(max, min, ray_sign);
    let axis_max = select(min, max, ray_sign);
    let t_min = dot(direction, axis_min);
    let t_max = dot(direction, axis_max);

    return vec2<f32>(t_min, t_max);
}

struct Frustum {
    planes: array<vec4<f32>, 6>,
    points: array<vec3<f32>, 8>,
    axes: array<vec3<f32>, 6>,
    mins: array<vec3<f32>, 6>,
    maxs: array<vec3<f32>, 6>,
}

fn get_cone_frustum(tl: vec4<f32>, tr: vec4<f32>, bl: vec4<f32>, br: vec4<f32>, center: vec3<f32>) -> Frustum{
    // Calculate points on the near plane
    let near_tl = settings.origin + tl.xyz * settings.near * tl.w;
    let near_tr = settings.origin + tr.xyz * settings.near * tr.w;
    let near_bl = settings.origin + bl.xyz * settings.near * bl.w;
    let near_br = settings.origin + br.xyz * settings.near * br.w;

    // Calculate points on the far plane
    let far_tl = settings.origin + tl.xyz * settings.far * tl.w;
    let far_tr = settings.origin + tr.xyz * settings.far * tr.w;
    let far_bl = settings.origin + bl.xyz * settings.far * bl.w;
    let far_br = settings.origin + br.xyz * settings.far * br.w;

    var frustum: Frustum;
    frustum.planes[0] = vec4<f32>(center, -dot(center, near_tl));
    frustum.planes[1] = vec4<f32>(-center, -dot(-center, far_tl));
    frustum.planes[2] = get_plane(near_bl, far_bl, far_tl);
    frustum.planes[3] = get_plane(near_tl, far_tl, far_tr);
    frustum.planes[4] = get_plane(near_tr, far_tr, far_br);
    frustum.planes[5] = get_plane(near_br, far_br, far_bl);

    frustum.points[0] = near_tl;
    frustum.points[1] = near_tr;
    frustum.points[2] = near_br;
    frustum.points[3] = near_bl;
    frustum.points[4] = far_tl;
    frustum.points[5] = far_tr;
    frustum.points[6] = far_br;
    frustum.points[7] = far_bl;

    let frustum_axes = array(normalize(far_tr - far_tl), normalize(far_tl - far_bl), tl.xyz, tr.xyz, bl.xyz, br.xyz);
    for (var i = 0u; i < 6; i++) {
        let a = frustum_axes[i];
        frustum.axes[i] = a;
        // Calculate the cross products between the current frustum edge and each aligned axis.
        // These axes doesn't need to be normalized since the min/max in both
        //   sides of the SAT check are always scaled according to it
        let axes = array(
            vec3<f32>(0., -a.z, a.y),
            vec3<f32>(a.z, 0., -a.x),
            vec3<f32>(-a.y, a.x, 0.),
        );
        for (var j = 0u; j < 3; j++) {
            let axis = axes[j];
            let dntl = dot(axis, near_tl);
            let dntr = dot(axis, near_tr);
            let dnbl = dot(axis, near_bl);
            let dnbr = dot(axis, near_br);
            let dftl = dot(axis, far_tl);
            let dftr = dot(axis, far_tr);
            let dfbl = dot(axis, far_bl);
            let dfbr = dot(axis, far_br);
            frustum.mins[i][j] = min(min(min(min(min(min(min(
                dntl, dntr), dnbl), dnbr),
                dftl), dftr), dfbl), dfbr
            );
            frustum.maxs[i][j] = max(max(max(max(max(max(max(
                dntl, dntr), dnbl), dnbr),
                dftl), dftr), dfbl), dfbr
            );
        }
    }

    return frustum;
}

fn get_plane(p0: vec3<f32>, p1: vec3<f32>, p2: vec3<f32>) -> vec4<f32> {
    let normal = normalize(cross(p1-p0, p2-p1));
    let offset = dot(normal, p0);
    return vec4<f32>(normal, -offset);
}

fn len_sq(v: vec3<f32>) -> f32 {
    return dot(v, v);
}

fn check_aabb_frustum(frustum: Frustum, min: vec3<f32>, max: vec3<f32>) -> bool {
    // check box outside/inside of frustum
    for (var i = 0u; i < 6; i++) {
        if dot(frustum.planes[i].xyzw, vec4(max + min, 2.0f)) + dot(abs(frustum.planes[i].xyz), max - min) < 0. {
          return false;
        }
    }

    // check frustum outside/inside of box
    var outMin = vec3<u32>(1);
    var outMax = vec3<u32>(1);

    for( var i=0; i<8; i++ ) {
        outMin = min(outMin, vec3<u32>(-sign(frustum.points[i].xyz - min)));
        outMax = min(outMax, vec3<u32>(sign(frustum.points[i].xyz - max)));
    }

    let out = max(outMin, outMax);
    if out.x == 1 || out.y == 1 || out.z == 1 {
        return false;
    }

    // Separating Axis Theorem for the cross products between each unique edge direction
    for (var i = 0u; i < 6; i++) {
        let a = frustum.axes[i];
        let f_mins = frustum.mins[i];
        let f_maxs = frustum.maxs[i];

        let axes = array(
            vec3<f32>(0., -a.z, a.y),
            vec3<f32>(a.z, 0., -a.x),
            vec3<f32>(-a.y, a.x, 0.),
        );
        for (var j = 0u; j < 3; j++) {
            let cross_axis = axes[j];
            let sign = sign(cross_axis) == vec3<f32>(1.);
            let axis_min = select(max, min, sign);
            let axis_max = select(min, max, sign);

            let t_min = dot(cross_axis, axis_min);
            let t_max = dot(cross_axis, axis_max);

            if f_mins[j] > t_max || t_min > f_maxs[j] {
                return false;
            }
        }
    }

    return true;
}