1use std::fmt;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum Severity {
13 Error,
14 Warning,
15 Info,
16}
17
18#[derive(Debug, Clone)]
20pub struct ShaderError {
21 pub line: usize,
22 pub col: usize,
23 pub message: String,
24 pub severity: Severity,
25}
26
27impl fmt::Display for ShaderError {
28 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29 let sev = match self.severity {
30 Severity::Error => "error",
31 Severity::Warning => "warning",
32 Severity::Info => "info",
33 };
34 write!(f, "{}:{}:{}: {}", sev, self.line, self.col, self.message)
35 }
36}
37
38#[derive(Debug, Clone)]
40pub struct TranslateError {
41 pub message: String,
42 pub errors: Vec<ShaderError>,
43}
44
45impl fmt::Display for TranslateError {
46 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47 write!(f, "TranslateError: {}", self.message)?;
48 for e in &self.errors {
49 write!(f, "\n {}", e)?;
50 }
51 Ok(())
52 }
53}
54
55impl std::error::Error for TranslateError {}
56
57impl TranslateError {
58 pub fn new(msg: impl Into<String>) -> Self {
59 Self { message: msg.into(), errors: Vec::new() }
60 }
61 pub fn with_error(mut self, err: ShaderError) -> Self {
62 self.errors.push(err);
63 self
64 }
65}
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
72pub enum ShaderLanguage {
73 GLSL,
74 WGSL,
75 SPIRV,
76 HLSL,
77 MSL,
78}
79
80impl fmt::Display for ShaderLanguage {
81 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82 match self {
83 Self::GLSL => write!(f, "GLSL"),
84 Self::WGSL => write!(f, "WGSL"),
85 Self::SPIRV => write!(f, "SPIR-V"),
86 Self::HLSL => write!(f, "HLSL"),
87 Self::MSL => write!(f, "MSL"),
88 }
89 }
90}
91
92#[derive(Debug, Clone, PartialEq)]
97pub enum GlslToken {
98 Version(u32),
99 In { location: u32, ty: String, name: String },
100 Out { location: u32, ty: String, name: String },
101 Uniform { ty: String, name: String },
102 UniformBlock { name: String, binding: Option<u32> },
103 Sampler { name: String },
104 MainBegin,
105 MainEnd,
106 Line(String),
107}
108
109#[derive(Debug, Clone, PartialEq)]
110pub enum WgslToken {
111 Struct { name: String, fields: Vec<(String, String)> },
112 Binding { group: u32, binding: u32, name: String, ty: String },
113 VertexOutput { name: String },
114 FragmentOutput { name: String },
115 EntryPoint { stage: String, name: String },
116 Line(String),
117}
118
119fn parse_glsl_tokens(source: &str) -> Vec<GlslToken> {
125 let mut tokens = Vec::new();
126 let mut in_main = false;
127
128 for line in source.lines() {
129 let trimmed = line.trim();
130
131 if trimmed.starts_with("#version") {
133 if let Some(ver_str) = trimmed.strip_prefix("#version") {
134 let ver_str = ver_str.trim().split_whitespace().next().unwrap_or("330");
135 if let Ok(v) = ver_str.parse::<u32>() {
136 tokens.push(GlslToken::Version(v));
137 }
138 }
139 continue;
140 }
141
142 if let Some(rest) = try_parse_layout_in_out(trimmed) {
144 tokens.push(rest);
145 continue;
146 }
147
148 if trimmed.starts_with("uniform ") && !trimmed.contains('{') {
150 let parts: Vec<&str> = trimmed.trim_end_matches(';')
151 .split_whitespace().collect();
152 if parts.len() >= 3 {
153 let ty = parts[1].to_string();
154 let name = parts[2].trim_end_matches(';').to_string();
155 if ty.starts_with("sampler") {
156 tokens.push(GlslToken::Sampler { name });
157 } else {
158 tokens.push(GlslToken::Uniform { ty, name });
159 }
160 }
161 continue;
162 }
163
164 if trimmed.contains("void main") && trimmed.contains('(') {
166 tokens.push(GlslToken::MainBegin);
167 in_main = true;
168 continue;
169 }
170
171 if in_main && trimmed == "}" {
173 tokens.push(GlslToken::MainEnd);
174 in_main = false;
175 continue;
176 }
177
178 tokens.push(GlslToken::Line(line.to_string()));
179 }
180 tokens
181}
182
183fn try_parse_layout_in_out(line: &str) -> Option<GlslToken> {
185 if !line.starts_with("layout") { return None; }
186 let loc = extract_location(line)?;
187 let after_paren = line.find(')')? ;
188 let rest = &line[after_paren + 1..].trim();
189
190 if rest.starts_with("in ") {
191 let parts: Vec<&str> = rest[3..].trim_end_matches(';').split_whitespace().collect();
192 if parts.len() >= 2 {
193 return Some(GlslToken::In {
194 location: loc,
195 ty: parts[0].to_string(),
196 name: parts[1].trim_end_matches(';').to_string(),
197 });
198 }
199 } else if rest.starts_with("out ") {
200 let parts: Vec<&str> = rest[4..].trim_end_matches(';').split_whitespace().collect();
201 if parts.len() >= 2 {
202 return Some(GlslToken::Out {
203 location: loc,
204 ty: parts[0].to_string(),
205 name: parts[1].trim_end_matches(';').to_string(),
206 });
207 }
208 }
209 None
210}
211
212fn extract_location(line: &str) -> Option<u32> {
213 let start = line.find("location")? + "location".len();
214 let rest = &line[start..];
215 let eq = rest.find('=')?;
216 let after_eq = &rest[eq + 1..];
217 let end = after_eq.find(')')?;
218 after_eq[..end].trim().parse::<u32>().ok()
219}
220
221fn glsl_type_to_wgsl(ty: &str) -> String {
226 match ty {
227 "float" => "f32".into(),
228 "int" => "i32".into(),
229 "uint" => "u32".into(),
230 "bool" => "bool".into(),
231 "vec2" => "vec2<f32>".into(),
232 "vec3" => "vec3<f32>".into(),
233 "vec4" => "vec4<f32>".into(),
234 "ivec2" => "vec2<i32>".into(),
235 "ivec3" => "vec3<i32>".into(),
236 "ivec4" => "vec4<i32>".into(),
237 "uvec2" => "vec2<u32>".into(),
238 "uvec3" => "vec3<u32>".into(),
239 "uvec4" => "vec4<u32>".into(),
240 "mat2" => "mat2x2<f32>".into(),
241 "mat3" => "mat3x3<f32>".into(),
242 "mat4" => "mat4x4<f32>".into(),
243 "sampler2D" => "texture_2d<f32>".into(),
244 other => other.to_string(),
245 }
246}
247
248fn wgsl_type_to_glsl(ty: &str) -> String {
249 match ty {
250 "f32" => "float".into(),
251 "i32" => "int".into(),
252 "u32" => "uint".into(),
253 "vec2<f32>" => "vec2".into(),
254 "vec3<f32>" => "vec3".into(),
255 "vec4<f32>" => "vec4".into(),
256 "vec2<i32>" => "ivec2".into(),
257 "vec3<i32>" => "ivec3".into(),
258 "vec4<i32>" => "ivec4".into(),
259 "vec2<u32>" => "uvec2".into(),
260 "vec3<u32>" => "uvec3".into(),
261 "vec4<u32>" => "uvec4".into(),
262 "mat2x2<f32>" => "mat2".into(),
263 "mat3x3<f32>" => "mat3".into(),
264 "mat4x4<f32>" => "mat4".into(),
265 "texture_2d<f32>" => "sampler2D".into(),
266 other => other.to_string(),
267 }
268}
269
270fn translate_glsl_builtins_to_wgsl(line: &str) -> String {
272 let mut out = line.to_string();
273 if out.contains("texture2D(") {
276 out = out.replace("texture2D(", "textureSample(");
277 }
278 if out.contains("texture(") {
279 out = out.replace("texture(", "textureSample(");
280 }
281 out = out.replace("gl_Position", "output.position");
283 out = out.replace("gl_FragColor", "output.color");
285 out
287}
288
289fn translate_wgsl_builtins_to_glsl(line: &str) -> String {
291 let mut out = line.to_string();
292 out = out.replace("textureSample(", "texture(");
293 out = out.replace("output.position", "gl_Position");
294 out = out.replace("output.color", "gl_FragColor");
295 out
296}
297
298pub fn glsl_to_wgsl(glsl_source: &str) -> Result<String, TranslateError> {
304 let tokens = parse_glsl_tokens(glsl_source);
305 let mut wgsl = String::new();
306 let mut inputs: Vec<(u32, String, String)> = Vec::new();
307 let mut outputs: Vec<(u32, String, String)> = Vec::new();
308 let mut uniforms: Vec<(String, String)> = Vec::new();
309 let mut samplers: Vec<String> = Vec::new();
310 let mut body_lines: Vec<String> = Vec::new();
311 let mut in_body = false;
312 let mut binding_counter = 0u32;
313
314 for token in &tokens {
315 match token {
316 GlslToken::Version(_) => {}
317 GlslToken::In { location, ty, name } => {
318 inputs.push((*location, ty.clone(), name.clone()));
319 }
320 GlslToken::Out { location, ty, name } => {
321 outputs.push((*location, ty.clone(), name.clone()));
322 }
323 GlslToken::Uniform { ty, name } => {
324 uniforms.push((ty.clone(), name.clone()));
325 }
326 GlslToken::Sampler { name } => {
327 samplers.push(name.clone());
328 }
329 GlslToken::UniformBlock { .. } => {}
330 GlslToken::MainBegin => {
331 in_body = true;
332 }
333 GlslToken::MainEnd => {
334 in_body = false;
335 }
336 GlslToken::Line(l) => {
337 if in_body {
338 body_lines.push(l.clone());
339 }
340 }
341 }
342 }
343
344 if !inputs.is_empty() {
346 wgsl.push_str("struct VertexInput {\n");
347 for (loc, ty, name) in &inputs {
348 wgsl.push_str(&format!(
349 " @location({}) {}: {},\n",
350 loc, name, glsl_type_to_wgsl(ty)
351 ));
352 }
353 wgsl.push_str("};\n\n");
354 }
355
356 if !outputs.is_empty() {
358 wgsl.push_str("struct VertexOutput {\n");
359 wgsl.push_str(" @builtin(position) position: vec4<f32>,\n");
360 for (loc, ty, name) in &outputs {
361 wgsl.push_str(&format!(
362 " @location({}) {}: {},\n",
363 loc, name, glsl_type_to_wgsl(ty)
364 ));
365 }
366 wgsl.push_str("};\n\n");
367 }
368
369 for (ty, name) in &uniforms {
371 wgsl.push_str(&format!(
372 "@group(0) @binding({}) var<uniform> {}: {};\n",
373 binding_counter, name, glsl_type_to_wgsl(ty)
374 ));
375 binding_counter += 1;
376 }
377
378 for name in &samplers {
380 wgsl.push_str(&format!(
381 "@group(0) @binding({}) var {}: texture_2d<f32>;\n",
382 binding_counter, name,
383 ));
384 binding_counter += 1;
385 wgsl.push_str(&format!(
386 "@group(0) @binding({}) var {}_sampler: sampler;\n",
387 binding_counter, name,
388 ));
389 binding_counter += 1;
390 }
391
392 if !uniforms.is_empty() || !samplers.is_empty() {
393 wgsl.push('\n');
394 }
395
396 wgsl.push_str("@vertex\n");
398 wgsl.push_str("fn vs_main(input: VertexInput) -> VertexOutput {\n");
399 wgsl.push_str(" var output: VertexOutput;\n");
400 for line in &body_lines {
401 let translated = translate_glsl_builtins_to_wgsl(line);
402 let translated = translated.trim();
403 if !translated.is_empty() {
404 wgsl.push_str(&format!(" {}\n", translated));
405 }
406 }
407 wgsl.push_str(" return output;\n");
408 wgsl.push_str("}\n");
409
410 Ok(wgsl)
411}
412
413pub fn glsl_to_spirv_text(glsl_source: &str) -> Result<String, TranslateError> {
420 let tokens = parse_glsl_tokens(glsl_source);
421 let mut spirv = String::new();
422 spirv.push_str("; SPIR-V text representation (generated by proof-engine shader_translate)\n");
423 spirv.push_str("; Magic: 0x07230203\n");
424 spirv.push_str("; Version: 1.0\n");
425 spirv.push_str("; Generator: proof-engine\n\n");
426
427 let mut id_counter = 1u32;
428 let mut next_id = || { let id = id_counter; id_counter += 1; id };
429
430 spirv.push_str(" OpCapability Shader\n");
432 let ext_id = next_id();
433 spirv.push_str(&format!(" %{} = OpExtInstImport \"GLSL.std.450\"\n", ext_id));
434 spirv.push_str(" OpMemoryModel Logical GLSL450\n");
435
436 let main_id = next_id();
438 let mut interface_ids = Vec::new();
439
440 for token in &tokens {
441 match token {
442 GlslToken::In { location, ty, name } => {
443 let var_id = next_id();
444 interface_ids.push(var_id);
445 spirv.push_str(&format!(
446 " OpDecorate %{} Location {}\n",
447 var_id, location
448 ));
449 let type_id = next_id();
450 spirv.push_str(&format!(
451 " %{} = OpTypePointer Input %{} ; {}: {}\n",
452 var_id, type_id, name, ty
453 ));
454 }
455 GlslToken::Out { location, ty, name } => {
456 let var_id = next_id();
457 interface_ids.push(var_id);
458 spirv.push_str(&format!(
459 " OpDecorate %{} Location {}\n",
460 var_id, location
461 ));
462 let type_id = next_id();
463 spirv.push_str(&format!(
464 " %{} = OpTypePointer Output %{} ; {}: {}\n",
465 var_id, type_id, name, ty
466 ));
467 }
468 GlslToken::Uniform { ty, name } => {
469 let var_id = next_id();
470 let type_id = next_id();
471 spirv.push_str(&format!(
472 " %{} = OpTypePointer Uniform %{} ; uniform {}: {}\n",
473 var_id, type_id, name, ty
474 ));
475 }
476 _ => {}
477 }
478 }
479
480 let iface_str: String = interface_ids.iter().map(|id| format!("%{}", id)).collect::<Vec<_>>().join(" ");
481 spirv.push_str(&format!(
482 " OpEntryPoint Vertex %{} \"main\" {}\n",
483 main_id, iface_str
484 ));
485
486 let void_id = next_id();
488 let func_type_id = next_id();
489 spirv.push_str(&format!(" %{} = OpTypeVoid\n", void_id));
490 spirv.push_str(&format!(" %{} = OpTypeFunction %{}\n", func_type_id, void_id));
491 spirv.push_str(&format!(" %{} = OpFunction %{} None %{}\n", main_id, void_id, func_type_id));
492 let label_id = next_id();
493 spirv.push_str(&format!(" %{} = OpLabel\n", label_id));
494 spirv.push_str(" OpReturn\n");
495 spirv.push_str(" OpFunctionEnd\n");
496
497 Ok(spirv)
498}
499
500pub fn wgsl_to_glsl(wgsl_source: &str) -> Result<String, TranslateError> {
506 let mut glsl = String::from("#version 330 core\n\n");
507 let mut in_struct = false;
508 let mut current_struct_name = String::new();
509 let mut in_fn = false;
510 let mut fn_body_lines: Vec<String> = Vec::new();
511
512 for line in wgsl_source.lines() {
513 let trimmed = line.trim();
514
515 if trimmed.starts_with("@group") && trimmed.contains("var<uniform>") {
517 if let Some(rest) = trimmed.split("var<uniform>").nth(1) {
518 let rest = rest.trim().trim_end_matches(';');
519 let parts: Vec<&str> = rest.splitn(2, ':').collect();
520 if parts.len() == 2 {
521 let name = parts[0].trim();
522 let ty = parts[1].trim();
523 glsl.push_str(&format!("uniform {} {};\n", wgsl_type_to_glsl(ty), name));
524 }
525 }
526 continue;
527 }
528
529 if trimmed.starts_with("@group") && trimmed.contains("var ") && trimmed.contains("texture") {
531 continue;
533 }
534
535 if trimmed.starts_with("@group") && trimmed.contains("sampler") && !trimmed.contains("texture") {
537 if let Some(rest) = trimmed.split("var ").nth(1) {
539 let name = rest.split(':').next().unwrap_or("").trim();
540 let base = name.strip_suffix("_sampler").unwrap_or(name);
542 glsl.push_str(&format!("uniform sampler2D {};\n", base));
543 }
544 continue;
545 }
546
547 if trimmed.starts_with("struct ") {
549 in_struct = true;
550 current_struct_name = trimmed
551 .strip_prefix("struct ")
552 .unwrap_or("")
553 .trim_end_matches('{')
554 .trim()
555 .to_string();
556 continue;
557 }
558
559 if in_struct {
560 if trimmed.starts_with('}') {
561 in_struct = false;
562 continue;
563 }
564 if trimmed.contains("@location") {
566 if let Some(loc) = extract_wgsl_location(trimmed) {
567 let rest = trimmed.split(')').last().unwrap_or("").trim();
568 let parts: Vec<&str> = rest.splitn(2, ':').collect();
569 if parts.len() == 2 {
570 let name = parts[0].trim();
571 let ty = parts[1].trim().trim_end_matches(',');
572 let is_input = current_struct_name.contains("Input");
573 let qualifier = if is_input { "in" } else { "out" };
574 glsl.push_str(&format!(
575 "layout(location = {}) {} {} {};\n",
576 loc, qualifier, wgsl_type_to_glsl(ty), name
577 ));
578 }
579 }
580 }
581 continue;
582 }
583
584 if trimmed.starts_with("@vertex") || trimmed.starts_with("@fragment") || trimmed.starts_with("@compute") {
586 in_fn = true;
587 fn_body_lines.clear();
588 continue;
589 }
590
591 if trimmed.starts_with("fn ") && in_fn {
592 continue;
594 }
595
596 if in_fn {
597 if trimmed == "}" {
598 glsl.push_str("\nvoid main() {\n");
600 for bl in &fn_body_lines {
601 let translated = translate_wgsl_builtins_to_glsl(bl);
602 let translated = translated.trim();
603 if !translated.is_empty()
604 && !translated.starts_with("var output")
605 && !translated.starts_with("return")
606 {
607 glsl.push_str(&format!(" {}\n", translated));
608 }
609 }
610 glsl.push_str("}\n");
611 in_fn = false;
612 continue;
613 }
614 fn_body_lines.push(line.to_string());
615 }
616 }
617
618 Ok(glsl)
619}
620
621fn extract_wgsl_location(line: &str) -> Option<u32> {
622 let start = line.find("@location(")? + "@location(".len();
623 let rest = &line[start..];
624 let end = rest.find(')')?;
625 rest[..end].trim().parse::<u32>().ok()
626}
627
628#[derive(Debug, Clone, Default)]
634pub struct ShaderReflection {
635 pub inputs: Vec<ReflectedBinding>,
636 pub outputs: Vec<ReflectedBinding>,
637 pub uniforms: Vec<ReflectedBinding>,
638 pub storage_buffers: Vec<ReflectedBinding>,
639 pub textures: Vec<ReflectedBinding>,
640 pub samplers: Vec<ReflectedBinding>,
641 pub workgroup_size: Option<[u32; 3]>,
642}
643
644#[derive(Debug, Clone)]
646pub struct ReflectedBinding {
647 pub name: String,
648 pub ty: String,
649 pub location_or_binding: u32,
650 pub group: Option<u32>,
651}
652
653pub fn reflect_glsl(source: &str) -> ShaderReflection {
655 let tokens = parse_glsl_tokens(source);
656 let mut refl = ShaderReflection::default();
657
658 for token in &tokens {
659 match token {
660 GlslToken::In { location, ty, name } => {
661 refl.inputs.push(ReflectedBinding {
662 name: name.clone(),
663 ty: ty.clone(),
664 location_or_binding: *location,
665 group: None,
666 });
667 }
668 GlslToken::Out { location, ty, name } => {
669 refl.outputs.push(ReflectedBinding {
670 name: name.clone(),
671 ty: ty.clone(),
672 location_or_binding: *location,
673 group: None,
674 });
675 }
676 GlslToken::Uniform { ty, name } => {
677 refl.uniforms.push(ReflectedBinding {
678 name: name.clone(),
679 ty: ty.clone(),
680 location_or_binding: 0,
681 group: None,
682 });
683 }
684 GlslToken::Sampler { name } => {
685 refl.samplers.push(ReflectedBinding {
686 name: name.clone(),
687 ty: "sampler2D".into(),
688 location_or_binding: 0,
689 group: None,
690 });
691 }
692 _ => {}
693 }
694 }
695
696 for line in source.lines() {
698 let trimmed = line.trim();
699 if trimmed.contains("local_size_x") {
700 let mut ws = [1u32, 1, 1];
701 for dim in ["local_size_x", "local_size_y", "local_size_z"].iter().enumerate() {
702 if let Some(pos) = trimmed.find(dim.1) {
703 let rest = &trimmed[pos + dim.1.len()..];
704 if let Some(eq) = rest.find('=') {
705 let after = &rest[eq + 1..];
706 let num_str: String = after.chars().take_while(|c| c.is_ascii_digit()).collect();
707 if let Ok(n) = num_str.parse::<u32>() {
708 ws[dim.0] = n;
709 }
710 }
711 }
712 }
713 refl.workgroup_size = Some(ws);
714 }
715 }
716
717 refl
718}
719
720pub fn reflect_wgsl(source: &str) -> ShaderReflection {
722 let mut refl = ShaderReflection::default();
723
724 for line in source.lines() {
725 let trimmed = line.trim();
726
727 if trimmed.starts_with("@group") && trimmed.contains("@binding") {
729 let group = extract_wgsl_group(trimmed);
730 let binding = extract_wgsl_binding_num(trimmed);
731
732 if trimmed.contains("var<uniform>") {
733 if let Some(rest) = trimmed.split("var<uniform>").nth(1) {
734 let rest = rest.trim().trim_end_matches(';');
735 let parts: Vec<&str> = rest.splitn(2, ':').collect();
736 if parts.len() == 2 {
737 refl.uniforms.push(ReflectedBinding {
738 name: parts[0].trim().to_string(),
739 ty: parts[1].trim().to_string(),
740 location_or_binding: binding.unwrap_or(0),
741 group,
742 });
743 }
744 }
745 } else if trimmed.contains("var<storage") {
746 if let Some(rest) = trimmed.split('>').last() {
747 let rest = rest.trim().trim_end_matches(';');
748 let parts: Vec<&str> = rest.splitn(2, ':').collect();
749 if parts.len() == 2 {
750 refl.storage_buffers.push(ReflectedBinding {
751 name: parts[0].trim().to_string(),
752 ty: parts[1].trim().to_string(),
753 location_or_binding: binding.unwrap_or(0),
754 group,
755 });
756 }
757 }
758 } else if trimmed.contains("texture") {
759 if let Some(rest) = trimmed.split("var ").nth(1) {
760 let parts: Vec<&str> = rest.splitn(2, ':').collect();
761 if parts.len() == 2 {
762 refl.textures.push(ReflectedBinding {
763 name: parts[0].trim().to_string(),
764 ty: parts[1].trim().trim_end_matches(';').to_string(),
765 location_or_binding: binding.unwrap_or(0),
766 group,
767 });
768 }
769 }
770 } else if trimmed.contains("sampler") {
771 if let Some(rest) = trimmed.split("var ").nth(1) {
772 let name = rest.split(':').next().unwrap_or("").trim().to_string();
773 refl.samplers.push(ReflectedBinding {
774 name,
775 ty: "sampler".into(),
776 location_or_binding: binding.unwrap_or(0),
777 group,
778 });
779 }
780 }
781 }
782
783 if trimmed.contains("@location(") && !trimmed.starts_with("@group") {
785 if let Some(loc) = extract_wgsl_location(trimmed) {
786 let rest = trimmed.split(')').last().unwrap_or("").trim();
787 let parts: Vec<&str> = rest.splitn(2, ':').collect();
788 if parts.len() == 2 {
789 let binding = ReflectedBinding {
790 name: parts[0].trim().to_string(),
791 ty: parts[1].trim().trim_end_matches(',').to_string(),
792 location_or_binding: loc,
793 group: None,
794 };
795 refl.inputs.push(binding);
797 }
798 }
799 }
800
801 if trimmed.contains("@workgroup_size(") {
803 let start = trimmed.find("@workgroup_size(").unwrap() + "@workgroup_size(".len();
804 let rest = &trimmed[start..];
805 if let Some(end) = rest.find(')') {
806 let nums: Vec<u32> = rest[..end]
807 .split(',')
808 .filter_map(|s| s.trim().parse::<u32>().ok())
809 .collect();
810 let mut ws = [1u32, 1, 1];
811 for (i, &n) in nums.iter().enumerate().take(3) {
812 ws[i] = n;
813 }
814 refl.workgroup_size = Some(ws);
815 }
816 }
817 }
818
819 refl
820}
821
822fn extract_wgsl_group(line: &str) -> Option<u32> {
823 let start = line.find("@group(")? + "@group(".len();
824 let rest = &line[start..];
825 let end = rest.find(')')?;
826 rest[..end].trim().parse::<u32>().ok()
827}
828
829fn extract_wgsl_binding_num(line: &str) -> Option<u32> {
830 let start = line.find("@binding(")? + "@binding(".len();
831 let rest = &line[start..];
832 let end = rest.find(')')?;
833 rest[..end].trim().parse::<u32>().ok()
834}
835
836pub fn validate_shader(source: &str, language: ShaderLanguage) -> Vec<ShaderError> {
842 let mut errors = Vec::new();
843
844 match language {
845 ShaderLanguage::GLSL => validate_glsl(source, &mut errors),
846 ShaderLanguage::WGSL => validate_wgsl(source, &mut errors),
847 _ => {
848 if source.trim().is_empty() {
850 errors.push(ShaderError {
851 line: 1, col: 1,
852 message: "Empty shader source".into(),
853 severity: Severity::Error,
854 });
855 }
856 }
857 }
858
859 errors
860}
861
862fn validate_glsl(source: &str, errors: &mut Vec<ShaderError>) {
863 let mut has_version = false;
864 let mut has_main = false;
865 let mut brace_depth: i32 = 0;
866
867 for (i, line) in source.lines().enumerate() {
868 let ln = i + 1;
869 let trimmed = line.trim();
870
871 if trimmed.starts_with("#version") {
872 has_version = true;
873 if ln != 1 {
874 errors.push(ShaderError {
875 line: ln, col: 1,
876 message: "#version must be on the first line".into(),
877 severity: Severity::Warning,
878 });
879 }
880 }
881
882 if trimmed.contains("void main") {
883 has_main = true;
884 }
885
886 for ch in trimmed.chars() {
887 if ch == '{' { brace_depth += 1; }
888 if ch == '}' { brace_depth -= 1; }
889 }
890
891 if brace_depth < 0 {
892 errors.push(ShaderError {
893 line: ln, col: 1,
894 message: "Unmatched closing brace".into(),
895 severity: Severity::Error,
896 });
897 }
898 }
899
900 if !has_version {
901 errors.push(ShaderError {
902 line: 1, col: 1,
903 message: "Missing #version directive".into(),
904 severity: Severity::Warning,
905 });
906 }
907
908 if !has_main {
909 errors.push(ShaderError {
910 line: 1, col: 1,
911 message: "Missing void main() entry point".into(),
912 severity: Severity::Error,
913 });
914 }
915
916 if brace_depth != 0 {
917 errors.push(ShaderError {
918 line: source.lines().count(), col: 1,
919 message: format!("Unbalanced braces (depth {})", brace_depth),
920 severity: Severity::Error,
921 });
922 }
923}
924
925fn validate_wgsl(source: &str, errors: &mut Vec<ShaderError>) {
926 let mut has_entry = false;
927 let mut brace_depth: i32 = 0;
928
929 for (i, line) in source.lines().enumerate() {
930 let ln = i + 1;
931 let trimmed = line.trim();
932
933 if trimmed.starts_with("@vertex") || trimmed.starts_with("@fragment") || trimmed.starts_with("@compute") {
934 has_entry = true;
935 }
936
937 for ch in trimmed.chars() {
938 if ch == '{' { brace_depth += 1; }
939 if ch == '}' { brace_depth -= 1; }
940 }
941
942 if brace_depth < 0 {
943 errors.push(ShaderError {
944 line: ln, col: 1,
945 message: "Unmatched closing brace".into(),
946 severity: Severity::Error,
947 });
948 }
949
950 if trimmed.starts_with("#version") {
952 errors.push(ShaderError {
953 line: ln, col: 1,
954 message: "#version is not valid WGSL".into(),
955 severity: Severity::Error,
956 });
957 }
958
959 if trimmed.contains("void main") {
960 errors.push(ShaderError {
961 line: ln, col: 1,
962 message: "WGSL does not use 'void main()'; use @vertex/@fragment fn".into(),
963 severity: Severity::Error,
964 });
965 }
966 }
967
968 if !has_entry {
969 errors.push(ShaderError {
970 line: 1, col: 1,
971 message: "Missing entry point (@vertex, @fragment, or @compute)".into(),
972 severity: Severity::Warning,
973 });
974 }
975
976 if brace_depth != 0 {
977 errors.push(ShaderError {
978 line: source.lines().count(), col: 1,
979 message: format!("Unbalanced braces (depth {})", brace_depth),
980 severity: Severity::Error,
981 });
982 }
983}
984
985#[cfg(test)]
990mod tests {
991 use super::*;
992
993 const SIMPLE_VERT_GLSL: &str = r#"#version 330 core
994layout(location = 0) in vec3 aPos;
995layout(location = 1) in vec2 aUV;
996layout(location = 0) out vec2 vUV;
997uniform mat4 uMVP;
998void main() {
999 gl_Position = uMVP * vec4(aPos, 1.0);
1000 vUV = aUV;
1001}
1002"#;
1003
1004 #[test]
1005 fn parse_glsl_tokens_simple() {
1006 let tokens = parse_glsl_tokens(SIMPLE_VERT_GLSL);
1007 assert!(tokens.iter().any(|t| matches!(t, GlslToken::Version(330))));
1008 assert!(tokens.iter().any(|t| matches!(t, GlslToken::In { location: 0, .. })));
1009 assert!(tokens.iter().any(|t| matches!(t, GlslToken::Out { location: 0, .. })));
1010 assert!(tokens.iter().any(|t| matches!(t, GlslToken::Uniform { .. })));
1011 assert!(tokens.iter().any(|t| matches!(t, GlslToken::MainBegin)));
1012 assert!(tokens.iter().any(|t| matches!(t, GlslToken::MainEnd)));
1013 }
1014
1015 #[test]
1016 fn glsl_to_wgsl_simple() {
1017 let wgsl = glsl_to_wgsl(SIMPLE_VERT_GLSL).unwrap();
1018 assert!(wgsl.contains("struct VertexInput"));
1019 assert!(wgsl.contains("@location(0) aPos: vec3<f32>"));
1020 assert!(wgsl.contains("@location(1) aUV: vec2<f32>"));
1021 assert!(wgsl.contains("struct VertexOutput"));
1022 assert!(wgsl.contains("@group(0) @binding(0) var<uniform> uMVP: mat4x4<f32>"));
1023 assert!(wgsl.contains("@vertex"));
1024 assert!(wgsl.contains("fn vs_main"));
1025 }
1026
1027 #[test]
1028 fn glsl_to_wgsl_translates_builtins() {
1029 let wgsl = glsl_to_wgsl(SIMPLE_VERT_GLSL).unwrap();
1030 assert!(wgsl.contains("output.position"));
1031 assert!(!wgsl.contains("gl_Position"));
1032 }
1033
1034 #[test]
1035 fn wgsl_to_glsl_roundtrip() {
1036 let wgsl = glsl_to_wgsl(SIMPLE_VERT_GLSL).unwrap();
1037 let glsl_back = wgsl_to_glsl(&wgsl).unwrap();
1038 assert!(glsl_back.contains("#version 330 core"));
1040 assert!(glsl_back.contains("uniform mat4x4 uMVP") || glsl_back.contains("uniform mat4 uMVP"));
1041 assert!(glsl_back.contains("void main()"));
1042 }
1043
1044 #[test]
1045 fn glsl_to_spirv_text_has_structure() {
1046 let spirv = glsl_to_spirv_text(SIMPLE_VERT_GLSL).unwrap();
1047 assert!(spirv.contains("OpCapability Shader"));
1048 assert!(spirv.contains("OpMemoryModel Logical GLSL450"));
1049 assert!(spirv.contains("OpEntryPoint Vertex"));
1050 assert!(spirv.contains("OpReturn"));
1051 assert!(spirv.contains("OpFunctionEnd"));
1052 }
1053
1054 #[test]
1055 fn type_translation_glsl_to_wgsl() {
1056 assert_eq!(glsl_type_to_wgsl("float"), "f32");
1057 assert_eq!(glsl_type_to_wgsl("vec3"), "vec3<f32>");
1058 assert_eq!(glsl_type_to_wgsl("mat4"), "mat4x4<f32>");
1059 assert_eq!(glsl_type_to_wgsl("sampler2D"), "texture_2d<f32>");
1060 }
1061
1062 #[test]
1063 fn type_translation_wgsl_to_glsl() {
1064 assert_eq!(wgsl_type_to_glsl("f32"), "float");
1065 assert_eq!(wgsl_type_to_glsl("vec3<f32>"), "vec3");
1066 assert_eq!(wgsl_type_to_glsl("mat4x4<f32>"), "mat4");
1067 }
1068
1069 #[test]
1070 fn reflect_glsl_shader() {
1071 let refl = reflect_glsl(SIMPLE_VERT_GLSL);
1072 assert_eq!(refl.inputs.len(), 2);
1073 assert_eq!(refl.outputs.len(), 1);
1074 assert_eq!(refl.uniforms.len(), 1);
1075 assert_eq!(refl.uniforms[0].name, "uMVP");
1076 assert_eq!(refl.inputs[0].name, "aPos");
1077 assert_eq!(refl.inputs[0].location_or_binding, 0);
1078 }
1079
1080 #[test]
1081 fn reflect_glsl_compute_workgroup() {
1082 let src = r#"#version 430
1083layout(local_size_x=64, local_size_y=1, local_size_z=1) in;
1084void main() {}
1085"#;
1086 let refl = reflect_glsl(src);
1087 assert_eq!(refl.workgroup_size, Some([64, 1, 1]));
1088 }
1089
1090 #[test]
1091 fn reflect_glsl_sampler() {
1092 let src = r#"#version 330 core
1093uniform sampler2D uTexture;
1094void main() {}
1095"#;
1096 let refl = reflect_glsl(src);
1097 assert_eq!(refl.samplers.len(), 1);
1098 assert_eq!(refl.samplers[0].name, "uTexture");
1099 }
1100
1101 #[test]
1102 fn reflect_wgsl_shader() {
1103 let src = r#"
1104@group(0) @binding(0) var<uniform> uMVP: mat4x4<f32>;
1105@group(0) @binding(1) var myTex: texture_2d<f32>;
1106@group(0) @binding(2) var mySampler: sampler;
1107@vertex
1108fn vs_main() -> vec4<f32> {
1109 return vec4<f32>(0.0);
1110}
1111"#;
1112 let refl = reflect_wgsl(src);
1113 assert_eq!(refl.uniforms.len(), 1);
1114 assert_eq!(refl.textures.len(), 1);
1115 assert_eq!(refl.samplers.len(), 1);
1116 }
1117
1118 #[test]
1119 fn reflect_wgsl_compute_workgroup() {
1120 let src = r#"
1121@compute @workgroup_size(256, 1, 1)
1122fn main() {}
1123"#;
1124 let refl = reflect_wgsl(src);
1125 assert_eq!(refl.workgroup_size, Some([256, 1, 1]));
1126 }
1127
1128 #[test]
1129 fn validate_valid_glsl() {
1130 let errs = validate_shader(SIMPLE_VERT_GLSL, ShaderLanguage::GLSL);
1131 let real_errors: Vec<_> = errs.iter().filter(|e| e.severity == Severity::Error).collect();
1133 assert!(real_errors.is_empty(), "Unexpected errors: {:?}", real_errors);
1134 }
1135
1136 #[test]
1137 fn validate_glsl_missing_main() {
1138 let src = "#version 330 core\nuniform float x;\n";
1139 let errs = validate_shader(src, ShaderLanguage::GLSL);
1140 assert!(errs.iter().any(|e| e.message.contains("main")));
1141 }
1142
1143 #[test]
1144 fn validate_glsl_unbalanced_braces() {
1145 let src = "#version 330 core\nvoid main() {\n";
1146 let errs = validate_shader(src, ShaderLanguage::GLSL);
1147 assert!(errs.iter().any(|e| e.message.contains("brace")));
1148 }
1149
1150 #[test]
1151 fn validate_valid_wgsl() {
1152 let wgsl = glsl_to_wgsl(SIMPLE_VERT_GLSL).unwrap();
1153 let errs = validate_shader(&wgsl, ShaderLanguage::WGSL);
1154 let real_errors: Vec<_> = errs.iter().filter(|e| e.severity == Severity::Error).collect();
1155 assert!(real_errors.is_empty(), "Unexpected errors: {:?}", real_errors);
1156 }
1157
1158 #[test]
1159 fn validate_wgsl_with_glsl_isms() {
1160 let src = "#version 330\nvoid main() {}\n";
1161 let errs = validate_shader(src, ShaderLanguage::WGSL);
1162 assert!(errs.iter().any(|e| e.message.contains("#version")));
1163 }
1164
1165 #[test]
1166 fn validate_empty_shader() {
1167 let errs = validate_shader("", ShaderLanguage::HLSL);
1168 assert!(errs.iter().any(|e| e.message.contains("Empty")));
1169 }
1170
1171 #[test]
1172 fn shader_language_display() {
1173 assert_eq!(format!("{}", ShaderLanguage::GLSL), "GLSL");
1174 assert_eq!(format!("{}", ShaderLanguage::SPIRV), "SPIR-V");
1175 }
1176
1177 #[test]
1178 fn translate_error_display() {
1179 let err = TranslateError::new("test error")
1180 .with_error(ShaderError {
1181 line: 5, col: 10,
1182 message: "bad token".into(),
1183 severity: Severity::Error,
1184 });
1185 let s = format!("{}", err);
1186 assert!(s.contains("test error"));
1187 assert!(s.contains("bad token"));
1188 }
1189
1190 #[test]
1191 fn glsl_texture_builtin_translation() {
1192 let line = "vec4 c = texture2D(myTex, uv);";
1193 let translated = translate_glsl_builtins_to_wgsl(line);
1194 assert!(translated.contains("textureSample("));
1195 }
1196
1197 #[test]
1198 fn wgsl_texture_builtin_translation() {
1199 let line = "let c = textureSample(myTex, mySampler, uv);";
1200 let translated = translate_wgsl_builtins_to_glsl(line);
1201 assert!(translated.contains("texture("));
1202 }
1203}