use std::{
    ffi::{CStr, NulError},
    io,
};

use thiserror::Error;

use singe_cuda_sys::{nvrtc, runtime};

#[derive(Error, Debug)]
pub enum Error {
    #[error("cuda error ({code}): {message}")]
    Cuda { code: u32, message: String },

    #[error("nvrtc error ({code}): {message}")]
    Nvrtc { code: u32, message: String },

    #[error("string contains interior nul byte")]
    InteriorNul,

    #[error("io error: {0}")]
    Io(#[from] io::Error),

    #[error("device not found")]
    DeviceNotFound,

    #[error("invalid memory access")]
    InvalidMemoryAccess,
    #[error("invalid memory allocation request")]
    InvalidMemoryAllocationRequest,
    #[error("unexpected null handle")]
    NullHandle,
    #[error("invalid value")]
    InvalidValue,

    #[error("operation graph required")]
    GraphOperationRequired,
    #[error("graph dependency mismatch")]
    GraphDependencyMismatch,
}

pub type Result<T> = std::result::Result<T, Error>;

impl From<runtime::cudaError_t> for Error {
    fn from(code: runtime::cudaError_t) -> Self {
        let message = unsafe {
            let c_ptr = runtime::cudaGetErrorString(code);
            if c_ptr.is_null() {
                String::from("unknown cuda error")
            } else {
                CStr::from_ptr(c_ptr).to_string_lossy().into_owned()
            }
        };

        Self::Cuda {
            code: code.into(),
            message,
        }
    }
}

impl From<NulError> for Error {
    fn from(_: NulError) -> Self {
        Self::InteriorNul
    }
}

impl From<nvrtc::nvrtcResult> for Error {
    fn from(code: nvrtc::nvrtcResult) -> Self {
        let message = unsafe {
            let c_ptr = nvrtc::nvrtcGetErrorString(code);
            if c_ptr.is_null() {
                String::from("unknown nvrtc error")
            } else {
                CStr::from_ptr(c_ptr).to_string_lossy().into_owned()
            }
        };

        Self::Nvrtc {
            code: code.into(),
            message,
        }
    }
}

#[macro_export]
macro_rules! try_cuda {
    ($expr:expr) => {{
        let err = { $expr };
        if err != singe_cuda_sys::runtime::cudaError_t::CUDA_SUCCESS {
            Err($crate::error::Error::from(err))
        } else {
            Ok(())
        }
    }};
}

#[macro_export]
macro_rules! try_nvrtc {
    ($expr:expr) => {{
        let err = { $expr };
        if err != singe_cuda_sys::nvrtc::nvrtcResult::NVRTC_SUCCESS {
            Err($crate::error::Error::from(err))
        } else {
            Ok(())
        }
    }};
}