Skip to main content

karpal_optics/
traversal.rs

1use crate::optic::Optic;
2use karpal_profunctor::Traversing;
3use std::rc::Rc;
4
5/// A multi-focus optic that can get/modify zero or more foci.
6///
7/// `S` — source, `T` — modified source, `A` — focus, `B` — replacement.
8pub struct Traversal<S, T, A, B> {
9    #[allow(clippy::type_complexity)]
10    get_all: Rc<dyn Fn(&S) -> Vec<A>>,
11    #[allow(clippy::type_complexity)]
12    modify_all: Rc<dyn Fn(S, &dyn Fn(A) -> B) -> T>,
13}
14
15/// A simple (monomorphic) traversal where `S == T` and `A == B`.
16pub type SimpleTraversal<S, A> = Traversal<S, S, A, A>;
17
18impl<S, T, A, B> Optic for Traversal<S, T, A, B> {}
19
20impl<S, T, A, B> Traversal<S, T, A, B> {
21    pub fn new(
22        get_all: impl Fn(&S) -> Vec<A> + 'static,
23        modify_all: impl Fn(S, &dyn Fn(A) -> B) -> T + 'static,
24    ) -> Self {
25        Self {
26            get_all: Rc::new(get_all),
27            modify_all: Rc::new(modify_all),
28        }
29    }
30
31    pub fn get_all(&self, s: &S) -> Vec<A> {
32        (self.get_all)(s)
33    }
34
35    pub fn over(&self, s: S, f: impl Fn(A) -> B) -> T {
36        (self.modify_all)(s, &f)
37    }
38
39    pub fn set(&self, s: S, b: B) -> T
40    where
41        B: Clone,
42    {
43        (self.modify_all)(s, &|_| b.clone())
44    }
45
46    /// Profunctor encoding via `Traversing::wander`.
47    pub fn transform<P: Traversing>(&self, pab: P::P<A, B>) -> P::P<S, T>
48    where
49        S: 'static,
50        T: 'static,
51        A: 'static,
52        B: 'static,
53    {
54        let get_all = Rc::clone(&self.get_all);
55        let modify_all = Rc::clone(&self.modify_all);
56        P::wander(move |s| get_all(s), move |s, f| modify_all(s, f), pab)
57    }
58
59    /// Convert to a `Fold` (read-only).
60    pub fn to_fold(&self) -> crate::fold::Fold<S, A>
61    where
62        S: 'static,
63        A: 'static,
64    {
65        let get_all = Rc::clone(&self.get_all);
66        crate::fold::Fold::new(move |s| get_all(s))
67    }
68
69    /// Compose with another traversal for deeper multi-focus access.
70    pub fn then<X, Y>(self, inner: Traversal<A, B, X, Y>) -> ComposedTraversal<S, T, X, Y>
71    where
72        S: 'static,
73        T: 'static,
74        A: 'static,
75        B: 'static,
76        X: 'static,
77        Y: 'static,
78    {
79        let outer_get_all = self.get_all;
80        let inner_get_all = Rc::clone(&inner.get_all);
81        let outer_modify_all = self.modify_all;
82        let inner_modify_all = inner.modify_all;
83        ComposedTraversal {
84            get_all: Box::new(move |s| {
85                outer_get_all(s)
86                    .into_iter()
87                    .flat_map(|a| inner_get_all(&a))
88                    .collect()
89            }),
90            modify_all: Box::new(move |s, f| outer_modify_all(s, &|a| (inner_modify_all)(a, f))),
91        }
92    }
93}
94
95/// A composed traversal using boxed closures.
96pub struct ComposedTraversal<S, T, A, B> {
97    #[allow(clippy::type_complexity)]
98    get_all: Box<dyn Fn(&S) -> Vec<A>>,
99    #[allow(clippy::type_complexity)]
100    modify_all: Box<dyn Fn(S, &dyn Fn(A) -> B) -> T>,
101}
102
103impl<S, T, A, B> Optic for ComposedTraversal<S, T, A, B> {}
104
105impl<S, T, A, B> ComposedTraversal<S, T, A, B> {
106    pub fn get_all(&self, s: &S) -> Vec<A> {
107        (self.get_all)(s)
108    }
109
110    pub fn over(&self, s: S, f: impl Fn(A) -> B) -> T {
111        (self.modify_all)(s, &f)
112    }
113
114    pub fn set(&self, s: S, b: B) -> T
115    where
116        B: Clone,
117    {
118        (self.modify_all)(s, &|_| b.clone())
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125    use karpal_profunctor::{FnP, ForgetF};
126
127    fn vec_each_traversal() -> SimpleTraversal<Vec<i32>, i32> {
128        Traversal::new(
129            |v: &Vec<i32>| v.clone(),
130            |v: Vec<i32>, f: &dyn Fn(i32) -> i32| v.into_iter().map(f).collect(),
131        )
132    }
133
134    #[test]
135    fn traversal_get_all() {
136        let trav = vec_each_traversal();
137        assert_eq!(trav.get_all(&vec![1, 2, 3]), vec![1, 2, 3]);
138    }
139
140    #[test]
141    fn traversal_over() {
142        let trav = vec_each_traversal();
143        assert_eq!(trav.over(vec![1, 2, 3], |x| x * 10), vec![10, 20, 30]);
144    }
145
146    #[test]
147    fn traversal_set() {
148        let trav = vec_each_traversal();
149        assert_eq!(trav.set(vec![1, 2, 3], 0), vec![0, 0, 0]);
150    }
151
152    #[test]
153    fn traversal_transform_fnp() {
154        let trav = vec_each_traversal();
155        let double: Box<dyn Fn(i32) -> i32> = Box::new(|x| x * 2);
156        let f = trav.transform::<FnP>(double);
157        assert_eq!(f(vec![1, 2, 3]), vec![2, 4, 6]);
158    }
159
160    #[test]
161    fn traversal_transform_forget() {
162        let trav = vec_each_traversal();
163        let to_string: Box<dyn Fn(i32) -> String> = Box::new(|x| x.to_string());
164        let f = trav.transform::<ForgetF<String>>(to_string);
165        // ForgetF with String Monoid concatenates
166        assert_eq!(f(vec![1, 2, 3]), "123");
167    }
168
169    #[test]
170    fn traversal_identity_law() {
171        let trav = vec_each_traversal();
172        let v = vec![1, 2, 3];
173        assert_eq!(trav.over(v.clone(), |x| x), v);
174    }
175
176    // Composition: traverse into nested vecs
177    #[test]
178    fn traversal_composition() {
179        let outer: SimpleTraversal<Vec<Vec<i32>>, Vec<i32>> = Traversal::new(
180            |v: &Vec<Vec<i32>>| v.clone(),
181            |v: Vec<Vec<i32>>, f: &dyn Fn(Vec<i32>) -> Vec<i32>| {
182                v.into_iter().map(f).collect::<Vec<_>>()
183            },
184        );
185        let inner = vec_each_traversal();
186        let composed = outer.then(inner);
187        assert_eq!(composed.get_all(&vec![vec![1, 2], vec![3]]), vec![1, 2, 3]);
188        assert_eq!(
189            composed.over(vec![vec![1, 2], vec![3]], |x| x * 10),
190            vec![vec![10, 20], vec![30]]
191        );
192    }
193
194    #[test]
195    fn traversal_from_lens() {
196        use crate::lens::Lens;
197
198        #[derive(Debug, Clone, PartialEq)]
199        struct Point {
200            x: i32,
201            y: i32,
202        }
203
204        let lens = Lens::new(|p: &Point| p.x, |p: Point, x| Point { x, ..p });
205        let trav = lens.to_traversal();
206        let p = Point { x: 1, y: 2 };
207        assert_eq!(trav.get_all(&p), vec![1]);
208        assert_eq!(trav.over(p, |x| x + 10), Point { x: 11, y: 2 });
209    }
210
211    #[test]
212    fn traversal_from_prism() {
213        use crate::prism::Prism;
214
215        #[derive(Debug, Clone, PartialEq)]
216        enum Val {
217            Int(i32),
218            Str(String),
219        }
220        let prism = Prism::new(
221            |v: Val| match v {
222                Val::Int(n) => Ok(n),
223                Val::Str(s) => Err(Val::Str(s)),
224            },
225            Val::Int,
226        );
227        let trav = prism.to_traversal();
228        assert_eq!(trav.get_all(&Val::Int(5)), vec![5]);
229        assert_eq!(trav.get_all(&Val::Str("hi".into())), Vec::<i32>::new());
230        assert_eq!(trav.over(Val::Int(5), |x| x * 2), Val::Int(10));
231        assert_eq!(
232            trav.over(Val::Str("hi".into()), |x| x * 2),
233            Val::Str("hi".into())
234        );
235    }
236
237    #[test]
238    fn traversal_composition_law() {
239        // over(over(s, f), g) == over(s, g . f)
240        let trav = vec_each_traversal();
241        let v = vec![1, 2, 3];
242        let left = trav.over(trav.over(v.clone(), |x| x + 1), |x| x * 2);
243        let right = trav.over(v, |x| (x + 1) * 2);
244        assert_eq!(left, right);
245    }
246}