use crate::c_api::mts_block_t;
use crate::errors::check_status;
use crate::{Array, ArrayRef, Labels, Error};
use super::{TensorBlockRef, TensorBlockRefMut};
#[derive(Debug)]
#[repr(transparent)]
pub struct TensorBlock {
ptr: *mut mts_block_t,
}
unsafe impl Send for TensorBlock {}
unsafe impl Sync for TensorBlock {}
impl std::ops::Drop for TensorBlock {
#[allow(unused_must_use)]
fn drop(&mut self) {
unsafe {
crate::c_api::mts_block_free(self.as_mut_ptr());
}
}
}
impl TensorBlock {
pub(crate) unsafe fn from_raw(ptr: *mut mts_block_t) -> TensorBlock {
assert!(!ptr.is_null(), "pointer to mts_block_t should not be NULL");
TensorBlock {
ptr: ptr,
}
}
pub(super) fn as_ptr(&self) -> *const mts_block_t {
self.ptr
}
pub(super) fn as_mut_ptr(&mut self) -> *mut mts_block_t {
self.ptr
}
#[inline]
pub fn as_ref(&self) -> TensorBlockRef<'_> {
unsafe {
TensorBlockRef::from_raw(self.as_ptr())
}
}
#[inline]
pub fn as_ref_mut(&mut self) -> TensorBlockRefMut<'_> {
unsafe {
TensorBlockRefMut::from_raw(self.as_mut_ptr())
}
}
#[inline]
pub fn values(&self) -> ArrayRef<'_> {
return self.as_ref().values();
}
#[inline]
pub fn samples(&self) -> Labels {
return self.as_ref().samples();
}
#[inline]
pub fn components(&self) -> Vec<Labels> {
return self.as_ref().components();
}
#[inline]
pub fn properties(&self) -> Labels {
return self.as_ref().properties();
}
#[inline]
pub fn new(
data: impl Array,
samples: &Labels,
components: &[Labels],
properties: &Labels
) -> Result<TensorBlock, Error> {
let mut c_components = Vec::new();
for component in components {
c_components.push(component.as_mts_labels_t());
}
let ptr = unsafe {
crate::c_api::mts_block(
(Box::new(data) as Box<dyn Array>).into(),
samples.as_mts_labels_t(),
c_components.as_ptr(),
c_components.len(),
properties.as_mts_labels_t(),
)
};
crate::errors::check_ptr(ptr)?;
return Ok(unsafe { TensorBlock::from_raw(ptr) });
}
#[allow(clippy::needless_pass_by_value)]
#[inline]
pub fn add_gradient(
&mut self,
parameter: &str,
mut gradient: TensorBlock
) -> Result<(), Error> {
let mut parameter = parameter.to_owned().into_bytes();
parameter.push(b'\0');
let gradient_ptr = gradient.as_ref_mut().as_mut_ptr();
std::mem::forget(gradient);
unsafe {
check_status(crate::c_api::mts_block_add_gradient(
self.as_ref_mut().as_mut_ptr(),
parameter.as_ptr().cast(),
gradient_ptr,
))?;
}
return Ok(());
}
pub fn load(path: impl AsRef<std::path::Path>) -> Result<TensorBlock, Error> {
return crate::io::load_block(path);
}
pub fn load_buffer(buffer: &[u8]) -> Result<TensorBlock, Error> {
return crate::io::load_block_buffer(buffer);
}
pub fn save(&self, path: impl AsRef<std::path::Path>) -> Result<(), Error> {
self.as_ref().save(path)
}
pub fn save_buffer(&self, buffer: &mut Vec<u8>) -> Result<(), Error> {
self.as_ref().save_buffer(buffer)
}
}
#[cfg(test)]
mod tests {
use crate::c_api::mts_block_t;
use super::*;
#[test]
fn block() {
let block = TensorBlock::new(
ndarray::ArrayD::from_elem(vec![2, 1, 3], 1.0),
&Labels::new(["samples"], &[[0], [1]]),
&[Labels::new(["component"], &[[0]])],
&Labels::new(["properties"], &[[-2], [0], [1]]),
).unwrap();
assert_eq!(block.values().as_array(), ndarray::ArrayD::from_elem(vec![2, 1, 3], 1.0));
assert_eq!(block.samples(), Labels::new(["samples"], &[[0], [1]]));
assert_eq!(block.components(), [Labels::new(["component"], &[[0]])]);
assert_eq!(block.properties(), Labels::new(["properties"], &[[-2], [0], [1]]));
}
#[test]
fn check_repr() {
assert_eq!(std::mem::size_of::<TensorBlock>(), std::mem::size_of::<*mut mts_block_t>());
assert_eq!(std::mem::align_of::<TensorBlock>(), std::mem::align_of::<*mut mts_block_t>());
}
}