compute_shader_game_of_life/
compute_shader_game_of_life.rs

1//! A compute shader that simulates Conway's Game of Life.
2//!
3//! Compute shaders use the GPU for computing arbitrary information, that may be independent of what
4//! is rendered to the screen.
5
6use bevy::{
7    asset::RenderAssetUsages,
8    prelude::*,
9    render::{
10        extract_resource::{ExtractResource, ExtractResourcePlugin},
11        render_asset::RenderAssets,
12        render_graph::{self, RenderGraph, RenderLabel},
13        render_resource::{
14            binding_types::{texture_storage_2d, uniform_buffer},
15            *,
16        },
17        renderer::{RenderContext, RenderDevice, RenderQueue},
18        texture::GpuImage,
19        Render, RenderApp, RenderStartup, RenderSystems,
20    },
21    shader::PipelineCacheError,
22};
23use std::borrow::Cow;
24
25/// This example uses a shader source file from the assets subdirectory
26const SHADER_ASSET_PATH: &str = "shaders/game_of_life.wgsl";
27
28const DISPLAY_FACTOR: u32 = 4;
29const SIZE: UVec2 = UVec2::new(1280 / DISPLAY_FACTOR, 720 / DISPLAY_FACTOR);
30const WORKGROUP_SIZE: u32 = 8;
31
32fn main() {
33    App::new()
34        .insert_resource(ClearColor(Color::BLACK))
35        .add_plugins((
36            DefaultPlugins
37                .set(WindowPlugin {
38                    primary_window: Some(Window {
39                        resolution: (SIZE * DISPLAY_FACTOR).into(),
40                        // uncomment for unthrottled FPS
41                        // present_mode: bevy::window::PresentMode::AutoNoVsync,
42                        ..default()
43                    }),
44                    ..default()
45                })
46                .set(ImagePlugin::default_nearest()),
47            GameOfLifeComputePlugin,
48        ))
49        .add_systems(Startup, setup)
50        .add_systems(Update, switch_textures)
51        .run();
52}
53
54fn setup(mut commands: Commands, mut images: ResMut<Assets<Image>>) {
55    let mut image = Image::new_target_texture(SIZE.x, SIZE.y, TextureFormat::Rgba32Float);
56    image.asset_usage = RenderAssetUsages::RENDER_WORLD;
57    image.texture_descriptor.usage =
58        TextureUsages::COPY_DST | TextureUsages::STORAGE_BINDING | TextureUsages::TEXTURE_BINDING;
59    let image0 = images.add(image.clone());
60    let image1 = images.add(image);
61
62    commands.spawn((
63        Sprite {
64            image: image0.clone(),
65            custom_size: Some(SIZE.as_vec2()),
66            ..default()
67        },
68        Transform::from_scale(Vec3::splat(DISPLAY_FACTOR as f32)),
69    ));
70    commands.spawn(Camera2d);
71
72    commands.insert_resource(GameOfLifeImages {
73        texture_a: image0,
74        texture_b: image1,
75    });
76
77    commands.insert_resource(GameOfLifeUniforms {
78        alive_color: LinearRgba::RED,
79    });
80}
81
82// Switch texture to display every frame to show the one that was written to most recently.
83fn switch_textures(images: Res<GameOfLifeImages>, mut sprite: Single<&mut Sprite>) {
84    if sprite.image == images.texture_a {
85        sprite.image = images.texture_b.clone();
86    } else {
87        sprite.image = images.texture_a.clone();
88    }
89}
90
91struct GameOfLifeComputePlugin;
92
93#[derive(Debug, Hash, PartialEq, Eq, Clone, RenderLabel)]
94struct GameOfLifeLabel;
95
96impl Plugin for GameOfLifeComputePlugin {
97    fn build(&self, app: &mut App) {
98        // Extract the game of life image resource from the main world into the render world
99        // for operation on by the compute shader and display on the sprite.
100        app.add_plugins((
101            ExtractResourcePlugin::<GameOfLifeImages>::default(),
102            ExtractResourcePlugin::<GameOfLifeUniforms>::default(),
103        ));
104        let render_app = app.sub_app_mut(RenderApp);
105        render_app
106            .add_systems(RenderStartup, init_game_of_life_pipeline)
107            .add_systems(
108                Render,
109                prepare_bind_group.in_set(RenderSystems::PrepareBindGroups),
110            );
111
112        let mut render_graph = render_app.world_mut().resource_mut::<RenderGraph>();
113        render_graph.add_node(GameOfLifeLabel, GameOfLifeNode::default());
114        render_graph.add_node_edge(GameOfLifeLabel, bevy::render::graph::CameraDriverLabel);
115    }
116}
117
118#[derive(Resource, Clone, ExtractResource)]
119struct GameOfLifeImages {
120    texture_a: Handle<Image>,
121    texture_b: Handle<Image>,
122}
123
124#[derive(Resource, Clone, ExtractResource, ShaderType)]
125struct GameOfLifeUniforms {
126    alive_color: LinearRgba,
127}
128
129#[derive(Resource)]
130struct GameOfLifeImageBindGroups([BindGroup; 2]);
131
132fn prepare_bind_group(
133    mut commands: Commands,
134    pipeline: Res<GameOfLifePipeline>,
135    gpu_images: Res<RenderAssets<GpuImage>>,
136    game_of_life_images: Res<GameOfLifeImages>,
137    game_of_life_uniforms: Res<GameOfLifeUniforms>,
138    render_device: Res<RenderDevice>,
139    queue: Res<RenderQueue>,
140) {
141    let view_a = gpu_images.get(&game_of_life_images.texture_a).unwrap();
142    let view_b = gpu_images.get(&game_of_life_images.texture_b).unwrap();
143
144    // Uniform buffer is used here to demonstrate how to set up a uniform in a compute shader
145    // Alternatives such as storage buffers or push constants may be more suitable for your use case
146    let mut uniform_buffer = UniformBuffer::from(game_of_life_uniforms.into_inner());
147    uniform_buffer.write_buffer(&render_device, &queue);
148
149    let bind_group_0 = render_device.create_bind_group(
150        None,
151        &pipeline.texture_bind_group_layout,
152        &BindGroupEntries::sequential((
153            &view_a.texture_view,
154            &view_b.texture_view,
155            &uniform_buffer,
156        )),
157    );
158    let bind_group_1 = render_device.create_bind_group(
159        None,
160        &pipeline.texture_bind_group_layout,
161        &BindGroupEntries::sequential((
162            &view_b.texture_view,
163            &view_a.texture_view,
164            &uniform_buffer,
165        )),
166    );
167    commands.insert_resource(GameOfLifeImageBindGroups([bind_group_0, bind_group_1]));
168}
169
170#[derive(Resource)]
171struct GameOfLifePipeline {
172    texture_bind_group_layout: BindGroupLayout,
173    init_pipeline: CachedComputePipelineId,
174    update_pipeline: CachedComputePipelineId,
175}
176
177fn init_game_of_life_pipeline(
178    mut commands: Commands,
179    render_device: Res<RenderDevice>,
180    asset_server: Res<AssetServer>,
181    pipeline_cache: Res<PipelineCache>,
182) {
183    let texture_bind_group_layout = render_device.create_bind_group_layout(
184        "GameOfLifeImages",
185        &BindGroupLayoutEntries::sequential(
186            ShaderStages::COMPUTE,
187            (
188                texture_storage_2d(TextureFormat::Rgba32Float, StorageTextureAccess::ReadOnly),
189                texture_storage_2d(TextureFormat::Rgba32Float, StorageTextureAccess::WriteOnly),
190                uniform_buffer::<GameOfLifeUniforms>(false),
191            ),
192        ),
193    );
194    let shader = asset_server.load(SHADER_ASSET_PATH);
195    let init_pipeline = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
196        layout: vec![texture_bind_group_layout.clone()],
197        shader: shader.clone(),
198        entry_point: Some(Cow::from("init")),
199        ..default()
200    });
201    let update_pipeline = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
202        layout: vec![texture_bind_group_layout.clone()],
203        shader,
204        entry_point: Some(Cow::from("update")),
205        ..default()
206    });
207
208    commands.insert_resource(GameOfLifePipeline {
209        texture_bind_group_layout,
210        init_pipeline,
211        update_pipeline,
212    });
213}
214
215enum GameOfLifeState {
216    Loading,
217    Init,
218    Update(usize),
219}
220
221struct GameOfLifeNode {
222    state: GameOfLifeState,
223}
224
225impl Default for GameOfLifeNode {
226    fn default() -> Self {
227        Self {
228            state: GameOfLifeState::Loading,
229        }
230    }
231}
232
233impl render_graph::Node for GameOfLifeNode {
234    fn update(&mut self, world: &mut World) {
235        let pipeline = world.resource::<GameOfLifePipeline>();
236        let pipeline_cache = world.resource::<PipelineCache>();
237
238        // if the corresponding pipeline has loaded, transition to the next stage
239        match self.state {
240            GameOfLifeState::Loading => {
241                match pipeline_cache.get_compute_pipeline_state(pipeline.init_pipeline) {
242                    CachedPipelineState::Ok(_) => {
243                        self.state = GameOfLifeState::Init;
244                    }
245                    // If the shader hasn't loaded yet, just wait.
246                    CachedPipelineState::Err(PipelineCacheError::ShaderNotLoaded(_)) => {}
247                    CachedPipelineState::Err(err) => {
248                        panic!("Initializing assets/{SHADER_ASSET_PATH}:\n{err}")
249                    }
250                    _ => {}
251                }
252            }
253            GameOfLifeState::Init => {
254                if let CachedPipelineState::Ok(_) =
255                    pipeline_cache.get_compute_pipeline_state(pipeline.update_pipeline)
256                {
257                    self.state = GameOfLifeState::Update(1);
258                }
259            }
260            GameOfLifeState::Update(0) => {
261                self.state = GameOfLifeState::Update(1);
262            }
263            GameOfLifeState::Update(1) => {
264                self.state = GameOfLifeState::Update(0);
265            }
266            GameOfLifeState::Update(_) => unreachable!(),
267        }
268    }
269
270    fn run(
271        &self,
272        _graph: &mut render_graph::RenderGraphContext,
273        render_context: &mut RenderContext,
274        world: &World,
275    ) -> Result<(), render_graph::NodeRunError> {
276        let bind_groups = &world.resource::<GameOfLifeImageBindGroups>().0;
277        let pipeline_cache = world.resource::<PipelineCache>();
278        let pipeline = world.resource::<GameOfLifePipeline>();
279
280        let mut pass = render_context
281            .command_encoder()
282            .begin_compute_pass(&ComputePassDescriptor::default());
283
284        // select the pipeline based on the current state
285        match self.state {
286            GameOfLifeState::Loading => {}
287            GameOfLifeState::Init => {
288                let init_pipeline = pipeline_cache
289                    .get_compute_pipeline(pipeline.init_pipeline)
290                    .unwrap();
291                pass.set_bind_group(0, &bind_groups[0], &[]);
292                pass.set_pipeline(init_pipeline);
293                pass.dispatch_workgroups(SIZE.x / WORKGROUP_SIZE, SIZE.y / WORKGROUP_SIZE, 1);
294            }
295            GameOfLifeState::Update(index) => {
296                let update_pipeline = pipeline_cache
297                    .get_compute_pipeline(pipeline.update_pipeline)
298                    .unwrap();
299                pass.set_bind_group(0, &bind_groups[index], &[]);
300                pass.set_pipeline(update_pipeline);
301                pass.dispatch_workgroups(SIZE.x / WORKGROUP_SIZE, SIZE.y / WORKGROUP_SIZE, 1);
302            }
303        }
304
305        Ok(())
306    }
307}