1use ndarray::ArrayD;
8use ndarray::IxDyn;
9
10pub fn tensordot(
23 a: &ArrayD<f64>,
24 b: &ArrayD<f64>,
25 axes_a: &[usize],
26 axes_b: &[usize],
27) -> Result<ArrayD<f64>, String> {
28
29 if axes_a.len() != axes_b.len() {
30
31 return Err("Contracted axes \
32 must have the \
33 same length."
34 .to_string());
35 }
36
37 for (&ax_a, &ax_b) in axes_a
38 .iter()
39 .zip(axes_b.iter())
40 {
41
42 if a.shape()[ax_a]
43 != b.shape()[ax_b]
44 {
45
46 return Err(format!(
47 "Dimension mismatch \
48 on contracted axes: \
49 {} != {}",
50 a.shape()[ax_a],
51 b.shape()[ax_b]
52 ));
53 }
54 }
55
56 let free_axes_a: Vec<_> = (0 .. a
57 .ndim())
58 .filter(|i| !axes_a.contains(i))
59 .collect();
60
61 let free_axes_b: Vec<_> = (0 .. b
62 .ndim())
63 .filter(|i| !axes_b.contains(i))
64 .collect();
65
66 let perm_a: Vec<_> = free_axes_a
67 .iter()
68 .chain(axes_a.iter())
69 .copied()
70 .collect();
71
72 let perm_b: Vec<_> = axes_b
73 .iter()
74 .chain(free_axes_b.iter())
75 .copied()
76 .collect();
77
78 let a_perm = a
79 .clone()
80 .permuted_axes(perm_a);
81
82 let b_perm = b
83 .clone()
84 .permuted_axes(perm_b);
85
86 let free_dim_a = free_axes_a
87 .iter()
88 .map(|&i| a.shape()[i])
89 .product::<usize>();
90
91 let free_dim_b = free_axes_b
92 .iter()
93 .map(|&i| b.shape()[i])
94 .product::<usize>();
95
96 let contracted_dim = axes_a
97 .iter()
98 .map(|&i| a.shape()[i])
99 .product::<usize>();
100
101 let a_mat = a_perm
102 .to_shape((
103 free_dim_a,
104 contracted_dim,
105 ))
106 .map_err(|e| e.to_string())?
107 .to_owned();
108
109 let b_mat = b_perm
110 .to_shape((
111 contracted_dim,
112 free_dim_b,
113 ))
114 .map_err(|e| e.to_string())?
115 .to_owned();
116
117 let result_mat = a_mat.dot(&b_mat);
118
119 let mut final_shape_dims =
120 Vec::new();
121
122 final_shape_dims.extend(
123 free_axes_a
124 .iter()
125 .map(|&i| a.shape()[i]),
126 );
127
128 final_shape_dims.extend(
129 free_axes_b
130 .iter()
131 .map(|&i| b.shape()[i]),
132 );
133
134 Ok(result_mat
135 .to_shape(IxDyn(
136 &final_shape_dims,
137 ))
138 .map_err(|e| e.to_string())?
139 .to_owned())
140}
141
142pub fn outer_product(
159 a: &ArrayD<f64>,
160 b: &ArrayD<f64>,
161) -> Result<ArrayD<f64>, String> {
162
163 let mut new_shape =
164 a.shape().to_vec();
165
166 new_shape
167 .extend_from_slice(b.shape());
168
169 let a_flat = a
170 .as_slice()
171 .ok_or_else(|| {
172
173 "Input tensor 'a' is not \
174 contiguous"
175 .to_string()
176 })?;
177
178 let b_flat = b
179 .as_slice()
180 .ok_or_else(|| {
181
182 "Input tensor 'b' is not \
183 contiguous"
184 .to_string()
185 })?;
186
187 let mut result_data =
188 Vec::with_capacity(
189 a.len() * b.len(),
190 );
191
192 for val_a in a_flat {
193
194 for val_b in b_flat {
195
196 result_data
197 .push(val_a * val_b);
198 }
199 }
200
201 ArrayD::from_shape_vec(
202 IxDyn(&new_shape),
203 result_data,
204 )
205 .map_err(|e| e.to_string())
206}
207
208pub fn tensor_vec_mul(
214 tensor: &ArrayD<f64>,
215 vector: &[f64],
216) -> Result<ArrayD<f64>, String> {
217
218 if tensor.ndim() < 1 {
219
220 return Err("Tensor must \
221 have at least \
222 one dimension."
223 .to_string());
224 }
225
226 let last_dim = tensor.shape()
227 [tensor.ndim() - 1];
228
229 if last_dim != vector.len() {
230
231 return Err(format!(
232 "Dimension mismatch: last \
233 tensor dim {} != vector \
234 length {}",
235 last_dim,
236 vector.len()
237 ));
238 }
239
240 let vec_arr =
241 ndarray::Array1::from_vec(
242 vector.to_vec(),
243 );
244
245 let res = tensordot(
246 tensor,
247 &vec_arr.into_dyn(),
248 &[tensor.ndim() - 1],
249 &[0],
250 )?;
251
252 Ok(res)
253}
254
255pub fn inner_product(
261 a: &ArrayD<f64>,
262 b: &ArrayD<f64>,
263) -> Result<f64, String> {
264
265 if a.shape() != b.shape() {
266
267 return Err("Tensors must \
268 have the same \
269 shape for inner \
270 product."
271 .to_string());
272 }
273
274 let a_flat = a.as_slice().ok_or(
275 "Tensor 'a' is not contiguous",
276 )?;
277
278 let b_flat = b.as_slice().ok_or(
279 "Tensor 'b' is not contiguous",
280 )?;
281
282 Ok(a_flat
283 .iter()
284 .zip(b_flat.iter())
285 .map(|(x, y)| x * y)
286 .sum())
287}
288
289pub fn contract(
295 a: &ArrayD<f64>,
296 axis1: usize,
297 axis2: usize,
298) -> Result<ArrayD<f64>, String> {
299
300 if axis1 == axis2 {
301
302 return Err("Axes must be \
303 different for \
304 contraction."
305 .to_string());
306 }
307
308 if a.shape()[axis1]
309 != a.shape()[axis2]
310 {
311
312 return Err("Dimensions \
313 along contraction \
314 axes must be \
315 equal."
316 .to_string());
317 }
318
319 let n = a.shape()[axis1];
320
321 #[warn(clippy::collection_is_never_read)]
322 let mut new_shape = Vec::new();
323
324 for i in 0 .. a.ndim() {
325
326 if i != axis1 && i != axis2 {
327
328 new_shape
329 .push(a.shape()[i]);
330 }
331 }
332
333 if a.ndim() == 2 {
347
348 let mut sum = 0.0;
349
350 for i in 0 .. n {
351
352 sum += a[[i, i]];
353 }
354
355 return Ok(
356 ndarray::Array0::from_elem(
357 (),
358 sum,
359 )
360 .into_dyn(),
361 );
362 }
363
364 Err(
365 "General tensor contraction \
366 (trace) for rank > 2 not yet \
367 implemented."
368 .to_string(),
369 )
370}
371
372#[must_use]
374
375pub fn norm(a: &ArrayD<f64>) -> f64 {
376
377 a.iter()
378 .map(|x| x * x)
379 .sum::<f64>()
380 .sqrt()
381}
382
383use serde::Deserialize;
384use serde::Serialize;
385
386#[derive(
388 Serialize, Deserialize, Debug, Clone,
389)]
390
391pub struct TensorData {
392 pub shape: Vec<usize>,
394 pub data: Vec<f64>,
396}
397
398impl From<&ArrayD<f64>> for TensorData {
399 fn from(arr: &ArrayD<f64>) -> Self {
400
401 Self {
402 shape : arr.shape().to_vec(),
403 data : arr
404 .clone()
405 .into_raw_vec_and_offset()
406 .0,
407 }
408 }
409}
410
411impl TensorData {
412 pub fn to_arrayd(
418 &self
419 ) -> Result<ArrayD<f64>, String>
420 {
421
422 ArrayD::from_shape_vec(
423 IxDyn(&self.shape),
424 self.data.clone(),
425 )
426 .map_err(|e| e.to_string())
427 }
428}
429
430#[cfg(test)]
431
432mod tests {
433
434 use ndarray::array;
435
436 use super::*;
437
438 #[test]
439
440 fn test_tensordot() {
441
442 let a = array![
443 [1.0, 2.0],
444 [3.0, 4.0]
445 ]
446 .into_dyn();
447
448 let b = array![
449 [5.0, 6.0],
450 [7.0, 8.0]
451 ]
452 .into_dyn();
453
454 let res = tensordot(
455 &a,
456 &b,
457 &[1],
458 &[0],
459 )
460 .unwrap();
461
462 assert_eq!(
464 res.shape(),
465 &[2, 2]
466 );
467
468 assert_eq!(
469 res[[0, 0]],
470 1.0 * 5.0 + 2.0 * 7.0
471 );
472 }
473
474 #[test]
475
476 fn test_outer_product() {
477
478 let a =
479 array![1.0, 2.0].into_dyn();
480
481 let b =
482 array![3.0, 4.0].into_dyn();
483
484 let res = outer_product(&a, &b)
485 .unwrap();
486
487 assert_eq!(
488 res.shape(),
489 &[2, 2]
490 );
491
492 assert_eq!(res[[0, 0]], 3.0);
493
494 assert_eq!(res[[1, 1]], 8.0);
495 }
496}