#[allow(unused_imports)]
use crate::error::Status;
use std::ptr;
use singe_cuda::{data_type::DataType, memory::DeviceMemory};
use crate::{
context::Context,
error::{Error, Result},
sys, try_ffi,
utility::to_i32,
};
pub fn sscal(ctx: &Context, alpha: &f32, x: &mut DeviceMemory<f32>, incx: usize) -> Result<()> {
ctx.bind()?;
ctx.require_host_pointer_mode()?;
if incx == 0 {
return Err(Error::InvalidIncrement);
}
let n = vector_len(x.len(), incx)?;
if n == 0 {
return Ok(());
}
let incx = to_i32(incx, "incx")?;
unsafe {
try_ffi!(sys::cublasSscal_v2(
ctx.as_raw(),
n,
alpha,
x.as_mut_ptr(),
incx,
))?;
}
Ok(())
}
pub fn nrm2_ex<TX, TResult>(
ctx: &Context,
x: &DeviceMemory<TX>,
x_type: DataType,
incx: usize,
result: &mut TResult,
result_type: DataType,
execution_type: DataType,
) -> Result<()> {
ctx.bind()?;
ctx.require_host_pointer_mode()?;
if incx == 0 {
return Err(Error::InvalidIncrement);
}
let n = vector_len(x.len(), incx)?;
if n == 0 {
return Ok(());
}
let incx = to_i32(incx, "incx")?;
unsafe {
try_ffi!(sys::cublasNrm2Ex(
ctx.as_raw(),
n,
x.as_ptr() as _,
x_type.into(),
incx,
ptr::from_mut(result) as _,
result_type.into(),
execution_type.into(),
))?;
}
Ok(())
}
pub fn dot_ex<TX, TY, TResult>(
ctx: &Context,
x: &DeviceMemory<TX>,
x_type: DataType,
incx: usize,
y: &DeviceMemory<TY>,
y_type: DataType,
incy: usize,
result: &mut TResult,
result_type: DataType,
execution_type: DataType,
) -> Result<()> {
ctx.bind()?;
ctx.require_host_pointer_mode()?;
if incx == 0 || incy == 0 {
return Err(Error::InvalidIncrement);
}
let n = vector_binary_len(x.len(), incx, y.len(), incy)?;
if n == 0 {
return Ok(());
}
let incx = to_i32(incx, "incx")?;
let incy = to_i32(incy, "incy")?;
unsafe {
try_ffi!(sys::cublasDotEx(
ctx.as_raw(),
n,
x.as_ptr() as _,
x_type.into(),
incx,
y.as_ptr() as _,
y_type.into(),
incy,
ptr::from_mut(result) as _,
result_type.into(),
execution_type.into(),
))?;
}
Ok(())
}
pub fn dotc_ex<TX, TY, TResult>(
ctx: &Context,
x: &DeviceMemory<TX>,
x_type: DataType,
incx: usize,
y: &DeviceMemory<TY>,
y_type: DataType,
incy: usize,
result: &mut TResult,
result_type: DataType,
execution_type: DataType,
) -> Result<()> {
ctx.bind()?;
ctx.require_host_pointer_mode()?;
if incx == 0 || incy == 0 {
return Err(Error::InvalidIncrement);
}
let n = vector_binary_len(x.len(), incx, y.len(), incy)?;
if n == 0 {
return Ok(());
}
let incx = to_i32(incx, "incx")?;
let incy = to_i32(incy, "incy")?;
unsafe {
try_ffi!(sys::cublasDotcEx(
ctx.as_raw(),
n,
x.as_ptr() as _,
x_type.into(),
incx,
y.as_ptr() as _,
y_type.into(),
incy,
ptr::from_mut(result) as _,
result_type.into(),
execution_type.into(),
))?;
}
Ok(())
}
pub fn drot(
ctx: &Context,
x: &mut DeviceMemory<f64>,
incx: usize,
y: &mut DeviceMemory<f64>,
incy: usize,
c: &f64,
s: &f64,
) -> Result<()> {
ctx.bind()?;
ctx.require_host_pointer_mode()?;
if incx == 0 || incy == 0 {
return Err(Error::InvalidIncrement);
}
let n = vector_binary_len(x.len(), incx, y.len(), incy)?;
if n == 0 {
return Ok(());
}
let incx = to_i32(incx, "incx")?;
let incy = to_i32(incy, "incy")?;
unsafe {
try_ffi!(sys::cublasDrot_v2(
ctx.as_raw(),
n,
x.as_mut_ptr(),
incx,
y.as_mut_ptr(),
incy,
c,
s,
))?;
}
Ok(())
}
pub fn rot_ex<TX, TY, TCS>(
ctx: &Context,
x: &mut DeviceMemory<TX>,
x_type: DataType,
incx: usize,
y: &mut DeviceMemory<TY>,
y_type: DataType,
incy: usize,
c: &TCS,
s: &TCS,
cs_type: DataType,
execution_type: DataType,
) -> Result<()> {
ctx.bind()?;
ctx.require_host_pointer_mode()?;
if incx == 0 || incy == 0 {
return Err(Error::InvalidIncrement);
}
let n = vector_binary_len(x.len(), incx, y.len(), incy)?;
if n == 0 {
return Ok(());
}
let incx = to_i32(incx, "incx")?;
let incy = to_i32(incy, "incy")?;
unsafe {
try_ffi!(sys::cublasRotEx(
ctx.as_raw(),
n,
x.as_mut_ptr() as _,
x_type.into(),
incx,
y.as_mut_ptr() as _,
y_type.into(),
incy,
ptr::from_ref(c) as _,
ptr::from_ref(s) as _,
cs_type.into(),
execution_type.into(),
))?;
}
Ok(())
}
pub fn drotg(ctx: &Context, a: &mut f64, b: &mut f64, c: &mut f64, s: &mut f64) -> Result<()> {
ctx.bind()?;
ctx.require_host_pointer_mode()?;
unsafe {
try_ffi!(sys::cublasDrotg_v2(
ctx.as_raw(),
ptr::from_mut(a),
ptr::from_mut(b),
ptr::from_mut(c),
ptr::from_mut(s),
))?;
}
Ok(())
}
pub fn drotm(
ctx: &Context,
x: &mut DeviceMemory<f64>,
incx: usize,
y: &mut DeviceMemory<f64>,
incy: usize,
param: &[f64; 5],
) -> Result<()> {
ctx.bind()?;
ctx.require_host_pointer_mode()?;
if incx == 0 || incy == 0 {
return Err(Error::InvalidIncrement);
}
let n = vector_binary_len(x.len(), incx, y.len(), incy)?;
if n == 0 {
return Ok(());
}
let incx = to_i32(incx, "incx")?;
let incy = to_i32(incy, "incy")?;
unsafe {
try_ffi!(sys::cublasDrotm_v2(
ctx.as_raw(),
n,
x.as_mut_ptr(),
incx,
y.as_mut_ptr(),
incy,
param.as_ptr(),
))?;
}
Ok(())
}
pub fn drotmg(
ctx: &Context,
d1: &mut f64,
d2: &mut f64,
x1: &mut f64,
y1: &f64,
param: &mut [f64; 5],
) -> Result<()> {
ctx.bind()?;
ctx.require_host_pointer_mode()?;
unsafe {
try_ffi!(sys::cublasDrotmg_v2(
ctx.as_raw(),
ptr::from_mut(d1),
ptr::from_mut(d2),
ptr::from_mut(x1),
ptr::from_ref(y1),
param.as_mut_ptr(),
))?;
}
Ok(())
}
pub fn scal_ex<TAlpha, TX>(
ctx: &Context,
alpha: &TAlpha,
alpha_type: DataType,
x: &mut DeviceMemory<TX>,
x_type: DataType,
incx: usize,
execution_type: DataType,
) -> Result<()> {
ctx.bind()?;
ctx.require_host_pointer_mode()?;
if incx == 0 {
return Err(Error::InvalidIncrement);
}
let n = vector_len(x.len(), incx)?;
if n == 0 {
return Ok(());
}
let incx = to_i32(incx, "incx")?;
unsafe {
try_ffi!(sys::cublasScalEx(
ctx.as_raw(),
n,
ptr::from_ref(alpha) as _,
alpha_type.into(),
x.as_mut_ptr() as _,
x_type.into(),
incx,
execution_type.into(),
))?;
}
Ok(())
}
pub fn axpy_ex<TAlpha, TX, TY>(
ctx: &Context,
alpha: &TAlpha,
alpha_type: DataType,
x: &DeviceMemory<TX>,
x_type: DataType,
incx: usize,
y: &mut DeviceMemory<TY>,
y_type: DataType,
incy: usize,
execution_type: DataType,
) -> Result<()> {
ctx.bind()?;
ctx.require_host_pointer_mode()?;
if incx == 0 || incy == 0 {
return Err(Error::InvalidIncrement);
}
let n = vector_binary_len(x.len(), incx, y.len(), incy)?;
if n == 0 {
return Ok(());
}
let incx = to_i32(incx, "incx")?;
let incy = to_i32(incy, "incy")?;
unsafe {
try_ffi!(sys::cublasAxpyEx(
ctx.as_raw(),
n,
ptr::from_ref(alpha) as _,
alpha_type.into(),
x.as_ptr() as _,
x_type.into(),
incx,
y.as_mut_ptr() as _,
y_type.into(),
incy,
execution_type.into(),
))?;
}
Ok(())
}
pub fn copy_ex<TX, TY>(
ctx: &Context,
x: &DeviceMemory<TX>,
x_type: DataType,
incx: usize,
y: &mut DeviceMemory<TY>,
y_type: DataType,
incy: usize,
) -> Result<()> {
ctx.bind()?;
if incx == 0 || incy == 0 {
return Err(Error::InvalidIncrement);
}
let n = vector_binary_len(x.len(), incx, y.len(), incy)?;
if n == 0 {
return Ok(());
}
let incx = to_i32(incx, "incx")?;
let incy = to_i32(incy, "incy")?;
unsafe {
try_ffi!(sys::cublasCopyEx(
ctx.as_raw(),
n,
x.as_ptr() as _,
x_type.into(),
incx,
y.as_mut_ptr() as _,
y_type.into(),
incy,
))?;
}
Ok(())
}
pub fn swap_ex<TX, TY>(
ctx: &Context,
x: &mut DeviceMemory<TX>,
x_type: DataType,
incx: usize,
y: &mut DeviceMemory<TY>,
y_type: DataType,
incy: usize,
) -> Result<()> {
ctx.bind()?;
if incx == 0 || incy == 0 {
return Err(Error::InvalidIncrement);
}
let n = vector_binary_len(x.len(), incx, y.len(), incy)?;
if n == 0 {
return Ok(());
}
let incx = to_i32(incx, "incx")?;
let incy = to_i32(incy, "incy")?;
unsafe {
try_ffi!(sys::cublasSwapEx(
ctx.as_raw(),
n,
x.as_mut_ptr() as _,
x_type.into(),
incx,
y.as_mut_ptr() as _,
y_type.into(),
incy,
))?;
}
Ok(())
}
pub fn iamax_ex<TX>(
ctx: &Context,
x: &DeviceMemory<TX>,
x_type: DataType,
incx: usize,
) -> Result<i32> {
ctx.bind()?;
ctx.require_host_pointer_mode()?;
if incx == 0 {
return Err(Error::InvalidIncrement);
}
let n = vector_len(x.len(), incx)?;
if n == 0 {
return Ok(0);
}
let incx = to_i32(incx, "incx")?;
let mut result = 0;
unsafe {
try_ffi!(sys::cublasIamaxEx(
ctx.as_raw(),
n,
x.as_ptr() as _,
x_type.into(),
incx,
&raw mut result,
))?;
}
Ok(result)
}
pub fn iamin_ex<TX>(
ctx: &Context,
x: &DeviceMemory<TX>,
x_type: DataType,
incx: usize,
) -> Result<i32> {
ctx.bind()?;
ctx.require_host_pointer_mode()?;
if incx == 0 {
return Err(Error::InvalidIncrement);
}
let n = vector_len(x.len(), incx)?;
if n == 0 {
return Ok(0);
}
let incx = to_i32(incx, "incx")?;
let mut result = 0;
unsafe {
try_ffi!(sys::cublasIaminEx(
ctx.as_raw(),
n,
x.as_ptr() as _,
x_type.into(),
incx,
&raw mut result,
))?;
}
Ok(result)
}
pub fn asum_ex<TX, TResult>(
ctx: &Context,
x: &DeviceMemory<TX>,
x_type: DataType,
incx: usize,
result: &mut TResult,
result_type: DataType,
execution_type: DataType,
) -> Result<()> {
ctx.bind()?;
ctx.require_host_pointer_mode()?;
if incx == 0 {
return Err(Error::InvalidIncrement);
}
let n = vector_len(x.len(), incx)?;
if n == 0 {
return Ok(());
}
let incx = to_i32(incx, "incx")?;
unsafe {
try_ffi!(sys::cublasAsumEx(
ctx.as_raw(),
n,
x.as_ptr() as _,
x_type.into(),
incx,
ptr::from_mut(result) as _,
result_type.into(),
execution_type.into(),
))?;
}
Ok(())
}
fn vector_len(length: usize, inc: usize) -> Result<i32> {
if length == 0 {
return Ok(0);
}
let n = 1 + (length - 1) / inc;
i32::try_from(n).map_err(|_| Error::OutOfRange { name: "n".into() })
}
fn vector_binary_len(x_length: usize, incx: usize, y_length: usize, incy: usize) -> Result<i32> {
let x_n = vector_len(x_length, incx)?;
let y_n = vector_len(y_length, incy)?;
if x_n != y_n {
return Err(Error::InvalidVectorShape);
}
Ok(x_n)
}
#[cfg(all(test, feature = "testing"))]
mod tests {
use super::*;
use crate::testing::setup_context;
#[test]
fn test_sscal_scales_vector() -> Result<()> {
let ctx = setup_context()?;
let mut x = DeviceMemory::from_slice(&[1.0_f32, 2.0, 3.0, 4.0])?;
sscal(&ctx, &2.5, &mut x, 1)?;
let result = x.copy_to_host_vec()?;
assert_eq!(result, vec![2.5, 5.0, 7.5, 10.0]);
Ok(())
}
#[test]
fn test_drot_rotates_vectors() -> Result<()> {
let ctx = setup_context()?;
let mut x = DeviceMemory::from_slice(&[1.0_f64, 2.0, 3.0, 4.0])?;
let mut y = DeviceMemory::from_slice(&[5.0_f64, 6.0, 7.0, 8.0])?;
drot(&ctx, &mut x, 1, &mut y, 1, &2.1, &1.2)?;
let x_result = x.copy_to_host_vec()?;
let y_result = y.copy_to_host_vec()?;
let x_expected = [8.1, 11.4, 14.7, 18.0];
let y_expected = [9.3, 10.2, 11.1, 12.0];
for (actual, expected) in x_result.iter().zip(x_expected) {
assert!((actual - expected).abs() < 1.0e-12);
}
for (actual, expected) in y_result.iter().zip(y_expected) {
assert!((actual - expected).abs() < 1.0e-12);
}
Ok(())
}
}