qudit_expr/codegen/
module.rs1use inkwell::context::Context;
2use qudit_core::RealScalar;
3
4use std::borrow::Cow;
5use std::ffi::{CStr, CString};
6use std::mem::{ManuallyDrop, MaybeUninit};
7use std::sync::Mutex;
8
9use inkwell::targets::{InitializationConfig, Target};
10use llvm_sys::core::{LLVMContextCreate, LLVMModuleCreateWithNameInContext};
11use llvm_sys::execution_engine::{
12 LLVMCreateJITCompilerForModule, LLVMDisposeExecutionEngine, LLVMExecutionEngineRef,
13 LLVMGetFunctionAddress,
14};
15use llvm_sys::prelude::LLVMModuleRef;
16
17use inkwell::module::Module as InkwellModule;
18
19use crate::WriteFunc;
20
21use super::process_name_for_gen;
22
23pub(crate) fn to_c_str(mut s: &str) -> Cow<'_, CStr> {
24 if s.is_empty() {
25 s = "\0";
26 }
27
28 if !s.chars().rev().any(|ch| ch == '\0') {
30 return Cow::from(CString::new(s).expect("unreachable since null bytes are checked"));
31 }
32
33 unsafe { Cow::from(CStr::from_ptr(s.as_ptr() as *const _)) }
34}
35
36fn convert_c_string(c_str: *mut i8) -> String {
37 assert!(!c_str.is_null());
39
40 let c_str = unsafe { CStr::from_ptr(c_str) };
42
43 c_str.to_string_lossy().into_owned()
45}
46
47#[derive(Debug)]
48pub struct Module<R: RealScalar> {
49 engine: Mutex<LLVMExecutionEngineRef>,
50 module: Mutex<LLVMModuleRef>,
51 context: Context,
52 phantom: std::marker::PhantomData<R>,
53}
54
55impl<R: RealScalar> Module<R> {
56 pub fn new(module_name: &str) -> Self {
57 unsafe {
58 let core_context = LLVMContextCreate();
59
60 let c_string = to_c_str(module_name);
61 let core_module = LLVMModuleCreateWithNameInContext(c_string.as_ptr(), core_context);
62 match Target::initialize_native(&InitializationConfig::default()) {
64 Ok(_) => {}
65 Err(string) => panic!("Error initializing native target: {:?}", string),
66 }
67
68 let mut execution_engine = MaybeUninit::uninit();
69 let mut err_string = MaybeUninit::uninit();
70
71 let code = LLVMCreateJITCompilerForModule(
72 execution_engine.as_mut_ptr(),
73 core_module,
74 3,
75 err_string.as_mut_ptr(),
76 );
77
78 if code == 1 {
79 panic!(
80 "Error creating JIT compiler: {:?}",
81 convert_c_string(err_string.assume_init())
82 );
83 }
84
85 let execution_engine = execution_engine.assume_init();
86
87 Module {
88 context: Context::new(core_context),
89 module: core_module.into(),
90 engine: execution_engine.into(),
91 phantom: std::marker::PhantomData,
92 }
93 }
94 }
95
96 pub fn with_module<'a, F, G>(&self, f: F) -> G
97 where
98 F: FnOnce(ManuallyDrop<InkwellModule<'a>>) -> G,
99 {
100 let module_ref = self.module.lock().unwrap();
101 let module = unsafe { ManuallyDrop::new(InkwellModule::new(*module_ref)) };
102 f(module)
103 }
104
105 pub fn context(&self) -> &Context {
106 &self.context
107 }
108
109 pub fn get_function(&self, name: &str) -> Option<WriteFunc<R>> {
111 let name = process_name_for_gen(name);
112 let engine_ref = self.engine.lock().unwrap();
113
114 assert!(!(*engine_ref).is_null());
115
116 let address = {
117 let c_string = to_c_str(&name);
118 let address = unsafe { LLVMGetFunctionAddress(*engine_ref, c_string.as_ptr()) };
119 if address == 0 {
120 return None;
121 }
122 address as usize
123 };
124
125 Some(unsafe { std::mem::transmute_copy(&address) })
126 }
127}
128
129impl<R: RealScalar> Drop for Module<R> {
130 fn drop(&mut self) {
131 let engine_ref = self.engine.lock().unwrap();
132 unsafe {
133 LLVMDisposeExecutionEngine(*engine_ref);
134 }
135 }
136}
137
138impl<R: RealScalar> std::fmt::Display for Module<R> {
139 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
140 self.with_module(|module| module.print_to_string().to_string().fmt(f))
141 }
142}
143
144unsafe impl<R: RealScalar> Send for Module<R> {}