Skip to main content

ringkernel_wgpu_codegen/
handler.rs

1//! Handler function parsing and codegen configuration.
2//!
3//! Parses Rust handler functions and prepares them for WGSL transpilation.
4
5/// Configuration for handler code generation.
6#[derive(Debug, Clone, Default)]
7pub struct HandlerCodegenConfig {
8    /// Whether to inline context method calls.
9    pub inline_context_methods: bool,
10    /// Whether to generate bounds checking.
11    pub bounds_checking: bool,
12}
13
14impl HandlerCodegenConfig {
15    /// Create a new configuration with defaults.
16    pub fn new() -> Self {
17        Self {
18            inline_context_methods: true,
19            bounds_checking: true,
20        }
21    }
22
23    /// Disable context method inlining.
24    pub fn without_inlining(mut self) -> Self {
25        self.inline_context_methods = false;
26        self
27    }
28
29    /// Disable bounds checking.
30    pub fn without_bounds_checking(mut self) -> Self {
31        self.bounds_checking = false;
32        self
33    }
34}
35
36/// Parsed handler signature.
37#[derive(Debug, Clone)]
38pub struct HandlerSignature {
39    /// Function name.
40    pub name: String,
41    /// Parameters.
42    pub params: Vec<HandlerParam>,
43    /// Return type.
44    pub return_type: HandlerReturnType,
45    /// Whether the handler takes a context parameter.
46    pub has_context: bool,
47    /// Index of the message parameter (if any).
48    pub message_param: Option<usize>,
49}
50
51/// Handler parameter description.
52#[derive(Debug, Clone)]
53pub struct HandlerParam {
54    /// Parameter name.
55    pub name: String,
56    /// Parameter kind.
57    pub kind: HandlerParamKind,
58    /// WGSL type string.
59    pub wgsl_type: String,
60}
61
62/// Kind of handler parameter.
63#[derive(Debug, Clone, PartialEq, Eq)]
64pub enum HandlerParamKind {
65    /// RingContext parameter.
66    Context,
67    /// Message parameter.
68    Message,
69    /// Buffer parameter (slice).
70    Buffer { mutable: bool },
71    /// Scalar value.
72    Scalar,
73}
74
75/// Handler return type.
76#[derive(Debug, Clone)]
77pub enum HandlerReturnType {
78    /// No return value.
79    Unit,
80    /// Returns a value type.
81    Value(String),
82    /// Returns a message type.
83    Message(String),
84}
85
86/// WGSL context method that can be inlined.
87#[derive(Debug, Clone, Copy, PartialEq, Eq)]
88pub enum WgslContextMethod {
89    /// Thread ID within workgroup.
90    LocalId,
91    /// Global thread ID.
92    GlobalId,
93    /// Workgroup ID.
94    WorkgroupId,
95    /// Workgroup barrier.
96    WorkgroupBarrier,
97    /// Storage barrier.
98    StorageBarrier,
99    /// Atomic add.
100    AtomicAdd,
101    /// Atomic load.
102    AtomicLoad,
103    /// Atomic store.
104    AtomicStore,
105}
106
107impl WgslContextMethod {
108    /// Get the WGSL code for this context method.
109    pub fn to_wgsl(&self) -> &'static str {
110        match self {
111            WgslContextMethod::LocalId => "local_invocation_id.x",
112            WgslContextMethod::GlobalId => "global_invocation_id.x",
113            WgslContextMethod::WorkgroupId => "workgroup_id.x",
114            WgslContextMethod::WorkgroupBarrier => "workgroupBarrier()",
115            WgslContextMethod::StorageBarrier => "storageBarrier()",
116            WgslContextMethod::AtomicAdd => "atomicAdd",
117            WgslContextMethod::AtomicLoad => "atomicLoad",
118            WgslContextMethod::AtomicStore => "atomicStore",
119        }
120    }
121
122    /// Look up a context method by name.
123    pub fn from_name(name: &str) -> Option<Self> {
124        match name {
125            "thread_id" | "local_id" => Some(WgslContextMethod::LocalId),
126            "global_thread_id" | "global_id" => Some(WgslContextMethod::GlobalId),
127            "workgroup_id" | "block_id" => Some(WgslContextMethod::WorkgroupId),
128            "sync_threads" | "barrier" => Some(WgslContextMethod::WorkgroupBarrier),
129            "thread_fence" | "storage_fence" => Some(WgslContextMethod::StorageBarrier),
130            _ => None,
131        }
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138
139    #[test]
140    fn test_handler_codegen_config() {
141        let config = HandlerCodegenConfig::new()
142            .without_inlining()
143            .without_bounds_checking();
144
145        assert!(!config.inline_context_methods);
146        assert!(!config.bounds_checking);
147    }
148
149    #[test]
150    fn test_context_method_lookup() {
151        assert_eq!(
152            WgslContextMethod::from_name("thread_id"),
153            Some(WgslContextMethod::LocalId)
154        );
155        assert_eq!(
156            WgslContextMethod::from_name("sync_threads"),
157            Some(WgslContextMethod::WorkgroupBarrier)
158        );
159        assert_eq!(WgslContextMethod::from_name("unknown"), None);
160    }
161
162    #[test]
163    fn test_context_method_wgsl() {
164        assert_eq!(
165            WgslContextMethod::LocalId.to_wgsl(),
166            "local_invocation_id.x"
167        );
168        assert_eq!(
169            WgslContextMethod::WorkgroupBarrier.to_wgsl(),
170            "workgroupBarrier()"
171        );
172    }
173}