1use super::buffer::SizedTensorBuffer;
2use faer::ColRef;
3use faer::MatRef;
4use faer::RowRef;
5use qudit_core::array::TensorRef;
6use qudit_core::{ComplexScalar, memory::MemoryBuffer};
7use qudit_expr::GenerationShape;
8
9pub struct TNVMReturnType2<'a, C: ComplexScalar> {
10 pub ncols: usize,
11 pub nrows: usize,
12 pub nmats: usize,
13 pub ntens: usize,
14 pub buffer: &'a [C],
15}
16
17impl<'a, C: ComplexScalar> TNVMReturnType2<'a, C> {
18 pub fn num_elements(&self) -> usize {
19 self.ncols * self.nrows * self.nmats * self.ntens
20 }
21
22 pub fn unpack_scalar(self) -> &'a C {
23 if self.num_elements() != 1 {
24 panic!("Cannot unpack a non-scalar type as a scalar.");
25 }
26 &self.buffer[0]
27 }
28
29 pub fn unpack_row(self) -> RowRef<'a, C> {
30 RowRef::from_slice(self.buffer)
31 }
32
33 pub fn unpack_col(self) -> ColRef<'a, C> {
34 ColRef::from_slice(self.buffer)
35 }
36
37 pub fn unpack_mat(self) -> MatRef<'a, C> {
38 if self.ntens != 1 {
39 MatRef::from_column_major_slice(
40 self.buffer,
41 self.ntens,
42 self.nmats * self.nrows * self.ncols,
43 )
44 } else if self.nmats != 1 {
45 MatRef::from_column_major_slice(self.buffer, self.nmats, self.nrows * self.ncols)
46 } else {
47 MatRef::from_column_major_slice(self.buffer, self.nrows, self.ncols)
48 }
49 }
50
51 pub fn unpack_tensor3d(self) -> TensorRef<'a, C, 3> {
52 if self.ntens != 1 {
54 if self.nmats != 1 {
55 unsafe {
56 TensorRef::from_raw_parts(
57 self.buffer.as_ptr(),
58 [self.ntens, self.nmats, self.nrows * self.ncols],
59 [
60 self.nmats * self.nrows * self.ncols,
61 self.nrows * self.ncols,
62 1,
63 ],
64 )
65 }
66 } else {
67 unsafe {
68 TensorRef::from_raw_parts(
69 self.buffer.as_ptr(),
70 [self.ntens, self.nrows, self.ncols],
71 [self.nrows * self.ncols, 1, self.nrows],
72 )
73 }
74 }
75 } else {
76 unsafe {
77 TensorRef::from_raw_parts(
78 self.buffer.as_ptr(),
79 [self.nmats, self.nrows, self.ncols],
80 [self.nrows * self.ncols, 1, self.nrows],
81 )
82 }
83 }
84 }
85
86 pub fn unpack_tensor4d(self) -> TensorRef<'a, C, 4> {
87 unsafe {
88 TensorRef::from_raw_parts(
89 self.buffer.as_ptr(),
90 [self.ntens, self.nmats, self.nrows, self.ncols],
91 [
92 self.nmats * self.nrows * self.ncols,
93 self.nrows * self.ncols,
94 1,
95 self.nrows,
96 ],
97 )
98 }
99 }
100}
101
102pub enum TNVMReturnType<'a, C: ComplexScalar> {
103 Scalar(&'a C),
104 Vector(faer::RowRef<'a, C>),
105 Matrix(faer::MatRef<'a, C>),
106 Tensor3D(qudit_core::array::TensorRef<'a, C, 3>),
107 Tensor4D(qudit_core::array::TensorRef<'a, C, 4>),
108 SymSqMatrix(qudit_core::array::SymSqTensorRef<'a, C, 2>),
109 SymSqTensor3D(qudit_core::array::SymSqTensorRef<'a, C, 3>),
110 SymSqTensor4D(qudit_core::array::SymSqTensorRef<'a, C, 4>),
111 SymSqTensor5D(qudit_core::array::SymSqTensorRef<'a, C, 5>),
112}
113
114impl<'a, C: ComplexScalar> TNVMReturnType<'a, C> {
115 pub fn unpack_scalar(self) -> &'a C {
116 match self {
117 TNVMReturnType::Scalar(s) => s,
118 _ => panic!("cannot unpack a non-scalar type as a scalar"),
119 }
120 }
121
122 pub fn unpack_vector(self) -> faer::RowRef<'a, C> {
123 match self {
124 TNVMReturnType::Vector(v) => v,
125 _ => panic!("cannot unpack a non-row-vector type as a row-vector"),
126 }
127 }
128
129 pub fn unpack_matrix(self) -> faer::MatRef<'a, C> {
138 match self {
139 TNVMReturnType::Matrix(m) => m,
140 _ => panic!("cannot unpack a non-matrix type as a matrix"),
141 }
142 }
143
144 pub fn unpack_tensor3d(self) -> qudit_core::array::TensorRef<'a, C, 3> {
145 match self {
146 TNVMReturnType::Tensor3D(m) => m,
147 _ => panic!("cannot unpack a non-tensor3d type as a tensor3d"),
148 }
149 }
150
151 pub fn unpack_tensor4d(self) -> qudit_core::array::TensorRef<'a, C, 4> {
152 match self {
153 TNVMReturnType::Tensor4D(t) => t,
154 _ => panic!("cannot unpack a non-tensor4d type as a tensor4d"),
155 }
156 }
157
158 pub fn unpack_symsq_matrix(self) -> qudit_core::array::SymSqTensorRef<'a, C, 2> {
159 match self {
160 TNVMReturnType::SymSqMatrix(t) => t,
161 _ => panic!("cannot unpack a non-symsq-matrix type as a symsq-matrix"),
162 }
163 }
164
165 pub fn unpack_symsq_tensor3d(self) -> qudit_core::array::SymSqTensorRef<'a, C, 3> {
166 match self {
167 TNVMReturnType::SymSqTensor3D(t) => t,
168 _ => panic!("cannot unpack a non-symsq-tensor3d type as a symsq-tensor3d"),
169 }
170 }
171
172 pub fn unpack_symsq_tensor4d(self) -> qudit_core::array::SymSqTensorRef<'a, C, 4> {
173 match self {
174 TNVMReturnType::SymSqTensor4D(t) => t,
175 _ => panic!("cannot unpack a non-symsq-tensor4d type as a symsq-tensor4d"),
176 }
177 }
178
179 pub fn unpack_symsq_tensor5d(self) -> qudit_core::array::SymSqTensorRef<'a, C, 5> {
180 match self {
181 TNVMReturnType::SymSqTensor5D(t) => t,
182 _ => panic!("cannot unpack a non-symsq-tensor5d type as a symsq-tensor5d"),
183 }
184 }
185
186 }
189
190pub struct TNVMResult<'a, C: ComplexScalar> {
191 memory: &'a MemoryBuffer<C>,
192 buffer: &'a SizedTensorBuffer<C>,
193}
194
195impl<'a, C: ComplexScalar> TNVMResult<'a, C> {
196 pub fn new(memory: &'a MemoryBuffer<C>, buffer: &'a SizedTensorBuffer<C>) -> Self {
197 TNVMResult { memory, buffer }
198 }
199
200 pub fn get_fn_result2(&self) -> TNVMReturnType2<'a, C> {
201 TNVMReturnType2 {
203 ncols: self.buffer.ncols(),
204 nrows: self.buffer.nrows(),
205 nmats: self.buffer.nmats(),
206 ntens: 1,
207 buffer: unsafe {
208 std::slice::from_raw_parts(
209 self.buffer.as_ptr(self.memory),
210 self.buffer.unit_memory_size(),
211 )
212 },
213 }
214 }
215
216 pub fn get_grad_result2(&self) -> TNVMReturnType2<'a, C> {
217 TNVMReturnType2 {
218 ncols: self.buffer.ncols(),
219 nrows: self.buffer.nrows(),
220 nmats: self.buffer.nmats(),
221 ntens: self.buffer.nparams(),
222 buffer: unsafe {
223 std::slice::from_raw_parts(
224 self.buffer
225 .as_ptr(self.memory)
226 .add(self.buffer.unit_memory_size()),
227 self.buffer.grad_memory_size(),
228 )
229 },
230 }
231 }
232
233 pub fn get_fn_result(&self) -> TNVMReturnType<'a, C> {
234 match self.buffer.shape() {
235 GenerationShape::Scalar => {
237 TNVMReturnType::Scalar(unsafe { self.buffer.as_scalar_ref(self.memory) })
238 }
239 GenerationShape::Vector(_) => {
240 TNVMReturnType::Vector(unsafe { self.buffer.as_vector_ref(self.memory) })
241 }
242 GenerationShape::Matrix(_, _) => {
243 TNVMReturnType::Matrix(unsafe { self.buffer.as_matrix_ref(self.memory) })
244 }
245 GenerationShape::Tensor3D(_, _, _) => {
246 TNVMReturnType::Tensor3D(unsafe { self.buffer.as_tensor3d_ref(self.memory) })
247 }
248 _ => panic!("No Tensor4D should be exposed a function value output."),
249 }
250 }
251
252 pub fn get_grad_result(&self) -> TNVMReturnType<'a, C> {
255 match self.buffer.shape() {
257 GenerationShape::Scalar => {
259 TNVMReturnType::Vector(unsafe { self.buffer.grad_as_vector_ref(self.memory) })
260 }
261 GenerationShape::Vector(_) => {
262 TNVMReturnType::Matrix(unsafe { self.buffer.grad_as_matrix_ref(self.memory) })
263 }
264 GenerationShape::Matrix(_, _) => {
265 TNVMReturnType::Tensor3D(unsafe { self.buffer.grad_as_tensor3d_ref(self.memory) })
266 }
267 GenerationShape::Tensor3D(_, _, _) => {
268 TNVMReturnType::Tensor4D(unsafe { self.buffer.grad_as_tensor4d_ref(self.memory) })
269 }
270 _ => panic!("No Tensor4D should be exposed a function value output."),
271 }
272 }
273
274 pub fn get_hess_result(&self) -> TNVMReturnType<'a, C> {
275 match self.buffer.shape() {
276 GenerationShape::Scalar => TNVMReturnType::SymSqMatrix(unsafe {
278 self.buffer.hess_as_symsq_matrix_ref(self.memory)
279 }),
280 GenerationShape::Vector(_) => TNVMReturnType::SymSqTensor3D(unsafe {
281 self.buffer.hess_as_symsq_tensor3d_ref(self.memory)
282 }),
283 GenerationShape::Matrix(_, _) => TNVMReturnType::SymSqTensor4D(unsafe {
284 self.buffer.hess_as_symsq_tensor4d_ref(self.memory)
285 }),
286 GenerationShape::Tensor3D(_, _, _) => TNVMReturnType::SymSqTensor5D(unsafe {
287 self.buffer.hess_as_symsq_tensor5d_ref(self.memory)
288 }),
289 _ => panic!("No Tensor4D should be exposed a function value output."),
290 }
291 }
292}