qudit_tensor/cpu/
result.rs

1use super::buffer::SizedTensorBuffer;
2use faer::ColRef;
3use faer::MatRef;
4use faer::RowRef;
5use qudit_core::array::TensorRef;
6use qudit_core::{ComplexScalar, memory::MemoryBuffer};
7use qudit_expr::GenerationShape;
8
9pub struct TNVMReturnType2<'a, C: ComplexScalar> {
10    pub ncols: usize,
11    pub nrows: usize,
12    pub nmats: usize,
13    pub ntens: usize,
14    pub buffer: &'a [C],
15}
16
17impl<'a, C: ComplexScalar> TNVMReturnType2<'a, C> {
18    pub fn num_elements(&self) -> usize {
19        self.ncols * self.nrows * self.nmats * self.ntens
20    }
21
22    pub fn unpack_scalar(self) -> &'a C {
23        if self.num_elements() != 1 {
24            panic!("Cannot unpack a non-scalar type as a scalar.");
25        }
26        &self.buffer[0]
27    }
28
29    pub fn unpack_row(self) -> RowRef<'a, C> {
30        RowRef::from_slice(self.buffer)
31    }
32
33    pub fn unpack_col(self) -> ColRef<'a, C> {
34        ColRef::from_slice(self.buffer)
35    }
36
37    pub fn unpack_mat(self) -> MatRef<'a, C> {
38        if self.ntens != 1 {
39            MatRef::from_column_major_slice(
40                self.buffer,
41                self.ntens,
42                self.nmats * self.nrows * self.ncols,
43            )
44        } else if self.nmats != 1 {
45            MatRef::from_column_major_slice(self.buffer, self.nmats, self.nrows * self.ncols)
46        } else {
47            MatRef::from_column_major_slice(self.buffer, self.nrows, self.ncols)
48        }
49    }
50
51    pub fn unpack_tensor3d(self) -> TensorRef<'a, C, 3> {
52        // println!("{}, {}, {}, {}", self.ntens, self.nmats, self.nrows, self.ncols);
53        if self.ntens != 1 {
54            if self.nmats != 1 {
55                unsafe {
56                    TensorRef::from_raw_parts(
57                        self.buffer.as_ptr(),
58                        [self.ntens, self.nmats, self.nrows * self.ncols],
59                        [
60                            self.nmats * self.nrows * self.ncols,
61                            self.nrows * self.ncols,
62                            1,
63                        ],
64                    )
65                }
66            } else {
67                unsafe {
68                    TensorRef::from_raw_parts(
69                        self.buffer.as_ptr(),
70                        [self.ntens, self.nrows, self.ncols],
71                        [self.nrows * self.ncols, 1, self.nrows],
72                    )
73                }
74            }
75        } else {
76            unsafe {
77                TensorRef::from_raw_parts(
78                    self.buffer.as_ptr(),
79                    [self.nmats, self.nrows, self.ncols],
80                    [self.nrows * self.ncols, 1, self.nrows],
81                )
82            }
83        }
84    }
85
86    pub fn unpack_tensor4d(self) -> TensorRef<'a, C, 4> {
87        unsafe {
88            TensorRef::from_raw_parts(
89                self.buffer.as_ptr(),
90                [self.ntens, self.nmats, self.nrows, self.ncols],
91                [
92                    self.nmats * self.nrows * self.ncols,
93                    self.nrows * self.ncols,
94                    1,
95                    self.nrows,
96                ],
97            )
98        }
99    }
100}
101
102pub enum TNVMReturnType<'a, C: ComplexScalar> {
103    Scalar(&'a C),
104    Vector(faer::RowRef<'a, C>),
105    Matrix(faer::MatRef<'a, C>),
106    Tensor3D(qudit_core::array::TensorRef<'a, C, 3>),
107    Tensor4D(qudit_core::array::TensorRef<'a, C, 4>),
108    SymSqMatrix(qudit_core::array::SymSqTensorRef<'a, C, 2>),
109    SymSqTensor3D(qudit_core::array::SymSqTensorRef<'a, C, 3>),
110    SymSqTensor4D(qudit_core::array::SymSqTensorRef<'a, C, 4>),
111    SymSqTensor5D(qudit_core::array::SymSqTensorRef<'a, C, 5>),
112}
113
114impl<'a, C: ComplexScalar> TNVMReturnType<'a, C> {
115    pub fn unpack_scalar(self) -> &'a C {
116        match self {
117            TNVMReturnType::Scalar(s) => s,
118            _ => panic!("cannot unpack a non-scalar type as a scalar"),
119        }
120    }
121
122    pub fn unpack_vector(self) -> faer::RowRef<'a, C> {
123        match self {
124            TNVMReturnType::Vector(v) => v,
125            _ => panic!("cannot unpack a non-row-vector type as a row-vector"),
126        }
127    }
128
129    // TODO: mutate the row into a col
130    // pub fn unpack_col_vector(self) -> qudit_core::matrix::ColRef<'a, C> {
131    //     match self {
132    //         TNVMReturnType::ColVector(v) => v,
133    //         _ => panic!("cannot unpack a non-col-vector type as a col-vector"),
134    //     }
135    // }
136
137    pub fn unpack_matrix(self) -> faer::MatRef<'a, C> {
138        match self {
139            TNVMReturnType::Matrix(m) => m,
140            _ => panic!("cannot unpack a non-matrix type as a matrix"),
141        }
142    }
143
144    pub fn unpack_tensor3d(self) -> qudit_core::array::TensorRef<'a, C, 3> {
145        match self {
146            TNVMReturnType::Tensor3D(m) => m,
147            _ => panic!("cannot unpack a non-tensor3d type as a tensor3d"),
148        }
149    }
150
151    pub fn unpack_tensor4d(self) -> qudit_core::array::TensorRef<'a, C, 4> {
152        match self {
153            TNVMReturnType::Tensor4D(t) => t,
154            _ => panic!("cannot unpack a non-tensor4d type as a tensor4d"),
155        }
156    }
157
158    pub fn unpack_symsq_matrix(self) -> qudit_core::array::SymSqTensorRef<'a, C, 2> {
159        match self {
160            TNVMReturnType::SymSqMatrix(t) => t,
161            _ => panic!("cannot unpack a non-symsq-matrix type as a symsq-matrix"),
162        }
163    }
164
165    pub fn unpack_symsq_tensor3d(self) -> qudit_core::array::SymSqTensorRef<'a, C, 3> {
166        match self {
167            TNVMReturnType::SymSqTensor3D(t) => t,
168            _ => panic!("cannot unpack a non-symsq-tensor3d type as a symsq-tensor3d"),
169        }
170    }
171
172    pub fn unpack_symsq_tensor4d(self) -> qudit_core::array::SymSqTensorRef<'a, C, 4> {
173        match self {
174            TNVMReturnType::SymSqTensor4D(t) => t,
175            _ => panic!("cannot unpack a non-symsq-tensor4d type as a symsq-tensor4d"),
176        }
177    }
178
179    pub fn unpack_symsq_tensor5d(self) -> qudit_core::array::SymSqTensorRef<'a, C, 5> {
180        match self {
181            TNVMReturnType::SymSqTensor5D(t) => t,
182            _ => panic!("cannot unpack a non-symsq-tensor5d type as a symsq-tensor5d"),
183        }
184    }
185
186    // TODO: Decide, do I want unpack_symsq_tensor3D or do I want unpack_tensor3d to
187    // un-symsq it? ... or both, why not?
188}
189
190pub struct TNVMResult<'a, C: ComplexScalar> {
191    memory: &'a MemoryBuffer<C>,
192    buffer: &'a SizedTensorBuffer<C>,
193}
194
195impl<'a, C: ComplexScalar> TNVMResult<'a, C> {
196    pub fn new(memory: &'a MemoryBuffer<C>, buffer: &'a SizedTensorBuffer<C>) -> Self {
197        TNVMResult { memory, buffer }
198    }
199
200    pub fn get_fn_result2(&self) -> TNVMReturnType2<'a, C> {
201        // println!("{}, {}, {}, {}", self.buffer.ncols(), self.buffer.nrows(), self.buffer.nmats(), self.buffer.unit_memory_size());
202        TNVMReturnType2 {
203            ncols: self.buffer.ncols(),
204            nrows: self.buffer.nrows(),
205            nmats: self.buffer.nmats(),
206            ntens: 1,
207            buffer: unsafe {
208                std::slice::from_raw_parts(
209                    self.buffer.as_ptr(self.memory),
210                    self.buffer.unit_memory_size(),
211                )
212            },
213        }
214    }
215
216    pub fn get_grad_result2(&self) -> TNVMReturnType2<'a, C> {
217        TNVMReturnType2 {
218            ncols: self.buffer.ncols(),
219            nrows: self.buffer.nrows(),
220            nmats: self.buffer.nmats(),
221            ntens: self.buffer.nparams(),
222            buffer: unsafe {
223                std::slice::from_raw_parts(
224                    self.buffer
225                        .as_ptr(self.memory)
226                        .add(self.buffer.unit_memory_size()),
227                    self.buffer.grad_memory_size(),
228                )
229            },
230        }
231    }
232
233    pub fn get_fn_result(&self) -> TNVMReturnType<'a, C> {
234        match self.buffer.shape() {
235            // Safety: TNVM told me this output buffer is mine
236            GenerationShape::Scalar => {
237                TNVMReturnType::Scalar(unsafe { self.buffer.as_scalar_ref(self.memory) })
238            }
239            GenerationShape::Vector(_) => {
240                TNVMReturnType::Vector(unsafe { self.buffer.as_vector_ref(self.memory) })
241            }
242            GenerationShape::Matrix(_, _) => {
243                TNVMReturnType::Matrix(unsafe { self.buffer.as_matrix_ref(self.memory) })
244            }
245            GenerationShape::Tensor3D(_, _, _) => {
246                TNVMReturnType::Tensor3D(unsafe { self.buffer.as_tensor3d_ref(self.memory) })
247            }
248            _ => panic!("No Tensor4D should be exposed a function value output."),
249        }
250    }
251
252    // TODO: this needs to be made more safe by gating it behind const DifferentiationLevel generic
253    // impls
254    pub fn get_grad_result(&self) -> TNVMReturnType<'a, C> {
255        // TODO: Has to ensure nparams is always outer most vector when unpacking
256        match self.buffer.shape() {
257            // Safety: TNVM told me this output buffer is mine
258            GenerationShape::Scalar => {
259                TNVMReturnType::Vector(unsafe { self.buffer.grad_as_vector_ref(self.memory) })
260            }
261            GenerationShape::Vector(_) => {
262                TNVMReturnType::Matrix(unsafe { self.buffer.grad_as_matrix_ref(self.memory) })
263            }
264            GenerationShape::Matrix(_, _) => {
265                TNVMReturnType::Tensor3D(unsafe { self.buffer.grad_as_tensor3d_ref(self.memory) })
266            }
267            GenerationShape::Tensor3D(_, _, _) => {
268                TNVMReturnType::Tensor4D(unsafe { self.buffer.grad_as_tensor4d_ref(self.memory) })
269            }
270            _ => panic!("No Tensor4D should be exposed a function value output."),
271        }
272    }
273
274    pub fn get_hess_result(&self) -> TNVMReturnType<'a, C> {
275        match self.buffer.shape() {
276            // Safety: TNVM told me this output buffer is mine
277            GenerationShape::Scalar => TNVMReturnType::SymSqMatrix(unsafe {
278                self.buffer.hess_as_symsq_matrix_ref(self.memory)
279            }),
280            GenerationShape::Vector(_) => TNVMReturnType::SymSqTensor3D(unsafe {
281                self.buffer.hess_as_symsq_tensor3d_ref(self.memory)
282            }),
283            GenerationShape::Matrix(_, _) => TNVMReturnType::SymSqTensor4D(unsafe {
284                self.buffer.hess_as_symsq_tensor4d_ref(self.memory)
285            }),
286            GenerationShape::Tensor3D(_, _, _) => TNVMReturnType::SymSqTensor5D(unsafe {
287                self.buffer.hess_as_symsq_tensor5d_ref(self.memory)
288            }),
289            _ => panic!("No Tensor4D should be exposed a function value output."),
290        }
291    }
292}