1use inkwell::AddressSpace;
2use inkwell::builder::Builder;
3use inkwell::values::FloatValue;
4use inkwell::values::FunctionValue;
5
6use coe::is_same;
7use qudit_core::RealScalar;
8use std::collections::HashMap;
9
10use crate::ComplexExpression;
11use crate::Expression;
12
13use super::builtins::Builtins;
14use super::module::Module;
15
16#[derive(Debug)]
17pub struct CodeGenError {
18 pub message: String,
19}
20
21impl CodeGenError {
22 pub fn new(message: &str) -> Self {
23 CodeGenError {
24 message: message.to_string(),
25 }
26 }
27}
28
29type CodeGenResult<T> = Result<T, CodeGenError>;
30
31#[derive(Debug)]
32pub struct CodeGenerator<'ctx, R: RealScalar> {
33 pub context: &'ctx Module<R>,
34 pub builder: Builder<'ctx>,
35
36 variables: HashMap<String, FloatValue<'ctx>>,
37 expressions: HashMap<String, FloatValue<'ctx>>,
38 functions: HashMap<String, FunctionValue<'ctx>>,
39 fn_value_opt: Option<FunctionValue<'ctx>>,
40 output_ptr_idx: Option<u32>,
41}
42
43impl<'ctx, R: RealScalar> CodeGenerator<'ctx, R> {
44 pub fn new(context: &'ctx Module<R>) -> Self {
45 let builder = context.context().create_builder();
46 CodeGenerator {
47 context,
48 builder,
49 variables: HashMap::new(),
50 functions: HashMap::new(),
51 expressions: HashMap::new(),
52 fn_value_opt: None,
53 output_ptr_idx: None,
54 }
55 }
56
57 fn int_type(&self) -> inkwell::types::IntType<'ctx> {
58 if is_same::<R, f32>() {
59 self.context.context().i32_type()
60 } else if is_same::<R, f64>() {
61 self.context.context().i64_type()
62 } else {
63 panic!("Unknown bit width");
64 }
65 }
66
67 fn float_type(&self) -> inkwell::types::FloatType<'ctx> {
68 if is_same::<R, f32>() {
69 self.context.context().f32_type()
70 } else if is_same::<R, f64>() {
71 self.context.context().f64_type()
72 } else {
73 panic!("Unknown bit width");
74 }
75 }
76
77 fn gen_write_func_proto(&self, name: &str) -> CodeGenResult<FunctionValue<'ctx>> {
78 let ret_type = self.context.context().void_type();
79 let ptr_type = self.context.context().ptr_type(AddressSpace::default());
80 let param_types = vec![
82 ptr_type.into(), ptr_type.into(), ptr_type.into(), ptr_type.into(), self.context.context().i64_type().into(), ptr_type.into(), ];
89 let func_type = ret_type.fn_type(¶m_types, false);
90 let func = self
91 .context
92 .with_module(|module| module.add_function(name, func_type, None));
93 Ok(func)
94 }
95
96 fn build_expression(&mut self, expr: &Expression) -> CodeGenResult<FloatValue<'ctx>> {
97 let expr_str = expr.to_string();
98 let cached = self.expressions.get(&expr_str);
99 if let Some(c) = cached {
100 return Ok(*c);
101 }
102
103 let val = match expr {
104 Expression::Pi => Ok(self.float_type().const_float(std::f64::consts::PI)),
105 Expression::Constant(_) => Ok(self.float_type().const_float(expr.to_float())),
106 Expression::Variable(name) => self
107 .variables
108 .get(name)
109 .ok_or(CodeGenError::new(&format!("Variable {} not found", name)))
110 .copied(),
111 Expression::Neg(expr) => {
112 let val = self.build_expression(expr)?;
113 Ok(self.builder.build_float_neg(val, "tmp").unwrap())
114 }
115 Expression::Add(lhs, rhs) => {
116 let lhs_val = self.build_expression(lhs)?;
117 let rhs_val = self.build_expression(rhs)?;
118 Ok(self
119 .builder
120 .build_float_add(lhs_val, rhs_val, "tmp")
121 .unwrap())
122 }
123 Expression::Sub(lhs, rhs) => {
124 let lhs_val = self.build_expression(lhs)?;
125 let rhs_val = self.build_expression(rhs)?;
126 Ok(self
127 .builder
128 .build_float_sub(lhs_val, rhs_val, "tmp")
129 .unwrap())
130 }
131 Expression::Mul(lhs, rhs) => {
132 let lhs_val = self.build_expression(lhs)?;
133 let rhs_val = self.build_expression(rhs)?;
134 Ok(self
135 .builder
136 .build_float_mul(lhs_val, rhs_val, "tmp")
137 .unwrap())
138 }
139 Expression::Div(lhs, rhs) => {
140 let lhs_val = self.build_expression(lhs)?;
141 let rhs_val = self.build_expression(rhs)?;
142 Ok(self
143 .builder
144 .build_float_div(lhs_val, rhs_val, "tmp")
145 .unwrap())
146 }
147 Expression::Pow(base, exponent) => {
148 let base_val = self.build_expression(base)?;
149 let exponent_val = self.build_expression(exponent)?;
150 let pow = self.get_builtin("pow");
151 let args = [base_val.into(), exponent_val.into()];
152 let val = self
153 .builder
154 .build_call(pow, &args, "tmp")
155 .unwrap()
156 .try_as_basic_value()
157 .left()
158 .unwrap()
159 .into_float_value();
160 Ok(val)
161 }
162 Expression::Sqrt(expr) => {
163 let arg = self.build_expression(expr)?;
164 let sqrt = self.get_builtin("sqrt");
165 let val = self
166 .builder
167 .build_call(sqrt, &[arg.into()], "tmp")
168 .unwrap()
169 .try_as_basic_value()
170 .left()
171 .unwrap()
172 .into_float_value();
173 Ok(val)
174 }
175 Expression::Sin(expr) => {
176 let arg = self.build_expression(expr)?;
177 let sin = self.get_builtin("sin");
178 let val = self
179 .builder
180 .build_call(sin, &[arg.into()], "tmp")
181 .unwrap()
182 .try_as_basic_value()
183 .left()
184 .unwrap()
185 .into_float_value();
186 Ok(val)
187 }
188 Expression::Cos(expr) => {
189 let arg = self.build_expression(expr)?;
190 let cos = self.get_builtin("cos");
191 let val = self
192 .builder
193 .build_call(cos, &[arg.into()], "tmp")
194 .unwrap()
195 .try_as_basic_value()
196 .left()
197 .unwrap()
198 .into_float_value();
199 Ok(val)
200 }
201 };
202
203 if let Ok(val) = val {
204 self.expressions.insert(expr_str, val);
205 }
206 val
207 }
208
209 pub fn compile_expr(
210 &mut self,
211 expr: &ComplexExpression,
212 re_offset: usize,
213 need_to_write_real_zero: bool,
214 ) -> CodeGenResult<()> {
215 let re_offset: u64 = re_offset as u64;
216 let ptr_idx = self
217 .output_ptr_idx
218 .to_owned()
219 .expect("Output pointer index not set");
220 let ptr = self
221 .fn_value_opt
222 .unwrap()
223 .get_nth_param(ptr_idx)
224 .unwrap()
225 .into_pointer_value();
226
227 if need_to_write_real_zero || !expr.real.is_zero_fast() {
228 let val = self.build_expression(&expr.real)?;
229 let offset = self.int_type().const_int(re_offset, false);
230 let offset_ptr = unsafe {
231 self.builder
232 .build_gep(self.float_type(), ptr, &[offset], "offset_ptr")
233 .unwrap()
234 };
235
236 match self.builder.build_store(offset_ptr, val) {
237 Ok(_) => {}
238 Err(e) => {
239 return Err(CodeGenError::new(&format!("Error storing value: {}", e)));
240 }
241 };
242 }
243
244 if !expr.imag.is_zero_fast() {
245 let val = self.build_expression(&expr.imag)?;
246 let offset = self.int_type().const_int(re_offset + 1, false);
247 let offset_ptr = unsafe {
248 self.builder
249 .build_gep(self.float_type(), ptr, &[offset], "offset_ptr")
250 .unwrap()
251 };
252
253 match self.builder.build_store(offset_ptr, val) {
254 Ok(_) => {}
255 Err(e) => {
256 return Err(CodeGenError::new(&format!("Error storing value: {}", e)));
257 }
258 };
259 }
260
261 Ok(())
262 }
263
264 fn build_var_table(&mut self, variables: &[String]) {
265 self.variables.clear();
266 let params_ptr = self
267 .fn_value_opt
268 .unwrap()
269 .get_nth_param(0)
270 .unwrap()
271 .into_pointer_value();
272 let param_offset_ptr = self
273 .fn_value_opt
274 .unwrap()
275 .get_nth_param(2)
276 .unwrap()
277 .into_pointer_value();
278
279 for (map_idx, var_name) in variables.iter().enumerate() {
280 let map_idx_val = self
283 .context
284 .context()
285 .i64_type()
286 .const_int(map_idx as u64, false);
287 let actual_offset_ptr = unsafe {
288 self.builder
289 .build_gep(
290 self.context.context().i64_type(),
291 param_offset_ptr,
292 &[map_idx_val],
293 "actual_offset_ptr_gep",
294 )
295 .unwrap()
296 };
297 let actual_offset_val = self
298 .builder
299 .build_load(
300 self.context.context().i64_type(),
301 actual_offset_ptr,
302 "actual_offset_val",
303 )
304 .unwrap()
305 .into_int_value();
306
307 let var_ptr = unsafe {
309 self.builder
310 .build_gep(
311 self.float_type(),
312 params_ptr,
313 &[actual_offset_val],
314 "var_ptr_gep",
315 )
316 .unwrap()
317 };
318
319 let val = self
320 .builder
321 .build_load(self.float_type(), var_ptr, var_name)
322 .unwrap()
323 .into_float_value();
324 self.variables.insert(var_name.to_owned(), val);
325 }
326 }
327
328 fn get_builtin(&mut self, name: &str) -> FunctionValue<'ctx> {
329 if let Some(f) = self.functions.get(name) {
330 return *f;
331 }
332
333 let b = match Builtins::from_str(name) {
334 Some(b) => b,
335 None => {
336 panic!("Unsupported builtin function: {}", name);
337 }
338 };
339
340 let intr = match b.intrinsic() {
341 Some(i) => i,
342 None => {
343 panic!("Unsupported builtin function: {}", name);
344 }
345 };
346
347 let decl = self
348 .context
349 .with_module(|module| intr.get_declaration(&module, &[self.float_type().into()]));
350
351 let fn_value = match decl {
352 Some(f) => f,
353 None => {
354 panic!("Unsupported builtin function: {}", name);
355 }
356 };
357
358 self.functions.insert(name.to_string(), fn_value);
359 fn_value
360 }
361
362 pub fn gen_func(
383 &mut self,
384 fn_name: &str,
385 fn_expr: &[Expression],
386 var_table: &[String],
387 fn_len: usize,
388 ) -> CodeGenResult<()> {
389 self.expressions.clear();
390 let func = self.gen_write_func_proto(fn_name)?;
391 let entry = self.context.context().append_basic_block(func, "entry");
392 self.builder.position_at_end(entry);
393 self.fn_value_opt = Some(func);
394 self.build_var_table(var_table);
396
397 let output_ptr = self
398 .fn_value_opt
399 .unwrap()
400 .get_nth_param(1)
401 .unwrap()
402 .into_pointer_value();
403 let output_map_ptr = self
404 .fn_value_opt
405 .unwrap()
406 .get_nth_param(3)
407 .unwrap()
408 .into_pointer_value();
409 let fn_unit_offset = self
410 .fn_value_opt
411 .unwrap()
412 .get_nth_param(4)
413 .unwrap()
414 .into_int_value();
415 let const_param_ptr = self
416 .fn_value_opt
417 .unwrap()
418 .get_nth_param(5)
419 .unwrap()
420 .into_pointer_value();
421 let int_increment = self.context.context().i64_type().const_int(1u64, false);
422 let mut dyn_fn_unit_idx = self.context.context().i64_type().const_int(0u64, false);
423
424 for (fn_unit_idx, fn_unit_exprs) in fn_expr.chunks(fn_len).enumerate() {
425 let current_fn = self.fn_value_opt.unwrap();
426 let compute_block = self
427 .context
428 .context()
429 .append_basic_block(current_fn, &format!("compute_unit_{}", fn_unit_idx));
430 let skip_block = self
431 .context
432 .context()
433 .append_basic_block(current_fn, &format!("skip_unit_{}", fn_unit_idx));
434
435 if fn_unit_idx == 0 {
436 self.builder
438 .build_unconditional_branch(compute_block)
439 .unwrap();
440 } else if fn_unit_idx <= var_table.len() {
441 let param_idx_for_const_check = self
444 .context
445 .context()
446 .i64_type()
447 .const_int((fn_unit_idx - 1) as u64, false);
448 let const_param_elem_ptr = unsafe {
449 self.builder
450 .build_gep(
451 self.context.context().i8_type(),
452 const_param_ptr,
453 &[param_idx_for_const_check],
454 "const_param_elem_ptr",
455 )
456 .unwrap()
457 };
458 let is_constant_val = self
459 .builder
460 .build_load(
461 self.context.context().i8_type(),
462 const_param_elem_ptr,
463 "is_constant",
464 )
465 .unwrap()
466 .into_int_value();
467
468 let is_constant_true = self
469 .builder
470 .build_int_compare(
471 inkwell::IntPredicate::EQ,
472 is_constant_val,
473 self.context.context().i8_type().const_int(1, false),
474 "is_constant_true",
475 )
476 .unwrap();
477
478 self.builder
480 .build_conditional_branch(is_constant_true, skip_block, compute_block)
481 .unwrap();
482 } else {
483 let index = fn_unit_idx - var_table.len() - 1;
486 let param_j = ((((8 * index + 1) as f64).sqrt().floor() as usize) - 1) / 2;
487 let param_i = index - param_j * (param_j + 1) / 2;
488
489 let param_idx_i_for_const_check = self
491 .context
492 .context()
493 .i64_type()
494 .const_int((param_i - 1) as u64, false);
495 let const_param_elem_ptr = unsafe {
496 self.builder
497 .build_gep(
498 self.context.context().i8_type(),
499 const_param_ptr,
500 &[param_idx_i_for_const_check],
501 "const_param_i_elem_ptr",
502 )
503 .unwrap()
504 };
505 let is_i_constant_val = self
506 .builder
507 .build_load(
508 self.context.context().i8_type(),
509 const_param_elem_ptr,
510 "is_i_constant",
511 )
512 .unwrap()
513 .into_int_value();
514
515 let is_i_constant_true = self
516 .builder
517 .build_int_compare(
518 inkwell::IntPredicate::EQ,
519 is_i_constant_val,
520 self.context.context().i8_type().const_int(1, false),
521 "is_i_constant_true",
522 )
523 .unwrap();
524
525 let param_idx_j_for_const_check = self
527 .context
528 .context()
529 .i64_type()
530 .const_int((param_j - 1) as u64, false);
531 let const_param_elem_ptr = unsafe {
532 self.builder
533 .build_gep(
534 self.context.context().i8_type(),
535 const_param_ptr,
536 &[param_idx_j_for_const_check],
537 "const_param_j_elem_ptr",
538 )
539 .unwrap()
540 };
541 let is_j_constant_val = self
542 .builder
543 .build_load(
544 self.context.context().i8_type(),
545 const_param_elem_ptr,
546 "is_j_constant",
547 )
548 .unwrap()
549 .into_int_value();
550
551 let is_j_constant_true = self
552 .builder
553 .build_int_compare(
554 inkwell::IntPredicate::EQ,
555 is_j_constant_val,
556 self.context.context().i8_type().const_int(1, false),
557 "is_j_constant_true",
558 )
559 .unwrap();
560
561 let is_either_constant = self
562 .builder
563 .build_or(is_i_constant_true, is_j_constant_true, "is_either_constant")
564 .unwrap();
565
566 self.builder
568 .build_conditional_branch(is_either_constant, skip_block, compute_block)
569 .unwrap();
570 }
571
572 self.builder.position_at_end(compute_block);
574
575 let this_unit_offset = self
576 .builder
577 .build_int_mul(fn_unit_offset, dyn_fn_unit_idx, "this_unit_offset")
578 .unwrap();
579 dyn_fn_unit_idx = self
580 .builder
581 .build_int_add(dyn_fn_unit_idx, int_increment, "dyn_fn_unit_idx")
582 .unwrap();
583
584 for (i, expr) in fn_unit_exprs.iter().enumerate() {
585 if expr.is_zero_fast() {
586 continue;
587 }
588
589 let val = self.build_expression(expr)?;
590 let offset_ptr = unsafe {
591 let output_idx = self.context.context().i64_type().const_int(i as u64, false);
592 let output_map_elem_ptr = self
593 .builder
594 .build_gep(
595 self.context.context().i64_type(),
596 output_map_ptr,
597 &[output_idx],
598 "output_map_elem_ptr",
599 )
600 .unwrap();
601 let output_map_offset = self
602 .builder
603 .build_load(
604 self.context.context().i64_type(),
605 output_map_elem_ptr,
606 "output_map_offset",
607 )
608 .unwrap()
609 .into_int_value();
610 let combined_offset = self
611 .builder
612 .build_int_add(output_map_offset, this_unit_offset, "combined_offset")
613 .unwrap();
614 self.builder
615 .build_gep(
616 self.float_type(),
617 output_ptr,
618 &[combined_offset],
619 "offset_ptr",
620 )
621 .unwrap()
622 };
623
624 match self.builder.build_store(offset_ptr, val) {
625 Ok(_) => {}
626 Err(e) => {
627 return Err(CodeGenError::new(&format!("Error storing value: {}", e)));
628 }
629 };
630 }
631
632 self.builder.build_unconditional_branch(skip_block).unwrap();
635
636 self.builder.position_at_end(skip_block);
638 }
639
640 match self.builder.build_return(None) {
641 Ok(_) => Ok(()),
642 Err(e) => Err(CodeGenError::new(&format!("Error building return: {}", e))),
643 }
644 }
645}