1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
//! Common error handling routines.
//!
//! The main error handling method employed is a thread-local variable called
//! `LAST_ERROR` which holds the most recent error as well as some convenience
//! functions for getting/clearing this variable.
//!
//! The theory is if a function fails then it should return an *"obviously
//! invalid"* value (typically `-1` or `0` when returning integers or `NULL` for
//! pointers, see the [`Nullable`] trait for more). The user can then check for
//! this and consult the most recent error for more information. Of course that
//! means all fallible operations *must* update the most recent error if they
//! fail and that you *must* check the returned value of any fallible operation.
//!
//! While it isn't as elegant as Rust's monad-style `Result<T, E>` with `?` and
//! the various combinators, it actually turns out to be a pretty robust error
//! handling technique in practice.
//!
//! > **Note:** It is highly recommended to have a skim through libgit2's
//! > [error handling docs][libgit2]. The error handling mechanism used here
//! > takes a lot of inspiration from `libgit2`.
//!
//! ## Examples
//!
//! The following shows a full example where our `write_data()` function will
//! try to write some data into a buffer. The first time through
//!
//! ```rust
//! use libc::{c_char, c_int};
//! use std::slice;
//! use ffi_helpers::error_handling;
//! # use anyhow::Error;
//!
//! fn main() {
//!     if unsafe { some_fallible_operation() } != 1 {
//!         // Before we can retrieve the message we need to know how long it is.
//!         let err_msg_length = error_handling::last_error_length();
//!
//!         // then allocate a big enough buffer
//!         let mut buffer = vec![0; err_msg_length as usize];
//!         let bytes_written = unsafe {
//!             let buf = buffer.as_mut_ptr() as *mut c_char;
//!             let len = buffer.len() as c_int;
//!             error_handling::error_message_utf8(buf, len)
//!         };
//!
//!         // then interpret the message
//!         match bytes_written {
//!             -1 => panic!("Our buffer wasn't big enough!"),
//!             0 => panic!("There wasn't an error message... Huh?"),
//!             len if len > 0 => {
//!                 buffer.truncate(len as usize - 1);
//!                 let msg = String::from_utf8(buffer).unwrap();
//!                 println!("Error: {}", msg);
//!             }
//!             _ => unreachable!(),
//!         }
//!     }
//! }
//!
//! /// pretend to do some complicated operation, returning whether the
//! /// operation was successful.
//! #[no_mangle]
//! unsafe extern "C" fn some_fallible_operation() -> c_int {
//!     match do_stuff() {
//!         Ok(_) => 1, // do_stuff() always errors, so this is unreachable
//!         Err(e) => {
//!             ffi_helpers::update_last_error(e);
//!             0
//!         }
//!     }
//! }
//!
//! # fn do_stuff() -> Result<(), Error> { Err(anyhow::anyhow!("An error occurred")) }
//! ```
//!
//! [`Nullable`]: trait.Nullable.html
//! [libgit2]: https://github.com/libgit2/libgit2/blob/master/docs/error-handling.md

use anyhow::Error;
use libc::{c_char, c_int};
use std::{cell::RefCell, slice};

use crate::nullable::Nullable;

thread_local! {
    static LAST_ERROR: RefCell<Option<Error>> = RefCell::new(None);
}

/// Clear the `LAST_ERROR`.
pub extern "C" fn clear_last_error() { let _ = take_last_error(); }

/// Take the most recent error, clearing `LAST_ERROR` in the process.
pub fn take_last_error() -> Option<Error> {
    LAST_ERROR.with(|prev| prev.borrow_mut().take())
}

/// Update the `thread_local` error, taking ownership of the `Error`.
pub fn update_last_error<E: Into<Error>>(err: E) {
    LAST_ERROR.with(|prev| *prev.borrow_mut() = Some(err.into()));
}

/// Get the length of the last error message in bytes when encoded as UTF-8,
/// including the trailing null.
pub fn last_error_length() -> c_int {
    LAST_ERROR.with(|prev| {
        prev.borrow()
            .as_ref()
            .map(|e| e.to_string().len() + 1)
            .unwrap_or(0)
    }) as c_int
}

/// Get the length of the last error message in bytes when encoded as UTF-16,
/// including the trailing null.
pub fn last_error_length_utf16() -> c_int {
    LAST_ERROR.with(|prev| {
        prev.borrow()
            .as_ref()
            .map(|e| e.to_string().encode_utf16().count() + 1)
            .unwrap_or(0)
    }) as c_int
}

/// Peek at the most recent error and get its error message as a Rust `String`.
pub fn error_message() -> Option<String> {
    LAST_ERROR.with(|prev| prev.borrow().as_ref().map(|e| e.to_string()))
}

/// Peek at the most recent error and write its error message (`Display` impl)
/// into the provided buffer as a UTF-8 encoded string.
///
/// This returns the number of bytes written, or `-1` if there was an error.
pub unsafe fn error_message_utf8(buf: *mut c_char, length: c_int) -> c_int {
    crate::null_pointer_check!(buf);
    let buffer = slice::from_raw_parts_mut(buf as *mut u8, length as usize);

    copy_error_into_buffer(buffer, |msg| msg.into())
}

/// Peek at the most recent error and write its error message (`Display` impl)
/// into the provided buffer as a UTF-16 encoded string.
///
/// This returns the number of bytes written, or `-1` if there was an error.
pub unsafe fn error_message_utf16(buf: *mut u16, length: c_int) -> c_int {
    crate::null_pointer_check!(buf);
    let buffer = slice::from_raw_parts_mut(buf, length as usize);

    let ret =
        copy_error_into_buffer(buffer, |msg| msg.encode_utf16().collect());

    if ret > 0 {
        // utf16 uses two bytes per character
        ret * 2
    } else {
        ret
    }
}

fn copy_error_into_buffer<B, F>(buffer: &mut [B], error_msg: F) -> c_int
where
    F: FnOnce(String) -> Vec<B>,
    B: Copy + Nullable,
{
    let maybe_error_message: Option<Vec<B>> =
        error_message().map(|msg| error_msg(msg));

    let err_msg = match maybe_error_message {
        Some(msg) => msg,
        None => return 0,
    };

    if err_msg.len() + 1 > buffer.len() {
        // buffer isn't big enough
        return -1;
    }

    buffer[..err_msg.len()].copy_from_slice(&err_msg);
    // Make sure to add a trailing null in case people use this as a bare char*
    buffer[err_msg.len()] = B::NULL;

    (err_msg.len() + 1) as c_int
}

#[doc(hidden)]
#[macro_export]
macro_rules! export_c_symbol {
    (fn $name:ident($( $arg:ident : $type:ty ),*) -> $ret:ty) => {
        #[no_mangle]
        pub unsafe extern "C" fn $name($( $arg : $type),*) -> $ret {
            $crate::error_handling::$name($( $arg ),*)
        }
    };
    (fn $name:ident($( $arg:ident : $type:ty ),*)) => {
        export_c_symbol!(fn $name($( $arg : $type),*) -> ());
    }
}

/// As a workaround for [rust-lang/rfcs#2771][2771], you can use this macro to
/// make sure the symbols for `ffi_helpers`'s error handling are correctly
/// exported in your `cdylib`.
///
/// [2771]: https://github.com/rust-lang/rfcs/issues/2771
#[macro_export]
macro_rules! export_error_handling_functions {
    () => {
        #[allow(missing_docs)]
        #[doc(hidden)]
        pub mod __ffi_helpers_errors {
            export_c_symbol!(fn clear_last_error());
            export_c_symbol!(fn last_error_length() -> ::libc::c_int);
            export_c_symbol!(fn last_error_length_utf16() -> ::libc::c_int);
            export_c_symbol!(fn error_message_utf8(buf: *mut ::libc::c_char, length: ::libc::c_int) -> ::libc::c_int);
            export_c_symbol!(fn error_message_utf16(buf: *mut u16, length: ::libc::c_int) -> ::libc::c_int);
        }
    };
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::str;

    fn clear_last_error() {
        let _ = LAST_ERROR.with(|e| e.borrow_mut().take());
    }

    #[test]
    fn update_the_error() {
        clear_last_error();

        let err_msg = "An Error Occurred";
        let e = anyhow::anyhow!(err_msg);

        update_last_error(e);

        let got_err_msg =
            LAST_ERROR.with(|e| e.borrow_mut().take().unwrap().to_string());
        assert_eq!(got_err_msg, err_msg);
    }

    #[test]
    fn take_the_last_error() {
        clear_last_error();

        let err_msg = "An Error Occurred";
        let e = anyhow::anyhow!(err_msg);
        update_last_error(e);

        let got_err_msg = take_last_error().unwrap().to_string();
        assert_eq!(got_err_msg, err_msg);
    }

    #[test]
    fn get_the_last_error_messages_length() {
        clear_last_error();

        let err_msg = "An Error Occurred";
        let should_be = err_msg.len() + 1;

        let e = anyhow::anyhow!(err_msg);
        update_last_error(e);

        // Get a valid error message's length
        let got = last_error_length();
        assert_eq!(got, should_be as _);

        // Then clear the error message and make sure we get 0
        clear_last_error();
        let got = last_error_length();
        assert_eq!(got, 0);
    }

    #[test]
    fn write_the_last_error_message_into_a_buffer() {
        clear_last_error();

        let err_msg = "An Error Occurred";

        let e = anyhow::anyhow!(err_msg);
        update_last_error(e);

        let mut buffer: Vec<u8> = vec![0; 40];
        let bytes_written = unsafe {
            error_message_utf8(
                buffer.as_mut_ptr() as *mut c_char,
                buffer.len() as _,
            )
        };

        assert!(bytes_written > 0);
        assert_eq!(bytes_written as usize, err_msg.len() + 1);

        let msg =
            str::from_utf8(&buffer[..bytes_written as usize - 1]).unwrap();
        assert_eq!(msg, err_msg);
    }
}