use core::ffi::CStr;
use std::ptr::NonNull;
use core::ops::DerefMut;
use crate::convert::FromMatlab;
use rustmex_core::{
convert::{
FromMatlabError,
},
classid::ClassID,
shim::{
rustmex_get_field_number as mxGetFieldNumber,
rustmex_get_field_by_number as mxGetFieldByNumber,
rustmex_set_field_by_number as mxSetFieldByNumber,
rustmex_get_number_of_fields as mxGetNumberOfFields,
rustmex_get_field_name_by_number as mxGetFieldNameByNumber,
rustmex_create_struct_array as mxCreateStructArray,
rustmex_add_field as mxAddField,
rustmex_remove_field as mxRemoveField,
},
mxArray,
pointers::{
MxArray,
MatlabPtr,
MutMatlabPtr,
},
MatlabClass,
MutMatlabClass,
OwnedMatlabClass,
NewEmpty,
};
pub use super::index::Index;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum StructError {
NotAField,
OutOfBounds,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
#[repr(transparent)]
pub struct Struct<P>(P);
impl<'a> FromMatlab<'a> for Struct<&'a mxArray> {
fn from_matlab(mx: &'a mxArray) -> Result<Self, FromMatlabError<&'a mxArray>> {
Self::from_mx_array(mx)
}
}
impl<P> std::ops::Deref for Struct<P> where P: MatlabPtr {
type Target = mxArray;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<P> std::ops::DerefMut for Struct<P> where P: MutMatlabPtr {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
pub trait FieldIndex {
fn index_into<P: MatlabPtr>(&self, s: &Struct<P>) -> Result<i32, StructError>;
}
impl FieldIndex for &CStr {
fn index_into<P: MatlabPtr>(&self, s: &Struct<P>) -> Result<i32, StructError> {
s.field_number(self)
}
}
impl FieldIndex for i32 {
fn index_into<P: MatlabPtr>(&self, s: &Struct<P>) -> Result<i32, StructError> {
if *self < 0 || *self >= s.num_fields() {
Err(StructError::NotAField)
} else {
Ok(*self)
}
}
}
impl<'p, P: MatlabPtr + 'p> MatlabClass<P> for Struct<P> {
fn from_mx_array(mx: P) -> Result<Self, FromMatlabError<P>> {
if mx.class_id() == Ok(ClassID::Struct) {
Ok(Self(mx))
} else {
Err(FromMatlabError::new_badclass(mx))
}
}
fn into_inner(self) -> P {
self.0
}
fn inner(&self) -> &P {
&self.0
}
type Owned = Struct<MxArray>;
fn duplicate(&self) -> Self::Owned {
Struct(self.0.duplicate())
}
}
impl<'p, P: MutMatlabPtr + 'p> MutMatlabClass<P> for Struct<P> {
type AsBorrowed<'a> = Struct<&'a mxArray> where P: 'a;
fn as_borrowed<'a>(&'a self) -> Self::AsBorrowed<'a> {
Struct(self.0.deref())
}
fn inner_mut(&mut self) -> &mut P {
&mut self.0
}
}
pub type OwnedStruct = Struct<MxArray>;
impl OwnedMatlabClass for Struct<MxArray> {
type AsMutable<'a> = Struct<&'a mut mxArray> where Self: 'a;
fn as_mutable<'a>(&'a mut self) -> Self::AsMutable<'a> {
Struct(self.0.deref_mut())
}
}
impl NewEmpty for Struct<MxArray> {
fn new_empty() -> Self {
Self::new_empty_with_fields(&[])
}
}
impl<'p, P: MatlabPtr + 'p> Struct<P> {
pub fn into_scalar(self) -> Result<ScalarStruct<P>, Self> {
ScalarStruct::from_struct(self)
}
fn num_fields(&self) -> i32 {
let num = unsafe { mxGetNumberOfFields(self.0.deref()) };
assert_ne!(num, 0,
"Only documented failure case if self.0 is not a struct, but we know it is");
num
}
pub fn field_number(&self, field: &CStr) -> Result<i32, StructError> {
let num = unsafe { mxGetFieldNumber(self.0.deref(), field.as_ptr()) };
if num < 0 {
Err(StructError::NotAField)
} else {
Ok(num)
}
}
pub fn field_name(&self, fieldnum: i32) -> Result<&'p CStr, StructError> {
if fieldnum < 0 {
return Err(StructError::NotAField)
}
let name = unsafe { mxGetFieldNameByNumber(self.0.deref(), fieldnum) };
if name.is_null() {
Err(StructError::NotAField)
} else {
Ok( unsafe { CStr::from_ptr(name) })
}
}
pub fn field_names(&self) -> impl Iterator<Item = &'p CStr> + '_ {
let n = self.num_fields();
(0..n).map(|idx| self.field_name(idx).unwrap())
}
pub fn get<I: Index, F: FieldIndex>(&self, idx: I, field: F) -> Result<Option<&'p mxArray>, StructError> {
let linidx = idx.index_into(&self.0).ok_or(StructError::OutOfBounds)?;
let fieldidx = field.index_into(self)?;
Ok(unsafe { NonNull::new(mxGetFieldByNumber(self.0.deref(), linidx as _, fieldidx)) }
.map(|mx| unsafe { mx.as_ref() }))
}
pub fn fields<F: FieldIndex>(&self, field: F) -> Result<impl Iterator<Item = Option<&'p mxArray>> + '_, StructError> {
let n = self.0.numel();
let f = field.index_into(self)?;
Ok((0..n).map(move |linidx| self.get(linidx, f).expect("linidx should be in range")))
}
}
impl<'p, P: MutMatlabPtr + 'p> Struct<P> {
pub fn as_ref_struct<'s>(&'s self) -> Struct<&'s mxArray> {
Struct(&self.0)
}
pub fn get_mut<I: Index, F: FieldIndex>(&mut self, idx: I, field: F) -> Result<Option<&'p mut mxArray>, StructError> {
let linidx = idx.index_into(&self.0).ok_or(StructError::OutOfBounds)?;
let fieldidx = field.index_into(self)?;
Ok(unsafe { NonNull::new(mxGetFieldByNumber(self.0.deref_mut(), linidx as _, fieldidx)) }
.map(|mut mx| unsafe { mx.as_mut() }))
}
pub fn set<I: Index, F: FieldIndex>(&mut self, idx: I, field: F, value: MxArray) -> Result<Option<MxArray>, StructError> {
self.replace(idx, field, Some(value))
}
pub fn replace<I: Index, F: FieldIndex>(&mut self, idx: I, field: F, value: Option<MxArray>) -> Result<Option<MxArray>, StructError> {
let linidx = idx.index_into(&self.0).ok_or(StructError::OutOfBounds)?;
let fieldidx = field.index_into(self)?;
let old = unsafe { NonNull::new(mxGetFieldByNumber(self.0.deref_mut(), linidx as _, fieldidx)) }
.map(|mut ptr| unsafe { MxArray::assume_responsibility(ptr.as_mut()) } );
unsafe {
mxSetFieldByNumber(
self.0.deref_mut(),
linidx as _,
fieldidx,
if let Some(val) = value {
MxArray::transfer_responsibility_ptr(val)
} else {
std::ptr::null_mut()
}
)
}
Ok(old)
}
pub fn unset<I: Index, F: FieldIndex>(&mut self, idx: I, field: F) -> Result<Option<MxArray>, StructError> {
self.replace(idx, field, None)
}
pub fn fields_mut<F: FieldIndex>(&mut self, field: F) -> Result<impl Iterator<Item = Option<&'p mut mxArray>> + '_, StructError> {
let n = self.0.numel();
let f = field.index_into(self)?;
Ok((0..n).map(move |linidx| self.get_mut(linidx, f).unwrap()))
}
pub fn fields_values<F: FieldIndex>(&mut self, field: F) -> Result<impl Iterator<Item = Option<MxArray>> + '_, StructError> {
let n = self.0.numel();
let f = field.index_into(self)?;
Ok((0..n).map(move |linidx| self.unset(linidx, f).unwrap()))
}
pub fn add_field(&mut self, field: &CStr) -> Result<i32, i32> {
if let Ok(fieldnum) = self.field_number(field) {
return Err(fieldnum);
}
let fieldnum = unsafe { mxAddField(self.0.deref_mut(), field.as_ptr()) };
if fieldnum == -1 {
panic!("OOM")
}
Ok(fieldnum)
}
pub fn remove_field<F: FieldIndex>(&mut self, field: F) -> Result<(), StructError> {
let fieldidx = field.index_into(self)?;
unsafe { mxRemoveField(self.0.deref_mut(), fieldidx) };
Ok(())
}
pub fn delete_field<F: FieldIndex>(&mut self, field: F) -> Result<(), StructError> {
let fieldidx = field.index_into(self)?;
self.fields_values(fieldidx).unwrap().for_each(|_value| {});
self.remove_field(fieldidx)
}
}
impl Struct<MxArray> {
pub fn new_empty_with_fields(fieldnames: &[&CStr]) -> Self {
const EMPTY: [usize; 2] = [0,0];
Self::new(&EMPTY[..], fieldnames)
}
pub fn new(shape: &[usize], fieldnames: &[&CStr]) -> Self {
rustmex_core::shape_ok!(shape);
more_asserts::assert_lt!(fieldnames.len(), i32::MAX as usize, "Too many field names");
let mut fieldnames = fieldnames
.iter()
.map(|fieldname| fieldname.as_ptr())
.collect::<Vec<_>>();
let ptr = unsafe { mxCreateStructArray(
shape.len(),
shape.as_ptr(),
fieldnames.len() as i32,
fieldnames.as_mut_ptr()
)};
if ptr.is_null() {
panic!("OOM")
}
Self(unsafe { MxArray::assume_responsibility(&mut *ptr) } )
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
#[repr(transparent)]
pub struct ScalarStruct<P>(Struct<P>);
impl<'p, P: MatlabPtr + 'p> ScalarStruct<P> {
pub fn from_struct(s: Struct<P>) -> Result<Self, Struct<P>> {
if s.numel() == 1 {
Ok(Self(s))
} else {
Err(s)
}
}
pub fn into_struct(self) -> Struct<P> {
self.0
}
pub fn into_inner(self) -> P {
self.into_struct().into_inner()
}
#[inline]
pub fn get<F: FieldIndex>(&self, f: F) -> Result<Option<&'p mxArray>, StructError> {
self.0.get(0, f)
}
#[inline]
pub fn field_name(&self, fieldnum: i32) -> Result<&'p CStr, StructError> {
self.0.field_name(fieldnum)
}
#[inline]
pub fn field_number(&self, fieldname: &CStr) -> Result<i32, StructError> {
self.0.field_number(fieldname)
}
#[inline]
pub fn field_names(&self) -> impl Iterator<Item = &'p CStr> + '_ {
self.0.field_names()
}
#[inline]
pub fn fields<F: FieldIndex>(&self, field: F) -> Result<impl Iterator<Item = Option<&'p mxArray>> + '_, StructError> {
self.0.fields(field)
}
}
impl<'p, P: MutMatlabPtr + 'p> ScalarStruct<P> {
pub fn as_ref_struct<'s>(&'s self) -> ScalarStruct<&'s mxArray> {
ScalarStruct(self.0.as_ref_struct())
}
#[inline]
pub fn get_mut<F: FieldIndex>(&mut self, f: F) -> Result<Option<&'p mut mxArray>, StructError> {
self.0.get_mut(0, f)
}
#[inline]
pub fn set<F: FieldIndex>(&mut self, field: F, value: MxArray) -> Result<Option<MxArray>, StructError> {
self.replace(field, Some(value))
}
#[inline]
pub fn replace<F: FieldIndex>(&mut self, field: F, value: Option<MxArray>) -> Result<Option<MxArray>, StructError> {
self.0.replace(0, field, value)
}
#[inline]
pub fn unset<F: FieldIndex>(&mut self, field: F) -> Result<Option<MxArray>, StructError> {
self.replace(field, None)
}
#[inline]
pub fn fields_mut<F: FieldIndex>(&mut self, field: F) -> Result<impl Iterator<Item = Option<&'p mut mxArray>> + '_, StructError> {
self.0.fields_mut(field)
}
#[inline]
pub fn fields_values<F: FieldIndex>(&mut self, field: F) -> Result<impl Iterator<Item = Option<MxArray>> + '_, StructError> {
self.0.fields_values(field)
}
#[inline]
pub fn remove_field<F: FieldIndex>(&mut self, field: F) -> Result<(), StructError> {
self.0.remove_field(field)
}
#[inline]
pub fn delete_field<F: FieldIndex>(&mut self, field: F) -> Result<(), StructError> {
self.0.delete_field(field)
}
#[inline]
pub fn add_field(&mut self, field: &CStr) -> Result<i32, i32> {
self.0.add_field(field)
}
}