Skip to main content

diskann_quantization/
error.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6//! Formatting utilities for error chains.
7
8use std::{cell::UnsafeCell, marker::PhantomData, mem::MaybeUninit};
9
10#[derive(Debug)]
11pub enum Infallible {}
12
13impl std::fmt::Display for Infallible {
14    fn fmt(&self, _: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
15        unreachable!("Infallible is unconstructible");
16    }
17}
18
19impl std::error::Error for Infallible {}
20
21/// Format the entire error chain for `err` by first calling `err.to_string()` and then
22/// by walking the error's
23/// [source tree](https://doc.rust-lang.org/std/error/trait.Error.html#method.source).
24pub fn format<E>(err: &E) -> String
25where
26    E: std::error::Error + ?Sized,
27{
28    // Cast wrap the walking of the source chain into something that behaves like an
29    // iterator.
30    struct SourceIterator<'a>(Option<&'a (dyn std::error::Error + 'static)>);
31    impl<'a> Iterator for SourceIterator<'a> {
32        type Item = &'a (dyn std::error::Error + 'static);
33        fn next(&mut self) -> Option<Self::Item> {
34            let current = self.0;
35            self.0 = match current {
36                Some(current) => current.source(),
37                None => None,
38            };
39            current
40        }
41    }
42
43    // Get the base message from the error.
44    let mut message = err.to_string();
45    // Walk the source chain, formatting each
46    for source in SourceIterator(err.source()) {
47        message.push_str("\n    caused by: ");
48        message.push_str(&source.to_string());
49    }
50    message
51}
52
53/// An implementation of `Box<dyn std::error::Error>` that stores the error payload inline,
54/// avoiding dynamic memory allocation. This has several practical drawbacks:
55///
56/// 1. The size of the error payload must be at most `N` bytes.
57/// 2. The alignment of the error payload must be at most 8 bytes.
58///
59/// Both of these contraints are verified using post-monomorphization errors.
60///
61/// # Example
62///
63/// ```
64/// use diskann_quantization::error::InlineError;
65///
66/// let base_error = u32::try_from(u64::MAX).unwrap_err();
67/// let mut error = InlineError::<8>::new(base_error);
68/// assert_eq!(error.to_string(), base_error.to_string());
69///
70/// // Change the dynamic type of the contained error.
71/// error = InlineError::new(Box::new(base_error));
72/// assert_eq!(error.to_string(), base_error.to_string());
73/// ```
74#[repr(C)]
75pub struct InlineError<const N: usize = 16> {
76    // We place the vtable first to enable the niche-optimization.
77    vtable: &'static ErrorVTable,
78
79    // NOTE: We need to use `MaybeUninit` instead of `u8` to maintain the provenance of
80    // any pointers/references stored in the payload.
81    //
82    // Additionally, the `UnsafeCell` is needed because the payload may have interior
83    // mutability as a side-effect of API calls.
84    object: UnsafeCell<[MaybeUninit<u8>; N]>,
85}
86
87// SAFETY: We only allow error payloads that are `Send`.
88unsafe impl<const N: usize> Send for InlineError<N> {}
89
90// SAFETY: We only allow error payloads that are `Sync`.
91unsafe impl<const N: usize> Sync for InlineError<N> {}
92
93impl<const N: usize> InlineError<N> {
94    /// Construct a new `InlineError` around `error`.
95    ///
96    /// Fails to compile if:
97    ///
98    /// 1. `std::mem::align_of::<T>() > 8`: Objects of type `T` must be compatible with the
99    ///    inline storage buffer.
100    /// 2. `std::mem::size_of::<T>() > N`: Objects of type `T` must fit within a buffer of
101    ///    size `N`.
102    pub fn new<T>(error: T) -> Self
103    where
104        T: std::error::Error + Send + Sync + 'static,
105    {
106        const { assert!(std::mem::size_of::<T>() <= N, "error type is too big") };
107        const {
108            assert!(
109                std::mem::align_of::<T>() <= std::mem::align_of::<&'static ErrorVTable>(),
110                "error type has alignment stricter than 8"
111            )
112        };
113
114        let mut this = Self {
115            vtable: &ErrorVTable {
116                debug: error_debug::<T>,
117                display: error_display::<T>,
118                source: error_source::<T>,
119                drop: error_drop::<T>,
120            },
121            object: UnsafeCell::new([MaybeUninit::uninit(); N]),
122        };
123
124        // SAFETY: We have const assertions that the size and alignment of `T` are
125        // compatible with the buffer we created.
126        //
127        // Additionally, the memory we are writing to does not have a valid object stored,
128        // so using `ptr::write` will not leak memory.
129        unsafe { this.object.get_mut().as_mut_ptr().cast::<T>().write(error) };
130
131        this
132    }
133
134    // Return the base pointer of the inline storage in a type that propagates the lifetime
135    // of `self`. This allows the `.source()` implementation to propagate the correct
136    // lifetime.
137    fn ptr_ref(&self) -> Ref<'_> {
138        Ref {
139            ptr: self.object.get().cast::<MaybeUninit<u8>>(),
140            _lifetime: PhantomData,
141        }
142    }
143}
144
145impl<const N: usize> Drop for InlineError<N> {
146    fn drop(&mut self) {
147        // SAFETY: The constructor invariants of `InlineError` ensure that the vtable method
148        // is safe to call.
149        //
150        // Since the only place where the `drop` function is called is in the implementation
151        // of `Drop` for `InlineError`, we are guaranteed that the underlying object is
152        // valid.
153        unsafe { (self.vtable.drop)(self.object.get().cast::<MaybeUninit<u8>>()) }
154    }
155}
156
157impl<const N: usize> std::fmt::Display for InlineError<N> {
158    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159        // SAFETY: The constructor invariants of `InlineError` ensure that the vtable method
160        // is safe to call.
161        unsafe { (self.vtable.display)(self.object.get().cast::<MaybeUninit<u8>>(), f) }
162    }
163}
164
165impl<const N: usize> std::fmt::Debug for InlineError<N> {
166    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
167        write!(f, "InlineError<{}> {{ object: ", N)?;
168        // SAFETY: The constructor invariants of `InlineError` ensure that the vtable method
169        // is safe to call.
170        unsafe { (self.vtable.debug)(self.object.get().cast::<MaybeUninit<u8>>(), f) }?;
171        write!(f, ", vtable: {:?} }}", self.vtable)
172    }
173}
174
175impl<const N: usize> std::error::Error for InlineError<N> {
176    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
177        // SAFETY: The constructor invariants of `InlineError` ensure that the vtable method
178        // is safe to call.
179        unsafe { (self.vtable.source)(self.ptr_ref()) }
180    }
181}
182
183#[derive(Debug)]
184struct ErrorVTable {
185    debug: unsafe fn(*const MaybeUninit<u8>, &mut std::fmt::Formatter<'_>) -> std::fmt::Result,
186    display: unsafe fn(*const MaybeUninit<u8>, &mut std::fmt::Formatter<'_>) -> std::fmt::Result,
187    source: unsafe fn(Ref<'_>) -> Option<&(dyn std::error::Error + 'static)>,
188    drop: unsafe fn(*mut MaybeUninit<u8>),
189}
190
191// SAFETY: `object` must point to a valid object of type `T`.
192unsafe fn error_debug<T>(
193    object: *const MaybeUninit<u8>,
194    f: &mut std::fmt::Formatter<'_>,
195) -> std::fmt::Result
196where
197    T: std::fmt::Debug,
198{
199    // SAFETY: Required of caller.
200    unsafe { &*object.cast::<T>() }.fmt(f)
201}
202
203// SAFETY: `object` must point to a valid object of type `T`.
204unsafe fn error_display<T>(
205    object: *const MaybeUninit<u8>,
206    f: &mut std::fmt::Formatter<'_>,
207) -> std::fmt::Result
208where
209    T: std::fmt::Display,
210{
211    // SAFETY: Required of caller.
212    unsafe { &*object.cast::<T>() }.fmt(f)
213}
214
215// SAFETY: A valid instance of type `T` must be stored in `object` beginning at the start
216// of the slice. Note that this implies that `std::mem::size_of::<T>() <= object.len()` and
217// that the start of the slice is properly aligned.
218unsafe fn error_source<T>(object: Ref<'_>) -> Option<&(dyn std::error::Error + 'static)>
219where
220    T: std::error::Error + 'static,
221{
222    // SAFETY: Required of caller.
223    unsafe { &*object.ptr.cast::<T>() }.source()
224}
225
226// A pointer with a tagged lifetime.
227struct Ref<'a> {
228    ptr: *const MaybeUninit<u8>,
229    _lifetime: PhantomData<&'a MaybeUninit<u8>>,
230}
231
232// SAFETY: `object` must point to a valid object of type `T`. As a side effect, the
233// pointed-to object will be dropped.
234unsafe fn error_drop<T>(object: *mut MaybeUninit<u8>) {
235    // SAFETY: Required of caller.
236    unsafe { std::ptr::drop_in_place::<T>(object.cast::<T>()) }
237}
238
239///////////
240// Tests //
241///////////
242
243#[cfg(test)]
244mod tests {
245    use std::sync::{
246        Arc, Mutex,
247        atomic::{AtomicUsize, Ordering},
248    };
249
250    use thiserror::Error;
251
252    use super::*;
253
254    #[derive(Error, Debug)]
255    #[error("error A")]
256    struct ErrorA;
257
258    #[derive(Error, Debug)]
259    #[error("error B with val {val}")]
260    struct ErrorB<Inner: std::error::Error> {
261        val: usize,
262        #[source]
263        source: Inner,
264    }
265
266    #[derive(Error, Debug)]
267    #[error("error C with message {message}")]
268    struct ErrorC<Inner: std::error::Error> {
269        message: String,
270        /// `thiserror` automatically marks this as the error source.
271        source: Inner,
272    }
273
274    #[test]
275    fn test_formatting() {
276        // No Nesting
277        let message = format(&ErrorA);
278        assert_eq!(message, "error A");
279
280        // One Level of Nesting
281        let error = ErrorB {
282            val: 10,
283            source: ErrorA,
284        };
285
286        let expected = "error B with val 10\n    caused by: error A";
287        assert_eq!(format(&error), expected);
288
289        // Multiple Levels of Nesting
290        let error = ErrorC {
291            message: "Hello World".to_string(),
292            source: error,
293        };
294        let expected = "error C with message Hello World\n    \
295                        caused by: error B with val 10\n    \
296                        caused by: error A";
297        assert_eq!(format(&error), expected);
298    }
299
300    ///////////
301    // Error //
302    ///////////
303
304    #[derive(Debug, Error)]
305    #[error("zero sized error")]
306    struct ZeroSizedError;
307
308    #[derive(Debug, Error)]
309    #[error("error with drop: {}", self.0.load(Ordering::Relaxed))]
310    struct ErrorWithDrop(Arc<AtomicUsize>);
311
312    impl Drop for ErrorWithDrop {
313        fn drop(&mut self) {
314            self.0.fetch_add(1, Ordering::Relaxed);
315        }
316    }
317
318    #[derive(Debug, Error)]
319    #[error("error with source")]
320    struct ErrorWithSource(#[from] ZeroSizedError);
321
322    // This tests (using Miri) that it's safe to contain error types with interior mutability.
323    struct ErrorWithInteriorMutability(Mutex<usize>);
324
325    impl std::fmt::Debug for ErrorWithInteriorMutability {
326        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
327            let current = {
328                let mut guard = self.0.lock().unwrap();
329                let current = *guard;
330                *guard += 1;
331                current
332            };
333
334            write!(f, "{}", current)
335        }
336    }
337
338    impl std::fmt::Display for ErrorWithInteriorMutability {
339        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
340            let current = {
341                let mut guard = self.0.lock().unwrap();
342                let current = *guard;
343                *guard += 1;
344                current
345            };
346
347            write!(f, "{}", current)
348        }
349    }
350
351    impl std::error::Error for ErrorWithInteriorMutability {
352        fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
353            *self.0.lock().unwrap() += 1;
354            None
355        }
356    }
357
358    #[test]
359    fn sizes_and_offsets() {
360        let ref_size = std::mem::size_of::<&'static ()>();
361        let ref_align = std::mem::align_of::<&'static ()>();
362
363        assert_eq!(std::mem::offset_of!(InlineError<0>, object), ref_size);
364        assert_eq!(std::mem::offset_of!(InlineError<8>, object), ref_size);
365        assert_eq!(std::mem::offset_of!(InlineError<16>, object), ref_size);
366
367        assert_eq!(std::mem::size_of::<InlineError<0>>(), ref_size);
368        assert_eq!(std::mem::size_of::<Option<InlineError<0>>>(), ref_size);
369        assert_eq!(std::mem::align_of::<InlineError<0>>(), ref_align);
370        assert_eq!(std::mem::align_of::<Option<InlineError<0>>>(), ref_align);
371
372        assert_eq!(std::mem::size_of::<InlineError<8>>(), ref_size + 8);
373        assert_eq!(std::mem::size_of::<Option<InlineError<8>>>(), ref_size + 8);
374        assert_eq!(std::mem::align_of::<InlineError<8>>(), ref_align);
375        assert_eq!(std::mem::align_of::<Option<InlineError<8>>>(), ref_align);
376
377        assert_eq!(std::mem::size_of::<InlineError<16>>(), ref_size + 16);
378        assert_eq!(
379            std::mem::size_of::<Option<InlineError<16>>>(),
380            ref_size + 16
381        );
382        assert_eq!(std::mem::align_of::<InlineError<16>>(), ref_align);
383        assert_eq!(std::mem::align_of::<Option<InlineError<16>>>(), ref_align);
384    }
385
386    #[test]
387    fn inline_error_zst() {
388        use std::error::Error;
389
390        let error = InlineError::<0>::new(ZeroSizedError);
391        assert_eq!(
392            std::mem::size_of_val(&error),
393            8,
394            "expected 8 bytes for the payload and 0-bytes for the vtable"
395        );
396        assert_eq!(error.to_string(), "zero sized error");
397
398        let debug = format!("{:?}", error);
399        assert!(
400            debug.starts_with(&format!("InlineError<0> {{ object: {:?}", ZeroSizedError)),
401            "debug message: {}",
402            debug
403        );
404
405        assert!(error.source().is_none());
406
407        // Move it into a box. This is mainly a Miri tests.
408        let _ = Box::new(error);
409    }
410
411    #[test]
412    fn inline_error_with_drop() {
413        use std::error::Error;
414
415        let count = Arc::new(AtomicUsize::new(10));
416        let mut error = InlineError::<8>::new(ErrorWithDrop(count.clone()));
417        assert_eq!(
418            std::mem::size_of_val(&error),
419            16,
420            "expected 8 bytes for the payload and 8-bytes for the vtable"
421        );
422        assert_eq!(error.to_string(), "error with drop: 10");
423        assert!(error.source().is_none());
424
425        // Move it into a box. This is mainly a Miri tests.
426        error = InlineError::new(ZeroSizedError);
427        assert_eq!(error.to_string(), "zero sized error");
428
429        assert_eq!(count.load(Ordering::Relaxed), 11, "failed to run \"drop\"");
430    }
431
432    #[test]
433    fn inline_error_with_interior_mutability() {
434        use std::error::Error;
435
436        let error = InlineError::<16>::new(ErrorWithInteriorMutability(Mutex::new(0)));
437        assert_eq!(
438            std::mem::size_of_val(&error),
439            24,
440            "expected 16 bytes for the payload and 8-bytes for the vtable"
441        );
442        assert_eq!(error.to_string(), "0");
443        let debug = format!("{:?}", error);
444        assert!(debug.contains("object: 1"), "got {}", debug);
445        assert_eq!(error.to_string(), "2");
446
447        let debug = format!("{:?}", error);
448        assert!(debug.contains("object: 3"), "got {}", debug);
449
450        assert!(error.source().is_none());
451        assert_eq!(error.to_string(), "5");
452    }
453
454    #[test]
455    fn inline_error_with_source() {
456        use std::error::Error;
457
458        let error = InlineError::<8>::new(ErrorWithSource(ZeroSizedError));
459        assert_eq!(
460            std::mem::size_of_val(&error),
461            16,
462            "expected 8 bytes for the payload and 8-bytes for the vtable"
463        );
464        assert_eq!(error.to_string(), "error with source");
465        assert_eq!(error.source().unwrap().to_string(), "zero sized error");
466
467        // Move it into a box. This is mainly a Miri tests.
468        let _ = Box::new(error);
469    }
470}