cosmian_ffi_utils/
error.rs1use core::{cell::RefCell, fmt::Display};
2use std::ffi::CString;
3
4use crate::ErrorCode;
5
6#[derive(Debug)]
7pub enum FfiError {
8 NullPointer(String),
9 Generic(String),
10}
11
12impl Display for FfiError {
13 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
14 match self {
15 Self::NullPointer(pointer_name) => write!(f, "{pointer_name} shouldn't be null"),
16 Self::Generic(err) => write!(f, "{err}"),
17 }
18 }
19}
20
21thread_local! {
22 static LAST_ERROR: RefCell<Option<Box<FfiError>>> = const { RefCell::new(None) };
24}
25
26#[inline]
30pub fn set_last_error(err: FfiError) {
31 LAST_ERROR.with(|prev| {
32 *prev.borrow_mut() = Some(Box::new(err));
33 });
34}
35
36#[inline]
38#[must_use]
39pub fn get_last_error() -> String {
40 LAST_ERROR
41 .with(|prev| prev.borrow_mut().take())
42 .map_or(String::new(), |e| e.to_string())
43}
44
45#[no_mangle]
58pub unsafe extern "C" fn h_set_error(error_message_ptr: *const i8) -> i32 {
59 let error_message = ffi_read_string!("error message", error_message_ptr);
60 set_last_error(FfiError::Generic(error_message));
61 0
62}
63
64#[no_mangle]
78pub unsafe extern "C" fn h_get_error(error_ptr: *mut i8, error_len: *mut i32) -> i32 {
79 let cs = ffi_unwrap!(
81 CString::new(get_last_error()),
82 "failed to convert error to CString",
83 ErrorCode::InvalidArgument("CString".to_string())
84 );
85
86 ffi_write_bytes!("error", cs.as_bytes(), error_ptr, error_len);
87}
88
89#[cfg(test)]
90mod tests {
91 use std::ptr::null_mut;
92
93 use super::*;
94
95 #[test]
96 fn test_error() {
97 let error_msg = "Emergency!!!";
98
99 let res = unsafe { h_set_error(error_msg.as_ptr().cast::<i8>()) };
101 assert_eq!(res, 0);
102
103 let res = unsafe {
105 let mut bytes = [0u8; 8192];
106 let ptr = bytes.as_mut_ptr().cast();
107 let mut len = bytes.len() as i32;
108 h_get_error(ptr, &mut len);
109 String::from_utf8(bytes[..len as usize].to_vec()).unwrap()
110 };
111 assert!(res.contains(error_msg));
112
113 unsafe {
115 let ptr = null_mut::<u8>();
116 let mut len = 10;
117 h_get_error(ptr.cast(), &mut len);
118 };
119
120 let res = unsafe {
122 let mut bytes = [0u8; 8192];
123 let ptr = bytes.as_mut_ptr().cast();
124 let mut len = bytes.len() as i32;
125 h_get_error(ptr, &mut len);
126 String::from_utf8(bytes[..len as usize].to_vec()).unwrap()
127 };
128 assert!(res.contains("shouldn't be null"));
129 }
130}