Skip to main content

kore/codegen/
hlsl.rs

1//! HLSL Code Generation - Direct KORE to HLSL/USF
2//! Bypasses SPIR-V for maximum control and code generation power
3
4use crate::types::{TypedProgram, TypedItem, TypedShader};
5use crate::error::{KoreResult, KoreError};
6use crate::ast::{Type, ShaderStage, Expr, Stmt, Block, BinaryOp, Pattern};
7use std::collections::HashMap;
8
9pub fn generate(program: &TypedProgram) -> KoreResult<String> {
10    let mut output = String::new();
11    
12    // HLSL header
13    output.push_str("// Generated by KORE Compiler\n");
14    output.push_str("// Direct HLSL codegen - SUPERCHARGED\n\n");
15    
16    for item in &program.items {
17        if let TypedItem::Shader(shader) = item {
18            output.push_str(&emit_shader(shader)?);
19        }
20    }
21    
22    Ok(output)
23}
24
25struct HLSLContext {
26    // Variable name -> HLSL type
27    vars: HashMap<String, String>,
28    indent_level: usize,
29    // Track uniform bindings
30    uniform_bindings: Vec<(String, String, u32)>, // (name, type, binding)
31}
32
33impl HLSLContext {
34    fn new() -> Self {
35        Self {
36            vars: HashMap::new(),
37            indent_level: 0,
38            uniform_bindings: Vec::new(),
39        }
40    }
41    
42    fn indent(&self) -> String {
43        "    ".repeat(self.indent_level)
44    }
45    
46    fn push_indent(&mut self) {
47        self.indent_level += 1;
48    }
49    
50    fn pop_indent(&mut self) {
51        if self.indent_level > 0 {
52            self.indent_level -= 1;
53        }
54    }
55}
56
57fn emit_shader(shader: &TypedShader) -> KoreResult<String> {
58    let mut output = String::new();
59    let mut ctx = HLSLContext::new();
60    
61    // Collect uniforms
62    for uniform in &shader.ast.uniforms {
63        let hlsl_type = map_type_to_hlsl(&uniform.ty);
64        ctx.uniform_bindings.push((uniform.name.clone(), hlsl_type, uniform.binding));
65    }
66    
67    // Generate constant buffers for non-texture uniforms
68    let mut cbuffer_uniforms = Vec::new();
69    let mut texture_uniforms = Vec::new();
70    let mut buffer_uniforms = Vec::new();
71    
72    for (name, ty, binding) in &ctx.uniform_bindings {
73        if ty.contains("Texture") || ty.contains("Sampler") {
74            texture_uniforms.push((name.clone(), ty.clone(), *binding));
75        } else if ty.contains("Buffer") || ty.contains("RWBuffer") || ty.contains("StructuredBuffer") {
76            buffer_uniforms.push((name.clone(), ty.clone(), *binding));
77        } else {
78            cbuffer_uniforms.push((name.clone(), ty.clone(), *binding));
79        }
80    }
81    
82    // Emit constant buffer
83    if !cbuffer_uniforms.is_empty() {
84        output.push_str("cbuffer ShaderParams : register(b0)\n{\n");
85        for (name, ty, _) in &cbuffer_uniforms {
86            output.push_str(&format!("    {} {};\n", ty, name));
87        }
88        output.push_str("};\n\n");
89    }
90    
91    // Emit texture declarations
92    for (name, ty, binding) in &texture_uniforms {
93        output.push_str(&format!("{} {} : register(t{});\n", ty, name, binding));
94        output.push_str(&format!("SamplerState {}_sampler : register(s{});\n", name, binding));
95    }
96    if !texture_uniforms.is_empty() {
97        output.push_str("\n");
98    }
99    
100    // Emit buffer declarations
101    for (name, ty, binding) in &buffer_uniforms {
102        output.push_str(&format!("{} {} : register(u{});\n", ty, name, binding));
103    }
104    if !buffer_uniforms.is_empty() {
105        output.push_str("\n");
106    }
107    
108    match shader.ast.stage {
109        ShaderStage::Compute => {
110            // Compute shader - different structure
111            output.push_str("[numthreads(8, 8, 1)]\n");
112            output.push_str("void CSMain(uint3 dispatchThreadID : SV_DispatchThreadID,\n");
113            output.push_str("            uint3 groupThreadID : SV_GroupThreadID,\n");
114            output.push_str("            uint3 groupID : SV_GroupID,\n");
115            output.push_str("            uint groupIndex : SV_GroupIndex)\n{\n");
116            ctx.push_indent();
117            
118            // Add compute shader built-ins to context
119            ctx.vars.insert("dispatch_thread_id".to_string(), "dispatchThreadID".to_string());
120            ctx.vars.insert("group_thread_id".to_string(), "groupThreadID".to_string());
121            ctx.vars.insert("group_id".to_string(), "groupID".to_string());
122            ctx.vars.insert("group_index".to_string(), "groupIndex".to_string());
123            
124            // Emit function body
125            let body_code = emit_block(&mut ctx, &shader.ast.body)?;
126            output.push_str(&body_code);
127            
128            ctx.pop_indent();
129            output.push_str("}\n");
130        },
131        ShaderStage::Vertex => {
132            // Generate input struct
133            output.push_str("struct VSInput\n{\n");
134            for (i, param) in shader.ast.inputs.iter().enumerate() {
135                let hlsl_type = map_type_to_hlsl(&param.ty);
136                let semantic = match param.name.as_str() {
137                    "position" => "POSITION",
138                    "normal" => "NORMAL",
139                    "tangent" => "TANGENT",
140                    "color" => "COLOR",
141                    _ => "TEXCOORD",
142                };
143                output.push_str(&format!("    {} {} : {}{};\n", 
144                    hlsl_type, param.name, semantic, 
145                    if semantic == "TEXCOORD" { i.to_string() } else { "".to_string() }
146                ));
147            }
148            output.push_str("};\n\n");
149            
150            // Generate output struct
151            output.push_str("struct VSOutput\n{\n");
152            output.push_str("    float4 position : SV_Position;\n");
153            // TODO: Add other vertex outputs based on shader.ast.outputs
154            output.push_str("};\n\n");
155            
156            // Generate main function
157            output.push_str("VSOutput VSMain(VSInput input)\n{\n");
158            ctx.push_indent();
159            
160            // Add input variables to context
161            for param in &shader.ast.inputs {
162                ctx.vars.insert(param.name.clone(), format!("input.{}", param.name));
163            }
164            
165            // Emit function body
166            let body_code = emit_block(&mut ctx, &shader.ast.body)?;
167            output.push_str(&body_code);
168            
169            ctx.pop_indent();
170            output.push_str("}\n");
171        },
172        ShaderStage::Fragment => {
173            // Generate input struct
174            output.push_str("struct VSInput\n{\n");
175            for (i, param) in shader.ast.inputs.iter().enumerate() {
176                let hlsl_type = map_type_to_hlsl(&param.ty);
177                output.push_str(&format!("    {} {} : TEXCOORD{};\n", hlsl_type, param.name, i));
178            }
179            output.push_str("};\n\n");
180            
181            // Generate output struct
182            output.push_str("struct PSOutput\n{\n");
183            let out_type = map_type_to_hlsl(&shader.ast.outputs);
184            output.push_str(&format!("    {} color : SV_Target0;\n", out_type));
185            // TODO: Support multiple render targets
186            output.push_str("};\n\n");
187            
188            // Generate main function
189            output.push_str("PSOutput PSMain(VSInput input)\n{\n");
190            ctx.push_indent();
191            
192            // Add input variables to context
193            for param in &shader.ast.inputs {
194                ctx.vars.insert(param.name.clone(), format!("input.{}", param.name));
195            }
196            
197            // Emit function body
198            let body_code = emit_block(&mut ctx, &shader.ast.body)?;
199            output.push_str(&body_code);
200            
201            ctx.pop_indent();
202            output.push_str("}\n");
203        },
204    }
205    
206    Ok(output)
207}
208
209fn emit_block(ctx: &mut HLSLContext, block: &Block) -> KoreResult<String> {
210    let mut output = String::new();
211    
212    for stmt in &block.stmts {
213        output.push_str(&emit_stmt(ctx, stmt)?);
214    }
215    
216    Ok(output)
217}
218
219fn emit_stmt(ctx: &mut HLSLContext, stmt: &Stmt) -> KoreResult<String> {
220    let mut output = String::new();
221    
222    match stmt {
223        Stmt::Let { pattern, value, .. } => {
224            if let Some(value) = value {
225                if let Pattern::Binding { name, .. } = pattern {
226                    let (expr_code, expr_type) = emit_expr(ctx, value)?;
227                    output.push_str(&format!("{}{} {} = {};\n", ctx.indent(), expr_type, name, expr_code));
228                    ctx.vars.insert(name.clone(), name.clone());
229                }
230            }
231        },
232        Stmt::Return(Some(expr), _) => {
233            let (expr_code, _) = emit_expr(ctx, expr)?;
234            output.push_str(&format!("{}PSOutput _result;\n", ctx.indent()));
235            output.push_str(&format!("{}_result.color = {};\n", ctx.indent(), expr_code));
236            output.push_str(&format!("{}return _result;\n", ctx.indent()));
237        },
238        Stmt::Return(None, _) => {
239            output.push_str(&format!("{}return;\n", ctx.indent()));
240        },
241        Stmt::Expr(expr) => {
242            let (expr_code, _) = emit_expr(ctx, expr)?;
243            output.push_str(&format!("{}{};\n", ctx.indent(), expr_code));
244        },
245        // If expressions are handled in Expr::If, not Stmt::If
246        Stmt::While { condition, body, .. } => {
247            let (cond_code, _) = emit_expr(ctx, condition)?;
248            output.push_str(&format!("{}while ({})\n", ctx.indent(), cond_code));
249            output.push_str(&format!("{}{{\n", ctx.indent()));
250            ctx.push_indent();
251            output.push_str(&emit_block(ctx, body)?);
252            ctx.pop_indent();
253            output.push_str(&format!("{}}}\n", ctx.indent()));
254        },
255        Stmt::For { binding, iter: _, body, .. } => {
256            // Simple for loop support - assumes range-like iteration
257            if let Pattern::Binding { name, .. } = binding {
258                output.push_str(&format!("{}for (int {} = 0; {} < 10; {}++)\n", 
259                    ctx.indent(), name, name, name));
260                output.push_str(&format!("{}{{\n", ctx.indent()));
261                ctx.push_indent();
262                ctx.vars.insert(name.clone(), name.clone());
263                output.push_str(&emit_block(ctx, body)?);
264                ctx.pop_indent();
265                output.push_str(&format!("{}}}\n", ctx.indent()));
266            }
267        },
268        Stmt::Break(_, _) => {
269            output.push_str(&format!("{}break;\n", ctx.indent()));
270        },
271        Stmt::Continue(_) => {
272            output.push_str(&format!("{}continue;\n", ctx.indent()));
273        },
274        _ => {}
275    }
276    
277    Ok(output)
278}
279
280fn emit_expr(ctx: &mut HLSLContext, expr: &Expr) -> KoreResult<(String, String)> {
281    match expr {
282        Expr::Ident(name, _) => {
283            if let Some(mapped) = ctx.vars.get(name) {
284                Ok((mapped.clone(), "float4".to_string()))
285            } else {
286                Ok((name.clone(), "float4".to_string()))
287            }
288        },
289        Expr::Float(f, _) => {
290            Ok((format!("{:.6}", f), "float".to_string()))
291        },
292        Expr::Int(i, _) => {
293            Ok((format!("{}", i), "int".to_string()))
294        },
295        Expr::Bool(b, _) => {
296            Ok((format!("{}", b), "bool".to_string()))
297        },
298        Expr::String(s, _) => {
299            // HLSL doesn't have strings, but we can use this for debug/comments
300            Ok((format!("\"{}\"", s), "string".to_string()))
301        },
302        Expr::Binary { left, op, right, .. } => {
303            let (left_code, left_ty) = emit_expr(ctx, left)?;
304            let (right_code, _) = emit_expr(ctx, right)?;
305            
306            let op_str = match op {
307                BinaryOp::Add => "+",
308                BinaryOp::Sub => "-",
309                BinaryOp::Mul => "*",
310                BinaryOp::Div => "/",
311                BinaryOp::Mod => "%",
312                BinaryOp::Eq => "==",
313                BinaryOp::Ne => "!=",
314                BinaryOp::Lt => "<",
315                BinaryOp::Le => "<=",
316                BinaryOp::Gt => ">",
317                BinaryOp::Ge => ">=",
318                BinaryOp::And => "&&",
319                BinaryOp::Or => "||",
320                BinaryOp::BitAnd => "&",
321                BinaryOp::BitOr => "|",
322                BinaryOp::BitXor => "^",
323                BinaryOp::Shl => "<<",
324                BinaryOp::Shr => ">>",
325                _ => return Err(KoreError::codegen("Unsupported binary op", expr.span())),
326            };
327            
328            // Preserve type from left operand for most ops
329            let result_ty = match op {
330                BinaryOp::Eq | BinaryOp::Ne | BinaryOp::Lt | BinaryOp::Le | 
331                BinaryOp::Gt | BinaryOp::Ge | BinaryOp::And | BinaryOp::Or => "bool".to_string(),
332                _ => left_ty,
333            };
334            
335            Ok((format!("({} {} {})", left_code, op_str, right_code), result_ty))
336        },
337        Expr::Unary { op, operand, .. } => {
338            let (operand_code, ty) = emit_expr(ctx, operand)?;
339            let op_str = match op {
340                crate::ast::UnaryOp::Neg => "-",
341                crate::ast::UnaryOp::Not => "!",
342                crate::ast::UnaryOp::BitNot => "~",
343                crate::ast::UnaryOp::Ref | crate::ast::UnaryOp::RefMut => {
344                    // HLSL doesn't have explicit references, just pass through
345                    return Ok((operand_code, ty));
346                },
347                crate::ast::UnaryOp::Deref => {
348                    // HLSL doesn't have explicit dereferencing, just pass through
349                    return Ok((operand_code, ty));
350                },
351            };
352            Ok((format!("({}{})", op_str, operand_code), ty))
353        },
354        Expr::Call { callee, args, .. } => {
355            if let Expr::Ident(name, _) = &**callee {
356                emit_function_call(ctx, name, args)
357            } else {
358                Err(KoreError::codegen("Complex callee not supported", expr.span()))
359            }
360        },
361        Expr::Field { object, field, .. } => {
362            let (obj_code, _) = emit_expr(ctx, object)?;
363            
364            // HLSL supports direct swizzling - just pass it through!
365            // Supports: .x .y .z .w .r .g .b .a
366            // And combinations: .xy .xyz .rgba .xyzw .bgra etc.
367            Ok((format!("{}.{}", obj_code, field), infer_swizzle_type(field)))
368        },
369        Expr::Index { object, index, .. } => {
370            let (obj_code, obj_ty) = emit_expr(ctx, object)?;
371            let (idx_code, _) = emit_expr(ctx, index)?;
372            // Array indexing preserves element type
373            let elem_ty = if obj_ty.starts_with("float") {
374                "float".to_string()
375            } else {
376                obj_ty
377            };
378            Ok((format!("{}[{}]", obj_code, idx_code), elem_ty))
379        },
380        Expr::If { condition, then_branch, else_branch, .. } => {
381            // HLSL ternary operator for simple if expressions
382            let (cond_code, _) = emit_expr(ctx, condition)?;
383            
384            // For now, emit as ternary if we can extract simple expressions
385            // TODO: Handle complex if expressions with multiple statements
386            if then_branch.stmts.len() == 1 && else_branch.is_some() {
387                if let Stmt::Expr(then_expr) = &then_branch.stmts[0] {
388                    let (then_code, then_ty) = emit_expr(ctx, then_expr)?;
389                    
390                    if let Some(crate::ast::ElseBranch::Else(else_block)) = else_branch.as_ref().map(|b| b.as_ref()) {
391                        if else_block.stmts.len() == 1 {
392                            if let Stmt::Expr(else_expr) = &else_block.stmts[0] {
393                                let (else_code, _) = emit_expr(ctx, else_expr)?;
394                                return Ok((format!("({} ? {} : {})", cond_code, then_code, else_code), then_ty));
395                            }
396                        }
397                    }
398                }
399            }
400            
401            // Fallback: can't emit as expression
402            Err(KoreError::codegen("Complex if expressions not yet supported in HLSL backend", expr.span()))
403        },
404        Expr::Paren(inner, _) => {
405            // Parenthesized expression - just emit the inner expression with parens
406            let (inner_code, ty) = emit_expr(ctx, inner)?;
407            Ok((format!("({})", inner_code), ty))
408        },
409        _ => Err(KoreError::codegen("Unsupported expression", expr.span())),
410    }
411}
412
413fn emit_function_call(ctx: &mut HLSLContext, name: &str, args: &[crate::ast::CallArg]) -> KoreResult<(String, String)> {
414    match name {
415        // Vector constructors
416        "vec2" | "Vec2" => {
417            let mut arg_codes = Vec::new();
418            for arg in args {
419                let (code, _) = emit_expr(ctx, &arg.value)?;
420                arg_codes.push(code);
421            }
422            Ok((format!("float2({})", arg_codes.join(", ")), "float2".to_string()))
423        },
424        "vec3" | "Vec3" => {
425            let mut arg_codes = Vec::new();
426            for arg in args {
427                let (code, _) = emit_expr(ctx, &arg.value)?;
428                arg_codes.push(code);
429            }
430            Ok((format!("float3({})", arg_codes.join(", ")), "float3".to_string()))
431        },
432        "vec4" | "Vec4" => {
433            let mut arg_codes = Vec::new();
434            for arg in args {
435                let (code, _) = emit_expr(ctx, &arg.value)?;
436                arg_codes.push(code);
437            }
438            Ok((format!("float4({})", arg_codes.join(", ")), "float4".to_string()))
439        },
440        
441        // Math functions - Trigonometry
442        "sin" | "cos" | "tan" => {
443            let (arg_code, ty) = emit_expr(ctx, &args[0].value)?;
444            Ok((format!("{}({})", name, arg_code), ty))
445        },
446        "asin" | "acos" | "atan" => {
447            let (arg_code, ty) = emit_expr(ctx, &args[0].value)?;
448            Ok((format!("{}({})", name, arg_code), ty))
449        },
450        "atan2" => {
451            let (arg1, ty) = emit_expr(ctx, &args[0].value)?;
452            let (arg2, _) = emit_expr(ctx, &args[1].value)?;
453            Ok((format!("atan2({}, {})", arg1, arg2), ty))
454        },
455        
456        // Math functions - Common
457        "abs" | "floor" | "ceil" | "round" | "trunc" | "fract" => {
458            let (arg_code, ty) = emit_expr(ctx, &args[0].value)?;
459            let hlsl_name = if name == "fract" { "frac" } else { name };
460            Ok((format!("{}({})", hlsl_name, arg_code), ty))
461        },
462        "sqrt" | "rsqrt" | "exp" | "exp2" | "log" | "log2" | "log10" => {
463            let (arg_code, ty) = emit_expr(ctx, &args[0].value)?;
464            Ok((format!("{}({})", name, arg_code), ty))
465        },
466        "sign" | "saturate" => {
467            let (arg_code, ty) = emit_expr(ctx, &args[0].value)?;
468            Ok((format!("{}({})", name, arg_code), ty))
469        },
470        
471        // Math functions - Two arguments
472        "pow" | "min" | "max" | "fmod" | "step" => {
473            let (arg1, ty) = emit_expr(ctx, &args[0].value)?;
474            let (arg2, _) = emit_expr(ctx, &args[1].value)?;
475            Ok((format!("{}({}, {})", name, arg1, arg2), ty))
476        },
477        
478        // Math functions - Three arguments
479        "clamp" | "smoothstep" | "mad" => {
480            let (arg1, ty) = emit_expr(ctx, &args[0].value)?;
481            let (arg2, _) = emit_expr(ctx, &args[1].value)?;
482            let (arg3, _) = emit_expr(ctx, &args[2].value)?;
483            Ok((format!("{}({}, {}, {})", name, arg1, arg2, arg3), ty))
484        },
485        "mix" | "lerp" => {
486            let (arg1, ty) = emit_expr(ctx, &args[0].value)?;
487            let (arg2, _) = emit_expr(ctx, &args[1].value)?;
488            let (arg3, _) = emit_expr(ctx, &args[2].value)?;
489            Ok((format!("lerp({}, {}, {})", arg1, arg2, arg3), ty))
490        },
491        
492        // Vector functions
493        "length" => {
494            let (arg_code, _) = emit_expr(ctx, &args[0].value)?;
495            Ok((format!("length({})", arg_code), "float".to_string()))
496        },
497        "distance" => {
498            let (arg1, _) = emit_expr(ctx, &args[0].value)?;
499            let (arg2, _) = emit_expr(ctx, &args[1].value)?;
500            Ok((format!("distance({}, {})", arg1, arg2), "float".to_string()))
501        },
502        "normalize" => {
503            let (arg_code, ty) = emit_expr(ctx, &args[0].value)?;
504            Ok((format!("normalize({})", arg_code), ty))
505        },
506        "dot" => {
507            let (arg1, _) = emit_expr(ctx, &args[0].value)?;
508            let (arg2, _) = emit_expr(ctx, &args[1].value)?;
509            Ok((format!("dot({}, {})", arg1, arg2), "float".to_string()))
510        },
511        "cross" => {
512            let (arg1, ty) = emit_expr(ctx, &args[0].value)?;
513            let (arg2, _) = emit_expr(ctx, &args[1].value)?;
514            Ok((format!("cross({}, {})", arg1, arg2), ty))
515        },
516        "reflect" => {
517            let (arg1, ty) = emit_expr(ctx, &args[0].value)?;
518            let (arg2, _) = emit_expr(ctx, &args[1].value)?;
519            Ok((format!("reflect({}, {})", arg1, arg2), ty))
520        },
521        "refract" => {
522            let (arg1, ty) = emit_expr(ctx, &args[0].value)?;
523            let (arg2, _) = emit_expr(ctx, &args[1].value)?;
524            let (arg3, _) = emit_expr(ctx, &args[2].value)?;
525            Ok((format!("refract({}, {}, {})", arg1, arg2, arg3), ty))
526        },
527        "faceforward" => {
528            let (arg1, ty) = emit_expr(ctx, &args[0].value)?;
529            let (arg2, _) = emit_expr(ctx, &args[1].value)?;
530            let (arg3, _) = emit_expr(ctx, &args[2].value)?;
531            Ok((format!("faceforward({}, {}, {})", arg1, arg2, arg3), ty))
532        },
533        
534        // Matrix functions
535        "transpose" => {
536            let (arg_code, ty) = emit_expr(ctx, &args[0].value)?;
537            Ok((format!("transpose({})", arg_code), ty))
538        },
539        "determinant" => {
540            let (arg_code, _) = emit_expr(ctx, &args[0].value)?;
541            Ok((format!("determinant({})", arg_code), "float".to_string()))
542        },
543        
544        // Texture sampling
545        "sample" => {
546            let (sampler, _) = emit_expr(ctx, &args[0].value)?;
547            let (coords, _) = emit_expr(ctx, &args[1].value)?;
548            Ok((format!("{}.Sample({}_sampler, {})", sampler, sampler, coords), "float4".to_string()))
549        },
550        "sample_lod" => {
551            let (sampler, _) = emit_expr(ctx, &args[0].value)?;
552            let (coords, _) = emit_expr(ctx, &args[1].value)?;
553            let (lod, _) = emit_expr(ctx, &args[2].value)?;
554            Ok((format!("{}.SampleLevel({}_sampler, {}, {})", sampler, sampler, coords, lod), "float4".to_string()))
555        },
556        "sample_grad" => {
557            let (sampler, _) = emit_expr(ctx, &args[0].value)?;
558            let (coords, _) = emit_expr(ctx, &args[1].value)?;
559            let (ddx, _) = emit_expr(ctx, &args[2].value)?;
560            let (ddy, _) = emit_expr(ctx, &args[3].value)?;
561            Ok((format!("{}.SampleGrad({}_sampler, {}, {}, {})", sampler, sampler, coords, ddx, ddy), "float4".to_string()))
562        },
563        "sample_bias" => {
564            let (sampler, _) = emit_expr(ctx, &args[0].value)?;
565            let (coords, _) = emit_expr(ctx, &args[1].value)?;
566            let (bias, _) = emit_expr(ctx, &args[2].value)?;
567            Ok((format!("{}.SampleBias({}_sampler, {}, {})", sampler, sampler, coords, bias), "float4".to_string()))
568        },
569        "sample_cmp" => {
570            let (sampler, _) = emit_expr(ctx, &args[0].value)?;
571            let (coords, _) = emit_expr(ctx, &args[1].value)?;
572            let (compare, _) = emit_expr(ctx, &args[2].value)?;
573            Ok((format!("{}.SampleCmp({}_sampler, {}, {})", sampler, sampler, coords, compare), "float".to_string()))
574        },
575        "load" => {
576            let (texture, _) = emit_expr(ctx, &args[0].value)?;
577            let (location, _) = emit_expr(ctx, &args[1].value)?;
578            Ok((format!("{}.Load({})", texture, location), "float4".to_string()))
579        },
580        
581        // Derivative functions
582        "ddx" | "ddy" | "ddx_fine" | "ddy_fine" | "ddx_coarse" | "ddy_coarse" => {
583            let (arg_code, ty) = emit_expr(ctx, &args[0].value)?;
584            Ok((format!("{}({})", name, arg_code), ty))
585        },
586        "fwidth" => {
587            let (arg_code, ty) = emit_expr(ctx, &args[0].value)?;
588            Ok((format!("fwidth({})", arg_code), ty))
589        },
590        
591        // Bit operations
592        "countbits" | "firstbithigh" | "firstbitlow" | "reversebits" => {
593            let (arg_code, ty) = emit_expr(ctx, &args[0].value)?;
594            Ok((format!("{}({})", name, arg_code), ty))
595        },
596        
597        // Interpolation
598        "all" | "any" => {
599            let (arg_code, _) = emit_expr(ctx, &args[0].value)?;
600            Ok((format!("{}({})", name, arg_code), "bool".to_string()))
601        },
602        
603        // Noise functions (custom implementations)
604        "noise" => {
605            let (arg_code, _) = emit_expr(ctx, &args[0].value)?;
606            Ok((format!("frac(sin(dot({}, float2(12.9898, 78.233))) * 43758.5453)", arg_code), "float".to_string()))
607        },
608        "noise3d" => {
609            let (arg_code, _) = emit_expr(ctx, &args[0].value)?;
610            Ok((format!("frac(sin(dot({}, float3(12.9898, 78.233, 37.719))) * 43758.5453)", arg_code), "float".to_string()))
611        },
612        
613        // Packing/Unpacking functions
614        "pack_half_2x16" => {
615            let (arg_code, _) = emit_expr(ctx, &args[0].value)?;
616            Ok((format!("f32tof16({}).x | (f32tof16({}).y << 16)", arg_code, arg_code), "uint".to_string()))
617        },
618        "unpack_half_2x16" => {
619            let (arg_code, _) = emit_expr(ctx, &args[0].value)?;
620            Ok((format!("float2(f16tof32({} & 0xFFFF), f16tof32({} >> 16))", arg_code, arg_code), "float2".to_string()))
621        },
622        
623        // Advanced texture operations
624        "texture_size" => {
625            let (texture, _) = emit_expr(ctx, &args[0].value)?;
626            let (lod, _) = if args.len() > 1 {
627                emit_expr(ctx, &args[1].value)?
628            } else {
629                ("0".to_string(), "int".to_string())
630            };
631            Ok((format!("{}.GetDimensions({})", texture, lod), "int2".to_string()))
632        },
633        "texture_query_lod" => {
634            let (sampler, _) = emit_expr(ctx, &args[0].value)?;
635            let (coords, _) = emit_expr(ctx, &args[1].value)?;
636            Ok((format!("{}.CalculateLevelOfDetail({}_sampler, {})", sampler, sampler, coords), "float".to_string()))
637        },
638        "texture_gather" => {
639            let (texture, _) = emit_expr(ctx, &args[0].value)?;
640            let (coords, _) = emit_expr(ctx, &args[1].value)?;
641            let component = if args.len() > 2 {
642                if let Expr::Int(i, _) = &args[2].value {
643                    *i as u32
644                } else {
645                    0
646                }
647            } else {
648                0
649            };
650            Ok((format!("{}.Gather({}_sampler, {}, {})", texture, texture, coords, component), "float4".to_string()))
651        },
652        
653        // Color space conversions
654        "rgb_to_hsv" => {
655            let (rgb, _) = emit_expr(ctx, &args[0].value)?;
656            let code = format!(
657                "({{ \
658                    float3 _rgb = {}; \
659                    float4 K = float4(0.0, -1.0/3.0, 2.0/3.0, -1.0); \
660                    float4 p = lerp(float4(_rgb.bg, K.wz), float4(_rgb.gb, K.xy), step(_rgb.b, _rgb.g)); \
661                    float4 q = lerp(float4(p.xyw, _rgb.r), float4(_rgb.r, p.yzx), step(p.x, _rgb.r)); \
662                    float d = q.x - min(q.w, q.y); \
663                    float e = 1.0e-10; \
664                    float3(abs(q.z + (q.w - q.y) / (6.0 * d + e)), d / (q.x + e), q.x); \
665                }})", rgb
666            );
667            Ok((code, "float3".to_string()))
668        },
669        "hsv_to_rgb" => {
670            let (hsv, _) = emit_expr(ctx, &args[0].value)?;
671            let code = format!(
672                "({{ \
673                    float3 _hsv = {}; \
674                    float4 K = float4(1.0, 2.0/3.0, 1.0/3.0, 3.0); \
675                    float3 p = abs(frac(_hsv.xxx + K.xyz) * 6.0 - K.www); \
676                    lerp(K.xxx, saturate(p - K.xxx), _hsv.y) * _hsv.z; \
677                }})", hsv
678            );
679            Ok((code, "float3".to_string()))
680        },
681        
682        // Matrix construction
683        "mat2" | "Mat2" => {
684            let mut arg_codes = Vec::new();
685            for arg in args {
686                let (code, _) = emit_expr(ctx, &arg.value)?;
687                arg_codes.push(code);
688            }
689            Ok((format!("float2x2({})", arg_codes.join(", ")), "float2x2".to_string()))
690        },
691        "mat3" | "Mat3" => {
692            let mut arg_codes = Vec::new();
693            for arg in args {
694                let (code, _) = emit_expr(ctx, &arg.value)?;
695                arg_codes.push(code);
696            }
697            Ok((format!("float3x3({})", arg_codes.join(", ")), "float3x3".to_string()))
698        },
699        "mat4" | "Mat4" => {
700            let mut arg_codes = Vec::new();
701            for arg in args {
702                let (code, _) = emit_expr(ctx, &arg.value)?;
703                arg_codes.push(code);
704            }
705            Ok((format!("float4x4({})", arg_codes.join(", ")), "float4x4".to_string()))
706        },
707        
708        // Advanced math
709        "modf" => {
710            let (arg_code, ty) = emit_expr(ctx, &args[0].value)?;
711            Ok((format!("modf({}, _modf_int)", arg_code), ty))
712        },
713        "frexp" => {
714            let (arg_code, ty) = emit_expr(ctx, &args[0].value)?;
715            Ok((format!("frexp({}, _frexp_exp)", arg_code), ty))
716        },
717        "ldexp" => {
718            let (x, ty) = emit_expr(ctx, &args[0].value)?;
719            let (exp, _) = emit_expr(ctx, &args[1].value)?;
720            Ok((format!("ldexp({}, {})", x, exp), ty))
721        },
722        
723        // Interpolation attributes (for vertex shader outputs)
724        "flat" | "noperspective" | "centroid" => {
725            // These are interpolation modifiers, not functions
726            // They would be handled in struct field declarations
727            Err(KoreError::codegen(format!("{} is an interpolation modifier, not a function", name), crate::span::Span::new(0, 0)))
728        },
729        
730        // Atomic operations (for compute shaders)
731        "atomic_add" | "atomic_sub" | "atomic_min" | "atomic_max" | 
732        "atomic_and" | "atomic_or" | "atomic_xor" | "atomic_exchange" | "atomic_cas" => {
733            let hlsl_name = match name {
734                "atomic_add" => "InterlockedAdd",
735                "atomic_sub" => "InterlockedAdd", // with negated value
736                "atomic_min" => "InterlockedMin",
737                "atomic_max" => "InterlockedMax",
738                "atomic_and" => "InterlockedAnd",
739                "atomic_or" => "InterlockedOr",
740                "atomic_xor" => "InterlockedXor",
741                "atomic_exchange" => "InterlockedExchange",
742                "atomic_cas" => "InterlockedCompareExchange",
743                _ => name,
744            };
745            
746            let mut arg_codes = Vec::new();
747            for arg in args {
748                let (code, _) = emit_expr(ctx, &arg.value)?;
749                arg_codes.push(code);
750            }
751            
752            Ok((format!("{}({})", hlsl_name, arg_codes.join(", ")), "void".to_string()))
753        },
754        
755        // Wave intrinsics (shader model 6.0+)
756        "wave_active_all_true" | "wave_active_any_true" | "wave_active_ballot" |
757        "wave_active_sum" | "wave_active_product" | "wave_active_min" | "wave_active_max" |
758        "wave_prefix_sum" | "wave_prefix_product" | "wave_read_lane_first" | "wave_read_lane_at" => {
759            let hlsl_name = match name {
760                "wave_active_all_true" => "WaveActiveAllTrue",
761                "wave_active_any_true" => "WaveActiveAnyTrue",
762                "wave_active_ballot" => "WaveActiveBallot",
763                "wave_active_sum" => "WaveActiveSum",
764                "wave_active_product" => "WaveActiveProduct",
765                "wave_active_min" => "WaveActiveMin",
766                "wave_active_max" => "WaveActiveMax",
767                "wave_prefix_sum" => "WavePrefixSum",
768                "wave_prefix_product" => "WavePrefixProduct",
769                "wave_read_lane_first" => "WaveReadLaneFirst",
770                "wave_read_lane_at" => "WaveReadLaneAt",
771                _ => name,
772            };
773            
774            let mut arg_codes = Vec::new();
775            for arg in args {
776                let (code, _) = emit_expr(ctx, &arg.value)?;
777                arg_codes.push(code);
778            }
779            
780            let return_type = if name.contains("ballot") {
781                "uint4".to_string()
782            } else if name.contains("all_true") || name.contains("any_true") {
783                "bool".to_string()
784            } else if !args.is_empty() {
785                emit_expr(ctx, &args[0].value)?.1
786            } else {
787                "float".to_string()
788            };
789            
790            Ok((format!("{}({})", hlsl_name, arg_codes.join(", ")), return_type))
791        },
792        
793        _ => Err(KoreError::codegen(format!("Unknown function: {}", name), crate::span::Span::new(0, 0))),
794    }
795}
796
797fn map_type_to_hlsl(ty: &Type) -> String {
798    match ty {
799        Type::Named { name, .. } => match name.as_str() {
800            "Float" | "f32" => "float".to_string(),
801            "Int" | "i32" => "int".to_string(),
802            "UInt" | "u32" => "uint".to_string(),
803            "Bool" => "bool".to_string(),
804            "Vec2" => "float2".to_string(),
805            "Vec3" => "float3".to_string(),
806            "Vec4" => "float4".to_string(),
807            "IVec2" => "int2".to_string(),
808            "IVec3" => "int3".to_string(),
809            "IVec4" => "int4".to_string(),
810            "UVec2" => "uint2".to_string(),
811            "UVec3" => "uint3".to_string(),
812            "UVec4" => "uint4".to_string(),
813            "Mat4" => "float4x4".to_string(),
814            "Mat3" => "float3x3".to_string(),
815            "Mat2" => "float2x2".to_string(),
816            "Sampler2D" => "Texture2D".to_string(),
817            "Sampler3D" => "Texture3D".to_string(),
818            "SamplerCube" => "TextureCube".to_string(),
819            "Sampler2DArray" => "Texture2DArray".to_string(),
820            "SamplerCubeArray" => "TextureCubeArray".to_string(),
821            "Sampler2DMS" => "Texture2DMS".to_string(),
822            "RWTexture2D" => "RWTexture2D<float4>".to_string(),
823            "RWTexture3D" => "RWTexture3D<float4>".to_string(),
824            "Buffer" => "Buffer<float4>".to_string(),
825            "RWBuffer" => "RWBuffer<float4>".to_string(),
826            "StructuredBuffer" => "StructuredBuffer<float4>".to_string(),
827            "RWStructuredBuffer" => "RWStructuredBuffer<float4>".to_string(),
828            "ByteAddressBuffer" => "ByteAddressBuffer".to_string(),
829            "RWByteAddressBuffer" => "RWByteAddressBuffer".to_string(),
830            "Void" => "void".to_string(),
831            _ => "float4".to_string(),
832        },
833        Type::Array(element, _size, _span) => {
834            let elem_ty = map_type_to_hlsl(element);
835            // HLSL arrays use fixed size
836            format!("{}[{}]", elem_ty, _size)
837        },
838        _ => "float4".to_string(),
839    }
840}
841
842fn infer_swizzle_type(swizzle: &str) -> String {
843    match swizzle.len() {
844        1 => "float".to_string(),
845        2 => "float2".to_string(),
846        3 => "float3".to_string(),
847        4 => "float4".to_string(),
848        _ => "float".to_string(),
849    }
850}