1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
use ffi::*;
use ::{API, Error};
use super::{Operation, PointerMode};
#[derive(Debug, Clone)]
pub struct Context {
id: cublasHandle_t,
}
unsafe impl ::std::marker::Sync for Context {}
impl Drop for Context {
#[allow(unused_must_use)]
fn drop(&mut self) {
unsafe{ API::destroy(self) };
}
}
impl Context {
pub fn new() -> Result<Context, Error> {
API::create()
}
pub fn from_c(id: cublasHandle_t) -> Context {
Context { id: id }
}
pub fn id_c(&self) -> &cublasHandle_t {
&self.id
}
pub fn pointer_mode(&self) -> Result<PointerMode, Error> {
API::get_pointer_mode(self)
}
pub fn set_pointer_mode(&mut self, pointer_mode: PointerMode) -> Result<(), Error> {
API::set_pointer_mode(self, pointer_mode)
}
pub fn asum(&self, x: *mut f32, result: *mut f32, n: i32, stride: Option<i32>) -> Result<(), Error> {
API::asum(self, x, result, n, stride)
}
pub fn axpy(&self, alpha: *mut f32, x: *mut f32, y: *mut f32, n: i32, stride_x: Option<i32>, stride_y: Option<i32>) -> Result<(), Error> {
API::axpy(self, alpha, x, y, n, stride_x, stride_y)
}
pub fn copy(&self, x: *mut f32, y: *mut f32, n: i32, stride_x: Option<i32>, stride_y: Option<i32>) -> Result<(), Error> {
API::copy(self, x, y, n, stride_x, stride_y)
}
pub fn dot(&self, x: *mut f32, y: *mut f32, result: *mut f32, n: i32, stride_x: Option<i32>, stride_y: Option<i32>) -> Result<(), Error> {
API::dot(self, x, y, result, n, stride_x, stride_y)
}
pub fn nrm2(&self, x: *mut f32, result: *mut f32, n: i32, stride_x: Option<i32>) -> Result<(), Error> {
API::nrm2(self, x, result, n, stride_x)
}
pub fn scal(&self, alpha: *mut f32, x: *mut f32, n: i32, stride_x: Option<i32>) -> Result<(), Error> {
API::scal(self, alpha, x, n, stride_x)
}
pub fn swap(&self, x: *mut f32, y: *mut f32, n: i32, stride_x: Option<i32>, stride_y: Option<i32>) -> Result<(), Error> {
API::swap(self, x, y, n, stride_x, stride_y)
}
pub fn gemm(&self,
transa: Operation, transb: Operation,
m: i32, n: i32, k: i32,
alpha: *mut f32,
a: *mut f32, lda: i32,
b: *mut f32, ldb: i32,
beta: *mut f32,
c: *mut f32, ldc: i32) -> Result<(), Error> {
API::gemm(self, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)
}
}
#[cfg(test)]
mod test {
use super::*;
use super::super::PointerMode;
#[test]
fn create_context() {
Context::new().unwrap();
}
#[test]
fn default_pointer_mode_is_host() {
let ctx = Context::new().unwrap();
let mode = ctx.pointer_mode().unwrap();
assert_eq!(PointerMode::Host, mode);
}
#[test]
fn can_set_pointer_mode() {
let mut context = Context::new().unwrap();
context.set_pointer_mode(PointerMode::Device).unwrap();
let mode = context.pointer_mode().unwrap();
assert_eq!(PointerMode::Device, mode);
context.set_pointer_mode(PointerMode::Host).unwrap();
let mode2 = context.pointer_mode().unwrap();
assert_eq!(PointerMode::Host, mode2);
}
}