1#![cfg_attr(
2 target_arch = "spirv",
3 no_std,
4 feature(register_attr, lang_items),
5 register_attr(spirv)
6)]
7
8use crate::rand::DefaultRng;
9use camera::Camera;
10use hittable::HitRecord;
11use material::{Material, Scatter};
12use ray::Ray;
13use spirv_std::glam::{vec3, UVec3, Vec3, Vec4};
14#[cfg(not(target_arch = "spirv"))]
15use spirv_std::macros::spirv;
16#[allow(unused_imports)]
17use spirv_std::num_traits::Float;
18use spirv_std::num_traits::FloatConst;
19
20use bytemuck::{Pod, Zeroable};
21
22pub mod aabb;
23pub mod bool;
24pub mod bvh;
25pub mod camera;
26pub mod hittable;
27pub mod material;
28pub mod math;
29pub mod pod;
30pub mod rand;
31pub mod ray;
32pub mod sphere;
33
34#[derive(Copy, Clone, Pod, Zeroable)]
35#[repr(C)]
36pub struct ShaderConstants {
37 pub width: u32,
38 pub height: u32,
39 pub seed: u32,
40}
41
42fn ray_color(
66 mut ray: Ray,
67 world: &[sphere::Sphere],
68 bvh: &[bvh::BVHNode],
69 rng: &mut DefaultRng,
70) -> Vec3 {
71 let mut color = vec3(1.0, 1.0, 1.0);
72 let mut hit_record = HitRecord::default();
73 let mut scatter = Scatter::default();
74
75 for _ in 0..50 {
76 if (bvh::BVH { nodes: bvh })
77 .hit(&ray, 0.001, f32::INFINITY, &mut hit_record, world)
78 .into()
79 {
80 let material = hit_record.material;
81
82 if material
83 .scatter(&ray, &hit_record, rng, &mut scatter)
84 .into()
85 {
86 color *= scatter.color;
87 ray = scatter.ray;
88 } else {
89 break;
90 }
91 } else {
92 let unit_direction = ray.direction.normalize();
93 let t = 0.5 * (unit_direction.y + 1.0);
94 color *= vec3(1.0, 1.0, 1.0).lerp(vec3(0.5, 0.7, 1.0), t);
95 break;
96 };
97 }
98
99 color
100}
101
102pub const NUM_THREADS_X: u32 = 8;
103pub const NUM_THREADS_Y: u32 = 8;
104
105#[spirv(compute(threads(8, 8, 1)))]
106pub fn main_cs(
107 #[spirv(global_invocation_id)] id: UVec3,
108 #[spirv(push_constant)] constants: &ShaderConstants,
109 #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] world: &[sphere::Sphere],
110 #[spirv(storage_buffer, descriptor_set = 0, binding = 1)] bvh: &[bvh::BVHNode],
111 #[spirv(storage_buffer, descriptor_set = 0, binding = 2)] out: &mut [Vec4],
112) {
113 let x = id.x;
114 let y = id.y;
115
116 if x >= constants.width {
117 return;
118 }
119
120 if y >= constants.height {
121 return;
122 }
123
124 let seed = constants.seed ^ (constants.width * y + x);
125 let mut rng = DefaultRng::new(seed);
126
127 let camera = Camera::new(
128 vec3(13.0, 2.0, 3.0),
129 vec3(0.0, 0.0, 0.0),
130 vec3(0.0, 1.0, 0.0),
131 20.0 / 180.0 * f32::PI(),
132 constants.width as f32 / constants.height as f32,
133 0.1,
134 10.0,
135 0.0,
136 1.0,
137 );
138
139 let u = (x as f32 + rng.next_f32()) / (constants.width - 1) as f32;
140 let v = (y as f32 + rng.next_f32()) / (constants.height - 1) as f32;
141
142 let ray = camera.get_ray(u, v, &mut rng);
143 let color = ray_color(ray, world, bvh, &mut rng);
144
145 out[((constants.height - y - 1) * constants.width + x) as usize] += color.extend(1.0);
146}