1use std::path::{Path, PathBuf};
4use std::sync::Arc;
5use wgpu::{Device, ShaderModule};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
8pub enum BuiltinShader {
9 Basic,
10 Lit,
11 Unlit,
12 SkyGradient,
13 SpriteUi,
14 Fallback,
15}
16
17#[derive(Debug, Clone)]
18pub enum ShaderSource {
19 Builtin(BuiltinShader),
20 Wgsl(&'static str),
21 WgslOwned(String),
22 File(PathBuf),
23}
24
25#[derive(thiserror::Error, Debug)]
26pub enum ShaderSourceError {
27 #[error("Failed to read shader file '{path}': {source}")]
28 Io {
29 path: String,
30 source: std::io::Error,
31 },
32}
33
34pub fn builtin_shader_source(shader: BuiltinShader) -> &'static str {
35 match shader {
36 BuiltinShader::Basic => BASIC_SHADER,
37 BuiltinShader::Lit => LIT_SHADER,
38 BuiltinShader::Unlit => UNLIT_SHADER,
39 BuiltinShader::SkyGradient => SKY_GRADIENT_SHADER,
40 BuiltinShader::SpriteUi => SPRITE_UI_SHADER,
41 BuiltinShader::Fallback => FALLBACK_SHADER,
42 }
43}
44
45pub fn load_shader_source(source: &ShaderSource) -> Result<String, ShaderSourceError> {
46 match source {
47 ShaderSource::Builtin(kind) => Ok(builtin_shader_source(*kind).to_string()),
48 ShaderSource::Wgsl(src) => Ok((*src).to_string()),
49 ShaderSource::WgslOwned(src) => Ok(src.clone()),
50 ShaderSource::File(path) => {
51 std::fs::read_to_string(path).map_err(|source| ShaderSourceError::Io {
52 path: path.display().to_string(),
53 source,
54 })
55 }
56 }
57}
58
59pub fn load_shader_source_from_path(path: impl AsRef<Path>) -> Result<String, ShaderSourceError> {
60 load_shader_source(&ShaderSource::File(path.as_ref().to_path_buf()))
61}
62
63pub fn load_wgsl(device: &Arc<Device>, source: &'static str) -> ShaderModule {
64 device.create_shader_module(wgpu::ShaderModuleDescriptor {
65 label: None,
66 source: wgpu::ShaderSource::Wgsl(source.into()),
67 })
68}
69
70pub const BASIC_SHADER: &str = include_str!("../shaders/basic.wgsl");
71pub const LIT_SHADER: &str = include_str!("../shaders/lit.wgsl");
72pub const UNLIT_SHADER: &str = include_str!("../shaders/unlit.wgsl");
73pub const SKY_GRADIENT_SHADER: &str = include_str!("../shaders/sky_gradient.wgsl");
74pub const SPRITE_UI_SHADER: &str = include_str!("../shaders/sprite_ui.wgsl");
75pub const FALLBACK_SHADER: &str = include_str!("../shaders/fallback.wgsl");