qudit_expr/codegen/
builder.rs1use super::{codegen::CodeGenerator, module::Module};
2use qudit_core::RealScalar;
3
4use crate::Expression;
5
6pub struct CompilableUnit<'a> {
7 pub fn_name: String,
8 pub exprs: &'a [Expression],
9 pub variables: Vec<String>,
10 pub unit_size: usize,
11}
12
13impl<'a> CompilableUnit<'a> {
14 pub fn new(
15 name: &str,
16 exprs: &'a [Expression],
17 variables: Vec<String>,
18 unit_size: usize,
19 ) -> Self {
20 CompilableUnit {
21 fn_name: name.to_string(),
22 exprs,
23 variables,
24 unit_size,
25 }
26 }
27
28 pub fn add_to_module<R: RealScalar>(&self, module: &Module<R>) {
29 let mut codegen = CodeGenerator::new(module);
34 codegen
35 .gen_func(&self.fn_name, self.exprs, &self.variables, self.unit_size)
36 .expect("Error generating function.");
37 }
38}
39
40pub type DifferentiationLevel = usize;
41pub const FUNCTION: DifferentiationLevel = 1;
42pub const GRADIENT: DifferentiationLevel = 2;
43pub const HESSIAN: DifferentiationLevel = 3;
44
45pub struct ModuleBuilder<'a, R: RealScalar> {
73 name: String,
74 exprs: Vec<CompilableUnit<'a>>,
75 _phantom_c: std::marker::PhantomData<R>,
76}
77
78impl<'a, R: RealScalar> ModuleBuilder<'a, R> {
79 pub fn new(name: &str) -> Self {
80 ModuleBuilder {
81 name: name.to_string(),
82 exprs: Vec::new(),
83 _phantom_c: std::marker::PhantomData,
84 }
85 }
86
87 pub fn add_unit(mut self, unit: CompilableUnit<'a>) -> Self {
88 self.exprs.push(unit);
89 self
90 }
91
92 pub fn build(self) -> Module<R> {
93 let module = Module::new(&self.name);
94 for expr in &self.exprs {
95 expr.add_to_module(&module);
96 }
97 module
98 }
99}