qudit_expr/codegen/
module.rs

1use inkwell::context::Context;
2use qudit_core::RealScalar;
3
4use std::borrow::Cow;
5use std::ffi::{CStr, CString};
6use std::mem::{ManuallyDrop, MaybeUninit};
7use std::sync::Mutex;
8
9use inkwell::targets::{InitializationConfig, Target};
10use llvm_sys::core::{LLVMContextCreate, LLVMModuleCreateWithNameInContext};
11use llvm_sys::execution_engine::{
12    LLVMCreateJITCompilerForModule, LLVMDisposeExecutionEngine, LLVMExecutionEngineRef,
13    LLVMGetFunctionAddress, LLVMLinkInMCJIT,
14};
15use llvm_sys::prelude::LLVMModuleRef;
16
17use inkwell::module::Module as InkwellModule;
18
19use crate::WriteFunc;
20
21use super::process_name_for_gen;
22
23pub(crate) fn to_c_str(mut s: &str) -> Cow<'_, CStr> {
24    if s.is_empty() {
25        s = "\0";
26    }
27
28    // Start from the end of the string as it's the most likely place to find a null byte
29    if !s.chars().rev().any(|ch| ch == '\0') {
30        return Cow::from(CString::new(s).expect("unreachable since null bytes are checked"));
31    }
32
33    unsafe { Cow::from(CStr::from_ptr(s.as_ptr() as *const _)) }
34}
35
36fn convert_c_string(c_str: *mut i8) -> String {
37    // Safety: Ensure that c_str is not null and points to a valid null-terminated string.
38    assert!(!c_str.is_null());
39
40    // Convert the raw pointer to a CStr, which will handle the null termination.
41    let c_str = unsafe { CStr::from_ptr(c_str) };
42
43    // Convert CStr to String
44    c_str.to_string_lossy().into_owned()
45}
46
47#[derive(Debug)]
48pub struct Module<R: RealScalar> {
49    engine: Mutex<LLVMExecutionEngineRef>,
50    module: Mutex<LLVMModuleRef>,
51    context: Context,
52    phantom: std::marker::PhantomData<R>,
53}
54
55impl<R: RealScalar> Module<R> {
56    pub fn new(module_name: &str) -> Self {
57        unsafe {
58            let core_context = LLVMContextCreate();
59
60            let c_string = to_c_str(module_name);
61            let core_module = LLVMModuleCreateWithNameInContext(c_string.as_ptr(), core_context);
62            // LLVMLinkInMCJIT();
63            match Target::initialize_native(&InitializationConfig::default()) {
64                Ok(_) => {}
65                Err(string) => panic!("Error initializing native target: {:?}", string),
66            }
67
68            let mut execution_engine = MaybeUninit::uninit();
69            let mut err_string = MaybeUninit::uninit();
70            LLVMLinkInMCJIT();
71
72            let code = LLVMCreateJITCompilerForModule(
73                execution_engine.as_mut_ptr(),
74                core_module,
75                3,
76                err_string.as_mut_ptr(),
77            );
78
79            if code == 1 {
80                panic!(
81                    "Error creating JIT compiler: {:?}",
82                    convert_c_string(err_string.assume_init())
83                );
84            }
85
86            let execution_engine = execution_engine.assume_init();
87
88            Module {
89                context: Context::new(core_context),
90                module: core_module.into(),
91                engine: execution_engine.into(),
92                phantom: std::marker::PhantomData,
93            }
94        }
95    }
96
97    pub fn with_module<'a, F, G>(&self, f: F) -> G
98    where
99        F: FnOnce(ManuallyDrop<InkwellModule<'a>>) -> G,
100    {
101        let module_ref = self.module.lock().unwrap();
102        let module = unsafe { ManuallyDrop::new(InkwellModule::new(*module_ref)) };
103        f(module)
104    }
105
106    pub fn context(&self) -> &Context {
107        &self.context
108    }
109
110    // pub fn get_function<'a>(&'a self, name: &str) -> Option<WriteFuncWithLifeTime<'a, R>> {
111    pub fn get_function(&self, name: &str) -> Option<WriteFunc<R>> {
112        let name = process_name_for_gen(name);
113        let engine_ref = self.engine.lock().unwrap();
114
115        assert!(!(*engine_ref).is_null());
116
117        let address = {
118            let c_string = to_c_str(&name);
119            let address = unsafe { LLVMGetFunctionAddress(*engine_ref, c_string.as_ptr()) };
120            if address == 0 {
121                return None;
122            }
123            address as usize
124        };
125
126        Some(unsafe { std::mem::transmute_copy(&address) })
127    }
128}
129
130impl<R: RealScalar> Drop for Module<R> {
131    fn drop(&mut self) {
132        let engine_ref = self.engine.lock().unwrap();
133        unsafe {
134            LLVMDisposeExecutionEngine(*engine_ref);
135        }
136    }
137}
138
139impl<R: RealScalar> std::fmt::Display for Module<R> {
140    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141        self.with_module(|module| module.print_to_string().to_string().fmt(f))
142    }
143}
144
145unsafe impl<R: RealScalar> Send for Module<R> {}