Skip to main content

karpal_diagram/
tensor.rs

1// Copyright (C) 2026 Industrial Algebra
2// SPDX-License-Identifier: Apache-2.0
3
4use karpal_arrow::Arrow;
5#[cfg(any(feature = "std", feature = "alloc"))]
6use karpal_arrow::FnA;
7
8pub type AssocLeft<A, B, C> = ((A, B), C);
9pub type AssocRight<A, B, C> = (A, (B, C));
10pub type HexagonTarget<A, B, C> = (B, (C, A));
11
12/// Monoidal structure for categories whose objects can be tensored in parallel.
13///
14/// In this initial encoding, the tensor product is modeled with Rust tuples.
15pub trait Tensor: Arrow {
16    /// Tensor two morphisms in parallel.
17    fn tensor<A: Clone + 'static, B: Clone + 'static, C: Clone + 'static, D: Clone + 'static>(
18        left: Self::P<A, B>,
19        right: Self::P<C, D>,
20    ) -> Self::P<(A, C), (B, D)>;
21
22    /// The left-associated product `((a, b), c) -> (a, (b, c))`.
23    fn associate<A: Clone + 'static, B: Clone + 'static, C: Clone + 'static>()
24    -> Self::P<AssocLeft<A, B, C>, AssocRight<A, B, C>> {
25        Self::arr(|((a, b), c): AssocLeft<A, B, C>| (a, (b, c)))
26    }
27
28    /// The inverse associator `(a, (b, c)) -> ((a, b), c)`.
29    fn associate_inv<A: Clone + 'static, B: Clone + 'static, C: Clone + 'static>()
30    -> Self::P<AssocRight<A, B, C>, AssocLeft<A, B, C>> {
31        Self::arr(|(a, (b, c)): AssocRight<A, B, C>| ((a, b), c))
32    }
33
34    /// Left unitor `((), a) -> a`.
35    fn left_unitor<A: Clone + 'static>() -> Self::P<((), A), A> {
36        Self::arr(|(_, a): ((), A)| a)
37    }
38
39    /// Inverse left unitor `a -> ((), a)`.
40    fn left_unitor_inv<A: Clone + 'static>() -> Self::P<A, ((), A)> {
41        Self::arr(|a: A| ((), a))
42    }
43
44    /// Right unitor `(a, ()) -> a`.
45    fn right_unitor<A: Clone + 'static>() -> Self::P<(A, ()), A> {
46        Self::arr(|(a, _): (A, ())| a)
47    }
48
49    /// Inverse right unitor `a -> (a, ())`.
50    fn right_unitor_inv<A: Clone + 'static>() -> Self::P<A, (A, ())> {
51        Self::arr(|a: A| (a, ()))
52    }
53}
54
55#[cfg(any(feature = "std", feature = "alloc"))]
56impl Tensor for FnA {
57    fn tensor<A: Clone + 'static, B: Clone + 'static, C: Clone + 'static, D: Clone + 'static>(
58        left: Self::P<A, B>,
59        right: Self::P<C, D>,
60    ) -> Self::P<(A, C), (B, D)> {
61        Self::split(left, right)
62    }
63}
64
65#[cfg(all(test, any(feature = "std", feature = "alloc")))]
66mod tests {
67    use super::*;
68    use karpal_arrow::Semigroupoid;
69
70    #[test]
71    fn fna_tensor_runs_both_sides_in_parallel() {
72        let left = FnA::arr(|n: i32| n * 2);
73        let right = FnA::arr(|flag: bool| !flag);
74        let combined = FnA::tensor(left, right);
75
76        assert_eq!(combined((3, true)), (6, false));
77    }
78
79    #[test]
80    fn associator_round_trips() {
81        let assoc = FnA::associate::<i32, bool, &'static str>();
82        let assoc_inv = FnA::associate_inv::<i32, bool, &'static str>();
83        let round_trip = FnA::compose(assoc_inv, assoc);
84
85        assert_eq!(round_trip(((4, true), "x")), ((4, true), "x"));
86    }
87
88    #[test]
89    fn unitors_round_trip() {
90        let left = FnA::compose(FnA::left_unitor::<i32>(), FnA::left_unitor_inv::<i32>());
91        let right = FnA::compose(FnA::right_unitor::<i32>(), FnA::right_unitor_inv::<i32>());
92
93        assert_eq!(left(5), 5);
94        assert_eq!(right(5), 5);
95    }
96}