diffgeom/tensors/
variance.rs

1//! Module defining variances (types of tensors)
2
3use crate::typenum::bit::Bit;
4use crate::typenum::consts::{B1, U0, U1};
5use crate::typenum::uint::{UInt, Unsigned};
6use crate::typenum::{Add1, Sub1};
7use crate::typenum::{Cmp, Greater, Same};
8use std::ops::{Add, Sub};
9
10/// This enum serves to represent the type of a tensor. A tensor can have any number of indices,
11/// and each one can be either covariant (a lower index), or contravariant (an upper index).
12/// For example, a vector is a tensor with only one contravariant index.
13#[derive(Clone, Copy, PartialEq, Debug)]
14pub enum IndexType {
15    Covariant,
16    Contravariant,
17}
18
19/// Trait identifying a type as representing a tensor variance. It is implemented for
20/// `CovariantIndex`, `ContravariantIndex` and tuples (Index, Variance).
21pub trait Variance {
22    type Rank: Unsigned + Add<B1>;
23    fn rank() -> usize {
24        Self::Rank::to_usize()
25    }
26    fn variance() -> Vec<IndexType>;
27}
28
29impl Variance for () {
30    type Rank = U0;
31
32    fn variance() -> Vec<IndexType> {
33        vec![]
34    }
35}
36
37/// Trait identifying a type as representing a tensor index. It is implemented
38/// for `CovariantIndex` and `ContravariantIndex`.
39pub trait TensorIndex: Variance {
40    fn index_type() -> IndexType;
41}
42
43/// Type representing a contravariant (upper) tensor index.
44pub struct ContravariantIndex;
45impl TensorIndex for ContravariantIndex {
46    fn index_type() -> IndexType {
47        IndexType::Contravariant
48    }
49}
50
51/// Type representing a covariant (lower) tensor index.
52pub struct CovariantIndex;
53impl TensorIndex for CovariantIndex {
54    fn index_type() -> IndexType {
55        IndexType::Covariant
56    }
57}
58
59impl Variance for ContravariantIndex {
60    type Rank = U1;
61    fn variance() -> Vec<IndexType> {
62        vec![IndexType::Contravariant]
63    }
64}
65
66impl Variance for CovariantIndex {
67    type Rank = U1;
68    fn variance() -> Vec<IndexType> {
69        vec![IndexType::Covariant]
70    }
71}
72
73/// Trait representing the other index type
74///
75/// Used for identifying indices that can be contracted
76pub trait OtherIndex: TensorIndex {
77    type Output: TensorIndex;
78}
79
80impl OtherIndex for CovariantIndex {
81    type Output = ContravariantIndex;
82}
83
84impl OtherIndex for ContravariantIndex {
85    type Output = CovariantIndex;
86}
87
88// Back to implementing Variance
89
90impl<T, U> Variance for (T, U)
91where
92    U: Variance,
93    Add1<U::Rank>: Unsigned + Add<B1>,
94    T: TensorIndex,
95{
96    type Rank = Add1<U::Rank>;
97
98    fn variance() -> Vec<IndexType> {
99        let mut result = vec![T::index_type()];
100        result.append(&mut U::variance());
101        result
102    }
103}
104
105/// Operator trait used for concatenating two variances.
106///
107/// Used in tensor outer product.
108pub trait Concat<T: Variance>: Variance {
109    type Output: Variance;
110}
111
112/// Helper type for variance concatenation.
113pub type Joined<T, U> = <T as Concat<U>>::Output;
114
115impl<T, U> Concat<U> for T
116where
117    T: TensorIndex,
118    U: TensorIndex,
119    Add1<<U as Variance>::Rank>: Unsigned + Add<B1>,
120{
121    type Output = (T, U);
122}
123
124impl<T> Concat<T> for ()
125where
126    T: TensorIndex,
127{
128    type Output = T;
129}
130
131impl<T, U, V> Concat<V> for (T, U)
132where
133    T: TensorIndex,
134    V: TensorIndex,
135    U: Variance + Concat<V>,
136    <U as Concat<V>>::Output: Variance,
137    Add1<<U as Variance>::Rank>: Unsigned + Add<B1>,
138    Add1<<Joined<U, V> as Variance>::Rank>: Unsigned + Add<B1>,
139{
140    type Output = (T, <U as Concat<V>>::Output);
141}
142
143impl<T, U, V> Concat<(U, V)> for T
144where
145    T: TensorIndex,
146    U: TensorIndex,
147    V: Variance,
148    Add1<<V as Variance>::Rank>: Unsigned + Add<B1>,
149    Add1<Add1<<V as Variance>::Rank>>: Unsigned + Add<B1>,
150{
151    type Output = (T, (U, V));
152}
153
154impl<T, U, V, W> Concat<(V, W)> for (T, U)
155where
156    T: TensorIndex,
157    U: Variance + Concat<(V, W)>,
158    V: TensorIndex,
159    W: Variance,
160    Add1<<U as Variance>::Rank>: Unsigned + Add<B1>,
161    Add1<<W as Variance>::Rank>: Unsigned + Add<B1>,
162    Add1<<Joined<U, (V, W)> as Variance>::Rank>: Unsigned + Add<B1>,
163{
164    type Output = (T, Joined<U, (V, W)>);
165}
166
167/// Indexing operator trait: Output is equal to the index type at the given position
168///
169/// Warning: Indices are numbered starting from 0!
170pub trait Index<T: Unsigned>: Variance {
171    type Output: TensorIndex;
172}
173
174/// Helper type for variance indexing.
175pub type At<T, U> = <T as Index<U>>::Output;
176
177impl Index<U0> for CovariantIndex {
178    type Output = CovariantIndex;
179}
180
181impl Index<U0> for ContravariantIndex {
182    type Output = ContravariantIndex;
183}
184
185impl<T, V, U, B> Index<UInt<U, B>> for (V, T)
186where
187    V: TensorIndex,
188    U: Unsigned,
189    B: Bit,
190    UInt<U, B>: Sub<B1>,
191    Sub1<UInt<U, B>>: Unsigned,
192    T: Variance + Index<Sub1<UInt<U, B>>>,
193    Add1<<T as Variance>::Rank>: Unsigned + Add<B1>,
194{
195    type Output = At<T, Sub1<UInt<U, B>>>;
196}
197
198impl<T, V> Index<U0> for (V, T)
199where
200    V: TensorIndex,
201    T: Variance,
202    Add1<<T as Variance>::Rank>: Unsigned + Add<B1>,
203{
204    type Output = V;
205}
206
207/// An operator trait, removing the indicated index from a variance
208pub trait RemoveIndex<T: Unsigned>: Variance {
209    type Output: Variance;
210}
211
212/// Helper type for index removal
213pub type Removed<T, U> = <T as RemoveIndex<U>>::Output;
214
215impl RemoveIndex<U0> for CovariantIndex {
216    type Output = ();
217}
218
219impl RemoveIndex<U0> for ContravariantIndex {
220    type Output = ();
221}
222
223impl<U, V> RemoveIndex<U0> for (U, V)
224where
225    U: TensorIndex,
226    V: Variance,
227    Add1<<V as Variance>::Rank>: Unsigned + Add<B1>,
228{
229    type Output = V;
230}
231
232impl<T, B, U, V> RemoveIndex<UInt<T, B>> for (U, V)
233where
234    T: Unsigned,
235    B: Bit,
236    U: TensorIndex,
237    UInt<T, B>: Sub<B1>,
238    Sub1<UInt<T, B>>: Unsigned,
239    V: Variance + RemoveIndex<Sub1<UInt<T, B>>>,
240    (U, V): Variance,
241    (U, Removed<V, Sub1<UInt<T, B>>>): Variance,
242{
243    type Output = (U, Removed<V, Sub1<UInt<T, B>>>);
244}
245
246/// An operator trait representing tensor contraction
247///
248/// Used in tensor inner product
249pub trait Contract<Ul: Unsigned, Uh: Unsigned>: Variance {
250    type Output: Variance;
251}
252
253/// Helper type for contraction
254pub type Contracted<V, Ul, Uh> = <V as Contract<Ul, Uh>>::Output;
255
256impl<Ul, Uh, V> Contract<Ul, Uh> for V
257where
258    Ul: Unsigned,
259    Uh: Unsigned + Sub<B1> + Cmp<Ul>,
260    Sub1<Uh>: Unsigned,
261    <Uh as Cmp<Ul>>::Output: Same<Greater>,
262    V: Index<Ul> + Index<Uh> + RemoveIndex<Ul>,
263    At<V, Ul>: OtherIndex,
264    At<V, Uh>: Same<<At<V, Ul> as OtherIndex>::Output>,
265    Removed<V, Ul>: RemoveIndex<Sub1<Uh>>,
266    Removed<Removed<V, Ul>, Sub1<Uh>>: Variance,
267{
268    type Output = Removed<Removed<V, Ul>, Sub1<Uh>>;
269}
270
271#[cfg(test)]
272mod test {
273    use super::*;
274    use crate::typenum::consts::{U0, U1, U2};
275
276    #[test]
277    fn test_variance() {
278        assert_eq!(
279            <(CovariantIndex, ContravariantIndex) as Variance>::variance(),
280            vec![IndexType::Covariant, IndexType::Contravariant]
281        );
282    }
283
284    #[test]
285    fn test_variance_concat() {
286        assert_eq!(
287            <Joined<CovariantIndex, ContravariantIndex> as Variance>::variance(),
288            vec![IndexType::Covariant, IndexType::Contravariant]
289        );
290
291        assert_eq!(
292            <Joined<(CovariantIndex, CovariantIndex), ContravariantIndex> as Variance>::variance(),
293            vec![
294                IndexType::Covariant,
295                IndexType::Covariant,
296                IndexType::Contravariant
297            ]
298        );
299
300        assert_eq!(
301            <Joined<CovariantIndex, (CovariantIndex, ContravariantIndex)> as Variance>::variance(),
302            vec![
303                IndexType::Covariant,
304                IndexType::Covariant,
305                IndexType::Contravariant
306            ]
307        );
308
309        assert_eq!(<Joined<(ContravariantIndex, CovariantIndex),
310                          (CovariantIndex, ContravariantIndex)> as Variance>::variance(),
311                   vec![IndexType::Contravariant,
312                        IndexType::Covariant,
313                        IndexType::Covariant,
314                        IndexType::Contravariant]);
315    }
316
317    #[test]
318    fn test_index() {
319        assert_eq!(
320            <At<CovariantIndex, U0> as TensorIndex>::index_type(),
321            IndexType::Covariant
322        );
323
324        assert_eq!(
325            <At<(CovariantIndex, ContravariantIndex), U0> as TensorIndex>::index_type(),
326            IndexType::Covariant
327        );
328
329        assert_eq!(
330            <At<(CovariantIndex, ContravariantIndex), U1> as TensorIndex>::index_type(),
331            IndexType::Contravariant
332        );
333
334        assert_eq!(
335            <At<(ContravariantIndex, (CovariantIndex, CovariantIndex)), U0> as TensorIndex>
336                ::index_type(),
337            IndexType::Contravariant);
338
339        assert_eq!(
340            <At<(ContravariantIndex, (CovariantIndex, CovariantIndex)), U2> as TensorIndex>
341                ::index_type(),
342            IndexType::Covariant);
343    }
344
345    #[test]
346    fn test_remove() {
347        assert_eq!(
348            <Removed<CovariantIndex, U0> as Variance>::variance(),
349            vec![]
350        );
351
352        assert_eq!(
353            <Removed<(CovariantIndex, ContravariantIndex), U0> as Variance>::variance(),
354            vec![IndexType::Contravariant]
355        );
356
357        assert_eq!(
358            <Removed<(CovariantIndex, ContravariantIndex), U1> as Variance>::variance(),
359            vec![IndexType::Covariant]
360        );
361
362        assert_eq!(
363            <Removed<(ContravariantIndex, (CovariantIndex, CovariantIndex)), U1> as Variance>
364                ::variance(),
365            vec![IndexType::Contravariant, IndexType::Covariant]);
366    }
367
368    #[test]
369    fn test_contract() {
370        assert_eq!(
371            <Contracted<(CovariantIndex, ContravariantIndex), U0, U1> as Variance>::variance(),
372            vec![]
373        );
374
375        assert_eq!(
376            <Contracted<(ContravariantIndex, CovariantIndex), U0, U1> as Variance>::variance(),
377            vec![]
378        );
379
380        assert_eq!(
381            <Contracted<(ContravariantIndex, (CovariantIndex, CovariantIndex)), U0, U1> as Variance>
382                ::variance(),
383            vec![IndexType::Covariant]);
384
385        assert_eq!(
386            <Contracted<(ContravariantIndex, (CovariantIndex, CovariantIndex)), U0, U2> as Variance>
387                ::variance(),
388            vec![IndexType::Covariant]);
389    }
390}