1use crate::types::{TypedProgram, TypedItem, TypedShader};
4use crate::error::{KoreResult, KoreError};
5use crate::ast::{Type, ShaderStage, Expr, Stmt, Block, BinaryOp};
6use rspirv::binary::Assemble;
7use rspirv::dr::{Builder, Operand};
8use rspirv::spirv::{Capability, AddressingModel, MemoryModel, ExecutionModel, ExecutionMode, StorageClass, Decoration};
9use std::collections::HashMap;
10
11pub fn generate(program: &TypedProgram) -> KoreResult<Vec<u8>> {
12 let mut builder = Builder::new();
13
14 builder.capability(Capability::Shader);
16 builder.memory_model(AddressingModel::Logical, MemoryModel::GLSL450);
18
19 for item in &program.items {
20 if let TypedItem::Shader(shader) = item {
21 emit_shader(&mut builder, shader)?;
22 }
23 }
24
25 let module = builder.module();
26 let bytes: Vec<u8> = module.assemble().iter().flat_map(|w| w.to_le_bytes()).collect();
27 Ok(bytes)
28}
29
30struct ShaderContext<'a> {
31 b: &'a mut Builder,
32 vars: HashMap<String, (u32, Type, bool)>,
34 output_var: Option<u32>,
35 struct_uniforms: std::collections::HashSet<String>,
37 glsl_ext: Option<u32>,
39}
40
41fn emit_shader(b: &mut Builder, shader: &TypedShader) -> KoreResult<()> {
42 let exec_model = match shader.ast.stage {
43 ShaderStage::Vertex => ExecutionModel::Vertex,
44 ShaderStage::Fragment => ExecutionModel::Fragment,
45 ShaderStage::Compute => ExecutionModel::GLCompute,
46 };
47
48 let void = b.type_void();
50
51 let fn_void_void = b.type_function(void, vec![]);
53
54 let mut interface_vars = vec![];
56 let mut ctx_vars = HashMap::new();
57 let mut struct_uniforms = std::collections::HashSet::new();
58
59 for (i, param) in shader.ast.inputs.iter().enumerate() {
61 let ty = map_ast_type(b, ¶m.ty);
62 let ptr_ty = b.type_pointer(None, StorageClass::Input, ty);
63 let var = b.variable(ptr_ty, None, StorageClass::Input, None);
64 b.decorate(var, Decoration::Location, vec![Operand::LiteralBit32(i as u32)]);
65 interface_vars.push(var);
66 ctx_vars.insert(param.name.clone(), (var, param.ty.clone(), true));
67 }
68
69 let output_var = if !is_void(&shader.ast.outputs) {
71 let output_ty = map_ast_type(b, &shader.ast.outputs);
72 let ptr_ty = b.type_pointer(None, StorageClass::Output, output_ty);
73 let var = b.variable(ptr_ty, None, StorageClass::Output, None);
74
75 if exec_model == ExecutionModel::Vertex && is_vec4(&shader.ast.outputs) {
77 b.decorate(var, Decoration::BuiltIn, vec![Operand::BuiltIn(rspirv::spirv::BuiltIn::Position)]);
78 } else {
79 b.decorate(var, Decoration::Location, vec![Operand::LiteralBit32(0)]);
80 }
81
82 interface_vars.push(var);
83 Some(var)
84 } else {
85 None
86 };
87
88 for uniform in &shader.ast.uniforms {
90 let inner_ty = map_ast_type(b, &uniform.ty);
91
92 let is_sampler = matches!(&uniform.ty, Type::Named { name, .. } if name == "Sampler2D");
94
95 if is_sampler {
96 let ptr_ty = b.type_pointer(None, StorageClass::UniformConstant, inner_ty);
98 let var = b.variable(ptr_ty, None, StorageClass::UniformConstant, None);
99 b.decorate(var, Decoration::DescriptorSet, vec![Operand::LiteralBit32(0)]);
100 b.decorate(var, Decoration::Binding, vec![Operand::LiteralBit32(uniform.binding)]);
101 ctx_vars.insert(uniform.name.clone(), (var, uniform.ty.clone(), true));
102 } else {
103 let struct_ty = b.type_struct(vec![inner_ty]);
105 b.decorate(struct_ty, Decoration::Block, vec![]);
106 b.member_decorate(struct_ty, 0, Decoration::Offset, vec![Operand::LiteralBit32(0)]);
108
109 if matches!(&uniform.ty, Type::Named { name, .. } if name == "Mat4") {
111 b.member_decorate(struct_ty, 0, Decoration::ColMajor, vec![]);
112 b.member_decorate(struct_ty, 0, Decoration::MatrixStride, vec![Operand::LiteralBit32(16)]);
113 }
114
115 let ptr_ty = b.type_pointer(None, StorageClass::Uniform, struct_ty);
116 let var = b.variable(ptr_ty, None, StorageClass::Uniform, None);
117 b.decorate(var, Decoration::DescriptorSet, vec![Operand::LiteralBit32(0)]);
118 b.decorate(var, Decoration::Binding, vec![Operand::LiteralBit32(uniform.binding)]);
119 ctx_vars.insert(uniform.name.clone(), (var, uniform.ty.clone(), true));
120 struct_uniforms.insert(uniform.name.clone());
121 }
122 }
123
124 let main_fn = b.begin_function(void, None, rspirv::spirv::FunctionControl::NONE, fn_void_void).unwrap();
126 b.begin_block(None).unwrap();
127
128 let mut ctx = ShaderContext {
129 b,
130 vars: ctx_vars,
131 output_var,
132 struct_uniforms,
133 glsl_ext: None,
134 };
135
136 emit_block(&mut ctx, &shader.ast.body)?;
137
138 if shader.ast.body.stmts.last().map_or(true, |s| !matches!(s, Stmt::Return(_, _))) {
140 ctx.b.ret().unwrap();
141 }
142
143 ctx.b.end_function().unwrap();
144
145 b.entry_point(exec_model, main_fn, &shader.ast.name, interface_vars);
147
148 if exec_model == ExecutionModel::Fragment {
149 b.execution_mode(main_fn, ExecutionMode::OriginUpperLeft, vec![]);
150 }
151
152 Ok(())
153}
154
155impl<'a> ShaderContext<'a> {
156 fn get_glsl_ext(&mut self) -> u32 {
157 if let Some(ext) = self.glsl_ext {
158 ext
159 } else {
160 let ext = self.b.ext_inst_import("GLSL.std.450");
161 self.glsl_ext = Some(ext);
162 ext
163 }
164 }
165}
166
167fn emit_block(ctx: &mut ShaderContext, block: &Block) -> KoreResult<()> {
168 for stmt in &block.stmts {
169 match stmt {
170 Stmt::Return(expr, _) => {
171 if let Some(expr) = expr {
172 if let Some(out_var) = ctx.output_var {
173 let (val, _) = emit_expr(ctx, expr)?;
174 ctx.b.store(out_var, val, None, vec![]).unwrap();
175 }
176 }
177 ctx.b.ret().unwrap();
178 },
179 Stmt::Let { pattern, value, .. } => {
180 if let Some(value) = value {
181 let (val, ty) = emit_expr(ctx, value)?;
182 if let crate::ast::Pattern::Binding { name, .. } = pattern {
184 ctx.vars.insert(name.clone(), (val, ty, false));
187 }
188 }
189 },
190 Stmt::Expr(expr) => {
191 emit_expr(ctx, expr)?;
192 },
193 _ => {} }
195 }
196 Ok(())
197}
198
199fn emit_expr(ctx: &mut ShaderContext, expr: &Expr) -> KoreResult<(u32, Type)> {
200 match expr {
201 Expr::Ident(name, span) => {
202 if let Some((id, ty, is_ptr)) = ctx.vars.get(name).cloned() {
203 if is_ptr {
204 let type_id = map_ast_type(ctx.b, &ty);
206
207 if ctx.struct_uniforms.contains(name) {
209 let ptr_ty = ctx.b.type_pointer(None, StorageClass::Uniform, type_id);
211 let int_ty = ctx.b.type_int(32, 0);
212 let zero = ctx.b.constant_bit32(int_ty, 0);
213 let member_ptr = ctx.b.access_chain(ptr_ty, None, id, vec![zero]).unwrap();
214 let val_id = ctx.b.load(type_id, None, member_ptr, None, std::iter::empty()).unwrap();
215 Ok((val_id, ty))
216 } else {
217 let val_id = ctx.b.load(type_id, None, id, None, std::iter::empty()).unwrap();
219 Ok((val_id, ty))
220 }
221 } else {
222 Ok((id, ty))
223 }
224 } else {
225 Err(KoreError::codegen(format!("Unknown variable: {}", name), *span))
226 }
227 },
228 Expr::Binary { left, op, right, .. } => {
229 let (lhs, lhs_ty) = emit_expr(ctx, left)?;
230 let (rhs, rhs_ty) = emit_expr(ctx, right)?;
231
232 let res_ty_id = map_ast_type(ctx.b, &lhs_ty); let res_id = match op {
236 BinaryOp::Mul => {
237 if is_mat4(&lhs_ty) && is_mat4(&rhs_ty) {
238 ctx.b.matrix_times_matrix(res_ty_id, None, lhs, rhs).unwrap()
239 } else if is_mat4(&lhs_ty) && is_vec4(&rhs_ty) {
240 let vec4_ty = map_ast_type(ctx.b, &rhs_ty);
242 ctx.b.matrix_times_vector(vec4_ty, None, lhs, rhs).unwrap()
243 } else if is_vec4(&lhs_ty) && is_mat4(&rhs_ty) {
244 let vec4_ty = map_ast_type(ctx.b, &lhs_ty);
246 ctx.b.vector_times_matrix(vec4_ty, None, lhs, rhs).unwrap()
247 } else if is_float(&lhs_ty) && is_float(&rhs_ty) {
248 ctx.b.f_mul(res_ty_id, None, lhs, rhs).unwrap()
249 } else {
250 ctx.b.f_mul(res_ty_id, None, lhs, rhs).unwrap()
252 }
253 },
254 BinaryOp::Add => ctx.b.f_add(res_ty_id, None, lhs, rhs).unwrap(),
255 BinaryOp::Sub => ctx.b.f_sub(res_ty_id, None, lhs, rhs).unwrap(),
256 BinaryOp::Div => ctx.b.f_div(res_ty_id, None, lhs, rhs).unwrap(),
257 _ => return Err(KoreError::codegen("Unsupported binary op in shader", expr.span())),
258 };
259
260 let res_ty = if is_mat4(&lhs_ty) && is_vec4(&rhs_ty) {
262 rhs_ty
263 } else {
264 lhs_ty
265 };
266
267 Ok((res_id, res_ty))
268 },
269 Expr::Call { callee, args, .. } => {
270 if let Expr::Ident(name, _) = &**callee {
271 let float = ctx.b.type_float(32);
272
273 match name.as_str() {
275 "vec2" | "Vec2" if args.len() == 2 => {
276 let vec2 = ctx.b.type_vector(float, 2);
277 let mut components = vec![];
278 for arg in args {
279 let (val, _) = emit_expr(ctx, &arg.value)?;
280 components.push(val);
281 }
282 let res_id = ctx.b.composite_construct(vec2, None, components).unwrap();
283 return Ok((res_id, Type::Named { name: "Vec2".into(), generics: vec![], span: expr.span() }));
284 },
285 "vec3" | "Vec3" if args.len() == 3 => {
286 let vec3 = ctx.b.type_vector(float, 3);
287 let mut components = vec![];
288 for arg in args {
289 let (val, _) = emit_expr(ctx, &arg.value)?;
290 components.push(val);
291 }
292 let res_id = ctx.b.composite_construct(vec3, None, components).unwrap();
293 return Ok((res_id, Type::Named { name: "Vec3".into(), generics: vec![], span: expr.span() }));
294 },
295 "vec4" | "Vec4" if args.len() == 4 => {
296 let vec4 = ctx.b.type_vector(float, 4);
297 let mut components = vec![];
298 for arg in args {
299 let (val, _) = emit_expr(ctx, &arg.value)?;
300 components.push(val);
301 }
302 let res_id = ctx.b.composite_construct(vec4, None, components).unwrap();
303 return Ok((res_id, Type::Named { name: "Vec4".into(), generics: vec![], span: expr.span() }));
304 },
305
306 "sin" if args.len() == 1 => {
308 let (val, ty) = emit_expr(ctx, &args[0].value)?;
309 let res_ty = map_ast_type(ctx.b, &ty);
310 let glsl = ctx.get_glsl_ext();
311 let res_id = ctx.b.ext_inst(res_ty, None, glsl, 13, vec![Operand::IdRef(val)]).unwrap(); return Ok((res_id, ty));
313 },
314 "cos" if args.len() == 1 => {
315 let (val, ty) = emit_expr(ctx, &args[0].value)?;
316 let res_ty = map_ast_type(ctx.b, &ty);
317 let glsl = ctx.get_glsl_ext();
318 let res_id = ctx.b.ext_inst(res_ty, None, glsl, 14, vec![Operand::IdRef(val)]).unwrap(); return Ok((res_id, ty));
320 },
321 "tan" if args.len() == 1 => {
322 let (val, ty) = emit_expr(ctx, &args[0].value)?;
323 let res_ty = map_ast_type(ctx.b, &ty);
324 let glsl = ctx.b.ext_inst_import("GLSL.std.450");
325 let res_id = ctx.b.ext_inst(res_ty, None, glsl, 15, vec![Operand::IdRef(val)]).unwrap(); return Ok((res_id, ty));
327 },
328 "pow" if args.len() == 2 => {
329 let (base, ty) = emit_expr(ctx, &args[0].value)?;
330 let (exp, _) = emit_expr(ctx, &args[1].value)?;
331 let res_ty = map_ast_type(ctx.b, &ty);
332 let glsl = ctx.b.ext_inst_import("GLSL.std.450");
333 let res_id = ctx.b.ext_inst(res_ty, None, glsl, 26, vec![Operand::IdRef(base), Operand::IdRef(exp)]).unwrap(); return Ok((res_id, ty));
335 },
336 "sqrt" if args.len() == 1 => {
337 let (val, ty) = emit_expr(ctx, &args[0].value)?;
338 let res_ty = map_ast_type(ctx.b, &ty);
339 let glsl = ctx.b.ext_inst_import("GLSL.std.450");
340 let res_id = ctx.b.ext_inst(res_ty, None, glsl, 31, vec![Operand::IdRef(val)]).unwrap(); return Ok((res_id, ty));
342 },
343 "abs" if args.len() == 1 => {
344 let (val, ty) = emit_expr(ctx, &args[0].value)?;
345 let res_ty = map_ast_type(ctx.b, &ty);
346 let glsl = ctx.b.ext_inst_import("GLSL.std.450");
347 let res_id = ctx.b.ext_inst(res_ty, None, glsl, 4, vec![Operand::IdRef(val)]).unwrap(); return Ok((res_id, ty));
349 },
350 "floor" if args.len() == 1 => {
351 let (val, ty) = emit_expr(ctx, &args[0].value)?;
352 let res_ty = map_ast_type(ctx.b, &ty);
353 let glsl = ctx.b.ext_inst_import("GLSL.std.450");
354 let res_id = ctx.b.ext_inst(res_ty, None, glsl, 8, vec![Operand::IdRef(val)]).unwrap(); return Ok((res_id, ty));
356 },
357 "ceil" if args.len() == 1 => {
358 let (val, ty) = emit_expr(ctx, &args[0].value)?;
359 let res_ty = map_ast_type(ctx.b, &ty);
360 let glsl = ctx.b.ext_inst_import("GLSL.std.450");
361 let res_id = ctx.b.ext_inst(res_ty, None, glsl, 9, vec![Operand::IdRef(val)]).unwrap(); return Ok((res_id, ty));
363 },
364 "fract" if args.len() == 1 => {
365 let (val, ty) = emit_expr(ctx, &args[0].value)?;
366 let res_ty = map_ast_type(ctx.b, &ty);
367 let glsl = ctx.b.ext_inst_import("GLSL.std.450");
368 let res_id = ctx.b.ext_inst(res_ty, None, glsl, 10, vec![Operand::IdRef(val)]).unwrap(); return Ok((res_id, ty));
370 },
371 "min" if args.len() == 2 => {
372 let (a, ty) = emit_expr(ctx, &args[0].value)?;
373 let (b, _) = emit_expr(ctx, &args[1].value)?;
374 let res_ty = map_ast_type(ctx.b, &ty);
375 let glsl = ctx.b.ext_inst_import("GLSL.std.450");
376 let res_id = ctx.b.ext_inst(res_ty, None, glsl, 37, vec![Operand::IdRef(a), Operand::IdRef(b)]).unwrap(); return Ok((res_id, ty));
378 },
379 "max" if args.len() == 2 => {
380 let (a, ty) = emit_expr(ctx, &args[0].value)?;
381 let (b, _) = emit_expr(ctx, &args[1].value)?;
382 let res_ty = map_ast_type(ctx.b, &ty);
383 let glsl = ctx.b.ext_inst_import("GLSL.std.450");
384 let res_id = ctx.b.ext_inst(res_ty, None, glsl, 40, vec![Operand::IdRef(a), Operand::IdRef(b)]).unwrap(); return Ok((res_id, ty));
386 },
387 "clamp" if args.len() == 3 => {
388 let (val, ty) = emit_expr(ctx, &args[0].value)?;
389 let (min_val, _) = emit_expr(ctx, &args[1].value)?;
390 let (max_val, _) = emit_expr(ctx, &args[2].value)?;
391 let res_ty = map_ast_type(ctx.b, &ty);
392 let glsl = ctx.b.ext_inst_import("GLSL.std.450");
393 let res_id = ctx.b.ext_inst(res_ty, None, glsl, 43, vec![Operand::IdRef(val), Operand::IdRef(min_val), Operand::IdRef(max_val)]).unwrap(); return Ok((res_id, ty));
395 },
396 "mix" if args.len() == 3 => {
397 let (a, ty) = emit_expr(ctx, &args[0].value)?;
398 let (b, _) = emit_expr(ctx, &args[1].value)?;
399 let (t, _) = emit_expr(ctx, &args[2].value)?;
400 let res_ty = map_ast_type(ctx.b, &ty);
401 let glsl = ctx.b.ext_inst_import("GLSL.std.450");
402 let res_id = ctx.b.ext_inst(res_ty, None, glsl, 46, vec![Operand::IdRef(a), Operand::IdRef(b), Operand::IdRef(t)]).unwrap(); return Ok((res_id, ty));
404 },
405 "step" if args.len() == 2 => {
406 let (edge, ty) = emit_expr(ctx, &args[0].value)?;
407 let (x, _) = emit_expr(ctx, &args[1].value)?;
408 let res_ty = map_ast_type(ctx.b, &ty);
409 let glsl = ctx.b.ext_inst_import("GLSL.std.450");
410 let res_id = ctx.b.ext_inst(res_ty, None, glsl, 48, vec![Operand::IdRef(edge), Operand::IdRef(x)]).unwrap(); return Ok((res_id, ty));
412 },
413 "smoothstep" if args.len() == 3 => {
414 let (edge0, ty) = emit_expr(ctx, &args[0].value)?;
415 let (edge1, _) = emit_expr(ctx, &args[1].value)?;
416 let (x, _) = emit_expr(ctx, &args[2].value)?;
417 let res_ty = map_ast_type(ctx.b, &ty);
418 let glsl = ctx.b.ext_inst_import("GLSL.std.450");
419 let res_id = ctx.b.ext_inst(res_ty, None, glsl, 49, vec![Operand::IdRef(edge0), Operand::IdRef(edge1), Operand::IdRef(x)]).unwrap(); return Ok((res_id, ty));
421 },
422 "length" if args.len() == 1 => {
423 let (val, _) = emit_expr(ctx, &args[0].value)?;
424 let glsl = ctx.b.ext_inst_import("GLSL.std.450");
425 let res_id = ctx.b.ext_inst(float, None, glsl, 66, vec![Operand::IdRef(val)]).unwrap(); return Ok((res_id, Type::Named { name: "Float".into(), generics: vec![], span: expr.span() }));
427 },
428 "normalize" if args.len() == 1 => {
429 let (val, ty) = emit_expr(ctx, &args[0].value)?;
430 let res_ty = map_ast_type(ctx.b, &ty);
431 let glsl = ctx.b.ext_inst_import("GLSL.std.450");
432 let res_id = ctx.b.ext_inst(res_ty, None, glsl, 69, vec![Operand::IdRef(val)]).unwrap(); return Ok((res_id, ty));
434 },
435 "dot" if args.len() == 2 => {
436 let (a, _) = emit_expr(ctx, &args[0].value)?;
437 let (b, _) = emit_expr(ctx, &args[1].value)?;
438 let res_id = ctx.b.dot(float, None, a, b).unwrap();
439 return Ok((res_id, Type::Named { name: "Float".into(), generics: vec![], span: expr.span() }));
440 },
441 "cross" if args.len() == 2 => {
442 let (a, ty) = emit_expr(ctx, &args[0].value)?;
443 let (b, _) = emit_expr(ctx, &args[1].value)?;
444 let res_ty = map_ast_type(ctx.b, &ty);
445 let glsl = ctx.b.ext_inst_import("GLSL.std.450");
446 let res_id = ctx.b.ext_inst(res_ty, None, glsl, 68, vec![Operand::IdRef(a), Operand::IdRef(b)]).unwrap(); return Ok((res_id, ty));
448 },
449 "reflect" if args.len() == 2 => {
450 let (i, ty) = emit_expr(ctx, &args[0].value)?;
451 let (n, _) = emit_expr(ctx, &args[1].value)?;
452 let res_ty = map_ast_type(ctx.b, &ty);
453 let glsl = ctx.b.ext_inst_import("GLSL.std.450");
454 let res_id = ctx.b.ext_inst(res_ty, None, glsl, 71, vec![Operand::IdRef(i), Operand::IdRef(n)]).unwrap(); return Ok((res_id, ty));
456 },
457
458 "sample" if args.len() == 2 => {
460 let (sampler, _) = emit_expr(ctx, &args[0].value)?;
461 let (coords, _) = emit_expr(ctx, &args[1].value)?;
462 let vec4 = ctx.b.type_vector(float, 4);
463 let res_id = ctx.b.image_sample_implicit_lod(vec4, None, sampler, coords, None, std::iter::empty()).unwrap();
464 return Ok((res_id, Type::Named { name: "Vec4".into(), generics: vec![], span: expr.span() }));
465 },
466 "sample_lod" if args.len() == 3 => {
467 let (sampler, _) = emit_expr(ctx, &args[0].value)?;
468 let (coords, _) = emit_expr(ctx, &args[1].value)?;
469 let (lod, _) = emit_expr(ctx, &args[2].value)?;
470 let vec4 = ctx.b.type_vector(float, 4);
471 let res_id = ctx.b.image_sample_explicit_lod(vec4, None, sampler, coords, rspirv::spirv::ImageOperands::LOD, vec![Operand::IdRef(lod)]).unwrap();
472 return Ok((res_id, Type::Named { name: "Vec4".into(), generics: vec![], span: expr.span() }));
473 },
474
475 _ => {}
476 }
477 }
478 Err(KoreError::codegen(format!("Unsupported function call in shader: {:?}", callee), expr.span()))
479 },
480 Expr::Float(f, span) => {
481 let float = ctx.b.type_float(32);
482 let val = ctx.b.constant_bit32(float, (*f as f32).to_bits());
483 Ok((val, Type::Named { name: "Float".into(), generics: vec![], span: *span }))
484 },
485 Expr::Field { object, field, span } => {
486 let (obj_id, _obj_ty) = emit_expr(ctx, object)?;
487
488 let float = ctx.b.type_float(32);
490 match field.as_str() {
491 "x" | "r" => {
493 let res_id = ctx.b.composite_extract(float, None, obj_id, vec![0]).unwrap();
494 Ok((res_id, Type::Named { name: "Float".into(), generics: vec![], span: *span }))
495 },
496 "y" | "g" => {
497 let res_id = ctx.b.composite_extract(float, None, obj_id, vec![1]).unwrap();
498 Ok((res_id, Type::Named { name: "Float".into(), generics: vec![], span: *span }))
499 },
500 "z" | "b" => {
501 let res_id = ctx.b.composite_extract(float, None, obj_id, vec![2]).unwrap();
502 Ok((res_id, Type::Named { name: "Float".into(), generics: vec![], span: *span }))
503 },
504 "w" | "a" => {
505 let res_id = ctx.b.composite_extract(float, None, obj_id, vec![3]).unwrap();
506 Ok((res_id, Type::Named { name: "Float".into(), generics: vec![], span: *span }))
507 },
508 "xy" | "rg" => {
510 let vec2 = ctx.b.type_vector(float, 2);
511 let res_id = ctx.b.vector_shuffle(vec2, None, obj_id, obj_id, vec![0, 1]).unwrap();
512 Ok((res_id, Type::Named { name: "Vec2".into(), generics: vec![], span: *span }))
513 },
514 "xz" | "rb" => {
515 let vec2 = ctx.b.type_vector(float, 2);
516 let res_id = ctx.b.vector_shuffle(vec2, None, obj_id, obj_id, vec![0, 2]).unwrap();
517 Ok((res_id, Type::Named { name: "Vec2".into(), generics: vec![], span: *span }))
518 },
519 "yz" | "gb" => {
520 let vec2 = ctx.b.type_vector(float, 2);
521 let res_id = ctx.b.vector_shuffle(vec2, None, obj_id, obj_id, vec![1, 2]).unwrap();
522 Ok((res_id, Type::Named { name: "Vec2".into(), generics: vec![], span: *span }))
523 },
524 "xyz" | "rgb" => {
526 let vec3 = ctx.b.type_vector(float, 3);
527 let res_id = ctx.b.vector_shuffle(vec3, None, obj_id, obj_id, vec![0, 1, 2]).unwrap();
528 Ok((res_id, Type::Named { name: "Vec3".into(), generics: vec![], span: *span }))
529 },
530 _ => Err(KoreError::codegen(format!("Unsupported field access: {}", field), *span))
531 }
532 },
533 _ => Err(KoreError::codegen("Unsupported expression in shader", expr.span())),
534 }
535}
536
537fn map_ast_type(b: &mut Builder, ty: &Type) -> u32 {
538 let float = b.type_float(32);
539 match ty {
540 Type::Named { name, .. } => match name.as_str() {
541 "Float" | "f32" => float,
542 "Int" | "i32" => b.type_int(32, 1),
543 "Bool" => b.type_bool(),
544 "Vec2" => b.type_vector(float, 2),
545 "Vec3" => b.type_vector(float, 3),
546 "Vec4" => b.type_vector(float, 4),
547 "Mat4" => {
548 let v4 = b.type_vector(float, 4);
549 b.type_matrix(v4, 4)
550 },
551 "Sampler2D" => {
552 let image = b.type_image(float, rspirv::spirv::Dim::Dim2D, 0, 0, 0, 1, rspirv::spirv::ImageFormat::Unknown, None);
554 b.type_sampled_image(image)
555 },
556 "StorageBuffer" => {
557 let rt_array = b.type_runtime_array(float);
560 let struct_ty = b.type_struct(vec![rt_array]);
561 b.decorate(struct_ty, Decoration::Block, vec![]);
562 struct_ty
563 },
564 "Void" => b.type_void(),
565 _ => b.type_void(),
566 },
567 _ => b.type_void(),
568 }
569}
570
571fn is_void(ty: &Type) -> bool {
572 matches!(ty, Type::Named { name, .. } if name == "Void")
573}
574
575fn is_vec4(ty: &Type) -> bool {
576 matches!(ty, Type::Named { name, .. } if name == "Vec4")
577}
578
579fn is_mat4(ty: &Type) -> bool {
580 matches!(ty, Type::Named { name, .. } if name == "Mat4")
581}
582
583fn is_float(ty: &Type) -> bool {
584 matches!(ty, Type::Named { name, .. } if name == "Float" || name == "f32")
585}
586