qudit_expr/codegen/
builder.rs

1use super::{codegen::CodeGenerator, module::Module};
2use qudit_core::RealScalar;
3
4use crate::Expression;
5
6pub struct CompilableUnit<'a> {
7    pub fn_name: String,
8    pub exprs: &'a [Expression],
9    pub variables: Vec<String>,
10    pub unit_size: usize,
11}
12
13impl<'a> CompilableUnit<'a> {
14    pub fn new(
15        name: &str,
16        exprs: &'a [Expression],
17        variables: Vec<String>,
18        unit_size: usize,
19    ) -> Self {
20        CompilableUnit {
21            fn_name: name.to_string(),
22            exprs,
23            variables,
24            unit_size,
25        }
26    }
27
28    pub fn add_to_module<R: RealScalar>(&self, module: &Module<R>) {
29        // println!("Adding fn_name: {} to module.", self.fn_name);
30        // for expr in &self.exprs {
31        //     println!("{:?}", expr);
32        // }
33        let mut codegen = CodeGenerator::new(module);
34        codegen
35            .gen_func(&self.fn_name, self.exprs, &self.variables, self.unit_size)
36            .expect("Error generating function.");
37    }
38}
39
40pub type DifferentiationLevel = usize;
41pub const FUNCTION: DifferentiationLevel = 1;
42pub const GRADIENT: DifferentiationLevel = 2;
43pub const HESSIAN: DifferentiationLevel = 3;
44
45// pub trait DifferentiationLevel {}
46
47// struct Function {}
48// impl DifferentiationLevel for Function {}
49// struct Gradient {}
50// impl DifferentiationLevel for Gradient {}
51// struct Hessian {}
52// impl DifferentiationLevel for Hessian {}
53
54// impl DifferentiationLevel {
55//     pub fn gradient_capable(&self) -> bool {
56//         match self {
57//             DifferentiationLevel::None => false,
58//             DifferentiationLevel::Gradient => true,
59//             DifferentiationLevel::Hessian => true,
60//         }
61//     }
62
63//     pub fn hessian_capable(&self) -> bool {
64//         match self {
65//             DifferentiationLevel::None => false,
66//             DifferentiationLevel::Gradient => false,
67//             DifferentiationLevel::Hessian => true,
68//         }
69//     }
70// }
71
72pub struct ModuleBuilder<'a, R: RealScalar> {
73    name: String,
74    exprs: Vec<CompilableUnit<'a>>,
75    _phantom_c: std::marker::PhantomData<R>,
76}
77
78impl<'a, R: RealScalar> ModuleBuilder<'a, R> {
79    pub fn new(name: &str) -> Self {
80        ModuleBuilder {
81            name: name.to_string(),
82            exprs: Vec::new(),
83            _phantom_c: std::marker::PhantomData,
84        }
85    }
86
87    pub fn add_unit(mut self, unit: CompilableUnit<'a>) -> Self {
88        self.exprs.push(unit);
89        self
90    }
91
92    pub fn build(self) -> Module<R> {
93        let module = Module::new(&self.name);
94        for expr in &self.exprs {
95            expr.add_to_module(&module);
96        }
97        module
98    }
99}