compile_in_memory/
lib.rs

1use std::{fs::File, os::unix::prelude::FromRawFd, io::Write, ffi::{CString, CStr}, process::{Command, Stdio}};
2
3use loader::{Library, c_str};
4
5#[cfg(not(any(
6    all(target_os = "linux", target_env = "gnu"),
7    all(target_os = "linux", target_env = "musl"),
8    target_os = "freebsd",
9)))]
10compile_error!("The libc crate only has the memfd_create syscall under linux-gnu, linux-musl, and freebsd.");
11
12unsafe fn make_fd(name: &CStr, flags: u32) -> Result<i32, std::io::Error> {
13    let fd = libc::memfd_create(name.as_ptr(), flags);
14    if fd >= 0 {
15        Ok(fd)
16    } else {
17        Err(std::io::Error::last_os_error())
18    }
19}
20
21#[derive(Debug)]
22pub enum CompileError {
23    IOError(std::io::Error),
24    CompileError(CString),
25    DLError(CString),
26}
27
28impl From<std::io::Error> for CompileError {
29    fn from(err: std::io::Error) -> Self {
30        CompileError::IOError(err)
31    }
32}
33
34impl From<CString> for CompileError {
35    fn from(err: CString) -> Self {
36        CompileError::DLError(err)
37    }
38}
39
40pub enum OptimizationLevel {
41    NoOptimization,
42    One,
43    Two,
44    Three,
45    ForDebugging,
46}
47
48impl OptimizationLevel {
49    fn arg(&self) -> &'static str {
50        match self {
51            OptimizationLevel::NoOptimization => "-O0",
52            OptimizationLevel::One => "-O1",
53            OptimizationLevel::Two => "-O2",
54            OptimizationLevel::Three => "-O3",
55            OptimizationLevel::ForDebugging => "-Og",
56        }
57    }
58}
59
60pub fn compile(compiler: &str, source: &str, language: &str, optimization: OptimizationLevel, debug_symbols: bool) -> Result<Library, CompileError> {
61    let pid = unsafe {
62        libc::getpid()
63    };
64    let source_fd = unsafe {
65        make_fd(c_str!("source"), 0)
66    }?;
67    let code_fd = unsafe {
68        make_fd(c_str!("code"), 0)
69    }?;
70
71    let source_path = format!("/proc/{pid}/fd/{source_fd}");
72    let code_path = format!("/proc/{pid}/fd/{code_fd}");
73
74    let mut source_file = unsafe { File::from_raw_fd(source_fd) };
75
76    source_file.write_all(source.as_bytes())?;
77    source_file.flush()?;
78
79    let mut args = vec![
80        "-shared", optimization.arg(),
81        "-x", language,
82        &source_path,
83        "-o", &code_path,
84    ];
85    if debug_symbols {
86        args.push("-g");
87    }
88    
89    let handle = Command::new(compiler)
90        .stdin(Stdio::null())
91        .stdout(Stdio::piped())
92        .stderr(Stdio::piped())
93        .args(args)
94        .spawn()?;
95
96    let mut x = handle.wait_with_output()?;
97
98    if !x.status.success() {
99        let mut errmsg = std::mem::take(&mut x.stdout);
100        errmsg.append(&mut x.stderr);
101        return Err(CompileError::CompileError(CString::new(errmsg).unwrap()));
102    }
103
104    let code_path = CString::new(code_path).unwrap();
105
106    let handle = Library::new(&code_path, false)?;
107    
108    Ok(handle)
109}
110
111#[test]
112fn main() {
113    let handle = compile(
114        "gcc",
115        "int add(int x, int y) { return x + y; }",
116        "c",
117        OptimizationLevel::ForDebugging,
118        true
119    ).unwrap();
120    
121    let func = handle.sym_func::<extern "C" fn(i32, i32) -> i32>(c_str!("add")).unwrap();
122    let func = unsafe { func.assert_callable_shared() };
123
124    assert_eq!(func(3, 4), 7);
125}
126
127
128#[test]
129fn float() {
130    let handle = compile(
131        "gcc",
132        "void mandelbrot(
133            _Complex double *result,
134            const _Complex double *c,
135            const _Complex double *z
136        ) {
137            *result = *z * *z + *c;
138        }
139        ",
140        "c",
141        OptimizationLevel::ForDebugging,
142        true
143    ).unwrap();
144
145    use num_complex::Complex;
146    
147    let func = handle.sym_func::<
148        extern "C" fn(*mut Complex<f64>, *const Complex<f64>, *const Complex<f64>)
149    >(c_str!("mandelbrot")).unwrap();
150    let func = unsafe { func.assert_callable_shared() };
151
152    let mut parameter = Complex { re: 0.0, im: 0.0 };
153    let negative_three_fourths = Complex { re: -0.75, im: 0.0 };
154
155    let mut result = Complex::default();
156    func(&mut result, &negative_three_fourths, &parameter);
157    assert_eq!(result, negative_three_fourths);
158    parameter = result;
159    func(&mut result, &negative_three_fourths, &parameter);
160    assert_eq!(result, negative_three_fourths * negative_three_fourths + negative_three_fourths);
161}