dlt/tensor/
ops.rs

1// Tensor operation implementations
2
3use crate::complex::c64;
4use crate::dimension::{Dimension, InvertDimension, MultiplyDimensions};
5use crate::tensor::element::TensorElement;
6use crate::tensor::Tensor;
7use std::marker::PhantomData;
8use std::ops::{Add, Mul, Neg, Sub, Div};
9use crate::*;
10
11// -----------------------------------------
12// ============= OPERATIONS ================
13// -----------------------------------------
14
15impl<E: TensorElement + Add<Output = E> + Copy, D, const LAYERS: usize, const ROWS: usize, const COLS: usize>
16Add for Tensor<E, D, LAYERS, ROWS, COLS>
17where
18    [(); LAYERS * ROWS * COLS]:,
19{
20    type Output = Self;
21
22    fn add(self, other: Self) -> Self {
23        let data: [E; LAYERS * ROWS * COLS] = self
24            .data
25            .iter()
26            .zip(other.data.iter())
27            .map(|(&a, &b)| a + b)
28            .collect::<Vec<_>>()
29            .try_into()
30            .unwrap();
31
32        Self {
33            data,
34            _phantom: PhantomData,
35        }
36    }
37}
38
39impl<E: TensorElement + Sub<Output = E> + Copy, D, const LAYERS: usize, const ROWS: usize, const COLS: usize>
40Sub for Tensor<E, D, LAYERS, ROWS, COLS>
41where
42    [(); LAYERS * ROWS * COLS]:,
43{
44    type Output = Self;
45
46    fn sub(self, other: Self) -> Self {
47        let data: [E; LAYERS * ROWS * COLS] = self
48            .data
49            .iter()
50            .zip(other.data.iter())
51            .map(|(&a, &b)| a - b)
52            .collect::<Vec<_>>()
53            .try_into()
54            .unwrap();
55
56        Self {
57            data,
58            _phantom: PhantomData,
59        }
60    }
61}
62
63impl<
64    E: TensorElement + Mul<Output = E> + Add<Output = E> + Copy,
65    const LAYERS: usize,
66    const L1: i32,
67    const M1: i32,
68    const T1: i32,
69    const Θ1: i32,
70    const I1: i32,
71    const N1: i32,
72    const J1: i32,
73    const ROWS: usize,
74    const COMMON: usize,
75> Tensor<E, Dimension<L1, M1, T1, Θ1, I1, N1, J1>, LAYERS, ROWS, COMMON>
76where
77    [(); LAYERS * ROWS * COMMON]:,
78{
79    /// Performs matrix multiplication between two tensors.
80    pub fn matmul<
81        const L2: i32,
82        const M2: i32,
83        const T2: i32,
84        const Θ2: i32,
85        const I2: i32,
86        const N2: i32,
87        const J2: i32,
88        const COLS: usize,
89    >(
90        self,
91        other: Tensor<E, Dimension<L2, M2, T2, Θ2, I2, N2, J2>, LAYERS, COMMON, COLS>,
92    ) -> Tensor<
93        E,
94        <Dimension<L1, M1, T1, Θ1, I1, N1, J1> as MultiplyDimensions<
95            Dimension<L2, M2, T2, Θ2, I2, N2, J2>
96        >>::Output,
97        LAYERS,
98        ROWS,
99        COLS,
100    >
101    where
102        Dimension<L1, M1, T1, Θ1, I1, N1, J1>: MultiplyDimensions<Dimension<L2, M2, T2, Θ2, I2, N2, J2>>,
103        [(); LAYERS * COMMON * COLS]:,
104        [(); LAYERS * ROWS * COLS]:,
105        [(); COLS]:,
106    {
107        // Create a vector to store the output tensor, initializing all entries to zero.
108        let mut result = vec![E::zero(); LAYERS * ROWS * COLS];
109
110        // Iterate over layers, rows, and columns to compute each element.
111        for layer in 0..LAYERS {
112            for row in 0..ROWS {
113                for col in 0..COLS {
114                    let mut sum = E::zero();
115                    // Compute the dot product for the current (layer, row, col) element.
116                    for k in 0..COMMON {
117                        let index_a = layer * (ROWS * COMMON) + row * COMMON + k;
118                        let index_b = layer * (COMMON * COLS) + k * COLS + col;
119                        sum = sum + self.data[index_a] * other.data[index_b];
120                    }
121                    let index_result = layer * (ROWS * COLS) + row * COLS + col;
122                    result[index_result] = sum;
123                }
124            }
125        }
126
127        // Convert the result vector into an array.
128        let data: [E; LAYERS * ROWS * COLS] =
129            result.into_iter().collect::<Vec<E>>().try_into().unwrap();
130
131        Tensor {
132            data,
133            _phantom: PhantomData,
134        }
135    }
136}
137
138impl<E: TensorElement + Mul<Output = E> + Copy, D, const LAYERS: usize, const ROWS: usize, const COLS: usize>
139Tensor<E, D, LAYERS, ROWS, COLS>
140where
141    [(); LAYERS * ROWS * COLS]:,
142{
143    /// Multiplies every element of the tensor by a scalar.
144    pub fn scale<DS>(
145        self,
146        scalar: Tensor<E, DS, 1, 1, 1>,
147    ) -> Tensor<E, <D as MultiplyDimensions<DS>>::Output, LAYERS, ROWS, COLS>
148    where
149        D: MultiplyDimensions<DS>,
150        <D as MultiplyDimensions<DS>>::Output:,
151            {
152        let s = scalar.data[0];
153        let data: [E; LAYERS * ROWS * COLS] = self
154            .data
155            .iter()
156            .map(|&v| v * s)
157            .collect::<Vec<_>>()
158            .try_into()
159            .unwrap();
160
161        Tensor {
162            data,
163            _phantom: PhantomData::<<D as MultiplyDimensions<DS>>::Output>,
164        }
165    }
166}
167
168impl<E, D, DS, const LAYERS: usize, const ROWS: usize, const COLS: usize>
169    Mul<Tensor<E, DS, 1, 1, 1>> for Tensor<E, D, LAYERS, ROWS, COLS>
170where
171    E: TensorElement + Mul<Output = E> + Copy,
172    D: MultiplyDimensions<DS>,
173    [(); LAYERS * ROWS * COLS]:,
174{
175    type Output = Tensor<E, <D as MultiplyDimensions<DS>>::Output, LAYERS, ROWS, COLS>;
176
177    fn mul(self, rhs: Tensor<E, DS, 1, 1, 1>) -> Self::Output {
178        self.scale(rhs)
179    }
180}
181
182
183impl<E, D, DS, const LAYERS: usize, const ROWS: usize, const COLS: usize>
184    Div<Tensor<E, DS, 1, 1, 1>> for Tensor<E, D, LAYERS, ROWS, COLS>
185where
186    E: TensorElement + Div<Output = E> + Copy,
187    DS: InvertDimension,
188    D: MultiplyDimensions<<DS as InvertDimension>::Output>,
189    [(); LAYERS * ROWS * COLS]:,
190{
191    type Output = Tensor<
192        E,
193        <D as MultiplyDimensions<<DS as InvertDimension>::Output>>::Output,
194        LAYERS,
195        ROWS,
196        COLS
197    >;
198
199    fn div(self, rhs: Tensor<E, DS, 1, 1, 1>) -> Self::Output {
200        self.scale(rhs.inv())
201    }
202}
203
204impl<E: TensorElement + Div<Output = E> + Copy + PartialEq, D> Tensor<E, D, 1, 1, 1>
205where
206    [(); 1]:,
207{
208    pub fn inv(self) -> Tensor<E, <D as InvertDimension>::Output, 1, 1, 1>
209    where
210        D: InvertDimension,
211    {
212        let data: [E; 1] = [E::one() / self.data[0]];
213        Tensor {
214            data,
215            _phantom: PhantomData,
216        }
217    }
218}
219
220impl<E: TensorElement + Neg<Output = E> + Copy, D, const LAYERS: usize, const ROWS: usize, const COLS: usize>
221Neg for Tensor<E, D, LAYERS, ROWS, COLS>
222where
223    [(); LAYERS * ROWS * COLS]:,
224{
225    type Output = Self;
226
227    fn neg(self) -> Self {
228        let data: [E; LAYERS * ROWS * COLS] = self
229            .data
230            .iter()
231            .map(|&v| -v)
232            .collect::<Vec<_>>()
233            .try_into()
234            .unwrap();
235
236        Self {
237            data,
238            _phantom: PhantomData,
239        }
240    }
241}
242
243impl<E, D, const LAYERS: usize, const ROWS: usize, const COLS: usize> Tensor<E, D, LAYERS, ROWS, COLS>
244where
245    E: TensorElement + Into<c64> + Copy,
246    [(); LAYERS * ROWS * COLS]:,
247{
248    /// Converts the tensor to one with c64 elements by mapping each element via Into<c64>.
249    pub fn to_c64(&self) -> Tensor<c64, D, LAYERS, ROWS, COLS> {
250        let data: [c64; LAYERS * ROWS * COLS] = self
251            .data
252            .iter()
253            .map(|&v| v.into())
254            .collect::<Vec<_>>()
255            .try_into()
256            .unwrap();
257        Tensor {
258            data,
259            _phantom: PhantomData,
260        }
261    }
262}
263
264// implement conjugate for all tensors
265impl<E: TensorElement,D, const LAYERS: usize, const ROWS: usize, const COLS: usize> Tensor<E,D, LAYERS, ROWS, COLS>
266where
267    [(); LAYERS * ROWS * COLS]:,
268{
269    pub fn conjugate(self) -> Self {
270        let data: [E; LAYERS * ROWS * COLS] = self
271            .data
272            .iter()
273            .map(|&v| v.conjugate())
274            .collect::<Vec<_>>()
275            .try_into()
276            .unwrap();
277
278        Self {
279            data,
280            _phantom: PhantomData,
281        }
282    }
283
284    /// Returns the conjugate transpose of this tensor.
285    pub fn conjugate_transpose(self) -> Tensor<E,D, LAYERS, COLS, ROWS>
286    where
287        [(); LAYERS * COLS * ROWS]:,
288    {
289        self.transpose().conjugate()
290    }
291}
292
293impl<E: TensorElement,D, const LAYERS: usize, const ROWS: usize, const COLS: usize> Tensor<E,D, LAYERS, ROWS, COLS>
294where
295    [(); LAYERS * ROWS * COLS]:,
296{
297    /// Returns the transpose of this tensor.
298    pub fn transpose(self) -> Tensor<E,D, LAYERS, COLS, ROWS>
299    where
300        [(); LAYERS * COLS * ROWS]:,
301    {
302        let mut transposed = [E::zero(); LAYERS * COLS * ROWS];
303        for l in 0..LAYERS {
304            for i in 0..ROWS {
305                for j in 0..COLS {
306                    // Element at (i, j) moves to (j, i)
307                    let src = l * (ROWS * COLS) + i * COLS + j;
308                    let dst = l * (COLS * ROWS) + j * ROWS + i;
309                    transposed[dst] = self.data[src];
310                }
311            }
312        }
313        Tensor::<E,D, LAYERS, COLS, ROWS> {
314            data: transposed,
315            _phantom: PhantomData,
316        }
317    }
318}
319
320
321
322
323// Fix the call site by manually extracting the element for sqrt:
324impl<
325    E: TensorElement + Into<f64> + Copy,
326    const L: i32,
327    const M: i32,
328    const T: i32,
329    const Θ: i32,
330    const I: i32,
331    const N: i32,
332    const J: i32,
333    const ROWS: usize
334>
335Tensor<E, Dimension<L, M, T, Θ, I, N, J>, 1, ROWS, 1>
336where
337    [(); 1 * ROWS * 1]:,
338    [(); 1 * 1 * ROWS]:,
339    [(); ROWS * 1 * 1]:,
340{
341    pub fn norm(
342        self
343    ) -> Tensor<f64, Dimension<L, M, T, Θ, I, N, J>, 1, 1, 1>
344    where
345        [(); { <() as ConstAdd<L, L>>::OUTPUT } as usize]:,
346        [(); { <() as ConstAdd<M, M>>::OUTPUT } as usize]:,
347        [(); { <() as ConstAdd<T, T>>::OUTPUT } as usize]:,
348        [(); { <() as ConstAdd<Θ, Θ>>::OUTPUT } as usize]:,
349        [(); { <() as ConstAdd<I, I>>::OUTPUT } as usize]:,
350        [(); { <() as ConstAdd<N, N>>::OUTPUT } as usize]:,
351        [(); { <() as ConstAdd<J, J>>::OUTPUT } as usize]:,
352    {
353        let ct: Tensor<E, Dimension<L, M, T, Θ, I, N, J>, 1, 1, ROWS> = self.conjugate_transpose();
354        let i: Tensor<E, Dimension<_, _, _, _, _, _, _>, 1, 1, 1> = ct.matmul(self);
355
356        // Manually extract the single element and compute sqrt().
357        let val: c64 = i.data[0].into();
358        let sqrt_val = f64::from(val.sqrt());
359
360        Tensor {
361            data: [sqrt_val],
362            _phantom: PhantomData,
363        }
364    }
365
366    pub fn dist(
367        self,
368        other: Self,
369    ) -> Tensor<f64, Dimension<L, M, T, Θ, I, N, J>, 1, 1, 1>
370    where
371        [(); 1 * ROWS * 1]:,
372        [(); 1 * 1 * ROWS]:,
373        [(); { <() as ConstAdd<L, L>>::OUTPUT } as usize]:,
374        [(); { <() as ConstAdd<M, M>>::OUTPUT } as usize]:,
375        [(); { <() as ConstAdd<T, T>>::OUTPUT } as usize]:,
376        [(); { <() as ConstAdd<Θ, Θ>>::OUTPUT } as usize]:,
377        [(); { <() as ConstAdd<I, I>>::OUTPUT } as usize]:,
378        [(); { <() as ConstAdd<N, N>>::OUTPUT } as usize]:,
379        [(); { <() as ConstAdd<J, J>>::OUTPUT } as usize]:,
380    {
381       let sub = self - other;
382       sub.norm()
383    }
384
385}
386    
387// Implement elementwise equality for all tensors.
388impl<E: TensorElement,D, const LAYERS: usize, const ROWS: usize, const COLS: usize> PartialEq for Tensor<E,D, LAYERS, ROWS, COLS>
389where
390    [(); LAYERS * ROWS * COLS]:,
391{
392    fn eq(&self, other: &Self) -> bool {
393        self.data
394            .iter()
395            .zip(other.data.iter())
396            .all(|(&a, &b)| a == b)
397    }
398}
399
400// Optionally, if c64: Eq then implement Eq.
401impl<E:TensorElement,D, const LAYERS: usize, const ROWS: usize, const COLS: usize> Eq for Tensor<E,D, LAYERS, ROWS, COLS>
402where
403    [(); LAYERS * ROWS * COLS]:,
404    c64: Eq,
405{
406}
407
408// Implement ordering (>, >=, <, <=) for 1×1×1 tensors only.
409impl<E: TensorElement,D> PartialOrd for Tensor<E,D, 1, 1, 1>
410where
411    [(); 1]:,
412{
413    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
414        self.data[0].partial_cmp(&other.data[0])
415    }
416}
417
418// implement dot product as macro that does transpose and multiply
419#[macro_export]
420macro_rules! dot {
421    ($a:expr, $b:expr) => {{
422        let a = $a;
423        let b = $b;
424        let a_t = a.transpose();
425        let result = a_t.matmul(b);
426        result
427    }};
428}
429
430#[macro_export]
431macro_rules! inner_product {
432    ($a:expr, $b:expr) => {{
433        let a = $a;
434        let b = $b;
435        let a_t = a.conjugate_transpose();
436        let result = a_t.matmul(b);
437        result
438    }};
439}
440
441#[macro_export]
442macro_rules! ip {
443    ($x:expr, $y:expr) => {
444        inner_product!($x, $y)
445    };
446}
447
448// linear algebraic stuff
449// reshape, flatten, etc.
450
451impl<E: TensorElement,D, const LAYERS: usize, const ROWS: usize, const COLS: usize> Tensor<E,D, LAYERS, ROWS, COLS>
452where
453    [(); LAYERS * ROWS * COLS]:,
454{
455    /// Returns the number of elements in the tensor.
456    pub fn size(&self) -> usize {
457        LAYERS * ROWS * COLS
458    }
459
460    pub fn shape(&self) -> (usize, usize, usize) {
461        (LAYERS, ROWS, COLS)
462    }
463
464    /// Returns the number of layers in the tensor.
465    pub fn layers(&self) -> usize {
466        LAYERS
467    }
468
469    /// Returns the number of rows in the tensor.
470    pub fn rows(&self) -> usize {
471        ROWS
472    }
473
474    /// Returns the number of columns in the tensor.
475    pub fn cols(&self) -> usize {
476        COLS
477    }
478
479    /// Returns the data as a slice.
480    pub fn data(&self) -> &[E] {
481        &self.data
482    }
483
484    pub fn reshape<const L: usize, const R: usize, const C: usize>(
485        &self,
486    ) -> Tensor<E,D, L, R, C>
487    where
488        [(); L * R * C]:,
489    {
490        assert_eq!(LAYERS * ROWS * COLS, L * R * C);
491        let data: [E; L * R * C] = self
492            .data
493            .iter()
494            .copied()
495            .collect::<Vec<_>>()
496            .try_into()
497            .unwrap();
498
499        Tensor {
500            data,
501            _phantom: PhantomData,
502        }
503    }
504
505    pub fn flatten(&self) -> Tensor<E,D, 1, 1, {LAYERS * ROWS * COLS}>
506    where
507        [(); LAYERS * ROWS * COLS]:,
508        [(); 1 * 1 * (LAYERS * ROWS * COLS)]:,
509    {
510        self.reshape::<1, 1, {LAYERS * ROWS * COLS}>()
511    }
512
513}
514
515// implement all the boolean operations (and, or, xor, not) 
516// they should return a tensor of the same size wth 0s and 1s
517impl<E: TensorElement,D, const LAYERS: usize, const ROWS: usize, const COLS: usize> Tensor<E,D, LAYERS, ROWS, COLS>
518where
519    [(); LAYERS * ROWS * COLS]:,
520{
521    pub fn and(self, other: Self) -> Self {
522        let data: [E; LAYERS * ROWS * COLS] = self
523            .data
524            .iter()
525            .zip(other.data.iter())
526            .map(|(&a, &b)| if a != E::zero() && b != E::zero() { E::one() } else { E::zero() })
527            .collect::<Vec<_>>()
528            .try_into()
529            .unwrap();
530
531        Self {
532            data,
533            _phantom: PhantomData,
534        }
535    }
536}
537
538impl<E: TensorElement + PartialEq + Copy, D, const LAYERS: usize, const ROWS: usize, const COLS: usize>
539Tensor<E, D, LAYERS, ROWS, COLS>
540where
541    [(); LAYERS * ROWS * COLS]:,
542{
543    pub fn or(self, other: Self) -> Self {
544        let data: [E; LAYERS * ROWS * COLS] = self
545            .data
546            .iter()
547            .zip(other.data.iter())
548            .map(|(&a, &b)| if a != E::zero() || b != E::zero() { E::one() } else { E::zero() })
549            .collect::<Vec<_>>()
550            .try_into()
551            .unwrap();
552
553        Self {
554            data,
555            _phantom: PhantomData,
556        }
557    }
558}
559
560
561
562// implement all the comparison operations (eq, ne, gt, ge, lt, le)
563// overload the operators
564impl<E: TensorElement,D, const LAYERS: usize, const ROWS: usize, const COLS: usize> Tensor<E,D, LAYERS, ROWS, COLS>
565where
566    [(); LAYERS * ROWS * COLS]:,
567{
568    pub fn eq(self, other: Self) -> Self {
569        let data: [E; LAYERS * ROWS * COLS] = self
570            .data
571            .iter()
572            .zip(other.data.iter())
573            .map(|(&a, &b)| if a == b { E::one() } else { E::zero() })
574            .collect::<Vec<_>>()
575            .try_into()
576            .unwrap();
577
578        Self {
579            data,
580            _phantom: PhantomData,
581        }
582    }
583}
584impl<E: TensorElement,D, const LAYERS: usize, const ROWS: usize, const COLS: usize> Tensor<E,D, LAYERS, ROWS, COLS>
585where
586    [(); LAYERS * ROWS * COLS]:,
587{
588    pub fn ne(self, other: Self) -> Self {
589        let data: [E; LAYERS * ROWS * COLS] = self
590            .data
591            .iter()
592            .zip(other.data.iter())
593            .map(|(&a, &b)| if a != b { E::one() } else { E::one() })
594            .collect::<Vec<_>>()
595            .try_into()
596            .unwrap();
597
598        Self {
599            data,
600            _phantom: PhantomData,
601        }
602    }
603}
604impl<E: TensorElement,D, const LAYERS: usize, const ROWS: usize, const COLS: usize> Tensor<E,D, LAYERS, ROWS, COLS>
605where
606    [(); LAYERS * ROWS * COLS]:,
607{
608    pub fn gt(self, other: Self) -> Self {
609        let data: [E; LAYERS * ROWS * COLS] = self
610            .data
611            .iter()
612            .zip(other.data.iter())
613            .map(|(&a, &b)| if a > b { E::one() } else { E::zero() })
614            .collect::<Vec<_>>()    
615            .try_into()
616            .unwrap();
617        Self {
618            data,
619            _phantom: PhantomData,
620        }
621    }
622}
623
624impl<E: TensorElement,D, const LAYERS: usize, const ROWS: usize, const COLS: usize> Tensor<E,D, LAYERS, ROWS, COLS>
625where
626    [(); LAYERS * ROWS * COLS]:,
627{
628    pub fn ge(self, other: Self) -> Self {
629        let data: [E; LAYERS * ROWS * COLS] = self
630            .data
631            .iter()
632            .zip(other.data.iter())
633            .map(|(&a, &b)| if a >= b { E::zero() } else { E::one() })
634            .collect::<Vec<_>>()
635            .try_into()
636            .unwrap();
637
638        Self {
639            data,
640            _phantom: PhantomData,
641        }
642    }
643}