#![deny(missing_docs)]
use std::collections::HashMap;
use std::fmt::Write;
use t_ree::{
declaration::{Declaration, FunctionDefinition, Module, Newtype},
expression::{Binding, Block, Expression, ExpressionKind, Literal, Statement},
operator::{
ArithmeticOperator, BinaryOperator, BitwiseOperator, ComparisonOperator, LogicalOperator,
UnaryOperator,
},
types::{FloatWidth, IntWidth, Signedness, Type},
};
pub fn compile(module: &Module) -> Result<String, String> {
let mut emitter = Emitter::new();
emitter.emit_module(module)?;
Ok(emitter.output)
}
struct Emitter {
output: String,
indent_level: usize,
newtypes: HashMap<String, Type>,
}
impl Emitter {
fn new() -> Self {
Self {
output: String::new(),
indent_level: 0,
newtypes: HashMap::new(),
}
}
fn push(&mut self, string: &str) {
self.output.push_str(string);
}
fn indent(&mut self) {
for _ in 0..self.indent_level {
self.push(" ");
}
}
fn emit_module(&mut self, module: &Module) -> Result<(), String> {
for declaration in module {
if let Declaration::Type(newtype) = declaration {
self.newtypes
.insert(newtype.name.clone(), newtype.inner_type.clone());
}
}
for declaration in module {
self.emit_declaration(declaration)?;
}
Ok(())
}
fn emit_declaration(&mut self, declaration: &Declaration) -> Result<(), String> {
match declaration {
Declaration::Type(newtype) => self.emit_newtype(newtype),
Declaration::Function(function) => self.emit_function(function),
Declaration::Constant(constant) => {
self.push("const ");
self.push(&constant.name);
self.push(": ");
self.emit_type(&constant.constant_type)?;
self.push(" = ");
self.emit_expression(&constant.value)?;
self.push(";\n");
Ok(())
}
Declaration::Extern(_) => Ok(()),
Declaration::Import(_) => Ok(()),
}
}
fn emit_newtype(&mut self, newtype: &Newtype) -> Result<(), String> {
match &newtype.inner_type {
Type::Tuple(fields) if !fields.is_empty() => {
writeln!(self.output, "struct {} {{", newtype.name).unwrap();
self.indent_level += 1;
for field in fields {
if let Type::Named(field_name) = field {
self.indent();
self.push(field_name);
self.push(": ");
if let Some(inner) = self.newtypes.get(field_name).cloned() {
self.emit_type(&inner)?;
} else {
self.push(field_name);
}
self.push(",\n");
}
}
self.indent_level -= 1;
self.push("}\n\n");
}
_ => {
write!(self.output, "alias {} = ", newtype.name).unwrap();
self.emit_type(&newtype.inner_type)?;
self.push(";\n");
}
}
Ok(())
}
fn emit_type(&mut self, ty: &Type) -> Result<(), String> {
match ty {
Type::Bool => self.push("bool"),
Type::Int(IntWidth::W32, Signedness::Signed) => self.push("i32"),
Type::Int(IntWidth::W32, Signedness::Unsigned) => self.push("u32"),
Type::Int(_, Signedness::Signed) => self.push("i32"),
Type::Int(_, Signedness::Unsigned) => self.push("u32"),
Type::Float(FloatWidth::W32) => self.push("f32"),
Type::Float(FloatWidth::W64) => {
return Err("WGSL does not support f64 (use f32 instead)".into());
}
Type::Array(element, length) => {
self.push("array<");
self.emit_type(element)?;
write!(self.output, ", {length}>").unwrap();
}
Type::Vector(element, count) => {
write!(self.output, "vec{count}<").unwrap();
self.emit_type(element)?;
self.push(">");
}
Type::Named(name) => {
if let Some(inner) = self.newtypes.get(name).cloned() {
if matches!(inner, Type::Tuple(ref fields) if !fields.is_empty()) {
self.push(name);
} else {
self.emit_type(&inner)?;
}
} else {
self.push(name);
}
}
Type::Never => self.push("void"),
Type::Pointer(..) => {
return Err("WGSL does not support pointers".into());
}
_ => {
return Err(format!("unsupported WGSL type: {ty:?}"));
}
}
Ok(())
}
fn emit_function(&mut self, function: &FunctionDefinition) -> Result<(), String> {
if function.name == "main" {
return Ok(());
}
self.push("fn ");
self.push(&function.name);
self.push("(");
for (index, parameter) in function.parameters.iter().enumerate() {
if index > 0 {
self.push(", ");
}
self.push(¶meter.name);
self.push(": ");
if let Some(ref ty) = parameter.parameter_type {
self.emit_type(ty)?;
}
}
self.push(")");
if !function.return_type.is_unit() {
self.push(" -> ");
self.emit_type(&function.return_type)?;
}
self.push(" {\n");
self.indent_level += 1;
self.emit_block(&function.body)?;
self.indent_level -= 1;
self.push("}\n\n");
Ok(())
}
fn emit_block(&mut self, block: &Block) -> Result<(), String> {
let mut index = 0;
while index < block.statements.len() {
if let Statement::Label {
name,
parameters,
initial_arguments,
} = &block.statements[index]
{
self.emit_loop(
name,
parameters,
initial_arguments,
&block.statements[index + 1..],
)?;
break;
}
self.emit_statement(&block.statements[index])?;
index += 1;
}
if let Some(result) = &block.result {
self.indent();
self.push("return ");
self.emit_expression(result)?;
self.push(";\n");
}
Ok(())
}
fn emit_loop(
&mut self,
label_name: &str,
parameters: &[t_ree::declaration::Parameter],
initial_arguments: &[Expression],
body_statements: &[Statement],
) -> Result<(), String> {
for (parameter, init) in parameters.iter().zip(initial_arguments) {
self.indent();
self.push("var ");
self.push(¶meter.name);
if let Some(ref ty) = parameter.parameter_type {
self.push(": ");
self.emit_type(ty)?;
}
self.push(" = ");
self.emit_expression(init)?;
self.push(";\n");
}
let loop_end = body_statements
.iter()
.rposition(|s| Self::contains_jump(s, label_name))
.map_or(body_statements.len(), |i| i + 1);
self.indent();
self.push("loop {\n");
self.indent_level += 1;
for statement in &body_statements[..loop_end] {
self.emit_loop_statement(statement, label_name, parameters)?;
}
self.indent_level -= 1;
self.indent();
self.push("}\n");
for statement in &body_statements[loop_end..] {
self.emit_statement(statement)?;
}
Ok(())
}
fn contains_jump(statement: &Statement, label: &str) -> bool {
match statement {
Statement::Jump { label: target, .. } => target == label,
Statement::Expression(expression) => match &expression.kind {
ExpressionKind::If {
then_branch,
else_branch,
..
} => {
then_branch
.statements
.iter()
.any(|s| Self::contains_jump(s, label))
|| else_branch.as_ref().is_some_and(|b| {
b.statements.iter().any(|s| Self::contains_jump(s, label))
})
}
_ => false,
},
_ => false,
}
}
fn emit_loop_statement(
&mut self,
statement: &Statement,
label_name: &str,
parameters: &[t_ree::declaration::Parameter],
) -> Result<(), String> {
match statement {
Statement::Jump { label, arguments } if label == label_name => {
if !arguments.is_empty() {
for (index, (_parameter, argument)) in
parameters.iter().zip(arguments).enumerate()
{
self.indent();
let temp = format!("_t{index}");
self.push("let ");
self.push(&temp);
self.push(" = ");
self.emit_expression(argument)?;
self.push(";\n");
}
for (index, parameter) in parameters.iter().enumerate() {
self.indent();
self.push(¶meter.name);
writeln!(self.output, " = _t{index};").unwrap();
}
}
self.indent();
self.push("continue;\n");
Ok(())
}
Statement::Expression(expression) => {
if let ExpressionKind::If {
condition,
then_branch,
else_branch,
} = &expression.kind
{
if Self::block_ends_with_jump(then_branch, label_name)
&& else_branch.is_none()
&& then_branch.result.is_none()
{
self.indent();
self.push("if !(");
self.emit_expression(condition)?;
self.push(") { break; }\n");
for inner in &then_branch.statements {
self.emit_loop_statement(inner, label_name, parameters)?;
}
return Ok(());
}
self.indent();
self.push("if (");
self.emit_expression(condition)?;
self.push(") {\n");
self.indent_level += 1;
for inner in &then_branch.statements {
self.emit_loop_statement(inner, label_name, parameters)?;
}
self.indent_level -= 1;
self.indent();
if let Some(else_block) = else_branch {
self.push("} else {\n");
self.indent_level += 1;
for inner in &else_block.statements {
self.emit_loop_statement(inner, label_name, parameters)?;
}
self.indent_level -= 1;
self.indent();
}
self.push("}\n");
return Ok(());
}
self.emit_statement(statement)
}
_ => self.emit_statement(statement),
}
}
fn block_ends_with_jump(block: &Block, label_name: &str) -> bool {
block
.statements
.last()
.is_some_and(|s| matches!(s, Statement::Jump { label, .. } if label == label_name))
}
fn emit_statement(&mut self, statement: &Statement) -> Result<(), String> {
self.indent();
match statement {
Statement::Expression(expression) => {
self.emit_expression(expression)?;
self.push(";\n");
}
Statement::Let {
name,
binding,
value,
..
} => {
match binding {
Binding::Value | Binding::Reference => self.push("let "),
Binding::Variable => self.push("var "),
}
self.push(name);
if let Some(ref ty) = value.resolved_type {
self.push(": ");
self.emit_type(ty)?;
}
self.push(" = ");
self.emit_expression(value)?;
self.push(";\n");
}
Statement::Assign(target, value) => {
self.emit_expression(target)?;
self.push(" = ");
self.emit_expression(value)?;
self.push(";\n");
}
Statement::Return(Some(value)) => {
self.push("return ");
self.emit_expression(value)?;
self.push(";\n");
}
Statement::Return(None) => {
self.push("return;\n");
}
Statement::Label { .. } => {}
Statement::Jump { label, .. } => {
self.push("break; // jump ");
self.push(label);
self.push("\n");
}
Statement::MultiReplace {
targets, values, ..
} => {
for (index, (target, value)) in targets.iter().zip(values).enumerate() {
if index > 0 {
self.indent();
}
self.emit_expression(target)?;
self.push(" = ");
self.emit_expression(value)?;
self.push(";\n");
}
}
Statement::Defer(_) => {
return Err("defer not supported in WGSL".into());
}
}
Ok(())
}
fn emit_expression(&mut self, expression: &Expression) -> Result<(), String> {
match &expression.kind {
ExpressionKind::Literal(literal) => self.emit_literal(literal),
ExpressionKind::Variable(name) => {
self.push(name);
Ok(())
}
ExpressionKind::BinaryOperation(operator, left, right) => {
self.push("(");
self.emit_expression(left)?;
self.push(" ");
self.emit_binary_operator(operator);
self.push(" ");
self.emit_expression(right)?;
self.push(")");
Ok(())
}
ExpressionKind::UnaryOperation(operator, operand) => {
match operator {
UnaryOperator::Negate => self.push("-("),
UnaryOperator::LogicalNot => self.push("!("),
UnaryOperator::BitwiseNot => self.push("~("),
}
self.emit_expression(operand)?;
self.push(")");
Ok(())
}
ExpressionKind::Call(callee, arguments) => {
self.emit_expression(callee)?;
self.push("(");
for (index, argument) in arguments.iter().enumerate() {
if index > 0 {
self.push(", ");
}
self.emit_expression(argument)?;
}
self.push(")");
Ok(())
}
ExpressionKind::Field(object, field) => {
self.emit_expression(object)?;
self.push(".");
self.push(field);
Ok(())
}
ExpressionKind::Index(array, index) => {
self.emit_expression(array)?;
self.push("[");
self.emit_expression(index)?;
self.push("]");
Ok(())
}
ExpressionKind::If {
condition,
then_branch,
else_branch,
} => {
self.push("if (");
self.emit_expression(condition)?;
self.push(") {\n");
self.indent_level += 1;
self.emit_block(then_branch)?;
self.indent_level -= 1;
self.indent();
if let Some(else_block) = else_branch {
self.push("} else {\n");
self.indent_level += 1;
self.emit_block(else_block)?;
self.indent_level -= 1;
self.indent();
}
self.push("}");
Ok(())
}
ExpressionKind::Convert(operand, target_type)
| ExpressionKind::Transmute(operand, target_type) => {
self.emit_type(target_type)?;
self.push("(");
self.emit_expression(operand)?;
self.push(")");
Ok(())
}
ExpressionKind::TypeConstruction(name, fields) => {
self.push(name);
self.push("(");
for (index, (_, value)) in fields.iter().enumerate() {
if index > 0 {
self.push(", ");
}
self.emit_expression(value)?;
}
self.push(")");
Ok(())
}
ExpressionKind::ArrayLiteral(elements) => {
if let Some(ref ty) = expression.resolved_type {
self.emit_type(ty)?;
}
self.push("(");
for (index, element) in elements.iter().enumerate() {
if index > 0 {
self.push(", ");
}
self.emit_expression(element)?;
}
self.push(")");
Ok(())
}
ExpressionKind::Block(block) => {
self.push("{\n");
self.indent_level += 1;
self.emit_block(block)?;
self.indent_level -= 1;
self.indent();
self.push("}");
Ok(())
}
ExpressionKind::Replace(target, value) => {
self.emit_expression(target)?;
self.push(" = ");
self.emit_expression(value)?;
Ok(())
}
ExpressionKind::OpAssign(operator, target, value) => {
self.emit_expression(target)?;
let symbol = match operator {
ArithmeticOperator::Add => " += ",
ArithmeticOperator::Subtract => " -= ",
ArithmeticOperator::Multiply => " *= ",
ArithmeticOperator::Divide => " /= ",
ArithmeticOperator::Remainder => " %= ",
};
self.push(symbol);
self.emit_expression(value)?;
Ok(())
}
ExpressionKind::Dereference(inner) => self.emit_expression(inner),
_ => Err(format!(
"unsupported WGSL expression: {:?}",
expression.kind
)),
}
}
fn emit_literal(&mut self, literal: &Literal) -> Result<(), String> {
match literal {
Literal::Integer(value) => write!(self.output, "{value}").unwrap(),
Literal::Float(value) => {
let string = format!("{value}");
if string.contains('.') {
self.push(&string);
} else {
write!(self.output, "{value}.0").unwrap();
}
}
Literal::Bool(value) => write!(self.output, "{value}").unwrap(),
Literal::String(_) => return Err("strings not supported in WGSL".into()),
Literal::Null => return Err("null pointers not supported in WGSL".into()),
}
Ok(())
}
fn emit_binary_operator(&mut self, operator: &BinaryOperator) {
let symbol = match operator {
BinaryOperator::Arithmetic(ArithmeticOperator::Add) => "+",
BinaryOperator::Arithmetic(ArithmeticOperator::Subtract) => "-",
BinaryOperator::Arithmetic(ArithmeticOperator::Multiply) => "*",
BinaryOperator::Arithmetic(ArithmeticOperator::Divide) => "/",
BinaryOperator::Arithmetic(ArithmeticOperator::Remainder) => "%",
BinaryOperator::Comparison(ComparisonOperator::Equal) => "==",
BinaryOperator::Comparison(ComparisonOperator::NotEqual) => "!=",
BinaryOperator::Comparison(ComparisonOperator::Less) => "<",
BinaryOperator::Comparison(ComparisonOperator::LessEqual) => "<=",
BinaryOperator::Comparison(ComparisonOperator::Greater) => ">",
BinaryOperator::Comparison(ComparisonOperator::GreaterEqual) => ">=",
BinaryOperator::Logical(LogicalOperator::And) => "&&",
BinaryOperator::Logical(LogicalOperator::Or) => "||",
BinaryOperator::Bitwise(BitwiseOperator::And) => "&",
BinaryOperator::Bitwise(BitwiseOperator::Or) => "|",
BinaryOperator::Bitwise(BitwiseOperator::Xor) => "^",
BinaryOperator::Bitwise(BitwiseOperator::ShiftLeft) => "<<",
BinaryOperator::Bitwise(BitwiseOperator::ShiftRight) => ">>",
};
self.push(symbol);
}
}
#[cfg(test)]
mod tests {
use super::*;
fn compile_c_to_wgsl(source: &str) -> Result<String, String> {
let mut module = t_parser_c::parse(source).map_err(|e| e.to_string())?;
let errors = t_ree::resolve::resolve_module(&mut module, true);
if !errors.is_empty() {
return Err(errors.join("\n"));
}
compile(&module)
}
fn validate_wgsl(source: &str) -> Result<(), String> {
let wgsl = compile_c_to_wgsl(source)?;
naga::front::wgsl::parse_str(&wgsl).map_err(|error| {
format!("WGSL validation failed:\n{error}\n\nGenerated WGSL:\n{wgsl}")
})?;
Ok(())
}
#[test]
fn simple_function() {
validate_wgsl("type a = f32; type b = f32; fn add(a, b) -> f32 { (f32: a) + (f32: b) }")
.unwrap();
}
#[test]
fn newtype_alias() {
validate_wgsl("type meters = f32; fn test(meters) -> f32 { (f32: meters) }").unwrap();
}
#[test]
fn struct_type() {
validate_wgsl(
"type x = f32; type y = f32; type z = f32; type point3 = x & y & z;\n\
fn test() -> f32 { let v = point3 { x: 1.0, y: 2.0, z: 3.0 }; v.x }",
)
.unwrap();
}
#[test]
fn constant() {
validate_wgsl("pub const PI: f32 = 3.14159; fn test() -> f32 { PI }").unwrap();
}
#[test]
fn if_else() {
validate_wgsl(
"fn test(x: f32) -> f32 { if x > 0.0 { return x; } else { return 0.0 - x; } 0.0 }",
)
.unwrap();
}
#[test]
fn variables() {
validate_wgsl(
"type a = f32; type b = f32;\n\
fn test(a, b) -> f32 {\n\
var result: f32 = (f32: a);\n\
result := result + (f32: b);\n\
return result;\n\
}",
)
.unwrap();
}
#[test]
fn array_type() {
validate_wgsl(
"fn sum(arr: [f32]3) -> f32 {\n\
return arr[0] + arr[1] + arr[2];\n\
}",
)
.unwrap();
}
#[test]
fn arithmetic() {
validate_wgsl("type a = f32; type b = f32; fn test(a, b) -> f32 { ((f32: a) + (f32: b)) * ((f32: a) - (f32: b)) / ((f32: a) + 1.0) }").unwrap();
}
#[test]
fn comparison() {
validate_wgsl("type a = f32; type b = f32; fn test(a, b) -> bool { (f32: a) < (f32: b) }")
.unwrap();
}
#[test]
fn integer_ops() {
validate_wgsl(
"type a = i32; type b = i32; fn test(a, b) -> i32 { ((i32: a) + (i32: b)) * 2 }",
)
.unwrap();
}
#[test]
fn sdf_functions() {
validate_wgsl(
"extern fn sqrt(x: f32) -> f32;\n\
extern fn abs(x: f32) -> f32;\n\
extern fn min(a: f32, b: f32) -> f32;\n\
extern fn max(a: f32, b: f32) -> f32;\n\
extern fn clamp(x: f32, lo: f32, hi: f32) -> f32;\n\
extern fn mix(a: f32, b: f32, t: f32) -> f32;\n\
type px = f32; type py = f32; type pz = f32;\n\
type point = px & py & pz;\n\
type radius = f32;\n\
fn sphere_sdf(point, radius) -> f32 {\n\
let d = sqrt(point.px * point.px + point.py * point.py + point.pz * point.pz);\n\
return d - (f32: radius);\n\
}\n\
type distance_a = f32; type distance_b = f32; type smoothness = f32;\n\
fn smooth_union(distance_a, distance_b, smoothness) -> f32 {\n\
let h = clamp(0.5 + 0.5 * ((f32: distance_b) - (f32: distance_a)) / (f32: smoothness), 0.0, 1.0);\n\
return mix((f32: distance_b), (f32: distance_a), h) - (f32: smoothness) * h * (1.0 - h);\n\
}",
)
.unwrap();
}
#[test]
fn label_jump_loop() {
validate_wgsl(
"fn sum_to(n: i32) -> i32 {\n\
label loop(i: i32 = 0, total: i32 = 0);\n\
if i < n {\n\
jump loop(i + 1, total + i);\n\
}\n\
return total;\n\
}",
)
.unwrap();
}
#[test]
fn factorial_loop() {
validate_wgsl(
"fn factorial(n: i32) -> i32 {\n\
label loop(i: i32 = 1, result: i32 = 1);\n\
if i <= n {\n\
jump loop(i + 1, result * i);\n\
}\n\
return result;\n\
}",
)
.unwrap();
}
#[test]
fn f64_rejected() {
assert!(compile_c_to_wgsl("fn test(x: f64) -> f64 { x }").is_err());
}
#[test]
fn pointer_rejected() {
assert!(compile_c_to_wgsl("fn test(x: |i32) -> i32 { x }").is_err());
}
}