Skip to main content

compio_py_dynamic_openssl/
loader.rs

1// SPDX-License-Identifier: Apache-2.0 OR MulanPSL-2.0
2// Copyright 2026 Fantix King
3
4use std::{
5    ffi::{OsStr, c_char, c_int, c_long, c_uchar, c_uint, c_ulong, c_void},
6    io,
7    sync::OnceLock,
8};
9
10use libloading::{Library, Symbol};
11
12use crate::sys::*;
13
14static OPENSSL: OnceLock<OpenSSL> = OnceLock::new();
15
16pub enum Error {
17    IoError(io::Error),
18    Loader(libloading::Error),
19    #[cfg(windows)]
20    PE(pelite::Error),
21    LibraryNotFound,
22    AlreadyLoaded,
23    VersionTooOld,
24}
25
26impl From<libloading::Error> for Error {
27    fn from(value: libloading::Error) -> Self {
28        Self::Loader(value)
29    }
30}
31
32impl From<io::Error> for Error {
33    fn from(value: io::Error) -> Self {
34        Self::IoError(value)
35    }
36}
37
38#[cfg(windows)]
39impl From<pelite::Error> for Error {
40    fn from(value: pelite::Error) -> Self {
41        Self::PE(value)
42    }
43}
44
45#[allow(bad_style)]
46pub struct OpenSSL {
47    lib: OsslLibraries,
48    pub version_num: c_ulong,
49
50    pub BIO_meth_new: unsafe extern "C" fn(type_: c_int, name: *const c_char) -> *mut BIO_METHOD,
51    pub BIO_meth_free: unsafe extern "C" fn(biom: *mut BIO_METHOD),
52    pub BIO_meth_set_write: unsafe extern "C" fn(
53        biom: *mut BIO_METHOD,
54        write: Option<unsafe extern "C" fn(*mut BIO, *const c_char, c_int) -> c_int>,
55    ) -> c_int,
56    pub BIO_meth_set_read: unsafe extern "C" fn(
57        biom: *mut BIO_METHOD,
58        read: Option<unsafe extern "C" fn(*mut BIO, *mut c_char, c_int) -> c_int>,
59    ) -> c_int,
60    pub BIO_meth_set_puts: unsafe extern "C" fn(
61        biom: *mut BIO_METHOD,
62        puts: Option<unsafe extern "C" fn(*mut BIO, *const c_char) -> c_int>,
63    ) -> c_int,
64    pub BIO_meth_set_ctrl: unsafe extern "C" fn(
65        biom: *mut BIO_METHOD,
66        ctrl: Option<unsafe extern "C" fn(*mut BIO, c_int, c_long, *mut c_void) -> c_long>,
67    ) -> c_int,
68    pub BIO_meth_set_create: unsafe extern "C" fn(
69        biom: *mut BIO_METHOD,
70        create: Option<unsafe extern "C" fn(*mut BIO) -> c_int>,
71    ) -> c_int,
72    pub BIO_meth_set_destroy: unsafe extern "C" fn(
73        biom: *mut BIO_METHOD,
74        destroy: Option<unsafe extern "C" fn(*mut BIO) -> c_int>,
75    ) -> c_int,
76
77    pub BIO_new: unsafe extern "C" fn(type_: *const BIO_METHOD) -> *mut BIO,
78    pub BIO_get_data: unsafe extern "C" fn(b: *mut BIO) -> *mut c_void,
79    pub BIO_set_data: unsafe extern "C" fn(b: *mut BIO, data: *mut c_void),
80    pub BIO_set_init: unsafe extern "C" fn(b: *mut BIO, init: c_int),
81    pub BIO_set_flags: unsafe extern "C" fn(b: *mut BIO, flags: c_int),
82    pub BIO_clear_flags: unsafe extern "C" fn(b: *mut BIO, flags: c_int),
83
84    pub SSL_new: unsafe extern "C" fn(ctx: *mut SSL_CTX) -> *mut SSL,
85    pub SSL_free: unsafe extern "C" fn(ssl: *mut SSL),
86    pub SSL_connect: unsafe extern "C" fn(ssl: *mut SSL) -> i32,
87    pub SSL_accept: unsafe extern "C" fn(ssl: *mut SSL) -> c_int,
88    pub SSL_ctrl:
89        unsafe extern "C" fn(ssl: *mut SSL, cmd: c_int, larg: c_long, parg: *mut c_void) -> c_long,
90    pub SSL_do_handshake: unsafe extern "C" fn(ssl: *mut SSL) -> c_int,
91    pub SSL_set_bio: unsafe extern "C" fn(ssl: *mut SSL, rbio: *mut BIO, wbio: *mut BIO),
92    pub SSL_get_rbio: unsafe extern "C" fn(ssl: *mut SSL) -> *mut BIO,
93    pub SSL_get_error: unsafe extern "C" fn(ssl: *mut SSL, ret: c_int) -> c_int,
94    pub SSL_read_ex: unsafe extern "C" fn(
95        ssl: *mut SSL,
96        buf: *mut c_void,
97        num: usize,
98        readbytes: *mut usize,
99    ) -> c_int,
100    pub SSL_write_ex: unsafe extern "C" fn(
101        ssl: *mut SSL,
102        buf: *const c_void,
103        num: usize,
104        written: *mut usize,
105    ) -> c_int,
106    pub SSL_get0_alpn_selected:
107        unsafe extern "C" fn(s: *const SSL, data: *mut *const c_uchar, len: *mut c_uint),
108    pub SSL_get0_param: unsafe extern "C" fn(ssl: *mut SSL) -> *mut X509_VERIFY_PARAM,
109    pub SSL_shutdown: unsafe extern "C" fn(*mut SSL) -> c_int,
110
111    pub X509_VERIFY_PARAM_set1_host: unsafe extern "C" fn(
112        param: *mut X509_VERIFY_PARAM,
113        name: *const c_char,
114        namelen: isize,
115    ) -> c_int,
116    pub X509_VERIFY_PARAM_set1_ip: unsafe extern "C" fn(
117        param: *mut X509_VERIFY_PARAM,
118        ip: *const c_uchar,
119        iplen: isize,
120    ) -> c_int,
121
122    pub ERR_get_error_all: Option<
123        unsafe extern "C" fn(
124            file: *mut *const c_char,
125            line: *mut c_int,
126            func: *mut *const c_char,
127            data: *mut *const c_char,
128            flags: *mut c_int,
129        ) -> c_ulong,
130    >,
131    pub ERR_get_error_line_data: Option<
132        unsafe extern "C" fn(
133            file: *mut *const c_char,
134            line: *mut c_int,
135            data: *mut *const c_char,
136            flags: *mut c_int,
137        ) -> c_ulong,
138    >,
139    pub ERR_func_error_string: Option<unsafe extern "C" fn(err: c_ulong) -> *const c_char>,
140    pub ERR_lib_error_string: unsafe extern "C" fn(err: c_ulong) -> *const c_char,
141    pub ERR_reason_error_string: unsafe extern "C" fn(err: c_ulong) -> *const c_char,
142}
143
144struct OsslLibraries(Vec<Library>);
145
146impl OsslLibraries {
147    #[cfg(windows)]
148    fn from(filename: &OsStr) -> Result<Self, Error> {
149        use std::{ffi::OsString, fs, os::windows::ffi::OsStringExt, path::Path};
150
151        use pelite::pe::{Pe, PeFile};
152        use windows_sys::Win32::{
153            Foundation::MAX_PATH,
154            System::ProcessStatus::{EnumProcessModules, GetModuleFileNameExW},
155        };
156
157        let mut linked_dlls = Vec::new();
158        let file_data = fs::read(filename)?;
159        let pe = PeFile::from_bytes(&file_data)?;
160        if let Ok(imports) = pe.imports() {
161            for desc in imports {
162                if let Ok(dll_name) = desc.dll_name()
163                    && let Ok(dll_name) = dll_name.to_str()
164                {
165                    let lower_name = dll_name.to_lowercase();
166                    if lower_name.contains("ssl") || lower_name.contains("crypto") {
167                        linked_dlls.push(OsString::from(lower_name));
168                    }
169                }
170            }
171        }
172
173        let mut libs = Vec::new();
174        let process = unsafe { windows_sys::Win32::System::Threading::GetCurrentProcess() };
175        let mut h_mods = vec![0isize; 1024];
176        let mut cb_needed = 0u32;
177        if unsafe {
178            EnumProcessModules(
179                process,
180                h_mods.as_mut_ptr() as *mut _,
181                (h_mods.len() * size_of::<isize>()) as u32,
182                &mut cb_needed,
183            )
184        } != 0
185        {
186            let count = (cb_needed as usize) / size_of::<isize>();
187            for i in 0..count {
188                let mut sz_mod_name = vec![0u16; MAX_PATH as usize];
189                if unsafe {
190                    GetModuleFileNameExW(
191                        process,
192                        h_mods[i] as *mut _,
193                        sz_mod_name.as_mut_ptr(),
194                        MAX_PATH,
195                    )
196                } > 0
197                {
198                    let len = sz_mod_name
199                        .iter()
200                        .position(|&c| c == 0)
201                        .unwrap_or(sz_mod_name.len());
202                    let path = OsString::from_wide(&sz_mod_name[..len]);
203                    if let Some(name) = Path::new(&path).file_name()
204                        && linked_dlls.iter().any(|n| n == name)
205                    {
206                        let lib = libloading::os::windows::Library::open_already_loaded(path)?;
207                        libs.push(lib.into());
208                    }
209                }
210            }
211        }
212        Ok(Self(libs))
213    }
214
215    #[cfg(unix)]
216    fn from(filename: &OsStr) -> Result<Self, Error> {
217        cfg_if::cfg_if! {
218            if #[cfg(any(
219                target_os = "linux",
220                target_os = "android",
221                target_os = "emscripten",
222                target_os = "solaris",
223                target_os = "illumos",
224                target_os = "fuchsia",
225                target_os = "hurd",
226            ))] {
227                const RTLD_NOLOAD: c_int = 0x4;
228            } else if #[cfg(any(
229                target_os = "macos",
230                target_os = "ios",
231                target_os = "tvos",
232                target_os = "visionos",
233                target_os = "watchos",
234                target_os = "cygwin",
235            ))] {
236                const RTLD_NOLOAD: c_int = 0x10;
237            } else if #[cfg(any(
238                target_os = "freebsd",
239                target_os = "dragonfly",
240                target_os = "netbsd",
241            ))] {
242                const RTLD_NOLOAD: c_int = 0x2000;
243            } else {
244                compile_error!(
245                    "Target has no known `RTLD_NOLOAD` value. Please submit an issue or PR adding it."
246                );
247            }
248        }
249        unsafe {
250            let lib = libloading::os::unix::Library::open(
251                Some(filename),
252                RTLD_NOLOAD | libloading::os::unix::RTLD_LAZY,
253            )?;
254            Ok(Self(vec![lib.into()]))
255        }
256    }
257
258    pub unsafe fn get<T>(&self, symbol: &[u8]) -> Result<Symbol<'_, T>, Error> {
259        let mut res = Err(Error::LibraryNotFound);
260        for lib in self.0.iter() {
261            match unsafe { lib.get::<T>(symbol) } {
262                Ok(sym) => return Ok(sym),
263                Err(e) => res = Err(e.into()),
264            }
265        }
266        res
267    }
268}
269
270impl OpenSSL {
271    fn load(filename: &OsStr) -> Result<Self, Error> {
272        let lib = OsslLibraries::from(filename)?;
273
274        let mut rv = Self {
275            BIO_meth_new: *unsafe { lib.get(b"BIO_meth_new")? },
276            BIO_meth_free: *unsafe { lib.get(b"BIO_meth_free")? },
277            BIO_meth_set_write: *unsafe { lib.get(b"BIO_meth_set_write")? },
278            BIO_meth_set_read: *unsafe { lib.get(b"BIO_meth_set_read")? },
279            BIO_meth_set_puts: *unsafe { lib.get(b"BIO_meth_set_puts")? },
280            BIO_meth_set_ctrl: *unsafe { lib.get(b"BIO_meth_set_ctrl")? },
281            BIO_meth_set_create: *unsafe { lib.get(b"BIO_meth_set_create")? },
282            BIO_meth_set_destroy: *unsafe { lib.get(b"BIO_meth_set_destroy")? },
283
284            BIO_new: *unsafe { lib.get(b"BIO_new")? },
285            BIO_get_data: *unsafe { lib.get(b"BIO_get_data")? },
286            BIO_set_data: *unsafe { lib.get(b"BIO_set_data")? },
287            BIO_set_init: *unsafe { lib.get(b"BIO_set_init")? },
288            BIO_set_flags: *unsafe { lib.get(b"BIO_set_flags")? },
289            BIO_clear_flags: *unsafe { lib.get(b"BIO_clear_flags")? },
290
291            SSL_new: *unsafe { lib.get(b"SSL_new")? },
292            SSL_free: *unsafe { lib.get(b"SSL_free")? },
293            SSL_connect: *unsafe { lib.get(b"SSL_connect")? },
294            SSL_accept: *unsafe { lib.get(b"SSL_accept")? },
295            SSL_ctrl: *unsafe { lib.get(b"SSL_ctrl")? },
296            SSL_do_handshake: *unsafe { lib.get(b"SSL_do_handshake")? },
297            SSL_set_bio: *unsafe { lib.get(b"SSL_set_bio")? },
298            SSL_get_rbio: *unsafe { lib.get(b"SSL_get_rbio")? },
299            SSL_get_error: *unsafe { lib.get(b"SSL_get_error")? },
300            SSL_read_ex: *unsafe { lib.get(b"SSL_read_ex")? },
301            SSL_write_ex: *unsafe { lib.get(b"SSL_write_ex")? },
302            SSL_get0_alpn_selected: *unsafe { lib.get(b"SSL_get0_alpn_selected")? },
303            SSL_get0_param: *unsafe { lib.get(b"SSL_get0_param")? },
304            SSL_shutdown: *unsafe { lib.get(b"SSL_shutdown")? },
305
306            X509_VERIFY_PARAM_set1_host: *unsafe { lib.get(b"X509_VERIFY_PARAM_set1_host")? },
307            X509_VERIFY_PARAM_set1_ip: *unsafe { lib.get(b"X509_VERIFY_PARAM_set1_ip")? },
308
309            ERR_get_error_all: None,
310            ERR_get_error_line_data: None,
311            ERR_func_error_string: None,
312            ERR_lib_error_string: *unsafe { lib.get(b"ERR_lib_error_string")? },
313            ERR_reason_error_string: *unsafe { lib.get(b"ERR_reason_error_string")? },
314
315            version_num: unsafe {
316                lib.get::<unsafe extern "C" fn() -> c_ulong>(b"OpenSSL_version_num")?()
317            },
318
319            lib,
320        };
321        if rv.version_num < 0x10100010 {
322            return Err(Error::VersionTooOld);
323        }
324        if rv.version_num < 0x30000000 {
325            rv.ERR_get_error_line_data = Some(*unsafe { rv.lib.get(b"ERR_get_error_line_data")? });
326            rv.ERR_func_error_string = Some(*unsafe { rv.lib.get(b"ERR_func_error_string")? });
327        } else {
328            rv.ERR_get_error_all = Some(*unsafe { rv.lib.get(b"ERR_get_error_all")? });
329        }
330        Ok(rv)
331    }
332}
333
334pub fn is_loaded() -> bool {
335    OPENSSL.get().is_some()
336}
337
338pub fn load(filename: &OsStr) -> Result<(), Error> {
339    if is_loaded() {
340        return Err(Error::AlreadyLoaded);
341    }
342    OPENSSL
343        .set(OpenSSL::load(filename)?)
344        .map_err(|_| Error::AlreadyLoaded)
345}
346
347pub fn get() -> &'static OpenSSL {
348    OPENSSL.get().expect("OpenSSL library not loaded")
349}