Skip to main content

kore/codegen/
spirv.rs

1//! SPIR-V Code Generation for GPU shaders
2
3use crate::types::{TypedProgram, TypedItem, TypedShader};
4use crate::error::{KoreResult, KoreError};
5use crate::ast::{Type, ShaderStage, Expr, Stmt, Block, BinaryOp};
6use rspirv::binary::Assemble;
7use rspirv::dr::{Builder, Operand};
8use rspirv::spirv::{Capability, AddressingModel, MemoryModel, ExecutionModel, ExecutionMode, StorageClass, Decoration};
9use std::collections::HashMap;
10
11pub fn generate(program: &TypedProgram) -> KoreResult<Vec<u8>> {
12    let mut builder = Builder::new();
13    
14    // Set capabilities and memory model
15    builder.capability(Capability::Shader);
16    // Add VulkanMemoryModel if targeting Vulkan, but GLSL450 is standard for now
17    builder.memory_model(AddressingModel::Logical, MemoryModel::GLSL450);
18    
19    for item in &program.items {
20        if let TypedItem::Shader(shader) = item {
21            emit_shader(&mut builder, shader)?;
22        }
23    }
24    
25    let module = builder.module();
26    let bytes: Vec<u8> = module.assemble().iter().flat_map(|w| w.to_le_bytes()).collect();
27    Ok(bytes)
28}
29
30struct ShaderContext<'a> {
31    b: &'a mut Builder,
32    // Name -> (SPIR-V ID, AST Type, IsPointer)
33    vars: HashMap<String, (u32, Type, bool)>,
34    output_var: Option<u32>,
35    // Track which variables are struct-wrapped uniforms (need AccessChain)
36    struct_uniforms: std::collections::HashSet<String>,
37}
38
39fn emit_shader(b: &mut Builder, shader: &TypedShader) -> KoreResult<()> {
40    let exec_model = match shader.ast.stage {
41        ShaderStage::Vertex => ExecutionModel::Vertex,
42        ShaderStage::Fragment => ExecutionModel::Fragment,
43        ShaderStage::Compute => ExecutionModel::GLCompute,
44    };
45    
46    // 1. Define Basic Types
47    let void = b.type_void();
48    
49    // 2. Define Entry Point Function Type
50    let fn_void_void = b.type_function(void, vec![]);
51    
52    // 3. Declare Variables (Global Interface)
53    let mut interface_vars = vec![];
54    let mut ctx_vars = HashMap::new();
55    let mut struct_uniforms = std::collections::HashSet::new();
56
57    // Inputs
58    for (i, param) in shader.ast.inputs.iter().enumerate() {
59        let ty = map_ast_type(b, &param.ty);
60        let ptr_ty = b.type_pointer(None, StorageClass::Input, ty);
61        let var = b.variable(ptr_ty, None, StorageClass::Input, None);
62        b.decorate(var, Decoration::Location, vec![Operand::LiteralBit32(i as u32)]);
63        interface_vars.push(var);
64        ctx_vars.insert(param.name.clone(), (var, param.ty.clone(), true));
65    }
66
67    // Outputs
68    let output_var = if !is_void(&shader.ast.outputs) {
69         let output_ty = map_ast_type(b, &shader.ast.outputs);
70         let ptr_ty = b.type_pointer(None, StorageClass::Output, output_ty);
71         let var = b.variable(ptr_ty, None, StorageClass::Output, None);
72         
73         // Vertex shader output is @builtin(position) for Vec4, otherwise use Location
74         if exec_model == ExecutionModel::Vertex && is_vec4(&shader.ast.outputs) {
75             b.decorate(var, Decoration::BuiltIn, vec![Operand::BuiltIn(rspirv::spirv::BuiltIn::Position)]);
76         } else {
77             b.decorate(var, Decoration::Location, vec![Operand::LiteralBit32(0)]);
78         }
79         
80         interface_vars.push(var);
81         Some(var)
82    } else {
83        None
84    };
85
86    // Uniforms
87    for uniform in &shader.ast.uniforms {
88        let inner_ty = map_ast_type(b, &uniform.ty);
89        
90        // Check if this is a sampler type (uses UniformConstant) or data type (uses Uniform with struct)
91        let is_sampler = matches!(&uniform.ty, Type::Named { name, .. } if name == "Sampler2D");
92        
93        if is_sampler {
94            // Samplers use UniformConstant storage class directly
95            let ptr_ty = b.type_pointer(None, StorageClass::UniformConstant, inner_ty);
96            let var = b.variable(ptr_ty, None, StorageClass::UniformConstant, None);
97            b.decorate(var, Decoration::DescriptorSet, vec![Operand::LiteralBit32(0)]);
98            b.decorate(var, Decoration::Binding, vec![Operand::LiteralBit32(uniform.binding)]);
99            ctx_vars.insert(uniform.name.clone(), (var, uniform.ty.clone(), true));
100        } else {
101            // Data uniforms (matrices, vectors, etc.) need a struct wrapper with Block decoration
102            let struct_ty = b.type_struct(vec![inner_ty]);
103            b.decorate(struct_ty, Decoration::Block, vec![]);
104            // Offset decoration for the first (and only) member
105            b.member_decorate(struct_ty, 0, Decoration::Offset, vec![Operand::LiteralBit32(0)]);
106            
107            // For matrices, we need ColMajor and MatrixStride decorations
108            if matches!(&uniform.ty, Type::Named { name, .. } if name == "Mat4") {
109                b.member_decorate(struct_ty, 0, Decoration::ColMajor, vec![]);
110                b.member_decorate(struct_ty, 0, Decoration::MatrixStride, vec![Operand::LiteralBit32(16)]);
111            }
112            
113            let ptr_ty = b.type_pointer(None, StorageClass::Uniform, struct_ty);
114            let var = b.variable(ptr_ty, None, StorageClass::Uniform, None);
115            b.decorate(var, Decoration::DescriptorSet, vec![Operand::LiteralBit32(0)]);
116            b.decorate(var, Decoration::Binding, vec![Operand::LiteralBit32(uniform.binding)]);
117            ctx_vars.insert(uniform.name.clone(), (var, uniform.ty.clone(), true));
118            struct_uniforms.insert(uniform.name.clone());
119        }
120    }
121
122    // 4. Function Body
123    let main_fn = b.begin_function(void, None, rspirv::spirv::FunctionControl::NONE, fn_void_void).unwrap();
124    b.begin_block(None).unwrap();
125
126    let mut ctx = ShaderContext {
127        b,
128        vars: ctx_vars,
129        output_var,
130        struct_uniforms,
131    };
132
133    emit_block(&mut ctx, &shader.ast.body)?;
134
135    // Ensure we always have a return
136    if shader.ast.body.stmts.last().map_or(true, |s| !matches!(s, Stmt::Return(_, _))) {
137        ctx.b.ret().unwrap();
138    }
139    
140    ctx.b.end_function().unwrap();
141
142    // 5. Entry Point
143    b.entry_point(exec_model, main_fn, &shader.ast.name, interface_vars);
144    
145    if exec_model == ExecutionModel::Fragment {
146        b.execution_mode(main_fn, ExecutionMode::OriginUpperLeft, vec![]);
147    }
148    
149    Ok(())
150}
151
152fn emit_block(ctx: &mut ShaderContext, block: &Block) -> KoreResult<()> {
153    for stmt in &block.stmts {
154        match stmt {
155            Stmt::Return(expr, _) => {
156                if let Some(expr) = expr {
157                    if let Some(out_var) = ctx.output_var {
158                        let (val, _) = emit_expr(ctx, expr)?;
159                        ctx.b.store(out_var, val, None, vec![]).unwrap();
160                    }
161                }
162                ctx.b.ret().unwrap();
163            },
164            Stmt::Let { pattern, value, .. } => {
165                if let Some(value) = value {
166                    let (val, ty) = emit_expr(ctx, value)?;
167                    // For now, only simple bindings
168                    if let crate::ast::Pattern::Binding { name, .. } = pattern {
169                        // In SSA, we just map name -> value ID
170                        // We don't support mutation of locals yet (need OpVariable + Store/Load)
171                        ctx.vars.insert(name.clone(), (val, ty, false));
172                    }
173                }
174            },
175            Stmt::Expr(expr) => {
176                emit_expr(ctx, expr)?;
177            },
178            _ => {} // Ignore others for now
179        }
180    }
181    Ok(())
182}
183
184fn emit_expr(ctx: &mut ShaderContext, expr: &Expr) -> KoreResult<(u32, Type)> {
185    match expr {
186        Expr::Ident(name, span) => {
187            if let Some((id, ty, is_ptr)) = ctx.vars.get(name).cloned() {
188                if is_ptr {
189                    // Need to load from pointer
190                    let type_id = map_ast_type(ctx.b, &ty);
191                    
192                    // Check if this is a struct-wrapped uniform
193                    if ctx.struct_uniforms.contains(name) {
194                        // Use AccessChain to get pointer to member 0 of the struct
195                        let ptr_ty = ctx.b.type_pointer(None, StorageClass::Uniform, type_id);
196                        let int_ty = ctx.b.type_int(32, 0);
197                        let zero = ctx.b.constant_bit32(int_ty, 0);
198                        let member_ptr = ctx.b.access_chain(ptr_ty, None, id, vec![zero]).unwrap();
199                        let val_id = ctx.b.load(type_id, None, member_ptr, None, std::iter::empty()).unwrap();
200                        Ok((val_id, ty))
201                    } else {
202                        // Direct load for inputs and non-wrapped uniforms
203                        let val_id = ctx.b.load(type_id, None, id, None, std::iter::empty()).unwrap();
204                        Ok((val_id, ty))
205                    }
206                } else {
207                    Ok((id, ty))
208                }
209            } else {
210                 Err(KoreError::codegen(format!("Unknown variable: {}", name), *span))
211            }
212        },
213        Expr::Binary { left, op, right, .. } => {
214            let (lhs, lhs_ty) = emit_expr(ctx, left)?;
215            let (rhs, rhs_ty) = emit_expr(ctx, right)?;
216            
217            // Map types to SPIR-V types
218            let res_ty_id = map_ast_type(ctx.b, &lhs_ty); // Assume result type matches lhs for now
219            
220            let res_id = match op {
221                BinaryOp::Mul => {
222                    if is_mat4(&lhs_ty) && is_mat4(&rhs_ty) {
223                        ctx.b.matrix_times_matrix(res_ty_id, None, lhs, rhs).unwrap()
224                    } else if is_mat4(&lhs_ty) && is_vec4(&rhs_ty) {
225                        // Mat4 * Vec4 -> Vec4
226                         let vec4_ty = map_ast_type(ctx.b, &rhs_ty);
227                         ctx.b.matrix_times_vector(vec4_ty, None, lhs, rhs).unwrap()
228                    } else if is_vec4(&lhs_ty) && is_mat4(&rhs_ty) {
229                        // Vec4 * Mat4 -> Vec4
230                         let vec4_ty = map_ast_type(ctx.b, &lhs_ty);
231                         ctx.b.vector_times_matrix(vec4_ty, None, lhs, rhs).unwrap()
232                    } else if is_float(&lhs_ty) && is_float(&rhs_ty) {
233                        ctx.b.f_mul(res_ty_id, None, lhs, rhs).unwrap()
234                    } else {
235                         // Fallback to FMul (vector * scalar, etc - simplified)
236                        ctx.b.f_mul(res_ty_id, None, lhs, rhs).unwrap()
237                    }
238                },
239                BinaryOp::Add => ctx.b.f_add(res_ty_id, None, lhs, rhs).unwrap(),
240                BinaryOp::Sub => ctx.b.f_sub(res_ty_id, None, lhs, rhs).unwrap(),
241                BinaryOp::Div => ctx.b.f_div(res_ty_id, None, lhs, rhs).unwrap(),
242                _ => return Err(KoreError::codegen("Unsupported binary op in shader", expr.span())),
243            };
244            
245            // Result type inference (simplified)
246            let res_ty = if is_mat4(&lhs_ty) && is_vec4(&rhs_ty) {
247                rhs_ty
248            } else {
249                lhs_ty
250            };
251            
252            Ok((res_id, res_ty))
253        },
254        Expr::Call { callee, args, .. } => {
255            if let Expr::Ident(name, _) = &**callee {
256                if name == "Vec4" && args.len() == 4 {
257                    // Constructor
258                    let float = ctx.b.type_float(32);
259                    let vec4 = ctx.b.type_vector(float, 4);
260                    let mut components = vec![];
261                    for arg in args {
262                        let (val, _) = emit_expr(ctx, &arg.value)?;
263                        components.push(val);
264                    }
265                    let res_id = ctx.b.composite_construct(vec4, None, components).unwrap();
266                    return Ok((res_id, Type::Named { name: "Vec4".into(), generics: vec![], span: expr.span() }));
267                }
268            }
269            Err(KoreError::codegen("Unsupported function call in shader", expr.span()))
270        },
271        Expr::Float(f, span) => {
272            let float = ctx.b.type_float(32);
273            let val = ctx.b.constant_bit32(float, (*f as f32).to_bits());
274            Ok((val, Type::Named { name: "Float".into(), generics: vec![], span: *span }))
275        },
276        _ => Err(KoreError::codegen("Unsupported expression in shader", expr.span())),
277    }
278}
279
280fn map_ast_type(b: &mut Builder, ty: &Type) -> u32 {
281    let float = b.type_float(32);
282    match ty {
283        Type::Named { name, .. } => match name.as_str() {
284            "Float" | "f32" => float,
285            "Int" | "i32" => b.type_int(32, 1),
286            "Bool" => b.type_bool(),
287            "Vec2" => b.type_vector(float, 2),
288            "Vec3" => b.type_vector(float, 3),
289            "Vec4" => b.type_vector(float, 4),
290            "Mat4" => {
291                let v4 = b.type_vector(float, 4);
292                b.type_matrix(v4, 4)
293            },
294            "Sampler2D" => {
295                // Dim2D, NotDepth, Arrayed=False, MS=False, Sampled=1, Format=Unknown
296                let image = b.type_image(float, rspirv::spirv::Dim::Dim2D, 0, 0, 0, 1, rspirv::spirv::ImageFormat::Unknown, None);
297                b.type_sampled_image(image)
298            },
299            "StorageBuffer" => {
300                // Struct wrapper needed for buffer block
301                // Simplified: just array of floats for now
302                let rt_array = b.type_runtime_array(float);
303                let struct_ty = b.type_struct(vec![rt_array]);
304                b.decorate(struct_ty, Decoration::Block, vec![]);
305                struct_ty
306            },
307            "Void" => b.type_void(),
308            _ => b.type_void(),
309        },
310        _ => b.type_void(),
311    }
312}
313
314fn is_void(ty: &Type) -> bool {
315    matches!(ty, Type::Named { name, .. } if name == "Void")
316}
317
318fn is_vec4(ty: &Type) -> bool {
319    matches!(ty, Type::Named { name, .. } if name == "Vec4")
320}
321
322fn is_mat4(ty: &Type) -> bool {
323    matches!(ty, Type::Named { name, .. } if name == "Mat4")
324}
325
326fn is_float(ty: &Type) -> bool {
327    matches!(ty, Type::Named { name, .. } if name == "Float" || name == "f32")
328}
329