minikalman/buffers/types/
temporary_state_matrix_buffer.rs

1use core::marker::PhantomData;
2use core::ops::{Index, IndexMut};
3
4use crate::kalman::TemporaryStateMatrix;
5use crate::matrix::{IntoInnerData, MatrixData, MatrixDataArray, MatrixDataMut};
6use crate::matrix::{Matrix, MatrixMut};
7use crate::prelude::{RowMajorSequentialData, RowMajorSequentialDataMut};
8
9/// Mutable buffer for the temporary system matrix (`num_states` × `num_states`).
10///
11/// ## Example
12/// ```
13/// use minikalman::buffers::types::TemporaryStateMatrixBuffer;
14/// use minikalman::prelude::*;
15///
16/// // From owned data
17/// let buffer = TemporaryStateMatrixBuffer::new(MatrixData::new_array::<2, 2, 4, f32>([0.0; 4]));
18///
19/// // From a reference
20/// let mut data = [0.0; 4];
21/// let buffer = TemporaryStateMatrixBuffer::<2, f32, _>::from(data.as_mut_slice());
22/// ```
23pub struct TemporaryStateMatrixBuffer<const STATES: usize, T, M>(M, PhantomData<T>)
24where
25    M: MatrixMut<STATES, STATES, T>;
26
27// -----------------------------------------------------------
28
29impl<'a, const STATES: usize, T> From<&'a mut [T]>
30    for TemporaryStateMatrixBuffer<STATES, T, MatrixDataMut<'a, STATES, STATES, T>>
31{
32    fn from(value: &'a mut [T]) -> Self {
33        #[cfg(not(feature = "no_assert"))]
34        {
35            debug_assert!(STATES * STATES <= value.len());
36        }
37        Self::new(MatrixData::new_mut::<STATES, STATES, T>(value))
38    }
39}
40
41impl<const STATES: usize, const TOTAL: usize, T> From<[T; TOTAL]>
42    for TemporaryStateMatrixBuffer<STATES, T, MatrixDataArray<STATES, STATES, TOTAL, T>>
43{
44    fn from(value: [T; TOTAL]) -> Self {
45        #[cfg(not(feature = "no_assert"))]
46        {
47            debug_assert!(STATES * STATES <= TOTAL);
48        }
49        Self::new(MatrixData::new_array::<STATES, STATES, TOTAL, T>(value))
50    }
51}
52
53// -----------------------------------------------------------
54
55impl<const STATES: usize, T, M> TemporaryStateMatrixBuffer<STATES, T, M>
56where
57    M: MatrixMut<STATES, STATES, T>,
58{
59    pub const fn new(matrix: M) -> Self {
60        Self(matrix, PhantomData)
61    }
62
63    pub const fn len(&self) -> usize {
64        STATES * STATES
65    }
66
67    pub const fn is_empty(&self) -> bool {
68        STATES * STATES == 0
69    }
70
71    /// Ensures the underlying buffer has enough space for the expected number of values.
72    pub fn is_valid(&self) -> bool {
73        self.0.is_valid()
74    }
75}
76
77impl<const STATES: usize, T, M> RowMajorSequentialData<STATES, STATES, T>
78    for TemporaryStateMatrixBuffer<STATES, T, M>
79where
80    M: MatrixMut<STATES, STATES, T>,
81{
82    #[inline(always)]
83    fn as_slice(&self) -> &[T] {
84        self.0.as_slice()
85    }
86}
87
88impl<const STATES: usize, T, M> RowMajorSequentialDataMut<STATES, STATES, T>
89    for TemporaryStateMatrixBuffer<STATES, T, M>
90where
91    M: MatrixMut<STATES, STATES, T>,
92{
93    #[inline(always)]
94    fn as_mut_slice(&mut self) -> &mut [T] {
95        self.0.as_mut_slice()
96    }
97}
98
99impl<const STATES: usize, T, M> Matrix<STATES, STATES, T>
100    for TemporaryStateMatrixBuffer<STATES, T, M>
101where
102    M: MatrixMut<STATES, STATES, T>,
103{
104}
105
106impl<const STATES: usize, T, M> MatrixMut<STATES, STATES, T>
107    for TemporaryStateMatrixBuffer<STATES, T, M>
108where
109    M: MatrixMut<STATES, STATES, T>,
110{
111}
112
113impl<const STATES: usize, T, M> TemporaryStateMatrix<STATES, T>
114    for TemporaryStateMatrixBuffer<STATES, T, M>
115where
116    M: MatrixMut<STATES, STATES, T>,
117{
118    type Target = M;
119    type TargetMut = M;
120
121    fn as_matrix(&self) -> &Self::Target {
122        &self.0
123    }
124
125    fn as_matrix_mut(&mut self) -> &mut Self::TargetMut {
126        &mut self.0
127    }
128}
129
130impl<const STATES: usize, T, M> Index<usize> for TemporaryStateMatrixBuffer<STATES, T, M>
131where
132    M: MatrixMut<STATES, STATES, T>,
133{
134    type Output = T;
135
136    fn index(&self, index: usize) -> &Self::Output {
137        self.0.index(index)
138    }
139}
140
141impl<const STATES: usize, T, M> IndexMut<usize> for TemporaryStateMatrixBuffer<STATES, T, M>
142where
143    M: MatrixMut<STATES, STATES, T>,
144{
145    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
146        self.0.index_mut(index)
147    }
148}
149
150// -----------------------------------------------------------
151
152impl<const STATES: usize, T, M> IntoInnerData for TemporaryStateMatrixBuffer<STATES, T, M>
153where
154    M: MatrixMut<STATES, STATES, T> + IntoInnerData,
155{
156    type Target = M::Target;
157
158    fn into_inner(self) -> Self::Target {
159        self.0.into_inner()
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166
167    #[test]
168    fn test_from_array() {
169        let value: TemporaryStateMatrixBuffer<5, f32, _> = [0.0; 100].into();
170        assert_eq!(value.len(), 25);
171        assert!(!value.is_empty());
172        assert!(value.is_valid());
173    }
174
175    #[test]
176    fn test_from_mut() {
177        let mut data = [0.0_f32; 100];
178        let value: TemporaryStateMatrixBuffer<5, f32, _> = data.as_mut_slice().into();
179        assert_eq!(value.len(), 25);
180        assert!(!value.is_empty());
181        assert!(value.is_valid());
182        assert!(core::ptr::eq(value.as_slice(), &data));
183    }
184
185    #[test]
186    #[cfg(feature = "no_assert")]
187    fn test_from_array_invalid_size() {
188        let value: TemporaryStateMatrixBuffer<5, f32, _> = [0.0; 1].into();
189        assert!(!value.is_valid());
190    }
191
192    #[test]
193    #[rustfmt::skip]
194    fn test_access() {
195        let mut value: TemporaryStateMatrixBuffer<5, f32, _> = [0.0; 25].into();
196
197        // Set values.
198        {
199            let matrix = value.as_matrix_mut();
200            for i in 0..matrix.cols() {
201                matrix.set_symmetric(0, i, i as _);
202                matrix.set_at(i, i, i as _);
203            }
204        }
205
206        // Update values.
207        for i in 0..value.len() {
208            value[i] += 10.0;
209        }
210
211        // Get values.
212        {
213            let matrix = value.as_matrix();
214            for i in 0..matrix.rows() {
215                assert_eq!(matrix.get_at(0, i), 10.0 + i as f32);
216                assert_eq!(matrix.get_at(i, 0), 10.0 + i as f32);
217            }
218        }
219
220        assert_eq!(value.into_inner(),
221                   [
222                       10.0, 11.0, 12.0, 13.0, 14.0,
223                       11.0, 11.0, 10.0, 10.0, 10.0,
224                       12.0, 10.0, 12.0, 10.0, 10.0,
225                       13.0, 10.0, 10.0, 13.0, 10.0,
226                       14.0, 10.0, 10.0, 10.0, 14.0,
227                   ]);
228    }
229}