Skip to main content

compio_py_dynamic_openssl/
ssl.rs

1// This module is mostly copied and modified from:
2// https://github.com/rust-openssl/rust-openssl/blob/openssl-v0.10.75/openssl/src/ssl/mod.rs
3//
4// SPDX-License-Identifier: Apache-2.0
5// Copyright 2011-2017 Google Inc.
6//           2013 Jack Lloyd
7//           2013-2014 Steven Fackler
8//
9// SPDX-License-Identifier: Apache-2.0 OR MulanPSL-2.0
10// Copyright 2026 Fantix King
11
12use std::{
13    cell::UnsafeCell,
14    ffi::{CString, c_int, c_uchar, c_uint},
15    fmt,
16    io::{self, Read, Write},
17    marker::PhantomData,
18    mem::ManuallyDrop,
19    net::IpAddr,
20    panic::resume_unwind,
21    ptr,
22};
23
24use self::error::InnerError;
25pub use self::error::{Error, ErrorCode, HandshakeError};
26use crate::{
27    bio::{self, BioMethod, cvt, cvt_p},
28    error::ErrorStack,
29    sys as ffi,
30};
31
32pub struct Ssl(*mut ffi::SSL);
33
34impl Drop for Ssl {
35    fn drop(&mut self) {
36        let ossl = crate::get();
37        unsafe {
38            (ossl.SSL_free)(self.0);
39        }
40    }
41}
42
43impl Ssl {
44    pub fn new(ctx: *mut ffi::SSL_CTX) -> Result<Ssl, ErrorStack> {
45        let ossl = crate::get();
46        cvt_p(unsafe { (ossl.SSL_new)(ctx) }).map(Self)
47    }
48
49    pub fn connect<S>(self, stream: S) -> Result<SslStream<S>, HandshakeError<S>>
50    where
51        S: Read + Write,
52    {
53        let mut stream = SslStream::new(self, stream)?;
54        match stream.connect() {
55            Ok(()) => Ok(stream),
56            Err(error) => match error.code() {
57                ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => {
58                    Err(HandshakeError::WouldBlock(MidHandshakeSslStream {
59                        stream,
60                        error,
61                    }))
62                }
63                _ => Err(HandshakeError::Failure(MidHandshakeSslStream {
64                    stream,
65                    error,
66                })),
67            },
68        }
69    }
70
71    pub fn accept<S>(self, stream: S) -> Result<SslStream<S>, HandshakeError<S>>
72    where
73        S: Read + Write,
74    {
75        let mut stream = SslStream::new(self, stream)?;
76        match stream.accept() {
77            Ok(()) => Ok(stream),
78            Err(error) => match error.code() {
79                ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => {
80                    Err(HandshakeError::WouldBlock(MidHandshakeSslStream {
81                        stream,
82                        error,
83                    }))
84                }
85                _ => Err(HandshakeError::Failure(MidHandshakeSslStream {
86                    stream,
87                    error,
88                })),
89            },
90        }
91    }
92
93    fn get_raw_rbio(&self) -> *mut ffi::BIO {
94        let ffi = crate::get();
95        unsafe { (ffi.SSL_get_rbio)(self.0) }
96    }
97
98    fn get_error(&self, ret: c_int) -> ErrorCode {
99        let ffi = crate::get();
100        unsafe { ErrorCode::from_raw((ffi.SSL_get_error)(self.0, ret)) }
101    }
102
103    pub fn set_hostname(&mut self, hostname: &str) -> Result<(), ErrorStack> {
104        let ffi = crate::get();
105        let cstr = CString::new(hostname).unwrap();
106        unsafe {
107            cvt(ffi.SSL_set_tlsext_host_name(self.0, cstr.as_ptr() as *mut _) as c_int).map(|_| ())
108        }
109    }
110
111    pub fn selected_alpn_protocol(&self) -> Option<&[u8]> {
112        let ffi = crate::get();
113        unsafe {
114            let mut data: *const c_uchar = ptr::null();
115            let mut len: c_uint = 0;
116            (ffi.SSL_get0_alpn_selected)(self.0, &mut data, &mut len);
117
118            if data.is_null() {
119                None
120            } else {
121                Some(bio::from_raw_parts(data, len as usize))
122            }
123        }
124    }
125
126    pub fn param_mut(&mut self) -> &mut X509VerifyParamRef {
127        let ffi = crate::get();
128        unsafe { X509VerifyParamRef::from_ptr_mut((ffi.SSL_get0_param)(self.0)) }
129    }
130}
131
132pub struct MidHandshakeSslStream<S> {
133    stream: SslStream<S>,
134    error: Error,
135}
136
137impl<S> MidHandshakeSslStream<S> {
138    pub fn get_mut(&mut self) -> &mut S {
139        self.stream.get_mut()
140    }
141
142    pub fn into_error(self) -> Error {
143        self.error
144    }
145}
146
147impl<S> MidHandshakeSslStream<S>
148where
149    S: Read + Write,
150{
151    pub fn handshake(mut self) -> Result<SslStream<S>, HandshakeError<S>> {
152        match self.stream.do_handshake() {
153            Ok(()) => Ok(self.stream),
154            Err(error) => {
155                self.error = error;
156                match self.error.code() {
157                    ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => {
158                        Err(HandshakeError::WouldBlock(self))
159                    }
160                    _ => Err(HandshakeError::Failure(self)),
161                }
162            }
163        }
164    }
165}
166
167pub struct SslStream<S> {
168    ssl: ManuallyDrop<Ssl>,
169    method: ManuallyDrop<BioMethod>,
170    _p: PhantomData<S>,
171}
172
173impl<S> Drop for SslStream<S> {
174    fn drop(&mut self) {
175        unsafe {
176            ManuallyDrop::drop(&mut self.ssl);
177            ManuallyDrop::drop(&mut self.method);
178        }
179    }
180}
181
182impl<S> fmt::Debug for SslStream<S>
183where
184    S: fmt::Debug,
185{
186    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
187        fmt.debug_struct("SslStream")
188            .field("stream", &self.get_ref())
189            .field("ssl", &self.ssl.0)
190            .finish()
191    }
192}
193
194impl<S: Read + Write> SslStream<S> {
195    pub fn new(ssl: Ssl, stream: S) -> Result<Self, ErrorStack> {
196        let ffi = crate::get();
197        let (bio, method) = bio::new(stream)?;
198        unsafe {
199            (ffi.SSL_set_bio)(ssl.0, bio, bio);
200        }
201
202        Ok(Self {
203            ssl: ManuallyDrop::new(ssl),
204            method: ManuallyDrop::new(method),
205            _p: PhantomData,
206        })
207    }
208
209    pub fn connect(&mut self) -> Result<(), Error> {
210        let ffi = crate::get();
211        let ret = unsafe { (ffi.SSL_connect)(self.ssl.0) };
212        if ret > 0 {
213            Ok(())
214        } else {
215            Err(self.make_error(ret))
216        }
217    }
218
219    pub fn accept(&mut self) -> Result<(), Error> {
220        let ffi = crate::get();
221        let ret = unsafe { (ffi.SSL_accept)(self.ssl.0) };
222        if ret > 0 {
223            Ok(())
224        } else {
225            Err(self.make_error(ret))
226        }
227    }
228
229    pub fn do_handshake(&mut self) -> Result<(), Error> {
230        let ffi = crate::get();
231        let ret = unsafe { (ffi.SSL_do_handshake)(self.ssl.0) };
232        if ret > 0 {
233            Ok(())
234        } else {
235            Err(self.make_error(ret))
236        }
237    }
238
239    pub fn ssl_read(&mut self, buf: &mut [u8]) -> Result<usize, Error> {
240        let ffi = crate::get();
241        let mut readbytes = 0;
242        let ret = unsafe {
243            (ffi.SSL_read_ex)(
244                self.ssl.0,
245                buf.as_mut_ptr().cast(),
246                buf.len(),
247                &mut readbytes,
248            )
249        };
250
251        if ret > 0 {
252            Ok(readbytes)
253        } else {
254            Err(self.make_error(ret))
255        }
256    }
257
258    pub fn ssl_write(&mut self, buf: &[u8]) -> Result<usize, Error> {
259        let ffi = crate::get();
260        let mut written = 0;
261        let ret =
262            unsafe { (ffi.SSL_write_ex)(self.ssl.0, buf.as_ptr().cast(), buf.len(), &mut written) };
263
264        if ret > 0 {
265            Ok(written)
266        } else {
267            Err(self.make_error(ret))
268        }
269    }
270
271    pub fn shutdown(&mut self) -> Result<ShutdownResult, Error> {
272        let ffi = crate::get();
273        match unsafe { (ffi.SSL_shutdown)(self.ssl.0) } {
274            0 => Ok(ShutdownResult::Sent),
275            1 => Ok(ShutdownResult::Received),
276            n => Err(self.make_error(n)),
277        }
278    }
279}
280
281impl<S> SslStream<S> {
282    fn make_error(&mut self, ret: c_int) -> Error {
283        self.check_panic();
284
285        let code = self.ssl.get_error(ret);
286
287        let cause = match code {
288            ErrorCode::SSL => Some(InnerError::Ssl(ErrorStack::get())),
289            ErrorCode::SYSCALL => {
290                let errs = ErrorStack::get();
291                if errs.errors().is_empty() {
292                    self.get_bio_error().map(InnerError::Io)
293                } else {
294                    Some(InnerError::Ssl(errs))
295                }
296            }
297            ErrorCode::ZERO_RETURN => None,
298            ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => {
299                self.get_bio_error().map(InnerError::Io)
300            }
301            _ => None,
302        };
303
304        Error { code, cause }
305    }
306
307    fn check_panic(&mut self) {
308        if let Some(err) = unsafe { bio::take_panic::<S>(self.ssl.get_raw_rbio()) } {
309            resume_unwind(err);
310        }
311    }
312
313    fn get_bio_error(&mut self) -> Option<io::Error> {
314        unsafe { bio::take_error::<S>(self.ssl.get_raw_rbio()) }
315    }
316
317    pub fn get_ref(&self) -> &S {
318        unsafe {
319            let bio = self.ssl.get_raw_rbio();
320            bio::get_ref(bio)
321        }
322    }
323
324    pub fn get_mut(&mut self) -> &mut S {
325        unsafe {
326            let bio = self.ssl.get_raw_rbio();
327            bio::get_mut(bio)
328        }
329    }
330
331    pub fn ssl(&self) -> &Ssl {
332        &self.ssl
333    }
334}
335
336#[derive(Copy, Clone, Debug, PartialEq, Eq)]
337pub enum ShutdownResult {
338    Sent,
339    Received,
340}
341
342pub struct X509VerifyParamRef(UnsafeCell<()>);
343
344impl X509VerifyParamRef {
345    #[inline]
346    unsafe fn from_ptr_mut<'a>(ptr: *mut ffi::X509_VERIFY_PARAM) -> &'a mut Self {
347        unsafe { &mut *(ptr as *mut _) }
348    }
349
350    #[inline]
351    fn as_ptr(&self) -> *mut ffi::X509_VERIFY_PARAM {
352        self as *const _ as *mut _
353    }
354
355    pub fn set_host(&mut self, host: &str) -> Result<(), ErrorStack> {
356        let ffi = crate::get();
357        unsafe {
358            let raw_host = if host.is_empty() { "\0" } else { host };
359            cvt((ffi.X509_VERIFY_PARAM_set1_host)(
360                self.as_ptr(),
361                raw_host.as_ptr() as *const _,
362                isize::try_from(host.len()).expect("host length <= isize::MAX"),
363            ))
364            .map(|_| ())
365        }
366    }
367
368    pub fn set_ip(&mut self, ip: IpAddr) -> Result<(), ErrorStack> {
369        let ffi = crate::get();
370        unsafe {
371            let mut buf = [0; 16];
372            let len = match ip {
373                IpAddr::V4(addr) => {
374                    buf[..4].copy_from_slice(&addr.octets());
375                    4
376                }
377                IpAddr::V6(addr) => {
378                    buf.copy_from_slice(&addr.octets());
379                    16
380                }
381            };
382            cvt((ffi.X509_VERIFY_PARAM_set1_ip)(
383                self.as_ptr(),
384                buf.as_ptr() as *const _,
385                len,
386            ))
387            .map(|_| ())
388        }
389    }
390}
391
392mod error {
393    use std::{error, ffi::c_int, fmt, io};
394
395    use crate::{error::ErrorStack, ssl::MidHandshakeSslStream, sys as ffi};
396
397    #[derive(Debug, Copy, Clone, PartialEq, Eq)]
398    pub struct ErrorCode(c_int);
399
400    impl ErrorCode {
401        pub const ZERO_RETURN: ErrorCode = ErrorCode(ffi::SSL_ERROR_ZERO_RETURN);
402        pub const WANT_READ: ErrorCode = ErrorCode(ffi::SSL_ERROR_WANT_READ);
403        pub const WANT_WRITE: ErrorCode = ErrorCode(ffi::SSL_ERROR_WANT_WRITE);
404        pub const SYSCALL: ErrorCode = ErrorCode(ffi::SSL_ERROR_SYSCALL);
405        pub const SSL: ErrorCode = ErrorCode(ffi::SSL_ERROR_SSL);
406        pub const WANT_CLIENT_HELLO_CB: ErrorCode = ErrorCode(ffi::SSL_ERROR_WANT_CLIENT_HELLO_CB);
407
408        pub fn from_raw(raw: c_int) -> ErrorCode {
409            ErrorCode(raw)
410        }
411    }
412
413    #[derive(Debug)]
414    pub(crate) enum InnerError {
415        Io(io::Error),
416        Ssl(ErrorStack),
417    }
418
419    #[derive(Debug)]
420    pub struct Error {
421        pub(crate) code: ErrorCode,
422        pub(crate) cause: Option<InnerError>,
423    }
424
425    impl Error {
426        pub fn code(&self) -> ErrorCode {
427            self.code
428        }
429
430        pub fn io_error(&self) -> Option<&io::Error> {
431            match self.cause {
432                Some(InnerError::Io(ref e)) => Some(e),
433                _ => None,
434            }
435        }
436
437        pub fn into_io_error(self) -> Result<io::Error, Error> {
438            match self.cause {
439                Some(InnerError::Io(e)) => Ok(e),
440                _ => Err(self),
441            }
442        }
443
444        pub fn ssl_error(&self) -> Option<&ErrorStack> {
445            match self.cause {
446                Some(InnerError::Ssl(ref e)) => Some(e),
447                _ => None,
448            }
449        }
450    }
451
452    impl fmt::Display for Error {
453        fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
454            match self.code {
455                ErrorCode::ZERO_RETURN => fmt.write_str("the SSL session has been shut down"),
456                ErrorCode::WANT_READ => match self.io_error() {
457                    Some(_) => fmt.write_str("a nonblocking read call would have blocked"),
458                    None => fmt.write_str("the operation should be retried"),
459                },
460                ErrorCode::WANT_WRITE => match self.io_error() {
461                    Some(_) => fmt.write_str("a nonblocking write call would have blocked"),
462                    None => fmt.write_str("the operation should be retried"),
463                },
464                ErrorCode::SYSCALL => match self.io_error() {
465                    Some(err) => write!(fmt, "{}", err),
466                    None => fmt.write_str("unexpected EOF"),
467                },
468                ErrorCode::SSL => match self.ssl_error() {
469                    Some(e) => write!(fmt, "{}", e),
470                    None => fmt.write_str("OpenSSL error"),
471                },
472                ErrorCode(code) => write!(fmt, "unknown error code {}", code),
473            }
474        }
475    }
476
477    impl error::Error for Error {
478        fn source(&self) -> Option<&(dyn error::Error + 'static)> {
479            match self.cause {
480                Some(InnerError::Io(ref e)) => Some(e),
481                Some(InnerError::Ssl(ref e)) => Some(e),
482                None => None,
483            }
484        }
485    }
486
487    pub enum HandshakeError<S> {
488        SetupFailure(ErrorStack),
489        Failure(MidHandshakeSslStream<S>),
490        WouldBlock(MidHandshakeSslStream<S>),
491    }
492
493    impl<S> From<ErrorStack> for HandshakeError<S> {
494        fn from(e: ErrorStack) -> HandshakeError<S> {
495            HandshakeError::SetupFailure(e)
496        }
497    }
498}