use anyhow::Error;
use libc::{c_char, c_int};
use std::{cell::RefCell, slice};
use crate::nullable::Nullable;
thread_local! {
static LAST_ERROR: RefCell<Option<Error>> = RefCell::new(None);
}
pub extern "C" fn clear_last_error() { let _ = take_last_error(); }
pub fn take_last_error() -> Option<Error> {
LAST_ERROR.with(|prev| prev.borrow_mut().take())
}
pub fn update_last_error<E: Into<Error>>(err: E) {
LAST_ERROR.with(|prev| *prev.borrow_mut() = Some(err.into()));
}
pub fn last_error_length() -> c_int {
LAST_ERROR.with(|prev| {
prev.borrow()
.as_ref()
.map(|e| e.to_string().len() + 1)
.unwrap_or(0)
}) as c_int
}
pub fn last_error_length_utf16() -> c_int {
LAST_ERROR.with(|prev| {
prev.borrow()
.as_ref()
.map(|e| e.to_string().encode_utf16().count() + 1)
.unwrap_or(0)
}) as c_int
}
pub fn error_message() -> Option<String> {
LAST_ERROR.with(|prev| prev.borrow().as_ref().map(|e| e.to_string()))
}
pub unsafe fn error_message_utf8(buf: *mut c_char, length: c_int) -> c_int {
crate::null_pointer_check!(buf);
let buffer = slice::from_raw_parts_mut(buf as *mut u8, length as usize);
copy_error_into_buffer(buffer, |msg| msg.into())
}
pub unsafe fn error_message_utf16(buf: *mut u16, length: c_int) -> c_int {
crate::null_pointer_check!(buf);
let buffer = slice::from_raw_parts_mut(buf, length as usize);
let ret =
copy_error_into_buffer(buffer, |msg| msg.encode_utf16().collect());
if ret > 0 {
ret * 2
} else {
ret
}
}
fn copy_error_into_buffer<B, F>(buffer: &mut [B], error_msg: F) -> c_int
where
F: FnOnce(String) -> Vec<B>,
B: Copy + Nullable,
{
let maybe_error_message: Option<Vec<B>> =
error_message().map(|msg| error_msg(msg));
let err_msg = match maybe_error_message {
Some(msg) => msg,
None => return 0,
};
if err_msg.len() + 1 > buffer.len() {
return -1;
}
buffer[..err_msg.len()].copy_from_slice(&err_msg);
buffer[err_msg.len()] = B::NULL;
(err_msg.len() + 1) as c_int
}
#[doc(hidden)]
#[macro_export]
macro_rules! export_c_symbol {
(fn $name:ident($( $arg:ident : $type:ty ),*) -> $ret:ty) => {
#[no_mangle]
pub unsafe extern "C" fn $name($( $arg : $type),*) -> $ret {
$crate::error_handling::$name($( $arg ),*)
}
};
(fn $name:ident($( $arg:ident : $type:ty ),*)) => {
export_c_symbol!(fn $name($( $arg : $type),*) -> ());
}
}
#[macro_export]
macro_rules! export_error_handling_functions {
() => {
#[allow(missing_docs)]
#[doc(hidden)]
pub mod __ffi_helpers_errors {
export_c_symbol!(fn clear_last_error());
export_c_symbol!(fn last_error_length() -> ::libc::c_int);
export_c_symbol!(fn last_error_length_utf16() -> ::libc::c_int);
export_c_symbol!(fn error_message_utf8(buf: *mut ::libc::c_char, length: ::libc::c_int) -> ::libc::c_int);
export_c_symbol!(fn error_message_utf16(buf: *mut u16, length: ::libc::c_int) -> ::libc::c_int);
}
};
}
#[cfg(test)]
mod tests {
use super::*;
use std::str;
fn clear_last_error() {
let _ = LAST_ERROR.with(|e| e.borrow_mut().take());
}
#[test]
fn update_the_error() {
clear_last_error();
let err_msg = "An Error Occurred";
let e = anyhow::anyhow!(err_msg);
update_last_error(e);
let got_err_msg =
LAST_ERROR.with(|e| e.borrow_mut().take().unwrap().to_string());
assert_eq!(got_err_msg, err_msg);
}
#[test]
fn take_the_last_error() {
clear_last_error();
let err_msg = "An Error Occurred";
let e = anyhow::anyhow!(err_msg);
update_last_error(e);
let got_err_msg = take_last_error().unwrap().to_string();
assert_eq!(got_err_msg, err_msg);
}
#[test]
fn get_the_last_error_messages_length() {
clear_last_error();
let err_msg = "An Error Occurred";
let should_be = err_msg.len() + 1;
let e = anyhow::anyhow!(err_msg);
update_last_error(e);
let got = last_error_length();
assert_eq!(got, should_be as _);
clear_last_error();
let got = last_error_length();
assert_eq!(got, 0);
}
#[test]
fn write_the_last_error_message_into_a_buffer() {
clear_last_error();
let err_msg = "An Error Occurred";
let e = anyhow::anyhow!(err_msg);
update_last_error(e);
let mut buffer: Vec<u8> = vec![0; 40];
let bytes_written = unsafe {
error_message_utf8(
buffer.as_mut_ptr() as *mut c_char,
buffer.len() as _,
)
};
assert!(bytes_written > 0);
assert_eq!(bytes_written as usize, err_msg.len() + 1);
let msg =
str::from_utf8(&buffer[..bytes_written as usize - 1]).unwrap();
assert_eq!(msg, err_msg);
}
}