use std::ffi::{CStr, CString};
use std::iter::FusedIterator;
use crate::c_api::{mts_block_t, mts_array_t, mts_labels_t};
use crate::c_api::MTS_INVALID_PARAMETER_ERROR;
use crate::errors::check_status;
use crate::{ArrayRef, Labels, Error};
use super::{TensorBlock, LazyMetadata};
#[derive(Debug, Clone, Copy)]
pub struct TensorBlockRef<'a> {
ptr: *const mts_block_t,
marker: std::marker::PhantomData<&'a mts_block_t>,
}
unsafe impl<'a> Send for TensorBlockRef<'a> {}
unsafe impl<'a> Sync for TensorBlockRef<'a> {}
#[derive(Debug)]
pub struct TensorBlockData<'a> {
pub values: ArrayRef<'a>,
pub samples: LazyMetadata<Labels>,
pub components: LazyMetadata<Vec<Labels>>,
pub properties: LazyMetadata<Labels>,
}
impl<'a> TensorBlockRef<'a> {
pub(crate) unsafe fn from_raw(ptr: *const mts_block_t) -> TensorBlockRef<'a> {
assert!(!ptr.is_null(), "pointer to mts_block_t should not be NULL");
TensorBlockRef {
ptr: ptr,
marker: std::marker::PhantomData,
}
}
pub(crate) fn as_ptr(&self) -> *const mts_block_t {
self.ptr
}
}
fn block_gradient(block: *const mts_block_t, parameter: &CStr) -> Option<*const mts_block_t> {
let mut gradient_block = std::ptr::null_mut();
let status = unsafe { crate::c_api::mts_block_gradient(
block.cast_mut(),
parameter.as_ptr(),
&mut gradient_block
)
};
match crate::errors::check_status(status) {
Ok(()) => Some(gradient_block.cast_const()),
Err(error) => {
if error.code == Some(MTS_INVALID_PARAMETER_ERROR) {
None
} else {
panic!("failed to get the gradient from a block: {:?}", error)
}
}
}
}
pub(super) fn get_samples(ptr: *const mts_block_t) -> Labels {
unsafe {
TensorBlockRef::from_raw(ptr).samples()
}
}
pub(super) fn get_components(ptr: *const mts_block_t) -> Vec<Labels> {
unsafe {
TensorBlockRef::from_raw(ptr).components()
}
}
pub(super) fn get_properties(ptr: *const mts_block_t) -> Labels {
unsafe {
TensorBlockRef::from_raw(ptr).properties()
}
}
impl<'a> TensorBlockRef<'a> {
#[inline]
pub fn data(&'a self) -> TensorBlockData<'a> {
TensorBlockData {
values: self.values(),
samples: LazyMetadata::new(get_samples, self.as_ptr()),
components: LazyMetadata::new(get_components, self.as_ptr()),
properties: LazyMetadata::new(get_properties, self.as_ptr()),
}
}
#[inline]
pub fn values(&self) -> ArrayRef<'a> {
let mut array = mts_array_t::null();
unsafe {
crate::errors::check_status(crate::c_api::mts_block_data(
self.as_ptr().cast_mut(),
&mut array
)).expect("failed to get the array for a block");
};
unsafe { ArrayRef::from_raw(array) }
}
#[inline]
fn labels(&self, dimension: usize) -> Labels {
let mut labels = mts_labels_t::null();
unsafe {
check_status(crate::c_api::mts_block_labels(
self.as_ptr(),
dimension,
&mut labels,
)).expect("failed to get labels");
}
return unsafe { Labels::from_raw(labels) };
}
#[inline]
pub fn samples(&self) -> Labels {
return self.labels(0);
}
#[inline]
pub fn components(&self) -> Vec<Labels> {
let values = self.values();
let shape = values.as_raw().shape().expect("failed to get the data shape");
let mut result = Vec::new();
for i in 1..(shape.len() - 1) {
result.push(self.labels(i));
}
return result;
}
#[inline]
pub fn properties(&self) -> Labels {
let values = self.values();
let shape = values.as_raw().shape().expect("failed to get the data shape");
return self.labels(shape.len() - 1);
}
#[inline]
pub fn gradient_list(&self) -> Vec<&'a str> {
let mut parameters_ptr = std::ptr::null();
let mut parameters_count = 0;
unsafe {
check_status(crate::c_api::mts_block_gradients_list(
self.as_ptr(),
&mut parameters_ptr,
&mut parameters_count
)).expect("failed to get gradient list");
}
if parameters_count == 0 {
return Vec::new();
} else {
assert!(!parameters_ptr.is_null());
unsafe {
let parameters = std::slice::from_raw_parts(parameters_ptr, parameters_count);
return parameters.iter()
.map(|&ptr| CStr::from_ptr(ptr).to_str().unwrap())
.collect();
}
}
}
#[inline]
pub fn gradient(&self, parameter: &str) -> Option<TensorBlockRef<'a>> {
let parameter = CString::new(parameter).expect("invalid C string");
block_gradient(self.as_ptr(), ¶meter)
.map(|gradient_block| {
unsafe { TensorBlockRef::from_raw(gradient_block) }
})
}
#[inline]
pub fn try_clone(&self) -> Result<TensorBlock, Error> {
let ptr = unsafe {
crate::c_api::mts_block_copy(self.as_ptr())
};
crate::errors::check_ptr(ptr)?;
return Ok(unsafe { TensorBlock::from_raw(ptr) });
}
#[inline]
pub fn gradients(&self) -> GradientsIter<'_> {
GradientsIter {
parameters: self.gradient_list().into_iter(),
block: self.as_ptr(),
}
}
pub fn save(&self, path: impl AsRef<std::path::Path>) -> Result<(), Error> {
return crate::io::save_block(path, *self);
}
pub fn save_buffer(&self, buffer: &mut Vec<u8>) -> Result<(), Error> {
return crate::io::save_block_buffer(*self, buffer);
}
}
pub struct GradientsIter<'a> {
parameters: std::vec::IntoIter<&'a str>,
block: *const mts_block_t,
}
impl<'a> Iterator for GradientsIter<'a> {
type Item = (&'a str, TensorBlockRef<'a>);
#[inline]
fn next(&mut self) -> Option<Self::Item> {
self.parameters.next().map(|parameter| {
let parameter_c = CString::new(parameter).expect("invalid C string");
let block = block_gradient(self.block, ¶meter_c).expect("missing gradient");
let block = unsafe { TensorBlockRef::from_raw(block) };
return (parameter, block);
})
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.len(), Some(self.len()))
}
}
impl<'a> ExactSizeIterator for GradientsIter<'a> {
#[inline]
fn len(&self) -> usize {
self.parameters.len()
}
}
impl<'a> FusedIterator for GradientsIter<'a> {}
#[cfg(test)]
mod tests {
}