saa 5.1.1

Word-sized low-level synchronization primitives providing both asynchronous and synchronous interfaces.
Documentation
//! Define base operations for synchronization primitives.

use std::pin::{Pin, pin};
use std::ptr::{addr_of, null, with_exposed_provenance};
#[cfg(not(feature = "loom"))]
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release};
use std::thread;

#[cfg(feature = "loom")]
use loom::sync::atomic::AtomicUsize;

use crate::opcode::Opcode;
use crate::wait_queue::{Entry, WaitQueue};

/// Defines base operations for synchronization primitives.
pub(crate) trait SyncPrimitive: Sized {
    /// Returns a reference to the state.
    fn state(&self) -> &AtomicUsize;

    /// Returns the maximum number of shared owners.
    fn max_shared_owners() -> usize;

    /// Called when an enqueued wait queue entry is being dropped without acknowledging the result.
    fn drop_wait_queue_entry(entry: &Entry);

    /// Converts a reference to `Self` into a memory address.
    #[inline]
    fn addr(&self) -> usize {
        let self_ptr: *const Self = addr_of!(*self);
        self_ptr.expose_provenance()
    }

    /// Tries to push a wait queue entry into the wait queue.
    #[must_use]
    fn try_push_wait_queue_entry<F: FnOnce()>(
        &self,
        wait_queue: Pin<&WaitQueue>,
        state: usize,
        begin_wait: F,
    ) -> Option<F> {
        let anchor_ptr = wait_queue.anchor_ptr().0;
        let anchor_addr = anchor_ptr.expose_provenance();
        debug_assert_eq!(anchor_addr & (!WaitQueue::ADDR_MASK), 0);

        let tail_anchor_ptr = WaitQueue::to_anchor_ptr(state);
        wait_queue
            .entry()
            .update_next_entry_anchor_ptr(tail_anchor_ptr);

        // The anchor pointer, instead of an entry pointer, is stored in the state.
        let next_state = (state & (!WaitQueue::ADDR_MASK)) | anchor_addr;
        if self
            .state()
            .compare_exchange(state, next_state, AcqRel, Acquire)
            .is_ok()
        {
            // The entry cannot be dropped until the result is acknowledged.
            wait_queue.entry().set_pollable();
            begin_wait();
            None
        } else {
            Some(begin_wait)
        }
    }

    /// Waits for the desired resource synchronously.
    fn wait_resources_sync<F: FnOnce()>(
        &self,
        state: usize,
        opcode: Opcode,
        begin_wait: F,
    ) -> Result<u8, F> {
        debug_assert!(state & WaitQueue::ADDR_MASK != 0 || state & WaitQueue::DATA_MASK != 0);

        let pinned_wait_queue = pin!(WaitQueue::default());
        pinned_wait_queue.as_ref().construct(self, opcode, true);
        if let Some(returned) =
            self.try_push_wait_queue_entry(pinned_wait_queue.as_ref(), state, begin_wait)
        {
            return Err(returned);
        }
        Ok(pinned_wait_queue.entry().poll_result_sync())
    }

    /// Releases the resource represented by the supplied operation mode.
    ///
    /// Returns `false` if the resource cannot be released.
    fn release_loop(&self, mut state: usize, opcode: Opcode) -> bool {
        while opcode.can_release(state) {
            if state & WaitQueue::ADDR_MASK == 0
                || state & WaitQueue::LOCKED_FLAG == WaitQueue::LOCKED_FLAG
            {
                // Release the resource in-place.
                match self.state().compare_exchange(
                    state,
                    state - opcode.acquired_count(),
                    Release,
                    Relaxed,
                ) {
                    Ok(_) => return true,
                    Err(new_state) => state = new_state,
                }
            } else {
                // The wait queue is not empty and is not being processed.
                let next_state = (state | WaitQueue::LOCKED_FLAG) - opcode.acquired_count();
                if let Err(new_state) = self
                    .state()
                    .compare_exchange(state, next_state, AcqRel, Relaxed)
                {
                    state = new_state;
                    continue;
                }
                self.process_wait_queue(next_state);
                return true;
            }
        }
        false
    }

    /// Processes the wait queue.
    ///
    /// The tail entry of the wait queue is either reset or stays the same.
    fn process_wait_queue(&self, mut state: usize) {
        let mut head_entry_ptr: *const Entry = null();
        let mut unlocked = false;
        while !unlocked {
            debug_assert_eq!(state & WaitQueue::LOCKED_FLAG, WaitQueue::LOCKED_FLAG);

            let anchor_ptr = WaitQueue::to_anchor_ptr(state);
            let tail_entry_ptr = WaitQueue::to_entry_ptr(anchor_ptr);
            if head_entry_ptr.is_null() {
                Entry::iter_forward(tail_entry_ptr, true, |entry, next_entry| {
                    head_entry_ptr = Entry::ref_to_ptr(entry);
                    next_entry.is_none()
                });
            } else {
                Entry::set_prev_ptr(tail_entry_ptr);
            }

            let data = state & WaitQueue::DATA_MASK;
            let mut transferred = 0;
            let mut resolved_entry_ptr: *const Entry = null();
            let mut reset_failed = false;

            Entry::iter_backward(head_entry_ptr, |entry, prev_entry| {
                let desired = entry.opcode().desired_count();
                if data + transferred == 0
                    || data + transferred + desired <= Self::max_shared_owners()
                {
                    // The entry can inherit ownership.
                    let acquired = entry.opcode().acquired_count();
                    debug_assert!(acquired <= desired);
                    if prev_entry.is_some() {
                        transferred += acquired;
                        resolved_entry_ptr = Entry::ref_to_ptr(entry);
                        false
                    } else {
                        // This is the tail of the wait queue: try to reset.
                        debug_assert_eq!(tail_entry_ptr, addr_of!(*entry));
                        if self
                            .state()
                            .compare_exchange(state, data + transferred + acquired, AcqRel, Acquire)
                            .is_err()
                        {
                            // This entry will be processed on the next retry.
                            entry.update_next_entry_anchor_ptr(null());
                            head_entry_ptr = Entry::ref_to_ptr(entry);
                            reset_failed = true;
                            return true;
                        }

                        // The wait queue was reset.
                        unlocked = true;
                        resolved_entry_ptr = Entry::ref_to_ptr(entry);
                        true
                    }
                } else {
                    // Unlink those that have succeeded in acquiring shared ownership.
                    entry.update_next_entry_anchor_ptr(null());
                    head_entry_ptr = Entry::ref_to_ptr(entry);
                    true
                }
            });
            debug_assert!(!reset_failed || !unlocked);

            if !reset_failed && !unlocked {
                unlocked = self
                    .state()
                    .fetch_update(AcqRel, Acquire, |new_state| {
                        let new_data = new_state & WaitQueue::DATA_MASK;
                        debug_assert!(new_data <= data);
                        debug_assert!(new_data + transferred <= WaitQueue::DATA_MASK);

                        if new_data == data {
                            Some((new_state & WaitQueue::ADDR_MASK) | (new_data + transferred))
                        } else {
                            None
                        }
                    })
                    .is_ok();
            }

            if !unlocked {
                state = self.state().fetch_add(transferred, AcqRel) + transferred;
            }

            Entry::iter_forward(resolved_entry_ptr, false, |entry, _next_entry| {
                entry.set_result(0);
                false
            });
        }
    }

    /// Removes a wait queue entry from the wait queue.
    fn remove_wait_queue_entry(
        &self,
        mut state: usize,
        entry_ptr_to_remove: *const Entry,
    ) -> (usize, bool) {
        let mut result = Ok((state, false));

        loop {
            debug_assert_eq!(state & WaitQueue::LOCKED_FLAG, WaitQueue::LOCKED_FLAG);
            debug_assert_ne!(state & WaitQueue::ADDR_MASK, 0);

            let anchor_ptr = WaitQueue::to_anchor_ptr(state);
            let tail_entry_ptr = WaitQueue::to_entry_ptr(anchor_ptr);
            Entry::iter_forward(tail_entry_ptr, true, |entry, next_entry| {
                if Entry::ref_to_ptr(entry) == entry_ptr_to_remove {
                    // Found the entry to remove.
                    let prev_entry_ptr = entry.prev_entry_ptr();
                    if let Some(next_entry) = next_entry {
                        next_entry.update_prev_entry_ptr(prev_entry_ptr);
                    }
                    result = if let Some(prev_entry) = unsafe { prev_entry_ptr.as_ref() } {
                        // Successfully unlinked the target entry without updating the state.
                        prev_entry.update_next_entry_anchor_ptr(entry.next_entry_anchor_ptr());
                        Ok((state, true))
                    } else if let Some(next_entry) = next_entry {
                        // The next entry becomes the new tail of the wait queue.
                        let next_entry_addr = Entry::ref_to_ptr(next_entry).expose_provenance();
                        let next_entry_ptr = with_exposed_provenance(next_entry_addr);
                        let new_tail_ptr = Entry::to_wait_queue_ptr(next_entry_ptr);
                        let new_anchor_ptr = unsafe { (*new_tail_ptr).anchor_ptr().0 };
                        debug_assert_eq!(new_anchor_ptr.addr() & (!WaitQueue::ADDR_MASK), 0);

                        let next_state =
                            (state & (!WaitQueue::ADDR_MASK)) | new_anchor_ptr.expose_provenance();
                        debug_assert_eq!(
                            next_state & WaitQueue::LOCKED_FLAG,
                            WaitQueue::LOCKED_FLAG
                        );

                        self.state()
                            .compare_exchange(state, next_state, AcqRel, Acquire)
                            .map(|_| (next_state, true))
                    } else {
                        // Reset the wait queue and unlock.
                        let next_state = state & WaitQueue::DATA_MASK;
                        self.state()
                            .compare_exchange(state, next_state, AcqRel, Acquire)
                            .map(|_| (next_state, true))
                    };
                    true
                } else {
                    false
                }
            });

            match result {
                Ok((state, removed)) => return (state, removed),
                Err(new_state) => state = new_state,
            }
        }
    }

    /// Removes a [`WaitQueue`] entry that was pushed into the wait queue but has not been
    /// processed.
    fn force_remove_wait_queue_entry(entry: &Entry) {
        let this: &Self = entry.sync_primitive_ref();
        let this_ptr: *const Entry = addr_of!(*entry);

        // Remove the wait queue entry from the wait queue list.
        let mut state = this.state().load(Acquire);
        let mut need_completion = false;
        loop {
            if state & WaitQueue::LOCKED_FLAG == WaitQueue::LOCKED_FLAG {
                // Another thread is processing the wait queue.
                thread::yield_now();
                state = this.state().load(Acquire);
            } else if state & WaitQueue::ADDR_MASK == 0 {
                // The wait queue is empty.
                need_completion = true;
                break;
            } else if let Err(new_state) = this.state().compare_exchange(
                state,
                state | WaitQueue::LOCKED_FLAG,
                AcqRel,
                Acquire,
            ) {
                state = new_state;
            } else {
                let (new_state, removed) =
                    this.remove_wait_queue_entry(state | WaitQueue::LOCKED_FLAG, this_ptr);
                if new_state & WaitQueue::LOCKED_FLAG == WaitQueue::LOCKED_FLAG {
                    // We need to process the wait queue if it is still locked.
                    this.process_wait_queue(new_state);
                }
                if !removed {
                    need_completion = true;
                }
                break;
            }
        }

        if need_completion {
            // The entry was removed by another thread, so it will be completed.
            while !entry.result_finalized() {
                thread::yield_now();
            }
            this.release_loop(state, entry.opcode());
        }
    }
}