diskann_quantization/
error.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6//! Formatting utilities for error chains.
7
8use std::{cell::UnsafeCell, marker::PhantomData, mem::MaybeUninit};
9
10/// Format the entire error chain for `err` by first calling `err.to_string()` and then
11/// by walking the error's
12/// [source tree](https://doc.rust-lang.org/std/error/trait.Error.html#method.source).
13pub fn format<E>(err: &E) -> String
14where
15    E: std::error::Error + ?Sized,
16{
17    // Cast wrap the walking of the source chain into something that behaves like an
18    // iterator.
19    struct SourceIterator<'a>(Option<&'a (dyn std::error::Error + 'static)>);
20    impl<'a> Iterator for SourceIterator<'a> {
21        type Item = &'a (dyn std::error::Error + 'static);
22        fn next(&mut self) -> Option<Self::Item> {
23            let current = self.0;
24            self.0 = match current {
25                Some(current) => current.source(),
26                None => None,
27            };
28            current
29        }
30    }
31
32    // Get the base message from the error.
33    let mut message = err.to_string();
34    // Walk the source chain, formatting each
35    for source in SourceIterator(err.source()) {
36        message.push_str("\n    caused by: ");
37        message.push_str(&source.to_string());
38    }
39    message
40}
41
42/// An implementation of `Box<dyn std::error::Error>` that stores the error payload inline,
43/// avoiding dynamic memory allocation. This has several practical drawbacks:
44///
45/// 1. The size of the error payload must be at most `N` bytes.
46/// 2. The alignment of the error payload must be at most 8 bytes.
47///
48/// Both of these contraints are verified using post-monomorphization errors.
49///
50/// # Example
51///
52/// ```
53/// use diskann_quantization::error::InlineError;
54///
55/// let base_error = u32::try_from(u64::MAX).unwrap_err();
56/// let mut error = InlineError::<8>::new(base_error);
57/// assert_eq!(error.to_string(), base_error.to_string());
58///
59/// // Change the dynamic type of the contained error.
60/// error = InlineError::new(Box::new(base_error));
61/// assert_eq!(error.to_string(), base_error.to_string());
62/// ```
63#[repr(C)]
64pub struct InlineError<const N: usize = 16> {
65    // We place the vtable first to enable the niche-optimization.
66    vtable: &'static ErrorVTable,
67
68    // NOTE: We need to use `MaybeUninit` instead of `u8` to maintain the provenance of
69    // any pointers/references stored in the payload.
70    //
71    // Additionally, the `UnsafeCell` is needed because the payload may have interior
72    // mutability as a side-effect of API calls.
73    object: UnsafeCell<[MaybeUninit<u8>; N]>,
74}
75
76// SAFETY: We only allow error payloads that are `Send`.
77unsafe impl<const N: usize> Send for InlineError<N> {}
78
79// SAFETY: We only allow error payloads that are `Sync`.
80unsafe impl<const N: usize> Sync for InlineError<N> {}
81
82impl<const N: usize> InlineError<N> {
83    /// Construct a new `InlineError` around `error`.
84    ///
85    /// Fails to compile if:
86    ///
87    /// 1. `std::mem::align_of::<T>() > 8`: Objects of type `T` must be compatible with the
88    ///    inline storage buffer.
89    /// 2. `std::mem::size_of::<T>() > N`: Objects of type `T` must fit within a buffer of
90    ///    size `N`.
91    pub fn new<T>(error: T) -> Self
92    where
93        T: std::error::Error + Send + Sync + 'static,
94    {
95        const { assert!(std::mem::size_of::<T>() <= N, "error type is too big") };
96        const {
97            assert!(
98                std::mem::align_of::<T>() <= std::mem::align_of::<&'static ErrorVTable>(),
99                "error type has alignment stricter than 8"
100            )
101        };
102
103        let mut this = Self {
104            vtable: &ErrorVTable {
105                debug: error_debug::<T>,
106                display: error_display::<T>,
107                source: error_source::<T>,
108                drop: error_drop::<T>,
109            },
110            object: UnsafeCell::new([MaybeUninit::uninit(); N]),
111        };
112
113        // SAFETY: We have const assertions that the size and alignment of `T` are
114        // compatible with the buffer we created.
115        //
116        // Additionally, the memory we are writing to does not have a valid object stored,
117        // so using `ptr::write` will not leak memory.
118        unsafe { this.object.get_mut().as_mut_ptr().cast::<T>().write(error) };
119
120        this
121    }
122
123    // Return the base pointer of the inline storage in a type that propagates the lifetime
124    // of `self`. This allows the `.source()` implementation to propagate the correct
125    // lifetime.
126    fn ptr_ref(&self) -> Ref<'_> {
127        Ref {
128            ptr: self.object.get().cast::<MaybeUninit<u8>>(),
129            _lifetime: PhantomData,
130        }
131    }
132}
133
134impl<const N: usize> Drop for InlineError<N> {
135    fn drop(&mut self) {
136        // SAFETY: The constructor invariants of `InlineError` ensure that the vtable method
137        // is safe to call.
138        //
139        // Since the only place where the `drop` function is called is in the implementation
140        // of `Drop` for `InlineError`, we are guaranteed that the underlying object is
141        // valid.
142        unsafe { (self.vtable.drop)(self.object.get().cast::<MaybeUninit<u8>>()) }
143    }
144}
145
146impl<const N: usize> std::fmt::Display for InlineError<N> {
147    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
148        // SAFETY: The constructor invariants of `InlineError` ensure that the vtable method
149        // is safe to call.
150        unsafe { (self.vtable.display)(self.object.get().cast::<MaybeUninit<u8>>(), f) }
151    }
152}
153
154impl<const N: usize> std::fmt::Debug for InlineError<N> {
155    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156        write!(f, "InlineError<{}> {{ object: ", N)?;
157        // SAFETY: The constructor invariants of `InlineError` ensure that the vtable method
158        // is safe to call.
159        unsafe { (self.vtable.debug)(self.object.get().cast::<MaybeUninit<u8>>(), f) }?;
160        write!(f, ", vtable: {:?} }}", self.vtable)
161    }
162}
163
164impl<const N: usize> std::error::Error for InlineError<N> {
165    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
166        // SAFETY: The constructor invariants of `InlineError` ensure that the vtable method
167        // is safe to call.
168        unsafe { (self.vtable.source)(self.ptr_ref()) }
169    }
170}
171
172#[derive(Debug)]
173struct ErrorVTable {
174    debug: unsafe fn(*const MaybeUninit<u8>, &mut std::fmt::Formatter<'_>) -> std::fmt::Result,
175    display: unsafe fn(*const MaybeUninit<u8>, &mut std::fmt::Formatter<'_>) -> std::fmt::Result,
176    source: unsafe fn(Ref<'_>) -> Option<&(dyn std::error::Error + 'static)>,
177    drop: unsafe fn(*mut MaybeUninit<u8>),
178}
179
180// SAFETY: `object` must point to a valid object of type `T`.
181unsafe fn error_debug<T>(
182    object: *const MaybeUninit<u8>,
183    f: &mut std::fmt::Formatter<'_>,
184) -> std::fmt::Result
185where
186    T: std::fmt::Debug,
187{
188    // SAFETY: Required of caller.
189    unsafe { &*object.cast::<T>() }.fmt(f)
190}
191
192// SAFETY: `object` must point to a valid object of type `T`.
193unsafe fn error_display<T>(
194    object: *const MaybeUninit<u8>,
195    f: &mut std::fmt::Formatter<'_>,
196) -> std::fmt::Result
197where
198    T: std::fmt::Display,
199{
200    // SAFETY: Required of caller.
201    unsafe { &*object.cast::<T>() }.fmt(f)
202}
203
204// SAFETY: A valid instance of type `T` must be stored in `object` beginning at the start
205// of the slice. Note that this implies that `std::mem::size_of::<T>() <= object.len()` and
206// that the start of the slice is properly aligned.
207unsafe fn error_source<T>(object: Ref<'_>) -> Option<&(dyn std::error::Error + 'static)>
208where
209    T: std::error::Error + 'static,
210{
211    // SAFETY: Required of caller.
212    unsafe { &*object.ptr.cast::<T>() }.source()
213}
214
215// A pointer with a tagged lifetime.
216struct Ref<'a> {
217    ptr: *const MaybeUninit<u8>,
218    _lifetime: PhantomData<&'a MaybeUninit<u8>>,
219}
220
221// SAFETY: `object` must point to a valid object of type `T`. As a side effect, the
222// pointed-to object will be dropped.
223unsafe fn error_drop<T>(object: *mut MaybeUninit<u8>) {
224    // SAFETY: Required of caller.
225    unsafe { std::ptr::drop_in_place::<T>(object.cast::<T>()) }
226}
227
228///////////
229// Tests //
230///////////
231
232#[cfg(test)]
233mod tests {
234    use std::sync::{
235        atomic::{AtomicUsize, Ordering},
236        Arc, Mutex,
237    };
238
239    use thiserror::Error;
240
241    use super::*;
242
243    #[derive(Error, Debug)]
244    #[error("error A")]
245    struct ErrorA;
246
247    #[derive(Error, Debug)]
248    #[error("error B with val {val}")]
249    struct ErrorB<Inner: std::error::Error> {
250        val: usize,
251        #[source]
252        source: Inner,
253    }
254
255    #[derive(Error, Debug)]
256    #[error("error C with message {message}")]
257    struct ErrorC<Inner: std::error::Error> {
258        message: String,
259        /// `thiserror` automatically marks this as the error source.
260        source: Inner,
261    }
262
263    #[test]
264    fn test_formatting() {
265        // No Nesting
266        let message = format(&ErrorA);
267        assert_eq!(message, "error A");
268
269        // One Level of Nesting
270        let error = ErrorB {
271            val: 10,
272            source: ErrorA,
273        };
274
275        let expected = "error B with val 10\n    caused by: error A";
276        assert_eq!(format(&error), expected);
277
278        // Multiple Levels of Nesting
279        let error = ErrorC {
280            message: "Hello World".to_string(),
281            source: error,
282        };
283        let expected = "error C with message Hello World\n    \
284                        caused by: error B with val 10\n    \
285                        caused by: error A";
286        assert_eq!(format(&error), expected);
287    }
288
289    ///////////
290    // Error //
291    ///////////
292
293    #[derive(Debug, Error)]
294    #[error("zero sized error")]
295    struct ZeroSizedError;
296
297    #[derive(Debug, Error)]
298    #[error("error with drop: {}", self.0.load(Ordering::Relaxed))]
299    struct ErrorWithDrop(Arc<AtomicUsize>);
300
301    impl Drop for ErrorWithDrop {
302        fn drop(&mut self) {
303            self.0.fetch_add(1, Ordering::Relaxed);
304        }
305    }
306
307    #[derive(Debug, Error)]
308    #[error("error with source")]
309    struct ErrorWithSource(#[from] ZeroSizedError);
310
311    // This tests (using Miri) that it's safe to contain error types with interior mutability.
312    struct ErrorWithInteriorMutability(Mutex<usize>);
313
314    impl std::fmt::Debug for ErrorWithInteriorMutability {
315        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
316            let current = {
317                let mut guard = self.0.lock().unwrap();
318                let current = *guard;
319                *guard += 1;
320                current
321            };
322
323            write!(f, "{}", current)
324        }
325    }
326
327    impl std::fmt::Display for ErrorWithInteriorMutability {
328        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
329            let current = {
330                let mut guard = self.0.lock().unwrap();
331                let current = *guard;
332                *guard += 1;
333                current
334            };
335
336            write!(f, "{}", current)
337        }
338    }
339
340    impl std::error::Error for ErrorWithInteriorMutability {
341        fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
342            *self.0.lock().unwrap() += 1;
343            None
344        }
345    }
346
347    #[test]
348    fn sizes_and_offsets() {
349        let ref_size = std::mem::size_of::<&'static ()>();
350        let ref_align = std::mem::align_of::<&'static ()>();
351
352        assert_eq!(std::mem::offset_of!(InlineError<0>, object), ref_size);
353        assert_eq!(std::mem::offset_of!(InlineError<8>, object), ref_size);
354        assert_eq!(std::mem::offset_of!(InlineError<16>, object), ref_size);
355
356        assert_eq!(std::mem::size_of::<InlineError<0>>(), ref_size);
357        assert_eq!(std::mem::size_of::<Option<InlineError<0>>>(), ref_size);
358        assert_eq!(std::mem::align_of::<InlineError<0>>(), ref_align);
359        assert_eq!(std::mem::align_of::<Option<InlineError<0>>>(), ref_align);
360
361        assert_eq!(std::mem::size_of::<InlineError<8>>(), ref_size + 8);
362        assert_eq!(std::mem::size_of::<Option<InlineError<8>>>(), ref_size + 8);
363        assert_eq!(std::mem::align_of::<InlineError<8>>(), ref_align);
364        assert_eq!(std::mem::align_of::<Option<InlineError<8>>>(), ref_align);
365
366        assert_eq!(std::mem::size_of::<InlineError<16>>(), ref_size + 16);
367        assert_eq!(
368            std::mem::size_of::<Option<InlineError<16>>>(),
369            ref_size + 16
370        );
371        assert_eq!(std::mem::align_of::<InlineError<16>>(), ref_align);
372        assert_eq!(std::mem::align_of::<Option<InlineError<16>>>(), ref_align);
373    }
374
375    #[test]
376    fn inline_error_zst() {
377        use std::error::Error;
378
379        let error = InlineError::<0>::new(ZeroSizedError);
380        assert_eq!(
381            std::mem::size_of_val(&error),
382            8,
383            "expected 8 bytes for the payload and 0-bytes for the vtable"
384        );
385        assert_eq!(error.to_string(), "zero sized error");
386
387        let debug = format!("{:?}", error);
388        assert!(
389            debug.starts_with(&format!("InlineError<0> {{ object: {:?}", ZeroSizedError)),
390            "debug message: {}",
391            debug
392        );
393
394        assert!(error.source().is_none());
395
396        // Move it into a box. This is mainly a Miri tests.
397        let _ = Box::new(error);
398    }
399
400    #[test]
401    fn inline_error_with_drop() {
402        use std::error::Error;
403
404        let count = Arc::new(AtomicUsize::new(10));
405        let mut error = InlineError::<8>::new(ErrorWithDrop(count.clone()));
406        assert_eq!(
407            std::mem::size_of_val(&error),
408            16,
409            "expected 8 bytes for the payload and 8-bytes for the vtable"
410        );
411        assert_eq!(error.to_string(), "error with drop: 10");
412        assert!(error.source().is_none());
413
414        // Move it into a box. This is mainly a Miri tests.
415        error = InlineError::new(ZeroSizedError);
416        assert_eq!(error.to_string(), "zero sized error");
417
418        assert_eq!(count.load(Ordering::Relaxed), 11, "failed to run \"drop\"");
419    }
420
421    #[test]
422    fn inline_error_with_interior_mutability() {
423        use std::error::Error;
424
425        let error = InlineError::<16>::new(ErrorWithInteriorMutability(Mutex::new(0)));
426        assert_eq!(
427            std::mem::size_of_val(&error),
428            24,
429            "expected 16 bytes for the payload and 8-bytes for the vtable"
430        );
431        assert_eq!(error.to_string(), "0");
432        let debug = format!("{:?}", error);
433        assert!(debug.contains("object: 1"), "got {}", debug);
434        assert_eq!(error.to_string(), "2");
435
436        let debug = format!("{:?}", error);
437        assert!(debug.contains("object: 3"), "got {}", debug);
438
439        assert!(error.source().is_none());
440        assert_eq!(error.to_string(), "5");
441    }
442
443    #[test]
444    fn inline_error_with_source() {
445        use std::error::Error;
446
447        let error = InlineError::<8>::new(ErrorWithSource(ZeroSizedError));
448        assert_eq!(
449            std::mem::size_of_val(&error),
450            16,
451            "expected 8 bytes for the payload and 8-bytes for the vtable"
452        );
453        assert_eq!(error.to_string(), "error with source");
454        assert_eq!(error.source().unwrap().to_string(), "zero sized error");
455
456        // Move it into a box. This is mainly a Miri tests.
457        let _ = Box::new(error);
458    }
459}