pingora_openssl/
ext.rs

1// Copyright 2025 Cloudflare, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use foreign_types::ForeignTypeRef;
16use libc::*;
17use openssl::error::ErrorStack;
18use openssl::pkey::{HasPrivate, PKeyRef};
19use openssl::ssl::{Ssl, SslAcceptor, SslRef};
20use openssl::x509::store::X509StoreRef;
21use openssl::x509::verify::X509VerifyParamRef;
22use openssl::x509::X509Ref;
23use openssl_sys::{
24    SSL_ctrl, EVP_PKEY, SSL, SSL_CTRL_SET_GROUPS_LIST, SSL_CTRL_SET_VERIFY_CERT_STORE, X509,
25    X509_VERIFY_PARAM,
26};
27use std::ffi::CString;
28use std::os::raw;
29
30fn cvt(r: c_long) -> Result<c_long, ErrorStack> {
31    if r != 1 {
32        Err(ErrorStack::get())
33    } else {
34        Ok(r)
35    }
36}
37
38extern "C" {
39    pub fn X509_VERIFY_PARAM_add1_host(
40        param: *mut X509_VERIFY_PARAM,
41        name: *const c_char,
42        namelen: size_t,
43    ) -> c_int;
44
45    pub fn SSL_use_certificate(ssl: *mut SSL, cert: *mut X509) -> c_int;
46    pub fn SSL_use_PrivateKey(ssl: *mut SSL, key: *mut EVP_PKEY) -> c_int;
47
48    pub fn SSL_set_cert_cb(
49        ssl: *mut SSL,
50        cb: ::std::option::Option<
51            unsafe extern "C" fn(ssl: *mut SSL, arg: *mut raw::c_void) -> raw::c_int,
52        >,
53        arg: *mut raw::c_void,
54    );
55}
56
57/// Add name as an additional reference identifier that can match the peer's certificate
58///
59/// See [X509_VERIFY_PARAM_set1_host](https://www.openssl.org/docs/man3.1/man3/X509_VERIFY_PARAM_set1_host.html).
60pub fn add_host(verify_param: &mut X509VerifyParamRef, host: &str) -> Result<(), ErrorStack> {
61    if host.is_empty() {
62        return Ok(());
63    }
64    unsafe {
65        cvt(X509_VERIFY_PARAM_add1_host(
66            verify_param.as_ptr(),
67            host.as_ptr() as *const c_char,
68            host.len(),
69        ) as c_long)
70        .map(|_| ())
71    }
72}
73
74/// Set the verify cert store of `ssl`
75///
76/// See [SSL_set1_verify_cert_store](https://www.openssl.org/docs/man1.1.1/man3/SSL_set1_verify_cert_store.html).
77pub fn ssl_set_verify_cert_store(
78    ssl: &mut SslRef,
79    cert_store: &X509StoreRef,
80) -> Result<(), ErrorStack> {
81    unsafe {
82        cvt(SSL_ctrl(
83            ssl.as_ptr(),
84            SSL_CTRL_SET_VERIFY_CERT_STORE,
85            1, // increase the ref count of X509Store so that ssl_ctx can outlive X509StoreRef
86            cert_store.as_ptr() as *mut c_void,
87        ))?;
88    }
89    Ok(())
90}
91
92/// Load the certificate into `ssl`
93///
94/// See [SSL_use_certificate](https://www.openssl.org/docs/man1.1.1/man3/SSL_use_certificate.html).
95pub fn ssl_use_certificate(ssl: &mut SslRef, cert: &X509Ref) -> Result<(), ErrorStack> {
96    unsafe {
97        cvt(SSL_use_certificate(ssl.as_ptr(), cert.as_ptr()) as c_long)?;
98    }
99    Ok(())
100}
101
102/// Load the private key into `ssl`
103///
104/// See [SSL_use_certificate](https://www.openssl.org/docs/man1.1.1/man3/SSL_use_PrivateKey.html).
105pub fn ssl_use_private_key<T>(ssl: &mut SslRef, key: &PKeyRef<T>) -> Result<(), ErrorStack>
106where
107    T: HasPrivate,
108{
109    unsafe {
110        cvt(SSL_use_PrivateKey(ssl.as_ptr(), key.as_ptr()) as c_long)?;
111    }
112    Ok(())
113}
114
115/// Add the certificate into the cert chain of `ssl`
116///
117/// See [SSL_add1_chain_cert](https://www.openssl.org/docs/man1.1.1/man3/SSL_add1_chain_cert.html)
118pub fn ssl_add_chain_cert(ssl: &mut SslRef, cert: &X509Ref) -> Result<(), ErrorStack> {
119    const SSL_CTRL_CHAIN_CERT: i32 = 89;
120    unsafe {
121        cvt(SSL_ctrl(
122            ssl.as_ptr(),
123            SSL_CTRL_CHAIN_CERT,
124            1, // increase the ref count of X509 so that ssl can outlive X509StoreRef
125            cert.as_ptr() as *mut c_void,
126        ))?;
127    }
128    Ok(())
129}
130
131/// Set renegotiation
132///
133/// This function is specific to BoringSSL. This function is noop for OpenSSL.
134pub fn ssl_set_renegotiate_mode_freely(_ssl: &mut SslRef) {}
135
136/// Set the curves/groups of `ssl`
137///
138/// See [set_groups_list](https://www.openssl.org/docs/manmaster/man3/SSL_CTX_set1_curves.html).
139pub fn ssl_set_groups_list(ssl: &mut SslRef, groups: &str) -> Result<(), ErrorStack> {
140    if groups.contains('\0') {
141        return Err(ErrorStack::get());
142    }
143    let groups = CString::new(groups).map_err(|_| ErrorStack::get())?;
144    unsafe {
145        cvt(SSL_ctrl(
146            ssl.as_ptr(),
147            SSL_CTRL_SET_GROUPS_LIST,
148            0,
149            groups.as_ptr() as *mut c_void,
150        ))?;
151    }
152    Ok(())
153}
154
155/// Set's whether a second keyshare to be sent in client hello when PQ is used.
156///
157/// This function is specific to BoringSSL. This function is noop for OpenSSL.
158pub fn ssl_use_second_key_share(_ssl: &mut SslRef, _enabled: bool) {}
159
160/// Clear the error stack
161///
162/// SSL calls should check and clear the OpenSSL error stack. But some calls fail to do so.
163/// This causes the next unrelated SSL call to fail due to the leftover errors. This function allows
164/// caller to clear the error stack before performing SSL calls to avoid this issue.
165pub fn clear_error_stack() {
166    let _ = ErrorStack::get();
167}
168
169/// Create a new [Ssl] from &[SslAcceptor]
170///
171/// this function is to unify the interface between this crate and [`pingora-boringssl`](https://docs.rs/pingora-boringssl)
172pub fn ssl_from_acceptor(acceptor: &SslAcceptor) -> Result<Ssl, ErrorStack> {
173    Ssl::new(acceptor.context())
174}
175
176/// Suspend the TLS handshake when a certificate is needed.
177///
178/// This function will cause tls handshake to pause and return the error: SSL_ERROR_WANT_X509_LOOKUP.
179/// The caller should set the certificate and then call [unblock_ssl_cert()] before continue the
180/// handshake on the tls connection.
181pub fn suspend_when_need_ssl_cert(ssl: &mut SslRef) {
182    unsafe {
183        SSL_set_cert_cb(ssl.as_ptr(), Some(raw_cert_block), std::ptr::null_mut());
184    }
185}
186
187/// Unblock a TLS handshake after the certificate is set.
188///
189/// The user should continue to call tls handshake after this function is called.
190pub fn unblock_ssl_cert(ssl: &mut SslRef) {
191    unsafe {
192        SSL_set_cert_cb(ssl.as_ptr(), None, std::ptr::null_mut());
193    }
194}
195
196// Just block the handshake
197extern "C" fn raw_cert_block(_ssl: *mut openssl_sys::SSL, _arg: *mut c_void) -> c_int {
198    -1
199}
200
201/// Whether the TLS error is SSL_ERROR_WANT_X509_LOOKUP
202pub fn is_suspended_for_cert(error: &openssl::ssl::Error) -> bool {
203    error.code().as_raw() == openssl_sys::SSL_ERROR_WANT_X509_LOOKUP
204}
205
206#[allow(clippy::mut_from_ref)]
207/// Get a mutable SslRef ouf of SslRef, which is a missing functionality even when holding &mut SslStream
208/// # Safety
209/// the caller needs to make sure that they hold a &mut SslStream (or other types of mutable ref to the Ssl)
210pub unsafe fn ssl_mut(ssl: &SslRef) -> &mut SslRef {
211    SslRef::from_ptr_mut(ssl.as_ptr())
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217    use openssl::ssl::{SslContextBuilder, SslMethod};
218
219    #[test]
220    fn test_ssl_set_groups_list() {
221        let ctx_builder = SslContextBuilder::new(SslMethod::tls()).unwrap();
222        let ssl = Ssl::new(&ctx_builder.build()).unwrap();
223        let ssl_ref = unsafe { ssl_mut(&ssl) };
224
225        // Valid input
226        assert!(ssl_set_groups_list(ssl_ref, "P-256:P-384").is_ok());
227
228        // Invalid input (contains null byte)
229        assert!(ssl_set_groups_list(ssl_ref, "P-256\0P-384").is_err());
230    }
231}