use std::{mem, ptr};

use singe_cuda_sys::{driver, runtime};

use crate::{
    error::{Error, Result},
    try_cuda,
};

bitflags::bitflags! {
    /// Flags for `cudaIpcOpenMemHandle`.
    #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
    pub struct IpcMemoryFlags: u32 {
        const LAZY_ENABLE_PEER_ACCESS = driver::CUipcMem_flags::CU_IPC_MEM_LAZY_ENABLE_PEER_ACCESS as _;
    }
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct IpcEventHandle(runtime::cudaIpcEventHandle_t);

impl IpcEventHandle {
    pub const unsafe fn from_raw(handle: runtime::cudaIpcEventHandle_t) -> Self {
        Self(handle)
    }

    pub const fn zeroed() -> Self {
        unsafe { mem::zeroed() }
    }

    pub const unsafe fn as_raw(&self) -> runtime::cudaIpcEventHandle_t {
        self.0
    }
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct IpcMemoryHandle(runtime::cudaIpcMemHandle_t);

impl IpcMemoryHandle {
    pub const unsafe fn from_raw(handle: runtime::cudaIpcMemHandle_t) -> Self {
        Self(handle)
    }

    pub const fn zeroed() -> Self {
        unsafe { mem::zeroed() }
    }

    pub const unsafe fn as_raw(&self) -> runtime::cudaIpcMemHandle_t {
        self.0
    }

    pub fn open<T>(self, flags: IpcMemoryFlags) -> Result<OpenedIpcMemory<T>> {
        let mut dev_ptr = ptr::null_mut();
        unsafe {
            try_cuda!(runtime::cudaIpcOpenMemHandle(
                &raw mut dev_ptr,
                self.as_raw(),
                flags.bits()
            ))?;
        }
        if dev_ptr.is_null() {
            return Err(Error::NullHandle);
        }

        Ok(OpenedIpcMemory {
            ptr: dev_ptr.cast(),
        })
    }
}

#[derive(Debug)]
pub struct OpenedIpcMemory<T> {
    ptr: *mut T,
}

impl<T> OpenedIpcMemory<T> {
    pub const fn as_ptr(&self) -> *mut T {
        self.ptr
    }

    pub fn close(self) -> Result<()> {
        let ptr: *mut () = self.ptr.cast();
        mem::forget(self);
        if ptr.is_null() {
            return Ok(());
        }
        unsafe { try_cuda!(runtime::cudaIpcCloseMemHandle(ptr as _)) }
    }
}

impl<T> Drop for OpenedIpcMemory<T> {
    fn drop(&mut self) {
        if self.ptr.is_null() {
            return;
        }

        unsafe {
            if let Err(err) = try_cuda!(runtime::cudaIpcCloseMemHandle(self.ptr.cast())) {
                #[cfg(debug_assertions)]
                eprintln!("failed to close cuda ipc memory handle: {err}");
            }
        }
    }
}