ffi_helpers/
error_handling.rs

1//! Common error handling routines.
2//!
3//! The main error handling method employed is a thread-local variable called
4//! `LAST_ERROR` which holds the most recent error as well as some convenience
5//! functions for getting/clearing this variable.
6//!
7//! The theory is if a function fails then it should return an *"obviously
8//! invalid"* value (typically `-1` or `0` when returning integers or `NULL` for
9//! pointers, see the [`Nullable`] trait for more). The user can then check for
10//! this and consult the most recent error for more information. Of course that
11//! means all fallible operations *must* update the most recent error if they
12//! fail and that you *must* check the returned value of any fallible operation.
13//!
14//! While it isn't as elegant as Rust's monad-style `Result<T, E>` with `?` and
15//! the various combinators, it actually turns out to be a pretty robust error
16//! handling technique in practice.
17//!
18//! > **Note:** It is highly recommended to have a skim through libgit2's
19//! > [error handling docs][libgit2]. The error handling mechanism used here
20//! > takes a lot of inspiration from `libgit2`.
21//!
22//! ## Examples
23//!
24//! The following shows a full example where our `write_data()` function will
25//! try to write some data into a buffer. The first time through
26//!
27//! ```rust
28//! use libc::{c_char, c_int};
29//! use std::slice;
30//! use ffi_helpers::error_handling;
31//! # use anyhow::Error;
32//!
33//! fn main() {
34//!     if unsafe { some_fallible_operation() } != 1 {
35//!         // Before we can retrieve the message we need to know how long it is.
36//!         let err_msg_length = error_handling::last_error_length();
37//!
38//!         // then allocate a big enough buffer
39//!         let mut buffer = vec![0; err_msg_length as usize];
40//!         let bytes_written = unsafe {
41//!             let buf = buffer.as_mut_ptr() as *mut c_char;
42//!             let len = buffer.len() as c_int;
43//!             error_handling::error_message_utf8(buf, len)
44//!         };
45//!
46//!         // then interpret the message
47//!         match bytes_written {
48//!             -1 => panic!("Our buffer wasn't big enough!"),
49//!             0 => panic!("There wasn't an error message... Huh?"),
50//!             len if len > 0 => {
51//!                 buffer.truncate(len as usize - 1);
52//!                 let msg = String::from_utf8(buffer).unwrap();
53//!                 println!("Error: {}", msg);
54//!             }
55//!             _ => unreachable!(),
56//!         }
57//!     }
58//! }
59//!
60//! /// pretend to do some complicated operation, returning whether the
61//! /// operation was successful.
62//! #[no_mangle]
63//! unsafe extern "C" fn some_fallible_operation() -> c_int {
64//!     match do_stuff() {
65//!         Ok(_) => 1, // do_stuff() always errors, so this is unreachable
66//!         Err(e) => {
67//!             ffi_helpers::update_last_error(e);
68//!             0
69//!         }
70//!     }
71//! }
72//!
73//! # fn do_stuff() -> Result<(), Error> { Err(anyhow::anyhow!("An error occurred")) }
74//! ```
75//!
76//! [`Nullable`]: trait.Nullable.html
77//! [libgit2]: https://github.com/libgit2/libgit2/blob/master/docs/error-handling.md
78
79use anyhow::Error;
80use libc::{c_char, c_int};
81use std::{cell::RefCell, slice};
82
83use crate::nullable::Nullable;
84
85thread_local! {
86    static LAST_ERROR: RefCell<Option<Error>> = RefCell::new(None);
87}
88
89/// Clear the `LAST_ERROR`.
90pub extern "C" fn clear_last_error() { let _ = take_last_error(); }
91
92/// Take the most recent error, clearing `LAST_ERROR` in the process.
93pub fn take_last_error() -> Option<Error> {
94    LAST_ERROR.with(|prev| prev.borrow_mut().take())
95}
96
97/// Update the `thread_local` error, taking ownership of the `Error`.
98pub fn update_last_error<E: Into<Error>>(err: E) {
99    LAST_ERROR.with(|prev| *prev.borrow_mut() = Some(err.into()));
100}
101
102/// Get the length of the last error message in bytes when encoded as UTF-8,
103/// including the trailing null.
104pub fn last_error_length() -> c_int {
105    LAST_ERROR.with(|prev| {
106        prev.borrow()
107            .as_ref()
108            .map(|e| e.to_string().len() + 1)
109            .unwrap_or(0)
110    }) as c_int
111}
112
113/// Get the length of the last error message in bytes when encoded as UTF-16,
114/// including the trailing null.
115pub fn last_error_length_utf16() -> c_int {
116    LAST_ERROR.with(|prev| {
117        prev.borrow()
118            .as_ref()
119            .map(|e| e.to_string().encode_utf16().count() + 1)
120            .unwrap_or(0)
121    }) as c_int
122}
123
124/// Peek at the most recent error and get its error message as a Rust `String`.
125pub fn error_message() -> Option<String> {
126    LAST_ERROR.with(|prev| prev.borrow().as_ref().map(|e| e.to_string()))
127}
128
129/// Peek at the most recent error and write its error message (`Display` impl)
130/// into the provided buffer as a UTF-8 encoded string.
131///
132/// This returns the number of bytes written, or `-1` if there was an error.
133pub unsafe fn error_message_utf8(buf: *mut c_char, length: c_int) -> c_int {
134    crate::null_pointer_check!(buf);
135    let buffer = slice::from_raw_parts_mut(buf as *mut u8, length as usize);
136
137    copy_error_into_buffer(buffer, |msg| msg.into())
138}
139
140/// Peek at the most recent error and write its error message (`Display` impl)
141/// into the provided buffer as a UTF-16 encoded string.
142///
143/// This returns the number of bytes written, or `-1` if there was an error.
144pub unsafe fn error_message_utf16(buf: *mut u16, length: c_int) -> c_int {
145    crate::null_pointer_check!(buf);
146    let buffer = slice::from_raw_parts_mut(buf, length as usize);
147
148    let ret =
149        copy_error_into_buffer(buffer, |msg| msg.encode_utf16().collect());
150
151    if ret > 0 {
152        // utf16 uses two bytes per character
153        ret * 2
154    } else {
155        ret
156    }
157}
158
159fn copy_error_into_buffer<B, F>(buffer: &mut [B], error_msg: F) -> c_int
160where
161    F: FnOnce(String) -> Vec<B>,
162    B: Copy + Nullable,
163{
164    let maybe_error_message: Option<Vec<B>> =
165        error_message().map(|msg| error_msg(msg));
166
167    let err_msg = match maybe_error_message {
168        Some(msg) => msg,
169        None => return 0,
170    };
171
172    if err_msg.len() + 1 > buffer.len() {
173        // buffer isn't big enough
174        return -1;
175    }
176
177    buffer[..err_msg.len()].copy_from_slice(&err_msg);
178    // Make sure to add a trailing null in case people use this as a bare char*
179    buffer[err_msg.len()] = B::NULL;
180
181    (err_msg.len() + 1) as c_int
182}
183
184#[doc(hidden)]
185#[macro_export]
186macro_rules! export_c_symbol {
187    (fn $name:ident($( $arg:ident : $type:ty ),*) -> $ret:ty) => {
188        #[no_mangle]
189        pub unsafe extern "C" fn $name($( $arg : $type),*) -> $ret {
190            $crate::error_handling::$name($( $arg ),*)
191        }
192    };
193    (fn $name:ident($( $arg:ident : $type:ty ),*)) => {
194        export_c_symbol!(fn $name($( $arg : $type),*) -> ());
195    }
196}
197
198/// As a workaround for [rust-lang/rfcs#2771][2771], you can use this macro to
199/// make sure the symbols for `ffi_helpers`'s error handling are correctly
200/// exported in your `cdylib`.
201///
202/// [2771]: https://github.com/rust-lang/rfcs/issues/2771
203#[macro_export]
204macro_rules! export_error_handling_functions {
205    () => {
206        #[allow(missing_docs)]
207        #[doc(hidden)]
208        pub mod __ffi_helpers_errors {
209            export_c_symbol!(fn clear_last_error());
210            export_c_symbol!(fn last_error_length() -> ::libc::c_int);
211            export_c_symbol!(fn last_error_length_utf16() -> ::libc::c_int);
212            export_c_symbol!(fn error_message_utf8(buf: *mut ::libc::c_char, length: ::libc::c_int) -> ::libc::c_int);
213            export_c_symbol!(fn error_message_utf16(buf: *mut u16, length: ::libc::c_int) -> ::libc::c_int);
214        }
215    };
216}
217
218#[cfg(test)]
219mod tests {
220    use super::*;
221    use std::str;
222
223    fn clear_last_error() {
224        let _ = LAST_ERROR.with(|e| e.borrow_mut().take());
225    }
226
227    #[test]
228    fn update_the_error() {
229        clear_last_error();
230
231        let err_msg = "An Error Occurred";
232        let e = anyhow::anyhow!(err_msg);
233
234        update_last_error(e);
235
236        let got_err_msg =
237            LAST_ERROR.with(|e| e.borrow_mut().take().unwrap().to_string());
238        assert_eq!(got_err_msg, err_msg);
239    }
240
241    #[test]
242    fn take_the_last_error() {
243        clear_last_error();
244
245        let err_msg = "An Error Occurred";
246        let e = anyhow::anyhow!(err_msg);
247        update_last_error(e);
248
249        let got_err_msg = take_last_error().unwrap().to_string();
250        assert_eq!(got_err_msg, err_msg);
251    }
252
253    #[test]
254    fn get_the_last_error_messages_length() {
255        clear_last_error();
256
257        let err_msg = "An Error Occurred";
258        let should_be = err_msg.len() + 1;
259
260        let e = anyhow::anyhow!(err_msg);
261        update_last_error(e);
262
263        // Get a valid error message's length
264        let got = last_error_length();
265        assert_eq!(got, should_be as _);
266
267        // Then clear the error message and make sure we get 0
268        clear_last_error();
269        let got = last_error_length();
270        assert_eq!(got, 0);
271    }
272
273    #[test]
274    fn write_the_last_error_message_into_a_buffer() {
275        clear_last_error();
276
277        let err_msg = "An Error Occurred";
278
279        let e = anyhow::anyhow!(err_msg);
280        update_last_error(e);
281
282        let mut buffer: Vec<u8> = vec![0; 40];
283        let bytes_written = unsafe {
284            error_message_utf8(
285                buffer.as_mut_ptr() as *mut c_char,
286                buffer.len() as _,
287            )
288        };
289
290        assert!(bytes_written > 0);
291        assert_eq!(bytes_written as usize, err_msg.len() + 1);
292
293        let msg =
294            str::from_utf8(&buffer[..bytes_written as usize - 1]).unwrap();
295        assert_eq!(msg, err_msg);
296    }
297}