oxihuman_export/
compute_shader_export.rs1#![allow(dead_code)]
4
5#[derive(Clone, Copy, PartialEq)]
9pub enum ComputeApi {
10 WebGpu,
11 Vulkan,
12 Metal,
13 DirectX12,
14}
15
16impl ComputeApi {
17 pub fn name(&self) -> &'static str {
18 match self {
19 ComputeApi::WebGpu => "WebGPU",
20 ComputeApi::Vulkan => "Vulkan",
21 ComputeApi::Metal => "Metal",
22 ComputeApi::DirectX12 => "DirectX 12",
23 }
24 }
25}
26
27pub struct DispatchConfig {
29 pub group_size_x: u32,
30 pub group_size_y: u32,
31 pub group_size_z: u32,
32}
33
34impl Default for DispatchConfig {
35 fn default() -> Self {
36 Self {
37 group_size_x: 64,
38 group_size_y: 1,
39 group_size_z: 1,
40 }
41 }
42}
43
44pub struct ComputeShaderExport {
46 pub api: ComputeApi,
47 pub source: String,
48 pub entry_point: String,
49 pub dispatch: DispatchConfig,
50 pub bindings: Vec<String>,
51}
52
53pub fn new_compute_shader_export(api: ComputeApi, entry: &str) -> ComputeShaderExport {
55 ComputeShaderExport {
56 api,
57 source: String::new(),
58 entry_point: entry.to_string(),
59 dispatch: DispatchConfig::default(),
60 bindings: Vec::new(),
61 }
62}
63
64pub fn set_compute_source(exp: &mut ComputeShaderExport, src: &str) {
66 exp.source = src.to_string();
67}
68
69pub fn add_compute_binding(exp: &mut ComputeShaderExport, binding: &str) {
71 exp.bindings.push(binding.to_string());
72}
73
74pub fn compute_binding_count(exp: &ComputeShaderExport) -> usize {
76 exp.bindings.len()
77}
78
79pub fn compute_group_count(exp: &ComputeShaderExport, element_count: u32) -> u32 {
81 let gs = exp.dispatch.group_size_x.max(1);
82 element_count.div_ceil(gs)
83}
84
85pub fn validate_compute_shader(exp: &ComputeShaderExport) -> bool {
87 !exp.source.is_empty() && !exp.entry_point.is_empty()
88}
89
90pub fn render_compute_summary(exp: &ComputeShaderExport) -> String {
92 format!(
93 "API:{} Entry:{} Bindings:{} GroupSize:{}x{}x{}",
94 exp.api.name(),
95 exp.entry_point,
96 exp.bindings.len(),
97 exp.dispatch.group_size_x,
98 exp.dispatch.group_size_y,
99 exp.dispatch.group_size_z,
100 )
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106
107 #[test]
108 fn new_export_empty_source() {
109 let exp = new_compute_shader_export(ComputeApi::Vulkan, "cs_main");
110 assert!(exp.source.is_empty() );
111 }
112
113 #[test]
114 fn set_source_updates() {
115 let mut exp = new_compute_shader_export(ComputeApi::Metal, "main");
116 set_compute_source(&mut exp, "kernel void main(){}");
117 assert!(!exp.source.is_empty() );
118 }
119
120 #[test]
121 fn add_binding_increments() {
122 let mut exp = new_compute_shader_export(ComputeApi::WebGpu, "cs");
123 add_compute_binding(
124 &mut exp,
125 "@group(0) @binding(0) var<storage> buf: array<f32>",
126 );
127 assert_eq!(compute_binding_count(&exp), 1 );
128 }
129
130 #[test]
131 fn api_name_correct() {
132 assert_eq!(ComputeApi::DirectX12.name(), "DirectX 12" );
133 }
134
135 #[test]
136 fn compute_group_count_correct() {
137 let exp = new_compute_shader_export(ComputeApi::Vulkan, "cs");
138 let groups = compute_group_count(&exp, 128);
139 assert_eq!(groups, 2 );
140 }
141
142 #[test]
143 fn compute_group_count_ceiling() {
144 let exp = new_compute_shader_export(ComputeApi::Vulkan, "cs");
145 let groups = compute_group_count(&exp, 65);
146 assert_eq!(groups, 2 );
147 }
148
149 #[test]
150 fn validate_needs_source_and_entry() {
151 let mut exp = new_compute_shader_export(ComputeApi::Metal, "main");
152 assert!(!validate_compute_shader(&exp) );
153 set_compute_source(&mut exp, "// code");
154 assert!(validate_compute_shader(&exp) );
155 }
156
157 #[test]
158 fn render_summary_contains_api() {
159 let exp = new_compute_shader_export(ComputeApi::WebGpu, "main");
160 let s = render_compute_summary(&exp);
161 assert!(s.contains("WebGPU") );
162 }
163
164 #[test]
165 fn default_dispatch_group_size() {
166 let exp = new_compute_shader_export(ComputeApi::Vulkan, "cs");
167 assert_eq!(exp.dispatch.group_size_x, 64 );
168 }
169}