1use std::ffi::{CStr, c_char, c_ulonglong, c_void};
4use std::num::NonZeroU64;
5use std::ptr::null_mut;
6
7use rustix_uring::Errno;
8
9#[repr(C)]
10struct RawTlsServer {
11 _data: [u8; 0],
12 _marker: core::marker::PhantomData<(*mut u8, core::marker::PhantomPinned)>,
13}
14
15#[repr(C)]
17struct RawTlsServerContext {
18 _data: [u8; 0],
19 _marker: core::marker::PhantomData<(*mut u8, core::marker::PhantomPinned)>,
20}
21
22#[repr(C)]
23struct RawError {
24 error_type: i32,
25 error_code: i32,
26}
27
28#[repr(C)]
29struct Slice {
30 buf: *mut u8,
31 size: usize,
32}
33
34#[allow(non_camel_case_types)]
35#[derive(Debug)]
36pub enum OpensslErrorType {
37 SSL_ERROR_NONE = 0,
38 SSL_ERROR_SSL = 1,
39 SSL_ERROR_WANT_READ = 2,
40 SSL_ERROR_WANT_WRITE = 3,
41 SSL_ERROR_WANT_X509_LOOKUP = 4,
42 SSL_ERROR_SYSCALL = 5,
43 SSL_ERROR_ZERO_RETURN = 6,
44 SSL_ERROR_WANT_CONNECT = 7,
45 SSL_ERROR_WANT_ACCEPT = 8,
46 SSL_ERROR_WANT_ASYNC = 9,
47 SSL_ERROR_WANT_ASYNC_JOB = 10,
48 SSL_ERROR_WANT_CLIENT_HELLO_CB = 11,
49 SSL_ERROR_WANT_RETRY_VERIFY = 12,
50
51 InvalidErrorCode,
52}
53
54#[derive(Debug)]
55pub enum TlsServerError {
56 Errno(Errno),
57 TlsError(Vec<u64>),
58}
59
60impl std::error::Error for TlsServerError {}
61
62impl std::fmt::Display for TlsServerError {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 match self {
65 TlsServerError::Errno(errno) => std::fmt::Display::fmt(errno, f),
66 TlsServerError::TlsError(errors) => {
67 for e in errors {
68 let e = *e;
69 let lib = unsafe { ERR_lib_error_string(e) };
70 let func = unsafe { ERR_func_error_string(e) };
71 let reason = unsafe { ERR_reason_error_string(e) };
72 let empty: [i8; 1] = [0; 1];
73 let lib = unsafe {
74 CStr::from_ptr(if lib.is_null() {
75 empty.as_ptr() as *const c_char
76 } else {
77 lib
78 })
79 };
80 let func = unsafe {
81 CStr::from_ptr(if func.is_null() {
82 empty.as_ptr() as *const c_char
83 } else {
84 func
85 })
86 };
87 let reason = unsafe {
88 CStr::from_ptr(if reason.is_null() {
89 empty.as_ptr() as *const c_char
90 } else {
91 reason
92 })
93 };
94
95 let message = std::fmt::format(format_args!(
96 "TlsError error:{e} lib:{lib:?} func:{func:?} reason:{reason:?}\n"
97 ));
98 f.write_str(&message)?;
99 }
100 Ok(())
101 }
102 }
103 }
104}
105
106pub fn get_error_details(code: u64) -> (String, String, Option<String>) {
107 let lib = unsafe { ERR_lib_error_string(code) };
108 let func = unsafe { ERR_func_error_string(code) };
109 let reason = unsafe { ERR_reason_error_string(code) };
110 let empty: [i8; 1] = [0; 1];
111 let lib = String::from(
112 unsafe {
113 CStr::from_ptr(if lib.is_null() {
114 empty.as_ptr() as *const c_char
115 } else {
116 lib
117 })
118 }
119 .to_string_lossy(),
120 );
121 let func = String::from(
122 unsafe {
123 CStr::from_ptr(if func.is_null() {
124 empty.as_ptr() as *const c_char
125 } else {
126 func
127 })
128 }
129 .to_string_lossy(),
130 );
131 let reason = if reason.is_null() {
132 None
133 } else {
134 Some(String::from(
135 unsafe { CStr::from_ptr(reason) }.to_string_lossy(),
136 ))
137 };
138 (lib, func, reason)
139}
140
141unsafe extern "C" {
142 fn tls_handle_close(tls: *mut RawTlsServer);
144 fn tls_handle_dup(tls: *mut RawTlsServer, server: &mut *mut RawTlsServer) -> RawError;
145 fn tls_handle_ctx_close(tls_ctx: *mut RawTlsServerContext);
146
147 fn tls_handle_create(
148 server_ctx: *mut RawTlsServerContext,
149 bufsize: usize,
150 is_server: bool,
151 server: &mut *mut RawTlsServer,
152 ) -> RawError;
153
154 fn tls_handle_push_get_buffer(tls: *mut RawTlsServer, slice: &mut Slice) -> i32;
157 fn tls_handle_push_advance(tls: *mut RawTlsServer, amount: usize) -> i32;
158
159 fn tls_handle_pull_get_buffer(tls: *mut RawTlsServer, slice: &mut Slice) -> i32;
162 fn tls_handle_pull_advance(tls: *mut RawTlsServer, amount: usize) -> i32;
163 fn tls_handle_read(tls: *mut RawTlsServer, buffer: *mut u8, length: isize) -> RawError;
164 fn tls_handle_write(tls: *mut RawTlsServer, buffer: *const u8, length: isize) -> RawError;
165 fn tls_handle_server_side_handshake(tls: *mut RawTlsServer) -> RawError;
166 fn tls_handle_client_side_handshake(tls: *mut RawTlsServer) -> RawError;
167 fn tls_handle_shutdown(tls: *mut RawTlsServer) -> RawError;
168
169 fn ERR_get_error() -> c_ulonglong;
171 fn ERR_lib_error_string(e: c_ulonglong) -> *const c_char;
172 fn ERR_func_error_string(e: c_ulonglong) -> *const c_char;
173 fn ERR_reason_error_string(e: c_ulonglong) -> *const c_char;
174
175 fn OpenSSL_version_num() -> c_ulonglong;
176
177 fn tls_get_ssl(tls: *mut RawTlsServer) -> *mut c_void;
179
180 fn tls_get_min_proto_version(ctx: *mut RawTlsServerContext) -> i32;
181}
182
183pub struct TlsServerContext {
184 ctx: *mut RawTlsServerContext,
185}
186
187unsafe impl Send for TlsServerContext {}
191
192pub struct TlsServer {
193 server: *mut RawTlsServer,
194}
195
196unsafe impl Send for TlsServer {}
200
201pub enum Response {
202 Success(usize),
203 Fail(TlsServerError),
204 Eof,
205 WantRead,
206 WantWrite,
207}
208
209fn get_response(error: RawError) -> Response {
210 match error.error_type {
211 0 => Response::Success(error.error_code as usize),
212 1 => Response::Fail(TlsServerError::TlsError(get_ssl_error())),
213 2 => match error.error_code {
214 0 => Response::Fail(TlsServerError::Errno(Errno::INVAL)),
215 1 => Response::Fail(TlsServerError::TlsError(get_ssl_error())),
217 2 => Response::WantRead,
219 3 => Response::WantWrite,
221 4 => Response::Fail(TlsServerError::Errno(Errno::PROTO)),
223 5 => Response::Fail(TlsServerError::TlsError(get_ssl_error())),
225 6 => Response::Eof,
227 7 => Response::Fail(TlsServerError::Errno(Errno::PROTO)),
229 8 => Response::Fail(TlsServerError::Errno(Errno::PROTO)),
231 9 => Response::Fail(TlsServerError::Errno(Errno::PROTO)),
233 10 => Response::Fail(TlsServerError::Errno(Errno::PROTO)),
235 11 => Response::Fail(TlsServerError::Errno(Errno::PROTO)),
237 12 => Response::Fail(TlsServerError::Errno(Errno::PROTO)),
239 _ => Response::Fail(TlsServerError::Errno(Errno::INVAL)),
240 },
241 3 => Response::Fail(TlsServerError::Errno(Errno::from_raw_os_error(
242 error.error_code,
243 ))),
244 4 => Response::Eof,
245 5 => Response::WantWrite,
246 6 => Response::WantRead,
247 _ => panic!("unexpected error code"),
248 }
249}
250
251fn get_error() -> Option<NonZeroU64> {
252 let code = unsafe { ERR_get_error() };
253 NonZeroU64::new(code)
254}
255
256fn get_ssl_error() -> Vec<u64> {
257 let mut codes: Vec<u64> = Vec::with_capacity(4);
258 while let Some(code) = get_error() {
259 codes.push(code.into());
260 }
261 codes
262}
263
264impl TlsServerContext {
265 pub fn server(&self, bufsize: usize) -> Result<TlsServer, TlsServerError> {
266 let mut server: *mut RawTlsServer = null_mut();
267 let result =
268 get_response(unsafe { tls_handle_create(self.ctx, bufsize, true, &mut server) });
269 match result {
270 Response::Success(_) => (),
271 Response::Fail(e) => return Err(e),
272 _ => panic!("Unexpected response"),
273 }
274 Ok(TlsServer { server })
275 }
276
277 pub fn client(&self, bufsize: usize) -> Result<TlsServer, TlsServerError> {
278 let mut client: *mut RawTlsServer = null_mut();
279 let result =
280 get_response(unsafe { tls_handle_create(self.ctx, bufsize, false, &mut client) });
281 match result {
282 Response::Success(_) => (),
283 Response::Fail(e) => return Err(e),
284 _ => panic!("Unexpected response"),
285 }
286 Ok(TlsServer { server: client })
287 }
288
289 pub fn get_min_proto_version(&self) -> i32 {
291 unsafe { tls_get_min_proto_version(self.ctx) }
292 }
293
294 pub fn from_raw(ctx: *mut c_void) -> Self {
296 assert!(!ctx.is_null(), "Context pointer must not be null");
297 let ctx = ctx as *mut RawTlsServerContext;
298 Self { ctx }
299 }
300}
301
302impl Drop for TlsServerContext {
303 fn drop(&mut self) {
304 unsafe {
305 tls_handle_ctx_close(self.ctx);
306 }
307 }
308}
309
310impl TlsServer {
312 pub fn client_side_handshake(&mut self) -> Response {
313 get_response(unsafe { tls_handle_client_side_handshake(self.server) })
314 }
315
316 pub fn server_side_handshake(&mut self) -> Response {
318 get_response(unsafe { tls_handle_server_side_handshake(self.server) })
319 }
320
321 pub fn get_ssl_raw(&self) -> *mut c_void {
323 unsafe { tls_get_ssl(self.server) }
324 }
325
326 pub fn shutdown(&mut self) -> Response {
327 get_response(unsafe { tls_handle_shutdown(self.server) })
328 }
329
330 pub fn read(&mut self, buffer: &mut [u8]) -> Response {
331 get_response(unsafe {
332 tls_handle_read(self.server, buffer.as_mut_ptr(), buffer.len() as isize)
333 })
334 }
335
336 pub fn write(&mut self, buffer: &[u8]) -> Response {
337 get_response(unsafe {
338 tls_handle_write(self.server, buffer.as_ptr(), buffer.len() as isize)
339 })
340 }
341
342 pub fn get_push_buffer(&mut self) -> Option<&mut [u8]> {
343 let mut slice = Slice {
344 buf: null_mut(),
345 size: 0,
346 };
347 let result = unsafe { tls_handle_push_get_buffer(self.server, &mut slice) };
348 if result > 0 {
349 Some(unsafe { std::slice::from_raw_parts_mut(slice.buf, slice.size) })
350 } else {
351 None
352 }
353 }
354
355 pub fn use_push_buffer(&mut self, amount: usize) {
356 let result = unsafe { tls_handle_push_advance(self.server, amount) };
357 assert!(result as usize == amount);
358 }
359
360 pub fn get_pull_buffer(&self) -> Option<&[u8]> {
361 let mut slice = Slice {
362 buf: null_mut(),
363 size: 0,
364 };
365 let result = unsafe { tls_handle_pull_get_buffer(self.server, &mut slice) };
366 if result > 0 {
367 Some(unsafe { std::slice::from_raw_parts(slice.buf, slice.size) })
368 } else {
369 None
370 }
371 }
372
373 pub fn use_pull_buffer(&mut self, amount: usize) {
374 let result = unsafe { tls_handle_pull_advance(self.server, amount) };
375 assert!(result as usize == amount);
376 }
377}
378
379impl Clone for TlsServer {
380 fn clone(&self) -> Self {
381 let mut server: *mut RawTlsServer = null_mut();
382 let result = get_response(unsafe { tls_handle_dup(self.server, &mut server) });
383 match result {
384 Response::Success(_) => (),
385 Response::Fail(e) => panic!("dup failed {e:?}"),
386 _ => panic!("Unexpected response"),
387 }
388 TlsServer { server }
389 }
390}
391
392impl Drop for TlsServer {
393 fn drop(&mut self) {
394 unsafe {
395 tls_handle_close(self.server);
396 }
397 }
398}
399
400pub fn version() -> (u64, u64, u64) {
401 let version = unsafe { OpenSSL_version_num() };
402 (
403 (version >> 28) & 0xf,
404 (version >> 20) & 0xff,
405 (version >> 4) & 0xff,
406 )
407}