makepad_shader_ast_impl/
lib.rs

1// This proc_macro is used to transform a rust closure function
2// of the following form
3// shader_ast!(||{
4//      // var def:
5//      let x:float<Uniform> = 10.0;
6//      // fn def:
7//      fn pixel()->vec4{
8//          return vec4(1.);
9//      }
10//})
11// into a nested tree of shader AST structs
12// these are defined in shader.rs in the root project
13// which looks something like the following:
14// ShAst{
15//      vars:vec![ShVar{name:"x".to_string(), ty:"float".to_string()}]   
16// }
17// The subset of Rust syntax we support is directly related to
18// a mapping of GLSL.
19// types have to be simple names like float or vec4
20// we support for loops only with integer ranges
21// think of the subset as how you would write GLSL with a Rust syntax
22// not as what you can write in Rust that has no direct
23// word for word match in GLSL.
24
25extern crate proc_macro;
26extern crate proc_macro2;
27use proc_macro_hack::proc_macro_hack;
28use proc_macro2::TokenStream;
29use proc_macro2::Span;
30use syn::{
31    Expr, Type, Pat, Stmt, PathArguments, GenericArgument, 
32    Item, Local, ItemFn, ItemConst, ItemStruct,
33    Lit, Block, FnArg, BinOp, UnOp, Ident, ReturnType, Member
34};
35use quote::quote;
36use quote::quote_spanned;
37use syn::spanned::Spanned;
38
39fn error(span:Span, msg: &str)->TokenStream{
40    let fmsg = format!("shader_ast: {}", msg);
41    quote_spanned!(span=>compile_error!(#fmsg))
42}
43
44// generate the ShVar definitions from a let statement
45fn generate_shvar_defs(stmt:Local)->TokenStream{
46    // lets define a local with storage specified
47    if let Pat::Type(pat) = &stmt.pat{
48        let name =  if let Pat::Ident(ident) = &*pat.pat{
49            ident.ident.to_string()
50        }
51        else{
52            return error(stmt.span(), "Please only use simple identifiers such as x or var_iable");
53        };
54        let found_type;
55        let store;
56        if let Type::Path(typath) = &*pat.ty{
57            if typath.path.segments.len() != 1{
58                return quote!{sh_var(#name, &#typath.shader_type(), #typath.var_store())}
59            }
60
61            if typath.path.segments.len() != 1{
62                return error(typath.span(), "Only simple typenames such as float or vec4 are supported");
63            }
64            let seg = &typath.path.segments[0];
65            found_type = seg.ident.to_string();
66            // lets read the path args
67            if let PathArguments::AngleBracketed(angle) = &seg.arguments{
68                if angle.args.len() != 1{
69                    return error(angle.span(), "Please pass one storage arg like float<Local>");
70                }
71                let arg = &angle.args[0];
72                if let GenericArgument::Type(ty) = arg{
73                    if let Type::Path(typath) = ty{
74                        if typath.path.segments.len() != 1{
75                            return error(typath.span(), "Only simple typenames such as float or vec4 are supported");
76                        }
77                        let seg = &typath.path.segments[0];
78                        store = seg.ident.clone();
79                    }
80                    else{
81                        return error(arg.span(), "Only simple typenames such as float or vec4 are supported");
82                    }
83                }
84                else{
85                    return error(arg.span(), "Please pass one storage arg like float<Local>");
86                }
87            }
88            else{
89               return error(stmt.span(), "Please pass one storage arg like float<Local>");
90            }
91        }
92        else{
93            return error(stmt.span(), "Please give the variable a type of the form float<Local> ");
94        }
95        return quote!{sh_var(#name, #found_type, ShVarStore::#store)}
96    }
97    else{
98        return error(stmt.span(), "Please only use simple identifiers such as x or var_iable {:?}")
99    }
100}
101
102// generate the ShFn definitions from a rust fn statement
103fn generate_fn_def(item:ItemFn)->TokenStream{
104    // alright lets do a function
105    // and then incrementally add all supported ast nodes
106    let name = item.sig.ident.to_string();
107       let mut args = Vec::new();
108    // lets process the fnargs
109    for arg in &item.sig.inputs{
110        if let FnArg::Typed(arg) = arg{
111            // lets look at pat and ty
112            if let Pat::Ident(pat) = &*arg.pat{
113                let name =  pat.ident.to_string();
114                let found_type;
115                if let Type::Path(typath) = &*arg.ty{
116                    if typath.path.segments.len() != 1{
117                        return error(typath.span(), "arg type not simple");
118                    }
119                    let seg = &typath.path.segments[0];
120                    found_type = seg.ident.to_string();
121                }
122                else{
123                    return error(arg.span(), "arg type not simple");
124                }
125                args.push(quote!{sh_fnarg(#name, #found_type)})
126            }
127            else{
128                return error(arg.span(), "arg pattern not simple identifier")
129            }
130        }
131        else{
132             return error(arg.span(), "arg pattern not simple identifier")
133        }
134    }
135    let return_type;
136    if let ReturnType::Type(_, ty) = item.sig.output{
137        if let Type::Path(typath) = *ty{
138            if typath.path.segments.len() != 1{
139                return error(typath.span(), "return type not simple");
140            }
141            let seg = &typath.path.segments[0];
142            return_type = seg.ident.to_string();
143        }
144        else{
145            return error(ty.span(), "return type not simple");
146        }
147    }   
148    else{
149        return_type = "void".to_string();
150        //return error(item.span(), "function needs to specify return type")
151    }
152    let block = generate_block(*item.block);
153    quote!{sh_fn(#name, &[#(#args,)*], #return_type, Some(#block))}
154}
155
156// generate a let statement inside a function
157fn generate_let(local:Local)->TokenStream{
158    // lets define a local with storage specified
159    if let Pat::Ident(ident) = &local.pat{
160        let name = ident.ident.to_string();
161        let init = if let Some((_,local_init)) = local.init{
162            generate_expr(*local_init)
163        }
164        else{
165            return error(local.span(), "let pattern misses initializer");
166        };
167
168        return quote!{sh_let(#name, "", #init)}
169    }
170    else if let Pat::Type(pat) = &local.pat{
171        let name =  if let Pat::Ident(ident) = &*pat.pat{
172            ident.ident.to_string()
173        }
174        else{
175            return error(local.span(), "Please only use simple identifiers such as x or var_iable");
176        };
177        
178        let ty = if let Type::Path(typath) = &*pat.ty{
179            if typath.path.segments.len() != 1{
180                return error(typath.span(), "Only simple typenames such as float or vec4 are supported");
181            }
182            let seg = &typath.path.segments[0];
183            seg.ident.to_string()
184        }
185        else{
186           return error(local.span(), "Only simple typenames such as float or vec4 are supported");
187        };
188
189        let init = if let Some((_,local_init)) = local.init{
190            generate_expr(*local_init)
191        }
192        else{
193            return error(local.span(), "let pattern misses initializer");
194        };
195        
196        return quote!{sh_let(#name, #ty, #init)}
197    }
198    else{
199        return error(local.span(), "let pattern doesn't need type");
200    }
201}
202
203// generate a { } block AST 
204fn generate_block(block:Block)->TokenStream{
205    let mut stmts = Vec::new();
206    for stmt in block.stmts{
207        match stmt{
208            Stmt::Local(stmt)=>{
209                let letstmt = generate_let(stmt);
210                stmts.push(letstmt)
211            }
212            Stmt::Item(stmt)=>{
213                return error(stmt.span(), "Shader functions don't support items");
214            }
215            Stmt::Expr(stmt)=>{
216                let expr = generate_expr(stmt);
217                stmts.push(quote!{sh_exps(#expr)})
218            }
219            Stmt::Semi(stmt, _tok)=>{
220                let expr = generate_expr(stmt);
221                stmts.push(quote!{sh_sems(#expr)})
222            }
223        }
224    }
225    return quote!{sh_block(&[#(#stmts,)*])}
226}
227
228// return the string name of a BinOp enum 
229fn get_binop(op:BinOp)->&'static str{
230    match op{
231        BinOp::Add(_)=>"Add",
232        BinOp::Sub(_)=>"Sub",
233        BinOp::Mul(_)=>"Mul",
234        BinOp::Div(_)=>"Div",
235        BinOp::Rem(_)=>"Rem",
236        BinOp::And(_)=>"And",
237        BinOp::Or(_)=>"Or",
238        BinOp::BitXor(_)=>"BitXor",
239        BinOp::BitAnd(_)=>"BitAnd",
240        BinOp::BitOr(_)=>"BitOr",
241        BinOp::Shl(_)=>"Shl",
242        BinOp::Shr(_)=>"Shr",
243        BinOp::Eq(_)=>"Eq",
244        BinOp::Lt(_)=>"Lt",
245        BinOp::Le(_)=>"Le",
246        BinOp::Ne(_)=>"Ne",
247        BinOp::Ge(_)=>"Ge",
248        BinOp::Gt(_)=>"Gt",
249        BinOp::AddEq(_)=>"AddEq",
250        BinOp::SubEq(_)=>"SubEq",
251        BinOp::MulEq(_)=>"MulEq",
252        BinOp::DivEq(_)=>"DivEq",
253        BinOp::RemEq(_)=>"RemEq",
254        BinOp::BitXorEq(_)=>"BitXorEq",
255        BinOp::BitAndEq(_)=>"BitAndEq",
256        BinOp::BitOrEq(_)=>"BitOrEq",
257        BinOp::ShlEq(_)=>"ShlEq",
258        BinOp::ShrEq(_)=>"ShrEq",
259    }
260}
261
262// generate the AST from an expression
263fn generate_expr(expr:Expr)->TokenStream{
264    match expr{
265        Expr::Call(expr)=>{
266            if let Expr::Path(func) = *expr.func{
267                if func.path.segments.len() != 1{
268                    return error(func.span(), "call identifier not simple");
269                }
270                let seg = &func.path.segments[0].ident.to_string();
271                // lets get all fn args
272                let mut args = Vec::new();
273                for arg in expr.args{
274                    args.push(generate_expr(arg));
275                }
276                
277                //return quote!{ShExpr::ShCall(ShCall{call:#seg.to_string(), args:{let mut v=Vec::new();#(v.push(#args);)*v}})}
278                return quote!{sh_call(#seg, &[#(#args,)*])}
279            }
280            else{
281                 return error(expr.span(), "call identifier not simple");
282            }
283        }
284        Expr::Binary(expr)=>{
285            let left = generate_expr(*expr.left);
286            let right = generate_expr(*expr.right);
287            let op = Ident::new(get_binop(expr.op), Span::call_site());
288            return quote!{sh_bin(#left, #right, ShOp::#op)}
289        }
290        Expr::Unary(expr)=>{
291            let op;
292            if let UnOp::Not(_) = &expr.op{
293                op = Ident::new("Not", Span::call_site());
294            }
295            else if let UnOp::Neg(_) = &expr.op{
296                op = Ident::new("Neg", Span::call_site());
297            }
298            else {
299                return error(expr.span(), "Deref not implemented");
300            }
301            let right = generate_expr(*expr.expr);
302            return quote!{sh_unary(#right, ShUnaryOp::#op)}
303        }
304        Expr::Lit(expr)=>{
305            match expr.lit{
306                Lit::Str(lit)=>{
307                    let value = lit.value();
308                    return quote!{sh_str(#value)}
309                }
310                Lit::Int(lit)=>{
311                    let value = lit.base10_parse::<i64>().unwrap();
312                    return quote!{sh_int(#value)}
313                }
314                Lit::Float(lit)=>{
315                    let value = lit.base10_parse::<f64>().unwrap();
316                    return quote!{sh_fl(#value)}
317                }
318                Lit::Bool(lit)=>{
319                    let value = lit.value;
320                    return quote!{sh_bool(#value)}
321                }
322                _=>{
323                    return error(expr.span(), "Unsupported literal for shader")
324                }
325            }
326        }
327        Expr::Let(expr)=>{
328            return error(expr.span(), "Not implemented Expr::Let")
329        }
330        Expr::If(expr)=>{
331            let cond = generate_expr(*expr.cond);
332            let then_branch = generate_block(expr.then_branch);
333
334            if let Some((_,else_branch)) = expr.else_branch{
335                let else_branch = generate_expr(*else_branch);
336                return quote!{sh_if_else(#cond, #then_branch, #else_branch)}
337            }
338            return quote!{sh_if(#cond, #then_branch)}
339        }
340        Expr::While(expr)=>{
341            let cond = generate_expr(*expr.cond);
342            let block = generate_block(expr.body);
343            return quote!{sh_while(#cond, #block)}
344       }
345        Expr::ForLoop(expr)=>{
346              // lets define a local with storage specified
347            let span = expr.span();
348            if let Pat::Ident(pat) = expr.pat{
349                let name =  pat.ident.to_string();
350                let body = generate_block(expr.body);
351                let from_ts;
352                let to_ts;
353                if let Expr::Range(range) = *expr.expr{
354                    if let Some(from) = range.from {
355                        from_ts = generate_expr(*from);
356                    }
357                    else{
358                        return error(span, "Must provide from range expression")
359                    }
360                    if let Some(to) = range.to {
361                        to_ts = generate_expr(*to);
362                    }
363                    else{
364                        return error(span, "Must provide to range expression")
365                    }
366                }
367                else{
368                    return error(span, "Must provide range expression")
369                }
370                return quote!{sh_for(#name, #from_ts, #to_ts, #body)}
371            }
372            else{
373                return error(expr.span(), "Use simple identifier for for loop")
374            }
375        }
376        Expr::Assign(expr)=>{
377            let left = generate_expr(*expr.left);
378            let right = generate_expr(*expr.right);
379            return quote!{sh_asn(#left, #right)};//ShExpr::ShAssign(ShAssign{left:Box::new(#left),right:Box::new(#right)})}
380        }
381        Expr::AssignOp(expr)=>{
382            let left = generate_expr(*expr.left);
383            let right = generate_expr(*expr.right);
384            let op = Ident::new(get_binop(expr.op), Span::call_site());
385            return quote!{sh_asn_op(#left, #right, ShOp::#op)}
386            // return quote!{ShExpr::ShAssignOp(ShAssignOp{left:Box::new(#left),op:ShBinOp::#op,right:Box::new(#right)})}
387        }
388        Expr::Field(expr)=>{
389            let member;
390            if let Member::Named(ident) = expr.member{
391                member = ident.to_string();
392            }
393            else{
394                return error(expr.span(), "No unnamed members supported")
395            }
396            let base = generate_expr(*expr.base);
397            return quote!{sh_fd(#base, #member)}//ShExpr::ShField(ShField{base:Box::new(#base),member:#member.to_string()})}
398        }
399        Expr::Index(expr)=>{
400            let base = generate_expr(*expr.expr);
401            let index = generate_expr(*expr.index);
402            return quote!{sh_idx(#base, #index)}//ShExpr::ShIndex(ShIndex{base:Box::new(#base),index:Box::new(#index)})}
403        }
404        Expr::Path(expr)=>{
405            if expr.path.segments.len() != 1{
406                return error(expr.span(), "type not simple");
407            }
408            let seg = &expr.path.segments[0].ident.to_string();
409            return quote!{sh_id(#seg)}//ShExpr::ShId(ShId{name:#seg.to_string()})}
410        }
411        Expr::Paren(expr)=>{
412            let expr = generate_expr(*expr.expr);
413            return quote!{sh_par(#expr)}//ShExpr::ShParen(ShParen{expr:Box::new(#expr)})}
414        }
415        Expr::Block(expr)=>{ // process a block expression
416            let block = generate_block(expr.block); 
417            return quote!{ShExpr::ShBlock(#block)}
418        }
419        Expr::Return(expr)=>{
420            if let Some(expr) = expr.expr{
421                let expr = generate_expr(*expr);
422                return quote!{sh_ret(#expr)}
423            }
424            return quote!{sh_retn()}
425        }
426        Expr::Break(_)=>{
427            return quote!{ShExpr::ShBreak(ShBreak{})}
428
429        }
430        Expr::Continue(_)=>{
431            return quote!{ShExpr::ShContinue(ShContinue{})}
432        }
433        _=>{
434            return error(expr.span(), "Unsupported syntax for shader")
435        }
436    }
437}
438
439// generate the ShConst defs
440fn generate_const_def(item:ItemConst)->TokenStream{
441    let name = item.ident.to_string();
442    let ty;
443
444    if let Type::Path(typath) = *item.ty{
445        if typath.path.segments.len() != 1{
446            return error(typath.span(), "const type not a basic identifie");
447        }
448        let seg = &typath.path.segments[0];
449        ty = seg.ident.to_string();
450    }
451    else{
452        return error(item.ty.span(), "const type not a basic identifier");
453    }
454
455    let expr = generate_expr(*item.expr);
456    quote!{
457        ShConst{
458            name:#name.to_string(),
459            ty:#ty.to_string(),
460            value:#expr
461        }
462    }
463}
464
465// generate the ShStruct defs
466fn generate_struct_def(_item:ItemStruct)->TokenStream{
467    TokenStream::new()
468}
469
470// Generate the ShAst rootnode
471fn generate_root(expr:Expr)->TokenStream{
472    let mut vars = Vec::new();
473    let mut fns = Vec::new();
474    let mut consts = Vec::new();
475    let mut structs = Vec::new();
476    match expr {
477        Expr::Block(expr)=>{
478            for stmt in expr.block.stmts{
479                match stmt{
480                    Stmt::Local(stmt)=>{
481                        vars.push(generate_shvar_defs(stmt));
482                    }
483                    Stmt::Item(stmt)=>{
484                        match stmt{
485                            Item::Struct(item)=>{
486                                structs.push(generate_struct_def(item));
487                            }
488                            Item::Const(item)=>{
489                                consts.push(generate_const_def(item));
490                            }
491                            Item::Fn(item)=>{
492                                fns.push(generate_fn_def(item));
493                            }
494                            _=>{
495                                return error(stmt.span(), "Unexpected statement")
496                            }
497                        }
498                    }
499                    Stmt::Expr(stmt)=>{
500                            return error(stmt.span(), "Expression not expected here")
501                    }
502                    Stmt::Semi(stmt, _tok)=>{
503                            return error(stmt.span(), "Statement not expected here")
504                    }
505                }
506            }
507        },
508        _=>{
509            return error(expr.span(), "Expecting block")
510        }
511    };
512    quote!{ 
513        ShAst{
514            types:Vec::new(),//{let mut v=Vec::new();#(v.push(#types);)*v},
515            vars:{let mut v=Vec::new();#(v.push(#vars);)*v},
516            consts:{let mut v=Vec::new();#(v.push(#consts);)*v},
517            fns:{let mut v=Vec::new();#(v.push(#fns);)*v} 
518        }
519    }
520
521}
522
523// The actual macro
524#[proc_macro_hack]
525pub fn shader_ast(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
526    
527    let parsed = syn::parse_macro_input!(input as syn::Expr);
528
529    let ts = generate_root(parsed);
530    proc_macro::TokenStream::from(ts)
531}
532