rcublas 0.6.0

safe Rust wrapper for CUDA's cuBLAS
use super::{Operation, PointerMode};
use crate::ffi::*;
use crate::{Error, API};

#[derive(Debug, Clone)]
/// Provides a the low-level cuBLAS context.
pub struct Context {
    id: cublasHandle_t,

// would yield a huge perf gain by avoiding Arc<Mutex<..>>
// but we want to play it safe for the time being
// unsafe impl ::std::marker::Sync for Context {}

// required for Arc<Mutex<..>> only
unsafe impl ::std::marker::Send for Context {}

impl Drop for Context {
    fn drop(&mut self) {
        unsafe { API::destroy(self) };

impl Context {
    /// Create a new cuBLAS Context by calling the low-level API.
    /// Context creation should done as sparely as possible.
    /// It is best to keep a context around as long as possible.
    pub fn new() -> Result<Context, Error> {

    /// Create a new cuBLAS Context from its C type.
    pub fn from_c(id: cublasHandle_t) -> Context {
        Context { id }

    /// Returns the cuBLAS Context as its C type.
    pub fn id_c(&self) -> &cublasHandle_t {

    pub fn pointer_mode(&self) -> Result<PointerMode, Error> {

    pub fn set_pointer_mode(&mut self, pointer_mode: PointerMode) -> Result<(), Error> {
        API::set_pointer_mode(self, pointer_mode)

    // Level 1 operations

    pub fn asum(
        x: *mut f32,
        result: *mut f32,
        n: i32,
        stride: Option<i32>,
    ) -> Result<(), Error> {
        API::asum(self, x, result, n, stride)

    pub fn axpy(
        alpha: *mut f32,
        x: *mut f32,
        y: *mut f32,
        n: i32,
        stride_x: Option<i32>,
        stride_y: Option<i32>,
    ) -> Result<(), Error> {
        API::axpy(self, alpha, x, y, n, stride_x, stride_y)

    pub fn copy(
        x: *mut f32,
        y: *mut f32,
        n: i32,
        stride_x: Option<i32>,
        stride_y: Option<i32>,
    ) -> Result<(), Error> {
        API::copy(self, x, y, n, stride_x, stride_y)

    pub fn dot(
        x: *mut f32,
        y: *mut f32,
        result: *mut f32,
        n: i32,
        stride_x: Option<i32>,
        stride_y: Option<i32>,
    ) -> Result<(), Error> {
        API::dot(self, x, y, result, n, stride_x, stride_y)

    pub fn nrm2(
        x: *mut f32,
        result: *mut f32,
        n: i32,
        stride_x: Option<i32>,
    ) -> Result<(), Error> {
        API::nrm2(self, x, result, n, stride_x)

    pub fn scal(
        alpha: *mut f32,
        x: *mut f32,
        n: i32,
        stride_x: Option<i32>,
    ) -> Result<(), Error> {
        API::scal(self, alpha, x, n, stride_x)

    pub fn swap(
        x: *mut f32,
        y: *mut f32,
        n: i32,
        stride_x: Option<i32>,
        stride_y: Option<i32>,
    ) -> Result<(), Error> {
        API::swap(self, x, y, n, stride_x, stride_y)

    // Level 3 operations
    pub fn gemm(
        transa: Operation,
        transb: Operation,
        m: i32,
        n: i32,
        k: i32,
        alpha: *mut f32,
        a: *mut f32,
        lda: i32,
        b: *mut f32,
        ldb: i32,
        beta: *mut f32,
        c: *mut f32,
        ldc: i32,
    ) -> Result<(), Error> {
            self, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,

    pub fn get_version(&self) -> i32 {

mod test {
    use super::super::PointerMode;
    use super::*;
    use crate::chore::*;

    fn create_context() {



    fn default_pointer_mode_is_host() {

        let ctx = Context::new().unwrap();
        let mode = ctx.pointer_mode().unwrap();
        assert_eq!(PointerMode::Host, mode);


    fn can_set_pointer_mode() {

        let mut context = Context::new().unwrap();
        // set to Device
        let mode = context.pointer_mode().unwrap();
        assert_eq!(PointerMode::Device, mode);
        // set to Host
        let mode2 = context.pointer_mode().unwrap();
        assert_eq!(PointerMode::Host, mode2);
