polyscope-render 0.5.10

Rendering backend for polyscope-rs: wgpu engine, shaders, and materials
Documentation
// Vector arrow shader using instanced rendering
// Each instance is an arrow (cylinder shaft + cone head) from a base point in a direction

struct CameraUniforms {
    view: mat4x4<f32>,
    proj: mat4x4<f32>,
    view_proj: mat4x4<f32>,
    inv_proj: mat4x4<f32>,
    camera_pos: vec3<f32>,
    _padding: f32,
}

// Slice plane uniforms for fragment-level slicing
struct SlicePlaneUniforms {
    origin: vec3<f32>,
    enabled: f32,
    normal: vec3<f32>,
    _padding: f32,
}

struct SlicePlanesArray {
    planes: array<SlicePlaneUniforms, 4>,
}

struct VectorUniforms {
    model: mat4x4<f32>,
    length_scale: f32,
    radius: f32,
    _padding: vec2<f32>,
    color: vec4<f32>,
}

@group(0) @binding(0) var<uniform> camera: CameraUniforms;
@group(0) @binding(1) var<uniform> vector_uniforms: VectorUniforms;
@group(0) @binding(2) var<storage, read> base_positions: array<vec3<f32>>;
@group(0) @binding(3) var<storage, read> vectors: array<vec3<f32>>;

@group(1) @binding(0) var<uniform> slice_planes: SlicePlanesArray;

// Matcap textures (Group 2)
@group(2) @binding(0) var matcap_r: texture_2d<f32>;
@group(2) @binding(1) var matcap_g: texture_2d<f32>;
@group(2) @binding(2) var matcap_b: texture_2d<f32>;
@group(2) @binding(3) var matcap_k: texture_2d<f32>;
@group(2) @binding(4) var matcap_sampler: sampler;

fn light_surface_matcap(normal: vec3<f32>, color: vec3<f32>) -> vec3<f32> {
    var n = normalize(normal);
    n.y = -n.y;
    n = n * 0.98;
    let uv = n.xy * 0.5 + vec2<f32>(0.5);
    let mat_r = textureSample(matcap_r, matcap_sampler, uv).rgb;
    let mat_g = textureSample(matcap_g, matcap_sampler, uv).rgb;
    let mat_b = textureSample(matcap_b, matcap_sampler, uv).rgb;
    let mat_k = textureSample(matcap_k, matcap_sampler, uv).rgb;
    return color.r * mat_r + color.g * mat_g
         + color.b * mat_b + (1.0 - color.r - color.g - color.b) * mat_k;
}

struct VertexOutput {
    @builtin(position) clip_position: vec4<f32>,
    @location(0) normal: vec3<f32>,
    @location(1) color: vec3<f32>,
    @location(2) world_position: vec3<f32>,
}

// Arrow geometry: cylinder shaft + cone sides + cone cap + shaft cap
// Shaft sides: 8 segments × 6 verts = 48 verts (indices 0..47)
// Cone sides:  8 segments × 3 verts = 24 verts (indices 48..71)
// Cone cap:    8 segments × 3 verts = 24 verts (indices 72..95)
// Shaft cap:   8 segments × 3 verts = 24 verts (indices 96..119)
// Total: 120 vertices per arrow instance
const SEGMENTS: u32 = 8u;
const SHAFT_VERTS: u32 = 48u; // SEGMENTS * 6
const CONE_SIDE_VERTS: u32 = 24u; // SEGMENTS * 3
const CONE_CAP_VERTS: u32 = 24u; // SEGMENTS * 3

// Arrow proportions
const CONE_HEIGHT_FRAC: f32 = 0.3;     // Cone takes 30% of total arrow length
const CONE_RADIUS_MULT: f32 = 2.0;     // Cone base radius = 2× shaft radius

@vertex
fn vs_main(
    @builtin(vertex_index) vertex_index: u32,
    @builtin(instance_index) instance_index: u32,
) -> VertexOutput {
    var out: VertexOutput;

    // Transform base position and vector direction by model matrix
    let base_pos = (vector_uniforms.model * vec4<f32>(base_positions[instance_index], 1.0)).xyz;
    let raw_vec = vectors[instance_index];
    let vec = (vector_uniforms.model * vec4<f32>(raw_vec, 0.0)).xyz;
    let vec_length = length(vec);

    if (vec_length < 0.0001) {
        // Zero vector - place vertex at origin (will be degenerate)
        out.clip_position = camera.view_proj * vec4<f32>(base_pos, 1.0);
        out.normal = vec3<f32>(0.0, 1.0, 0.0);
        out.color = vector_uniforms.color.rgb;
        out.world_position = base_pos;
        return out;
    }

    let vec_dir = vec / vec_length;
    let scaled_length = vec_length * vector_uniforms.length_scale;
    let shaft_radius = vector_uniforms.radius;
    let cone_base_radius = shaft_radius * CONE_RADIUS_MULT;
    let shaft_height = scaled_length * (1.0 - CONE_HEIGHT_FRAC);
    let cone_height = scaled_length * CONE_HEIGHT_FRAC;

    // Build right-handed orthonormal basis (right × forward = vec_dir)
    var up = vec3<f32>(0.0, 1.0, 0.0);
    if (abs(dot(vec_dir, up)) > 0.99) {
        up = vec3<f32>(1.0, 0.0, 0.0);
    }
    let right = normalize(cross(up, vec_dir));
    let forward = cross(vec_dir, right);

    var local_pos: vec3<f32>;
    var local_normal: vec3<f32>;

    if (vertex_index < SHAFT_VERTS) {
        // === Cylinder shaft ===
        let segment = vertex_index / 6u;
        let tri_vert = vertex_index % 6u;

        let angle0 = f32(segment) / f32(SEGMENTS) * 6.283185;
        let angle1 = f32(segment + 1u) / f32(SEGMENTS) * 6.283185;

        // Two triangles per segment forming a quad
        if (tri_vert == 0u) {
            local_pos = vec3<f32>(cos(angle0) * shaft_radius, sin(angle0) * shaft_radius, 0.0);
            local_normal = vec3<f32>(cos(angle0), sin(angle0), 0.0);
        } else if (tri_vert == 1u) {
            local_pos = vec3<f32>(cos(angle1) * shaft_radius, sin(angle1) * shaft_radius, 0.0);
            local_normal = vec3<f32>(cos(angle1), sin(angle1), 0.0);
        } else if (tri_vert == 2u) {
            local_pos = vec3<f32>(cos(angle0) * shaft_radius, sin(angle0) * shaft_radius, shaft_height);
            local_normal = vec3<f32>(cos(angle0), sin(angle0), 0.0);
        } else if (tri_vert == 3u) {
            local_pos = vec3<f32>(cos(angle1) * shaft_radius, sin(angle1) * shaft_radius, 0.0);
            local_normal = vec3<f32>(cos(angle1), sin(angle1), 0.0);
        } else if (tri_vert == 4u) {
            local_pos = vec3<f32>(cos(angle1) * shaft_radius, sin(angle1) * shaft_radius, shaft_height);
            local_normal = vec3<f32>(cos(angle1), sin(angle1), 0.0);
        } else {
            local_pos = vec3<f32>(cos(angle0) * shaft_radius, sin(angle0) * shaft_radius, shaft_height);
            local_normal = vec3<f32>(cos(angle0), sin(angle0), 0.0);
        }
    } else if (vertex_index < SHAFT_VERTS + CONE_SIDE_VERTS) {
        // === Cone arrowhead sides ===
        let cone_index = vertex_index - SHAFT_VERTS;
        let segment = cone_index / 3u;
        let tri_vert = cone_index % 3u;

        let angle0 = f32(segment) / f32(SEGMENTS) * 6.283185;
        let angle1 = f32(segment + 1u) / f32(SEGMENTS) * 6.283185;

        // Cone normal: for a cone with base radius R and height h,
        // the outward normal has radial component h and axial component R,
        // normalized to unit length.
        let cone_norm_len = sqrt(cone_height * cone_height + cone_base_radius * cone_base_radius);
        let n_radial = cone_height / cone_norm_len;
        let n_axial = cone_base_radius / cone_norm_len;

        // One triangle per segment: two base verts + tip
        if (tri_vert == 0u) {
            // Base vertex at angle0
            local_pos = vec3<f32>(cos(angle0) * cone_base_radius, sin(angle0) * cone_base_radius, shaft_height);
            local_normal = vec3<f32>(cos(angle0) * n_radial, sin(angle0) * n_radial, n_axial);
        } else if (tri_vert == 1u) {
            // Base vertex at angle1
            local_pos = vec3<f32>(cos(angle1) * cone_base_radius, sin(angle1) * cone_base_radius, shaft_height);
            local_normal = vec3<f32>(cos(angle1) * n_radial, sin(angle1) * n_radial, n_axial);
        } else {
            // Tip vertex
            local_pos = vec3<f32>(0.0, 0.0, scaled_length);
            // Average normal for tip (use midpoint angle)
            let mid_angle = (angle0 + angle1) * 0.5;
            local_normal = vec3<f32>(cos(mid_angle) * n_radial, sin(mid_angle) * n_radial, n_axial);
        }
    } else if (vertex_index < SHAFT_VERTS + CONE_SIDE_VERTS + CONE_CAP_VERTS) {
        // === Cone bottom cap (disc) ===
        let cap_index = vertex_index - SHAFT_VERTS - CONE_SIDE_VERTS;
        let segment = cap_index / 3u;
        let tri_vert = cap_index % 3u;

        let angle0 = f32(segment) / f32(SEGMENTS) * 6.283185;
        let angle1 = f32(segment + 1u) / f32(SEGMENTS) * 6.283185;

        // Normal points back toward shaft base (negative Z in local space)
        local_normal = vec3<f32>(0.0, 0.0, -1.0);

        // Fan triangle: center, edge at angle1, edge at angle0 (CW winding for outward -Z normal)
        if (tri_vert == 0u) {
            local_pos = vec3<f32>(0.0, 0.0, shaft_height);
        } else if (tri_vert == 1u) {
            local_pos = vec3<f32>(cos(angle1) * cone_base_radius, sin(angle1) * cone_base_radius, shaft_height);
        } else {
            local_pos = vec3<f32>(cos(angle0) * cone_base_radius, sin(angle0) * cone_base_radius, shaft_height);
        }
    } else {
        // === Shaft bottom cap (disc) ===
        let cap_index = vertex_index - SHAFT_VERTS - CONE_SIDE_VERTS - CONE_CAP_VERTS;
        let segment = cap_index / 3u;
        let tri_vert = cap_index % 3u;

        let angle0 = f32(segment) / f32(SEGMENTS) * 6.283185;
        let angle1 = f32(segment + 1u) / f32(SEGMENTS) * 6.283185;

        // Normal points away from arrow tip (negative Z in local space)
        local_normal = vec3<f32>(0.0, 0.0, -1.0);

        // Fan triangle: center, edge at angle1, edge at angle0 (CW winding for outward -Z normal)
        if (tri_vert == 0u) {
            local_pos = vec3<f32>(0.0, 0.0, 0.0);
        } else if (tri_vert == 1u) {
            local_pos = vec3<f32>(cos(angle1) * shaft_radius, sin(angle1) * shaft_radius, 0.0);
        } else {
            local_pos = vec3<f32>(cos(angle0) * shaft_radius, sin(angle0) * shaft_radius, 0.0);
        }
    }

    // Transform to world space
    let world_pos = base_pos
        + right * local_pos.x
        + forward * local_pos.y
        + vec_dir * local_pos.z;

    let world_normal = normalize(right * local_normal.x + forward * local_normal.y + vec_dir * local_normal.z);

    out.clip_position = camera.view_proj * vec4<f32>(world_pos, 1.0);
    out.normal = world_normal;
    out.color = vector_uniforms.color.rgb;
    out.world_position = world_pos;

    return out;
}

@fragment
fn fs_main(in: VertexOutput) -> @location(0) vec4<f32> {
    // Slice plane culling
    for (var i = 0u; i < 4u; i = i + 1u) {
        let plane = slice_planes.planes[i];
        if (plane.enabled > 0.5) {
            let dist = dot(in.world_position - plane.origin, plane.normal);
            if (dist < 0.0) {
                discard;
            }
        }
    }

    // Matcap lighting: transform world-space normal to view space
    let view_normal = normalize((camera.view * vec4<f32>(normalize(in.normal), 0.0)).xyz);
    let lit_color = light_surface_matcap(view_normal, in.color);

    return vec4<f32>(lit_color, 1.0);
}