cosmian_ffi_utils/
error.rs

1use core::{cell::RefCell, fmt::Display};
2use std::ffi::CString;
3
4use crate::ErrorCode;
5
6#[derive(Debug)]
7pub enum FfiError {
8    NullPointer(String),
9    Generic(String),
10}
11
12impl Display for FfiError {
13    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
14        match self {
15            Self::NullPointer(pointer_name) => write!(f, "{pointer_name} shouldn't be null"),
16            Self::Generic(err) => write!(f, "{err}"),
17        }
18    }
19}
20
21thread_local! {
22    /// a thread-local variable which holds the most recent error
23    static LAST_ERROR: RefCell<Option<Box<FfiError>>> = const { RefCell::new(None) };
24}
25
26/// Sets the most recent error, clearing whatever may have been there before.
27///
28/// - `err` : error to set
29#[inline]
30pub fn set_last_error(err: FfiError) {
31    LAST_ERROR.with(|prev| {
32        *prev.borrow_mut() = Some(Box::new(err));
33    });
34}
35
36/// Gets the last error message.
37#[inline]
38#[must_use]
39pub fn get_last_error() -> String {
40    LAST_ERROR
41        .with(|prev| prev.borrow_mut().take())
42        .map_or(String::new(), |e| e.to_string())
43}
44
45/// Externally sets the last error recorded on the Rust side.
46///
47/// # Safety
48///
49/// The pointer must point to a null-terminated string.
50///
51/// This function is meant to be called from the Foreign Function
52/// Interface.
53///
54/// # Parameters
55///
56/// - `error_message_ptr`   : pointer to the error message to set
57#[no_mangle]
58pub unsafe extern "C" fn h_set_error(error_message_ptr: *const i8) -> i32 {
59    let error_message = ffi_read_string!("error message", error_message_ptr);
60    set_last_error(FfiError::Generic(error_message));
61    0
62}
63
64/// Externally gets the most recent error recorded on the Rust side, clearing
65/// it in the process.
66///
67/// # Safety
68///
69/// The pointer `error_ptr` should point to a buffer which has been allocated
70/// `error_len` bytes. If the allocated size is smaller than `error_len`, a
71/// call to this function may result in a buffer overflow.
72///
73/// # Parameters
74///
75/// - `error_ptr`: pointer to the buffer to which to write the error
76/// - `error_len`: size of the allocated memory
77#[no_mangle]
78pub unsafe extern "C" fn h_get_error(error_ptr: *mut i8, error_len: *mut i32) -> i32 {
79    // Get the error message as a null terminated string.
80    let cs = ffi_unwrap!(
81        CString::new(get_last_error()),
82        "failed to convert error to CString",
83        ErrorCode::InvalidArgument("CString".to_string())
84    );
85
86    ffi_write_bytes!("error", cs.as_bytes(), error_ptr, error_len);
87}
88
89#[cfg(test)]
90mod tests {
91    use std::ptr::null_mut;
92
93    use super::*;
94
95    #[test]
96    fn test_error() {
97        let error_msg = "Emergency!!!";
98
99        // Set the error message.
100        let res = unsafe { h_set_error(error_msg.as_ptr().cast::<i8>()) };
101        assert_eq!(res, 0);
102
103        // Reads the error message.
104        let res = unsafe {
105            let mut bytes = [0u8; 8192];
106            let ptr = bytes.as_mut_ptr().cast();
107            let mut len = bytes.len() as i32;
108            h_get_error(ptr, &mut len);
109            String::from_utf8(bytes[..len as usize].to_vec()).unwrap()
110        };
111        assert!(res.contains(error_msg));
112
113        // Reads the error message.
114        unsafe {
115            let ptr = null_mut::<u8>();
116            let mut len = 10;
117            h_get_error(ptr.cast(), &mut len);
118        };
119
120        // Reads the error message.
121        let res = unsafe {
122            let mut bytes = [0u8; 8192];
123            let ptr = bytes.as_mut_ptr().cast();
124            let mut len = bytes.len() as i32;
125            h_get_error(ptr, &mut len);
126            String::from_utf8(bytes[..len as usize].to_vec()).unwrap()
127        };
128        assert!(res.contains("shouldn't be null"));
129    }
130}