1use crate::types::{TypedProgram, TypedItem, TypedShader};
4use crate::error::{KoreResult, KoreError};
5use crate::ast::{Type, ShaderStage, Expr, Stmt, Block, BinaryOp};
6use rspirv::binary::Assemble;
7use rspirv::dr::{Builder, Operand};
8use rspirv::spirv::{Capability, AddressingModel, MemoryModel, ExecutionModel, ExecutionMode, StorageClass, Decoration};
9use std::collections::HashMap;
10
11pub fn generate(program: &TypedProgram) -> KoreResult<Vec<u8>> {
12 let mut builder = Builder::new();
13
14 builder.capability(Capability::Shader);
16 builder.memory_model(AddressingModel::Logical, MemoryModel::GLSL450);
18
19 for item in &program.items {
20 if let TypedItem::Shader(shader) = item {
21 emit_shader(&mut builder, shader)?;
22 }
23 }
24
25 let module = builder.module();
26 let bytes: Vec<u8> = module.assemble().iter().flat_map(|w| w.to_le_bytes()).collect();
27 Ok(bytes)
28}
29
30struct ShaderContext<'a> {
31 b: &'a mut Builder,
32 vars: HashMap<String, (u32, Type, bool)>,
34 output_var: Option<u32>,
35 struct_uniforms: std::collections::HashSet<String>,
37}
38
39fn emit_shader(b: &mut Builder, shader: &TypedShader) -> KoreResult<()> {
40 let exec_model = match shader.ast.stage {
41 ShaderStage::Vertex => ExecutionModel::Vertex,
42 ShaderStage::Fragment => ExecutionModel::Fragment,
43 ShaderStage::Compute => ExecutionModel::GLCompute,
44 };
45
46 let void = b.type_void();
48
49 let fn_void_void = b.type_function(void, vec![]);
51
52 let mut interface_vars = vec![];
54 let mut ctx_vars = HashMap::new();
55 let mut struct_uniforms = std::collections::HashSet::new();
56
57 for (i, param) in shader.ast.inputs.iter().enumerate() {
59 let ty = map_ast_type(b, ¶m.ty);
60 let ptr_ty = b.type_pointer(None, StorageClass::Input, ty);
61 let var = b.variable(ptr_ty, None, StorageClass::Input, None);
62 b.decorate(var, Decoration::Location, vec![Operand::LiteralBit32(i as u32)]);
63 interface_vars.push(var);
64 ctx_vars.insert(param.name.clone(), (var, param.ty.clone(), true));
65 }
66
67 let output_var = if !is_void(&shader.ast.outputs) {
69 let output_ty = map_ast_type(b, &shader.ast.outputs);
70 let ptr_ty = b.type_pointer(None, StorageClass::Output, output_ty);
71 let var = b.variable(ptr_ty, None, StorageClass::Output, None);
72
73 if exec_model == ExecutionModel::Vertex && is_vec4(&shader.ast.outputs) {
75 b.decorate(var, Decoration::BuiltIn, vec![Operand::BuiltIn(rspirv::spirv::BuiltIn::Position)]);
76 } else {
77 b.decorate(var, Decoration::Location, vec![Operand::LiteralBit32(0)]);
78 }
79
80 interface_vars.push(var);
81 Some(var)
82 } else {
83 None
84 };
85
86 for uniform in &shader.ast.uniforms {
88 let inner_ty = map_ast_type(b, &uniform.ty);
89
90 let is_sampler = matches!(&uniform.ty, Type::Named { name, .. } if name == "Sampler2D");
92
93 if is_sampler {
94 let ptr_ty = b.type_pointer(None, StorageClass::UniformConstant, inner_ty);
96 let var = b.variable(ptr_ty, None, StorageClass::UniformConstant, None);
97 b.decorate(var, Decoration::DescriptorSet, vec![Operand::LiteralBit32(0)]);
98 b.decorate(var, Decoration::Binding, vec![Operand::LiteralBit32(uniform.binding)]);
99 ctx_vars.insert(uniform.name.clone(), (var, uniform.ty.clone(), true));
100 } else {
101 let struct_ty = b.type_struct(vec![inner_ty]);
103 b.decorate(struct_ty, Decoration::Block, vec![]);
104 b.member_decorate(struct_ty, 0, Decoration::Offset, vec![Operand::LiteralBit32(0)]);
106
107 if matches!(&uniform.ty, Type::Named { name, .. } if name == "Mat4") {
109 b.member_decorate(struct_ty, 0, Decoration::ColMajor, vec![]);
110 b.member_decorate(struct_ty, 0, Decoration::MatrixStride, vec![Operand::LiteralBit32(16)]);
111 }
112
113 let ptr_ty = b.type_pointer(None, StorageClass::Uniform, struct_ty);
114 let var = b.variable(ptr_ty, None, StorageClass::Uniform, None);
115 b.decorate(var, Decoration::DescriptorSet, vec![Operand::LiteralBit32(0)]);
116 b.decorate(var, Decoration::Binding, vec![Operand::LiteralBit32(uniform.binding)]);
117 ctx_vars.insert(uniform.name.clone(), (var, uniform.ty.clone(), true));
118 struct_uniforms.insert(uniform.name.clone());
119 }
120 }
121
122 let main_fn = b.begin_function(void, None, rspirv::spirv::FunctionControl::NONE, fn_void_void).unwrap();
124 b.begin_block(None).unwrap();
125
126 let mut ctx = ShaderContext {
127 b,
128 vars: ctx_vars,
129 output_var,
130 struct_uniforms,
131 };
132
133 emit_block(&mut ctx, &shader.ast.body)?;
134
135 if shader.ast.body.stmts.last().map_or(true, |s| !matches!(s, Stmt::Return(_, _))) {
137 ctx.b.ret().unwrap();
138 }
139
140 ctx.b.end_function().unwrap();
141
142 b.entry_point(exec_model, main_fn, &shader.ast.name, interface_vars);
144
145 if exec_model == ExecutionModel::Fragment {
146 b.execution_mode(main_fn, ExecutionMode::OriginUpperLeft, vec![]);
147 }
148
149 Ok(())
150}
151
152fn emit_block(ctx: &mut ShaderContext, block: &Block) -> KoreResult<()> {
153 for stmt in &block.stmts {
154 match stmt {
155 Stmt::Return(expr, _) => {
156 if let Some(expr) = expr {
157 if let Some(out_var) = ctx.output_var {
158 let (val, _) = emit_expr(ctx, expr)?;
159 ctx.b.store(out_var, val, None, vec![]).unwrap();
160 }
161 }
162 ctx.b.ret().unwrap();
163 },
164 Stmt::Let { pattern, value, .. } => {
165 if let Some(value) = value {
166 let (val, ty) = emit_expr(ctx, value)?;
167 if let crate::ast::Pattern::Binding { name, .. } = pattern {
169 ctx.vars.insert(name.clone(), (val, ty, false));
172 }
173 }
174 },
175 Stmt::Expr(expr) => {
176 emit_expr(ctx, expr)?;
177 },
178 _ => {} }
180 }
181 Ok(())
182}
183
184fn emit_expr(ctx: &mut ShaderContext, expr: &Expr) -> KoreResult<(u32, Type)> {
185 match expr {
186 Expr::Ident(name, span) => {
187 if let Some((id, ty, is_ptr)) = ctx.vars.get(name).cloned() {
188 if is_ptr {
189 let type_id = map_ast_type(ctx.b, &ty);
191
192 if ctx.struct_uniforms.contains(name) {
194 let ptr_ty = ctx.b.type_pointer(None, StorageClass::Uniform, type_id);
196 let int_ty = ctx.b.type_int(32, 0);
197 let zero = ctx.b.constant_bit32(int_ty, 0);
198 let member_ptr = ctx.b.access_chain(ptr_ty, None, id, vec![zero]).unwrap();
199 let val_id = ctx.b.load(type_id, None, member_ptr, None, std::iter::empty()).unwrap();
200 Ok((val_id, ty))
201 } else {
202 let val_id = ctx.b.load(type_id, None, id, None, std::iter::empty()).unwrap();
204 Ok((val_id, ty))
205 }
206 } else {
207 Ok((id, ty))
208 }
209 } else {
210 Err(KoreError::codegen(format!("Unknown variable: {}", name), *span))
211 }
212 },
213 Expr::Binary { left, op, right, .. } => {
214 let (lhs, lhs_ty) = emit_expr(ctx, left)?;
215 let (rhs, rhs_ty) = emit_expr(ctx, right)?;
216
217 let res_ty_id = map_ast_type(ctx.b, &lhs_ty); let res_id = match op {
221 BinaryOp::Mul => {
222 if is_mat4(&lhs_ty) && is_mat4(&rhs_ty) {
223 ctx.b.matrix_times_matrix(res_ty_id, None, lhs, rhs).unwrap()
224 } else if is_mat4(&lhs_ty) && is_vec4(&rhs_ty) {
225 let vec4_ty = map_ast_type(ctx.b, &rhs_ty);
227 ctx.b.matrix_times_vector(vec4_ty, None, lhs, rhs).unwrap()
228 } else if is_vec4(&lhs_ty) && is_mat4(&rhs_ty) {
229 let vec4_ty = map_ast_type(ctx.b, &lhs_ty);
231 ctx.b.vector_times_matrix(vec4_ty, None, lhs, rhs).unwrap()
232 } else if is_float(&lhs_ty) && is_float(&rhs_ty) {
233 ctx.b.f_mul(res_ty_id, None, lhs, rhs).unwrap()
234 } else {
235 ctx.b.f_mul(res_ty_id, None, lhs, rhs).unwrap()
237 }
238 },
239 BinaryOp::Add => ctx.b.f_add(res_ty_id, None, lhs, rhs).unwrap(),
240 BinaryOp::Sub => ctx.b.f_sub(res_ty_id, None, lhs, rhs).unwrap(),
241 BinaryOp::Div => ctx.b.f_div(res_ty_id, None, lhs, rhs).unwrap(),
242 _ => return Err(KoreError::codegen("Unsupported binary op in shader", expr.span())),
243 };
244
245 let res_ty = if is_mat4(&lhs_ty) && is_vec4(&rhs_ty) {
247 rhs_ty
248 } else {
249 lhs_ty
250 };
251
252 Ok((res_id, res_ty))
253 },
254 Expr::Call { callee, args, .. } => {
255 if let Expr::Ident(name, _) = &**callee {
256 if name == "Vec4" && args.len() == 4 {
257 let float = ctx.b.type_float(32);
259 let vec4 = ctx.b.type_vector(float, 4);
260 let mut components = vec![];
261 for arg in args {
262 let (val, _) = emit_expr(ctx, &arg.value)?;
263 components.push(val);
264 }
265 let res_id = ctx.b.composite_construct(vec4, None, components).unwrap();
266 return Ok((res_id, Type::Named { name: "Vec4".into(), generics: vec![], span: expr.span() }));
267 }
268 }
269 Err(KoreError::codegen("Unsupported function call in shader", expr.span()))
270 },
271 Expr::Float(f, span) => {
272 let float = ctx.b.type_float(32);
273 let val = ctx.b.constant_bit32(float, (*f as f32).to_bits());
274 Ok((val, Type::Named { name: "Float".into(), generics: vec![], span: *span }))
275 },
276 _ => Err(KoreError::codegen("Unsupported expression in shader", expr.span())),
277 }
278}
279
280fn map_ast_type(b: &mut Builder, ty: &Type) -> u32 {
281 let float = b.type_float(32);
282 match ty {
283 Type::Named { name, .. } => match name.as_str() {
284 "Float" | "f32" => float,
285 "Int" | "i32" => b.type_int(32, 1),
286 "Bool" => b.type_bool(),
287 "Vec2" => b.type_vector(float, 2),
288 "Vec3" => b.type_vector(float, 3),
289 "Vec4" => b.type_vector(float, 4),
290 "Mat4" => {
291 let v4 = b.type_vector(float, 4);
292 b.type_matrix(v4, 4)
293 },
294 "Sampler2D" => {
295 let image = b.type_image(float, rspirv::spirv::Dim::Dim2D, 0, 0, 0, 1, rspirv::spirv::ImageFormat::Unknown, None);
297 b.type_sampled_image(image)
298 },
299 "StorageBuffer" => {
300 let rt_array = b.type_runtime_array(float);
303 let struct_ty = b.type_struct(vec![rt_array]);
304 b.decorate(struct_ty, Decoration::Block, vec![]);
305 struct_ty
306 },
307 "Void" => b.type_void(),
308 _ => b.type_void(),
309 },
310 _ => b.type_void(),
311 }
312}
313
314fn is_void(ty: &Type) -> bool {
315 matches!(ty, Type::Named { name, .. } if name == "Void")
316}
317
318fn is_vec4(ty: &Type) -> bool {
319 matches!(ty, Type::Named { name, .. } if name == "Vec4")
320}
321
322fn is_mat4(ty: &Type) -> bool {
323 matches!(ty, Type::Named { name, .. } if name == "Mat4")
324}
325
326fn is_float(ty: &Type) -> bool {
327 matches!(ty, Type::Named { name, .. } if name == "Float" || name == "f32")
328}
329