compio_py_dynamic_openssl/
lib.rs1use std::{
5 ffi::OsStr,
6 io::{Read, Write},
7 net::IpAddr,
8};
9
10use compio_log::debug;
11pub use pyo3;
12use pyo3::{
13 ffi::{PyObject, c_str},
14 prelude::*,
15 types::PyDict,
16};
17
18use self::{
19 loader::get,
20 ssl::{HandshakeError, Ssl, SslStream},
21 sys as ffi,
22};
23
24pub mod bio;
25pub mod error;
26pub mod loader;
27pub mod ssl;
28pub mod sys;
29
30pub struct SSLContext {
31 ptr: *mut ffi::SSL_CTX,
32 pyobj: Py<PyAny>,
33}
34
35impl TryFrom<Bound<'_, PyAny>> for SSLContext {
36 type Error = PyErr;
37
38 fn try_from(obj: Bound<PyAny>) -> PyResult<Self> {
39 #[repr(C)]
40 struct PySSLContext {
41 ob_base: PyObject,
42 ctx: *mut ffi::SSL_CTX,
43 }
44
45 unsafe {
46 let ptr = obj.as_ptr() as *const PySSLContext;
47 let ptr = (*ptr).ctx;
48 if ptr.is_null() {
49 return Err(pyo3::exceptions::PyValueError::new_err(
50 "SSLContext has null SSL_CTX",
51 ));
52 }
53 Ok(Self {
54 ptr,
55 pyobj: obj.unbind(),
56 })
57 }
58 }
59}
60
61impl SSLContext {
62 pub fn connect<S>(&self, domain: &str, stream: S) -> Result<SslStream<S>, HandshakeError<S>>
63 where
64 S: Read + Write,
65 {
66 let mut ssl = Ssl::new(self.ptr)?;
67 let ip = domain.parse::<IpAddr>().ok();
68 if ip.is_none() {
69 ssl.set_hostname(domain)?;
70 }
71 let res = Python::attach(|py| {
72 self.pyobj
73 .bind(py)
74 .getattr("check_hostname")?
75 .extract::<bool>()
76 });
77 match res {
78 Ok(true) => {
79 let param = ssl.param_mut();
80 match ip {
81 Some(ip) => param.set_ip(ip)?,
82 None => param.set_host(domain)?,
83 }
84 }
85 Ok(false) => {}
86 Err(e) => panic!("{e}"),
87 }
88
89 ssl.connect(stream)
90 }
91
92 pub fn accept<S>(&self, stream: S) -> Result<SslStream<S>, HandshakeError<S>>
93 where
94 S: Read + Write,
95 {
96 let ssl = Ssl::new(self.ptr)?;
97 ssl.accept(stream)
98 }
99}
100
101impl Clone for SSLContext {
102 fn clone(&self) -> Self {
103 Python::attach(|py| Self {
104 ptr: self.ptr,
105 pyobj: self.pyobj.clone_ref(py),
106 })
107 }
108}
109
110pub fn load_py(py: Python) -> PyResult<bool> {
111 if loader::is_loaded() {
112 return Ok(true);
113 }
114
115 let find_lib = c_str!(
116 r#"
117import inspect, ssl
118try:
119 lib = inspect.getfile(ssl.SSLSession)
120except TypeError:
121 import os, sysconfig
122 lib = os.path.join(sysconfig.get_config_var("LIBDIR"), sysconfig.get_config_var("LDLIBRARY"))
123"#
124 );
125 let locals = PyDict::new(py);
126 py.run(find_lib, None, Some(&locals))?;
127 let lib = locals
128 .get_item("lib")?
129 .expect("defined lib")
130 .extract::<String>()?;
131 match py.detach(|| loader::load(OsStr::new(&lib))) {
132 Ok(()) => Ok(true),
133 Err(loader::Error::AlreadyLoaded) => Ok(true),
134 Err(loader::Error::LibraryNotFound) => {
135 debug!("Failed to load OpenSSL: library not found");
136 Ok(false)
137 }
138 Err(loader::Error::IoError(_e)) => {
139 debug!("Failed to load OpenSSL: {_e}");
140 Ok(false)
141 }
142 Err(loader::Error::VersionTooOld) => {
143 debug!("Failed to load OpenSSL: version is too old");
144 Ok(false)
145 }
146 Err(loader::Error::Loader(_e)) => {
147 debug!("Failed to load OpenSSL: {_e}");
148 Ok(false)
149 }
150 #[cfg(windows)]
151 Err(loader::Error::PE(_e)) => {
152 debug!("Failed to load OpenSSL: {_e}");
153 Ok(false)
154 }
155 }
156}