comfy_core/
shaders.rs

1use std::sync::atomic::{AtomicU32, Ordering};
2
3use crate::*;
4
5#[derive(Debug)]
6pub struct ShaderMap {
7    pub shaders: HashMap<ShaderId, Shader>,
8    pub watched_paths: HashMap<String, ShaderId>,
9}
10
11impl ShaderMap {
12    pub fn new() -> Self {
13        Self { shaders: Default::default(), watched_paths: Default::default() }
14    }
15
16    pub fn get(&self, id: ShaderId) -> Option<&Shader> {
17        self.shaders.get(&id)
18    }
19
20    pub fn insert_shader(&mut self, id: ShaderId, shader: Shader) {
21        self.shaders.insert(id, shader);
22    }
23
24    pub fn exists(&self, id: ShaderId) -> bool {
25        self.shaders.contains_key(&id)
26    }
27}
28
29pub type UniformDefs = HashMap<String, UniformDef>;
30
31#[derive(Clone, Debug)]
32pub struct Shader {
33    pub id: ShaderId,
34    pub name: String,
35    pub source: String,
36    pub uniform_defs: UniformDefs,
37    pub bindings: HashMap<String, u32>,
38}
39
40/// Opaque handle to a shader. The ID is exposed for debugging purposes.
41/// If you set it manually, you're on your own :)
42#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
43pub struct ShaderId(pub u64);
44
45impl std::fmt::Display for ShaderId {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        write!(f, "ShaderId({})", self.0)
48    }
49}
50
51#[derive(Clone, Debug, PartialEq, Eq)]
52pub struct ShaderInstance {
53    pub id: ShaderId,
54    pub uniforms: HashMap<String, Uniform>,
55}
56
57#[derive(Clone, Debug)]
58pub enum UniformDef {
59    F32(Option<f32>),
60    Custom { default_data: Option<Vec<u8>>, wgsl_decl: String },
61}
62
63impl UniformDef {
64    pub fn to_wgsl(&self) -> &str {
65        match self {
66            UniformDef::F32(_) => "f32",
67            UniformDef::Custom { wgsl_decl, .. } => wgsl_decl,
68        }
69    }
70}
71
72#[derive(Clone, Debug, PartialEq, Eq, Hash)]
73pub enum Uniform {
74    F32(OrderedFloat<f32>),
75    Custom(Vec<u8>),
76}
77
78// static CURRENT_RENDER_TARGET: Lazy<AtomicRefCell<Option<RenderTargetId>>> =
79//     Lazy::new(|| AtomicRefCell::new(None));
80
81static CURRENT_RENDER_TARGET: AtomicU32 = AtomicU32::new(0);
82
83pub fn use_render_target(id: RenderTargetId) {
84    CURRENT_RENDER_TARGET.store(id.0, Ordering::SeqCst);
85    // *CURRENT_RENDER_TARGET.borrow_mut() = Some(id);
86}
87
88pub fn use_default_render_target() {
89    CURRENT_RENDER_TARGET.store(0, Ordering::SeqCst);
90}
91
92pub fn get_current_render_target() -> RenderTargetId {
93    RenderTargetId(CURRENT_RENDER_TARGET.load(Ordering::SeqCst))
94}
95
96/// Sets a `f32` uniform value by name. The uniform must exist in the shader.
97pub fn set_uniform_f32(name: impl Into<String>, value: f32) {
98    set_uniform(name, Uniform::F32(OrderedFloat(value)));
99}
100
101/// Creates a new shader and returns its ID. The `source` parameter should only contain the
102/// fragment function, as the rest of the shader is automatically generated.
103///
104/// `uniform_defs` specifies the uniforms that the shader will use. The keys are the uniform names
105/// that will be also automatically generated and can be directly used in the shader. Meaning users
106/// don't have to care about WGPU bindings/groups.
107///
108/// For example, if you have a uniform named `time`, you simply use it as `time` in the shader.
109///
110/// `ShaderMap` can be obtained from `EngineContext` as `c.renderer.shaders.borrow_mut()`
111pub fn create_shader(
112    shaders: &mut ShaderMap,
113    name: &str,
114    source: &str,
115    uniform_defs: UniformDefs,
116) -> Result<ShaderId> {
117    let id = gen_shader_id();
118
119    if !source.contains("@vertex") {
120        panic!(
121            "Missing @vertex function in shader passed to `create_shader`.
122
123             Did you forget to call `sprite_shader_from_fragment`?"
124        );
125    }
126
127    if shaders.exists(id) {
128        bail!("Shader with name '{}' already exists", name);
129    }
130
131    let bindings = uniform_defs_to_bindings(&uniform_defs);
132
133    shaders.insert_shader(id, Shader {
134        id,
135        name: format!("{} Shader", name),
136        source: build_shader_source(source, &bindings, &uniform_defs),
137        uniform_defs,
138        bindings,
139    });
140
141    Ok(id)
142}
143
144pub fn uniform_defs_to_bindings(
145    uniform_defs: &UniformDefs,
146) -> HashMap<String, u32> {
147    uniform_defs
148        .iter()
149        .sorted_by_key(|x| x.0)
150        .enumerate()
151        .map(|(i, (name, _))| (name.clone(), i as u32))
152        .collect::<HashMap<String, u32>>()
153}
154
155/// Stores both a static source code for a shader as well as path to its file in development. This
156/// is used for automatic shader hot reloading.
157pub struct ReloadableShaderSource {
158    pub static_source: String,
159    pub path: String,
160}
161
162pub fn build_shader_source(
163    fragment_source: &str,
164    bindings: &HashMap<String, u32>,
165    uniform_defs: &UniformDefs,
166) -> String {
167    let mut uniforms_src = String::new();
168
169    for (name, binding) in bindings.iter() {
170        let typ = uniform_defs.get(name).unwrap();
171
172        uniforms_src.push_str(&format!(
173            "@group(2) @binding({})
174            var<uniform> {}: {};",
175            binding,
176            name,
177            typ.to_wgsl()
178        ));
179    }
180
181    format!("{}\n{}", uniforms_src, fragment_source)
182}
183
184#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash, PartialOrd, Ord)]
185pub struct RenderTargetId(pub u32);