1use karpal_arrow::Arrow;
5#[cfg(any(feature = "std", feature = "alloc"))]
6use karpal_arrow::FnA;
7
8pub type AssocLeft<A, B, C> = ((A, B), C);
9pub type AssocRight<A, B, C> = (A, (B, C));
10pub type HexagonTarget<A, B, C> = (B, (C, A));
11
12pub trait Tensor: Arrow {
16 fn tensor<A: Clone + 'static, B: Clone + 'static, C: Clone + 'static, D: Clone + 'static>(
18 left: Self::P<A, B>,
19 right: Self::P<C, D>,
20 ) -> Self::P<(A, C), (B, D)>;
21
22 fn associate<A: Clone + 'static, B: Clone + 'static, C: Clone + 'static>()
24 -> Self::P<AssocLeft<A, B, C>, AssocRight<A, B, C>> {
25 Self::arr(|((a, b), c): AssocLeft<A, B, C>| (a, (b, c)))
26 }
27
28 fn associate_inv<A: Clone + 'static, B: Clone + 'static, C: Clone + 'static>()
30 -> Self::P<AssocRight<A, B, C>, AssocLeft<A, B, C>> {
31 Self::arr(|(a, (b, c)): AssocRight<A, B, C>| ((a, b), c))
32 }
33
34 fn left_unitor<A: Clone + 'static>() -> Self::P<((), A), A> {
36 Self::arr(|(_, a): ((), A)| a)
37 }
38
39 fn left_unitor_inv<A: Clone + 'static>() -> Self::P<A, ((), A)> {
41 Self::arr(|a: A| ((), a))
42 }
43
44 fn right_unitor<A: Clone + 'static>() -> Self::P<(A, ()), A> {
46 Self::arr(|(a, _): (A, ())| a)
47 }
48
49 fn right_unitor_inv<A: Clone + 'static>() -> Self::P<A, (A, ())> {
51 Self::arr(|a: A| (a, ()))
52 }
53}
54
55#[cfg(any(feature = "std", feature = "alloc"))]
56impl Tensor for FnA {
57 fn tensor<A: Clone + 'static, B: Clone + 'static, C: Clone + 'static, D: Clone + 'static>(
58 left: Self::P<A, B>,
59 right: Self::P<C, D>,
60 ) -> Self::P<(A, C), (B, D)> {
61 Self::split(left, right)
62 }
63}
64
65#[cfg(all(test, any(feature = "std", feature = "alloc")))]
66mod tests {
67 use super::*;
68 use karpal_arrow::Semigroupoid;
69
70 #[test]
71 fn fna_tensor_runs_both_sides_in_parallel() {
72 let left = FnA::arr(|n: i32| n * 2);
73 let right = FnA::arr(|flag: bool| !flag);
74 let combined = FnA::tensor(left, right);
75
76 assert_eq!(combined((3, true)), (6, false));
77 }
78
79 #[test]
80 fn associator_round_trips() {
81 let assoc = FnA::associate::<i32, bool, &'static str>();
82 let assoc_inv = FnA::associate_inv::<i32, bool, &'static str>();
83 let round_trip = FnA::compose(assoc_inv, assoc);
84
85 assert_eq!(round_trip(((4, true), "x")), ((4, true), "x"));
86 }
87
88 #[test]
89 fn unitors_round_trip() {
90 let left = FnA::compose(FnA::left_unitor::<i32>(), FnA::left_unitor_inv::<i32>());
91 let right = FnA::compose(FnA::right_unitor::<i32>(), FnA::right_unitor_inv::<i32>());
92
93 assert_eq!(left(5), 5);
94 assert_eq!(right(5), 5);
95 }
96}