ringkernel_wgpu_codegen/
handler.rs1#[derive(Debug, Clone, Default)]
7pub struct HandlerCodegenConfig {
8 pub inline_context_methods: bool,
10 pub bounds_checking: bool,
12}
13
14impl HandlerCodegenConfig {
15 pub fn new() -> Self {
17 Self {
18 inline_context_methods: true,
19 bounds_checking: true,
20 }
21 }
22
23 pub fn without_inlining(mut self) -> Self {
25 self.inline_context_methods = false;
26 self
27 }
28
29 pub fn without_bounds_checking(mut self) -> Self {
31 self.bounds_checking = false;
32 self
33 }
34}
35
36#[derive(Debug, Clone)]
38pub struct HandlerSignature {
39 pub name: String,
41 pub params: Vec<HandlerParam>,
43 pub return_type: HandlerReturnType,
45 pub has_context: bool,
47 pub message_param: Option<usize>,
49}
50
51#[derive(Debug, Clone)]
53pub struct HandlerParam {
54 pub name: String,
56 pub kind: HandlerParamKind,
58 pub wgsl_type: String,
60}
61
62#[derive(Debug, Clone, PartialEq, Eq)]
64pub enum HandlerParamKind {
65 Context,
67 Message,
69 Buffer { mutable: bool },
71 Scalar,
73}
74
75#[derive(Debug, Clone)]
77pub enum HandlerReturnType {
78 Unit,
80 Value(String),
82 Message(String),
84}
85
86#[derive(Debug, Clone, Copy, PartialEq, Eq)]
88pub enum WgslContextMethod {
89 LocalId,
91 GlobalId,
93 WorkgroupId,
95 WorkgroupBarrier,
97 StorageBarrier,
99 AtomicAdd,
101 AtomicLoad,
103 AtomicStore,
105}
106
107impl WgslContextMethod {
108 pub fn to_wgsl(&self) -> &'static str {
110 match self {
111 WgslContextMethod::LocalId => "local_invocation_id.x",
112 WgslContextMethod::GlobalId => "global_invocation_id.x",
113 WgslContextMethod::WorkgroupId => "workgroup_id.x",
114 WgslContextMethod::WorkgroupBarrier => "workgroupBarrier()",
115 WgslContextMethod::StorageBarrier => "storageBarrier()",
116 WgslContextMethod::AtomicAdd => "atomicAdd",
117 WgslContextMethod::AtomicLoad => "atomicLoad",
118 WgslContextMethod::AtomicStore => "atomicStore",
119 }
120 }
121
122 pub fn from_name(name: &str) -> Option<Self> {
124 match name {
125 "thread_id" | "local_id" => Some(WgslContextMethod::LocalId),
126 "global_thread_id" | "global_id" => Some(WgslContextMethod::GlobalId),
127 "workgroup_id" | "block_id" => Some(WgslContextMethod::WorkgroupId),
128 "sync_threads" | "barrier" => Some(WgslContextMethod::WorkgroupBarrier),
129 "thread_fence" | "storage_fence" => Some(WgslContextMethod::StorageBarrier),
130 _ => None,
131 }
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138
139 #[test]
140 fn test_handler_codegen_config() {
141 let config = HandlerCodegenConfig::new()
142 .without_inlining()
143 .without_bounds_checking();
144
145 assert!(!config.inline_context_methods);
146 assert!(!config.bounds_checking);
147 }
148
149 #[test]
150 fn test_context_method_lookup() {
151 assert_eq!(
152 WgslContextMethod::from_name("thread_id"),
153 Some(WgslContextMethod::LocalId)
154 );
155 assert_eq!(
156 WgslContextMethod::from_name("sync_threads"),
157 Some(WgslContextMethod::WorkgroupBarrier)
158 );
159 assert_eq!(WgslContextMethod::from_name("unknown"), None);
160 }
161
162 #[test]
163 fn test_context_method_wgsl() {
164 assert_eq!(
165 WgslContextMethod::LocalId.to_wgsl(),
166 "local_invocation_id.x"
167 );
168 assert_eq!(
169 WgslContextMethod::WorkgroupBarrier.to_wgsl(),
170 "workgroupBarrier()"
171 );
172 }
173}