Skip to main content

open_hypergraphs/lax/
optic.rs

1//! # Optics for lax open hypergraphs
2//!
3//! This module provides an interface for defining optics on [`crate::lax::OpenHypergraph`]
4//! via the [`Optic`] trait.
5//!
6//! By defining the fwd and reverse mappings on objects and operations, the [`Optic`] trait will
7//! give you `map_arrow` and `map_adapted` methods for free.
8use std::fmt::Debug;
9
10use crate::lax::functor::dyn_functor::{to_dyn_functor, DynFunctor};
11use crate::lax::functor::Functor;
12
13use crate::operations::Operations;
14use crate::strict::vec::VecArray;
15use crate::strict::vec::VecKind;
16use crate::strict::IndexedCoproduct;
17use crate::strict::SemifiniteFunction;
18use crate::{lax, lax::OpenHypergraph, strict::functor::optic::Optic as StrictOptic};
19
20/// #
21///
22/// foo
23pub trait Optic<
24    O1: Clone + PartialEq,
25    A1: Clone,
26    O2: Clone + PartialEq + std::fmt::Debug,
27    A2: Clone,
28>: Clone + 'static
29{
30    fn fwd_object(&self, o: &O1) -> Vec<O2>;
31    fn fwd_operation(&self, a: &A1, source: &[O1], target: &[O1]) -> OpenHypergraph<O2, A2>;
32    fn rev_object(&self, o: &O1) -> Vec<O2>;
33    fn rev_operation(&self, a: &A1, source: &[O1], target: &[O1]) -> OpenHypergraph<O2, A2>;
34    fn residual(&self, a: &A1) -> Vec<O2>;
35
36    fn map_arrow(&self, term: OpenHypergraph<O1, A1>) -> OpenHypergraph<O2, A2> {
37        let optic = to_strict_optic(self);
38        let strict = term.to_strict();
39        lax::OpenHypergraph::from_strict({
40            // Get the right trait in scope.
41            use crate::strict::functor::Functor;
42            optic.map_arrow(&strict)
43        })
44    }
45
46    fn map_adapted(&self, term: OpenHypergraph<O1, A1>) -> OpenHypergraph<O2, A2> {
47        let optic = to_strict_optic(self);
48        let strict = term.to_strict();
49        lax::OpenHypergraph::from_strict({
50            use crate::strict::functor::Functor;
51            let optic_term = optic.map_arrow(&strict);
52            // Adapt the produced term so it's monogamous again (as long as the input was).
53            optic.adapt(&optic_term, &strict.source(), &strict.target())
54        })
55    }
56}
57
58#[allow(clippy::type_complexity)]
59fn to_strict_optic<
60    T: Optic<O1, A1, O2, A2> + 'static,
61    O1: Clone + PartialEq,
62    A1: Clone,
63    O2: Clone + PartialEq + Debug,
64    A2: Clone,
65>(
66    this: &T,
67) -> StrictOptic<
68    DynFunctor<Fwd<T, O1, A1, O2, A2>, O1, A1, O2, A2>,
69    DynFunctor<Rev<T, O1, A1, O2, A2>, O1, A1, O2, A2>,
70    VecKind,
71    O1,
72    A1,
73    O2,
74    A2,
75> {
76    let fwd = to_dyn_functor(Fwd::new(this.clone()));
77    let rev = to_dyn_functor(Rev::new(this.clone()));
78
79    // Clone self to avoid lifetime issues in the closure
80    let self_clone = this.clone();
81
82    StrictOptic::new(
83        fwd,
84        rev,
85        Box::new(move |ops: &Operations<VecKind, O1, A1>| {
86            let mut sources_vec = Vec::new();
87            let mut residuals = Vec::new();
88
89            for (op, _, _) in ops.iter() {
90                let m = self_clone.residual(op);
91                sources_vec.push(m.len());
92                residuals.extend(m);
93            }
94
95            let sources = SemifiniteFunction::<VecKind, usize>(VecArray(sources_vec));
96            let values = SemifiniteFunction(VecArray(residuals));
97            IndexedCoproduct::from_semifinite(sources, values).unwrap()
98        }),
99    )
100}
101
102////////////////////////////////////////////////////////////////////////////////
103// Fwd and Rev lax functor helpers, needed for Optic
104// **IMPORTANT NOTE**: never expose these in the public API.
105// They rely on never having their `map_arrow` methods called, and panic! in that case.
106
107#[derive(Clone, PartialEq)]
108struct Fwd<T, O1, A1, O2, A2> {
109    _phantom: std::marker::PhantomData<(O1, A1, O2, A2)>,
110    optic: Box<T>,
111}
112
113impl<
114        T: Optic<O1, A1, O2, A2>,
115        O1: Clone + PartialEq,
116        A1: Clone,
117        O2: Clone + PartialEq + Debug,
118        A2: Clone,
119    > Functor<O1, A1, O2, A2> for Fwd<T, O1, A1, O2, A2>
120{
121    fn map_object(&self, o: &O1) -> impl ExactSizeIterator<Item = O2> {
122        self.optic.fwd_object(o).into_iter()
123    }
124
125    fn map_operation(&self, a: &A1, source: &[O1], target: &[O1]) -> OpenHypergraph<O2, A2> {
126        self.optic.fwd_operation(a, source, target)
127    }
128
129    // NOTE: this method is never called; and the struct *must* remain private.
130    fn map_arrow(&self, _f: &OpenHypergraph<O1, A1>) -> OpenHypergraph<O2, A2> {
131        panic!("Fwd is not a functor!");
132    }
133}
134
135impl<
136        T: Optic<O1, A1, O2, A2>,
137        O1: Clone + PartialEq,
138        A1: Clone,
139        O2: Clone + PartialEq + Debug,
140        A2: Clone,
141    > Fwd<T, O1, A1, O2, A2>
142{
143    fn new(t: T) -> Self {
144        Self {
145            _phantom: std::marker::PhantomData,
146            optic: Box::new(t),
147        }
148    }
149}
150
151#[derive(Clone, PartialEq)]
152struct Rev<T, O1, A1, O2, A2> {
153    _phantom: std::marker::PhantomData<(O1, A1, O2, A2)>,
154    optic: Box<T>,
155}
156
157impl<
158        T: Optic<O1, A1, O2, A2>,
159        O1: Clone + PartialEq,
160        A1: Clone,
161        O2: Clone + PartialEq + Debug,
162        A2: Clone,
163    > Functor<O1, A1, O2, A2> for Rev<T, O1, A1, O2, A2>
164{
165    fn map_object(&self, o: &O1) -> impl ExactSizeIterator<Item = O2> {
166        self.optic.rev_object(o).into_iter()
167    }
168
169    fn map_operation(&self, a: &A1, source: &[O1], target: &[O1]) -> OpenHypergraph<O2, A2> {
170        self.optic.rev_operation(a, source, target)
171    }
172
173    // NOTE: this method is never called; and the struct *must* remain private.
174    fn map_arrow(&self, _f: &OpenHypergraph<O1, A1>) -> OpenHypergraph<O2, A2> {
175        panic!("Rev is not a functor!");
176    }
177}
178
179impl<
180        T: Optic<O1, A1, O2, A2>,
181        O1: Clone + PartialEq,
182        A1: Clone,
183        O2: Clone + PartialEq + Debug,
184        A2: Clone,
185    > Rev<T, O1, A1, O2, A2>
186{
187    fn new(t: T) -> Self {
188        Self {
189            _phantom: std::marker::PhantomData,
190            optic: Box::new(t),
191        }
192    }
193}