qudit_expr/codegen/
module.rs1use inkwell::context::Context;
2use qudit_core::RealScalar;
3
4use std::borrow::Cow;
5use std::ffi::{CStr, CString};
6use std::mem::{ManuallyDrop, MaybeUninit};
7use std::sync::Mutex;
8
9use inkwell::targets::{InitializationConfig, Target};
10use llvm_sys::core::{LLVMContextCreate, LLVMModuleCreateWithNameInContext};
11use llvm_sys::execution_engine::{
12 LLVMCreateJITCompilerForModule, LLVMDisposeExecutionEngine, LLVMExecutionEngineRef,
13 LLVMGetFunctionAddress, LLVMLinkInMCJIT,
14};
15use llvm_sys::prelude::LLVMModuleRef;
16
17use inkwell::module::Module as InkwellModule;
18
19use crate::WriteFunc;
20
21use super::process_name_for_gen;
22
23pub(crate) fn to_c_str(mut s: &str) -> Cow<'_, CStr> {
24 if s.is_empty() {
25 s = "\0";
26 }
27
28 if !s.chars().rev().any(|ch| ch == '\0') {
30 return Cow::from(CString::new(s).expect("unreachable since null bytes are checked"));
31 }
32
33 unsafe { Cow::from(CStr::from_ptr(s.as_ptr() as *const _)) }
34}
35
36fn convert_c_string(c_str: *mut i8) -> String {
37 assert!(!c_str.is_null());
39
40 let c_str = unsafe { CStr::from_ptr(c_str) };
42
43 c_str.to_string_lossy().into_owned()
45}
46
47#[derive(Debug)]
48pub struct Module<R: RealScalar> {
49 engine: Mutex<LLVMExecutionEngineRef>,
50 module: Mutex<LLVMModuleRef>,
51 context: Context,
52 phantom: std::marker::PhantomData<R>,
53}
54
55impl<R: RealScalar> Module<R> {
56 pub fn new(module_name: &str) -> Self {
57 unsafe {
58 let core_context = LLVMContextCreate();
59
60 let c_string = to_c_str(module_name);
61 let core_module = LLVMModuleCreateWithNameInContext(c_string.as_ptr(), core_context);
62 match Target::initialize_native(&InitializationConfig::default()) {
64 Ok(_) => {}
65 Err(string) => panic!("Error initializing native target: {:?}", string),
66 }
67
68 let mut execution_engine = MaybeUninit::uninit();
69 let mut err_string = MaybeUninit::uninit();
70 LLVMLinkInMCJIT();
71
72 let code = LLVMCreateJITCompilerForModule(
73 execution_engine.as_mut_ptr(),
74 core_module,
75 3,
76 err_string.as_mut_ptr(),
77 );
78
79 if code == 1 {
80 panic!(
81 "Error creating JIT compiler: {:?}",
82 convert_c_string(err_string.assume_init())
83 );
84 }
85
86 let execution_engine = execution_engine.assume_init();
87
88 Module {
89 context: Context::new(core_context),
90 module: core_module.into(),
91 engine: execution_engine.into(),
92 phantom: std::marker::PhantomData,
93 }
94 }
95 }
96
97 pub fn with_module<'a, F, G>(&self, f: F) -> G
98 where
99 F: FnOnce(ManuallyDrop<InkwellModule<'a>>) -> G,
100 {
101 let module_ref = self.module.lock().unwrap();
102 let module = unsafe { ManuallyDrop::new(InkwellModule::new(*module_ref)) };
103 f(module)
104 }
105
106 pub fn context(&self) -> &Context {
107 &self.context
108 }
109
110 pub fn get_function(&self, name: &str) -> Option<WriteFunc<R>> {
112 let name = process_name_for_gen(name);
113 let engine_ref = self.engine.lock().unwrap();
114
115 assert!(!(*engine_ref).is_null());
116
117 let address = {
118 let c_string = to_c_str(&name);
119 let address = unsafe { LLVMGetFunctionAddress(*engine_ref, c_string.as_ptr()) };
120 if address == 0 {
121 return None;
122 }
123 address as usize
124 };
125
126 Some(unsafe { std::mem::transmute_copy(&address) })
127 }
128}
129
130impl<R: RealScalar> Drop for Module<R> {
131 fn drop(&mut self) {
132 let engine_ref = self.engine.lock().unwrap();
133 unsafe {
134 LLVMDisposeExecutionEngine(*engine_ref);
135 }
136 }
137}
138
139impl<R: RealScalar> std::fmt::Display for Module<R> {
140 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141 self.with_module(|module| module.print_to_string().to_string().fmt(f))
142 }
143}
144
145unsafe impl<R: RealScalar> Send for Module<R> {}