use crate::call;
use crate::common::{DataType, OneDimArrayInt32, PlaceType, TwoDimArraySize};
use crate::ctypes::{
PD_Tensor, PD_TensorCopyFromCpuFloat, PD_TensorCopyFromCpuInt32, PD_TensorCopyFromCpuInt64,
PD_TensorCopyFromCpuInt8, PD_TensorCopyFromCpuUint8, PD_TensorCopyToCpuFloat,
PD_TensorCopyToCpuInt32, PD_TensorCopyToCpuInt64, PD_TensorCopyToCpuInt8,
PD_TensorCopyToCpuUint8, PD_TensorDataFloat, PD_TensorDataInt32, PD_TensorDataInt64,
PD_TensorDataInt8, PD_TensorDataUint8, PD_TensorDestroy, PD_TensorGetDataType, PD_TensorGetLod,
PD_TensorGetName, PD_TensorGetShape, PD_TensorMutableDataFloat, PD_TensorMutableDataInt32,
PD_TensorMutableDataInt64, PD_TensorMutableDataInt8, PD_TensorMutableDataUint8,
PD_TensorReshape, PD_TensorSetLod,
};
use std::borrow::Cow;
use std::ffi::CStr;
pub struct Tensor {
ptr: *mut PD_Tensor,
}
impl Tensor {
pub fn from_ptr(ptr: *mut PD_Tensor) -> Self {
Self { ptr }
}
}
impl Tensor {
pub fn reshape(&self, shape: &[i32]) {
call! {
PD_TensorReshape(self.ptr, shape.len(), shape.as_ptr() as *mut _)
};
}
pub fn shape(&self) -> Vec<i32> {
let ptr = call! { PD_TensorGetShape(self.ptr) };
OneDimArrayInt32::from_ptr(ptr).to_vec()
}
pub fn data_type(&self) -> DataType {
call! { PD_TensorGetDataType(self.ptr) }
}
pub fn name(&self) -> Cow<str> {
let ptr = call! { PD_TensorGetName(self.ptr) };
unsafe { CStr::from_ptr(ptr).to_string_lossy() }
}
}
impl Tensor {
pub fn copy_from_f32(&self, data: &[f32]) {
call! {
PD_TensorCopyFromCpuFloat(self.ptr, data.as_ptr())
};
}
pub fn copy_from_i64(&self, data: &[i64]) {
call! {
PD_TensorCopyFromCpuInt64(self.ptr, data.as_ptr())
};
}
pub fn copy_from_i32(&self, data: &[i32]) {
call! {
PD_TensorCopyFromCpuInt32(self.ptr, data.as_ptr())
};
}
pub fn copy_from_u8(&self, data: &[u8]) {
call! {
PD_TensorCopyFromCpuUint8(self.ptr, data.as_ptr())
};
}
pub fn copy_from_i8(&self, data: &[i8]) {
call! {
PD_TensorCopyFromCpuInt8(self.ptr, data.as_ptr())
};
}
}
impl Tensor {
#[inline]
fn size(&self) -> usize {
self.shape().into_iter().fold(1usize, |s, i| s * i as usize)
}
fn check_data_type(&self, ty: DataType) -> bool {
let dt = self.data_type();
dt != DataType::Unknown && dt == ty
}
fn check(&self, size: usize, ty: DataType) -> bool {
size >= self.size() && self.check_data_type(ty)
}
}
impl Tensor {
pub fn copy_to_f32(&self, data: &mut [f32]) -> bool {
if self.check(data.len(), DataType::Float32) {
call! { PD_TensorCopyToCpuFloat(self.ptr, data.as_mut_ptr()) };
true
} else {
false
}
}
pub fn copy_to_i64(&self, data: &mut [i64]) -> bool {
if self.check(data.len(), DataType::Int64) {
call! { PD_TensorCopyToCpuInt64(self.ptr, data.as_mut_ptr()) };
true
} else {
false
}
}
pub fn copy_to_i32(&self, data: &mut [i32]) -> bool {
if self.check(data.len(), DataType::Int32) {
call! { PD_TensorCopyToCpuInt32(self.ptr, data.as_mut_ptr()) };
true
} else {
false
}
}
pub fn copy_to_u8(&self, data: &mut [u8]) -> bool {
if self.check(data.len(), DataType::Uint8) {
call! { PD_TensorCopyToCpuUint8(self.ptr, data.as_mut_ptr()) };
true
} else {
false
}
}
pub fn copy_to_i8(&self, data: &mut [i8]) -> bool {
if self.check(data.len(), DataType::Uint8) {
call! { PD_TensorCopyToCpuInt8(self.ptr, data.as_mut_ptr()) };
true
} else {
false
}
}
}
impl Tensor {
pub fn as_mut_slice_f32(&self, place_type: PlaceType) -> Option<&mut [f32]> {
self.check_data_type(DataType::Float32).then(|| {
let ptr = call! { PD_TensorMutableDataFloat(self.ptr, place_type) };
unsafe { std::slice::from_raw_parts_mut(ptr, self.size()) }
})
}
pub fn as_mut_slice_i64(&self, place_type: PlaceType) -> Option<&mut [i64]> {
self.check_data_type(DataType::Int64).then(|| {
let ptr = call! { PD_TensorMutableDataInt64(self.ptr, place_type) };
unsafe { std::slice::from_raw_parts_mut(ptr, self.size()) }
})
}
pub fn as_mut_slice_i32(&self, place_type: PlaceType) -> Option<&mut [i32]> {
self.check_data_type(DataType::Int32).then(|| {
let ptr = call! { PD_TensorMutableDataInt32(self.ptr, place_type) };
unsafe { std::slice::from_raw_parts_mut(ptr, self.size()) }
})
}
pub fn as_mut_slice_u8(&self, place_type: PlaceType) -> Option<&mut [u8]> {
self.check_data_type(DataType::Uint8).then(|| {
let ptr = call! { PD_TensorMutableDataUint8(self.ptr, place_type) };
unsafe { std::slice::from_raw_parts_mut(ptr, self.size()) }
})
}
pub fn as_mut_slice_i8(&self, place_type: PlaceType) -> Option<&mut [i8]> {
self.check_data_type(DataType::Uint8).then(|| {
let ptr = call! { PD_TensorMutableDataInt8(self.ptr, place_type) };
unsafe { std::slice::from_raw_parts_mut(ptr, self.size()) }
})
}
}
impl Tensor {
pub fn as_slice_f32(&self) -> Option<(PlaceType, &[f32])> {
self.check_data_type(DataType::Float32).then(|| {
let mut place_type = PlaceType::Unknown;
let mut size = 0;
let ptr = call! { PD_TensorDataFloat(self.ptr, &mut place_type, &mut size) };
let s = unsafe { std::slice::from_raw_parts(ptr, size as usize) };
(place_type, s)
})
}
pub fn as_slice_i64(&self) -> Option<(PlaceType, &[i64])> {
self.check_data_type(DataType::Int64).then(|| {
let mut place_type = PlaceType::Unknown;
let mut size = 0;
let ptr = call! { PD_TensorDataInt64(self.ptr, &mut place_type, &mut size) };
let s = unsafe { std::slice::from_raw_parts(ptr, size as usize) };
(place_type, s)
})
}
pub fn as_slice_i32(&self) -> Option<(PlaceType, &[i32])> {
self.check_data_type(DataType::Int32).then(|| {
let mut place_type = PlaceType::Unknown;
let mut size = 0;
let ptr = call! { PD_TensorDataInt32(self.ptr, &mut place_type, &mut size) };
let s = unsafe { std::slice::from_raw_parts(ptr, size as usize) };
(place_type, s)
})
}
pub fn as_slice_u8(&self) -> Option<(PlaceType, &[u8])> {
self.check_data_type(DataType::Uint8).then(|| {
let mut place_type = PlaceType::Unknown;
let mut size = 0;
let ptr = call! { PD_TensorDataUint8(self.ptr, &mut place_type, &mut size) };
let s = unsafe { std::slice::from_raw_parts(ptr, size as usize) };
(place_type, s)
})
}
pub fn as_slice_i8(&self) -> Option<(PlaceType, &[i8])> {
self.check_data_type(DataType::Uint8).then(|| {
let mut place_type = PlaceType::Unknown;
let mut size = 0;
let ptr = call! { PD_TensorDataInt8(self.ptr, &mut place_type, &mut size) };
let s = unsafe { std::slice::from_raw_parts(ptr, size as usize) };
(place_type, s)
})
}
}
impl Tensor {
pub fn set_lod(&self, lod: TwoDimArraySize) {
call! { PD_TensorSetLod(self.ptr, lod.ptr) };
}
pub fn lod(&self) -> TwoDimArraySize {
let ptr = call!(PD_TensorGetLod(self.ptr));
TwoDimArraySize::from_ptr(ptr)
}
}
impl Drop for Tensor {
fn drop(&mut self) {
call!(PD_TensorDestroy(self.ptr));
}
}