use std::{ffi::CString, ptr, sync::Arc};

use singe_cuda_sys::driver;

use crate::{
    device::Device,
    error::{Error, Result},
    jit::JitOptions,
    library::Library,
    module::{Module, ModuleImage},
    nvrtc::{self, CompilationArtifact, OutputKind},
    try_cuda,
    types::Limit,
};

bitflags::bitflags! {
    /// Context creation flags.
    #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
    pub struct ContextFlags: u32 {
        const SCHEDULE_AUTO = driver::CUctx_flags::CU_CTX_SCHED_AUTO as _;
        const SCHEDULE_SPIN = driver::CUctx_flags::CU_CTX_SCHED_SPIN as _;
        const SCHEDULE_YIELD = driver::CUctx_flags::CU_CTX_SCHED_YIELD as _;
        const SCHEDULE_BLOCKING_SYNC = driver::CUctx_flags::CU_CTX_SCHED_BLOCKING_SYNC as _;
        const MAP_HOST = driver::CUctx_flags::CU_CTX_MAP_HOST as _;
        const LOCAL_MEMORY_RESIZE_TO_MAX = driver::CUctx_flags::CU_CTX_LMEM_RESIZE_TO_MAX as _;
        const COREDUMP_ENABLE = driver::CUctx_flags::CU_CTX_COREDUMP_ENABLE as _;
        const USER_COREDUMP_ENABLE = driver::CUctx_flags::CU_CTX_USER_COREDUMP_ENABLE as _;
        const SYNC_MEMORY_OPERATIONS = driver::CUctx_flags::CU_CTX_SYNC_MEMOPS as _;
    }
}

/// A shared CUDA driver context.
///
/// Unlike cuBLAS, cuDNN, cuFFT, and similar library handles, a CUDA context is
/// the underlying execution environment for a device. It is intended to be
/// shared by streams, modules, libraries, events, allocations, and higher-level
/// library wrappers.
///
/// This type is therefore reference-counted by returning `Arc<Self>` from the
/// constructors, and it remains `Send + Sync`. Shared references do not mutate
/// Rust-visible state on the `Context` object itself; methods such as `bind`
/// update the calling thread's current CUDA context in the driver.
///
/// Prefer one long-lived context per device and share it across dependent CUDA
/// objects instead of creating many short-lived contexts.
#[derive(Debug)]
pub struct Context {
    handle: driver::CUcontext,
    device: Device,
}

impl Context {
    pub fn create() -> Result<Arc<Self>> {
        Self::create_with_flags(ContextFlags::empty())
    }

    pub fn create_with_flags(flags: ContextFlags) -> Result<Arc<Self>> {
        let device = Device::current()?;
        Self::create_for_device_with_flags(device, flags)
    }

    pub fn create_for_device(device: Device) -> Result<Arc<Self>> {
        Self::create_for_device_with_flags(device, ContextFlags::empty())
    }

    pub fn create_for_device_with_flags(device: Device, flags: ContextFlags) -> Result<Arc<Self>> {
        unsafe {
            try_cuda!(driver::cuInit(0))?;

            let mut handle = ptr::null_mut();
            try_cuda!(driver::cuCtxCreate_v4(
                &raw mut handle,
                ptr::null_mut(), // CUctxCreateParams
                flags.bits(),
                device.id() as _,
            ))?;

            if handle.is_null() {
                return Err(Error::NullHandle);
            }

            Ok(Arc::new(Self { handle, device }))
        }
    }

    /// Binds this CUDA context to the calling CPU thread.
    ///
    /// The "current context" is thread-local driver state. Calling this method
    /// does not mutate the Rust `Context` value itself; it makes this context
    /// current for subsequent CUDA driver and interoperating runtime calls on
    /// the current host thread.
    pub fn bind(&self) -> Result<()> {
        unsafe {
            let mut current_ctx = ptr::null_mut();
            try_cuda!(driver::cuCtxGetCurrent(&raw mut current_ctx))?;
            if current_ctx == self.as_raw() {
                return Ok(());
            }
            try_cuda!(driver::cuCtxSetCurrent(self.as_raw()))?;
        }
        Ok(())
    }

    pub fn load_module(self: &Arc<Self>, image: &ModuleImage<'_>) -> Result<Module> {
        self.bind()?;

        unsafe {
            let mut module_handle = ptr::null_mut();
            try_cuda!(driver::cuModuleLoadData(
                &raw mut module_handle,
                image.as_ptr() as _,
            ))?;
            if module_handle.is_null() {
                return Err(Error::NullHandle);
            }
            Ok(Module::from_raw(module_handle, Arc::clone(self)))
        }
    }

    pub fn unload_module(self: &Arc<Self>, module: Module) -> Result<()> {
        drop(module);
        Ok(())
    }

    pub fn load_module_with_options(
        self: &Arc<Self>,
        image: &ModuleImage<'_>,
        mut jit_options: JitOptions<'_>,
    ) -> Result<Module> {
        self.bind()?;

        let mut jit_options = jit_options.build();
        unsafe {
            let mut module_handle = ptr::null_mut();
            try_cuda!(driver::cuModuleLoadDataEx(
                &raw mut module_handle,
                image.as_ptr() as _,
                jit_options.names.len() as _,
                jit_options.names.as_mut_ptr() as _,
                jit_options.values.as_mut_ptr() as _,
            ))?;
            if module_handle.is_null() {
                return Err(Error::NullHandle);
            }
            Ok(Module::from_raw(module_handle, Arc::clone(self)))
        }
    }

    pub fn load_nvrtc_module(
        self: &Arc<Self>,
        program: &nvrtc::Program,
        output: OutputKind,
    ) -> Result<Module> {
        self.load_nvrtc_module_with_options(program, output, JitOptions::default())
    }

    pub fn load_nvrtc_module_with_options(
        self: &Arc<Self>,
        program: &nvrtc::Program,
        output: OutputKind,
        jit_options: JitOptions<'_>,
    ) -> Result<Module> {
        let image = module_loadable_image(program.artifact(output)?)?;
        self.load_module_with_options(&image, jit_options)
    }

    pub fn load_library(self: &Arc<Self>, image: &ModuleImage<'_>) -> Result<Library> {
        self.load_library_with_options(image, JitOptions::default())
    }

    pub fn load_library_with_options(
        self: &Arc<Self>,
        image: &ModuleImage<'_>,
        mut jit_options: JitOptions<'_>,
    ) -> Result<Library> {
        self.bind()?;

        let mut jit_options = jit_options.build();
        let mut handle = ptr::null_mut();
        unsafe {
            try_cuda!(driver::cuLibraryLoadData(
                &raw mut handle,
                image.as_ptr() as _,
                jit_options.names.as_mut_ptr() as _,
                jit_options.values.as_mut_ptr() as _,
                jit_options.names.len() as _,
                ptr::null_mut(),
                ptr::null_mut(),
                0,
            ))?;
        }
        if handle.is_null() {
            return Err(Error::NullHandle);
        }
        Ok(unsafe { Library::from_raw(handle, Arc::clone(self)) })
    }

    pub fn load_nvrtc_library(
        self: &Arc<Self>,
        program: &nvrtc::Program,
        output: OutputKind,
    ) -> Result<Library> {
        self.load_nvrtc_library_with_options(program, output, JitOptions::default())
    }

    pub fn load_nvrtc_library_with_options(
        self: &Arc<Self>,
        program: &nvrtc::Program,
        output: OutputKind,
        jit_options: JitOptions<'_>,
    ) -> Result<Library> {
        let image = library_loadable_image(program.artifact(output)?)?;
        self.load_library_with_options(&image, jit_options)
    }

    pub fn load_library_from_file(self: &Arc<Self>, path: &str) -> Result<Library> {
        self.bind()?;
        let path = CString::new(path)?;
        let mut handle = ptr::null_mut();
        unsafe {
            try_cuda!(driver::cuLibraryLoadFromFile(
                &raw mut handle,
                path.as_ptr(),
                ptr::null_mut(),
                ptr::null_mut(),
                0,
                ptr::null_mut(),
                ptr::null_mut(),
                0,
            ))?;
        }
        if handle.is_null() {
            return Err(Error::NullHandle);
        }
        Ok(unsafe { Library::from_raw(handle, Arc::clone(self)) })
    }

    pub fn synchronize(&self) -> Result<()> {
        self.bind()?;
        unsafe {
            try_cuda!(driver::cuCtxSynchronize())?;
        }
        Ok(())
    }

    pub fn flags(&self) -> Result<ContextFlags> {
        self.bind()?;
        unsafe {
            let mut flags = 0;
            try_cuda!(driver::cuCtxGetFlags(&raw mut flags))?;
            Ok(ContextFlags::from_bits_truncate(flags))
        }
    }

    pub fn limit(&self, limit: Limit) -> Result<usize> {
        self.bind()?;
        unsafe {
            let mut value = 0;
            try_cuda!(driver::cuCtxGetLimit(&raw mut value, limit.into()))?;
            Ok(value as usize)
        }
    }

    pub fn set_limit(&self, limit: Limit, value: usize) -> Result<()> {
        self.bind()?;
        unsafe {
            try_cuda!(driver::cuCtxSetLimit(limit.into(), value as _))?;
        }
        Ok(())
    }

    pub const fn device(&self) -> Device {
        self.device
    }

    pub const unsafe fn as_raw(&self) -> driver::CUcontext {
        self.handle
    }
}

unsafe impl Send for Context {}

// CUDA driver contexts are shared execution environments, not per-thread
// library handles. The Rust wrapper only stores the raw context pointer and the
// owning device, while current-context selection is maintained by CUDA as
// thread-local driver state.
unsafe impl Sync for Context {}

impl Drop for Context {
    fn drop(&mut self) {
        unsafe {
            if let Err(err) = try_cuda!(driver::cuCtxDestroy_v2(self.handle)) {
                #[cfg(debug_assertions)]
                eprintln!("failed to destroy CUDA context: {err}");
            }
        }
    }
}

impl PartialEq for Context {
    fn eq(&self, other: &Self) -> bool {
        unsafe { self.as_raw() == other.as_raw() }
    }
}

impl Eq for Context {}

fn module_loadable_image(artifact: CompilationArtifact) -> Result<ModuleImage<'static>> {
    match artifact {
        CompilationArtifact::Ptx(image) | CompilationArtifact::Cubin(image) => Ok(image),
        CompilationArtifact::LtoIr(_) | CompilationArtifact::OptixIr(_) => Err(Error::InvalidValue),
    }
}

fn library_loadable_image(artifact: CompilationArtifact) -> Result<ModuleImage<'static>> {
    match artifact {
        CompilationArtifact::Ptx(image) | CompilationArtifact::Cubin(image) => Ok(image),
        CompilationArtifact::LtoIr(_) | CompilationArtifact::OptixIr(_) => Err(Error::InvalidValue),
    }
}