Skip to main content

open_hypergraphs/lax/
optic.rs

1//! # Optics for lax open hypergraphs
2//!
3//! This module provides an interface for defining optics on [`crate::lax::OpenHypergraph`]
4//! via the [`Optic`] trait.
5//!
6//! By defining the fwd and reverse mappings on objects and operations, the [`Optic`] trait will
7//! give you `map_arrow` and `map_adapted` methods for free.
8use std::fmt::Debug;
9
10use crate::lax::functor::dyn_functor::{to_dyn_functor, DynFunctor};
11use crate::lax::functor::Functor;
12
13use crate::operations::Operations;
14use crate::strict::vec::VecArray;
15use crate::strict::vec::VecKind;
16use crate::strict::IndexedCoproduct;
17use crate::strict::SemifiniteFunction;
18use crate::{lax, lax::OpenHypergraph, strict::functor::optic::Optic as StrictOptic};
19
20/// #
21///
22/// foo
23pub trait Optic<
24    O1: Clone + PartialEq,
25    A1: Clone,
26    O2: Clone + PartialEq + std::fmt::Debug,
27    A2: Clone,
28>: Clone + 'static
29{
30    fn fwd_object(&self, o: &O1) -> Vec<O2>;
31    fn fwd_operation(&self, a: &A1, source: &[O1], target: &[O1]) -> OpenHypergraph<O2, A2>;
32    fn rev_object(&self, o: &O1) -> Vec<O2>;
33    fn rev_operation(&self, a: &A1, source: &[O1], target: &[O1]) -> OpenHypergraph<O2, A2>;
34    fn residual(&self, a: &A1) -> Vec<O2>;
35
36    fn map_arrow(&self, term: OpenHypergraph<O1, A1>) -> OpenHypergraph<O2, A2> {
37        let optic = to_strict_optic(self);
38        let strict = term.to_strict();
39        lax::OpenHypergraph::from_strict({
40            // Get the right trait in scope.
41            use crate::strict::functor::Functor;
42            optic.map_arrow(&strict)
43        })
44    }
45
46    fn map_adapted(&self, term: OpenHypergraph<O1, A1>) -> OpenHypergraph<O2, A2> {
47        let optic = to_strict_optic(self);
48        let strict = term.to_strict();
49        lax::OpenHypergraph::from_strict({
50            use crate::strict::functor::Functor;
51            let optic_term = optic.map_arrow(&strict);
52            // Adapt the produced term so it's monogamous again (as long as the input was).
53            optic.adapt(&optic_term, &strict.source(), &strict.target())
54        })
55    }
56}
57
58fn to_strict_optic<
59    T: Optic<O1, A1, O2, A2> + 'static,
60    O1: Clone + PartialEq,
61    A1: Clone,
62    O2: Clone + PartialEq + Debug,
63    A2: Clone,
64>(
65    this: &T,
66) -> StrictOptic<
67    DynFunctor<Fwd<T, O1, A1, O2, A2>, O1, A1, O2, A2>,
68    DynFunctor<Rev<T, O1, A1, O2, A2>, O1, A1, O2, A2>,
69    VecKind,
70    O1,
71    A1,
72    O2,
73    A2,
74> {
75    let fwd = to_dyn_functor(Fwd::new(this.clone()));
76    let rev = to_dyn_functor(Rev::new(this.clone()));
77
78    // Clone self to avoid lifetime issues in the closure
79    let self_clone = this.clone();
80
81    StrictOptic::new(
82        fwd,
83        rev,
84        Box::new(move |ops: &Operations<VecKind, O1, A1>| {
85            let mut sources_vec = Vec::new();
86            let mut residuals = Vec::new();
87
88            for (op, _, _) in ops.iter() {
89                let m = self_clone.residual(op);
90                sources_vec.push(m.len());
91                residuals.extend(m);
92            }
93
94            let sources = SemifiniteFunction::<VecKind, usize>(VecArray(sources_vec));
95            let values = SemifiniteFunction(VecArray(residuals));
96            IndexedCoproduct::from_semifinite(sources, values).unwrap()
97        }),
98    )
99}
100
101////////////////////////////////////////////////////////////////////////////////
102// Fwd and Rev lax functor helpers, needed for Optic
103// **IMPORTANT NOTE**: never expose these in the public API.
104// They rely on never having their `map_arrow` methods called, and panic! in that case.
105
106#[derive(Clone, PartialEq)]
107struct Fwd<T, O1, A1, O2, A2> {
108    _phantom: std::marker::PhantomData<(O1, A1, O2, A2)>,
109    optic: Box<T>,
110}
111
112impl<
113        T: Optic<O1, A1, O2, A2>,
114        O1: Clone + PartialEq,
115        A1: Clone,
116        O2: Clone + PartialEq + Debug,
117        A2: Clone,
118    > Functor<O1, A1, O2, A2> for Fwd<T, O1, A1, O2, A2>
119{
120    fn map_object(&self, o: &O1) -> impl ExactSizeIterator<Item = O2> {
121        self.optic.fwd_object(o).into_iter()
122    }
123
124    fn map_operation(&self, a: &A1, source: &[O1], target: &[O1]) -> OpenHypergraph<O2, A2> {
125        self.optic.fwd_operation(a, source, target)
126    }
127
128    // NOTE: this method is never called; and the struct *must* remain private.
129    fn map_arrow(&self, _f: &OpenHypergraph<O1, A1>) -> OpenHypergraph<O2, A2> {
130        panic!("Fwd is not a functor!");
131    }
132}
133
134impl<
135        T: Optic<O1, A1, O2, A2>,
136        O1: Clone + PartialEq,
137        A1: Clone,
138        O2: Clone + PartialEq + Debug,
139        A2: Clone,
140    > Fwd<T, O1, A1, O2, A2>
141{
142    fn new(t: T) -> Self {
143        Self {
144            _phantom: std::marker::PhantomData,
145            optic: Box::new(t),
146        }
147    }
148}
149
150#[derive(Clone, PartialEq)]
151struct Rev<T, O1, A1, O2, A2> {
152    _phantom: std::marker::PhantomData<(O1, A1, O2, A2)>,
153    optic: Box<T>,
154}
155
156impl<
157        T: Optic<O1, A1, O2, A2>,
158        O1: Clone + PartialEq,
159        A1: Clone,
160        O2: Clone + PartialEq + Debug,
161        A2: Clone,
162    > Functor<O1, A1, O2, A2> for Rev<T, O1, A1, O2, A2>
163{
164    fn map_object(&self, o: &O1) -> impl ExactSizeIterator<Item = O2> {
165        self.optic.rev_object(o).into_iter()
166    }
167
168    fn map_operation(&self, a: &A1, source: &[O1], target: &[O1]) -> OpenHypergraph<O2, A2> {
169        self.optic.rev_operation(a, source, target)
170    }
171
172    // NOTE: this method is never called; and the struct *must* remain private.
173    fn map_arrow(&self, _f: &OpenHypergraph<O1, A1>) -> OpenHypergraph<O2, A2> {
174        panic!("Rev is not a functor!");
175    }
176}
177
178impl<
179        T: Optic<O1, A1, O2, A2>,
180        O1: Clone + PartialEq,
181        A1: Clone,
182        O2: Clone + PartialEq + Debug,
183        A2: Clone,
184    > Rev<T, O1, A1, O2, A2>
185{
186    fn new(t: T) -> Self {
187        Self {
188            _phantom: std::marker::PhantomData,
189            optic: Box::new(t),
190        }
191    }
192}