Skip to main content

karpal_optics/
prism.rs

1// Copyright (C) 2026 Industrial Algebra
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::fold::Fold;
5use crate::optic::Optic;
6use crate::review::Review;
7use crate::setter::Setter;
8use crate::traversal::Traversal;
9use karpal_profunctor::choice::Choice;
10
11/// A prism focuses on one variant of a sum type.
12///
13/// `S` — source type, `T` — modified source type,
14/// `A` — focus type (the variant's inner value), `B` — replacement type.
15///
16/// Where a [`Lens`](crate::Lens) uses [`Strong`](karpal_profunctor::Strong) to
17/// decompose products, a Prism uses [`Choice`] to decompose coproducts.
18///
19/// For simple (monomorphic) prisms, use [`SimplePrism`].
20pub struct Prism<S, T, A, B> {
21    /// Attempt to match. `Ok(a)` = matched, `Err(t)` = didn't match (pass-through).
22    match_: fn(S) -> Result<A, T>,
23    /// Construct a `T` from the replacement value.
24    build: fn(B) -> T,
25}
26
27/// A simple (monomorphic) prism where `S == T` and `A == B`.
28pub type SimplePrism<S, A> = Prism<S, S, A, A>;
29
30impl<S, T, A, B> Optic for Prism<S, T, A, B> {}
31
32impl<S, T, A, B> Prism<S, T, A, B> {
33    pub fn new(match_: fn(S) -> Result<A, T>, build: fn(B) -> T) -> Self {
34        Self { match_, build }
35    }
36
37    /// Try to extract the focus. Returns `Some(a)` if the variant matches.
38    pub fn preview(&self, s: &S) -> Option<A>
39    where
40        S: Clone,
41    {
42        (self.match_)(s.clone()).ok()
43    }
44
45    /// Construct a `T` from a replacement value (inject/construct).
46    pub fn review(&self, b: B) -> T {
47        (self.build)(b)
48    }
49
50    /// Replace the focus if the variant matches; otherwise pass through.
51    pub fn set(&self, s: S, b: B) -> T {
52        match (self.match_)(s) {
53            Ok(_) => (self.build)(b),
54            Err(t) => t,
55        }
56    }
57
58    /// Modify the focus if the variant matches; otherwise pass through.
59    pub fn over(&self, s: S, f: impl FnOnce(A) -> B) -> T {
60        match (self.match_)(s) {
61            Ok(a) => (self.build)(f(a)),
62            Err(t) => t,
63        }
64    }
65
66    /// Convert to a `Review` (write-only, construction).
67    pub fn to_review(&self) -> Review<T, B> {
68        Review::new(self.build)
69    }
70
71    /// Convert to a `Setter` (modify-only).
72    pub fn to_setter(&self) -> Setter<S, T, A, B>
73    where
74        S: 'static,
75        T: 'static,
76        A: 'static,
77        B: 'static,
78    {
79        let match_ = self.match_;
80        let build = self.build;
81        Setter::new(move |s: S, f: &dyn Fn(A) -> B| match match_(s) {
82            Ok(a) => build(f(a)),
83            Err(t) => t,
84        })
85    }
86
87    /// Convert to a `Traversal` (0-or-1 element focus).
88    pub fn to_traversal(&self) -> Traversal<S, T, A, B>
89    where
90        S: Clone + 'static,
91        T: 'static,
92        A: 'static,
93        B: 'static,
94    {
95        let match_ = self.match_;
96        let build = self.build;
97        Traversal::new(
98            move |s: &S| match match_(s.clone()) {
99                Ok(a) => vec![a],
100                Err(_) => vec![],
101            },
102            move |s: S, f: &dyn Fn(A) -> B| match match_(s) {
103                Ok(a) => build(f(a)),
104                Err(t) => t,
105            },
106        )
107    }
108
109    /// Convert to a `Fold` (0-or-1 element, read-only).
110    pub fn to_fold(&self) -> Fold<S, A>
111    where
112        S: Clone + 'static,
113        T: 'static,
114        A: 'static,
115    {
116        let match_ = self.match_;
117        Fold::new(move |s: &S| match match_(s.clone()) {
118            Ok(a) => vec![a],
119            Err(_) => vec![],
120        })
121    }
122
123    /// Profunctor encoding: transform a `P<A, B>` into a `P<S, T>` via this prism.
124    ///
125    /// This connects prisms to the profunctor hierarchy through [`Choice`].
126    /// Given any `Choice` profunctor `P` and a value `pab: P<A, B>`,
127    /// `transform` produces `P<S, T>` by:
128    ///
129    /// 1. `right(pab)` lifts to `P<Result<T, A>, Result<T, B>>`
130    /// 2. `dimap` pre-composes with `match_` (swapping arms) and post-composes
131    ///    with `build` (reassembling)
132    ///
133    /// The arm-swapping (`Ok→Err`, `Err→Ok` in the pre-composition) is necessary
134    /// because `Choice::right` acts on the `Err` branch of `Result`.
135    pub fn transform<P: Choice>(&self, pab: P::P<A, B>) -> P::P<S, T>
136    where
137        S: 'static,
138        T: 'static,
139        A: 'static,
140        B: 'static,
141    {
142        let match_ = self.match_;
143        let build = self.build;
144        let right_pab = P::right::<A, B, T>(pab);
145        P::dimap(
146            move |s: S| match match_(s) {
147                Ok(a) => Err(a), // focus found → Err arm for Choice::right
148                Err(t) => Ok(t), // no match → Ok arm passes through
149            },
150            move |result: Result<T, B>| match result {
151                Ok(t) => t,         // passed through unchanged
152                Err(b) => build(b), // transformed, rebuild
153            },
154            right_pab,
155        )
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162    use karpal_profunctor::FnP;
163    use proptest::prelude::*;
164
165    #[derive(Debug, Clone, PartialEq)]
166    enum Shape {
167        Circle(f64),
168        Rectangle(f64, f64),
169    }
170
171    fn circle_prism() -> SimplePrism<Shape, f64> {
172        Prism::new(
173            |s| match s {
174                Shape::Circle(r) => Ok(r),
175                Shape::Rectangle(w, h) => Err(Shape::Rectangle(w, h)),
176            },
177            Shape::Circle,
178        )
179    }
180
181    fn sample_circle() -> Shape {
182        Shape::Circle(5.0)
183    }
184
185    fn sample_rect() -> Shape {
186        Shape::Rectangle(3.0, 4.0)
187    }
188
189    // --- Unit tests ---
190
191    #[test]
192    fn preview_match() {
193        let prism = circle_prism();
194        assert_eq!(prism.preview(&sample_circle()), Some(5.0));
195    }
196
197    #[test]
198    fn preview_no_match() {
199        let prism = circle_prism();
200        assert_eq!(prism.preview(&sample_rect()), None);
201    }
202
203    #[test]
204    fn review() {
205        let prism = circle_prism();
206        assert_eq!(prism.review(10.0), Shape::Circle(10.0));
207    }
208
209    #[test]
210    fn set_match() {
211        let prism = circle_prism();
212        assert_eq!(prism.set(sample_circle(), 10.0), Shape::Circle(10.0));
213    }
214
215    #[test]
216    fn set_no_match() {
217        let prism = circle_prism();
218        assert_eq!(prism.set(sample_rect(), 10.0), sample_rect());
219    }
220
221    #[test]
222    fn over_match() {
223        let prism = circle_prism();
224        assert_eq!(
225            prism.over(sample_circle(), |r| r * 2.0),
226            Shape::Circle(10.0)
227        );
228    }
229
230    #[test]
231    fn over_no_match() {
232        let prism = circle_prism();
233        assert_eq!(prism.over(sample_rect(), |r| r * 2.0), sample_rect());
234    }
235
236    // --- Prism laws (proptest) ---
237
238    // Use bounded, finite f64 to avoid NaN/infinity
239    fn finite_f64() -> impl Strategy<Value = f64> {
240        (-1e6f64..1e6f64).prop_filter("finite", |v| v.is_finite())
241    }
242
243    // ReviewPreview: preview(review(b)) == Some(b)
244    proptest! {
245        #[test]
246        fn law_review_preview(b in finite_f64()) {
247            let prism = circle_prism();
248            let s = prism.review(b);
249            prop_assert_eq!(prism.preview(&s), Some(b));
250        }
251    }
252
253    // PreviewReview: if preview(s) == Some(a) then review(a) == s
254    proptest! {
255        #[test]
256        fn law_preview_review(r in finite_f64()) {
257            let prism = circle_prism();
258            let s = Shape::Circle(r);
259            if let Some(a) = prism.preview(&s) {
260                prop_assert_eq!(prism.review(a), s);
261            }
262        }
263    }
264
265    // OverIdentity: over(s, id) == s
266    proptest! {
267        #[test]
268        fn law_over_identity_circle(r in finite_f64()) {
269            let prism = circle_prism();
270            let s = Shape::Circle(r);
271            prop_assert_eq!(prism.over(s.clone(), |x| x), s);
272        }
273
274        #[test]
275        fn law_over_identity_rect(w in finite_f64(), h in finite_f64()) {
276            let prism = circle_prism();
277            let s = Shape::Rectangle(w, h);
278            prop_assert_eq!(prism.over(s.clone(), |x| x), s);
279        }
280    }
281
282    // --- FnP integration ---
283
284    #[test]
285    fn transform_fnp_match() {
286        let prism = circle_prism();
287        let double: Box<dyn Fn(f64) -> f64> = Box::new(|r| r * 2.0);
288        let transform_fn = prism.transform::<FnP>(double);
289        assert_eq!(transform_fn(sample_circle()), Shape::Circle(10.0));
290    }
291
292    #[test]
293    fn transform_fnp_no_match() {
294        let prism = circle_prism();
295        let double: Box<dyn Fn(f64) -> f64> = Box::new(|r| r * 2.0);
296        let transform_fn = prism.transform::<FnP>(double);
297        assert_eq!(transform_fn(sample_rect()), sample_rect());
298    }
299}