compute_shader_game_of_life/
compute_shader_game_of_life.rs1use bevy::{
7 asset::RenderAssetUsages,
8 prelude::*,
9 render::{
10 extract_resource::{ExtractResource, ExtractResourcePlugin},
11 render_asset::RenderAssets,
12 render_graph::{self, RenderGraph, RenderLabel},
13 render_resource::{
14 binding_types::{texture_storage_2d, uniform_buffer},
15 *,
16 },
17 renderer::{RenderContext, RenderDevice, RenderQueue},
18 texture::GpuImage,
19 Render, RenderApp, RenderStartup, RenderSystems,
20 },
21 shader::PipelineCacheError,
22};
23use std::borrow::Cow;
24
25const SHADER_ASSET_PATH: &str = "shaders/game_of_life.wgsl";
27
28const DISPLAY_FACTOR: u32 = 4;
29const SIZE: UVec2 = UVec2::new(1280 / DISPLAY_FACTOR, 720 / DISPLAY_FACTOR);
30const WORKGROUP_SIZE: u32 = 8;
31
32fn main() {
33 App::new()
34 .insert_resource(ClearColor(Color::BLACK))
35 .add_plugins((
36 DefaultPlugins
37 .set(WindowPlugin {
38 primary_window: Some(Window {
39 resolution: (SIZE * DISPLAY_FACTOR).into(),
40 ..default()
43 }),
44 ..default()
45 })
46 .set(ImagePlugin::default_nearest()),
47 GameOfLifeComputePlugin,
48 ))
49 .add_systems(Startup, setup)
50 .add_systems(Update, switch_textures)
51 .run();
52}
53
54fn setup(mut commands: Commands, mut images: ResMut<Assets<Image>>) {
55 let mut image = Image::new_target_texture(SIZE.x, SIZE.y, TextureFormat::Rgba32Float);
56 image.asset_usage = RenderAssetUsages::RENDER_WORLD;
57 image.texture_descriptor.usage =
58 TextureUsages::COPY_DST | TextureUsages::STORAGE_BINDING | TextureUsages::TEXTURE_BINDING;
59 let image0 = images.add(image.clone());
60 let image1 = images.add(image);
61
62 commands.spawn((
63 Sprite {
64 image: image0.clone(),
65 custom_size: Some(SIZE.as_vec2()),
66 ..default()
67 },
68 Transform::from_scale(Vec3::splat(DISPLAY_FACTOR as f32)),
69 ));
70 commands.spawn(Camera2d);
71
72 commands.insert_resource(GameOfLifeImages {
73 texture_a: image0,
74 texture_b: image1,
75 });
76
77 commands.insert_resource(GameOfLifeUniforms {
78 alive_color: LinearRgba::RED,
79 });
80}
81
82fn switch_textures(images: Res<GameOfLifeImages>, mut sprite: Single<&mut Sprite>) {
84 if sprite.image == images.texture_a {
85 sprite.image = images.texture_b.clone();
86 } else {
87 sprite.image = images.texture_a.clone();
88 }
89}
90
91struct GameOfLifeComputePlugin;
92
93#[derive(Debug, Hash, PartialEq, Eq, Clone, RenderLabel)]
94struct GameOfLifeLabel;
95
96impl Plugin for GameOfLifeComputePlugin {
97 fn build(&self, app: &mut App) {
98 app.add_plugins((
101 ExtractResourcePlugin::<GameOfLifeImages>::default(),
102 ExtractResourcePlugin::<GameOfLifeUniforms>::default(),
103 ));
104 let render_app = app.sub_app_mut(RenderApp);
105 render_app
106 .add_systems(RenderStartup, init_game_of_life_pipeline)
107 .add_systems(
108 Render,
109 prepare_bind_group.in_set(RenderSystems::PrepareBindGroups),
110 );
111
112 let mut render_graph = render_app.world_mut().resource_mut::<RenderGraph>();
113 render_graph.add_node(GameOfLifeLabel, GameOfLifeNode::default());
114 render_graph.add_node_edge(GameOfLifeLabel, bevy::render::graph::CameraDriverLabel);
115 }
116}
117
118#[derive(Resource, Clone, ExtractResource)]
119struct GameOfLifeImages {
120 texture_a: Handle<Image>,
121 texture_b: Handle<Image>,
122}
123
124#[derive(Resource, Clone, ExtractResource, ShaderType)]
125struct GameOfLifeUniforms {
126 alive_color: LinearRgba,
127}
128
129#[derive(Resource)]
130struct GameOfLifeImageBindGroups([BindGroup; 2]);
131
132fn prepare_bind_group(
133 mut commands: Commands,
134 pipeline: Res<GameOfLifePipeline>,
135 gpu_images: Res<RenderAssets<GpuImage>>,
136 game_of_life_images: Res<GameOfLifeImages>,
137 game_of_life_uniforms: Res<GameOfLifeUniforms>,
138 render_device: Res<RenderDevice>,
139 queue: Res<RenderQueue>,
140) {
141 let view_a = gpu_images.get(&game_of_life_images.texture_a).unwrap();
142 let view_b = gpu_images.get(&game_of_life_images.texture_b).unwrap();
143
144 let mut uniform_buffer = UniformBuffer::from(game_of_life_uniforms.into_inner());
147 uniform_buffer.write_buffer(&render_device, &queue);
148
149 let bind_group_0 = render_device.create_bind_group(
150 None,
151 &pipeline.texture_bind_group_layout,
152 &BindGroupEntries::sequential((
153 &view_a.texture_view,
154 &view_b.texture_view,
155 &uniform_buffer,
156 )),
157 );
158 let bind_group_1 = render_device.create_bind_group(
159 None,
160 &pipeline.texture_bind_group_layout,
161 &BindGroupEntries::sequential((
162 &view_b.texture_view,
163 &view_a.texture_view,
164 &uniform_buffer,
165 )),
166 );
167 commands.insert_resource(GameOfLifeImageBindGroups([bind_group_0, bind_group_1]));
168}
169
170#[derive(Resource)]
171struct GameOfLifePipeline {
172 texture_bind_group_layout: BindGroupLayout,
173 init_pipeline: CachedComputePipelineId,
174 update_pipeline: CachedComputePipelineId,
175}
176
177fn init_game_of_life_pipeline(
178 mut commands: Commands,
179 render_device: Res<RenderDevice>,
180 asset_server: Res<AssetServer>,
181 pipeline_cache: Res<PipelineCache>,
182) {
183 let texture_bind_group_layout = render_device.create_bind_group_layout(
184 "GameOfLifeImages",
185 &BindGroupLayoutEntries::sequential(
186 ShaderStages::COMPUTE,
187 (
188 texture_storage_2d(TextureFormat::Rgba32Float, StorageTextureAccess::ReadOnly),
189 texture_storage_2d(TextureFormat::Rgba32Float, StorageTextureAccess::WriteOnly),
190 uniform_buffer::<GameOfLifeUniforms>(false),
191 ),
192 ),
193 );
194 let shader = asset_server.load(SHADER_ASSET_PATH);
195 let init_pipeline = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
196 layout: vec![texture_bind_group_layout.clone()],
197 shader: shader.clone(),
198 entry_point: Some(Cow::from("init")),
199 ..default()
200 });
201 let update_pipeline = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
202 layout: vec![texture_bind_group_layout.clone()],
203 shader,
204 entry_point: Some(Cow::from("update")),
205 ..default()
206 });
207
208 commands.insert_resource(GameOfLifePipeline {
209 texture_bind_group_layout,
210 init_pipeline,
211 update_pipeline,
212 });
213}
214
215enum GameOfLifeState {
216 Loading,
217 Init,
218 Update(usize),
219}
220
221struct GameOfLifeNode {
222 state: GameOfLifeState,
223}
224
225impl Default for GameOfLifeNode {
226 fn default() -> Self {
227 Self {
228 state: GameOfLifeState::Loading,
229 }
230 }
231}
232
233impl render_graph::Node for GameOfLifeNode {
234 fn update(&mut self, world: &mut World) {
235 let pipeline = world.resource::<GameOfLifePipeline>();
236 let pipeline_cache = world.resource::<PipelineCache>();
237
238 match self.state {
240 GameOfLifeState::Loading => {
241 match pipeline_cache.get_compute_pipeline_state(pipeline.init_pipeline) {
242 CachedPipelineState::Ok(_) => {
243 self.state = GameOfLifeState::Init;
244 }
245 CachedPipelineState::Err(PipelineCacheError::ShaderNotLoaded(_)) => {}
247 CachedPipelineState::Err(err) => {
248 panic!("Initializing assets/{SHADER_ASSET_PATH}:\n{err}")
249 }
250 _ => {}
251 }
252 }
253 GameOfLifeState::Init => {
254 if let CachedPipelineState::Ok(_) =
255 pipeline_cache.get_compute_pipeline_state(pipeline.update_pipeline)
256 {
257 self.state = GameOfLifeState::Update(1);
258 }
259 }
260 GameOfLifeState::Update(0) => {
261 self.state = GameOfLifeState::Update(1);
262 }
263 GameOfLifeState::Update(1) => {
264 self.state = GameOfLifeState::Update(0);
265 }
266 GameOfLifeState::Update(_) => unreachable!(),
267 }
268 }
269
270 fn run(
271 &self,
272 _graph: &mut render_graph::RenderGraphContext,
273 render_context: &mut RenderContext,
274 world: &World,
275 ) -> Result<(), render_graph::NodeRunError> {
276 let bind_groups = &world.resource::<GameOfLifeImageBindGroups>().0;
277 let pipeline_cache = world.resource::<PipelineCache>();
278 let pipeline = world.resource::<GameOfLifePipeline>();
279
280 let mut pass = render_context
281 .command_encoder()
282 .begin_compute_pass(&ComputePassDescriptor::default());
283
284 match self.state {
286 GameOfLifeState::Loading => {}
287 GameOfLifeState::Init => {
288 let init_pipeline = pipeline_cache
289 .get_compute_pipeline(pipeline.init_pipeline)
290 .unwrap();
291 pass.set_bind_group(0, &bind_groups[0], &[]);
292 pass.set_pipeline(init_pipeline);
293 pass.dispatch_workgroups(SIZE.x / WORKGROUP_SIZE, SIZE.y / WORKGROUP_SIZE, 1);
294 }
295 GameOfLifeState::Update(index) => {
296 let update_pipeline = pipeline_cache
297 .get_compute_pipeline(pipeline.update_pipeline)
298 .unwrap();
299 pass.set_bind_group(0, &bind_groups[index], &[]);
300 pass.set_pipeline(update_pipeline);
301 pass.dispatch_workgroups(SIZE.x / WORKGROUP_SIZE, SIZE.y / WORKGROUP_SIZE, 1);
302 }
303 }
304
305 Ok(())
306 }
307}