use std::ffi::{CString, CStr};
use std::iter::FusedIterator;
use crate::c_api::{mts_block_t, mts_array_t, MTS_INVALID_PARAMETER_ERROR};
use crate::{ArrayRef, ArrayRefMut, Labels, Error};
use super::{TensorBlockRef, LazyMetadata};
use super::block_ref::{get_samples, get_components, get_properties};
#[derive(Debug)]
pub struct TensorBlockRefMut<'a> {
ptr: *mut mts_block_t,
marker: std::marker::PhantomData<&'a mut mts_block_t>,
}
unsafe impl Send for TensorBlockRefMut<'_> {}
unsafe impl Sync for TensorBlockRefMut<'_> {}
#[derive(Debug)]
pub struct TensorBlockDataMut<'a> {
pub values: ArrayRefMut<'a>,
pub samples: LazyMetadata<Labels>,
pub components: LazyMetadata<Vec<Labels>>,
pub properties: LazyMetadata<Labels>,
}
fn block_gradient(block: *mut mts_block_t, parameter: &CStr) -> Option<*mut mts_block_t> {
let mut gradient_block = std::ptr::null_mut();
let status = unsafe { crate::c_api::mts_block_gradient(
block,
parameter.as_ptr(),
&mut gradient_block
)
};
match crate::errors::check_status(status) {
Ok(()) => Some(gradient_block),
Err(error) => {
if error.code == Some(MTS_INVALID_PARAMETER_ERROR) {
None
} else {
panic!("failed to get the gradient from a block: {:?}", error)
}
}
}
}
impl<'a> TensorBlockRefMut<'a> {
pub(crate) unsafe fn from_raw(ptr: *mut mts_block_t) -> TensorBlockRefMut<'a> {
assert!(!ptr.is_null(), "pointer to mts_block_t should not be NULL");
TensorBlockRefMut {
ptr: ptr,
marker: std::marker::PhantomData,
}
}
pub(super) fn as_ptr(&self) -> *const mts_block_t {
self.ptr
}
pub(super) fn as_mut_ptr(&mut self) -> *mut mts_block_t {
self.ptr
}
#[inline]
pub fn as_ref(&self) -> TensorBlockRef<'_> {
unsafe {
TensorBlockRef::from_raw(self.as_ptr())
}
}
#[inline]
pub fn data_mut(&mut self) -> TensorBlockDataMut<'_> {
let samples = LazyMetadata::new(get_samples, self.as_ptr());
let components = LazyMetadata::new(get_components, self.as_ptr());
let properties = LazyMetadata::new(get_properties, self.as_ptr());
TensorBlockDataMut {
values: self.values_mut(),
samples: samples,
components: components,
properties: properties,
}
}
#[inline]
pub fn values_mut(&mut self) -> ArrayRefMut<'_> {
let mut array = mts_array_t::null();
unsafe {
crate::errors::check_status(crate::c_api::mts_block_data(
self.as_mut_ptr(),
&mut array
)).expect("failed to get the array for a block");
};
unsafe { ArrayRefMut::new(array) }
}
#[inline]
pub fn values(&self) -> ArrayRef<'_> {
return self.as_ref().values();
}
#[inline]
pub fn samples(&self) -> Labels {
return self.as_ref().samples();
}
#[inline]
pub fn components(&self) -> Vec<Labels> {
return self.as_ref().components();
}
#[inline]
pub fn properties(&self) -> Labels {
return self.as_ref().properties();
}
#[inline]
pub fn gradient_mut(&mut self, parameter: &str) -> Option<TensorBlockRefMut<'_>> {
let parameter = CString::new(parameter).expect("invalid C string");
block_gradient(self.as_mut_ptr(), ¶meter)
.map(|gradient_block| {
unsafe { TensorBlockRefMut::from_raw(gradient_block) }
})
}
#[inline]
pub fn gradients_mut(&mut self) -> GradientsMutIter<'_> {
let block_ptr = self.as_mut_ptr();
GradientsMutIter {
parameters: self.as_ref().gradient_list().into_iter(),
block: block_ptr,
}
}
pub fn save(&self, path: impl AsRef<std::path::Path>) -> Result<(), Error> {
self.as_ref().save(path)
}
pub fn save_buffer(&self, buffer: &mut Vec<u8>) -> Result<(), Error> {
self.as_ref().save_buffer(buffer)
}
}
pub struct GradientsMutIter<'a> {
parameters: std::vec::IntoIter<&'a str>,
block: *mut mts_block_t,
}
impl<'a> Iterator for GradientsMutIter<'a> {
type Item = (&'a str, TensorBlockRefMut<'a>);
#[inline]
fn next(&mut self) -> Option<Self::Item> {
self.parameters.next().map(|parameter| {
let parameter_c = CString::new(parameter).expect("invalid C string");
let block = block_gradient(self.block, ¶meter_c).expect("missing gradient");
let block = unsafe { TensorBlockRefMut::from_raw(block) };
return (parameter, block);
})
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.parameters.size_hint()
}
}
impl ExactSizeIterator for GradientsMutIter<'_> {
#[inline]
fn len(&self) -> usize {
self.parameters.len()
}
}
impl FusedIterator for GradientsMutIter<'_> {}
#[cfg(test)]
mod tests {
use crate::{Labels, TensorBlock};
#[test]
#[allow(clippy::float_cmp)]
fn gradients() {
let properties = Labels::new(["p"], &[[-2], [0], [1]]);
let mut block = TensorBlock::new(
ndarray::ArrayD::from_elem(vec![2, 3], 1.0),
&Labels::new(["s"], &[[0], [1]]), &[], &properties,
).unwrap();
block.add_gradient("g", TensorBlock::new(
ndarray::ArrayD::from_elem(vec![2, 3], -1.0),
&Labels::new(["sample"], &[[0], [1]]), &[], &properties,
).unwrap()).unwrap();
block.add_gradient("f", TensorBlock::new(
ndarray::ArrayD::from_elem(vec![2, 3], -2.0),
&Labels::new(["sample"], &[[0], [1]]), &[], &properties,
).unwrap()).unwrap();
let mut block = block.as_ref_mut();
let gradient = block.gradient_mut("g").unwrap();
assert_eq!(gradient.values().as_array()[[0, 0]], -1.0);
let gradient = block.gradient_mut("f").unwrap();
assert_eq!(gradient.values().as_array()[[0, 0]], -2.0);
assert!(block.gradient_mut("h").is_none());
let mut iter = block.gradients_mut();
assert_eq!(iter.len(), 2);
assert_eq!(iter.next().unwrap().0, "g");
assert_eq!(iter.next().unwrap().0, "f");
assert!(iter.next().is_none());
}
#[test]
fn block_data() {
let mut block = TensorBlock::new(
ndarray::ArrayD::from_elem(vec![2, 1, 3], 1.0),
&Labels::new(["samples"], &[[0], [1]]),
&[Labels::new(["component"], &[[0]])],
&Labels::new(["properties"], &[[-2], [0], [1]]),
).unwrap();
let mut block = block.as_ref_mut();
assert_eq!(block.values().as_array(), ndarray::ArrayD::from_elem(vec![2, 1, 3], 1.0));
assert_eq!(block.samples(), Labels::new(["samples"], &[[0], [1]]));
assert_eq!(block.components(), [Labels::new(["component"], &[[0]])]);
assert_eq!(block.properties(), Labels::new(["properties"], &[[-2], [0], [1]]));
let block = block.data_mut();
assert_eq!(block.values.as_array(), ndarray::ArrayD::from_elem(vec![2, 1, 3], 1.0));
assert_eq!(*block.samples, Labels::new(["samples"], &[[0], [1]]));
assert_eq!(*block.components, [Labels::new(["component"], &[[0]])]);
assert_eq!(*block.properties, Labels::new(["properties"], &[[-2], [0], [1]]));
}
}