1use alloc::{
2 string::{String, ToString},
3 vec,
4};
5use arrayvec::ArrayVec;
6use core::fmt::Write;
7
8use naga::{
9 Expression, Handle, Module, Scalar, ShaderStage, TypeInner,
10 back::{self, INDENT},
11 proc::{self, ExpressionKindTracker, NameKey},
12 valid::ModuleInfo,
13};
14
15use crate::config::WriterFlags;
16use crate::conv::{self, BinOpClassified, KEYWORDS_2024, unwrap_to_rust};
17use crate::util::{Gensym, LevelNext};
18use crate::{Config, Error};
19
20type BackendResult = Result<(), Error>;
24
25enum Attribute {
32 AllowFunctionBody,
34
35 Stage(ShaderStage),
37 WorkGroupSize([u32; 3]),
39}
40
41#[derive(Clone, Copy, Debug)]
57enum Indirection {
58 Place,
61
62 Ordinary,
66}
67
68#[allow(missing_debug_implementations, reason = "TODO")]
74pub struct Writer {
75 config: Config,
76 names: naga::FastHashMap<NameKey, String>,
77 namer: proc::Namer,
78 named_expressions: naga::FastIndexMap<Handle<Expression>, String>,
79}
80
81enum ExpressionCtx<'a> {
82 Global {
83 module: &'a Module,
84 module_info: &'a ModuleInfo,
85 expressions: &'a naga::Arena<Expression>,
86 },
87 Function {
88 func_ctx: &'a back::FunctionCtx<'a>,
90 },
91}
92
93impl<'a> ExpressionCtx<'a> {
94 #[track_caller]
95 fn expect_func_ctx(&self) -> &'a back::FunctionCtx<'a> {
96 match self {
97 ExpressionCtx::Function { func_ctx, .. } => func_ctx,
98 ExpressionCtx::Global { .. } => {
99 unreachable!("attempting to access the function context outside of a function")
100 }
101 }
102 }
103
104 fn expressions(&self) -> &'a naga::Arena<Expression> {
105 match self {
106 ExpressionCtx::Global { expressions, .. } => expressions,
107 ExpressionCtx::Function { func_ctx, .. } => func_ctx.expressions,
108 }
109 }
110
111 fn resolve_type(
112 &self,
113 handle: Handle<Expression>,
114 types: &'a naga::UniqueArena<naga::Type>,
115 ) -> &'a TypeInner {
116 match self {
117 ExpressionCtx::Global {
118 module_info,
119 module,
120 ..
121 } => module_info[handle].inner_with(&module.types),
122 ExpressionCtx::Function { func_ctx, .. } => func_ctx.resolve_type(handle, types),
123 }
124 }
125}
126
127impl Writer {
128 #[must_use]
130 pub fn new(config: Config) -> Self {
131 Writer {
132 config,
133 names: naga::FastHashMap::default(),
134 namer: proc::Namer::default(),
135 named_expressions: naga::FastIndexMap::default(),
136 }
137 }
138
139 fn reset(&mut self, module: &Module) {
140 let Self {
141 config,
142 names,
143 namer,
144 named_expressions,
145 } = self;
146 names.clear();
147 namer.reset(module, KEYWORDS_2024, &[], &[], &[], &mut self.names);
148 if let Some(g) = &config.global_struct {
149 namer.call(g);
152 }
153 named_expressions.clear();
154 }
155
156 #[expect(clippy::missing_panics_doc, reason = "TODO: unfinished")]
165 pub fn write(
166 &mut self,
167 out: &mut dyn Write,
168 module: &Module,
169 info: &ModuleInfo,
170 ) -> BackendResult {
171 if !module.overrides.is_empty() {
172 return Err(Error::Unimplemented("pipeline constants".into()));
173 }
174
175 self.reset(module);
176
177 for (handle, ty) in module.types.iter() {
179 if let TypeInner::Struct { ref members, .. } = ty.inner {
180 {
181 self.write_struct_definition(out, module, handle, members)?;
182 writeln!(out)?;
183 }
184 }
185 }
186
187 let mut constants = module
189 .constants
190 .iter()
191 .filter(|&(_, c)| c.name.is_some())
192 .peekable();
193 while let Some((handle, _)) = constants.next() {
194 self.write_global_constant(out, module, info, handle)?;
195 if constants.peek().is_none() {
197 writeln!(out)?;
198 }
199 }
200
201 if let Some(global_struct) = self.config.global_struct.clone() {
203 writeln!(out, "struct {global_struct} {{")?;
204 for (handle, global) in module.global_variables.iter() {
205 self.write_global_variable_as_struct_field(out, module, global, handle)?;
206 }
207 writeln!(
210 out,
211 "}}\n\
212 impl Default for {global_struct} {{\n\
213 {INDENT}fn default() -> Self {{ Self {{"
214 )?;
215 for (handle, global) in module.global_variables.iter() {
216 self.write_global_variable_as_field_initializer(out, module, info, global, handle)?;
217 }
218 writeln!(out, "{INDENT}}}}}\n}}")?;
219
220 writeln!(out, "impl {global_struct} {{")?;
222 } else if let Some((_, example)) = module.global_variables.iter().next() {
223 return Err(Error::GlobalVariablesNotEnabled {
224 example: example.name.clone().unwrap_or_default(),
225 });
226 }
227
228 for (handle, function) in module.functions.iter() {
230 let fun_info = &info[handle];
231
232 let func_ctx = back::FunctionCtx {
233 ty: back::FunctionType::Function(handle),
234 info: fun_info,
235 expressions: &function.expressions,
236 named_expressions: &function.named_expressions,
237 expr_kind_tracker: ExpressionKindTracker::from_arena(&function.expressions),
238 };
239
240 self.write_function(out, module, function, &func_ctx)?;
242
243 writeln!(out)?;
244 }
245
246 for (index, ep) in module.entry_points.iter().enumerate() {
248 let attributes = match ep.stage {
249 ShaderStage::Vertex | ShaderStage::Fragment => vec![Attribute::Stage(ep.stage)],
250 ShaderStage::Compute => vec![
251 Attribute::Stage(ShaderStage::Compute),
252 Attribute::WorkGroupSize(ep.workgroup_size),
253 ],
254 };
255
256 self.write_attributes(out, back::Level(0), &attributes)?;
257
258 let func_ctx = back::FunctionCtx {
259 ty: back::FunctionType::EntryPoint(index.try_into().unwrap()),
260 info: info.get_entry_point(index),
261 expressions: &ep.function.expressions,
262 named_expressions: &ep.function.named_expressions,
263 expr_kind_tracker: ExpressionKindTracker::from_arena(&ep.function.expressions),
264 };
265 self.write_function(out, module, &ep.function, &func_ctx)?;
266
267 if index < module.entry_points.len() - 1 {
268 writeln!(out)?;
269 }
270 }
271
272 if self.config.use_global_struct() {
273 writeln!(out, "}}")?;
275 }
276
277 Ok(())
278 }
279
280 fn write_function(
283 &mut self,
284 out: &mut dyn Write,
285 module: &Module,
286 func: &naga::Function,
287 func_ctx: &back::FunctionCtx<'_>,
288 ) -> BackendResult {
289 self.write_attributes(out, back::Level(0), &[Attribute::AllowFunctionBody])?;
290
291 let func_name = match func_ctx.ty {
293 back::FunctionType::EntryPoint(index) => &self.names[&NameKey::EntryPoint(index)],
294 back::FunctionType::Function(handle) => &self.names[&NameKey::Function(handle)],
295 };
296 let visibility = self.visibility();
297 write!(out, "{visibility}fn {func_name}(")?;
298
299 if self.config.use_global_struct() {
300 write!(out, "&self, ")?;
302 } else if func_ctx.info.global_variable_count() > 0 {
303 unreachable!(
304 "function has globals but globals are not enabled; \
305 should have been rejected earlier"
306 );
307 }
308
309 for (index, arg) in func.arguments.iter().enumerate() {
311 let argument_name = &self.names[&func_ctx.argument_key(index.try_into().unwrap())];
317
318 write!(out, "{argument_name}: ")?;
319 self.write_type(out, module, arg.ty)?;
321 if index < func.arguments.len() - 1 {
322 write!(out, ", ")?;
324 }
325 }
326
327 write!(out, ")")?;
328
329 if let Some(ref result) = func.result {
331 write!(out, " -> ")?;
332 self.write_type(out, module, result.ty)?;
336 }
337
338 write!(out, " {{")?;
339 writeln!(out)?;
340
341 for (handle, local) in func.local_variables.iter() {
343 write!(out, "{INDENT}")?;
345
346 write!(out, "let mut {}: ", self.names[&func_ctx.name_key(handle)])?;
349
350 self.write_type(out, module, local.ty)?;
352
353 if let Some(init) = local.init {
355 write!(out, " = ")?;
356 self.write_expr(
357 out,
358 module,
359 init,
360 &ExpressionCtx::Function {
361 func_ctx,
362 },
364 )?;
365 }
366
367 writeln!(out, ";")?;
369 }
370
371 if !func.local_variables.is_empty() {
372 writeln!(out)?;
373 }
374
375 for sta in func.body.iter() {
377 self.write_stmt(out, module, sta, func_ctx, back::Level(1))?;
379 }
380
381 writeln!(out, "}}")?;
382
383 self.named_expressions.clear();
384
385 Ok(())
386 }
387
388 fn write_attributes(
390 &self,
391 out: &mut dyn Write,
392 level: back::Level,
393 attributes: &[Attribute],
394 ) -> BackendResult {
395 let runtime_path = &self.config.runtime_path;
396 for attribute in attributes {
397 write!(out, "{level}#[")?;
398 match *attribute {
399 Attribute::AllowFunctionBody => {
400 write!(
401 out,
402 "allow(unused_parens, clippy::all, clippy::pedantic, clippy::nursery)"
406 )?;
407 }
408 Attribute::Stage(shader_stage) => {
409 let stage_str = match shader_stage {
410 ShaderStage::Vertex => "vertex",
411 ShaderStage::Fragment => "fragment",
412 ShaderStage::Compute => "compute",
413 };
414 write!(out, "{runtime_path}::{stage_str}")?;
415 }
416 Attribute::WorkGroupSize(size) => {
417 write!(
418 out,
419 "{runtime_path}::workgroup_size({}, {}, {})",
420 size[0], size[1], size[2]
421 )?;
422 }
423 }
424 writeln!(out, "]")?;
425 }
426 Ok(())
427 }
428
429 fn write_struct_definition(
436 &self,
437 out: &mut dyn Write,
438 module: &Module,
439 handle: Handle<naga::Type>,
440 members: &[naga::StructMember],
441 ) -> BackendResult {
442 let visibility = self.visibility();
444 write!(
445 out,
446 "#[repr(C)]\n\
447 {visibility}struct {}",
448 self.names[&NameKey::Type(handle)]
449 )?;
450 write!(out, " {{")?;
451 writeln!(out)?;
452 for (index, member) in members.iter().enumerate() {
453 write!(out, "{INDENT}")?;
455 let member_name =
460 &self.names[&NameKey::StructMember(handle, index.try_into().unwrap())];
461 write!(out, "{visibility}{member_name}: ")?;
462 self.write_type(out, module, member.ty)?;
463 write!(out, ",")?;
464 writeln!(out)?;
465 }
466
467 writeln!(out, "}}")?;
468
469 Ok(())
470 }
471
472 fn write_stmt(
477 &mut self,
478 out: &mut dyn Write,
479 module: &Module,
480 stmt: &naga::Statement,
481 func_ctx: &back::FunctionCtx<'_>,
482 level: back::Level,
483 ) -> BackendResult {
484 use naga::{Expression, Statement};
485
486 let runtime_path = &self.config.runtime_path;
487 let expr_ctx = &ExpressionCtx::Function {
488 func_ctx,
489 };
491
492 match *stmt {
493 Statement::Emit(ref range) => {
494 for handle in range.clone() {
495 let expr_info = &func_ctx.info[handle];
496 let expr_name = if let Some(name) = func_ctx.named_expressions.get(&handle) {
497 Some(self.namer.call(name))
502 } else {
503 let expr = &func_ctx.expressions[handle];
504 let min_ref_count = expr.bake_ref_count();
505 let required_baking_expr = matches!(
507 *expr,
508 Expression::ImageLoad { .. }
509 | Expression::ImageQuery { .. }
510 | Expression::ImageSample { .. }
511 );
512 if min_ref_count <= expr_info.ref_count || required_baking_expr {
513 Some(Gensym(handle).to_string())
514 } else {
515 None
516 }
517 };
518
519 if let Some(name) = expr_name {
520 write!(out, "{level}")?;
521 self.start_named_expr(out, module, handle, func_ctx, &name)?;
522 self.write_expr(out, module, handle, expr_ctx)?;
523 self.named_expressions.insert(handle, name);
524 writeln!(out, ";")?;
525 }
526 }
527 }
528 Statement::If {
529 condition,
530 ref accept,
531 ref reject,
532 } => {
533 let l2 = level.next();
534
535 write!(out, "{level}if ")?;
536 self.write_expr(out, module, condition, expr_ctx)?;
537 writeln!(out, " {{")?;
538 for s in accept {
539 self.write_stmt(out, module, s, func_ctx, l2)?;
540 }
541 if !reject.is_empty() {
542 writeln!(out, "{level}}} else {{")?;
543 for s in reject {
544 self.write_stmt(out, module, s, func_ctx, l2)?;
545 }
546 }
547 writeln!(out, "{level}}}")?
548 }
549 Statement::Return { value } => {
550 write!(out, "{level}return")?;
551 if let Some(return_value) = value {
552 write!(out, " ")?;
553 self.write_expr(out, module, return_value, expr_ctx)?;
554 }
555 writeln!(out, ";")?;
556 }
557 Statement::Kill => write!(out, "{level}{runtime_path}::discard();")?,
558 Statement::Store { pointer, value } => {
559 let is_atomic_pointer = func_ctx
560 .resolve_type(pointer, &module.types)
561 .is_atomic_pointer(&module.types);
562
563 if is_atomic_pointer {
564 return Err(Error::Unimplemented("atomic operations".into()));
565 }
566
567 write!(out, "{level}")?;
568 self.write_expr_with_indirection(
572 out,
573 module,
574 pointer,
575 expr_ctx,
576 Indirection::Place,
577 )?;
578 write!(out, " = ")?;
579 self.write_expr(out, module, value, expr_ctx)?;
580
581 writeln!(out, ";")?
582 }
583 Statement::Call {
584 function,
585 ref arguments,
586 result,
587 } => {
588 write!(out, "{level}")?;
589
590 if let Some(expr) = result {
592 let name = Gensym(expr).to_string();
593 self.start_named_expr(out, module, expr, func_ctx, &name)?;
594 self.named_expressions.insert(expr, name);
595 }
596
597 if self.config.use_global_struct() {
599 write!(out, "self.")?;
600 }
601
602 let func_name = &self.names[&NameKey::Function(function)];
603 write!(out, "{func_name}(")?;
604 for (index, &argument) in arguments.iter().enumerate() {
605 if index != 0 {
606 write!(out, ", ")?;
607 }
608 self.write_expr(out, module, argument, expr_ctx)?;
609 }
610 writeln!(out, ");")?
611 }
612 Statement::Atomic { .. } => {
613 return Err(Error::Unimplemented("atomic operations".into()));
614 }
615 Statement::ImageAtomic { .. } => {
616 return Err(Error::TexturesAreUnsupported {
617 found: "textureAtomic",
618 });
619 }
620 Statement::WorkGroupUniformLoad { .. } => {
621 todo!("Statement::WorkGroupUniformLoad");
622 }
623 Statement::ImageStore { .. } => {
624 return Err(Error::TexturesAreUnsupported {
625 found: "textureStore",
626 });
627 }
628 Statement::Block(ref block) => {
629 write!(out, "{level}")?;
630 writeln!(out, "{{")?;
631 for s in block.iter() {
632 self.write_stmt(out, module, s, func_ctx, level.next())?;
633 }
634 writeln!(out, "{level}}}")?;
635 }
636 Statement::Switch {
637 selector,
638 ref cases,
639 } => {
640 write!(out, "{level}")?;
642 write!(out, "match ")?;
643 self.write_expr(out, module, selector, expr_ctx)?;
644 writeln!(out, " {{")?;
645
646 let l2 = level.next();
648 let mut new_match_arm = true;
649 for case in cases {
650 if case.fall_through && !case.body.is_empty() {
651 return Err(Error::Unimplemented(
653 "fall-through switch case block".into(),
654 ));
655 }
656
657 if new_match_arm {
658 write!(out, "{l2}")?;
660 } else {
661 write!(out, " | ")?;
663 }
664 match case.value {
666 naga::SwitchValue::I32(value) => {
667 write!(out, "{value}i32")?;
668 }
669 naga::SwitchValue::U32(value) => {
670 write!(out, "{value}u32")?;
671 }
672 naga::SwitchValue::Default => {
673 write!(out, "_")?;
674 }
675 }
676
677 new_match_arm = !case.fall_through;
678
679 if new_match_arm {
682 writeln!(out, " => {{")?;
683 for sta in case.body.iter() {
684 self.write_stmt(out, module, sta, func_ctx, l2.next())?;
685 }
686 writeln!(out, "{l2}}}")?;
687 }
688 }
689
690 writeln!(out, "{level}}}")?;
691 }
692 Statement::Loop {
693 ref body,
694 ref continuing,
695 break_if,
696 } => {
697 write!(out, "{level}")?;
698 writeln!(out, "loop {{")?;
699
700 let l2 = level.next();
701 for sta in body.iter() {
702 self.write_stmt(out, module, sta, func_ctx, l2)?;
703 }
704
705 if !continuing.is_empty() {
706 return Err(Error::Unimplemented("continuing".into()));
707 }
708 if break_if.is_some() {
709 return Err(Error::Unimplemented("break_if".into()));
710 }
711
712 writeln!(out, "{level}}}")?;
713 }
714 Statement::Break => writeln!(out, "{level}break;")?,
715 Statement::Continue => writeln!(out, "{level}continue;")?,
716 Statement::Barrier(_) => {
717 return Err(Error::Unimplemented("barriers".into()));
718 }
719 Statement::RayQuery { .. } => {
720 return Err(Error::Unimplemented("raytracing".into()));
721 }
722 Statement::SubgroupBallot { .. }
723 | Statement::SubgroupCollectiveOperation { .. }
724 | Statement::SubgroupGather { .. } => {
725 return Err(Error::Unimplemented("workgroup operations".into()));
726 }
727 }
728
729 Ok(())
730 }
731
732 fn plain_form_indirection(
743 &self,
744 expr: Handle<Expression>,
745 module: &Module,
746 expr_ctx: &ExpressionCtx<'_>,
747 ) -> Indirection {
748 use naga::Expression as Ex;
749
750 if self.named_expressions.contains_key(&expr) {
754 return Indirection::Ordinary;
755 }
756
757 match expr_ctx.expressions()[expr] {
758 Ex::LocalVariable(_) => Indirection::Place,
762
763 Ex::GlobalVariable(handle) => {
769 let global = &module.global_variables[handle];
770 match global.space {
771 naga::AddressSpace::Handle => Indirection::Ordinary,
772 _ => Indirection::Place,
773 }
774 }
775
776 Ex::Access { base, .. } | Ex::AccessIndex { base, .. } => {
778 let base_ty = expr_ctx.resolve_type(base, &module.types);
779 match *base_ty {
780 TypeInner::Pointer { .. } | TypeInner::ValuePointer { .. } => {
781 Indirection::Place
782 }
783 _ => Indirection::Ordinary,
784 }
785 }
786 _ => Indirection::Ordinary,
787 }
788 }
789
790 fn start_named_expr(
791 &self,
792 out: &mut dyn Write,
793 module: &Module,
794 handle: Handle<Expression>,
795 func_ctx: &back::FunctionCtx<'_>,
796 name: &str,
797 ) -> BackendResult {
798 write!(out, "let {name}")?;
800 if self.config.flags.contains(WriterFlags::EXPLICIT_TYPES) {
801 write!(out, ": ")?;
802 let ty = &func_ctx.info[handle].ty;
803 match *ty {
805 proc::TypeResolution::Handle(ty_handle) => {
806 self.write_type(out, module, ty_handle)?;
807 }
808 proc::TypeResolution::Value(ref inner) => {
809 self.write_type_inner(out, module, inner)?;
810 }
811 }
812 }
813
814 write!(out, " = ")?;
815 Ok(())
816 }
817
818 fn write_expr(
822 &self,
823 out: &mut dyn Write,
824 module: &Module,
825 expr: Handle<Expression>,
826 expr_ctx: &ExpressionCtx<'_>,
827 ) -> BackendResult {
828 self.write_expr_with_indirection(out, module, expr, expr_ctx, Indirection::Ordinary)
829 }
830
831 fn write_expr_with_indirection(
840 &self,
841 out: &mut dyn Write,
842 module: &Module,
843 expr: Handle<Expression>,
844 expr_ctx: &ExpressionCtx<'_>,
845 requested: Indirection,
846 ) -> BackendResult {
847 let plain = self.plain_form_indirection(expr, module, expr_ctx);
850 match (requested, plain) {
851 (Indirection::Ordinary, Indirection::Place) => {
855 write!(out, "(&")?;
856 self.write_expr_plain_form(out, module, expr, expr_ctx, plain)?;
857 write!(out, ")")?;
858 }
859
860 (Indirection::Place, Indirection::Ordinary) => {
863 write!(out, "(*")?;
864 self.write_expr_plain_form(out, module, expr, expr_ctx, plain)?;
865 write!(out, ")")?;
866 }
867 (Indirection::Place, Indirection::Place)
869 | (Indirection::Ordinary, Indirection::Ordinary) => {
870 self.write_expr_plain_form(out, module, expr, expr_ctx, plain)?
871 }
872 }
873
874 Ok(())
875 }
876
877 fn write_expr_plain_form(
890 &self,
891 out: &mut dyn Write,
892 module: &Module,
893 expr: Handle<Expression>,
894 expr_ctx: &ExpressionCtx<'_>,
895 indirection: Indirection,
896 ) -> BackendResult {
897 if let Some(name) = self.named_expressions.get(&expr) {
898 write!(out, "{name}")?;
899 return Ok(());
900 }
901
902 let expression = &expr_ctx.expressions()[expr];
903 let runtime_path = &self.config.runtime_path;
904
905 match *expression {
906 Expression::Literal(literal) => match literal {
907 naga::Literal::F32(value) => write!(out, "{value}f32")?,
908 naga::Literal::U32(value) => write!(out, "{value}u32")?,
909 naga::Literal::I32(value) => {
910 write!(out, "{value}i32")?;
911 }
912 naga::Literal::Bool(value) => write!(out, "{value}")?,
913 naga::Literal::F64(value) => write!(out, "{value}f64")?,
914 naga::Literal::I64(value) => {
915 write!(out, "{value}i64")?;
916 }
917 naga::Literal::U64(value) => write!(out, "{value}u64")?,
918 naga::Literal::AbstractInt(_) | naga::Literal::AbstractFloat(_) => {
919 unreachable!("abstract types should not appear in IR presented to backends");
920 }
921 },
922 Expression::Constant(handle) => {
923 let constant = &module.constants[handle];
924 if constant.name.is_some() {
925 write!(out, "{}", self.names[&NameKey::Constant(handle)])?;
926 } else {
927 self.write_expr(out, module, constant.init, expr_ctx)?;
928 }
929 }
930 Expression::ZeroValue(ty) => {
931 write!(out, "{runtime_path}::zero::<")?;
932 self.write_type(out, module, ty)?;
933 write!(out, ">()")?;
934 }
935 Expression::Compose { ty, ref components } => {
936 self.write_constructor_expression(out, module, ty, components, expr_ctx)?;
937 }
938 Expression::Splat { size, value } => {
939 let size = conv::vector_size_str(size);
940 write!(out, "{runtime_path}::Vec{size}::splat(")?;
942 self.write_expr(out, module, value, expr_ctx)?;
943 write!(out, ")")?;
944 }
945 Expression::Override(_) => unreachable!(),
946 Expression::FunctionArgument(pos) => {
947 let name_key = expr_ctx.expect_func_ctx().argument_key(pos);
948 let name = &self.names[&name_key];
949 write!(out, "{name}")?;
950 }
951 Expression::Binary { op, left, right } => {
952 let inputs_are_scalar = matches!(
953 *expr_ctx.resolve_type(left, &module.types),
954 TypeInner::Scalar(_)
955 ) && matches!(
956 *expr_ctx.resolve_type(right, &module.types),
957 TypeInner::Scalar(_)
958 );
959 match (inputs_are_scalar, BinOpClassified::from(op)) {
960 (true, BinOpClassified::ScalarBool(_))
961 | (_, BinOpClassified::Vectorizable(_)) => {
962 write!(out, "(")?;
963 self.write_expr(out, module, left, expr_ctx)?;
964 write!(out, " {} ", back::binary_operation_str(op))?;
967 self.write_expr(out, module, right, expr_ctx)?;
968 write!(out, ")")?;
969 }
970 (_, BinOpClassified::ScalarBool(bop)) => {
971 self.write_expr(out, module, left, expr_ctx)?;
972 write!(out, ".{}(", bop.to_vector_method())?;
973 self.write_expr(out, module, right, expr_ctx)?;
974 write!(out, ")")?;
975 }
976 }
977 }
978 Expression::Access { base, index } => {
979 self.write_expr_with_indirection(out, module, base, expr_ctx, indirection)?;
980 write!(out, "[")?;
981 self.write_expr(out, module, index, expr_ctx)?;
982 write!(out, " as usize]")?
983 }
984 Expression::AccessIndex { base, index } => {
985 let base_ty_res = &expr_ctx.expect_func_ctx().info[base].ty;
986 let mut resolved = base_ty_res.inner_with(&module.types);
987
988 self.write_expr_with_indirection(out, module, base, expr_ctx, indirection)?;
989
990 let base_ty_handle = match *resolved {
991 TypeInner::Pointer { base, space: _ } => {
992 resolved = &module.types[base].inner;
993 Some(base)
994 }
995 _ => base_ty_res.handle(),
996 };
997
998 match *resolved {
999 TypeInner::Vector { .. } => {
1000 write!(out, ".{}", back::COMPONENTS[index as usize])?
1002 }
1003 TypeInner::Matrix { .. }
1004 | TypeInner::Array { .. }
1005 | TypeInner::BindingArray { .. }
1006 | TypeInner::ValuePointer { .. } => write!(out, "[{index} as usize]")?,
1007 TypeInner::Struct { .. } => {
1008 let ty = base_ty_handle.unwrap();
1011
1012 write!(out, ".{}", &self.names[&NameKey::StructMember(ty, index)])?
1013 }
1014 ref other => unreachable!("cannot index into a {other:?}"),
1015 }
1016 }
1017 Expression::ImageSample { .. } => {
1018 return Err(Error::TexturesAreUnsupported {
1019 found: "textureSample",
1020 });
1021 }
1022 Expression::ImageQuery { .. } => {
1023 return Err(Error::TexturesAreUnsupported {
1024 found: "texture queries",
1025 });
1026 }
1027 Expression::ImageLoad { .. } => {
1028 return Err(Error::TexturesAreUnsupported {
1029 found: "textureLoad",
1030 });
1031 }
1032 Expression::GlobalVariable(handle) => {
1033 let name = &self.names[&NameKey::GlobalVariable(handle)];
1034 write!(out, "self.{name}")?;
1035 }
1036
1037 Expression::As {
1038 expr,
1039 kind: to_kind,
1040 convert: to_width,
1041 } => {
1042 use naga::TypeInner as Ti;
1043
1044 let input_type = expr_ctx.resolve_type(expr, &module.types);
1045
1046 self.write_expr(out, module, expr, expr_ctx)?;
1047 match (input_type, to_kind, to_width) {
1048 (&Ti::Vector { size: _, scalar: _ }, to_kind, Some(to_width)) => {
1049 write!(
1051 out,
1052 ".cast_elem_as_{elem_ty}()",
1053 elem_ty = unwrap_to_rust(Scalar {
1054 kind: to_kind,
1055 width: to_width
1056 }),
1057 )?;
1058 }
1059 (&Ti::Scalar(_), to_kind, Some(to_width)) => {
1060 write!(
1063 out,
1064 " as {}",
1065 unwrap_to_rust(Scalar {
1066 kind: to_kind,
1067 width: to_width,
1068 })
1069 )?;
1070 }
1071 _ => {
1072 write!(
1074 out,
1075 " as _/* cast {input_type:?} to kind {to_kind:?} width {to_width:?} */"
1076 )?;
1077 }
1078 }
1079 }
1080 Expression::Load { pointer } => {
1081 self.write_expr_with_indirection(
1082 out,
1083 module,
1084 pointer,
1085 expr_ctx,
1086 Indirection::Place,
1087 )?;
1088 }
1089 Expression::LocalVariable(handle) => write!(
1090 out,
1091 "{}",
1092 self.names[&expr_ctx.expect_func_ctx().name_key(handle)]
1093 )?,
1094 Expression::ArrayLength(expr) => {
1095 self.write_expr(out, module, expr, expr_ctx)?;
1096 write!(out, ".len()")?;
1097 }
1098
1099 Expression::Math {
1100 fun,
1101 arg,
1102 arg1,
1103 arg2,
1104 arg3,
1105 } => {
1106 self.write_expr(out, module, arg, expr_ctx)?;
1107 write!(
1108 out,
1109 ".{method}(",
1110 method = conv::math_function_to_method(fun)
1111 )?;
1112 for arg in [arg1, arg2, arg3].into_iter().flatten() {
1113 self.write_expr(out, module, arg, expr_ctx)?;
1114 write!(out, ", ")?;
1115 }
1116 write!(out, ")")?
1117 }
1118
1119 Expression::Swizzle {
1120 size,
1121 vector,
1122 pattern,
1123 } => {
1124 self.write_expr(out, module, vector, expr_ctx)?;
1125 write!(out, ".")?;
1126 for &sc in pattern[..size as usize].iter() {
1127 out.write_char(back::COMPONENTS[sc as usize])?;
1128 }
1129 write!(out, "()")?;
1130 }
1131 Expression::Unary { op, expr } => {
1132 let unary = match op {
1133 naga::UnaryOperator::Negate => "-",
1134 naga::UnaryOperator::LogicalNot => "!",
1135 naga::UnaryOperator::BitwiseNot => "!",
1136 };
1137
1138 write!(out, "({unary}")?;
1142 self.write_expr(out, module, expr, expr_ctx)?;
1143
1144 write!(out, ")")?
1145 }
1146
1147 Expression::Select {
1148 condition,
1149 accept,
1150 reject,
1151 } => {
1152 let suffix = match *expr_ctx.resolve_type(condition, &module.types) {
1153 TypeInner::Scalar(Scalar::BOOL) => "",
1154 TypeInner::Vector {
1155 size,
1156 scalar: Scalar::BOOL,
1157 } => conv::vector_size_str(size),
1158 _ => unreachable!("validation should have rejected this"),
1159 };
1160 write!(out, "{runtime_path}::select{suffix}(")?;
1161 self.write_expr(out, module, reject, expr_ctx)?;
1162 write!(out, ", ")?;
1163 self.write_expr(out, module, accept, expr_ctx)?;
1164 write!(out, ", ")?;
1165 self.write_expr(out, module, condition, expr_ctx)?;
1166 write!(out, ")")?
1167 }
1168 Expression::Derivative { .. } => {
1169 return Err(Error::Unimplemented("derivatives".into()));
1170
1171 }
1187 Expression::Relational { fun, argument } => {
1188 use naga::RelationalFunction as Rf;
1189
1190 let fun_name = match fun {
1191 Rf::All => "all",
1192 Rf::Any => "any",
1193 Rf::IsNan => "is_nan",
1194 Rf::IsInf => "is_inf",
1195 };
1196 write!(out, "{runtime_path}::{fun_name}(")?;
1197 self.write_expr(out, module, argument, expr_ctx)?;
1198 write!(out, ")")?
1199 }
1200 Expression::RayQueryGetIntersection { .. } => unreachable!(),
1202 Expression::CallResult(_)
1204 | Expression::AtomicResult { .. }
1205 | Expression::RayQueryProceedResult
1206 | Expression::SubgroupBallotResult
1207 | Expression::SubgroupOperationResult { .. }
1208 | Expression::WorkGroupUniformLoadResult { .. } => {}
1209 }
1210
1211 Ok(())
1212 }
1213
1214 fn write_constructor_expression(
1219 &self,
1220 out: &mut dyn Write,
1221 module: &Module,
1222 ty: Handle<naga::Type>,
1223 components: &[Handle<Expression>],
1224 expr_ctx: &ExpressionCtx<'_>,
1225 ) -> BackendResult {
1226 use naga::VectorSize::{Bi, Quad, Tri};
1227
1228 let ctor_name = match module.types[ty].inner {
1229 TypeInner::Vector { size, scalar: _ } => {
1230 let arg_sizes: ArrayVec<u8, 4> = components
1234 .iter()
1235 .map(|&component_expr| match *expr_ctx.resolve_type(component_expr, &module.types) {
1236 TypeInner::Scalar(_) => 1,
1237 TypeInner::Vector { size, .. } => size as u8,
1238 ref t => unreachable!(
1239 "vector constructor argument should be a scalar or vector, not {t:?}"
1240 ),
1241 })
1242 .collect();
1243
1244 match (size, &*arg_sizes) {
1245 (Bi, [1, 1]) => "new",
1246 (Bi, [2]) => "from",
1247 (Tri, [1, 1, 1]) => "new",
1248 (Tri, [1, 2]) => "new_12",
1249 (Tri, [2, 1]) => "new_21",
1250 (Quad, [1, 1, 1, 1]) => "new",
1251 (Quad, [1, 1, 2]) => "new_112",
1252 (Quad, [1, 2, 1]) => "new_121",
1253 (Quad, [2, 1, 1]) => "new_211",
1254 (Quad, [2, 2]) => "new_22",
1255 (Quad, [1, 3]) => "new_13",
1256 (Quad, [3, 1]) => "new_31",
1257 (Quad, [4]) => "from",
1258 _ => unreachable!("vector constructor given too many components {arg_sizes:?}"),
1259 }
1260 }
1261
1262 TypeInner::Array {
1263 base: _,
1264 size,
1265 stride: _,
1266 } => {
1267 assert!(matches!(size, naga::ArraySize::Constant(_)));
1268
1269 write!(out, "[")?;
1271 for (index, component) in components.iter().enumerate() {
1272 if index > 0 {
1273 write!(out, ", ")?;
1274 }
1275 self.write_expr(out, module, *component, expr_ctx)?;
1276 }
1277 write!(out, "]")?;
1278
1279 return Ok(());
1280 }
1281
1282 _ => "new",
1285 };
1286
1287 write!(out, "<")?;
1288 self.write_type(out, module, ty)?;
1289 write!(out, ">::{ctor_name}(")?;
1290 for (index, component) in components.iter().enumerate() {
1291 if index > 0 {
1292 write!(out, ", ")?;
1293 }
1294 self.write_expr(out, module, *component, expr_ctx)?;
1295 }
1296 write!(out, ")")?;
1297
1298 Ok(())
1299 }
1300
1301 pub(super) fn write_type(
1302 &self,
1303 out: &mut dyn Write,
1304 module: &Module,
1305 handle: Handle<naga::Type>,
1306 ) -> BackendResult {
1307 let ty = &module.types[handle];
1308 match ty.inner {
1309 TypeInner::Struct { .. } => {
1310 out.write_str(self.names[&NameKey::Type(handle)].as_str())?
1311 }
1312 ref other => self.write_type_inner(out, module, other)?,
1313 }
1314
1315 Ok(())
1316 }
1317
1318 fn write_type_inner(
1319 &self,
1320 out: &mut dyn Write,
1321 module: &Module,
1322 inner: &TypeInner,
1323 ) -> BackendResult {
1324 let runtime_path = &self.config.runtime_path;
1325 match *inner {
1326 TypeInner::Vector { size, scalar } => write!(
1327 out,
1328 "{runtime_path}::Vec{}<{}>",
1329 conv::vector_size_str(size),
1330 unwrap_to_rust(scalar),
1331 )?,
1332 TypeInner::Sampler { comparison: false } => {
1333 write!(out, "{runtime_path}::Sampler")?;
1334 }
1335 TypeInner::Sampler { comparison: true } => {
1336 write!(out, "{runtime_path}::SamplerComparison")?;
1337 }
1338 TypeInner::Image { .. } => {
1339 write!(out, "{runtime_path}::Image")?;
1340 }
1341 TypeInner::Scalar(scalar) => {
1342 write!(out, "{}", unwrap_to_rust(scalar))?;
1343 }
1344 TypeInner::Atomic(scalar) => {
1345 write!(
1346 out,
1347 "::core::sync::atomic::{}",
1348 conv::atomic_type_name(scalar)?
1349 )?;
1350 }
1351 TypeInner::Array {
1352 base,
1353 size,
1354 stride: _,
1355 } => {
1356 write!(out, "[")?;
1357 match size {
1358 naga::ArraySize::Constant(len) => {
1359 self.write_type(out, module, base)?;
1360 write!(out, "; {len}")?;
1361 }
1362 naga::ArraySize::Pending(..) => {
1363 return Err(Error::Unimplemented("override array size".into()));
1364 }
1365 naga::ArraySize::Dynamic => {
1366 self.write_type(out, module, base)?;
1367 }
1368 }
1369 write!(out, "]")?;
1370 }
1371 TypeInner::BindingArray { .. } => {}
1372 TypeInner::Matrix { .. } => {
1373 return Err(Error::Unimplemented("matrices".into()));
1374 }
1375 TypeInner::Pointer { base, space: _ } => {
1376 if self.config.flags.contains(WriterFlags::RAW_POINTERS) {
1377 write!(out, "*mut ")?;
1378 } else {
1379 write!(out, "&mut ")?;
1380 }
1381 self.write_type(out, module, base)?;
1382 }
1383 TypeInner::ValuePointer {
1384 size: _,
1385 scalar: _,
1386 space: _,
1387 } => {
1388 if self.config.flags.contains(WriterFlags::RAW_POINTERS) {
1389 write!(out, "*mut ")?;
1390 } else {
1391 write!(out, "&mut ")?;
1392 }
1393 todo!()
1394 }
1395 TypeInner::Struct { .. } => {
1396 unreachable!("should only see a struct by name");
1397 }
1398 TypeInner::AccelerationStructure => {
1399 return Err(Error::Unimplemented("type AccelerationStructure".into()));
1400 }
1401 TypeInner::RayQuery => {
1402 return Err(Error::Unimplemented("type RayQuery".into()));
1403 }
1404 }
1405
1406 Ok(())
1407 }
1408
1409 fn write_global_variable_as_struct_field(
1411 &self,
1412 out: &mut dyn Write,
1413 module: &Module,
1414 global: &naga::GlobalVariable,
1415 handle: Handle<naga::GlobalVariable>,
1416 ) -> BackendResult {
1417 let &naga::GlobalVariable {
1419 name: _, space: _, binding: _, ty,
1423 init: _, } = global;
1425
1426 if let Some(naga::ResourceBinding { group, binding }) = global.binding {
1430 writeln!(out, "{INDENT}// group({group}) binding({binding})")?;
1431 }
1432
1433 write!(
1434 out,
1435 "{INDENT}{}: ",
1436 &self.names[&NameKey::GlobalVariable(handle)]
1437 )?;
1438 self.write_type(out, module, ty)?;
1439 writeln!(out, ",")?;
1440
1441 Ok(())
1442 }
1443 fn write_global_variable_as_field_initializer(
1444 &self,
1445 out: &mut dyn Write,
1446 module: &Module,
1447 info: &ModuleInfo,
1448 global: &naga::GlobalVariable,
1449 handle: Handle<naga::GlobalVariable>,
1450 ) -> BackendResult {
1451 write!(
1452 out,
1453 "{INDENT}{INDENT}{}: ",
1454 &self.names[&NameKey::GlobalVariable(handle)]
1455 )?;
1456
1457 if let Some(init) = global.init {
1458 self.write_expr(
1459 out,
1460 module,
1461 init,
1462 &ExpressionCtx::Global {
1463 expressions: &module.global_expressions,
1464 module,
1465 module_info: info,
1466 },
1467 )?;
1468 } else {
1469 write!(out, "Default::default()")?;
1471 }
1472
1473 writeln!(out, ",")?;
1475
1476 Ok(())
1477 }
1478
1479 fn write_global_constant(
1481 &self,
1482 out: &mut dyn Write,
1483 module: &Module,
1484 info: &ModuleInfo,
1485 handle: Handle<naga::Constant>,
1486 ) -> BackendResult {
1487 let name = &self.names[&NameKey::Constant(handle)];
1488 let visibility = self.visibility();
1489 let init = module.constants[handle].init;
1490
1491 write!(
1492 out,
1493 "#[allow(non_upper_case_globals)]\n{visibility}const {name}: "
1494 )?;
1495 self.write_type(out, module, module.constants[handle].ty)?;
1496 write!(out, " = ")?;
1497 self.write_expr(
1498 out,
1499 module,
1500 init,
1501 &ExpressionCtx::Global {
1502 expressions: &module.global_expressions,
1503 module,
1504 module_info: info,
1505 },
1506 )?;
1507 writeln!(out, ";")?;
1508
1509 Ok(())
1510 }
1511
1512 fn visibility(&self) -> &'static str {
1513 if self.config.flags.contains(WriterFlags::PUBLIC) {
1514 "pub "
1515 } else {
1516 ""
1517 }
1518 }
1519}