compio_py_dynamic_openssl/
ssl.rs1use std::{
13 cell::UnsafeCell,
14 ffi::{CString, c_int, c_uchar, c_uint},
15 fmt,
16 io::{self, Read, Write},
17 marker::PhantomData,
18 mem::ManuallyDrop,
19 net::IpAddr,
20 panic::resume_unwind,
21 ptr,
22};
23
24use self::error::InnerError;
25pub use self::error::{Error, ErrorCode, HandshakeError};
26use crate::{
27 bio::{self, BioMethod, cvt, cvt_p},
28 error::ErrorStack,
29 sys as ffi,
30};
31
32pub struct Ssl(*mut ffi::SSL);
33
34impl Drop for Ssl {
35 fn drop(&mut self) {
36 let ossl = crate::get();
37 unsafe {
38 (ossl.SSL_free)(self.0);
39 }
40 }
41}
42
43impl Ssl {
44 pub fn new(ctx: *mut ffi::SSL_CTX) -> Result<Ssl, ErrorStack> {
45 let ossl = crate::get();
46 cvt_p(unsafe { (ossl.SSL_new)(ctx) }).map(Self)
47 }
48
49 pub fn connect<S>(self, stream: S) -> Result<SslStream<S>, HandshakeError<S>>
50 where
51 S: Read + Write,
52 {
53 let mut stream = SslStream::new(self, stream)?;
54 match stream.connect() {
55 Ok(()) => Ok(stream),
56 Err(error) => match error.code() {
57 ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => {
58 Err(HandshakeError::WouldBlock(MidHandshakeSslStream {
59 stream,
60 error,
61 }))
62 }
63 _ => Err(HandshakeError::Failure(MidHandshakeSslStream {
64 stream,
65 error,
66 })),
67 },
68 }
69 }
70
71 pub fn accept<S>(self, stream: S) -> Result<SslStream<S>, HandshakeError<S>>
72 where
73 S: Read + Write,
74 {
75 let mut stream = SslStream::new(self, stream)?;
76 match stream.accept() {
77 Ok(()) => Ok(stream),
78 Err(error) => match error.code() {
79 ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => {
80 Err(HandshakeError::WouldBlock(MidHandshakeSslStream {
81 stream,
82 error,
83 }))
84 }
85 _ => Err(HandshakeError::Failure(MidHandshakeSslStream {
86 stream,
87 error,
88 })),
89 },
90 }
91 }
92
93 fn get_raw_rbio(&self) -> *mut ffi::BIO {
94 let ffi = crate::get();
95 unsafe { (ffi.SSL_get_rbio)(self.0) }
96 }
97
98 fn get_error(&self, ret: c_int) -> ErrorCode {
99 let ffi = crate::get();
100 unsafe { ErrorCode::from_raw((ffi.SSL_get_error)(self.0, ret)) }
101 }
102
103 pub fn set_hostname(&mut self, hostname: &str) -> Result<(), ErrorStack> {
104 let ffi = crate::get();
105 let cstr = CString::new(hostname).unwrap();
106 unsafe {
107 cvt(ffi.SSL_set_tlsext_host_name(self.0, cstr.as_ptr() as *mut _) as c_int).map(|_| ())
108 }
109 }
110
111 pub fn selected_alpn_protocol(&self) -> Option<&[u8]> {
112 let ffi = crate::get();
113 unsafe {
114 let mut data: *const c_uchar = ptr::null();
115 let mut len: c_uint = 0;
116 (ffi.SSL_get0_alpn_selected)(self.0, &mut data, &mut len);
117
118 if data.is_null() {
119 None
120 } else {
121 Some(bio::from_raw_parts(data, len as usize))
122 }
123 }
124 }
125
126 pub fn param_mut(&mut self) -> &mut X509VerifyParamRef {
127 let ffi = crate::get();
128 unsafe { X509VerifyParamRef::from_ptr_mut((ffi.SSL_get0_param)(self.0)) }
129 }
130}
131
132pub struct MidHandshakeSslStream<S> {
133 stream: SslStream<S>,
134 error: Error,
135}
136
137impl<S> MidHandshakeSslStream<S> {
138 pub fn get_mut(&mut self) -> &mut S {
139 self.stream.get_mut()
140 }
141
142 pub fn into_error(self) -> Error {
143 self.error
144 }
145}
146
147impl<S> MidHandshakeSslStream<S>
148where
149 S: Read + Write,
150{
151 pub fn handshake(mut self) -> Result<SslStream<S>, HandshakeError<S>> {
152 match self.stream.do_handshake() {
153 Ok(()) => Ok(self.stream),
154 Err(error) => {
155 self.error = error;
156 match self.error.code() {
157 ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => {
158 Err(HandshakeError::WouldBlock(self))
159 }
160 _ => Err(HandshakeError::Failure(self)),
161 }
162 }
163 }
164 }
165}
166
167pub struct SslStream<S> {
168 ssl: ManuallyDrop<Ssl>,
169 method: ManuallyDrop<BioMethod>,
170 _p: PhantomData<S>,
171}
172
173impl<S> Drop for SslStream<S> {
174 fn drop(&mut self) {
175 unsafe {
176 ManuallyDrop::drop(&mut self.ssl);
177 ManuallyDrop::drop(&mut self.method);
178 }
179 }
180}
181
182impl<S> fmt::Debug for SslStream<S>
183where
184 S: fmt::Debug,
185{
186 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
187 fmt.debug_struct("SslStream")
188 .field("stream", &self.get_ref())
189 .field("ssl", &self.ssl.0)
190 .finish()
191 }
192}
193
194impl<S: Read + Write> SslStream<S> {
195 pub fn new(ssl: Ssl, stream: S) -> Result<Self, ErrorStack> {
196 let ffi = crate::get();
197 let (bio, method) = bio::new(stream)?;
198 unsafe {
199 (ffi.SSL_set_bio)(ssl.0, bio, bio);
200 }
201
202 Ok(Self {
203 ssl: ManuallyDrop::new(ssl),
204 method: ManuallyDrop::new(method),
205 _p: PhantomData,
206 })
207 }
208
209 pub fn connect(&mut self) -> Result<(), Error> {
210 let ffi = crate::get();
211 let ret = unsafe { (ffi.SSL_connect)(self.ssl.0) };
212 if ret > 0 {
213 Ok(())
214 } else {
215 Err(self.make_error(ret))
216 }
217 }
218
219 pub fn accept(&mut self) -> Result<(), Error> {
220 let ffi = crate::get();
221 let ret = unsafe { (ffi.SSL_accept)(self.ssl.0) };
222 if ret > 0 {
223 Ok(())
224 } else {
225 Err(self.make_error(ret))
226 }
227 }
228
229 pub fn do_handshake(&mut self) -> Result<(), Error> {
230 let ffi = crate::get();
231 let ret = unsafe { (ffi.SSL_do_handshake)(self.ssl.0) };
232 if ret > 0 {
233 Ok(())
234 } else {
235 Err(self.make_error(ret))
236 }
237 }
238
239 pub fn ssl_read(&mut self, buf: &mut [u8]) -> Result<usize, Error> {
240 let ffi = crate::get();
241 let mut readbytes = 0;
242 let ret = unsafe {
243 (ffi.SSL_read_ex)(
244 self.ssl.0,
245 buf.as_mut_ptr().cast(),
246 buf.len(),
247 &mut readbytes,
248 )
249 };
250
251 if ret > 0 {
252 Ok(readbytes)
253 } else {
254 Err(self.make_error(ret))
255 }
256 }
257
258 pub fn ssl_write(&mut self, buf: &[u8]) -> Result<usize, Error> {
259 let ffi = crate::get();
260 let mut written = 0;
261 let ret =
262 unsafe { (ffi.SSL_write_ex)(self.ssl.0, buf.as_ptr().cast(), buf.len(), &mut written) };
263
264 if ret > 0 {
265 Ok(written)
266 } else {
267 Err(self.make_error(ret))
268 }
269 }
270
271 pub fn shutdown(&mut self) -> Result<ShutdownResult, Error> {
272 let ffi = crate::get();
273 match unsafe { (ffi.SSL_shutdown)(self.ssl.0) } {
274 0 => Ok(ShutdownResult::Sent),
275 1 => Ok(ShutdownResult::Received),
276 n => Err(self.make_error(n)),
277 }
278 }
279}
280
281impl<S> SslStream<S> {
282 fn make_error(&mut self, ret: c_int) -> Error {
283 self.check_panic();
284
285 let code = self.ssl.get_error(ret);
286
287 let cause = match code {
288 ErrorCode::SSL => Some(InnerError::Ssl(ErrorStack::get())),
289 ErrorCode::SYSCALL => {
290 let errs = ErrorStack::get();
291 if errs.errors().is_empty() {
292 self.get_bio_error().map(InnerError::Io)
293 } else {
294 Some(InnerError::Ssl(errs))
295 }
296 }
297 ErrorCode::ZERO_RETURN => None,
298 ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => {
299 self.get_bio_error().map(InnerError::Io)
300 }
301 _ => None,
302 };
303
304 Error { code, cause }
305 }
306
307 fn check_panic(&mut self) {
308 if let Some(err) = unsafe { bio::take_panic::<S>(self.ssl.get_raw_rbio()) } {
309 resume_unwind(err);
310 }
311 }
312
313 fn get_bio_error(&mut self) -> Option<io::Error> {
314 unsafe { bio::take_error::<S>(self.ssl.get_raw_rbio()) }
315 }
316
317 pub fn get_ref(&self) -> &S {
318 unsafe {
319 let bio = self.ssl.get_raw_rbio();
320 bio::get_ref(bio)
321 }
322 }
323
324 pub fn get_mut(&mut self) -> &mut S {
325 unsafe {
326 let bio = self.ssl.get_raw_rbio();
327 bio::get_mut(bio)
328 }
329 }
330
331 pub fn ssl(&self) -> &Ssl {
332 &self.ssl
333 }
334}
335
336#[derive(Copy, Clone, Debug, PartialEq, Eq)]
337pub enum ShutdownResult {
338 Sent,
339 Received,
340}
341
342pub struct X509VerifyParamRef(UnsafeCell<()>);
343
344impl X509VerifyParamRef {
345 #[inline]
346 unsafe fn from_ptr_mut<'a>(ptr: *mut ffi::X509_VERIFY_PARAM) -> &'a mut Self {
347 unsafe { &mut *(ptr as *mut _) }
348 }
349
350 #[inline]
351 fn as_ptr(&self) -> *mut ffi::X509_VERIFY_PARAM {
352 self as *const _ as *mut _
353 }
354
355 pub fn set_host(&mut self, host: &str) -> Result<(), ErrorStack> {
356 let ffi = crate::get();
357 unsafe {
358 let raw_host = if host.is_empty() { "\0" } else { host };
359 cvt((ffi.X509_VERIFY_PARAM_set1_host)(
360 self.as_ptr(),
361 raw_host.as_ptr() as *const _,
362 isize::try_from(host.len()).expect("host length <= isize::MAX"),
363 ))
364 .map(|_| ())
365 }
366 }
367
368 pub fn set_ip(&mut self, ip: IpAddr) -> Result<(), ErrorStack> {
369 let ffi = crate::get();
370 unsafe {
371 let mut buf = [0; 16];
372 let len = match ip {
373 IpAddr::V4(addr) => {
374 buf[..4].copy_from_slice(&addr.octets());
375 4
376 }
377 IpAddr::V6(addr) => {
378 buf.copy_from_slice(&addr.octets());
379 16
380 }
381 };
382 cvt((ffi.X509_VERIFY_PARAM_set1_ip)(
383 self.as_ptr(),
384 buf.as_ptr() as *const _,
385 len,
386 ))
387 .map(|_| ())
388 }
389 }
390}
391
392mod error {
393 use std::{error, ffi::c_int, fmt, io};
394
395 use crate::{error::ErrorStack, ssl::MidHandshakeSslStream, sys as ffi};
396
397 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
398 pub struct ErrorCode(c_int);
399
400 impl ErrorCode {
401 pub const ZERO_RETURN: ErrorCode = ErrorCode(ffi::SSL_ERROR_ZERO_RETURN);
402 pub const WANT_READ: ErrorCode = ErrorCode(ffi::SSL_ERROR_WANT_READ);
403 pub const WANT_WRITE: ErrorCode = ErrorCode(ffi::SSL_ERROR_WANT_WRITE);
404 pub const SYSCALL: ErrorCode = ErrorCode(ffi::SSL_ERROR_SYSCALL);
405 pub const SSL: ErrorCode = ErrorCode(ffi::SSL_ERROR_SSL);
406 pub const WANT_CLIENT_HELLO_CB: ErrorCode = ErrorCode(ffi::SSL_ERROR_WANT_CLIENT_HELLO_CB);
407
408 pub fn from_raw(raw: c_int) -> ErrorCode {
409 ErrorCode(raw)
410 }
411 }
412
413 #[derive(Debug)]
414 pub(crate) enum InnerError {
415 Io(io::Error),
416 Ssl(ErrorStack),
417 }
418
419 #[derive(Debug)]
420 pub struct Error {
421 pub(crate) code: ErrorCode,
422 pub(crate) cause: Option<InnerError>,
423 }
424
425 impl Error {
426 pub fn code(&self) -> ErrorCode {
427 self.code
428 }
429
430 pub fn io_error(&self) -> Option<&io::Error> {
431 match self.cause {
432 Some(InnerError::Io(ref e)) => Some(e),
433 _ => None,
434 }
435 }
436
437 pub fn into_io_error(self) -> Result<io::Error, Error> {
438 match self.cause {
439 Some(InnerError::Io(e)) => Ok(e),
440 _ => Err(self),
441 }
442 }
443
444 pub fn ssl_error(&self) -> Option<&ErrorStack> {
445 match self.cause {
446 Some(InnerError::Ssl(ref e)) => Some(e),
447 _ => None,
448 }
449 }
450 }
451
452 impl fmt::Display for Error {
453 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
454 match self.code {
455 ErrorCode::ZERO_RETURN => fmt.write_str("the SSL session has been shut down"),
456 ErrorCode::WANT_READ => match self.io_error() {
457 Some(_) => fmt.write_str("a nonblocking read call would have blocked"),
458 None => fmt.write_str("the operation should be retried"),
459 },
460 ErrorCode::WANT_WRITE => match self.io_error() {
461 Some(_) => fmt.write_str("a nonblocking write call would have blocked"),
462 None => fmt.write_str("the operation should be retried"),
463 },
464 ErrorCode::SYSCALL => match self.io_error() {
465 Some(err) => write!(fmt, "{}", err),
466 None => fmt.write_str("unexpected EOF"),
467 },
468 ErrorCode::SSL => match self.ssl_error() {
469 Some(e) => write!(fmt, "{}", e),
470 None => fmt.write_str("OpenSSL error"),
471 },
472 ErrorCode(code) => write!(fmt, "unknown error code {}", code),
473 }
474 }
475 }
476
477 impl error::Error for Error {
478 fn source(&self) -> Option<&(dyn error::Error + 'static)> {
479 match self.cause {
480 Some(InnerError::Io(ref e)) => Some(e),
481 Some(InnerError::Ssl(ref e)) => Some(e),
482 None => None,
483 }
484 }
485 }
486
487 pub enum HandshakeError<S> {
488 SetupFailure(ErrorStack),
489 Failure(MidHandshakeSslStream<S>),
490 WouldBlock(MidHandshakeSslStream<S>),
491 }
492
493 impl<S> From<ErrorStack> for HandshakeError<S> {
494 fn from(e: ErrorStack) -> HandshakeError<S> {
495 HandshakeError::SetupFailure(e)
496 }
497 }
498}