Skip to main content

oxihuman_export/
compute_shader_export.rs

1// Copyright (C) 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3#![allow(dead_code)]
4
5//! Generic compute shader stub export.
6
7/// Target API for compute shader.
8#[derive(Clone, Copy, PartialEq)]
9pub enum ComputeApi {
10    WebGpu,
11    Vulkan,
12    Metal,
13    DirectX12,
14}
15
16impl ComputeApi {
17    pub fn name(&self) -> &'static str {
18        match self {
19            ComputeApi::WebGpu => "WebGPU",
20            ComputeApi::Vulkan => "Vulkan",
21            ComputeApi::Metal => "Metal",
22            ComputeApi::DirectX12 => "DirectX 12",
23        }
24    }
25}
26
27/// Compute dispatch configuration.
28pub struct DispatchConfig {
29    pub group_size_x: u32,
30    pub group_size_y: u32,
31    pub group_size_z: u32,
32}
33
34impl Default for DispatchConfig {
35    fn default() -> Self {
36        Self {
37            group_size_x: 64,
38            group_size_y: 1,
39            group_size_z: 1,
40        }
41    }
42}
43
44/// A generic compute shader export.
45pub struct ComputeShaderExport {
46    pub api: ComputeApi,
47    pub source: String,
48    pub entry_point: String,
49    pub dispatch: DispatchConfig,
50    pub bindings: Vec<String>,
51}
52
53/// Create a new compute shader export.
54pub fn new_compute_shader_export(api: ComputeApi, entry: &str) -> ComputeShaderExport {
55    ComputeShaderExport {
56        api,
57        source: String::new(),
58        entry_point: entry.to_string(),
59        dispatch: DispatchConfig::default(),
60        bindings: Vec::new(),
61    }
62}
63
64/// Set the shader source.
65pub fn set_compute_source(exp: &mut ComputeShaderExport, src: &str) {
66    exp.source = src.to_string();
67}
68
69/// Add a binding declaration.
70pub fn add_compute_binding(exp: &mut ComputeShaderExport, binding: &str) {
71    exp.bindings.push(binding.to_string());
72}
73
74/// Binding count.
75pub fn compute_binding_count(exp: &ComputeShaderExport) -> usize {
76    exp.bindings.len()
77}
78
79/// Compute number of groups for a given element count.
80pub fn compute_group_count(exp: &ComputeShaderExport, element_count: u32) -> u32 {
81    let gs = exp.dispatch.group_size_x.max(1);
82    element_count.div_ceil(gs)
83}
84
85/// Validate (non-empty source and entry point).
86pub fn validate_compute_shader(exp: &ComputeShaderExport) -> bool {
87    !exp.source.is_empty() && !exp.entry_point.is_empty()
88}
89
90/// Render a summary string.
91pub fn render_compute_summary(exp: &ComputeShaderExport) -> String {
92    format!(
93        "API:{} Entry:{} Bindings:{} GroupSize:{}x{}x{}",
94        exp.api.name(),
95        exp.entry_point,
96        exp.bindings.len(),
97        exp.dispatch.group_size_x,
98        exp.dispatch.group_size_y,
99        exp.dispatch.group_size_z,
100    )
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106
107    #[test]
108    fn new_export_empty_source() {
109        let exp = new_compute_shader_export(ComputeApi::Vulkan, "cs_main");
110        assert!(exp.source.is_empty() /* no source */);
111    }
112
113    #[test]
114    fn set_source_updates() {
115        let mut exp = new_compute_shader_export(ComputeApi::Metal, "main");
116        set_compute_source(&mut exp, "kernel void main(){}");
117        assert!(!exp.source.is_empty() /* has source */);
118    }
119
120    #[test]
121    fn add_binding_increments() {
122        let mut exp = new_compute_shader_export(ComputeApi::WebGpu, "cs");
123        add_compute_binding(
124            &mut exp,
125            "@group(0) @binding(0) var<storage> buf: array<f32>",
126        );
127        assert_eq!(compute_binding_count(&exp), 1 /* one binding */);
128    }
129
130    #[test]
131    fn api_name_correct() {
132        assert_eq!(ComputeApi::DirectX12.name(), "DirectX 12" /* DX12 */);
133    }
134
135    #[test]
136    fn compute_group_count_correct() {
137        let exp = new_compute_shader_export(ComputeApi::Vulkan, "cs");
138        let groups = compute_group_count(&exp, 128);
139        assert_eq!(groups, 2 /* 128 / 64 */);
140    }
141
142    #[test]
143    fn compute_group_count_ceiling() {
144        let exp = new_compute_shader_export(ComputeApi::Vulkan, "cs");
145        let groups = compute_group_count(&exp, 65);
146        assert_eq!(groups, 2 /* ceil(65/64) = 2 */);
147    }
148
149    #[test]
150    fn validate_needs_source_and_entry() {
151        let mut exp = new_compute_shader_export(ComputeApi::Metal, "main");
152        assert!(!validate_compute_shader(&exp) /* no source */);
153        set_compute_source(&mut exp, "// code");
154        assert!(validate_compute_shader(&exp) /* now valid */);
155    }
156
157    #[test]
158    fn render_summary_contains_api() {
159        let exp = new_compute_shader_export(ComputeApi::WebGpu, "main");
160        let s = render_compute_summary(&exp);
161        assert!(s.contains("WebGPU") /* API in summary */);
162    }
163
164    #[test]
165    fn default_dispatch_group_size() {
166        let exp = new_compute_shader_export(ComputeApi::Vulkan, "cs");
167        assert_eq!(exp.dispatch.group_size_x, 64 /* default 64 */);
168    }
169}