singe-cublas 0.1.0-alpha.5

Safe Rust wrappers for the NVIDIA cuBLAS dense linear algebra library (with cuBLASLt).
Documentation
use singe_cuda::memory::DeviceMemory;

use crate::{
    error::{Error, Result},
    types::PointerMode,
};

#[derive(Debug, Clone, Copy)]
pub struct HostScalar<'a, T> {
    value: &'a T,
}

impl<'a, T> HostScalar<'a, T> {
    pub const fn new(value: &'a T) -> Self {
        Self { value }
    }

    pub const fn as_ptr(&self) -> *const T {
        self.value
    }
}

#[derive(Debug)]
pub struct DeviceScalar<'a, T> {
    value: &'a DeviceMemory<T>,
}

impl<'a, T> DeviceScalar<'a, T> {
    pub fn create(value: &'a DeviceMemory<T>) -> Result<Self> {
        if value.is_empty() {
            return Err(Error::InvalidVectorShape);
        }
        Ok(Self { value })
    }

    pub fn as_ptr(&self) -> *const T {
        self.value.as_ptr()
    }
}

#[derive(Debug)]
pub enum Scalar<'a, T> {
    Host(HostScalar<'a, T>),
    Device(DeviceScalar<'a, T>),
}

impl<'a, T> Scalar<'a, T> {
    pub const fn host(value: &'a T) -> Self {
        Self::Host(HostScalar::new(value))
    }

    pub fn device(value: &'a DeviceMemory<T>) -> Result<Self> {
        Ok(Self::Device(DeviceScalar::create(value)?))
    }

    pub const fn pointer_mode(&self) -> PointerMode {
        match self {
            Self::Host(_) => PointerMode::Host,
            Self::Device(_) => PointerMode::Device,
        }
    }

    pub fn as_ptr(&self) -> *const T {
        match self {
            Self::Host(value) => value.as_ptr(),
            Self::Device(value) => value.as_ptr(),
        }
    }
}

impl<'a, T> From<&'a T> for Scalar<'a, T> {
    fn from(value: &'a T) -> Self {
        Self::host(value)
    }
}