use std::ffi::{CString, CStr};
use std::iter::FusedIterator;
use crate::c_api::{mts_block_t, mts_array_t, MTS_INVALID_PARAMETER_ERROR};
use crate::{ArrayRef, ArrayRefMut, Labels, Error};
use super::{TensorBlockRef, LazyMetadata};
use super::block_ref::{get_samples, get_components, get_properties};
#[derive(Debug)]
pub struct TensorBlockRefMut<'a> {
ptr: *mut mts_block_t,
marker: std::marker::PhantomData<&'a mut mts_block_t>,
}
unsafe impl<'a> Send for TensorBlockRefMut<'a> {}
unsafe impl<'a> Sync for TensorBlockRefMut<'a> {}
#[derive(Debug)]
pub struct TensorBlockDataMut<'a> {
pub values: ArrayRefMut<'a>,
pub samples: LazyMetadata<Labels>,
pub components: LazyMetadata<Vec<Labels>>,
pub properties: LazyMetadata<Labels>,
}
fn block_gradient(block: *mut mts_block_t, parameter: &CStr) -> Option<*mut mts_block_t> {
let mut gradient_block = std::ptr::null_mut();
let status = unsafe { crate::c_api::mts_block_gradient(
block,
parameter.as_ptr(),
&mut gradient_block
)
};
match crate::errors::check_status(status) {
Ok(()) => Some(gradient_block),
Err(error) => {
if error.code == Some(MTS_INVALID_PARAMETER_ERROR) {
None
} else {
panic!("failed to get the gradient from a block: {:?}", error)
}
}
}
}
impl<'a> TensorBlockRefMut<'a> {
pub(crate) unsafe fn from_raw(ptr: *mut mts_block_t) -> TensorBlockRefMut<'a> {
assert!(!ptr.is_null(), "pointer to mts_block_t should not be NULL");
TensorBlockRefMut {
ptr: ptr,
marker: std::marker::PhantomData,
}
}
pub(super) fn as_ptr(&self) -> *const mts_block_t {
self.ptr
}
pub(super) fn as_mut_ptr(&mut self) -> *mut mts_block_t {
self.ptr
}
#[inline]
pub fn as_ref(&self) -> TensorBlockRef<'_> {
unsafe {
TensorBlockRef::from_raw(self.as_ptr())
}
}
#[inline]
pub fn data_mut(&mut self) -> TensorBlockDataMut<'_> {
let samples = LazyMetadata::new(get_samples, self.as_ptr());
let components = LazyMetadata::new(get_components, self.as_ptr());
let properties = LazyMetadata::new(get_properties, self.as_ptr());
TensorBlockDataMut {
values: self.values_mut(),
samples: samples,
components: components,
properties: properties,
}
}
#[inline]
pub fn values_mut(&mut self) -> ArrayRefMut<'_> {
let mut array = mts_array_t::null();
unsafe {
crate::errors::check_status(crate::c_api::mts_block_data(
self.as_mut_ptr(),
&mut array
)).expect("failed to get the array for a block");
};
unsafe { ArrayRefMut::new(array) }
}
#[inline]
pub fn values(&self) -> ArrayRef<'_> {
return self.as_ref().values();
}
#[inline]
pub fn samples(&self) -> Labels {
return self.as_ref().samples();
}
#[inline]
pub fn components(&self) -> Vec<Labels> {
return self.as_ref().components();
}
#[inline]
pub fn properties(&self) -> Labels {
return self.as_ref().properties();
}
#[inline]
pub fn gradient_mut(&mut self, parameter: &str) -> Option<TensorBlockRefMut<'_>> {
let parameter = CString::new(parameter).expect("invalid C string");
block_gradient(self.as_mut_ptr(), ¶meter)
.map(|gradient_block| {
unsafe { TensorBlockRefMut::from_raw(gradient_block) }
})
}
#[inline]
pub fn gradients_mut(&mut self) -> GradientsMutIter<'_> {
let block_ptr = self.as_mut_ptr();
GradientsMutIter {
parameters: self.as_ref().gradient_list().into_iter(),
block: block_ptr,
}
}
pub fn save(&self, path: impl AsRef<std::path::Path>) -> Result<(), Error> {
self.as_ref().save(path)
}
pub fn save_buffer(&self, buffer: &mut Vec<u8>) -> Result<(), Error> {
self.as_ref().save_buffer(buffer)
}
}
pub struct GradientsMutIter<'a> {
parameters: std::vec::IntoIter<&'a str>,
block: *mut mts_block_t,
}
impl<'a> Iterator for GradientsMutIter<'a> {
type Item = (&'a str, TensorBlockRefMut<'a>);
#[inline]
fn next(&mut self) -> Option<Self::Item> {
self.parameters.next().map(|parameter| {
let parameter_c = CString::new(parameter).expect("invalid C string");
let block = block_gradient(self.block, ¶meter_c).expect("missing gradient");
let block = unsafe { TensorBlockRefMut::from_raw(block) };
return (parameter, block);
})
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.parameters.size_hint()
}
}
impl<'a> ExactSizeIterator for GradientsMutIter<'a> {
#[inline]
fn len(&self) -> usize {
self.parameters.len()
}
}
impl<'a> FusedIterator for GradientsMutIter<'a> {}
#[cfg(test)]
mod tests {
}