wnf 0.3.0

Safe bindings for the Windows Notification Facility
Documentation
//! Methods for updating state data
//!
//! This module only adds inherent impls to [`OwnedState<T>`] and [`BorrowedState<'_, T>`](BorrowedState).

use std::ffi::c_void;
use std::{io, mem, ptr};

use tracing::debug;
use windows::Win32::Foundation::{NTSTATUS, STATUS_UNSUCCESSFUL};

use crate::bytes::NoUninit;
use crate::data::ChangeStamp;
use crate::ntapi;
use crate::state::{BorrowedState, OwnedState, RawState};

impl<T> OwnedState<T>
where
    T: NoUninit + ?Sized,
{
    /// Updates the data of this state with the given value
    ///
    /// The update is performed regardless of the current change stamp of the state. In order to perform the update
    /// conditionally based on the change stamp, use the [`update`](OwnedState::update) method.
    ///
    /// # Errors
    /// Returns an error if updating fails
    pub fn set(&self, data: &T) -> io::Result<()> {
        self.raw.set(data)
    }

    /// Updates the data of this state with the given value
    ///
    /// The update is only performed if the change stamp of the state before the update matches the given
    /// `expected_change_stamp`. In this case, the method returns `true`. Otherwise, the update is not performed and the
    /// method returns `false`.
    ///
    /// Note that this check is not guaranteed to work reliably in all situations. If the size of the given data exceeds
    /// the internal capacity of the state (causing a reallocation) while there is another concurrent update, it may
    /// happen that the data is updated even though the change stamp is already greater than the given one.
    ///
    /// In order to update the data without checking the change stamp, use the [`set`](OwnedState::set) method.
    ///
    /// # Errors
    /// Returns an error if updating fails
    pub fn update(&self, data: &T, expected_change_stamp: impl Into<ChangeStamp>) -> io::Result<bool> {
        self.raw.update(data, expected_change_stamp.into())
    }
}

impl<T> BorrowedState<'_, T>
where
    T: NoUninit + ?Sized,
{
    /// Updates the data of this state with the given value
    ///
    /// See [`OwnedState::set`]
    pub fn set(self, data: &T) -> io::Result<()> {
        self.raw.set(data)
    }

    /// Updates the data of this state with the given value
    ///
    /// See [`OwnedState::update`]
    pub fn update(self, data: &T, expected_change_stamp: impl Into<ChangeStamp>) -> io::Result<bool> {
        self.raw.update(data, expected_change_stamp.into())
    }
}

impl<T> RawState<T>
where
    T: NoUninit + ?Sized,
{
    /// Updates the data of this state with the given value
    ///
    /// The update is performed regardless of the current change stamp of the state.
    fn set(self, data: &T) -> io::Result<()> {
        self.update_internal(data, None).ok()?;
        Ok(())
    }

    /// Updates the data of this state with the given value
    ///
    /// The update is only performed if the change stamp of the state before the update matches the given
    /// `expected_change_stamp`. In this case, the method returns `true`. Otherwise, the update is not performed and the
    /// method returns `false`.
    pub(crate) fn update(self, data: &T, expected_change_stamp: ChangeStamp) -> io::Result<bool> {
        let result = self.update_internal(data, Some(expected_change_stamp));

        Ok(if result == STATUS_UNSUCCESSFUL {
            false
        } else {
            result.ok()?;
            true
        })
    }

    fn update_internal(self, data: &T, expected_change_stamp: Option<ChangeStamp>) -> NTSTATUS {
        let buffer_size = mem::size_of_val(data) as u32;
        let matching_change_stamp = expected_change_stamp.unwrap_or_default().into();
        let check_stamp: u32 = expected_change_stamp.is_some().into();

        // SAFETY:
        // - The pointer in the first argument points to a valid `u64` because it comes from a live reference
        // - The pointer in the second argument is valid for reads of size `buffer_size` because it comes from a live
        //   reference `data` (of type `T`) and `buffer_size == mem::size_of_val(data)`
        // - The memory range of size `buffer_size` starting at `buffer` is initialized because `T: NoUninit`
        // - The pointer in the fourth argument is either a null pointer or points to a valid `GUID` by the guarantees
        //   of `TypeId::as_ptr`
        let result = unsafe {
            ntapi::NtUpdateWnfStateData(
                &self.state_name.opaque_value(),
                data as *const T as *const c_void,
                buffer_size,
                self.type_id.as_ptr(),
                ptr::null(),
                matching_change_stamp,
                check_stamp,
            )
        };

        debug!(
            target: ntapi::TRACING_TARGET,
            ?result,
            input.state_name = %self.state_name,
            input.buffer_size = buffer_size,
            input.type_id = %self.type_id,
            input.matching_change_stamp = matching_change_stamp,
            input.check_stamp = check_stamp,
            "NtUpdateWnfStateData",
        );

        result
    }
}