1use std::{fs::File, os::unix::prelude::FromRawFd, io::Write, ffi::{CString, CStr}, process::{Command, Stdio}};
2
3use loader::{Library, c_str};
4
5#[cfg(not(any(
6 all(target_os = "linux", target_env = "gnu"),
7 all(target_os = "linux", target_env = "musl"),
8 target_os = "freebsd",
9)))]
10compile_error!("The libc crate only has the memfd_create syscall under linux-gnu, linux-musl, and freebsd.");
11
12unsafe fn make_fd(name: &CStr, flags: u32) -> Result<i32, std::io::Error> {
13 let fd = libc::memfd_create(name.as_ptr(), flags);
14 if fd >= 0 {
15 Ok(fd)
16 } else {
17 Err(std::io::Error::last_os_error())
18 }
19}
20
21#[derive(Debug)]
22pub enum CompileError {
23 IOError(std::io::Error),
24 CompileError(CString),
25 DLError(CString),
26}
27
28impl From<std::io::Error> for CompileError {
29 fn from(err: std::io::Error) -> Self {
30 CompileError::IOError(err)
31 }
32}
33
34impl From<CString> for CompileError {
35 fn from(err: CString) -> Self {
36 CompileError::DLError(err)
37 }
38}
39
40pub enum OptimizationLevel {
41 NoOptimization,
42 One,
43 Two,
44 Three,
45 ForDebugging,
46}
47
48impl OptimizationLevel {
49 fn arg(&self) -> &'static str {
50 match self {
51 OptimizationLevel::NoOptimization => "-O0",
52 OptimizationLevel::One => "-O1",
53 OptimizationLevel::Two => "-O2",
54 OptimizationLevel::Three => "-O3",
55 OptimizationLevel::ForDebugging => "-Og",
56 }
57 }
58}
59
60pub fn compile(compiler: &str, source: &str, language: &str, optimization: OptimizationLevel, debug_symbols: bool) -> Result<Library, CompileError> {
61 let pid = unsafe {
62 libc::getpid()
63 };
64 let source_fd = unsafe {
65 make_fd(c_str!("source"), 0)
66 }?;
67 let code_fd = unsafe {
68 make_fd(c_str!("code"), 0)
69 }?;
70
71 let source_path = format!("/proc/{pid}/fd/{source_fd}");
72 let code_path = format!("/proc/{pid}/fd/{code_fd}");
73
74 let mut source_file = unsafe { File::from_raw_fd(source_fd) };
75
76 source_file.write_all(source.as_bytes())?;
77 source_file.flush()?;
78
79 let mut args = vec![
80 "-shared", optimization.arg(),
81 "-x", language,
82 &source_path,
83 "-o", &code_path,
84 ];
85 if debug_symbols {
86 args.push("-g");
87 }
88
89 let handle = Command::new(compiler)
90 .stdin(Stdio::null())
91 .stdout(Stdio::piped())
92 .stderr(Stdio::piped())
93 .args(args)
94 .spawn()?;
95
96 let mut x = handle.wait_with_output()?;
97
98 if !x.status.success() {
99 let mut errmsg = std::mem::take(&mut x.stdout);
100 errmsg.append(&mut x.stderr);
101 return Err(CompileError::CompileError(CString::new(errmsg).unwrap()));
102 }
103
104 let code_path = CString::new(code_path).unwrap();
105
106 let handle = Library::new(&code_path, false)?;
107
108 Ok(handle)
109}
110
111#[test]
112fn main() {
113 let handle = compile(
114 "gcc",
115 "int add(int x, int y) { return x + y; }",
116 "c",
117 OptimizationLevel::ForDebugging,
118 true
119 ).unwrap();
120
121 let func = handle.sym_func::<extern "C" fn(i32, i32) -> i32>(c_str!("add")).unwrap();
122 let func = unsafe { func.assert_callable_shared() };
123
124 assert_eq!(func(3, 4), 7);
125}
126
127
128#[test]
129fn float() {
130 let handle = compile(
131 "gcc",
132 "void mandelbrot(
133 _Complex double *result,
134 const _Complex double *c,
135 const _Complex double *z
136 ) {
137 *result = *z * *z + *c;
138 }
139 ",
140 "c",
141 OptimizationLevel::ForDebugging,
142 true
143 ).unwrap();
144
145 use num_complex::Complex;
146
147 let func = handle.sym_func::<
148 extern "C" fn(*mut Complex<f64>, *const Complex<f64>, *const Complex<f64>)
149 >(c_str!("mandelbrot")).unwrap();
150 let func = unsafe { func.assert_callable_shared() };
151
152 let mut parameter = Complex { re: 0.0, im: 0.0 };
153 let negative_three_fourths = Complex { re: -0.75, im: 0.0 };
154
155 let mut result = Complex::default();
156 func(&mut result, &negative_three_fourths, ¶meter);
157 assert_eq!(result, negative_three_fourths);
158 parameter = result;
159 func(&mut result, &negative_three_fourths, ¶meter);
160 assert_eq!(result, negative_three_fourths * negative_three_fourths + negative_three_fourths);
161}