hooking 0.4.0

hooking libs in rust
Documentation
use std::ffi::c_void;
use std::ptr::NonNull;
use std::sync::{Mutex, MutexGuard};

use crate::mem::page::MemoryProtectionGuard;
use crate::mem::{AllocationInfo, DefaultMemoryController, MemoryHandle, MemoryProtection};

use super::MemoryController;
use super::{MemoryError, Result};

#[derive(Debug)]
pub struct HeapState<C: MemoryController> {
    allocation: Option<C::AllocationInfoType>,
    written: usize,
}

impl<C: MemoryController> HeapState<C> {
    pub const fn empty() -> Self {
        Self {
            allocation: None,
            written: 0,
        }
    }

    pub unsafe fn ensure_allocated(&mut self, mem: &C, min_size: Option<usize>) -> Result<()> {
        if self.allocation.is_none() {
            self.allocation = Some(unsafe { mem.allocate_memory(min_size)? });
        }
        Ok(())
    }

    fn allocation(&self) -> Result<&C::AllocationInfoType> {
        match &self.allocation {
            Some(allocation) => Ok(allocation),
            None => Err(MemoryError::TableHeapNotAllocated),
        }
    }
}

pub struct HookHeap<C: MemoryController> {
    pub mem: C,
    state: Mutex<HeapState<C>>,
}
unsafe impl<C: MemoryController> Send for HookHeap<C> where C: Send {}
unsafe impl<C: MemoryController> Sync for HookHeap<C> where C: Sync {}

impl HookHeap<DefaultMemoryController> {
    pub const fn new() -> Self {
        Self::with_memory_controller(DefaultMemoryController::new())
    }
}

impl<C: MemoryController> HookHeap<C> {
    pub const fn with_memory_controller(controller: C) -> Self {
        Self {
            mem: controller,
            state: Mutex::new(HeapState::empty()),
        }
    }

    fn state<'a>(&'a self) -> Result<MutexGuard<'a, HeapState<C>>> {
        self.state
            .lock()
            .map_err(|_| MemoryError::BadTableHeapState)
    }

    pub unsafe fn ensure_allocated(&self, min_size: Option<usize>) -> Result<()> {
        unsafe { self.state()?.ensure_allocated(&self.mem, min_size) }
    }

    pub fn get_handle<'a>(&'a self) -> Result<MemoryHeapHandle<'a, C>> {
        let mut state = self.state()?;
        unsafe {
            state.ensure_allocated(&self.mem, None)?;
        }

        Ok(MemoryHeapHandle {
            state,
            mem: &self.mem,
        })
    }
}

pub struct MemoryHeapHandle<'a, C: MemoryController> {
    state: MutexGuard<'a, HeapState<C>>,
    mem: &'a C,
}

impl<'a, M: MemoryController> MemoryHeapHandle<'a, M> {
    pub fn begin_write(&mut self) -> Result<MemoryWriteHandle<'a, '_, M>> {
        MemoryWriteHandle::new_from(self)
    }

    pub unsafe fn write_address(&self) -> Result<NonNull<c_void>> {
        let allocation = self.state.allocation()?;
        Ok(unsafe {
            allocation
                .allocation_start()
                .as_ptr()
                .add(self.state.written)
        })
    }

    pub unsafe fn reserve(&mut self, size: usize) -> Result<NonNull<c_void>> {
        let allocation = self.state.allocation()?;

        if self.state.written + size > allocation.allocation_size() {
            return Err(MemoryError::NoMemory {
                needs: self.state.written + size,
                has: allocation.allocation_size(),
            });
        }

        let write_address = unsafe {
            allocation
                .allocation_start()
                .as_ptr()
                .add(self.state.written)
        };

        self.state.written += size;

        Ok(write_address)
    }

    pub fn protection_guard(
        &self,
        on_enter: MemoryProtection,
        on_exit: MemoryProtection,
    ) -> Result<MemoryProtectionGuard<'a, M>> {
        let allocation = self.state.allocation()?;

        self.mem.protection_guard(
            allocation.allocation_start(),
            allocation.allocation_size(),
            on_enter,
            on_exit,
        )
    }
}

pub struct MemoryWriteHandle<'a, 'b, M: MemoryController> {
    heap: &'b mut MemoryHeapHandle<'a, M>,
    _guard: MemoryProtectionGuard<'a, M>,
}

impl<'a, 'b, M: MemoryController> MemoryWriteHandle<'a, 'b, M> {
    pub fn new_from(handle: &'b mut MemoryHeapHandle<'a, M>) -> Result<Self> {
        let handle = Self {
            _guard: handle
                .protection_guard(MemoryProtection::ReadWrite, MemoryProtection::ReadExecute)?,
            heap: handle,
        };

        Ok(handle)
    }

    pub unsafe fn write_address(&self) -> Result<NonNull<c_void>> {
        unsafe { self.heap.write_address() }
    }

    pub unsafe fn reserve(&mut self, size: usize) -> Result<NonNull<c_void>> {
        unsafe { self.heap.reserve(size) }
    }

    pub unsafe fn write_bytes(&mut self, buffer: &[u8]) -> Result<NonNull<c_void>> {
        let write_address = unsafe { self.reserve(buffer.len())? };
        unsafe {
            std::ptr::copy_nonoverlapping(
                buffer.as_ptr(),
                write_address.as_ptr() as *mut _,
                buffer.len(),
            );
        }
        Ok(write_address)
    }
}