qudit_expr/codegen/
codegen.rs

1use inkwell::AddressSpace;
2use inkwell::builder::Builder;
3use inkwell::values::FloatValue;
4use inkwell::values::FunctionValue;
5
6use coe::is_same;
7use qudit_core::RealScalar;
8use std::collections::HashMap;
9
10use crate::ComplexExpression;
11use crate::Expression;
12
13use super::builtins::Builtins;
14use super::module::Module;
15
16#[derive(Debug)]
17pub struct CodeGenError {
18    pub message: String,
19}
20
21impl CodeGenError {
22    pub fn new(message: &str) -> Self {
23        CodeGenError {
24            message: message.to_string(),
25        }
26    }
27}
28
29type CodeGenResult<T> = Result<T, CodeGenError>;
30
31#[derive(Debug)]
32pub struct CodeGenerator<'ctx, R: RealScalar> {
33    pub context: &'ctx Module<R>,
34    pub builder: Builder<'ctx>,
35
36    variables: HashMap<String, FloatValue<'ctx>>,
37    expressions: HashMap<String, FloatValue<'ctx>>,
38    functions: HashMap<String, FunctionValue<'ctx>>,
39    fn_value_opt: Option<FunctionValue<'ctx>>,
40    output_ptr_idx: Option<u32>,
41}
42
43impl<'ctx, R: RealScalar> CodeGenerator<'ctx, R> {
44    pub fn new(context: &'ctx Module<R>) -> Self {
45        let builder = context.context().create_builder();
46        CodeGenerator {
47            context,
48            builder,
49            variables: HashMap::new(),
50            functions: HashMap::new(),
51            expressions: HashMap::new(),
52            fn_value_opt: None,
53            output_ptr_idx: None,
54        }
55    }
56
57    fn int_type(&self) -> inkwell::types::IntType<'ctx> {
58        if is_same::<R, f32>() {
59            self.context.context().i32_type()
60        } else if is_same::<R, f64>() {
61            self.context.context().i64_type()
62        } else {
63            panic!("Unknown bit width");
64        }
65    }
66
67    fn float_type(&self) -> inkwell::types::FloatType<'ctx> {
68        if is_same::<R, f32>() {
69            self.context.context().f32_type()
70        } else if is_same::<R, f64>() {
71            self.context.context().f64_type()
72        } else {
73            panic!("Unknown bit width");
74        }
75    }
76
77    fn gen_write_func_proto(&self, name: &str) -> CodeGenResult<FunctionValue<'ctx>> {
78        let ret_type = self.context.context().void_type();
79        let ptr_type = self.context.context().ptr_type(AddressSpace::default());
80        // Match Rust WriteFunc<R> signature: (*const R, *mut R, *const u64, *const u64, u64, *const bool)
81        let param_types = vec![
82            ptr_type.into(),                          // *const R
83            ptr_type.into(),                          // *mut R
84            ptr_type.into(),                          // *const u64
85            ptr_type.into(),                          // *const u64
86            self.context.context().i64_type().into(), // u64
87            ptr_type.into(),                          // *const bool
88        ];
89        let func_type = ret_type.fn_type(&param_types, false);
90        let func = self
91            .context
92            .with_module(|module| module.add_function(name, func_type, None));
93        Ok(func)
94    }
95
96    fn build_expression(&mut self, expr: &Expression) -> CodeGenResult<FloatValue<'ctx>> {
97        let expr_str = expr.to_string();
98        let cached = self.expressions.get(&expr_str);
99        if let Some(c) = cached {
100            return Ok(*c);
101        }
102
103        let val = match expr {
104            Expression::Pi => Ok(self.float_type().const_float(std::f64::consts::PI)),
105            Expression::Constant(_) => Ok(self.float_type().const_float(expr.to_float())),
106            Expression::Variable(name) => self
107                .variables
108                .get(name)
109                .ok_or(CodeGenError::new(&format!("Variable {} not found", name)))
110                .copied(),
111            Expression::Neg(expr) => {
112                let val = self.build_expression(expr)?;
113                Ok(self.builder.build_float_neg(val, "tmp").unwrap())
114            }
115            Expression::Add(lhs, rhs) => {
116                let lhs_val = self.build_expression(lhs)?;
117                let rhs_val = self.build_expression(rhs)?;
118                Ok(self
119                    .builder
120                    .build_float_add(lhs_val, rhs_val, "tmp")
121                    .unwrap())
122            }
123            Expression::Sub(lhs, rhs) => {
124                let lhs_val = self.build_expression(lhs)?;
125                let rhs_val = self.build_expression(rhs)?;
126                Ok(self
127                    .builder
128                    .build_float_sub(lhs_val, rhs_val, "tmp")
129                    .unwrap())
130            }
131            Expression::Mul(lhs, rhs) => {
132                let lhs_val = self.build_expression(lhs)?;
133                let rhs_val = self.build_expression(rhs)?;
134                Ok(self
135                    .builder
136                    .build_float_mul(lhs_val, rhs_val, "tmp")
137                    .unwrap())
138            }
139            Expression::Div(lhs, rhs) => {
140                let lhs_val = self.build_expression(lhs)?;
141                let rhs_val = self.build_expression(rhs)?;
142                Ok(self
143                    .builder
144                    .build_float_div(lhs_val, rhs_val, "tmp")
145                    .unwrap())
146            }
147            Expression::Pow(base, exponent) => {
148                let base_val = self.build_expression(base)?;
149                let exponent_val = self.build_expression(exponent)?;
150                let pow = self.get_builtin("pow");
151                let args = [base_val.into(), exponent_val.into()];
152                let val = self
153                    .builder
154                    .build_call(pow, &args, "tmp")
155                    .unwrap()
156                    .try_as_basic_value()
157                    .left()
158                    .unwrap()
159                    .into_float_value();
160                Ok(val)
161            }
162            Expression::Sqrt(expr) => {
163                let arg = self.build_expression(expr)?;
164                let sqrt = self.get_builtin("sqrt");
165                let val = self
166                    .builder
167                    .build_call(sqrt, &[arg.into()], "tmp")
168                    .unwrap()
169                    .try_as_basic_value()
170                    .left()
171                    .unwrap()
172                    .into_float_value();
173                Ok(val)
174            }
175            Expression::Sin(expr) => {
176                let arg = self.build_expression(expr)?;
177                let sin = self.get_builtin("sin");
178                let val = self
179                    .builder
180                    .build_call(sin, &[arg.into()], "tmp")
181                    .unwrap()
182                    .try_as_basic_value()
183                    .left()
184                    .unwrap()
185                    .into_float_value();
186                Ok(val)
187            }
188            Expression::Cos(expr) => {
189                let arg = self.build_expression(expr)?;
190                let cos = self.get_builtin("cos");
191                let val = self
192                    .builder
193                    .build_call(cos, &[arg.into()], "tmp")
194                    .unwrap()
195                    .try_as_basic_value()
196                    .left()
197                    .unwrap()
198                    .into_float_value();
199                Ok(val)
200            }
201        };
202
203        if let Ok(val) = val {
204            self.expressions.insert(expr_str, val);
205        }
206        val
207    }
208
209    pub fn compile_expr(
210        &mut self,
211        expr: &ComplexExpression,
212        re_offset: usize,
213        need_to_write_real_zero: bool,
214    ) -> CodeGenResult<()> {
215        let re_offset: u64 = re_offset as u64;
216        let ptr_idx = self
217            .output_ptr_idx
218            .to_owned()
219            .expect("Output pointer index not set");
220        let ptr = self
221            .fn_value_opt
222            .unwrap()
223            .get_nth_param(ptr_idx)
224            .unwrap()
225            .into_pointer_value();
226
227        if need_to_write_real_zero || !expr.real.is_zero_fast() {
228            let val = self.build_expression(&expr.real)?;
229            let offset = self.int_type().const_int(re_offset, false);
230            let offset_ptr = unsafe {
231                self.builder
232                    .build_gep(self.float_type(), ptr, &[offset], "offset_ptr")
233                    .unwrap()
234            };
235
236            match self.builder.build_store(offset_ptr, val) {
237                Ok(_) => {}
238                Err(e) => {
239                    return Err(CodeGenError::new(&format!("Error storing value: {}", e)));
240                }
241            };
242        }
243
244        if !expr.imag.is_zero_fast() {
245            let val = self.build_expression(&expr.imag)?;
246            let offset = self.int_type().const_int(re_offset + 1, false);
247            let offset_ptr = unsafe {
248                self.builder
249                    .build_gep(self.float_type(), ptr, &[offset], "offset_ptr")
250                    .unwrap()
251            };
252
253            match self.builder.build_store(offset_ptr, val) {
254                Ok(_) => {}
255                Err(e) => {
256                    return Err(CodeGenError::new(&format!("Error storing value: {}", e)));
257                }
258            };
259        }
260
261        Ok(())
262    }
263
264    fn build_var_table(&mut self, variables: &[String]) {
265        self.variables.clear();
266        let params_ptr = self
267            .fn_value_opt
268            .unwrap()
269            .get_nth_param(0)
270            .unwrap()
271            .into_pointer_value();
272        let param_offset_ptr = self
273            .fn_value_opt
274            .unwrap()
275            .get_nth_param(2)
276            .unwrap()
277            .into_pointer_value();
278
279        for (map_idx, var_name) in variables.iter().enumerate() {
280            // Load the actual offset from param_offset_ptr using map_idx
281            // param_offset_ptr points to u64 values according to Rust signature
282            let map_idx_val = self
283                .context
284                .context()
285                .i64_type()
286                .const_int(map_idx as u64, false);
287            let actual_offset_ptr = unsafe {
288                self.builder
289                    .build_gep(
290                        self.context.context().i64_type(),
291                        param_offset_ptr,
292                        &[map_idx_val],
293                        "actual_offset_ptr_gep",
294                    )
295                    .unwrap()
296            };
297            let actual_offset_val = self
298                .builder
299                .build_load(
300                    self.context.context().i64_type(),
301                    actual_offset_ptr,
302                    "actual_offset_val",
303                )
304                .unwrap()
305                .into_int_value();
306
307            // Use the actual_offset_val to index into params_ptr
308            let var_ptr = unsafe {
309                self.builder
310                    .build_gep(
311                        self.float_type(),
312                        params_ptr,
313                        &[actual_offset_val],
314                        "var_ptr_gep",
315                    )
316                    .unwrap()
317            };
318
319            let val = self
320                .builder
321                .build_load(self.float_type(), var_ptr, var_name)
322                .unwrap()
323                .into_float_value();
324            self.variables.insert(var_name.to_owned(), val);
325        }
326    }
327
328    fn get_builtin(&mut self, name: &str) -> FunctionValue<'ctx> {
329        if let Some(f) = self.functions.get(name) {
330            return *f;
331        }
332
333        let b = match Builtins::from_str(name) {
334            Some(b) => b,
335            None => {
336                panic!("Unsupported builtin function: {}", name);
337            }
338        };
339
340        let intr = match b.intrinsic() {
341            Some(i) => i,
342            None => {
343                panic!("Unsupported builtin function: {}", name);
344            }
345        };
346
347        let decl = self
348            .context
349            .with_module(|module| intr.get_declaration(&module, &[self.float_type().into()]));
350
351        let fn_value = match decl {
352            Some(f) => f,
353            None => {
354                panic!("Unsupported builtin function: {}", name);
355            }
356        };
357
358        self.functions.insert(name.to_string(), fn_value);
359        fn_value
360    }
361
362    // fn get_expression(&mut self, name: &str) -> Option<FloatValue<'ctx>> {
363    //     if let Some(c) = self.expressions.get(name) {
364    //         return Some(c.clone());
365    //     }
366
367    //     let c = match name {
368    //         "pi" => Some(self.float_type().const_float(std::f64::consts::PI)),
369    //         "π" => Some(self.float_type().const_float(std::f64::consts::PI)),
370    //         "e" => Some(self.float_type().const_float(std::f64::consts::E)),
371    //         _ => None
372    //     };
373
374    //     if let Some(c) = c {
375    //         self.expressions.insert(name.to_string(), c);
376    //         return Some(c);
377    //     }
378
379    //     None
380    // }
381
382    pub fn gen_func(
383        &mut self,
384        fn_name: &str,
385        fn_expr: &[Expression],
386        var_table: &[String],
387        fn_len: usize,
388    ) -> CodeGenResult<()> {
389        self.expressions.clear();
390        let func = self.gen_write_func_proto(fn_name)?;
391        let entry = self.context.context().append_basic_block(func, "entry");
392        self.builder.position_at_end(entry);
393        self.fn_value_opt = Some(func);
394        // println!("name: {:?}, var_table: {:?}", fn_name, var_table);
395        self.build_var_table(var_table);
396
397        let output_ptr = self
398            .fn_value_opt
399            .unwrap()
400            .get_nth_param(1)
401            .unwrap()
402            .into_pointer_value();
403        let output_map_ptr = self
404            .fn_value_opt
405            .unwrap()
406            .get_nth_param(3)
407            .unwrap()
408            .into_pointer_value();
409        let fn_unit_offset = self
410            .fn_value_opt
411            .unwrap()
412            .get_nth_param(4)
413            .unwrap()
414            .into_int_value();
415        let const_param_ptr = self
416            .fn_value_opt
417            .unwrap()
418            .get_nth_param(5)
419            .unwrap()
420            .into_pointer_value();
421        let int_increment = self.context.context().i64_type().const_int(1u64, false);
422        let mut dyn_fn_unit_idx = self.context.context().i64_type().const_int(0u64, false);
423
424        for (fn_unit_idx, fn_unit_exprs) in fn_expr.chunks(fn_len).enumerate() {
425            let current_fn = self.fn_value_opt.unwrap();
426            let compute_block = self
427                .context
428                .context()
429                .append_basic_block(current_fn, &format!("compute_unit_{}", fn_unit_idx));
430            let skip_block = self
431                .context
432                .context()
433                .append_basic_block(current_fn, &format!("skip_unit_{}", fn_unit_idx));
434
435            if fn_unit_idx == 0 {
436                // First function unit is the function, always compute.
437                self.builder
438                    .build_unconditional_branch(compute_block)
439                    .unwrap();
440            } else if fn_unit_idx <= var_table.len() {
441                // Next var_table.len() function units are the partials in the gradient.
442                // Check if the corresponding parameter is constant.
443                let param_idx_for_const_check = self
444                    .context
445                    .context()
446                    .i64_type()
447                    .const_int((fn_unit_idx - 1) as u64, false);
448                let const_param_elem_ptr = unsafe {
449                    self.builder
450                        .build_gep(
451                            self.context.context().i8_type(),
452                            const_param_ptr,
453                            &[param_idx_for_const_check],
454                            "const_param_elem_ptr",
455                        )
456                        .unwrap()
457                };
458                let is_constant_val = self
459                    .builder
460                    .build_load(
461                        self.context.context().i8_type(),
462                        const_param_elem_ptr,
463                        "is_constant",
464                    )
465                    .unwrap()
466                    .into_int_value();
467
468                let is_constant_true = self
469                    .builder
470                    .build_int_compare(
471                        inkwell::IntPredicate::EQ,
472                        is_constant_val,
473                        self.context.context().i8_type().const_int(1, false),
474                        "is_constant_true",
475                    )
476                    .unwrap();
477
478                // If is_constant_true, branch to skip_block. Else, branch to compute_block.
479                self.builder
480                    .build_conditional_branch(is_constant_true, skip_block, compute_block)
481                    .unwrap();
482            } else {
483                // Last are the hessian units
484
485                let index = fn_unit_idx - var_table.len() - 1;
486                let param_j = ((((8 * index + 1) as f64).sqrt().floor() as usize) - 1) / 2;
487                let param_i = index - param_j * (param_j + 1) / 2;
488
489                // Check if the corresponding parameter is constant.
490                let param_idx_i_for_const_check = self
491                    .context
492                    .context()
493                    .i64_type()
494                    .const_int((param_i - 1) as u64, false);
495                let const_param_elem_ptr = unsafe {
496                    self.builder
497                        .build_gep(
498                            self.context.context().i8_type(),
499                            const_param_ptr,
500                            &[param_idx_i_for_const_check],
501                            "const_param_i_elem_ptr",
502                        )
503                        .unwrap()
504                };
505                let is_i_constant_val = self
506                    .builder
507                    .build_load(
508                        self.context.context().i8_type(),
509                        const_param_elem_ptr,
510                        "is_i_constant",
511                    )
512                    .unwrap()
513                    .into_int_value();
514
515                let is_i_constant_true = self
516                    .builder
517                    .build_int_compare(
518                        inkwell::IntPredicate::EQ,
519                        is_i_constant_val,
520                        self.context.context().i8_type().const_int(1, false),
521                        "is_i_constant_true",
522                    )
523                    .unwrap();
524
525                // Check if the corresponding parameter is constant.
526                let param_idx_j_for_const_check = self
527                    .context
528                    .context()
529                    .i64_type()
530                    .const_int((param_j - 1) as u64, false);
531                let const_param_elem_ptr = unsafe {
532                    self.builder
533                        .build_gep(
534                            self.context.context().i8_type(),
535                            const_param_ptr,
536                            &[param_idx_j_for_const_check],
537                            "const_param_j_elem_ptr",
538                        )
539                        .unwrap()
540                };
541                let is_j_constant_val = self
542                    .builder
543                    .build_load(
544                        self.context.context().i8_type(),
545                        const_param_elem_ptr,
546                        "is_j_constant",
547                    )
548                    .unwrap()
549                    .into_int_value();
550
551                let is_j_constant_true = self
552                    .builder
553                    .build_int_compare(
554                        inkwell::IntPredicate::EQ,
555                        is_j_constant_val,
556                        self.context.context().i8_type().const_int(1, false),
557                        "is_j_constant_true",
558                    )
559                    .unwrap();
560
561                let is_either_constant = self
562                    .builder
563                    .build_or(is_i_constant_true, is_j_constant_true, "is_either_constant")
564                    .unwrap();
565
566                // If is_constant_true, branch to skip_block. Else, branch to compute_block.
567                self.builder
568                    .build_conditional_branch(is_either_constant, skip_block, compute_block)
569                    .unwrap();
570            }
571
572            // Position builder at the compute block to emit the actual computation logic
573            self.builder.position_at_end(compute_block);
574
575            let this_unit_offset = self
576                .builder
577                .build_int_mul(fn_unit_offset, dyn_fn_unit_idx, "this_unit_offset")
578                .unwrap();
579            dyn_fn_unit_idx = self
580                .builder
581                .build_int_add(dyn_fn_unit_idx, int_increment, "dyn_fn_unit_idx")
582                .unwrap();
583
584            for (i, expr) in fn_unit_exprs.iter().enumerate() {
585                if expr.is_zero_fast() {
586                    continue;
587                }
588
589                let val = self.build_expression(expr)?;
590                let offset_ptr = unsafe {
591                    let output_idx = self.context.context().i64_type().const_int(i as u64, false);
592                    let output_map_elem_ptr = self
593                        .builder
594                        .build_gep(
595                            self.context.context().i64_type(),
596                            output_map_ptr,
597                            &[output_idx],
598                            "output_map_elem_ptr",
599                        )
600                        .unwrap();
601                    let output_map_offset = self
602                        .builder
603                        .build_load(
604                            self.context.context().i64_type(),
605                            output_map_elem_ptr,
606                            "output_map_offset",
607                        )
608                        .unwrap()
609                        .into_int_value();
610                    let combined_offset = self
611                        .builder
612                        .build_int_add(output_map_offset, this_unit_offset, "combined_offset")
613                        .unwrap();
614                    self.builder
615                        .build_gep(
616                            self.float_type(),
617                            output_ptr,
618                            &[combined_offset],
619                            "offset_ptr",
620                        )
621                        .unwrap()
622                };
623
624                match self.builder.build_store(offset_ptr, val) {
625                    Ok(_) => {}
626                    Err(e) => {
627                        return Err(CodeGenError::new(&format!("Error storing value: {}", e)));
628                    }
629                };
630            }
631
632            // After computation (if compute_block was taken), branch to the skip_block
633            // to ensure control flow continues to the next iteration of the loop.
634            self.builder.build_unconditional_branch(skip_block).unwrap();
635
636            // Position builder at the skip block for the next iteration.
637            self.builder.position_at_end(skip_block);
638        }
639
640        match self.builder.build_return(None) {
641            Ok(_) => Ok(()),
642            Err(e) => Err(CodeGenError::new(&format!("Error building return: {}", e))),
643        }
644    }
645}