Skip to main content

kore/codegen/
spirv.rs

1//! SPIR-V Code Generation for GPU shaders
2
3use crate::types::{TypedProgram, TypedItem, TypedShader};
4use crate::error::{KoreResult, KoreError};
5use crate::ast::{Type, ShaderStage, Expr, Stmt, Block, BinaryOp};
6use rspirv::binary::Assemble;
7use rspirv::dr::{Builder, Operand};
8use rspirv::spirv::{Capability, AddressingModel, MemoryModel, ExecutionModel, ExecutionMode, StorageClass, Decoration};
9use std::collections::HashMap;
10
11pub fn generate(program: &TypedProgram) -> KoreResult<Vec<u8>> {
12    let mut builder = Builder::new();
13    
14    // Set capabilities and memory model
15    builder.capability(Capability::Shader);
16    // Add VulkanMemoryModel if targeting Vulkan, but GLSL450 is standard for now
17    builder.memory_model(AddressingModel::Logical, MemoryModel::GLSL450);
18    
19    for item in &program.items {
20        if let TypedItem::Shader(shader) = item {
21            emit_shader(&mut builder, shader)?;
22        }
23    }
24    
25    let module = builder.module();
26    let bytes: Vec<u8> = module.assemble().iter().flat_map(|w| w.to_le_bytes()).collect();
27    Ok(bytes)
28}
29
30struct ShaderContext<'a> {
31    b: &'a mut Builder,
32    // Name -> (SPIR-V ID, AST Type, IsPointer)
33    vars: HashMap<String, (u32, Type, bool)>,
34    output_var: Option<u32>,
35    // Track which variables are struct-wrapped uniforms (need AccessChain)
36    struct_uniforms: std::collections::HashSet<String>,
37    // Cache GLSL extension import
38    glsl_ext: Option<u32>,
39}
40
41fn emit_shader(b: &mut Builder, shader: &TypedShader) -> KoreResult<()> {
42    let exec_model = match shader.ast.stage {
43        ShaderStage::Vertex => ExecutionModel::Vertex,
44        ShaderStage::Fragment => ExecutionModel::Fragment,
45        ShaderStage::Compute => ExecutionModel::GLCompute,
46    };
47    
48    // 1. Define Basic Types
49    let void = b.type_void();
50    
51    // 2. Define Entry Point Function Type
52    let fn_void_void = b.type_function(void, vec![]);
53    
54    // 3. Declare Variables (Global Interface)
55    let mut interface_vars = vec![];
56    let mut ctx_vars = HashMap::new();
57    let mut struct_uniforms = std::collections::HashSet::new();
58
59    // Inputs
60    for (i, param) in shader.ast.inputs.iter().enumerate() {
61        let ty = map_ast_type(b, &param.ty);
62        let ptr_ty = b.type_pointer(None, StorageClass::Input, ty);
63        let var = b.variable(ptr_ty, None, StorageClass::Input, None);
64        b.decorate(var, Decoration::Location, vec![Operand::LiteralBit32(i as u32)]);
65        interface_vars.push(var);
66        ctx_vars.insert(param.name.clone(), (var, param.ty.clone(), true));
67    }
68
69    // Outputs
70    let output_var = if !is_void(&shader.ast.outputs) {
71         let output_ty = map_ast_type(b, &shader.ast.outputs);
72         let ptr_ty = b.type_pointer(None, StorageClass::Output, output_ty);
73         let var = b.variable(ptr_ty, None, StorageClass::Output, None);
74         
75         // Vertex shader output is @builtin(position) for Vec4, otherwise use Location
76         if exec_model == ExecutionModel::Vertex && is_vec4(&shader.ast.outputs) {
77             b.decorate(var, Decoration::BuiltIn, vec![Operand::BuiltIn(rspirv::spirv::BuiltIn::Position)]);
78         } else {
79             b.decorate(var, Decoration::Location, vec![Operand::LiteralBit32(0)]);
80         }
81         
82         interface_vars.push(var);
83         Some(var)
84    } else {
85        None
86    };
87
88    // Uniforms
89    for uniform in &shader.ast.uniforms {
90        let inner_ty = map_ast_type(b, &uniform.ty);
91        
92        // Check if this is a sampler type (uses UniformConstant) or data type (uses Uniform with struct)
93        let is_sampler = matches!(&uniform.ty, Type::Named { name, .. } if name == "Sampler2D");
94        
95        if is_sampler {
96            // Samplers use UniformConstant storage class directly
97            let ptr_ty = b.type_pointer(None, StorageClass::UniformConstant, inner_ty);
98            let var = b.variable(ptr_ty, None, StorageClass::UniformConstant, None);
99            b.decorate(var, Decoration::DescriptorSet, vec![Operand::LiteralBit32(0)]);
100            b.decorate(var, Decoration::Binding, vec![Operand::LiteralBit32(uniform.binding)]);
101            ctx_vars.insert(uniform.name.clone(), (var, uniform.ty.clone(), true));
102        } else {
103            // Data uniforms (matrices, vectors, etc.) need a struct wrapper with Block decoration
104            let struct_ty = b.type_struct(vec![inner_ty]);
105            b.decorate(struct_ty, Decoration::Block, vec![]);
106            // Offset decoration for the first (and only) member
107            b.member_decorate(struct_ty, 0, Decoration::Offset, vec![Operand::LiteralBit32(0)]);
108            
109            // For matrices, we need ColMajor and MatrixStride decorations
110            if matches!(&uniform.ty, Type::Named { name, .. } if name == "Mat4") {
111                b.member_decorate(struct_ty, 0, Decoration::ColMajor, vec![]);
112                b.member_decorate(struct_ty, 0, Decoration::MatrixStride, vec![Operand::LiteralBit32(16)]);
113            }
114            
115            let ptr_ty = b.type_pointer(None, StorageClass::Uniform, struct_ty);
116            let var = b.variable(ptr_ty, None, StorageClass::Uniform, None);
117            b.decorate(var, Decoration::DescriptorSet, vec![Operand::LiteralBit32(0)]);
118            b.decorate(var, Decoration::Binding, vec![Operand::LiteralBit32(uniform.binding)]);
119            ctx_vars.insert(uniform.name.clone(), (var, uniform.ty.clone(), true));
120            struct_uniforms.insert(uniform.name.clone());
121        }
122    }
123
124    // 4. Function Body
125    let main_fn = b.begin_function(void, None, rspirv::spirv::FunctionControl::NONE, fn_void_void).unwrap();
126    b.begin_block(None).unwrap();
127
128    let mut ctx = ShaderContext {
129        b,
130        vars: ctx_vars,
131        output_var,
132        struct_uniforms,
133        glsl_ext: None,
134    };
135
136    emit_block(&mut ctx, &shader.ast.body)?;
137
138    // Ensure we always have a return
139    if shader.ast.body.stmts.last().map_or(true, |s| !matches!(s, Stmt::Return(_, _))) {
140        ctx.b.ret().unwrap();
141    }
142    
143    ctx.b.end_function().unwrap();
144
145    // 5. Entry Point
146    b.entry_point(exec_model, main_fn, &shader.ast.name, interface_vars);
147    
148    if exec_model == ExecutionModel::Fragment {
149        b.execution_mode(main_fn, ExecutionMode::OriginUpperLeft, vec![]);
150    }
151    
152    Ok(())
153}
154
155impl<'a> ShaderContext<'a> {
156    fn get_glsl_ext(&mut self) -> u32 {
157        if let Some(ext) = self.glsl_ext {
158            ext
159        } else {
160            let ext = self.b.ext_inst_import("GLSL.std.450");
161            self.glsl_ext = Some(ext);
162            ext
163        }
164    }
165}
166
167fn emit_block(ctx: &mut ShaderContext, block: &Block) -> KoreResult<()> {
168    for stmt in &block.stmts {
169        match stmt {
170            Stmt::Return(expr, _) => {
171                if let Some(expr) = expr {
172                    if let Some(out_var) = ctx.output_var {
173                        let (val, _) = emit_expr(ctx, expr)?;
174                        ctx.b.store(out_var, val, None, vec![]).unwrap();
175                    }
176                }
177                ctx.b.ret().unwrap();
178            },
179            Stmt::Let { pattern, value, .. } => {
180                if let Some(value) = value {
181                    let (val, ty) = emit_expr(ctx, value)?;
182                    // For now, only simple bindings
183                    if let crate::ast::Pattern::Binding { name, .. } = pattern {
184                        // In SSA, we just map name -> value ID
185                        // We don't support mutation of locals yet (need OpVariable + Store/Load)
186                        ctx.vars.insert(name.clone(), (val, ty, false));
187                    }
188                }
189            },
190            Stmt::Expr(expr) => {
191                emit_expr(ctx, expr)?;
192            },
193            _ => {} // Ignore others for now
194        }
195    }
196    Ok(())
197}
198
199fn emit_expr(ctx: &mut ShaderContext, expr: &Expr) -> KoreResult<(u32, Type)> {
200    match expr {
201        Expr::Ident(name, span) => {
202            if let Some((id, ty, is_ptr)) = ctx.vars.get(name).cloned() {
203                if is_ptr {
204                    // Need to load from pointer
205                    let type_id = map_ast_type(ctx.b, &ty);
206                    
207                    // Check if this is a struct-wrapped uniform
208                    if ctx.struct_uniforms.contains(name) {
209                        // Use AccessChain to get pointer to member 0 of the struct
210                        let ptr_ty = ctx.b.type_pointer(None, StorageClass::Uniform, type_id);
211                        let int_ty = ctx.b.type_int(32, 0);
212                        let zero = ctx.b.constant_bit32(int_ty, 0);
213                        let member_ptr = ctx.b.access_chain(ptr_ty, None, id, vec![zero]).unwrap();
214                        let val_id = ctx.b.load(type_id, None, member_ptr, None, std::iter::empty()).unwrap();
215                        Ok((val_id, ty))
216                    } else {
217                        // Direct load for inputs and non-wrapped uniforms
218                        let val_id = ctx.b.load(type_id, None, id, None, std::iter::empty()).unwrap();
219                        Ok((val_id, ty))
220                    }
221                } else {
222                    Ok((id, ty))
223                }
224            } else {
225                 Err(KoreError::codegen(format!("Unknown variable: {}", name), *span))
226            }
227        },
228        Expr::Binary { left, op, right, .. } => {
229            let (lhs, lhs_ty) = emit_expr(ctx, left)?;
230            let (rhs, rhs_ty) = emit_expr(ctx, right)?;
231            
232            // Map types to SPIR-V types
233            let res_ty_id = map_ast_type(ctx.b, &lhs_ty); // Assume result type matches lhs for now
234            
235            let res_id = match op {
236                BinaryOp::Mul => {
237                    if is_mat4(&lhs_ty) && is_mat4(&rhs_ty) {
238                        ctx.b.matrix_times_matrix(res_ty_id, None, lhs, rhs).unwrap()
239                    } else if is_mat4(&lhs_ty) && is_vec4(&rhs_ty) {
240                        // Mat4 * Vec4 -> Vec4
241                         let vec4_ty = map_ast_type(ctx.b, &rhs_ty);
242                         ctx.b.matrix_times_vector(vec4_ty, None, lhs, rhs).unwrap()
243                    } else if is_vec4(&lhs_ty) && is_mat4(&rhs_ty) {
244                        // Vec4 * Mat4 -> Vec4
245                         let vec4_ty = map_ast_type(ctx.b, &lhs_ty);
246                         ctx.b.vector_times_matrix(vec4_ty, None, lhs, rhs).unwrap()
247                    } else if is_float(&lhs_ty) && is_float(&rhs_ty) {
248                        ctx.b.f_mul(res_ty_id, None, lhs, rhs).unwrap()
249                    } else {
250                         // Fallback to FMul (vector * scalar, etc - simplified)
251                        ctx.b.f_mul(res_ty_id, None, lhs, rhs).unwrap()
252                    }
253                },
254                BinaryOp::Add => ctx.b.f_add(res_ty_id, None, lhs, rhs).unwrap(),
255                BinaryOp::Sub => ctx.b.f_sub(res_ty_id, None, lhs, rhs).unwrap(),
256                BinaryOp::Div => ctx.b.f_div(res_ty_id, None, lhs, rhs).unwrap(),
257                _ => return Err(KoreError::codegen("Unsupported binary op in shader", expr.span())),
258            };
259            
260            // Result type inference (simplified)
261            let res_ty = if is_mat4(&lhs_ty) && is_vec4(&rhs_ty) {
262                rhs_ty
263            } else {
264                lhs_ty
265            };
266            
267            Ok((res_id, res_ty))
268        },
269        Expr::Call { callee, args, .. } => {
270            if let Expr::Ident(name, _) = &**callee {
271                let float = ctx.b.type_float(32);
272                
273                // Vector constructors
274                match name.as_str() {
275                    "vec2" | "Vec2" if args.len() == 2 => {
276                        let vec2 = ctx.b.type_vector(float, 2);
277                        let mut components = vec![];
278                        for arg in args {
279                            let (val, _) = emit_expr(ctx, &arg.value)?;
280                            components.push(val);
281                        }
282                        let res_id = ctx.b.composite_construct(vec2, None, components).unwrap();
283                        return Ok((res_id, Type::Named { name: "Vec2".into(), generics: vec![], span: expr.span() }));
284                    },
285                    "vec3" | "Vec3" if args.len() == 3 => {
286                        let vec3 = ctx.b.type_vector(float, 3);
287                        let mut components = vec![];
288                        for arg in args {
289                            let (val, _) = emit_expr(ctx, &arg.value)?;
290                            components.push(val);
291                        }
292                        let res_id = ctx.b.composite_construct(vec3, None, components).unwrap();
293                        return Ok((res_id, Type::Named { name: "Vec3".into(), generics: vec![], span: expr.span() }));
294                    },
295                    "vec4" | "Vec4" if args.len() == 4 => {
296                        let vec4 = ctx.b.type_vector(float, 4);
297                        let mut components = vec![];
298                        for arg in args {
299                            let (val, _) = emit_expr(ctx, &arg.value)?;
300                            components.push(val);
301                        }
302                        let res_id = ctx.b.composite_construct(vec4, None, components).unwrap();
303                        return Ok((res_id, Type::Named { name: "Vec4".into(), generics: vec![], span: expr.span() }));
304                    },
305                    
306                    // Math functions (GLSL extended instructions)
307                    "sin" if args.len() == 1 => {
308                        let (val, ty) = emit_expr(ctx, &args[0].value)?;
309                        let res_ty = map_ast_type(ctx.b, &ty);
310                        let glsl = ctx.get_glsl_ext();
311                        let res_id = ctx.b.ext_inst(res_ty, None, glsl, 13, vec![Operand::IdRef(val)]).unwrap(); // Sin = 13
312                        return Ok((res_id, ty));
313                    },
314                    "cos" if args.len() == 1 => {
315                        let (val, ty) = emit_expr(ctx, &args[0].value)?;
316                        let res_ty = map_ast_type(ctx.b, &ty);
317                        let glsl = ctx.get_glsl_ext();
318                        let res_id = ctx.b.ext_inst(res_ty, None, glsl, 14, vec![Operand::IdRef(val)]).unwrap(); // Cos = 14
319                        return Ok((res_id, ty));
320                    },
321                    "tan" if args.len() == 1 => {
322                        let (val, ty) = emit_expr(ctx, &args[0].value)?;
323                        let res_ty = map_ast_type(ctx.b, &ty);
324                        let glsl = ctx.b.ext_inst_import("GLSL.std.450");
325                        let res_id = ctx.b.ext_inst(res_ty, None, glsl, 15, vec![Operand::IdRef(val)]).unwrap(); // Tan = 15
326                        return Ok((res_id, ty));
327                    },
328                    "pow" if args.len() == 2 => {
329                        let (base, ty) = emit_expr(ctx, &args[0].value)?;
330                        let (exp, _) = emit_expr(ctx, &args[1].value)?;
331                        let res_ty = map_ast_type(ctx.b, &ty);
332                        let glsl = ctx.b.ext_inst_import("GLSL.std.450");
333                        let res_id = ctx.b.ext_inst(res_ty, None, glsl, 26, vec![Operand::IdRef(base), Operand::IdRef(exp)]).unwrap(); // Pow = 26
334                        return Ok((res_id, ty));
335                    },
336                    "sqrt" if args.len() == 1 => {
337                        let (val, ty) = emit_expr(ctx, &args[0].value)?;
338                        let res_ty = map_ast_type(ctx.b, &ty);
339                        let glsl = ctx.b.ext_inst_import("GLSL.std.450");
340                        let res_id = ctx.b.ext_inst(res_ty, None, glsl, 31, vec![Operand::IdRef(val)]).unwrap(); // Sqrt = 31
341                        return Ok((res_id, ty));
342                    },
343                    "abs" if args.len() == 1 => {
344                        let (val, ty) = emit_expr(ctx, &args[0].value)?;
345                        let res_ty = map_ast_type(ctx.b, &ty);
346                        let glsl = ctx.b.ext_inst_import("GLSL.std.450");
347                        let res_id = ctx.b.ext_inst(res_ty, None, glsl, 4, vec![Operand::IdRef(val)]).unwrap(); // FAbs = 4
348                        return Ok((res_id, ty));
349                    },
350                    "floor" if args.len() == 1 => {
351                        let (val, ty) = emit_expr(ctx, &args[0].value)?;
352                        let res_ty = map_ast_type(ctx.b, &ty);
353                        let glsl = ctx.b.ext_inst_import("GLSL.std.450");
354                        let res_id = ctx.b.ext_inst(res_ty, None, glsl, 8, vec![Operand::IdRef(val)]).unwrap(); // Floor = 8
355                        return Ok((res_id, ty));
356                    },
357                    "ceil" if args.len() == 1 => {
358                        let (val, ty) = emit_expr(ctx, &args[0].value)?;
359                        let res_ty = map_ast_type(ctx.b, &ty);
360                        let glsl = ctx.b.ext_inst_import("GLSL.std.450");
361                        let res_id = ctx.b.ext_inst(res_ty, None, glsl, 9, vec![Operand::IdRef(val)]).unwrap(); // Ceil = 9
362                        return Ok((res_id, ty));
363                    },
364                    "fract" if args.len() == 1 => {
365                        let (val, ty) = emit_expr(ctx, &args[0].value)?;
366                        let res_ty = map_ast_type(ctx.b, &ty);
367                        let glsl = ctx.b.ext_inst_import("GLSL.std.450");
368                        let res_id = ctx.b.ext_inst(res_ty, None, glsl, 10, vec![Operand::IdRef(val)]).unwrap(); // Fract = 10
369                        return Ok((res_id, ty));
370                    },
371                    "min" if args.len() == 2 => {
372                        let (a, ty) = emit_expr(ctx, &args[0].value)?;
373                        let (b, _) = emit_expr(ctx, &args[1].value)?;
374                        let res_ty = map_ast_type(ctx.b, &ty);
375                        let glsl = ctx.b.ext_inst_import("GLSL.std.450");
376                        let res_id = ctx.b.ext_inst(res_ty, None, glsl, 37, vec![Operand::IdRef(a), Operand::IdRef(b)]).unwrap(); // FMin = 37
377                        return Ok((res_id, ty));
378                    },
379                    "max" if args.len() == 2 => {
380                        let (a, ty) = emit_expr(ctx, &args[0].value)?;
381                        let (b, _) = emit_expr(ctx, &args[1].value)?;
382                        let res_ty = map_ast_type(ctx.b, &ty);
383                        let glsl = ctx.b.ext_inst_import("GLSL.std.450");
384                        let res_id = ctx.b.ext_inst(res_ty, None, glsl, 40, vec![Operand::IdRef(a), Operand::IdRef(b)]).unwrap(); // FMax = 40
385                        return Ok((res_id, ty));
386                    },
387                    "clamp" if args.len() == 3 => {
388                        let (val, ty) = emit_expr(ctx, &args[0].value)?;
389                        let (min_val, _) = emit_expr(ctx, &args[1].value)?;
390                        let (max_val, _) = emit_expr(ctx, &args[2].value)?;
391                        let res_ty = map_ast_type(ctx.b, &ty);
392                        let glsl = ctx.b.ext_inst_import("GLSL.std.450");
393                        let res_id = ctx.b.ext_inst(res_ty, None, glsl, 43, vec![Operand::IdRef(val), Operand::IdRef(min_val), Operand::IdRef(max_val)]).unwrap(); // FClamp = 43
394                        return Ok((res_id, ty));
395                    },
396                    "mix" if args.len() == 3 => {
397                        let (a, ty) = emit_expr(ctx, &args[0].value)?;
398                        let (b, _) = emit_expr(ctx, &args[1].value)?;
399                        let (t, _) = emit_expr(ctx, &args[2].value)?;
400                        let res_ty = map_ast_type(ctx.b, &ty);
401                        let glsl = ctx.b.ext_inst_import("GLSL.std.450");
402                        let res_id = ctx.b.ext_inst(res_ty, None, glsl, 46, vec![Operand::IdRef(a), Operand::IdRef(b), Operand::IdRef(t)]).unwrap(); // FMix = 46
403                        return Ok((res_id, ty));
404                    },
405                    "step" if args.len() == 2 => {
406                        let (edge, ty) = emit_expr(ctx, &args[0].value)?;
407                        let (x, _) = emit_expr(ctx, &args[1].value)?;
408                        let res_ty = map_ast_type(ctx.b, &ty);
409                        let glsl = ctx.b.ext_inst_import("GLSL.std.450");
410                        let res_id = ctx.b.ext_inst(res_ty, None, glsl, 48, vec![Operand::IdRef(edge), Operand::IdRef(x)]).unwrap(); // Step = 48
411                        return Ok((res_id, ty));
412                    },
413                    "smoothstep" if args.len() == 3 => {
414                        let (edge0, ty) = emit_expr(ctx, &args[0].value)?;
415                        let (edge1, _) = emit_expr(ctx, &args[1].value)?;
416                        let (x, _) = emit_expr(ctx, &args[2].value)?;
417                        let res_ty = map_ast_type(ctx.b, &ty);
418                        let glsl = ctx.b.ext_inst_import("GLSL.std.450");
419                        let res_id = ctx.b.ext_inst(res_ty, None, glsl, 49, vec![Operand::IdRef(edge0), Operand::IdRef(edge1), Operand::IdRef(x)]).unwrap(); // SmoothStep = 49
420                        return Ok((res_id, ty));
421                    },
422                    "length" if args.len() == 1 => {
423                        let (val, _) = emit_expr(ctx, &args[0].value)?;
424                        let glsl = ctx.b.ext_inst_import("GLSL.std.450");
425                        let res_id = ctx.b.ext_inst(float, None, glsl, 66, vec![Operand::IdRef(val)]).unwrap(); // Length = 66
426                        return Ok((res_id, Type::Named { name: "Float".into(), generics: vec![], span: expr.span() }));
427                    },
428                    "normalize" if args.len() == 1 => {
429                        let (val, ty) = emit_expr(ctx, &args[0].value)?;
430                        let res_ty = map_ast_type(ctx.b, &ty);
431                        let glsl = ctx.b.ext_inst_import("GLSL.std.450");
432                        let res_id = ctx.b.ext_inst(res_ty, None, glsl, 69, vec![Operand::IdRef(val)]).unwrap(); // Normalize = 69
433                        return Ok((res_id, ty));
434                    },
435                    "dot" if args.len() == 2 => {
436                        let (a, _) = emit_expr(ctx, &args[0].value)?;
437                        let (b, _) = emit_expr(ctx, &args[1].value)?;
438                        let res_id = ctx.b.dot(float, None, a, b).unwrap();
439                        return Ok((res_id, Type::Named { name: "Float".into(), generics: vec![], span: expr.span() }));
440                    },
441                    "cross" if args.len() == 2 => {
442                        let (a, ty) = emit_expr(ctx, &args[0].value)?;
443                        let (b, _) = emit_expr(ctx, &args[1].value)?;
444                        let res_ty = map_ast_type(ctx.b, &ty);
445                        let glsl = ctx.b.ext_inst_import("GLSL.std.450");
446                        let res_id = ctx.b.ext_inst(res_ty, None, glsl, 68, vec![Operand::IdRef(a), Operand::IdRef(b)]).unwrap(); // Cross = 68
447                        return Ok((res_id, ty));
448                    },
449                    "reflect" if args.len() == 2 => {
450                        let (i, ty) = emit_expr(ctx, &args[0].value)?;
451                        let (n, _) = emit_expr(ctx, &args[1].value)?;
452                        let res_ty = map_ast_type(ctx.b, &ty);
453                        let glsl = ctx.b.ext_inst_import("GLSL.std.450");
454                        let res_id = ctx.b.ext_inst(res_ty, None, glsl, 71, vec![Operand::IdRef(i), Operand::IdRef(n)]).unwrap(); // Reflect = 71
455                        return Ok((res_id, ty));
456                    },
457                    
458                    // Texture sampling
459                    "sample" if args.len() == 2 => {
460                        let (sampler, _) = emit_expr(ctx, &args[0].value)?;
461                        let (coords, _) = emit_expr(ctx, &args[1].value)?;
462                        let vec4 = ctx.b.type_vector(float, 4);
463                        let res_id = ctx.b.image_sample_implicit_lod(vec4, None, sampler, coords, None, std::iter::empty()).unwrap();
464                        return Ok((res_id, Type::Named { name: "Vec4".into(), generics: vec![], span: expr.span() }));
465                    },
466                    "sample_lod" if args.len() == 3 => {
467                        let (sampler, _) = emit_expr(ctx, &args[0].value)?;
468                        let (coords, _) = emit_expr(ctx, &args[1].value)?;
469                        let (lod, _) = emit_expr(ctx, &args[2].value)?;
470                        let vec4 = ctx.b.type_vector(float, 4);
471                        let res_id = ctx.b.image_sample_explicit_lod(vec4, None, sampler, coords, rspirv::spirv::ImageOperands::LOD, vec![Operand::IdRef(lod)]).unwrap();
472                        return Ok((res_id, Type::Named { name: "Vec4".into(), generics: vec![], span: expr.span() }));
473                    },
474                    
475                    _ => {}
476                }
477            }
478            Err(KoreError::codegen(format!("Unsupported function call in shader: {:?}", callee), expr.span()))
479        },
480        Expr::Float(f, span) => {
481            let float = ctx.b.type_float(32);
482            let val = ctx.b.constant_bit32(float, (*f as f32).to_bits());
483            Ok((val, Type::Named { name: "Float".into(), generics: vec![], span: *span }))
484        },
485        Expr::Field { object, field, span } => {
486            let (obj_id, _obj_ty) = emit_expr(ctx, object)?;
487            
488            // Swizzle/component access
489            let float = ctx.b.type_float(32);
490            match field.as_str() {
491                // Single component access
492                "x" | "r" => {
493                    let res_id = ctx.b.composite_extract(float, None, obj_id, vec![0]).unwrap();
494                    Ok((res_id, Type::Named { name: "Float".into(), generics: vec![], span: *span }))
495                },
496                "y" | "g" => {
497                    let res_id = ctx.b.composite_extract(float, None, obj_id, vec![1]).unwrap();
498                    Ok((res_id, Type::Named { name: "Float".into(), generics: vec![], span: *span }))
499                },
500                "z" | "b" => {
501                    let res_id = ctx.b.composite_extract(float, None, obj_id, vec![2]).unwrap();
502                    Ok((res_id, Type::Named { name: "Float".into(), generics: vec![], span: *span }))
503                },
504                "w" | "a" => {
505                    let res_id = ctx.b.composite_extract(float, None, obj_id, vec![3]).unwrap();
506                    Ok((res_id, Type::Named { name: "Float".into(), generics: vec![], span: *span }))
507                },
508                // Vec2 swizzles
509                "xy" | "rg" => {
510                    let vec2 = ctx.b.type_vector(float, 2);
511                    let res_id = ctx.b.vector_shuffle(vec2, None, obj_id, obj_id, vec![0, 1]).unwrap();
512                    Ok((res_id, Type::Named { name: "Vec2".into(), generics: vec![], span: *span }))
513                },
514                "xz" | "rb" => {
515                    let vec2 = ctx.b.type_vector(float, 2);
516                    let res_id = ctx.b.vector_shuffle(vec2, None, obj_id, obj_id, vec![0, 2]).unwrap();
517                    Ok((res_id, Type::Named { name: "Vec2".into(), generics: vec![], span: *span }))
518                },
519                "yz" | "gb" => {
520                    let vec2 = ctx.b.type_vector(float, 2);
521                    let res_id = ctx.b.vector_shuffle(vec2, None, obj_id, obj_id, vec![1, 2]).unwrap();
522                    Ok((res_id, Type::Named { name: "Vec2".into(), generics: vec![], span: *span }))
523                },
524                // Vec3 swizzles
525                "xyz" | "rgb" => {
526                    let vec3 = ctx.b.type_vector(float, 3);
527                    let res_id = ctx.b.vector_shuffle(vec3, None, obj_id, obj_id, vec![0, 1, 2]).unwrap();
528                    Ok((res_id, Type::Named { name: "Vec3".into(), generics: vec![], span: *span }))
529                },
530                _ => Err(KoreError::codegen(format!("Unsupported field access: {}", field), *span))
531            }
532        },
533        _ => Err(KoreError::codegen("Unsupported expression in shader", expr.span())),
534    }
535}
536
537fn map_ast_type(b: &mut Builder, ty: &Type) -> u32 {
538    let float = b.type_float(32);
539    match ty {
540        Type::Named { name, .. } => match name.as_str() {
541            "Float" | "f32" => float,
542            "Int" | "i32" => b.type_int(32, 1),
543            "Bool" => b.type_bool(),
544            "Vec2" => b.type_vector(float, 2),
545            "Vec3" => b.type_vector(float, 3),
546            "Vec4" => b.type_vector(float, 4),
547            "Mat4" => {
548                let v4 = b.type_vector(float, 4);
549                b.type_matrix(v4, 4)
550            },
551            "Sampler2D" => {
552                // Dim2D, NotDepth, Arrayed=False, MS=False, Sampled=1, Format=Unknown
553                let image = b.type_image(float, rspirv::spirv::Dim::Dim2D, 0, 0, 0, 1, rspirv::spirv::ImageFormat::Unknown, None);
554                b.type_sampled_image(image)
555            },
556            "StorageBuffer" => {
557                // Struct wrapper needed for buffer block
558                // Simplified: just array of floats for now
559                let rt_array = b.type_runtime_array(float);
560                let struct_ty = b.type_struct(vec![rt_array]);
561                b.decorate(struct_ty, Decoration::Block, vec![]);
562                struct_ty
563            },
564            "Void" => b.type_void(),
565            _ => b.type_void(),
566        },
567        _ => b.type_void(),
568    }
569}
570
571fn is_void(ty: &Type) -> bool {
572    matches!(ty, Type::Named { name, .. } if name == "Void")
573}
574
575fn is_vec4(ty: &Type) -> bool {
576    matches!(ty, Type::Named { name, .. } if name == "Vec4")
577}
578
579fn is_mat4(ty: &Type) -> bool {
580    matches!(ty, Type::Named { name, .. } if name == "Mat4")
581}
582
583fn is_float(ty: &Type) -> bool {
584    matches!(ty, Type::Named { name, .. } if name == "Float" || name == "f32")
585}
586