matio_rs/
mat.rs

1use crate::{
2    MatArray, MatFile, MatFileRead, MatFileWrite, MatType, MatioError, MayBeFrom, MayBeInto, Result,
3};
4use std::{ffi::CStr, marker::PhantomData, ptr, slice::from_raw_parts};
5
6/// Matlab variable
7pub struct Mat<'a> {
8    pub(crate) name: String,
9    pub(crate) matvar_t: *mut ffi::matvar_t,
10    pub(crate) fields: Option<Vec<Mat<'a>>>,
11    pub(crate) marker: PhantomData<&'a ffi::matvar_t>,
12    pub(crate) as_ref: bool,
13}
14impl<'a> Drop for Mat<'a> {
15    fn drop(&mut self) {
16        if let Some(mut fields) = self.fields.take() {
17            fields.iter_mut().for_each(|mat| {
18                mat.matvar_t = ptr::null_mut();
19            })
20        }
21        if !self.as_ref {
22            unsafe {
23                ffi::Mat_VarFree(self.matvar_t);
24            }
25        }
26    }
27}
28impl<'a> MatFile<'a> {
29    /// Read from a [MatFile] the Matlab [Mat] variable `name`
30    pub fn read<S: Into<String>>(&self, name: S) -> Result<Mat<'a>> {
31        let c_name = std::ffi::CString::new(name.into())?;
32        let matvar_t = unsafe { ffi::Mat_VarRead(self.mat_t, c_name.as_ptr()) };
33        if matvar_t.is_null() {
34            Err(MatioError::MatVarRead(c_name.to_str().unwrap().to_string()))
35        } else {
36            Mat::from_ptr(c_name.to_str()?, matvar_t)
37        }
38    }
39    /// Write to a [MatFile] the Matlab [Mat] variable `name`
40    pub fn write(&self, var: Mat<'a>) -> &Self {
41        unsafe {
42            ffi::Mat_VarWrite(
43                self.mat_t,
44                var.matvar_t,
45                ffi::matio_compression_MAT_COMPRESSION_NONE,
46            );
47        }
48        self
49    }
50}
51impl<'a> MatFileRead<'a> {
52    /// Read from a [MatFileRead]er the Matlab [Mat] variable `name`
53    ///
54    /// Reading a scalar Matlab variable: a = π
55    /// ```
56    /// use matio_rs::MatFile;
57    /// # let file = tempfile::NamedTempFile::new().unwrap();
58    /// # let data_path = file.path();
59    /// # let mat_file = MatFile::save(&data_path)?.var("a", std::f64::consts::PI)?;
60    /// let a: f64 = MatFile::load(data_path)?.var("a")?;
61    /// println!("{a:?}");
62    /// # Ok::<(), matio_rs::MatioError>(())
63    /// ```
64    ///
65    /// Reading a Matlab vector: b = [3.0, 1.0, 4.0, 1.0, 6.0]
66    /// ```
67    /// use matio_rs::MatFile;
68    /// # let file = tempfile::NamedTempFile::new().unwrap();
69    /// # let data_path = file.path();
70    /// # let mat_file = MatFile::save(&data_path)?.var("b", vec![3.0, 1.0, 4.0, 1.0, 6.0])?;
71    /// let b: Vec<f64> = MatFile::load(data_path)?.var("b")?;
72    /// println!("{b:?}");
73    /// # Ok::<(), matio_rs::MatioError>(())
74    /// ```
75    pub fn var<S: Into<String>, T>(&self, name: S) -> Result<T>
76    where
77        Mat<'a>: MayBeInto<T>,
78    {
79        self.read(name).and_then(|mat| mat.maybe_into())
80    }
81}
82impl<'a> MatFileWrite<'a> {
83    /// Write to a [MatFileWrite]r the Matlab [Mat] variable `name`
84    ///
85    /// Saving to a mat file
86    /// ```
87    /// use matio_rs::MatFile;
88    /// # let file = tempfile::NamedTempFile::new().unwrap();
89    /// # let data_path = file.path();
90    /// let mut b = (0..5).map(|x| (x as f64).cosh()).collect::<Vec<f64>>();
91    /// MatFile::save(data_path)?
92    /// .var("a", 2f64.sqrt())?
93    /// .var("b", &b)?;
94    /// # Ok::<(), matio_rs::MatioError>(())
95    /// ```
96    pub fn var<S: Into<String>, T>(&self, name: S, data: T) -> Result<&Self>
97    where
98        Mat<'a>: MayBeFrom<T>,
99    {
100        let mat: Mat<'a> = MayBeFrom::<T>::maybe_from(name, data)?;
101        self.write(mat);
102        Ok(self)
103    }
104    /// Write to a [MatFileWrite]r the Matlab [Mat] variable `name` as a N-dimensition array [MatArray]
105    ///
106    /// The data is aligned according to and in the order of the dimension vector dims
107    pub fn array<S: Into<String>, T>(&self, name: S, data: &'a [T], dims: Vec<u64>) -> Result<&Self>
108    where
109        Mat<'a>: MayBeFrom<MatArray<'a, T>>,
110    {
111        let mat_array = MatArray::new(data, dims);
112        self.var(name, mat_array)?;
113        Ok(self)
114    }
115}
116impl<'a> Mat<'a> {
117    /// Returns the rank (# of dimensions) of the Matlab variable
118    pub fn rank(&self) -> usize {
119        unsafe { (*self.matvar_t).rank as usize }
120    }
121    /// Returns the dimensions of the Matlab variable
122    pub fn dims(&self) -> Vec<usize> {
123        let rank = self.rank();
124        let mut dims: Vec<usize> = Vec::with_capacity(rank);
125        unsafe {
126            ptr::copy((*self.matvar_t).dims, dims.as_mut_ptr(), rank);
127            dims.set_len(rank);
128        };
129        dims
130    }
131    /// Returns the number of elements of the Matlab variable
132    pub fn len(&self) -> usize {
133        self.dims().into_iter().product::<usize>() as usize
134    }
135    pub(crate) fn mat_type(&self) -> Option<MatType> {
136        MatType::from_ptr(self.matvar_t)
137    }
138    pub(crate) fn as_ptr<S: Into<String>>(name: S, ptr: *mut ffi::matvar_t) -> Result<Self> {
139        Self::from_ptr(name, ptr).map(|mut mat| {
140            mat.as_ref = true;
141            mat
142        })
143    }
144    pub(crate) fn from_ptr<S: Into<String>>(name: S, ptr: *mut ffi::matvar_t) -> Result<Self> {
145        if let Some(MatType::STRUCT) = MatType::from_ptr(ptr) {
146            let rank = unsafe { (*ptr).rank as usize };
147            let mut dims: Vec<usize> = Vec::with_capacity(rank);
148            unsafe {
149                ptr::copy((*ptr).dims, dims.as_mut_ptr(), rank);
150                dims.set_len(rank);
151            };
152            let nel: usize = dims.iter().product();
153            let n = unsafe { ffi::Mat_VarGetNumberOfFields(ptr) } as usize;
154            // fields name
155            let field_names = unsafe {
156                from_raw_parts(ffi::Mat_VarGetStructFieldnames(ptr), n)
157                    .into_iter()
158                    .map(|&s| CStr::from_ptr(s).to_str())
159                    .collect::<std::result::Result<Vec<&str>, std::str::Utf8Error>>()?
160            };
161            // fields data pointer
162            let field_ptr =
163                unsafe { from_raw_parts((*ptr).data as *mut *mut ffi::matvar_t, n * nel) };
164            let mut fields: Vec<Mat> = Vec::new();
165            for (name, &ptr) in field_names.into_iter().cycle().zip(field_ptr.iter()) {
166                let mat = Mat::from_ptr(name, ptr)?;
167                fields.push(mat);
168            }
169            Ok(Mat {
170                name: name.into(),
171                matvar_t: ptr,
172                fields: Some(fields),
173                marker: PhantomData,
174                as_ref: false,
175            })
176        } else {
177            Ok(Mat {
178                name: name.into(),
179                matvar_t: ptr,
180                fields: None,
181                marker: PhantomData,
182                as_ref: false,
183            })
184        }
185    }
186    /// Returns the field `name` from a Matlab structure
187    pub fn field<S: Into<String>>(&self, name: S) -> Result<Vec<&Mat<'_>>> {
188        let fields = if let Some(MatType::STRUCT) = self.mat_type() {
189            self.fields.as_ref().unwrap()
190        } else {
191            return Err(MatioError::TypeMismatch(
192                self.name.clone(),
193                stringify!(MatType::STRUCT).to_string(),
194                stringify!(self.mat_type()).to_string(),
195            ));
196        };
197        let field_name: String = name.into();
198        let field_value: Vec<&Mat> = fields
199            .iter()
200            .filter(|field| field.name == field_name)
201            .collect();
202        if field_value.is_empty() {
203            Err(MatioError::FieldNotFound(field_name))
204        } else {
205            Ok(field_value)
206        }
207    }
208}