1use crate::typenum::bit::Bit;
4use crate::typenum::consts::{B1, U0, U1};
5use crate::typenum::uint::{UInt, Unsigned};
6use crate::typenum::{Add1, Sub1};
7use crate::typenum::{Cmp, Greater, Same};
8use std::ops::{Add, Sub};
9
10#[derive(Clone, Copy, PartialEq, Debug)]
14pub enum IndexType {
15 Covariant,
16 Contravariant,
17}
18
19pub trait Variance {
22 type Rank: Unsigned + Add<B1>;
23 fn rank() -> usize {
24 Self::Rank::to_usize()
25 }
26 fn variance() -> Vec<IndexType>;
27}
28
29impl Variance for () {
30 type Rank = U0;
31
32 fn variance() -> Vec<IndexType> {
33 vec![]
34 }
35}
36
37pub trait TensorIndex: Variance {
40 fn index_type() -> IndexType;
41}
42
43pub struct ContravariantIndex;
45impl TensorIndex for ContravariantIndex {
46 fn index_type() -> IndexType {
47 IndexType::Contravariant
48 }
49}
50
51pub struct CovariantIndex;
53impl TensorIndex for CovariantIndex {
54 fn index_type() -> IndexType {
55 IndexType::Covariant
56 }
57}
58
59impl Variance for ContravariantIndex {
60 type Rank = U1;
61 fn variance() -> Vec<IndexType> {
62 vec![IndexType::Contravariant]
63 }
64}
65
66impl Variance for CovariantIndex {
67 type Rank = U1;
68 fn variance() -> Vec<IndexType> {
69 vec![IndexType::Covariant]
70 }
71}
72
73pub trait OtherIndex: TensorIndex {
77 type Output: TensorIndex;
78}
79
80impl OtherIndex for CovariantIndex {
81 type Output = ContravariantIndex;
82}
83
84impl OtherIndex for ContravariantIndex {
85 type Output = CovariantIndex;
86}
87
88impl<T, U> Variance for (T, U)
91where
92 U: Variance,
93 Add1<U::Rank>: Unsigned + Add<B1>,
94 T: TensorIndex,
95{
96 type Rank = Add1<U::Rank>;
97
98 fn variance() -> Vec<IndexType> {
99 let mut result = vec![T::index_type()];
100 result.append(&mut U::variance());
101 result
102 }
103}
104
105pub trait Concat<T: Variance>: Variance {
109 type Output: Variance;
110}
111
112pub type Joined<T, U> = <T as Concat<U>>::Output;
114
115impl<T, U> Concat<U> for T
116where
117 T: TensorIndex,
118 U: TensorIndex,
119 Add1<<U as Variance>::Rank>: Unsigned + Add<B1>,
120{
121 type Output = (T, U);
122}
123
124impl<T> Concat<T> for ()
125where
126 T: TensorIndex,
127{
128 type Output = T;
129}
130
131impl<T, U, V> Concat<V> for (T, U)
132where
133 T: TensorIndex,
134 V: TensorIndex,
135 U: Variance + Concat<V>,
136 <U as Concat<V>>::Output: Variance,
137 Add1<<U as Variance>::Rank>: Unsigned + Add<B1>,
138 Add1<<Joined<U, V> as Variance>::Rank>: Unsigned + Add<B1>,
139{
140 type Output = (T, <U as Concat<V>>::Output);
141}
142
143impl<T, U, V> Concat<(U, V)> for T
144where
145 T: TensorIndex,
146 U: TensorIndex,
147 V: Variance,
148 Add1<<V as Variance>::Rank>: Unsigned + Add<B1>,
149 Add1<Add1<<V as Variance>::Rank>>: Unsigned + Add<B1>,
150{
151 type Output = (T, (U, V));
152}
153
154impl<T, U, V, W> Concat<(V, W)> for (T, U)
155where
156 T: TensorIndex,
157 U: Variance + Concat<(V, W)>,
158 V: TensorIndex,
159 W: Variance,
160 Add1<<U as Variance>::Rank>: Unsigned + Add<B1>,
161 Add1<<W as Variance>::Rank>: Unsigned + Add<B1>,
162 Add1<<Joined<U, (V, W)> as Variance>::Rank>: Unsigned + Add<B1>,
163{
164 type Output = (T, Joined<U, (V, W)>);
165}
166
167pub trait Index<T: Unsigned>: Variance {
171 type Output: TensorIndex;
172}
173
174pub type At<T, U> = <T as Index<U>>::Output;
176
177impl Index<U0> for CovariantIndex {
178 type Output = CovariantIndex;
179}
180
181impl Index<U0> for ContravariantIndex {
182 type Output = ContravariantIndex;
183}
184
185impl<T, V, U, B> Index<UInt<U, B>> for (V, T)
186where
187 V: TensorIndex,
188 U: Unsigned,
189 B: Bit,
190 UInt<U, B>: Sub<B1>,
191 Sub1<UInt<U, B>>: Unsigned,
192 T: Variance + Index<Sub1<UInt<U, B>>>,
193 Add1<<T as Variance>::Rank>: Unsigned + Add<B1>,
194{
195 type Output = At<T, Sub1<UInt<U, B>>>;
196}
197
198impl<T, V> Index<U0> for (V, T)
199where
200 V: TensorIndex,
201 T: Variance,
202 Add1<<T as Variance>::Rank>: Unsigned + Add<B1>,
203{
204 type Output = V;
205}
206
207pub trait RemoveIndex<T: Unsigned>: Variance {
209 type Output: Variance;
210}
211
212pub type Removed<T, U> = <T as RemoveIndex<U>>::Output;
214
215impl RemoveIndex<U0> for CovariantIndex {
216 type Output = ();
217}
218
219impl RemoveIndex<U0> for ContravariantIndex {
220 type Output = ();
221}
222
223impl<U, V> RemoveIndex<U0> for (U, V)
224where
225 U: TensorIndex,
226 V: Variance,
227 Add1<<V as Variance>::Rank>: Unsigned + Add<B1>,
228{
229 type Output = V;
230}
231
232impl<T, B, U, V> RemoveIndex<UInt<T, B>> for (U, V)
233where
234 T: Unsigned,
235 B: Bit,
236 U: TensorIndex,
237 UInt<T, B>: Sub<B1>,
238 Sub1<UInt<T, B>>: Unsigned,
239 V: Variance + RemoveIndex<Sub1<UInt<T, B>>>,
240 (U, V): Variance,
241 (U, Removed<V, Sub1<UInt<T, B>>>): Variance,
242{
243 type Output = (U, Removed<V, Sub1<UInt<T, B>>>);
244}
245
246pub trait Contract<Ul: Unsigned, Uh: Unsigned>: Variance {
250 type Output: Variance;
251}
252
253pub type Contracted<V, Ul, Uh> = <V as Contract<Ul, Uh>>::Output;
255
256impl<Ul, Uh, V> Contract<Ul, Uh> for V
257where
258 Ul: Unsigned,
259 Uh: Unsigned + Sub<B1> + Cmp<Ul>,
260 Sub1<Uh>: Unsigned,
261 <Uh as Cmp<Ul>>::Output: Same<Greater>,
262 V: Index<Ul> + Index<Uh> + RemoveIndex<Ul>,
263 At<V, Ul>: OtherIndex,
264 At<V, Uh>: Same<<At<V, Ul> as OtherIndex>::Output>,
265 Removed<V, Ul>: RemoveIndex<Sub1<Uh>>,
266 Removed<Removed<V, Ul>, Sub1<Uh>>: Variance,
267{
268 type Output = Removed<Removed<V, Ul>, Sub1<Uh>>;
269}
270
271#[cfg(test)]
272mod test {
273 use super::*;
274 use crate::typenum::consts::{U0, U1, U2};
275
276 #[test]
277 fn test_variance() {
278 assert_eq!(
279 <(CovariantIndex, ContravariantIndex) as Variance>::variance(),
280 vec![IndexType::Covariant, IndexType::Contravariant]
281 );
282 }
283
284 #[test]
285 fn test_variance_concat() {
286 assert_eq!(
287 <Joined<CovariantIndex, ContravariantIndex> as Variance>::variance(),
288 vec![IndexType::Covariant, IndexType::Contravariant]
289 );
290
291 assert_eq!(
292 <Joined<(CovariantIndex, CovariantIndex), ContravariantIndex> as Variance>::variance(),
293 vec![
294 IndexType::Covariant,
295 IndexType::Covariant,
296 IndexType::Contravariant
297 ]
298 );
299
300 assert_eq!(
301 <Joined<CovariantIndex, (CovariantIndex, ContravariantIndex)> as Variance>::variance(),
302 vec![
303 IndexType::Covariant,
304 IndexType::Covariant,
305 IndexType::Contravariant
306 ]
307 );
308
309 assert_eq!(<Joined<(ContravariantIndex, CovariantIndex),
310 (CovariantIndex, ContravariantIndex)> as Variance>::variance(),
311 vec![IndexType::Contravariant,
312 IndexType::Covariant,
313 IndexType::Covariant,
314 IndexType::Contravariant]);
315 }
316
317 #[test]
318 fn test_index() {
319 assert_eq!(
320 <At<CovariantIndex, U0> as TensorIndex>::index_type(),
321 IndexType::Covariant
322 );
323
324 assert_eq!(
325 <At<(CovariantIndex, ContravariantIndex), U0> as TensorIndex>::index_type(),
326 IndexType::Covariant
327 );
328
329 assert_eq!(
330 <At<(CovariantIndex, ContravariantIndex), U1> as TensorIndex>::index_type(),
331 IndexType::Contravariant
332 );
333
334 assert_eq!(
335 <At<(ContravariantIndex, (CovariantIndex, CovariantIndex)), U0> as TensorIndex>
336 ::index_type(),
337 IndexType::Contravariant);
338
339 assert_eq!(
340 <At<(ContravariantIndex, (CovariantIndex, CovariantIndex)), U2> as TensorIndex>
341 ::index_type(),
342 IndexType::Covariant);
343 }
344
345 #[test]
346 fn test_remove() {
347 assert_eq!(
348 <Removed<CovariantIndex, U0> as Variance>::variance(),
349 vec![]
350 );
351
352 assert_eq!(
353 <Removed<(CovariantIndex, ContravariantIndex), U0> as Variance>::variance(),
354 vec![IndexType::Contravariant]
355 );
356
357 assert_eq!(
358 <Removed<(CovariantIndex, ContravariantIndex), U1> as Variance>::variance(),
359 vec![IndexType::Covariant]
360 );
361
362 assert_eq!(
363 <Removed<(ContravariantIndex, (CovariantIndex, CovariantIndex)), U1> as Variance>
364 ::variance(),
365 vec![IndexType::Contravariant, IndexType::Covariant]);
366 }
367
368 #[test]
369 fn test_contract() {
370 assert_eq!(
371 <Contracted<(CovariantIndex, ContravariantIndex), U0, U1> as Variance>::variance(),
372 vec![]
373 );
374
375 assert_eq!(
376 <Contracted<(ContravariantIndex, CovariantIndex), U0, U1> as Variance>::variance(),
377 vec![]
378 );
379
380 assert_eq!(
381 <Contracted<(ContravariantIndex, (CovariantIndex, CovariantIndex)), U0, U1> as Variance>
382 ::variance(),
383 vec![IndexType::Covariant]);
384
385 assert_eq!(
386 <Contracted<(ContravariantIndex, (CovariantIndex, CovariantIndex)), U0, U2> as Variance>
387 ::variance(),
388 vec![IndexType::Covariant]);
389 }
390}