use std::{ffi::CString, ptr, sync::Arc};

use singe_cuda_sys::driver;

use crate::{
    context::Context,
    error::{Error, Result},
    graph::{ExecutableGraph, Graph, GraphNode},
    kernel::{self, LibraryKernelHandle},
    module::{KernelFunction, KernelParameters, LaunchConfig, Module},
    try_cuda,
    types::{DeviceFunction, FunctionAttribute, FunctionCache},
};

#[derive(Debug)]
pub struct Library {
    handle: driver::CUlibrary,
    ctx: Arc<Context>,
}

#[derive(Debug, Clone, Copy)]
pub struct LibraryGlobal<'a> {
    ptr: *mut (),
    size: usize,
    _library: &'a Library,
}

#[derive(Debug, Clone, Copy)]
pub struct LibraryKernel<'a> {
    handle: driver::CUkernel,
    library: &'a Library,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct KernelParamInfo {
    pub offset: usize,
    pub size: usize,
}

impl Library {
    pub const unsafe fn from_raw(handle: driver::CUlibrary, ctx: Arc<Context>) -> Self {
        Self { handle, ctx }
    }

    pub fn kernel(&self, name: &str) -> Result<LibraryKernel<'_>> {
        let c_name = CString::new(name)?;
        let mut handle = ptr::null_mut();
        self.ctx.bind()?;
        unsafe {
            try_cuda!(driver::cuLibraryGetKernel(
                &raw mut handle,
                self.handle,
                c_name.as_ptr(),
            ))?;
        }
        if handle.is_null() {
            return Err(Error::NullHandle);
        }
        Ok(LibraryKernel {
            handle,
            library: self,
        })
    }

    pub fn kernel_count(&self) -> Result<usize> {
        let mut count = 0;
        self.ctx.bind()?;
        unsafe {
            try_cuda!(driver::cuLibraryGetKernelCount(&raw mut count, self.handle))?;
        }
        Ok(count as usize)
    }

    pub fn module(&self) -> Result<Module> {
        let mut handle = ptr::null_mut();
        self.ctx.bind()?;
        unsafe {
            try_cuda!(driver::cuLibraryGetModule(&raw mut handle, self.handle))?;
        }
        if handle.is_null() {
            return Err(Error::NullHandle);
        }
        Ok(unsafe { Module::from_borrowed_raw(handle, Arc::clone(&self.ctx)) })
    }

    pub fn global(&self, name: &str) -> Result<LibraryGlobal<'_>> {
        let c_name = CString::new(name)?;
        let mut ptr = 0;
        let mut size = 0;
        self.ctx.bind()?;
        unsafe {
            try_cuda!(driver::cuLibraryGetGlobal(
                &raw mut ptr,
                &raw mut size,
                self.handle,
                c_name.as_ptr(),
            ))?;
        }
        Ok(LibraryGlobal {
            ptr: ptr as *mut (),
            size: size as usize,
            _library: self,
        })
    }

    pub fn managed(&self, name: &str) -> Result<LibraryGlobal<'_>> {
        let c_name = CString::new(name)?;
        let mut ptr = 0;
        let mut size = 0;
        self.ctx.bind()?;
        unsafe {
            try_cuda!(driver::cuLibraryGetManaged(
                &raw mut ptr,
                &raw mut size,
                self.handle,
                c_name.as_ptr(),
            ))?;
        }
        Ok(LibraryGlobal {
            ptr: ptr as *mut (),
            size: size as usize,
            _library: self,
        })
    }

    pub fn unified_function(&self, symbol: &str) -> Result<*mut ()> {
        let c_symbol = CString::new(symbol)?;
        let mut ptr = ptr::null_mut();
        self.ctx.bind()?;
        unsafe {
            try_cuda!(driver::cuLibraryGetUnifiedFunction(
                &raw mut ptr,
                self.handle,
                c_symbol.as_ptr(),
            ))?;
        }
        if ptr.is_null() {
            return Err(Error::NullHandle);
        }
        Ok(ptr.cast())
    }

    pub const unsafe fn as_raw(&self) -> driver::CUlibrary {
        self.handle
    }
}

impl Drop for Library {
    fn drop(&mut self) {
        if let Err(err) = self.ctx.bind() {
            #[cfg(debug_assertions)]
            eprintln!("failed to bind context before unloading library: {err}");
            return;
        }

        unsafe {
            if let Err(err) = try_cuda!(driver::cuLibraryUnload(self.handle)) {
                #[cfg(debug_assertions)]
                eprintln!("failed to unload cuda library: {err}");
            }
        }
    }
}

impl LibraryGlobal<'_> {
    pub const fn as_ptr(&self) -> *mut () {
        self.ptr
    }

    pub const fn size(&self) -> usize {
        self.size
    }
}

impl LibraryKernel<'_> {
    pub fn name(&self) -> Result<String> {
        kernel::name::<LibraryKernelHandle>(self.library.ctx.as_ref(), self.handle)
    }

    pub fn function(&self) -> Result<DeviceFunction> {
        self.library.ctx.bind()?;
        let mut handle = ptr::null_mut();
        unsafe {
            try_cuda!(driver::cuKernelGetFunction(&raw mut handle, self.handle))?;
        }
        if handle.is_null() {
            return Err(Error::NullHandle);
        }
        Ok(handle.into())
    }

    pub fn add_to_graph(
        &self,
        graph: &mut Graph,
        dependencies: &[GraphNode],
        config: &LaunchConfig,
        params: &mut KernelParameters,
    ) -> Result<GraphNode> {
        let function = self.function()?;
        let module = self.library.module()?;
        let function = unsafe { KernelFunction::from_raw(function, &module) };
        function.add_to_graph(graph, dependencies, config, params)
    }

    pub fn set_graph_node_params(
        &self,
        executable: &mut ExecutableGraph,
        node: GraphNode,
        config: &LaunchConfig,
        params: &mut KernelParameters,
    ) -> Result<()> {
        let function = self.function()?;
        let module = self.library.module()?;
        let function = unsafe { KernelFunction::from_raw(function, &module) };
        function.set_graph_node_params(executable, node, config, params)
    }

    pub fn attribute(&self, attribute: FunctionAttribute) -> Result<i32> {
        kernel::attribute::<LibraryKernelHandle>(self.library.ctx.as_ref(), self.handle, attribute)
    }

    pub fn set_attribute(&self, attribute: FunctionAttribute, value: i32) -> Result<()> {
        kernel::set_attribute::<LibraryKernelHandle>(
            self.library.ctx.as_ref(),
            self.handle,
            attribute,
            value,
        )
    }

    pub fn set_cache_config(&self, config: FunctionCache) -> Result<()> {
        self.library.ctx.bind()?;
        unsafe {
            try_cuda!(driver::cuKernelSetCacheConfig(
                self.handle,
                config.into(),
                self.library.ctx.device().id() as _,
            ))?;
        }
        Ok(())
    }

    pub fn param_info(&self, index: usize) -> Result<KernelParamInfo> {
        self.library.ctx.bind()?;
        let mut offset = 0;
        let mut size = 0;
        unsafe {
            try_cuda!(driver::cuKernelGetParamInfo(
                self.handle,
                index as _,
                &raw mut offset,
                &raw mut size,
            ))?;
        }
        Ok(KernelParamInfo {
            offset: offset as usize,
            size: size as usize,
        })
    }

    pub const unsafe fn as_raw(&self) -> driver::CUkernel {
        self.handle
    }
}