Skip to main content

proof_engine/wgpu_backend/
shader_translate.rs

1//! Shader translation between GLSL, WGSL, SPIRV (text), HLSL, and MSL.
2//! Includes reflection (extracting bindings, inputs, outputs) and validation.
3
4use std::fmt;
5
6// ---------------------------------------------------------------------------
7// Error types
8// ---------------------------------------------------------------------------
9
10/// Severity of a shader diagnostic.
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum Severity {
13    Error,
14    Warning,
15    Info,
16}
17
18/// A located shader error / warning.
19#[derive(Debug, Clone)]
20pub struct ShaderError {
21    pub line: usize,
22    pub col: usize,
23    pub message: String,
24    pub severity: Severity,
25}
26
27impl fmt::Display for ShaderError {
28    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29        let sev = match self.severity {
30            Severity::Error   => "error",
31            Severity::Warning => "warning",
32            Severity::Info    => "info",
33        };
34        write!(f, "{}:{}:{}: {}", sev, self.line, self.col, self.message)
35    }
36}
37
38/// Translation error.
39#[derive(Debug, Clone)]
40pub struct TranslateError {
41    pub message: String,
42    pub errors: Vec<ShaderError>,
43}
44
45impl fmt::Display for TranslateError {
46    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47        write!(f, "TranslateError: {}", self.message)?;
48        for e in &self.errors {
49            write!(f, "\n  {}", e)?;
50        }
51        Ok(())
52    }
53}
54
55impl std::error::Error for TranslateError {}
56
57impl TranslateError {
58    pub fn new(msg: impl Into<String>) -> Self {
59        Self { message: msg.into(), errors: Vec::new() }
60    }
61    pub fn with_error(mut self, err: ShaderError) -> Self {
62        self.errors.push(err);
63        self
64    }
65}
66
67// ---------------------------------------------------------------------------
68// Shader language enum
69// ---------------------------------------------------------------------------
70
71#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
72pub enum ShaderLanguage {
73    GLSL,
74    WGSL,
75    SPIRV,
76    HLSL,
77    MSL,
78}
79
80impl fmt::Display for ShaderLanguage {
81    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82        match self {
83            Self::GLSL  => write!(f, "GLSL"),
84            Self::WGSL  => write!(f, "WGSL"),
85            Self::SPIRV => write!(f, "SPIR-V"),
86            Self::HLSL  => write!(f, "HLSL"),
87            Self::MSL   => write!(f, "MSL"),
88        }
89    }
90}
91
92// ---------------------------------------------------------------------------
93// Token types (for the mini-lexer)
94// ---------------------------------------------------------------------------
95
96#[derive(Debug, Clone, PartialEq)]
97pub enum GlslToken {
98    Version(u32),
99    In { location: u32, ty: String, name: String },
100    Out { location: u32, ty: String, name: String },
101    Uniform { ty: String, name: String },
102    UniformBlock { name: String, binding: Option<u32> },
103    Sampler { name: String },
104    MainBegin,
105    MainEnd,
106    Line(String),
107}
108
109#[derive(Debug, Clone, PartialEq)]
110pub enum WgslToken {
111    Struct { name: String, fields: Vec<(String, String)> },
112    Binding { group: u32, binding: u32, name: String, ty: String },
113    VertexOutput { name: String },
114    FragmentOutput { name: String },
115    EntryPoint { stage: String, name: String },
116    Line(String),
117}
118
119// ---------------------------------------------------------------------------
120// GLSL parser helpers
121// ---------------------------------------------------------------------------
122
123/// Parse a GLSL source and extract tokens for the translator.
124fn parse_glsl_tokens(source: &str) -> Vec<GlslToken> {
125    let mut tokens = Vec::new();
126    let mut in_main = false;
127
128    for line in source.lines() {
129        let trimmed = line.trim();
130
131        // #version
132        if trimmed.starts_with("#version") {
133            if let Some(ver_str) = trimmed.strip_prefix("#version") {
134                let ver_str = ver_str.trim().split_whitespace().next().unwrap_or("330");
135                if let Ok(v) = ver_str.parse::<u32>() {
136                    tokens.push(GlslToken::Version(v));
137                }
138            }
139            continue;
140        }
141
142        // layout(location = N) in TYPE NAME;
143        if let Some(rest) = try_parse_layout_in_out(trimmed) {
144            tokens.push(rest);
145            continue;
146        }
147
148        // uniform TYPE NAME;
149        if trimmed.starts_with("uniform ") && !trimmed.contains('{') {
150            let parts: Vec<&str> = trimmed.trim_end_matches(';')
151                .split_whitespace().collect();
152            if parts.len() >= 3 {
153                let ty = parts[1].to_string();
154                let name = parts[2].trim_end_matches(';').to_string();
155                if ty.starts_with("sampler") {
156                    tokens.push(GlslToken::Sampler { name });
157                } else {
158                    tokens.push(GlslToken::Uniform { ty, name });
159                }
160            }
161            continue;
162        }
163
164        // void main()
165        if trimmed.contains("void main") && trimmed.contains('(') {
166            tokens.push(GlslToken::MainBegin);
167            in_main = true;
168            continue;
169        }
170
171        // closing brace of main (simplistic)
172        if in_main && trimmed == "}" {
173            tokens.push(GlslToken::MainEnd);
174            in_main = false;
175            continue;
176        }
177
178        tokens.push(GlslToken::Line(line.to_string()));
179    }
180    tokens
181}
182
183/// Try to parse `layout(location = N) in/out TYPE NAME;`
184fn try_parse_layout_in_out(line: &str) -> Option<GlslToken> {
185    if !line.starts_with("layout") { return None; }
186    let loc = extract_location(line)?;
187    let after_paren = line.find(')')? ;
188    let rest = &line[after_paren + 1..].trim();
189
190    if rest.starts_with("in ") {
191        let parts: Vec<&str> = rest[3..].trim_end_matches(';').split_whitespace().collect();
192        if parts.len() >= 2 {
193            return Some(GlslToken::In {
194                location: loc,
195                ty: parts[0].to_string(),
196                name: parts[1].trim_end_matches(';').to_string(),
197            });
198        }
199    } else if rest.starts_with("out ") {
200        let parts: Vec<&str> = rest[4..].trim_end_matches(';').split_whitespace().collect();
201        if parts.len() >= 2 {
202            return Some(GlslToken::Out {
203                location: loc,
204                ty: parts[0].to_string(),
205                name: parts[1].trim_end_matches(';').to_string(),
206            });
207        }
208    }
209    None
210}
211
212fn extract_location(line: &str) -> Option<u32> {
213    let start = line.find("location")? + "location".len();
214    let rest = &line[start..];
215    let eq = rest.find('=')?;
216    let after_eq = &rest[eq + 1..];
217    let end = after_eq.find(')')?;
218    after_eq[..end].trim().parse::<u32>().ok()
219}
220
221// ---------------------------------------------------------------------------
222// Type translation helpers
223// ---------------------------------------------------------------------------
224
225fn glsl_type_to_wgsl(ty: &str) -> String {
226    match ty {
227        "float" => "f32".into(),
228        "int"   => "i32".into(),
229        "uint"  => "u32".into(),
230        "bool"  => "bool".into(),
231        "vec2"  => "vec2<f32>".into(),
232        "vec3"  => "vec3<f32>".into(),
233        "vec4"  => "vec4<f32>".into(),
234        "ivec2" => "vec2<i32>".into(),
235        "ivec3" => "vec3<i32>".into(),
236        "ivec4" => "vec4<i32>".into(),
237        "uvec2" => "vec2<u32>".into(),
238        "uvec3" => "vec3<u32>".into(),
239        "uvec4" => "vec4<u32>".into(),
240        "mat2"  => "mat2x2<f32>".into(),
241        "mat3"  => "mat3x3<f32>".into(),
242        "mat4"  => "mat4x4<f32>".into(),
243        "sampler2D" => "texture_2d<f32>".into(),
244        other => other.to_string(),
245    }
246}
247
248fn wgsl_type_to_glsl(ty: &str) -> String {
249    match ty {
250        "f32"             => "float".into(),
251        "i32"             => "int".into(),
252        "u32"             => "uint".into(),
253        "vec2<f32>"       => "vec2".into(),
254        "vec3<f32>"       => "vec3".into(),
255        "vec4<f32>"       => "vec4".into(),
256        "vec2<i32>"       => "ivec2".into(),
257        "vec3<i32>"       => "ivec3".into(),
258        "vec4<i32>"       => "ivec4".into(),
259        "vec2<u32>"       => "uvec2".into(),
260        "vec3<u32>"       => "uvec3".into(),
261        "vec4<u32>"       => "uvec4".into(),
262        "mat2x2<f32>"     => "mat2".into(),
263        "mat3x3<f32>"     => "mat3".into(),
264        "mat4x4<f32>"     => "mat4".into(),
265        "texture_2d<f32>" => "sampler2D".into(),
266        other => other.to_string(),
267    }
268}
269
270/// Translate GLSL built-in function calls to WGSL equivalents in a line.
271fn translate_glsl_builtins_to_wgsl(line: &str) -> String {
272    let mut out = line.to_string();
273    // texture2D(sampler, uv) -> textureSample(sampler, sampler_sampler, uv)
274    // We do a simple string replacement for common patterns.
275    if out.contains("texture2D(") {
276        out = out.replace("texture2D(", "textureSample(");
277    }
278    if out.contains("texture(") {
279        out = out.replace("texture(", "textureSample(");
280    }
281    // gl_Position -> output.position
282    out = out.replace("gl_Position", "output.position");
283    // gl_FragColor -> output.color (simplified)
284    out = out.replace("gl_FragColor", "output.color");
285    // GLSL smoothstep, mix, clamp have the same names in WGSL
286    out
287}
288
289/// Translate WGSL built-in patterns back to GLSL.
290fn translate_wgsl_builtins_to_glsl(line: &str) -> String {
291    let mut out = line.to_string();
292    out = out.replace("textureSample(", "texture(");
293    out = out.replace("output.position", "gl_Position");
294    out = out.replace("output.color", "gl_FragColor");
295    out
296}
297
298// ---------------------------------------------------------------------------
299// GLSL -> WGSL
300// ---------------------------------------------------------------------------
301
302/// Translate a GLSL shader to WGSL.
303pub fn glsl_to_wgsl(glsl_source: &str) -> Result<String, TranslateError> {
304    let tokens = parse_glsl_tokens(glsl_source);
305    let mut wgsl = String::new();
306    let mut inputs: Vec<(u32, String, String)> = Vec::new();
307    let mut outputs: Vec<(u32, String, String)> = Vec::new();
308    let mut uniforms: Vec<(String, String)> = Vec::new();
309    let mut samplers: Vec<String> = Vec::new();
310    let mut body_lines: Vec<String> = Vec::new();
311    let mut in_body = false;
312    let mut binding_counter = 0u32;
313
314    for token in &tokens {
315        match token {
316            GlslToken::Version(_) => {}
317            GlslToken::In { location, ty, name } => {
318                inputs.push((*location, ty.clone(), name.clone()));
319            }
320            GlslToken::Out { location, ty, name } => {
321                outputs.push((*location, ty.clone(), name.clone()));
322            }
323            GlslToken::Uniform { ty, name } => {
324                uniforms.push((ty.clone(), name.clone()));
325            }
326            GlslToken::Sampler { name } => {
327                samplers.push(name.clone());
328            }
329            GlslToken::UniformBlock { .. } => {}
330            GlslToken::MainBegin => {
331                in_body = true;
332            }
333            GlslToken::MainEnd => {
334                in_body = false;
335            }
336            GlslToken::Line(l) => {
337                if in_body {
338                    body_lines.push(l.clone());
339                }
340            }
341        }
342    }
343
344    // Emit input struct
345    if !inputs.is_empty() {
346        wgsl.push_str("struct VertexInput {\n");
347        for (loc, ty, name) in &inputs {
348            wgsl.push_str(&format!(
349                "    @location({}) {}: {},\n",
350                loc, name, glsl_type_to_wgsl(ty)
351            ));
352        }
353        wgsl.push_str("};\n\n");
354    }
355
356    // Emit output struct
357    if !outputs.is_empty() {
358        wgsl.push_str("struct VertexOutput {\n");
359        wgsl.push_str("    @builtin(position) position: vec4<f32>,\n");
360        for (loc, ty, name) in &outputs {
361            wgsl.push_str(&format!(
362                "    @location({}) {}: {},\n",
363                loc, name, glsl_type_to_wgsl(ty)
364            ));
365        }
366        wgsl.push_str("};\n\n");
367    }
368
369    // Emit uniforms as @group(0) @binding(N)
370    for (ty, name) in &uniforms {
371        wgsl.push_str(&format!(
372            "@group(0) @binding({}) var<uniform> {}: {};\n",
373            binding_counter, name, glsl_type_to_wgsl(ty)
374        ));
375        binding_counter += 1;
376    }
377
378    // Emit samplers
379    for name in &samplers {
380        wgsl.push_str(&format!(
381            "@group(0) @binding({}) var {}: texture_2d<f32>;\n",
382            binding_counter, name,
383        ));
384        binding_counter += 1;
385        wgsl.push_str(&format!(
386            "@group(0) @binding({}) var {}_sampler: sampler;\n",
387            binding_counter, name,
388        ));
389        binding_counter += 1;
390    }
391
392    if !uniforms.is_empty() || !samplers.is_empty() {
393        wgsl.push('\n');
394    }
395
396    // Emit entry point
397    wgsl.push_str("@vertex\n");
398    wgsl.push_str("fn vs_main(input: VertexInput) -> VertexOutput {\n");
399    wgsl.push_str("    var output: VertexOutput;\n");
400    for line in &body_lines {
401        let translated = translate_glsl_builtins_to_wgsl(line);
402        let translated = translated.trim();
403        if !translated.is_empty() {
404            wgsl.push_str(&format!("    {}\n", translated));
405        }
406    }
407    wgsl.push_str("    return output;\n");
408    wgsl.push_str("}\n");
409
410    Ok(wgsl)
411}
412
413// ---------------------------------------------------------------------------
414// GLSL -> SPIRV text representation
415// ---------------------------------------------------------------------------
416
417/// Produce a SPIR-V text representation from GLSL source.
418/// This is a simplified textual output, not actual binary SPIR-V.
419pub fn glsl_to_spirv_text(glsl_source: &str) -> Result<String, TranslateError> {
420    let tokens = parse_glsl_tokens(glsl_source);
421    let mut spirv = String::new();
422    spirv.push_str("; SPIR-V text representation (generated by proof-engine shader_translate)\n");
423    spirv.push_str("; Magic:     0x07230203\n");
424    spirv.push_str("; Version:   1.0\n");
425    spirv.push_str("; Generator: proof-engine\n\n");
426
427    let mut id_counter = 1u32;
428    let mut next_id = || { let id = id_counter; id_counter += 1; id };
429
430    // Capabilities
431    spirv.push_str("               OpCapability Shader\n");
432    let ext_id = next_id();
433    spirv.push_str(&format!("          %{}  = OpExtInstImport \"GLSL.std.450\"\n", ext_id));
434    spirv.push_str("               OpMemoryModel Logical GLSL450\n");
435
436    // Entry point
437    let main_id = next_id();
438    let mut interface_ids = Vec::new();
439
440    for token in &tokens {
441        match token {
442            GlslToken::In { location, ty, name } => {
443                let var_id = next_id();
444                interface_ids.push(var_id);
445                spirv.push_str(&format!(
446                    "               OpDecorate %{} Location {}\n",
447                    var_id, location
448                ));
449                let type_id = next_id();
450                spirv.push_str(&format!(
451                    "       %{}  = OpTypePointer Input %{} ; {}: {}\n",
452                    var_id, type_id, name, ty
453                ));
454            }
455            GlslToken::Out { location, ty, name } => {
456                let var_id = next_id();
457                interface_ids.push(var_id);
458                spirv.push_str(&format!(
459                    "               OpDecorate %{} Location {}\n",
460                    var_id, location
461                ));
462                let type_id = next_id();
463                spirv.push_str(&format!(
464                    "       %{}  = OpTypePointer Output %{} ; {}: {}\n",
465                    var_id, type_id, name, ty
466                ));
467            }
468            GlslToken::Uniform { ty, name } => {
469                let var_id = next_id();
470                let type_id = next_id();
471                spirv.push_str(&format!(
472                    "       %{}  = OpTypePointer Uniform %{} ; uniform {}: {}\n",
473                    var_id, type_id, name, ty
474                ));
475            }
476            _ => {}
477        }
478    }
479
480    let iface_str: String = interface_ids.iter().map(|id| format!("%{}", id)).collect::<Vec<_>>().join(" ");
481    spirv.push_str(&format!(
482        "               OpEntryPoint Vertex %{} \"main\" {}\n",
483        main_id, iface_str
484    ));
485
486    // Main function
487    let void_id = next_id();
488    let func_type_id = next_id();
489    spirv.push_str(&format!("       %{}  = OpTypeVoid\n", void_id));
490    spirv.push_str(&format!("       %{}  = OpTypeFunction %{}\n", func_type_id, void_id));
491    spirv.push_str(&format!("       %{}  = OpFunction %{} None %{}\n", main_id, void_id, func_type_id));
492    let label_id = next_id();
493    spirv.push_str(&format!("       %{}  = OpLabel\n", label_id));
494    spirv.push_str("               OpReturn\n");
495    spirv.push_str("               OpFunctionEnd\n");
496
497    Ok(spirv)
498}
499
500// ---------------------------------------------------------------------------
501// WGSL -> GLSL
502// ---------------------------------------------------------------------------
503
504/// Translate a simple WGSL shader back to GLSL 330 core.
505pub fn wgsl_to_glsl(wgsl_source: &str) -> Result<String, TranslateError> {
506    let mut glsl = String::from("#version 330 core\n\n");
507    let mut in_struct = false;
508    let mut current_struct_name = String::new();
509    let mut in_fn = false;
510    let mut fn_body_lines: Vec<String> = Vec::new();
511
512    for line in wgsl_source.lines() {
513        let trimmed = line.trim();
514
515        // Parse @group(G) @binding(B) var<uniform> NAME: TYPE;
516        if trimmed.starts_with("@group") && trimmed.contains("var<uniform>") {
517            if let Some(rest) = trimmed.split("var<uniform>").nth(1) {
518                let rest = rest.trim().trim_end_matches(';');
519                let parts: Vec<&str> = rest.splitn(2, ':').collect();
520                if parts.len() == 2 {
521                    let name = parts[0].trim();
522                    let ty = parts[1].trim();
523                    glsl.push_str(&format!("uniform {} {};\n", wgsl_type_to_glsl(ty), name));
524                }
525            }
526            continue;
527        }
528
529        // Parse @group(G) @binding(B) var NAME: texture_2d<f32>;
530        if trimmed.starts_with("@group") && trimmed.contains("var ") && trimmed.contains("texture") {
531            // Skip texture bindings (handled as sampler2D in GLSL)
532            continue;
533        }
534
535        // Parse @group(G) @binding(B) var NAME: sampler;
536        if trimmed.starts_with("@group") && trimmed.contains("sampler") && !trimmed.contains("texture") {
537            // The corresponding texture was already skipped; emit a sampler2D
538            if let Some(rest) = trimmed.split("var ").nth(1) {
539                let name = rest.split(':').next().unwrap_or("").trim();
540                // Strip _sampler suffix
541                let base = name.strip_suffix("_sampler").unwrap_or(name);
542                glsl.push_str(&format!("uniform sampler2D {};\n", base));
543            }
544            continue;
545        }
546
547        // Struct
548        if trimmed.starts_with("struct ") {
549            in_struct = true;
550            current_struct_name = trimmed
551                .strip_prefix("struct ")
552                .unwrap_or("")
553                .trim_end_matches('{')
554                .trim()
555                .to_string();
556            continue;
557        }
558
559        if in_struct {
560            if trimmed.starts_with('}') {
561                in_struct = false;
562                continue;
563            }
564            // @location(N) name: type,  or  @builtin(position) ...
565            if trimmed.contains("@location") {
566                if let Some(loc) = extract_wgsl_location(trimmed) {
567                    let rest = trimmed.split(')').last().unwrap_or("").trim();
568                    let parts: Vec<&str> = rest.splitn(2, ':').collect();
569                    if parts.len() == 2 {
570                        let name = parts[0].trim();
571                        let ty = parts[1].trim().trim_end_matches(',');
572                        let is_input = current_struct_name.contains("Input");
573                        let qualifier = if is_input { "in" } else { "out" };
574                        glsl.push_str(&format!(
575                            "layout(location = {}) {} {} {};\n",
576                            loc, qualifier, wgsl_type_to_glsl(ty), name
577                        ));
578                    }
579                }
580            }
581            continue;
582        }
583
584        // @vertex fn ...
585        if trimmed.starts_with("@vertex") || trimmed.starts_with("@fragment") || trimmed.starts_with("@compute") {
586            in_fn = true;
587            fn_body_lines.clear();
588            continue;
589        }
590
591        if trimmed.starts_with("fn ") && in_fn {
592            // Skip the fn signature line
593            continue;
594        }
595
596        if in_fn {
597            if trimmed == "}" {
598                // Emit main function
599                glsl.push_str("\nvoid main() {\n");
600                for bl in &fn_body_lines {
601                    let translated = translate_wgsl_builtins_to_glsl(bl);
602                    let translated = translated.trim();
603                    if !translated.is_empty()
604                        && !translated.starts_with("var output")
605                        && !translated.starts_with("return")
606                    {
607                        glsl.push_str(&format!("    {}\n", translated));
608                    }
609                }
610                glsl.push_str("}\n");
611                in_fn = false;
612                continue;
613            }
614            fn_body_lines.push(line.to_string());
615        }
616    }
617
618    Ok(glsl)
619}
620
621fn extract_wgsl_location(line: &str) -> Option<u32> {
622    let start = line.find("@location(")? + "@location(".len();
623    let rest = &line[start..];
624    let end = rest.find(')')?;
625    rest[..end].trim().parse::<u32>().ok()
626}
627
628// ---------------------------------------------------------------------------
629// Shader reflection
630// ---------------------------------------------------------------------------
631
632/// Reflected information about a shader module.
633#[derive(Debug, Clone, Default)]
634pub struct ShaderReflection {
635    pub inputs: Vec<ReflectedBinding>,
636    pub outputs: Vec<ReflectedBinding>,
637    pub uniforms: Vec<ReflectedBinding>,
638    pub storage_buffers: Vec<ReflectedBinding>,
639    pub textures: Vec<ReflectedBinding>,
640    pub samplers: Vec<ReflectedBinding>,
641    pub workgroup_size: Option<[u32; 3]>,
642}
643
644/// A reflected binding.
645#[derive(Debug, Clone)]
646pub struct ReflectedBinding {
647    pub name: String,
648    pub ty: String,
649    pub location_or_binding: u32,
650    pub group: Option<u32>,
651}
652
653/// Reflect a GLSL shader.
654pub fn reflect_glsl(source: &str) -> ShaderReflection {
655    let tokens = parse_glsl_tokens(source);
656    let mut refl = ShaderReflection::default();
657
658    for token in &tokens {
659        match token {
660            GlslToken::In { location, ty, name } => {
661                refl.inputs.push(ReflectedBinding {
662                    name: name.clone(),
663                    ty: ty.clone(),
664                    location_or_binding: *location,
665                    group: None,
666                });
667            }
668            GlslToken::Out { location, ty, name } => {
669                refl.outputs.push(ReflectedBinding {
670                    name: name.clone(),
671                    ty: ty.clone(),
672                    location_or_binding: *location,
673                    group: None,
674                });
675            }
676            GlslToken::Uniform { ty, name } => {
677                refl.uniforms.push(ReflectedBinding {
678                    name: name.clone(),
679                    ty: ty.clone(),
680                    location_or_binding: 0,
681                    group: None,
682                });
683            }
684            GlslToken::Sampler { name } => {
685                refl.samplers.push(ReflectedBinding {
686                    name: name.clone(),
687                    ty: "sampler2D".into(),
688                    location_or_binding: 0,
689                    group: None,
690                });
691            }
692            _ => {}
693        }
694    }
695
696    // Check for compute workgroup size: layout(local_size_x=X, ...)
697    for line in source.lines() {
698        let trimmed = line.trim();
699        if trimmed.contains("local_size_x") {
700            let mut ws = [1u32, 1, 1];
701            for dim in ["local_size_x", "local_size_y", "local_size_z"].iter().enumerate() {
702                if let Some(pos) = trimmed.find(dim.1) {
703                    let rest = &trimmed[pos + dim.1.len()..];
704                    if let Some(eq) = rest.find('=') {
705                        let after = &rest[eq + 1..];
706                        let num_str: String = after.chars().take_while(|c| c.is_ascii_digit()).collect();
707                        if let Ok(n) = num_str.parse::<u32>() {
708                            ws[dim.0] = n;
709                        }
710                    }
711                }
712            }
713            refl.workgroup_size = Some(ws);
714        }
715    }
716
717    refl
718}
719
720/// Reflect a WGSL shader.
721pub fn reflect_wgsl(source: &str) -> ShaderReflection {
722    let mut refl = ShaderReflection::default();
723
724    for line in source.lines() {
725        let trimmed = line.trim();
726
727        // @group(G) @binding(B) var<uniform> ...
728        if trimmed.starts_with("@group") && trimmed.contains("@binding") {
729            let group = extract_wgsl_group(trimmed);
730            let binding = extract_wgsl_binding_num(trimmed);
731
732            if trimmed.contains("var<uniform>") {
733                if let Some(rest) = trimmed.split("var<uniform>").nth(1) {
734                    let rest = rest.trim().trim_end_matches(';');
735                    let parts: Vec<&str> = rest.splitn(2, ':').collect();
736                    if parts.len() == 2 {
737                        refl.uniforms.push(ReflectedBinding {
738                            name: parts[0].trim().to_string(),
739                            ty: parts[1].trim().to_string(),
740                            location_or_binding: binding.unwrap_or(0),
741                            group,
742                        });
743                    }
744                }
745            } else if trimmed.contains("var<storage") {
746                if let Some(rest) = trimmed.split('>').last() {
747                    let rest = rest.trim().trim_end_matches(';');
748                    let parts: Vec<&str> = rest.splitn(2, ':').collect();
749                    if parts.len() == 2 {
750                        refl.storage_buffers.push(ReflectedBinding {
751                            name: parts[0].trim().to_string(),
752                            ty: parts[1].trim().to_string(),
753                            location_or_binding: binding.unwrap_or(0),
754                            group,
755                        });
756                    }
757                }
758            } else if trimmed.contains("texture") {
759                if let Some(rest) = trimmed.split("var ").nth(1) {
760                    let parts: Vec<&str> = rest.splitn(2, ':').collect();
761                    if parts.len() == 2 {
762                        refl.textures.push(ReflectedBinding {
763                            name: parts[0].trim().to_string(),
764                            ty: parts[1].trim().trim_end_matches(';').to_string(),
765                            location_or_binding: binding.unwrap_or(0),
766                            group,
767                        });
768                    }
769                }
770            } else if trimmed.contains("sampler") {
771                if let Some(rest) = trimmed.split("var ").nth(1) {
772                    let name = rest.split(':').next().unwrap_or("").trim().to_string();
773                    refl.samplers.push(ReflectedBinding {
774                        name,
775                        ty: "sampler".into(),
776                        location_or_binding: binding.unwrap_or(0),
777                        group,
778                    });
779                }
780            }
781        }
782
783        // @location(N) name: type  (inside a struct)
784        if trimmed.contains("@location(") && !trimmed.starts_with("@group") {
785            if let Some(loc) = extract_wgsl_location(trimmed) {
786                let rest = trimmed.split(')').last().unwrap_or("").trim();
787                let parts: Vec<&str> = rest.splitn(2, ':').collect();
788                if parts.len() == 2 {
789                    let binding = ReflectedBinding {
790                        name: parts[0].trim().to_string(),
791                        ty: parts[1].trim().trim_end_matches(',').to_string(),
792                        location_or_binding: loc,
793                        group: None,
794                    };
795                    // Heuristic: if we haven't seen outputs yet, treat as input
796                    refl.inputs.push(binding);
797                }
798            }
799        }
800
801        // @workgroup_size(X, Y, Z)
802        if trimmed.contains("@workgroup_size(") {
803            let start = trimmed.find("@workgroup_size(").unwrap() + "@workgroup_size(".len();
804            let rest = &trimmed[start..];
805            if let Some(end) = rest.find(')') {
806                let nums: Vec<u32> = rest[..end]
807                    .split(',')
808                    .filter_map(|s| s.trim().parse::<u32>().ok())
809                    .collect();
810                let mut ws = [1u32, 1, 1];
811                for (i, &n) in nums.iter().enumerate().take(3) {
812                    ws[i] = n;
813                }
814                refl.workgroup_size = Some(ws);
815            }
816        }
817    }
818
819    refl
820}
821
822fn extract_wgsl_group(line: &str) -> Option<u32> {
823    let start = line.find("@group(")? + "@group(".len();
824    let rest = &line[start..];
825    let end = rest.find(')')?;
826    rest[..end].trim().parse::<u32>().ok()
827}
828
829fn extract_wgsl_binding_num(line: &str) -> Option<u32> {
830    let start = line.find("@binding(")? + "@binding(".len();
831    let rest = &line[start..];
832    let end = rest.find(')')?;
833    rest[..end].trim().parse::<u32>().ok()
834}
835
836// ---------------------------------------------------------------------------
837// Shader validation
838// ---------------------------------------------------------------------------
839
840/// Validate a shader source in the given language.
841pub fn validate_shader(source: &str, language: ShaderLanguage) -> Vec<ShaderError> {
842    let mut errors = Vec::new();
843
844    match language {
845        ShaderLanguage::GLSL => validate_glsl(source, &mut errors),
846        ShaderLanguage::WGSL => validate_wgsl(source, &mut errors),
847        _ => {
848            // Minimal validation for other languages
849            if source.trim().is_empty() {
850                errors.push(ShaderError {
851                    line: 1, col: 1,
852                    message: "Empty shader source".into(),
853                    severity: Severity::Error,
854                });
855            }
856        }
857    }
858
859    errors
860}
861
862fn validate_glsl(source: &str, errors: &mut Vec<ShaderError>) {
863    let mut has_version = false;
864    let mut has_main = false;
865    let mut brace_depth: i32 = 0;
866
867    for (i, line) in source.lines().enumerate() {
868        let ln = i + 1;
869        let trimmed = line.trim();
870
871        if trimmed.starts_with("#version") {
872            has_version = true;
873            if ln != 1 {
874                errors.push(ShaderError {
875                    line: ln, col: 1,
876                    message: "#version must be on the first line".into(),
877                    severity: Severity::Warning,
878                });
879            }
880        }
881
882        if trimmed.contains("void main") {
883            has_main = true;
884        }
885
886        for ch in trimmed.chars() {
887            if ch == '{' { brace_depth += 1; }
888            if ch == '}' { brace_depth -= 1; }
889        }
890
891        if brace_depth < 0 {
892            errors.push(ShaderError {
893                line: ln, col: 1,
894                message: "Unmatched closing brace".into(),
895                severity: Severity::Error,
896            });
897        }
898    }
899
900    if !has_version {
901        errors.push(ShaderError {
902            line: 1, col: 1,
903            message: "Missing #version directive".into(),
904            severity: Severity::Warning,
905        });
906    }
907
908    if !has_main {
909        errors.push(ShaderError {
910            line: 1, col: 1,
911            message: "Missing void main() entry point".into(),
912            severity: Severity::Error,
913        });
914    }
915
916    if brace_depth != 0 {
917        errors.push(ShaderError {
918            line: source.lines().count(), col: 1,
919            message: format!("Unbalanced braces (depth {})", brace_depth),
920            severity: Severity::Error,
921        });
922    }
923}
924
925fn validate_wgsl(source: &str, errors: &mut Vec<ShaderError>) {
926    let mut has_entry = false;
927    let mut brace_depth: i32 = 0;
928
929    for (i, line) in source.lines().enumerate() {
930        let ln = i + 1;
931        let trimmed = line.trim();
932
933        if trimmed.starts_with("@vertex") || trimmed.starts_with("@fragment") || trimmed.starts_with("@compute") {
934            has_entry = true;
935        }
936
937        for ch in trimmed.chars() {
938            if ch == '{' { brace_depth += 1; }
939            if ch == '}' { brace_depth -= 1; }
940        }
941
942        if brace_depth < 0 {
943            errors.push(ShaderError {
944                line: ln, col: 1,
945                message: "Unmatched closing brace".into(),
946                severity: Severity::Error,
947            });
948        }
949
950        // Check for GLSL-isms that are wrong in WGSL
951        if trimmed.starts_with("#version") {
952            errors.push(ShaderError {
953                line: ln, col: 1,
954                message: "#version is not valid WGSL".into(),
955                severity: Severity::Error,
956            });
957        }
958
959        if trimmed.contains("void main") {
960            errors.push(ShaderError {
961                line: ln, col: 1,
962                message: "WGSL does not use 'void main()'; use @vertex/@fragment fn".into(),
963                severity: Severity::Error,
964            });
965        }
966    }
967
968    if !has_entry {
969        errors.push(ShaderError {
970            line: 1, col: 1,
971            message: "Missing entry point (@vertex, @fragment, or @compute)".into(),
972            severity: Severity::Warning,
973        });
974    }
975
976    if brace_depth != 0 {
977        errors.push(ShaderError {
978            line: source.lines().count(), col: 1,
979            message: format!("Unbalanced braces (depth {})", brace_depth),
980            severity: Severity::Error,
981        });
982    }
983}
984
985// ---------------------------------------------------------------------------
986// Tests
987// ---------------------------------------------------------------------------
988
989#[cfg(test)]
990mod tests {
991    use super::*;
992
993    const SIMPLE_VERT_GLSL: &str = r#"#version 330 core
994layout(location = 0) in vec3 aPos;
995layout(location = 1) in vec2 aUV;
996layout(location = 0) out vec2 vUV;
997uniform mat4 uMVP;
998void main() {
999    gl_Position = uMVP * vec4(aPos, 1.0);
1000    vUV = aUV;
1001}
1002"#;
1003
1004    #[test]
1005    fn parse_glsl_tokens_simple() {
1006        let tokens = parse_glsl_tokens(SIMPLE_VERT_GLSL);
1007        assert!(tokens.iter().any(|t| matches!(t, GlslToken::Version(330))));
1008        assert!(tokens.iter().any(|t| matches!(t, GlslToken::In { location: 0, .. })));
1009        assert!(tokens.iter().any(|t| matches!(t, GlslToken::Out { location: 0, .. })));
1010        assert!(tokens.iter().any(|t| matches!(t, GlslToken::Uniform { .. })));
1011        assert!(tokens.iter().any(|t| matches!(t, GlslToken::MainBegin)));
1012        assert!(tokens.iter().any(|t| matches!(t, GlslToken::MainEnd)));
1013    }
1014
1015    #[test]
1016    fn glsl_to_wgsl_simple() {
1017        let wgsl = glsl_to_wgsl(SIMPLE_VERT_GLSL).unwrap();
1018        assert!(wgsl.contains("struct VertexInput"));
1019        assert!(wgsl.contains("@location(0) aPos: vec3<f32>"));
1020        assert!(wgsl.contains("@location(1) aUV: vec2<f32>"));
1021        assert!(wgsl.contains("struct VertexOutput"));
1022        assert!(wgsl.contains("@group(0) @binding(0) var<uniform> uMVP: mat4x4<f32>"));
1023        assert!(wgsl.contains("@vertex"));
1024        assert!(wgsl.contains("fn vs_main"));
1025    }
1026
1027    #[test]
1028    fn glsl_to_wgsl_translates_builtins() {
1029        let wgsl = glsl_to_wgsl(SIMPLE_VERT_GLSL).unwrap();
1030        assert!(wgsl.contains("output.position"));
1031        assert!(!wgsl.contains("gl_Position"));
1032    }
1033
1034    #[test]
1035    fn wgsl_to_glsl_roundtrip() {
1036        let wgsl = glsl_to_wgsl(SIMPLE_VERT_GLSL).unwrap();
1037        let glsl_back = wgsl_to_glsl(&wgsl).unwrap();
1038        // The round-tripped GLSL should have the key elements
1039        assert!(glsl_back.contains("#version 330 core"));
1040        assert!(glsl_back.contains("uniform mat4x4 uMVP") || glsl_back.contains("uniform mat4 uMVP"));
1041        assert!(glsl_back.contains("void main()"));
1042    }
1043
1044    #[test]
1045    fn glsl_to_spirv_text_has_structure() {
1046        let spirv = glsl_to_spirv_text(SIMPLE_VERT_GLSL).unwrap();
1047        assert!(spirv.contains("OpCapability Shader"));
1048        assert!(spirv.contains("OpMemoryModel Logical GLSL450"));
1049        assert!(spirv.contains("OpEntryPoint Vertex"));
1050        assert!(spirv.contains("OpReturn"));
1051        assert!(spirv.contains("OpFunctionEnd"));
1052    }
1053
1054    #[test]
1055    fn type_translation_glsl_to_wgsl() {
1056        assert_eq!(glsl_type_to_wgsl("float"), "f32");
1057        assert_eq!(glsl_type_to_wgsl("vec3"), "vec3<f32>");
1058        assert_eq!(glsl_type_to_wgsl("mat4"), "mat4x4<f32>");
1059        assert_eq!(glsl_type_to_wgsl("sampler2D"), "texture_2d<f32>");
1060    }
1061
1062    #[test]
1063    fn type_translation_wgsl_to_glsl() {
1064        assert_eq!(wgsl_type_to_glsl("f32"), "float");
1065        assert_eq!(wgsl_type_to_glsl("vec3<f32>"), "vec3");
1066        assert_eq!(wgsl_type_to_glsl("mat4x4<f32>"), "mat4");
1067    }
1068
1069    #[test]
1070    fn reflect_glsl_shader() {
1071        let refl = reflect_glsl(SIMPLE_VERT_GLSL);
1072        assert_eq!(refl.inputs.len(), 2);
1073        assert_eq!(refl.outputs.len(), 1);
1074        assert_eq!(refl.uniforms.len(), 1);
1075        assert_eq!(refl.uniforms[0].name, "uMVP");
1076        assert_eq!(refl.inputs[0].name, "aPos");
1077        assert_eq!(refl.inputs[0].location_or_binding, 0);
1078    }
1079
1080    #[test]
1081    fn reflect_glsl_compute_workgroup() {
1082        let src = r#"#version 430
1083layout(local_size_x=64, local_size_y=1, local_size_z=1) in;
1084void main() {}
1085"#;
1086        let refl = reflect_glsl(src);
1087        assert_eq!(refl.workgroup_size, Some([64, 1, 1]));
1088    }
1089
1090    #[test]
1091    fn reflect_glsl_sampler() {
1092        let src = r#"#version 330 core
1093uniform sampler2D uTexture;
1094void main() {}
1095"#;
1096        let refl = reflect_glsl(src);
1097        assert_eq!(refl.samplers.len(), 1);
1098        assert_eq!(refl.samplers[0].name, "uTexture");
1099    }
1100
1101    #[test]
1102    fn reflect_wgsl_shader() {
1103        let src = r#"
1104@group(0) @binding(0) var<uniform> uMVP: mat4x4<f32>;
1105@group(0) @binding(1) var myTex: texture_2d<f32>;
1106@group(0) @binding(2) var mySampler: sampler;
1107@vertex
1108fn vs_main() -> vec4<f32> {
1109    return vec4<f32>(0.0);
1110}
1111"#;
1112        let refl = reflect_wgsl(src);
1113        assert_eq!(refl.uniforms.len(), 1);
1114        assert_eq!(refl.textures.len(), 1);
1115        assert_eq!(refl.samplers.len(), 1);
1116    }
1117
1118    #[test]
1119    fn reflect_wgsl_compute_workgroup() {
1120        let src = r#"
1121@compute @workgroup_size(256, 1, 1)
1122fn main() {}
1123"#;
1124        let refl = reflect_wgsl(src);
1125        assert_eq!(refl.workgroup_size, Some([256, 1, 1]));
1126    }
1127
1128    #[test]
1129    fn validate_valid_glsl() {
1130        let errs = validate_shader(SIMPLE_VERT_GLSL, ShaderLanguage::GLSL);
1131        // Should have no errors (may have warnings)
1132        let real_errors: Vec<_> = errs.iter().filter(|e| e.severity == Severity::Error).collect();
1133        assert!(real_errors.is_empty(), "Unexpected errors: {:?}", real_errors);
1134    }
1135
1136    #[test]
1137    fn validate_glsl_missing_main() {
1138        let src = "#version 330 core\nuniform float x;\n";
1139        let errs = validate_shader(src, ShaderLanguage::GLSL);
1140        assert!(errs.iter().any(|e| e.message.contains("main")));
1141    }
1142
1143    #[test]
1144    fn validate_glsl_unbalanced_braces() {
1145        let src = "#version 330 core\nvoid main() {\n";
1146        let errs = validate_shader(src, ShaderLanguage::GLSL);
1147        assert!(errs.iter().any(|e| e.message.contains("brace")));
1148    }
1149
1150    #[test]
1151    fn validate_valid_wgsl() {
1152        let wgsl = glsl_to_wgsl(SIMPLE_VERT_GLSL).unwrap();
1153        let errs = validate_shader(&wgsl, ShaderLanguage::WGSL);
1154        let real_errors: Vec<_> = errs.iter().filter(|e| e.severity == Severity::Error).collect();
1155        assert!(real_errors.is_empty(), "Unexpected errors: {:?}", real_errors);
1156    }
1157
1158    #[test]
1159    fn validate_wgsl_with_glsl_isms() {
1160        let src = "#version 330\nvoid main() {}\n";
1161        let errs = validate_shader(src, ShaderLanguage::WGSL);
1162        assert!(errs.iter().any(|e| e.message.contains("#version")));
1163    }
1164
1165    #[test]
1166    fn validate_empty_shader() {
1167        let errs = validate_shader("", ShaderLanguage::HLSL);
1168        assert!(errs.iter().any(|e| e.message.contains("Empty")));
1169    }
1170
1171    #[test]
1172    fn shader_language_display() {
1173        assert_eq!(format!("{}", ShaderLanguage::GLSL), "GLSL");
1174        assert_eq!(format!("{}", ShaderLanguage::SPIRV), "SPIR-V");
1175    }
1176
1177    #[test]
1178    fn translate_error_display() {
1179        let err = TranslateError::new("test error")
1180            .with_error(ShaderError {
1181                line: 5, col: 10,
1182                message: "bad token".into(),
1183                severity: Severity::Error,
1184            });
1185        let s = format!("{}", err);
1186        assert!(s.contains("test error"));
1187        assert!(s.contains("bad token"));
1188    }
1189
1190    #[test]
1191    fn glsl_texture_builtin_translation() {
1192        let line = "vec4 c = texture2D(myTex, uv);";
1193        let translated = translate_glsl_builtins_to_wgsl(line);
1194        assert!(translated.contains("textureSample("));
1195    }
1196
1197    #[test]
1198    fn wgsl_texture_builtin_translation() {
1199        let line = "let c = textureSample(myTex, mySampler, uv);";
1200        let translated = translate_wgsl_builtins_to_glsl(line);
1201        assert!(translated.contains("texture("));
1202    }
1203}