qudit_expr/codegen/
module.rs

1use inkwell::context::Context;
2use qudit_core::RealScalar;
3
4use std::borrow::Cow;
5use std::ffi::{CStr, CString};
6use std::mem::{ManuallyDrop, MaybeUninit};
7use std::sync::Mutex;
8
9use inkwell::targets::{InitializationConfig, Target};
10use llvm_sys::core::{LLVMContextCreate, LLVMModuleCreateWithNameInContext};
11use llvm_sys::execution_engine::{
12    LLVMCreateJITCompilerForModule, LLVMDisposeExecutionEngine, LLVMExecutionEngineRef,
13    LLVMGetFunctionAddress,
14};
15use llvm_sys::prelude::LLVMModuleRef;
16
17use inkwell::module::Module as InkwellModule;
18
19use crate::WriteFunc;
20
21use super::process_name_for_gen;
22
23pub(crate) fn to_c_str(mut s: &str) -> Cow<'_, CStr> {
24    if s.is_empty() {
25        s = "\0";
26    }
27
28    // Start from the end of the string as it's the most likely place to find a null byte
29    if !s.chars().rev().any(|ch| ch == '\0') {
30        return Cow::from(CString::new(s).expect("unreachable since null bytes are checked"));
31    }
32
33    unsafe { Cow::from(CStr::from_ptr(s.as_ptr() as *const _)) }
34}
35
36fn convert_c_string(c_str: *mut i8) -> String {
37    // Safety: Ensure that c_str is not null and points to a valid null-terminated string.
38    assert!(!c_str.is_null());
39
40    // Convert the raw pointer to a CStr, which will handle the null termination.
41    let c_str = unsafe { CStr::from_ptr(c_str) };
42
43    // Convert CStr to String
44    c_str.to_string_lossy().into_owned()
45}
46
47#[derive(Debug)]
48pub struct Module<R: RealScalar> {
49    engine: Mutex<LLVMExecutionEngineRef>,
50    module: Mutex<LLVMModuleRef>,
51    context: Context,
52    phantom: std::marker::PhantomData<R>,
53}
54
55impl<R: RealScalar> Module<R> {
56    pub fn new(module_name: &str) -> Self {
57        unsafe {
58            let core_context = LLVMContextCreate();
59
60            let c_string = to_c_str(module_name);
61            let core_module = LLVMModuleCreateWithNameInContext(c_string.as_ptr(), core_context);
62            // LLVMLinkInMCJIT();
63            match Target::initialize_native(&InitializationConfig::default()) {
64                Ok(_) => {}
65                Err(string) => panic!("Error initializing native target: {:?}", string),
66            }
67
68            let mut execution_engine = MaybeUninit::uninit();
69            let mut err_string = MaybeUninit::uninit();
70
71            let code = LLVMCreateJITCompilerForModule(
72                execution_engine.as_mut_ptr(),
73                core_module,
74                3,
75                err_string.as_mut_ptr(),
76            );
77
78            if code == 1 {
79                panic!(
80                    "Error creating JIT compiler: {:?}",
81                    convert_c_string(err_string.assume_init())
82                );
83            }
84
85            let execution_engine = execution_engine.assume_init();
86
87            Module {
88                context: Context::new(core_context),
89                module: core_module.into(),
90                engine: execution_engine.into(),
91                phantom: std::marker::PhantomData,
92            }
93        }
94    }
95
96    pub fn with_module<'a, F, G>(&self, f: F) -> G
97    where
98        F: FnOnce(ManuallyDrop<InkwellModule<'a>>) -> G,
99    {
100        let module_ref = self.module.lock().unwrap();
101        let module = unsafe { ManuallyDrop::new(InkwellModule::new(*module_ref)) };
102        f(module)
103    }
104
105    pub fn context(&self) -> &Context {
106        &self.context
107    }
108
109    // pub fn get_function<'a>(&'a self, name: &str) -> Option<WriteFuncWithLifeTime<'a, R>> {
110    pub fn get_function(&self, name: &str) -> Option<WriteFunc<R>> {
111        let name = process_name_for_gen(name);
112        let engine_ref = self.engine.lock().unwrap();
113
114        assert!(!(*engine_ref).is_null());
115
116        let address = {
117            let c_string = to_c_str(&name);
118            let address = unsafe { LLVMGetFunctionAddress(*engine_ref, c_string.as_ptr()) };
119            if address == 0 {
120                return None;
121            }
122            address as usize
123        };
124
125        Some(unsafe { std::mem::transmute_copy(&address) })
126    }
127}
128
129impl<R: RealScalar> Drop for Module<R> {
130    fn drop(&mut self) {
131        let engine_ref = self.engine.lock().unwrap();
132        unsafe {
133            LLVMDisposeExecutionEngine(*engine_ref);
134        }
135    }
136}
137
138impl<R: RealScalar> std::fmt::Display for Module<R> {
139    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
140        self.with_module(|module| module.print_to_string().to_string().fmt(f))
141    }
142}
143
144unsafe impl<R: RealScalar> Send for Module<R> {}