rukako_shader/
lib.rs

1#![cfg_attr(
2    target_arch = "spirv",
3    no_std,
4    feature(register_attr, lang_items),
5    register_attr(spirv)
6)]
7
8use crate::rand::DefaultRng;
9use camera::Camera;
10use hittable::HitRecord;
11use material::{Material, Scatter};
12use ray::Ray;
13use spirv_std::glam::{vec3, UVec3, Vec3, Vec4};
14#[cfg(not(target_arch = "spirv"))]
15use spirv_std::macros::spirv;
16#[allow(unused_imports)]
17use spirv_std::num_traits::Float;
18use spirv_std::num_traits::FloatConst;
19
20use bytemuck::{Pod, Zeroable};
21
22pub mod aabb;
23pub mod bool;
24pub mod bvh;
25pub mod camera;
26pub mod hittable;
27pub mod material;
28pub mod math;
29pub mod pod;
30pub mod rand;
31pub mod ray;
32pub mod sphere;
33
34#[derive(Copy, Clone, Pod, Zeroable)]
35#[repr(C)]
36pub struct ShaderConstants {
37    pub width: u32,
38    pub height: u32,
39    pub seed: u32,
40}
41
42/*
43fn hit(
44    ray: &Ray,
45    world: &[sphere::Sphere],
46    len: usize,
47    t_min: f32,
48    t_max: f32,
49    hit_record: &mut HitRecord,
50) -> Bool32 {
51    let mut closest_so_far = t_max;
52    let mut hit = Bool32::FALSE;
53
54    for i in 0..len {
55        if world[i].hit(ray, t_min, closest_so_far, hit_record).into() {
56            closest_so_far = hit_record.t;
57            hit = Bool32::TRUE;
58        }
59    }
60
61    hit
62}
63*/
64
65fn ray_color(
66    mut ray: Ray,
67    world: &[sphere::Sphere],
68    bvh: &[bvh::BVHNode],
69    rng: &mut DefaultRng,
70) -> Vec3 {
71    let mut color = vec3(1.0, 1.0, 1.0);
72    let mut hit_record = HitRecord::default();
73    let mut scatter = Scatter::default();
74
75    for _ in 0..50 {
76        if (bvh::BVH { nodes: bvh })
77            .hit(&ray, 0.001, f32::INFINITY, &mut hit_record, world)
78            .into()
79        {
80            let material = hit_record.material;
81
82            if material
83                .scatter(&ray, &hit_record, rng, &mut scatter)
84                .into()
85            {
86                color *= scatter.color;
87                ray = scatter.ray;
88            } else {
89                break;
90            }
91        } else {
92            let unit_direction = ray.direction.normalize();
93            let t = 0.5 * (unit_direction.y + 1.0);
94            color *= vec3(1.0, 1.0, 1.0).lerp(vec3(0.5, 0.7, 1.0), t);
95            break;
96        };
97    }
98
99    color
100}
101
102pub const NUM_THREADS_X: u32 = 8;
103pub const NUM_THREADS_Y: u32 = 8;
104
105#[spirv(compute(threads(/* NUM_THREADS_X */ 8, /* NUM_THREADS_Y */ 8, 1)))]
106pub fn main_cs(
107    #[spirv(global_invocation_id)] id: UVec3,
108    #[spirv(push_constant)] constants: &ShaderConstants,
109    #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] world: &[sphere::Sphere],
110    #[spirv(storage_buffer, descriptor_set = 0, binding = 1)] bvh: &[bvh::BVHNode],
111    #[spirv(storage_buffer, descriptor_set = 0, binding = 2)] out: &mut [Vec4],
112) {
113    let x = id.x;
114    let y = id.y;
115
116    if x >= constants.width {
117        return;
118    }
119
120    if y >= constants.height {
121        return;
122    }
123
124    let seed = constants.seed ^ (constants.width * y + x);
125    let mut rng = DefaultRng::new(seed);
126
127    let camera = Camera::new(
128        vec3(13.0, 2.0, 3.0),
129        vec3(0.0, 0.0, 0.0),
130        vec3(0.0, 1.0, 0.0),
131        20.0 / 180.0 * f32::PI(),
132        constants.width as f32 / constants.height as f32,
133        0.1,
134        10.0,
135        0.0,
136        1.0,
137    );
138
139    let u = (x as f32 + rng.next_f32()) / (constants.width - 1) as f32;
140    let v = (y as f32 + rng.next_f32()) / (constants.height - 1) as f32;
141
142    let ray = camera.get_ray(u, v, &mut rng);
143    let color = ray_color(ray, world, bvh, &mut rng);
144
145    out[((constants.height - y - 1) * constants.width + x) as usize] += color.extend(1.0);
146}