Skip to main content

compio_py_dynamic_openssl/
lib.rs

1// SPDX-License-Identifier: Apache-2.0 OR MulanPSL-2.0
2// Copyright 2026 Fantix King
3
4use std::{
5    ffi::OsStr,
6    io::{Read, Write},
7    net::IpAddr,
8};
9
10use compio_log::debug;
11pub use pyo3;
12use pyo3::{
13    ffi::{PyObject, c_str},
14    prelude::*,
15    types::PyDict,
16};
17
18use self::{
19    loader::get,
20    ssl::{HandshakeError, Ssl, SslStream},
21    sys as ffi,
22};
23
24pub mod bio;
25pub mod error;
26pub mod loader;
27pub mod ssl;
28pub mod sys;
29
30pub struct SSLContext {
31    ptr: *mut ffi::SSL_CTX,
32    pyobj: Py<PyAny>,
33}
34
35impl TryFrom<Bound<'_, PyAny>> for SSLContext {
36    type Error = PyErr;
37
38    fn try_from(obj: Bound<PyAny>) -> PyResult<Self> {
39        #[repr(C)]
40        struct PySSLContext {
41            ob_base: PyObject,
42            ctx: *mut ffi::SSL_CTX,
43        }
44
45        unsafe {
46            let ptr = obj.as_ptr() as *const PySSLContext;
47            let ptr = (*ptr).ctx;
48            if ptr.is_null() {
49                return Err(pyo3::exceptions::PyValueError::new_err(
50                    "SSLContext has null SSL_CTX",
51                ));
52            }
53            Ok(Self {
54                ptr,
55                pyobj: obj.unbind(),
56            })
57        }
58    }
59}
60
61impl SSLContext {
62    pub fn connect<S>(&self, domain: &str, stream: S) -> Result<SslStream<S>, HandshakeError<S>>
63    where
64        S: Read + Write,
65    {
66        let mut ssl = Ssl::new(self.ptr)?;
67        let ip = domain.parse::<IpAddr>().ok();
68        if ip.is_none() {
69            ssl.set_hostname(domain)?;
70        }
71        let res = Python::attach(|py| {
72            self.pyobj
73                .bind(py)
74                .getattr("check_hostname")?
75                .extract::<bool>()
76        });
77        match res {
78            Ok(true) => {
79                let param = ssl.param_mut();
80                match ip {
81                    Some(ip) => param.set_ip(ip)?,
82                    None => param.set_host(domain)?,
83                }
84            }
85            Ok(false) => {}
86            Err(e) => panic!("{e}"),
87        }
88
89        ssl.connect(stream)
90    }
91
92    pub fn accept<S>(&self, stream: S) -> Result<SslStream<S>, HandshakeError<S>>
93    where
94        S: Read + Write,
95    {
96        let ssl = Ssl::new(self.ptr)?;
97        ssl.accept(stream)
98    }
99}
100
101impl Clone for SSLContext {
102    fn clone(&self) -> Self {
103        Python::attach(|py| Self {
104            ptr: self.ptr,
105            pyobj: self.pyobj.clone_ref(py),
106        })
107    }
108}
109
110pub fn load_py(py: Python) -> PyResult<bool> {
111    if loader::is_loaded() {
112        return Ok(true);
113    }
114
115    let find_lib = c_str!(
116        r#"
117import inspect, ssl
118try:
119    lib = inspect.getfile(ssl.SSLSession)
120except TypeError:
121    import os, sysconfig
122    lib = os.path.join(sysconfig.get_config_var("LIBDIR"), sysconfig.get_config_var("LDLIBRARY"))
123"#
124    );
125    let locals = PyDict::new(py);
126    py.run(find_lib, None, Some(&locals))?;
127    let lib = locals
128        .get_item("lib")?
129        .expect("defined lib")
130        .extract::<String>()?;
131    match py.detach(|| loader::load(OsStr::new(&lib))) {
132        Ok(()) => Ok(true),
133        Err(loader::Error::AlreadyLoaded) => Ok(true),
134        Err(loader::Error::LibraryNotFound) => {
135            debug!("Failed to load OpenSSL: library not found");
136            Ok(false)
137        }
138        Err(loader::Error::IoError(_e)) => {
139            debug!("Failed to load OpenSSL: {_e}");
140            Ok(false)
141        }
142        Err(loader::Error::VersionTooOld) => {
143            debug!("Failed to load OpenSSL: version is too old");
144            Ok(false)
145        }
146        Err(loader::Error::Loader(_e)) => {
147            debug!("Failed to load OpenSSL: {_e}");
148            Ok(false)
149        }
150        #[cfg(windows)]
151        Err(loader::Error::PE(_e)) => {
152            debug!("Failed to load OpenSSL: {_e}");
153            Ok(false)
154        }
155    }
156}