Skip to main content

kimojio_tls/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3use std::ffi::{CStr, c_char, c_ulonglong, c_void};
4use std::num::NonZeroU64;
5use std::ptr::null_mut;
6
7use rustix_uring::Errno;
8
9#[repr(C)]
10struct RawTlsServer {
11    _data: [u8; 0],
12    _marker: core::marker::PhantomData<(*mut u8, core::marker::PhantomPinned)>,
13}
14
15/// This is the same as SSL_CTX
16#[repr(C)]
17struct RawTlsServerContext {
18    _data: [u8; 0],
19    _marker: core::marker::PhantomData<(*mut u8, core::marker::PhantomPinned)>,
20}
21
22#[repr(C)]
23struct RawError {
24    error_type: i32,
25    error_code: i32,
26}
27
28#[repr(C)]
29struct Slice {
30    buf: *mut u8,
31    size: usize,
32}
33
34#[allow(non_camel_case_types)]
35#[derive(Debug)]
36pub enum OpensslErrorType {
37    SSL_ERROR_NONE = 0,
38    SSL_ERROR_SSL = 1,
39    SSL_ERROR_WANT_READ = 2,
40    SSL_ERROR_WANT_WRITE = 3,
41    SSL_ERROR_WANT_X509_LOOKUP = 4,
42    SSL_ERROR_SYSCALL = 5,
43    SSL_ERROR_ZERO_RETURN = 6,
44    SSL_ERROR_WANT_CONNECT = 7,
45    SSL_ERROR_WANT_ACCEPT = 8,
46    SSL_ERROR_WANT_ASYNC = 9,
47    SSL_ERROR_WANT_ASYNC_JOB = 10,
48    SSL_ERROR_WANT_CLIENT_HELLO_CB = 11,
49    SSL_ERROR_WANT_RETRY_VERIFY = 12,
50
51    InvalidErrorCode,
52}
53
54#[derive(Debug)]
55pub enum TlsServerError {
56    Errno(Errno),
57    TlsError(Vec<u64>),
58}
59
60impl std::error::Error for TlsServerError {}
61
62impl std::fmt::Display for TlsServerError {
63    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64        match self {
65            TlsServerError::Errno(errno) => std::fmt::Display::fmt(errno, f),
66            TlsServerError::TlsError(errors) => {
67                for e in errors {
68                    let e = *e;
69                    let lib = unsafe { ERR_lib_error_string(e) };
70                    let func = unsafe { ERR_func_error_string(e) };
71                    let reason = unsafe { ERR_reason_error_string(e) };
72                    let empty: [i8; 1] = [0; 1];
73                    let lib = unsafe {
74                        CStr::from_ptr(if lib.is_null() {
75                            empty.as_ptr() as *const c_char
76                        } else {
77                            lib
78                        })
79                    };
80                    let func = unsafe {
81                        CStr::from_ptr(if func.is_null() {
82                            empty.as_ptr() as *const c_char
83                        } else {
84                            func
85                        })
86                    };
87                    let reason = unsafe {
88                        CStr::from_ptr(if reason.is_null() {
89                            empty.as_ptr() as *const c_char
90                        } else {
91                            reason
92                        })
93                    };
94
95                    let message = std::fmt::format(format_args!(
96                        "TlsError error:{e} lib:{lib:?} func:{func:?} reason:{reason:?}\n"
97                    ));
98                    f.write_str(&message)?;
99                }
100                Ok(())
101            }
102        }
103    }
104}
105
106pub fn get_error_details(code: u64) -> (String, String, Option<String>) {
107    let lib = unsafe { ERR_lib_error_string(code) };
108    let func = unsafe { ERR_func_error_string(code) };
109    let reason = unsafe { ERR_reason_error_string(code) };
110    let empty: [i8; 1] = [0; 1];
111    let lib = String::from(
112        unsafe {
113            CStr::from_ptr(if lib.is_null() {
114                empty.as_ptr() as *const c_char
115            } else {
116                lib
117            })
118        }
119        .to_string_lossy(),
120    );
121    let func = String::from(
122        unsafe {
123            CStr::from_ptr(if func.is_null() {
124                empty.as_ptr() as *const c_char
125            } else {
126                func
127            })
128        }
129        .to_string_lossy(),
130    );
131    let reason = if reason.is_null() {
132        None
133    } else {
134        Some(String::from(
135            unsafe { CStr::from_ptr(reason) }.to_string_lossy(),
136        ))
137    };
138    (lib, func, reason)
139}
140
141unsafe extern "C" {
142    // "fat" wrapper methods
143    fn tls_handle_close(tls: *mut RawTlsServer);
144    fn tls_handle_dup(tls: *mut RawTlsServer, server: &mut *mut RawTlsServer) -> RawError;
145    fn tls_handle_ctx_close(tls_ctx: *mut RawTlsServerContext);
146
147    fn tls_handle_create(
148        server_ctx: *mut RawTlsServerContext,
149        bufsize: usize,
150        is_server: bool,
151        server: &mut *mut RawTlsServer,
152    ) -> RawError;
153
154    // tls_handle_push_get_buffer should be called to get access to openssl buffer
155    // and then the actual amounts written should be passed to tls_handle_push_advance.
156    fn tls_handle_push_get_buffer(tls: *mut RawTlsServer, slice: &mut Slice) -> i32;
157    fn tls_handle_push_advance(tls: *mut RawTlsServer, amount: usize) -> i32;
158
159    // tls_handle_pull_get_buffer should be called to get access to openssl buffer
160    // and then the actual amounts read should be passed to tls_handle_pull_advance.
161    fn tls_handle_pull_get_buffer(tls: *mut RawTlsServer, slice: &mut Slice) -> i32;
162    fn tls_handle_pull_advance(tls: *mut RawTlsServer, amount: usize) -> i32;
163    fn tls_handle_read(tls: *mut RawTlsServer, buffer: *mut u8, length: isize) -> RawError;
164    fn tls_handle_write(tls: *mut RawTlsServer, buffer: *const u8, length: isize) -> RawError;
165    fn tls_handle_server_side_handshake(tls: *mut RawTlsServer) -> RawError;
166    fn tls_handle_client_side_handshake(tls: *mut RawTlsServer) -> RawError;
167    fn tls_handle_shutdown(tls: *mut RawTlsServer) -> RawError;
168
169    // Extra openssl methods for getting information about errors
170    fn ERR_get_error() -> c_ulonglong;
171    fn ERR_lib_error_string(e: c_ulonglong) -> *const c_char;
172    fn ERR_func_error_string(e: c_ulonglong) -> *const c_char;
173    fn ERR_reason_error_string(e: c_ulonglong) -> *const c_char;
174
175    fn OpenSSL_version_num() -> c_ulonglong;
176
177    // This gets the reference of the SSL object.
178    fn tls_get_ssl(tls: *mut RawTlsServer) -> *mut c_void;
179
180    fn tls_get_min_proto_version(ctx: *mut RawTlsServerContext) -> i32;
181}
182
183pub struct TlsServerContext {
184    ctx: *mut RawTlsServerContext,
185}
186
187// SAFETY: OpenSSL is safe to call from different
188// threads as long as not the same time (Send is
189// ok, but not Sync)
190unsafe impl Send for TlsServerContext {}
191
192pub struct TlsServer {
193    server: *mut RawTlsServer,
194}
195
196// SAFETY: OpenSSL is safe to call from different
197// threads as long as not the same time (Send is
198// ok, but not Sync)
199unsafe impl Send for TlsServer {}
200
201pub enum Response {
202    Success(usize),
203    Fail(TlsServerError),
204    Eof,
205    WantRead,
206    WantWrite,
207}
208
209fn get_response(error: RawError) -> Response {
210    match error.error_type {
211        0 => Response::Success(error.error_code as usize),
212        1 => Response::Fail(TlsServerError::TlsError(get_ssl_error())),
213        2 => match error.error_code {
214            0 => Response::Fail(TlsServerError::Errno(Errno::INVAL)),
215            // SSL_ERROR_SSL
216            1 => Response::Fail(TlsServerError::TlsError(get_ssl_error())),
217            // SSL_ERROR_WANT_READ
218            2 => Response::WantRead,
219            // SSL_ERROR_WANT_WRITE
220            3 => Response::WantWrite,
221            // SSL_ERROR_WANT_X509_LOOKUP
222            4 => Response::Fail(TlsServerError::Errno(Errno::PROTO)),
223            // SSL_ERROR_SYSCALL
224            5 => Response::Fail(TlsServerError::TlsError(get_ssl_error())),
225            // SSL_ERROR_ZERO_RETURN
226            6 => Response::Eof,
227            // SSL_ERROR_WANT_CONNECT
228            7 => Response::Fail(TlsServerError::Errno(Errno::PROTO)),
229            // SSL_ERROR_WANT_ACCEPT
230            8 => Response::Fail(TlsServerError::Errno(Errno::PROTO)),
231            // SSL_ERROR_WANT_ASYNC
232            9 => Response::Fail(TlsServerError::Errno(Errno::PROTO)),
233            // SSL_ERROR_WANT_ASYNC_JOB
234            10 => Response::Fail(TlsServerError::Errno(Errno::PROTO)),
235            // SSL_ERROR_WANT_HELLO_CB
236            11 => Response::Fail(TlsServerError::Errno(Errno::PROTO)),
237            // SSL_ERROR_WANT_RETRY_VERIFY
238            12 => Response::Fail(TlsServerError::Errno(Errno::PROTO)),
239            _ => Response::Fail(TlsServerError::Errno(Errno::INVAL)),
240        },
241        3 => Response::Fail(TlsServerError::Errno(Errno::from_raw_os_error(
242            error.error_code,
243        ))),
244        4 => Response::Eof,
245        5 => Response::WantWrite,
246        6 => Response::WantRead,
247        _ => panic!("unexpected error code"),
248    }
249}
250
251fn get_error() -> Option<NonZeroU64> {
252    let code = unsafe { ERR_get_error() };
253    NonZeroU64::new(code)
254}
255
256fn get_ssl_error() -> Vec<u64> {
257    let mut codes: Vec<u64> = Vec::with_capacity(4);
258    while let Some(code) = get_error() {
259        codes.push(code.into());
260    }
261    codes
262}
263
264impl TlsServerContext {
265    pub fn server(&self, bufsize: usize) -> Result<TlsServer, TlsServerError> {
266        let mut server: *mut RawTlsServer = null_mut();
267        let result =
268            get_response(unsafe { tls_handle_create(self.ctx, bufsize, true, &mut server) });
269        match result {
270            Response::Success(_) => (),
271            Response::Fail(e) => return Err(e),
272            _ => panic!("Unexpected response"),
273        }
274        Ok(TlsServer { server })
275    }
276
277    pub fn client(&self, bufsize: usize) -> Result<TlsServer, TlsServerError> {
278        let mut client: *mut RawTlsServer = null_mut();
279        let result =
280            get_response(unsafe { tls_handle_create(self.ctx, bufsize, false, &mut client) });
281        match result {
282            Response::Success(_) => (),
283            Response::Fail(e) => return Err(e),
284            _ => panic!("Unexpected response"),
285        }
286        Ok(TlsServer { server: client })
287    }
288
289    /// Get the minimum TLS protocol version for this context
290    pub fn get_min_proto_version(&self) -> i32 {
291        unsafe { tls_get_min_proto_version(self.ctx) }
292    }
293
294    /// From the raw *mut SSL_CTX pointer.
295    pub fn from_raw(ctx: *mut c_void) -> Self {
296        assert!(!ctx.is_null(), "Context pointer must not be null");
297        let ctx = ctx as *mut RawTlsServerContext;
298        Self { ctx }
299    }
300}
301
302impl Drop for TlsServerContext {
303    fn drop(&mut self) {
304        unsafe {
305            tls_handle_ctx_close(self.ctx);
306        }
307    }
308}
309
310// Note: The input parameter authorized_server_name is optional.
311impl TlsServer {
312    pub fn client_side_handshake(&mut self) -> Response {
313        get_response(unsafe { tls_handle_client_side_handshake(self.server) })
314    }
315
316    /// Represents the server side execution of a TLS handshake with a client.
317    pub fn server_side_handshake(&mut self) -> Response {
318        get_response(unsafe { tls_handle_server_side_handshake(self.server) })
319    }
320
321    /// Gets the reference to SSL object.
322    pub fn get_ssl_raw(&self) -> *mut c_void {
323        unsafe { tls_get_ssl(self.server) }
324    }
325
326    pub fn shutdown(&mut self) -> Response {
327        get_response(unsafe { tls_handle_shutdown(self.server) })
328    }
329
330    pub fn read(&mut self, buffer: &mut [u8]) -> Response {
331        get_response(unsafe {
332            tls_handle_read(self.server, buffer.as_mut_ptr(), buffer.len() as isize)
333        })
334    }
335
336    pub fn write(&mut self, buffer: &[u8]) -> Response {
337        get_response(unsafe {
338            tls_handle_write(self.server, buffer.as_ptr(), buffer.len() as isize)
339        })
340    }
341
342    pub fn get_push_buffer(&mut self) -> Option<&mut [u8]> {
343        let mut slice = Slice {
344            buf: null_mut(),
345            size: 0,
346        };
347        let result = unsafe { tls_handle_push_get_buffer(self.server, &mut slice) };
348        if result > 0 {
349            Some(unsafe { std::slice::from_raw_parts_mut(slice.buf, slice.size) })
350        } else {
351            None
352        }
353    }
354
355    pub fn use_push_buffer(&mut self, amount: usize) {
356        let result = unsafe { tls_handle_push_advance(self.server, amount) };
357        assert!(result as usize == amount);
358    }
359
360    pub fn get_pull_buffer(&self) -> Option<&[u8]> {
361        let mut slice = Slice {
362            buf: null_mut(),
363            size: 0,
364        };
365        let result = unsafe { tls_handle_pull_get_buffer(self.server, &mut slice) };
366        if result > 0 {
367            Some(unsafe { std::slice::from_raw_parts(slice.buf, slice.size) })
368        } else {
369            None
370        }
371    }
372
373    pub fn use_pull_buffer(&mut self, amount: usize) {
374        let result = unsafe { tls_handle_pull_advance(self.server, amount) };
375        assert!(result as usize == amount);
376    }
377}
378
379impl Clone for TlsServer {
380    fn clone(&self) -> Self {
381        let mut server: *mut RawTlsServer = null_mut();
382        let result = get_response(unsafe { tls_handle_dup(self.server, &mut server) });
383        match result {
384            Response::Success(_) => (),
385            Response::Fail(e) => panic!("dup failed {e:?}"),
386            _ => panic!("Unexpected response"),
387        }
388        TlsServer { server }
389    }
390}
391
392impl Drop for TlsServer {
393    fn drop(&mut self) {
394        unsafe {
395            tls_handle_close(self.server);
396        }
397    }
398}
399
400pub fn version() -> (u64, u64, u64) {
401    let version = unsafe { OpenSSL_version_num() };
402    (
403        (version >> 28) & 0xf,
404        (version >> 20) & 0xff,
405        (version >> 4) & 0xff,
406    )
407}