1use crate::types::{TypedProgram, TypedItem, TypedShader};
5use crate::error::{KoreResult, KoreError};
6use crate::ast::{Type, ShaderStage, Expr, Stmt, Block, BinaryOp, Pattern};
7use std::collections::HashMap;
8
9pub fn generate(program: &TypedProgram) -> KoreResult<String> {
10 let mut output = String::new();
11
12 output.push_str("// Generated by KORE Compiler\n");
14 output.push_str("// Direct HLSL codegen - SUPERCHARGED\n\n");
15
16 for item in &program.items {
17 if let TypedItem::Shader(shader) = item {
18 output.push_str(&emit_shader(shader)?);
19 }
20 }
21
22 Ok(output)
23}
24
25struct HLSLContext {
26 vars: HashMap<String, String>,
28 indent_level: usize,
29 uniform_bindings: Vec<(String, String, u32)>, }
32
33impl HLSLContext {
34 fn new() -> Self {
35 Self {
36 vars: HashMap::new(),
37 indent_level: 0,
38 uniform_bindings: Vec::new(),
39 }
40 }
41
42 fn indent(&self) -> String {
43 " ".repeat(self.indent_level)
44 }
45
46 fn push_indent(&mut self) {
47 self.indent_level += 1;
48 }
49
50 fn pop_indent(&mut self) {
51 if self.indent_level > 0 {
52 self.indent_level -= 1;
53 }
54 }
55}
56
57fn emit_shader(shader: &TypedShader) -> KoreResult<String> {
58 let mut output = String::new();
59 let mut ctx = HLSLContext::new();
60
61 for uniform in &shader.ast.uniforms {
63 let hlsl_type = map_type_to_hlsl(&uniform.ty);
64 ctx.uniform_bindings.push((uniform.name.clone(), hlsl_type, uniform.binding));
65 }
66
67 let mut cbuffer_uniforms = Vec::new();
69 let mut texture_uniforms = Vec::new();
70 let mut buffer_uniforms = Vec::new();
71
72 for (name, ty, binding) in &ctx.uniform_bindings {
73 if ty.contains("Texture") || ty.contains("Sampler") {
74 texture_uniforms.push((name.clone(), ty.clone(), *binding));
75 } else if ty.contains("Buffer") || ty.contains("RWBuffer") || ty.contains("StructuredBuffer") {
76 buffer_uniforms.push((name.clone(), ty.clone(), *binding));
77 } else {
78 cbuffer_uniforms.push((name.clone(), ty.clone(), *binding));
79 }
80 }
81
82 if !cbuffer_uniforms.is_empty() {
84 output.push_str("cbuffer ShaderParams : register(b0)\n{\n");
85 for (name, ty, _) in &cbuffer_uniforms {
86 output.push_str(&format!(" {} {};\n", ty, name));
87 }
88 output.push_str("};\n\n");
89 }
90
91 for (name, ty, binding) in &texture_uniforms {
93 output.push_str(&format!("{} {} : register(t{});\n", ty, name, binding));
94 output.push_str(&format!("SamplerState {}_sampler : register(s{});\n", name, binding));
95 }
96 if !texture_uniforms.is_empty() {
97 output.push_str("\n");
98 }
99
100 for (name, ty, binding) in &buffer_uniforms {
102 output.push_str(&format!("{} {} : register(u{});\n", ty, name, binding));
103 }
104 if !buffer_uniforms.is_empty() {
105 output.push_str("\n");
106 }
107
108 match shader.ast.stage {
109 ShaderStage::Compute => {
110 output.push_str("[numthreads(8, 8, 1)]\n");
112 output.push_str("void CSMain(uint3 dispatchThreadID : SV_DispatchThreadID,\n");
113 output.push_str(" uint3 groupThreadID : SV_GroupThreadID,\n");
114 output.push_str(" uint3 groupID : SV_GroupID,\n");
115 output.push_str(" uint groupIndex : SV_GroupIndex)\n{\n");
116 ctx.push_indent();
117
118 ctx.vars.insert("dispatch_thread_id".to_string(), "dispatchThreadID".to_string());
120 ctx.vars.insert("group_thread_id".to_string(), "groupThreadID".to_string());
121 ctx.vars.insert("group_id".to_string(), "groupID".to_string());
122 ctx.vars.insert("group_index".to_string(), "groupIndex".to_string());
123
124 let body_code = emit_block(&mut ctx, &shader.ast.body)?;
126 output.push_str(&body_code);
127
128 ctx.pop_indent();
129 output.push_str("}\n");
130 },
131 ShaderStage::Vertex => {
132 output.push_str("struct VSInput\n{\n");
134 for (i, param) in shader.ast.inputs.iter().enumerate() {
135 let hlsl_type = map_type_to_hlsl(¶m.ty);
136 let semantic = match param.name.as_str() {
137 "position" => "POSITION",
138 "normal" => "NORMAL",
139 "tangent" => "TANGENT",
140 "color" => "COLOR",
141 _ => "TEXCOORD",
142 };
143 output.push_str(&format!(" {} {} : {}{};\n",
144 hlsl_type, param.name, semantic,
145 if semantic == "TEXCOORD" { i.to_string() } else { "".to_string() }
146 ));
147 }
148 output.push_str("};\n\n");
149
150 output.push_str("struct VSOutput\n{\n");
152 output.push_str(" float4 position : SV_Position;\n");
153 output.push_str("};\n\n");
155
156 output.push_str("VSOutput VSMain(VSInput input)\n{\n");
158 ctx.push_indent();
159
160 for param in &shader.ast.inputs {
162 ctx.vars.insert(param.name.clone(), format!("input.{}", param.name));
163 }
164
165 let body_code = emit_block(&mut ctx, &shader.ast.body)?;
167 output.push_str(&body_code);
168
169 ctx.pop_indent();
170 output.push_str("}\n");
171 },
172 ShaderStage::Fragment => {
173 output.push_str("struct VSInput\n{\n");
175 for (i, param) in shader.ast.inputs.iter().enumerate() {
176 let hlsl_type = map_type_to_hlsl(¶m.ty);
177 output.push_str(&format!(" {} {} : TEXCOORD{};\n", hlsl_type, param.name, i));
178 }
179 output.push_str("};\n\n");
180
181 output.push_str("struct PSOutput\n{\n");
183 let out_type = map_type_to_hlsl(&shader.ast.outputs);
184 output.push_str(&format!(" {} color : SV_Target0;\n", out_type));
185 output.push_str("};\n\n");
187
188 output.push_str("PSOutput PSMain(VSInput input)\n{\n");
190 ctx.push_indent();
191
192 for param in &shader.ast.inputs {
194 ctx.vars.insert(param.name.clone(), format!("input.{}", param.name));
195 }
196
197 let body_code = emit_block(&mut ctx, &shader.ast.body)?;
199 output.push_str(&body_code);
200
201 ctx.pop_indent();
202 output.push_str("}\n");
203 },
204 }
205
206 Ok(output)
207}
208
209fn emit_block(ctx: &mut HLSLContext, block: &Block) -> KoreResult<String> {
210 let mut output = String::new();
211
212 for stmt in &block.stmts {
213 output.push_str(&emit_stmt(ctx, stmt)?);
214 }
215
216 Ok(output)
217}
218
219fn emit_stmt(ctx: &mut HLSLContext, stmt: &Stmt) -> KoreResult<String> {
220 let mut output = String::new();
221
222 match stmt {
223 Stmt::Let { pattern, value, .. } => {
224 if let Some(value) = value {
225 if let Pattern::Binding { name, .. } = pattern {
226 let (expr_code, expr_type) = emit_expr(ctx, value)?;
227 output.push_str(&format!("{}{} {} = {};\n", ctx.indent(), expr_type, name, expr_code));
228 ctx.vars.insert(name.clone(), name.clone());
229 }
230 }
231 },
232 Stmt::Return(Some(expr), _) => {
233 let (expr_code, _) = emit_expr(ctx, expr)?;
234 output.push_str(&format!("{}PSOutput _result;\n", ctx.indent()));
235 output.push_str(&format!("{}_result.color = {};\n", ctx.indent(), expr_code));
236 output.push_str(&format!("{}return _result;\n", ctx.indent()));
237 },
238 Stmt::Return(None, _) => {
239 output.push_str(&format!("{}return;\n", ctx.indent()));
240 },
241 Stmt::Expr(expr) => {
242 let (expr_code, _) = emit_expr(ctx, expr)?;
243 output.push_str(&format!("{}{};\n", ctx.indent(), expr_code));
244 },
245 Stmt::While { condition, body, .. } => {
247 let (cond_code, _) = emit_expr(ctx, condition)?;
248 output.push_str(&format!("{}while ({})\n", ctx.indent(), cond_code));
249 output.push_str(&format!("{}{{\n", ctx.indent()));
250 ctx.push_indent();
251 output.push_str(&emit_block(ctx, body)?);
252 ctx.pop_indent();
253 output.push_str(&format!("{}}}\n", ctx.indent()));
254 },
255 Stmt::For { binding, iter: _, body, .. } => {
256 if let Pattern::Binding { name, .. } = binding {
258 output.push_str(&format!("{}for (int {} = 0; {} < 10; {}++)\n",
259 ctx.indent(), name, name, name));
260 output.push_str(&format!("{}{{\n", ctx.indent()));
261 ctx.push_indent();
262 ctx.vars.insert(name.clone(), name.clone());
263 output.push_str(&emit_block(ctx, body)?);
264 ctx.pop_indent();
265 output.push_str(&format!("{}}}\n", ctx.indent()));
266 }
267 },
268 Stmt::Break(_, _) => {
269 output.push_str(&format!("{}break;\n", ctx.indent()));
270 },
271 Stmt::Continue(_) => {
272 output.push_str(&format!("{}continue;\n", ctx.indent()));
273 },
274 _ => {}
275 }
276
277 Ok(output)
278}
279
280fn emit_expr(ctx: &mut HLSLContext, expr: &Expr) -> KoreResult<(String, String)> {
281 match expr {
282 Expr::Ident(name, _) => {
283 if let Some(mapped) = ctx.vars.get(name) {
284 Ok((mapped.clone(), "float4".to_string()))
285 } else {
286 Ok((name.clone(), "float4".to_string()))
287 }
288 },
289 Expr::Float(f, _) => {
290 Ok((format!("{:.6}", f), "float".to_string()))
291 },
292 Expr::Int(i, _) => {
293 Ok((format!("{}", i), "int".to_string()))
294 },
295 Expr::Bool(b, _) => {
296 Ok((format!("{}", b), "bool".to_string()))
297 },
298 Expr::String(s, _) => {
299 Ok((format!("\"{}\"", s), "string".to_string()))
301 },
302 Expr::Binary { left, op, right, .. } => {
303 let (left_code, left_ty) = emit_expr(ctx, left)?;
304 let (right_code, _) = emit_expr(ctx, right)?;
305
306 let op_str = match op {
307 BinaryOp::Add => "+",
308 BinaryOp::Sub => "-",
309 BinaryOp::Mul => "*",
310 BinaryOp::Div => "/",
311 BinaryOp::Mod => "%",
312 BinaryOp::Eq => "==",
313 BinaryOp::Ne => "!=",
314 BinaryOp::Lt => "<",
315 BinaryOp::Le => "<=",
316 BinaryOp::Gt => ">",
317 BinaryOp::Ge => ">=",
318 BinaryOp::And => "&&",
319 BinaryOp::Or => "||",
320 BinaryOp::BitAnd => "&",
321 BinaryOp::BitOr => "|",
322 BinaryOp::BitXor => "^",
323 BinaryOp::Shl => "<<",
324 BinaryOp::Shr => ">>",
325 _ => return Err(KoreError::codegen("Unsupported binary op", expr.span())),
326 };
327
328 let result_ty = match op {
330 BinaryOp::Eq | BinaryOp::Ne | BinaryOp::Lt | BinaryOp::Le |
331 BinaryOp::Gt | BinaryOp::Ge | BinaryOp::And | BinaryOp::Or => "bool".to_string(),
332 _ => left_ty,
333 };
334
335 Ok((format!("({} {} {})", left_code, op_str, right_code), result_ty))
336 },
337 Expr::Unary { op, operand, .. } => {
338 let (operand_code, ty) = emit_expr(ctx, operand)?;
339 let op_str = match op {
340 crate::ast::UnaryOp::Neg => "-",
341 crate::ast::UnaryOp::Not => "!",
342 crate::ast::UnaryOp::BitNot => "~",
343 crate::ast::UnaryOp::Ref | crate::ast::UnaryOp::RefMut => {
344 return Ok((operand_code, ty));
346 },
347 crate::ast::UnaryOp::Deref => {
348 return Ok((operand_code, ty));
350 },
351 };
352 Ok((format!("({}{})", op_str, operand_code), ty))
353 },
354 Expr::Call { callee, args, .. } => {
355 if let Expr::Ident(name, _) = &**callee {
356 emit_function_call(ctx, name, args)
357 } else {
358 Err(KoreError::codegen("Complex callee not supported", expr.span()))
359 }
360 },
361 Expr::Field { object, field, .. } => {
362 let (obj_code, _) = emit_expr(ctx, object)?;
363
364 Ok((format!("{}.{}", obj_code, field), infer_swizzle_type(field)))
368 },
369 Expr::Index { object, index, .. } => {
370 let (obj_code, obj_ty) = emit_expr(ctx, object)?;
371 let (idx_code, _) = emit_expr(ctx, index)?;
372 let elem_ty = if obj_ty.starts_with("float") {
374 "float".to_string()
375 } else {
376 obj_ty
377 };
378 Ok((format!("{}[{}]", obj_code, idx_code), elem_ty))
379 },
380 Expr::If { condition, then_branch, else_branch, .. } => {
381 let (cond_code, _) = emit_expr(ctx, condition)?;
383
384 if then_branch.stmts.len() == 1 && else_branch.is_some() {
387 if let Stmt::Expr(then_expr) = &then_branch.stmts[0] {
388 let (then_code, then_ty) = emit_expr(ctx, then_expr)?;
389
390 if let Some(crate::ast::ElseBranch::Else(else_block)) = else_branch.as_ref().map(|b| b.as_ref()) {
391 if else_block.stmts.len() == 1 {
392 if let Stmt::Expr(else_expr) = &else_block.stmts[0] {
393 let (else_code, _) = emit_expr(ctx, else_expr)?;
394 return Ok((format!("({} ? {} : {})", cond_code, then_code, else_code), then_ty));
395 }
396 }
397 }
398 }
399 }
400
401 Err(KoreError::codegen("Complex if expressions not yet supported in HLSL backend", expr.span()))
403 },
404 Expr::Paren(inner, _) => {
405 let (inner_code, ty) = emit_expr(ctx, inner)?;
407 Ok((format!("({})", inner_code), ty))
408 },
409 _ => Err(KoreError::codegen("Unsupported expression", expr.span())),
410 }
411}
412
413fn emit_function_call(ctx: &mut HLSLContext, name: &str, args: &[crate::ast::CallArg]) -> KoreResult<(String, String)> {
414 match name {
415 "vec2" | "Vec2" => {
417 let mut arg_codes = Vec::new();
418 for arg in args {
419 let (code, _) = emit_expr(ctx, &arg.value)?;
420 arg_codes.push(code);
421 }
422 Ok((format!("float2({})", arg_codes.join(", ")), "float2".to_string()))
423 },
424 "vec3" | "Vec3" => {
425 let mut arg_codes = Vec::new();
426 for arg in args {
427 let (code, _) = emit_expr(ctx, &arg.value)?;
428 arg_codes.push(code);
429 }
430 Ok((format!("float3({})", arg_codes.join(", ")), "float3".to_string()))
431 },
432 "vec4" | "Vec4" => {
433 let mut arg_codes = Vec::new();
434 for arg in args {
435 let (code, _) = emit_expr(ctx, &arg.value)?;
436 arg_codes.push(code);
437 }
438 Ok((format!("float4({})", arg_codes.join(", ")), "float4".to_string()))
439 },
440
441 "sin" | "cos" | "tan" => {
443 let (arg_code, ty) = emit_expr(ctx, &args[0].value)?;
444 Ok((format!("{}({})", name, arg_code), ty))
445 },
446 "asin" | "acos" | "atan" => {
447 let (arg_code, ty) = emit_expr(ctx, &args[0].value)?;
448 Ok((format!("{}({})", name, arg_code), ty))
449 },
450 "atan2" => {
451 let (arg1, ty) = emit_expr(ctx, &args[0].value)?;
452 let (arg2, _) = emit_expr(ctx, &args[1].value)?;
453 Ok((format!("atan2({}, {})", arg1, arg2), ty))
454 },
455
456 "abs" | "floor" | "ceil" | "round" | "trunc" | "fract" => {
458 let (arg_code, ty) = emit_expr(ctx, &args[0].value)?;
459 let hlsl_name = if name == "fract" { "frac" } else { name };
460 Ok((format!("{}({})", hlsl_name, arg_code), ty))
461 },
462 "sqrt" | "rsqrt" | "exp" | "exp2" | "log" | "log2" | "log10" => {
463 let (arg_code, ty) = emit_expr(ctx, &args[0].value)?;
464 Ok((format!("{}({})", name, arg_code), ty))
465 },
466 "sign" | "saturate" => {
467 let (arg_code, ty) = emit_expr(ctx, &args[0].value)?;
468 Ok((format!("{}({})", name, arg_code), ty))
469 },
470
471 "pow" | "min" | "max" | "fmod" | "step" => {
473 let (arg1, ty) = emit_expr(ctx, &args[0].value)?;
474 let (arg2, _) = emit_expr(ctx, &args[1].value)?;
475 Ok((format!("{}({}, {})", name, arg1, arg2), ty))
476 },
477
478 "clamp" | "smoothstep" | "mad" => {
480 let (arg1, ty) = emit_expr(ctx, &args[0].value)?;
481 let (arg2, _) = emit_expr(ctx, &args[1].value)?;
482 let (arg3, _) = emit_expr(ctx, &args[2].value)?;
483 Ok((format!("{}({}, {}, {})", name, arg1, arg2, arg3), ty))
484 },
485 "mix" | "lerp" => {
486 let (arg1, ty) = emit_expr(ctx, &args[0].value)?;
487 let (arg2, _) = emit_expr(ctx, &args[1].value)?;
488 let (arg3, _) = emit_expr(ctx, &args[2].value)?;
489 Ok((format!("lerp({}, {}, {})", arg1, arg2, arg3), ty))
490 },
491
492 "length" => {
494 let (arg_code, _) = emit_expr(ctx, &args[0].value)?;
495 Ok((format!("length({})", arg_code), "float".to_string()))
496 },
497 "distance" => {
498 let (arg1, _) = emit_expr(ctx, &args[0].value)?;
499 let (arg2, _) = emit_expr(ctx, &args[1].value)?;
500 Ok((format!("distance({}, {})", arg1, arg2), "float".to_string()))
501 },
502 "normalize" => {
503 let (arg_code, ty) = emit_expr(ctx, &args[0].value)?;
504 Ok((format!("normalize({})", arg_code), ty))
505 },
506 "dot" => {
507 let (arg1, _) = emit_expr(ctx, &args[0].value)?;
508 let (arg2, _) = emit_expr(ctx, &args[1].value)?;
509 Ok((format!("dot({}, {})", arg1, arg2), "float".to_string()))
510 },
511 "cross" => {
512 let (arg1, ty) = emit_expr(ctx, &args[0].value)?;
513 let (arg2, _) = emit_expr(ctx, &args[1].value)?;
514 Ok((format!("cross({}, {})", arg1, arg2), ty))
515 },
516 "reflect" => {
517 let (arg1, ty) = emit_expr(ctx, &args[0].value)?;
518 let (arg2, _) = emit_expr(ctx, &args[1].value)?;
519 Ok((format!("reflect({}, {})", arg1, arg2), ty))
520 },
521 "refract" => {
522 let (arg1, ty) = emit_expr(ctx, &args[0].value)?;
523 let (arg2, _) = emit_expr(ctx, &args[1].value)?;
524 let (arg3, _) = emit_expr(ctx, &args[2].value)?;
525 Ok((format!("refract({}, {}, {})", arg1, arg2, arg3), ty))
526 },
527 "faceforward" => {
528 let (arg1, ty) = emit_expr(ctx, &args[0].value)?;
529 let (arg2, _) = emit_expr(ctx, &args[1].value)?;
530 let (arg3, _) = emit_expr(ctx, &args[2].value)?;
531 Ok((format!("faceforward({}, {}, {})", arg1, arg2, arg3), ty))
532 },
533
534 "transpose" => {
536 let (arg_code, ty) = emit_expr(ctx, &args[0].value)?;
537 Ok((format!("transpose({})", arg_code), ty))
538 },
539 "determinant" => {
540 let (arg_code, _) = emit_expr(ctx, &args[0].value)?;
541 Ok((format!("determinant({})", arg_code), "float".to_string()))
542 },
543
544 "sample" => {
546 let (sampler, _) = emit_expr(ctx, &args[0].value)?;
547 let (coords, _) = emit_expr(ctx, &args[1].value)?;
548 Ok((format!("{}.Sample({}_sampler, {})", sampler, sampler, coords), "float4".to_string()))
549 },
550 "sample_lod" => {
551 let (sampler, _) = emit_expr(ctx, &args[0].value)?;
552 let (coords, _) = emit_expr(ctx, &args[1].value)?;
553 let (lod, _) = emit_expr(ctx, &args[2].value)?;
554 Ok((format!("{}.SampleLevel({}_sampler, {}, {})", sampler, sampler, coords, lod), "float4".to_string()))
555 },
556 "sample_grad" => {
557 let (sampler, _) = emit_expr(ctx, &args[0].value)?;
558 let (coords, _) = emit_expr(ctx, &args[1].value)?;
559 let (ddx, _) = emit_expr(ctx, &args[2].value)?;
560 let (ddy, _) = emit_expr(ctx, &args[3].value)?;
561 Ok((format!("{}.SampleGrad({}_sampler, {}, {}, {})", sampler, sampler, coords, ddx, ddy), "float4".to_string()))
562 },
563 "sample_bias" => {
564 let (sampler, _) = emit_expr(ctx, &args[0].value)?;
565 let (coords, _) = emit_expr(ctx, &args[1].value)?;
566 let (bias, _) = emit_expr(ctx, &args[2].value)?;
567 Ok((format!("{}.SampleBias({}_sampler, {}, {})", sampler, sampler, coords, bias), "float4".to_string()))
568 },
569 "sample_cmp" => {
570 let (sampler, _) = emit_expr(ctx, &args[0].value)?;
571 let (coords, _) = emit_expr(ctx, &args[1].value)?;
572 let (compare, _) = emit_expr(ctx, &args[2].value)?;
573 Ok((format!("{}.SampleCmp({}_sampler, {}, {})", sampler, sampler, coords, compare), "float".to_string()))
574 },
575 "load" => {
576 let (texture, _) = emit_expr(ctx, &args[0].value)?;
577 let (location, _) = emit_expr(ctx, &args[1].value)?;
578 Ok((format!("{}.Load({})", texture, location), "float4".to_string()))
579 },
580
581 "ddx" | "ddy" | "ddx_fine" | "ddy_fine" | "ddx_coarse" | "ddy_coarse" => {
583 let (arg_code, ty) = emit_expr(ctx, &args[0].value)?;
584 Ok((format!("{}({})", name, arg_code), ty))
585 },
586 "fwidth" => {
587 let (arg_code, ty) = emit_expr(ctx, &args[0].value)?;
588 Ok((format!("fwidth({})", arg_code), ty))
589 },
590
591 "countbits" | "firstbithigh" | "firstbitlow" | "reversebits" => {
593 let (arg_code, ty) = emit_expr(ctx, &args[0].value)?;
594 Ok((format!("{}({})", name, arg_code), ty))
595 },
596
597 "all" | "any" => {
599 let (arg_code, _) = emit_expr(ctx, &args[0].value)?;
600 Ok((format!("{}({})", name, arg_code), "bool".to_string()))
601 },
602
603 "noise" => {
605 let (arg_code, _) = emit_expr(ctx, &args[0].value)?;
606 Ok((format!("frac(sin(dot({}, float2(12.9898, 78.233))) * 43758.5453)", arg_code), "float".to_string()))
607 },
608 "noise3d" => {
609 let (arg_code, _) = emit_expr(ctx, &args[0].value)?;
610 Ok((format!("frac(sin(dot({}, float3(12.9898, 78.233, 37.719))) * 43758.5453)", arg_code), "float".to_string()))
611 },
612
613 "pack_half_2x16" => {
615 let (arg_code, _) = emit_expr(ctx, &args[0].value)?;
616 Ok((format!("f32tof16({}).x | (f32tof16({}).y << 16)", arg_code, arg_code), "uint".to_string()))
617 },
618 "unpack_half_2x16" => {
619 let (arg_code, _) = emit_expr(ctx, &args[0].value)?;
620 Ok((format!("float2(f16tof32({} & 0xFFFF), f16tof32({} >> 16))", arg_code, arg_code), "float2".to_string()))
621 },
622
623 "texture_size" => {
625 let (texture, _) = emit_expr(ctx, &args[0].value)?;
626 let (lod, _) = if args.len() > 1 {
627 emit_expr(ctx, &args[1].value)?
628 } else {
629 ("0".to_string(), "int".to_string())
630 };
631 Ok((format!("{}.GetDimensions({})", texture, lod), "int2".to_string()))
632 },
633 "texture_query_lod" => {
634 let (sampler, _) = emit_expr(ctx, &args[0].value)?;
635 let (coords, _) = emit_expr(ctx, &args[1].value)?;
636 Ok((format!("{}.CalculateLevelOfDetail({}_sampler, {})", sampler, sampler, coords), "float".to_string()))
637 },
638 "texture_gather" => {
639 let (texture, _) = emit_expr(ctx, &args[0].value)?;
640 let (coords, _) = emit_expr(ctx, &args[1].value)?;
641 let component = if args.len() > 2 {
642 if let Expr::Int(i, _) = &args[2].value {
643 *i as u32
644 } else {
645 0
646 }
647 } else {
648 0
649 };
650 Ok((format!("{}.Gather({}_sampler, {}, {})", texture, texture, coords, component), "float4".to_string()))
651 },
652
653 "rgb_to_hsv" => {
655 let (rgb, _) = emit_expr(ctx, &args[0].value)?;
656 let code = format!(
657 "({{ \
658 float3 _rgb = {}; \
659 float4 K = float4(0.0, -1.0/3.0, 2.0/3.0, -1.0); \
660 float4 p = lerp(float4(_rgb.bg, K.wz), float4(_rgb.gb, K.xy), step(_rgb.b, _rgb.g)); \
661 float4 q = lerp(float4(p.xyw, _rgb.r), float4(_rgb.r, p.yzx), step(p.x, _rgb.r)); \
662 float d = q.x - min(q.w, q.y); \
663 float e = 1.0e-10; \
664 float3(abs(q.z + (q.w - q.y) / (6.0 * d + e)), d / (q.x + e), q.x); \
665 }})", rgb
666 );
667 Ok((code, "float3".to_string()))
668 },
669 "hsv_to_rgb" => {
670 let (hsv, _) = emit_expr(ctx, &args[0].value)?;
671 let code = format!(
672 "({{ \
673 float3 _hsv = {}; \
674 float4 K = float4(1.0, 2.0/3.0, 1.0/3.0, 3.0); \
675 float3 p = abs(frac(_hsv.xxx + K.xyz) * 6.0 - K.www); \
676 lerp(K.xxx, saturate(p - K.xxx), _hsv.y) * _hsv.z; \
677 }})", hsv
678 );
679 Ok((code, "float3".to_string()))
680 },
681
682 "mat2" | "Mat2" => {
684 let mut arg_codes = Vec::new();
685 for arg in args {
686 let (code, _) = emit_expr(ctx, &arg.value)?;
687 arg_codes.push(code);
688 }
689 Ok((format!("float2x2({})", arg_codes.join(", ")), "float2x2".to_string()))
690 },
691 "mat3" | "Mat3" => {
692 let mut arg_codes = Vec::new();
693 for arg in args {
694 let (code, _) = emit_expr(ctx, &arg.value)?;
695 arg_codes.push(code);
696 }
697 Ok((format!("float3x3({})", arg_codes.join(", ")), "float3x3".to_string()))
698 },
699 "mat4" | "Mat4" => {
700 let mut arg_codes = Vec::new();
701 for arg in args {
702 let (code, _) = emit_expr(ctx, &arg.value)?;
703 arg_codes.push(code);
704 }
705 Ok((format!("float4x4({})", arg_codes.join(", ")), "float4x4".to_string()))
706 },
707
708 "modf" => {
710 let (arg_code, ty) = emit_expr(ctx, &args[0].value)?;
711 Ok((format!("modf({}, _modf_int)", arg_code), ty))
712 },
713 "frexp" => {
714 let (arg_code, ty) = emit_expr(ctx, &args[0].value)?;
715 Ok((format!("frexp({}, _frexp_exp)", arg_code), ty))
716 },
717 "ldexp" => {
718 let (x, ty) = emit_expr(ctx, &args[0].value)?;
719 let (exp, _) = emit_expr(ctx, &args[1].value)?;
720 Ok((format!("ldexp({}, {})", x, exp), ty))
721 },
722
723 "flat" | "noperspective" | "centroid" => {
725 Err(KoreError::codegen(format!("{} is an interpolation modifier, not a function", name), crate::span::Span::new(0, 0)))
728 },
729
730 "atomic_add" | "atomic_sub" | "atomic_min" | "atomic_max" |
732 "atomic_and" | "atomic_or" | "atomic_xor" | "atomic_exchange" | "atomic_cas" => {
733 let hlsl_name = match name {
734 "atomic_add" => "InterlockedAdd",
735 "atomic_sub" => "InterlockedAdd", "atomic_min" => "InterlockedMin",
737 "atomic_max" => "InterlockedMax",
738 "atomic_and" => "InterlockedAnd",
739 "atomic_or" => "InterlockedOr",
740 "atomic_xor" => "InterlockedXor",
741 "atomic_exchange" => "InterlockedExchange",
742 "atomic_cas" => "InterlockedCompareExchange",
743 _ => name,
744 };
745
746 let mut arg_codes = Vec::new();
747 for arg in args {
748 let (code, _) = emit_expr(ctx, &arg.value)?;
749 arg_codes.push(code);
750 }
751
752 Ok((format!("{}({})", hlsl_name, arg_codes.join(", ")), "void".to_string()))
753 },
754
755 "wave_active_all_true" | "wave_active_any_true" | "wave_active_ballot" |
757 "wave_active_sum" | "wave_active_product" | "wave_active_min" | "wave_active_max" |
758 "wave_prefix_sum" | "wave_prefix_product" | "wave_read_lane_first" | "wave_read_lane_at" => {
759 let hlsl_name = match name {
760 "wave_active_all_true" => "WaveActiveAllTrue",
761 "wave_active_any_true" => "WaveActiveAnyTrue",
762 "wave_active_ballot" => "WaveActiveBallot",
763 "wave_active_sum" => "WaveActiveSum",
764 "wave_active_product" => "WaveActiveProduct",
765 "wave_active_min" => "WaveActiveMin",
766 "wave_active_max" => "WaveActiveMax",
767 "wave_prefix_sum" => "WavePrefixSum",
768 "wave_prefix_product" => "WavePrefixProduct",
769 "wave_read_lane_first" => "WaveReadLaneFirst",
770 "wave_read_lane_at" => "WaveReadLaneAt",
771 _ => name,
772 };
773
774 let mut arg_codes = Vec::new();
775 for arg in args {
776 let (code, _) = emit_expr(ctx, &arg.value)?;
777 arg_codes.push(code);
778 }
779
780 let return_type = if name.contains("ballot") {
781 "uint4".to_string()
782 } else if name.contains("all_true") || name.contains("any_true") {
783 "bool".to_string()
784 } else if !args.is_empty() {
785 emit_expr(ctx, &args[0].value)?.1
786 } else {
787 "float".to_string()
788 };
789
790 Ok((format!("{}({})", hlsl_name, arg_codes.join(", ")), return_type))
791 },
792
793 _ => Err(KoreError::codegen(format!("Unknown function: {}", name), crate::span::Span::new(0, 0))),
794 }
795}
796
797fn map_type_to_hlsl(ty: &Type) -> String {
798 match ty {
799 Type::Named { name, .. } => match name.as_str() {
800 "Float" | "f32" => "float".to_string(),
801 "Int" | "i32" => "int".to_string(),
802 "UInt" | "u32" => "uint".to_string(),
803 "Bool" => "bool".to_string(),
804 "Vec2" => "float2".to_string(),
805 "Vec3" => "float3".to_string(),
806 "Vec4" => "float4".to_string(),
807 "IVec2" => "int2".to_string(),
808 "IVec3" => "int3".to_string(),
809 "IVec4" => "int4".to_string(),
810 "UVec2" => "uint2".to_string(),
811 "UVec3" => "uint3".to_string(),
812 "UVec4" => "uint4".to_string(),
813 "Mat4" => "float4x4".to_string(),
814 "Mat3" => "float3x3".to_string(),
815 "Mat2" => "float2x2".to_string(),
816 "Sampler2D" => "Texture2D".to_string(),
817 "Sampler3D" => "Texture3D".to_string(),
818 "SamplerCube" => "TextureCube".to_string(),
819 "Sampler2DArray" => "Texture2DArray".to_string(),
820 "SamplerCubeArray" => "TextureCubeArray".to_string(),
821 "Sampler2DMS" => "Texture2DMS".to_string(),
822 "RWTexture2D" => "RWTexture2D<float4>".to_string(),
823 "RWTexture3D" => "RWTexture3D<float4>".to_string(),
824 "Buffer" => "Buffer<float4>".to_string(),
825 "RWBuffer" => "RWBuffer<float4>".to_string(),
826 "StructuredBuffer" => "StructuredBuffer<float4>".to_string(),
827 "RWStructuredBuffer" => "RWStructuredBuffer<float4>".to_string(),
828 "ByteAddressBuffer" => "ByteAddressBuffer".to_string(),
829 "RWByteAddressBuffer" => "RWByteAddressBuffer".to_string(),
830 "Void" => "void".to_string(),
831 _ => "float4".to_string(),
832 },
833 Type::Array(element, _size, _span) => {
834 let elem_ty = map_type_to_hlsl(element);
835 format!("{}[{}]", elem_ty, _size)
837 },
838 _ => "float4".to_string(),
839 }
840}
841
842fn infer_swizzle_type(swizzle: &str) -> String {
843 match swizzle.len() {
844 1 => "float".to_string(),
845 2 => "float2".to_string(),
846 3 => "float3".to_string(),
847 4 => "float4".to_string(),
848 _ => "float".to_string(),
849 }
850}