use std::{ptr, sync::Arc};

use singe_cuda_sys::{driver, runtime};

use crate::{
    context::Context,
    error::{Error, Result},
    stream::{BorrowedStream, Stream, StreamBinding},
    try_cuda,
};

bitflags::bitflags! {
    /// Flags for CUDA event creation (`cudaEventCreateWithFlags`).
    #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
    pub struct EventFlags: u32 {
        const DEFAULT = driver::CUevent_flags::CU_EVENT_DEFAULT as _;
        const BLOCKING_SYNC = driver::CUevent_flags::CU_EVENT_BLOCKING_SYNC as _;
        const DISABLE_TIMING = driver::CUevent_flags::CU_EVENT_DISABLE_TIMING as _;
        const INTERPROCESS = driver::CUevent_flags::CU_EVENT_INTERPROCESS as _;
    }
}

bitflags::bitflags! {
    /// Flags for `cudaEventRecordWithFlags`.
    #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
    pub struct EventRecordFlags: u32 {
        const DEFAULT = runtime::cudaEventRecordDefault;
        const EXTERNAL = runtime::cudaEventRecordExternal;
    }
}

#[derive(Debug)]
pub struct Event {
    handle: runtime::cudaEvent_t,
    ctx: Arc<Context>,
}

impl Event {
    pub fn record(&self, stream: &Stream, flags: EventRecordFlags) -> Result<()> {
        if stream.context() != self.context() {
            return Err(runtime::cudaError_t::CUDA_ERROR_INVALID_CONTEXT.into());
        }
        self.record_raw(unsafe { stream.as_raw() }, flags)
    }

    pub fn record_borrowed(&self, stream: &BorrowedStream, flags: EventRecordFlags) -> Result<()> {
        if stream.context() != self.context() {
            return Err(runtime::cudaError_t::CUDA_ERROR_INVALID_CONTEXT.into());
        }
        self.record_raw(stream.as_raw(), flags)
    }

    pub fn record_on(&self, stream: &StreamBinding, flags: EventRecordFlags) -> Result<()> {
        if stream.context() != self.context() {
            return Err(runtime::cudaError_t::CUDA_ERROR_INVALID_CONTEXT.into());
        }
        self.record_raw(stream.as_raw(), flags)
    }

    fn record_raw(&self, stream: runtime::cudaStream_t, flags: EventRecordFlags) -> Result<()> {
        self.ctx.bind()?;
        unsafe {
            try_cuda!(runtime::cudaEventRecordWithFlags(
                self.as_raw(),
                stream,
                flags.bits(),
            ))?;
        }
        Ok(())
    }

    pub fn query(&self) -> Result<bool> {
        let error = unsafe { runtime::cudaEventQuery(self.as_raw()) };
        match error {
            runtime::cudaError_t::CUDA_SUCCESS => Ok(true),
            runtime::cudaError_t::CUDA_ERROR_NOT_READY => Ok(false),
            _ => Err(error.into()),
        }
    }

    pub fn synchronize(&self) -> Result<()> {
        self.ctx.bind()?;
        unsafe {
            try_cuda!(runtime::cudaEventSynchronize(self.as_raw()))?;
        }
        Ok(())
    }

    pub fn elapsed_time_since(&self, start: &Event) -> Result<f32> {
        if self.context() != start.context() {
            return Err(runtime::cudaError_t::CUDA_ERROR_INVALID_CONTEXT.into());
        }

        self.ctx.bind()?;
        let mut milliseconds = 0.0f32;
        unsafe {
            try_cuda!(runtime::cudaEventElapsedTime(
                &raw mut milliseconds,
                start.as_raw(),
                self.as_raw(),
            ))?;
        }
        Ok(milliseconds)
    }

    pub fn context(&self) -> &Context {
        &self.ctx
    }

    pub const unsafe fn as_raw(&self) -> runtime::cudaEvent_t {
        self.handle
    }
}

unsafe impl Send for Event {}

unsafe impl Sync for Event {}

impl Drop for Event {
    fn drop(&mut self) {
        if let Err(err) = self.ctx.bind() {
            #[cfg(debug_assertions)]
            eprintln!("failed to bind context before destroying stream: {err}");
        }
        unsafe {
            if let Err(err) = try_cuda!(runtime::cudaEventDestroy(self.handle)) {
                #[cfg(debug_assertions)]
                eprintln!("failed to destroy CUDA event: {err}");
            }
        }
    }
}

impl Context {
    pub fn create_event(self: &Arc<Self>) -> Result<Event> {
        self.create_event_with_flags(EventFlags::DEFAULT)
    }

    pub fn create_event_with_flags(self: &Arc<Self>, flags: EventFlags) -> Result<Event> {
        self.bind()?;
        let mut handle = ptr::null_mut();
        unsafe {
            try_cuda!(runtime::cudaEventCreateWithFlags(
                &raw mut handle,
                flags.bits()
            ))?;
        }
        if handle.is_null() {
            return Err(Error::NullHandle);
        }
        Ok(Event {
            handle,
            ctx: Arc::clone(self),
        })
    }
}