Skip to main content

karpal_optics/
traversal.rs

1// Copyright (C) 2026 Industrial Algebra
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::optic::Optic;
5use karpal_profunctor::Traversing;
6use std::rc::Rc;
7
8/// A multi-focus optic that can get/modify zero or more foci.
9///
10/// `S` — source, `T` — modified source, `A` — focus, `B` — replacement.
11pub struct Traversal<S, T, A, B> {
12    #[allow(clippy::type_complexity)]
13    get_all: Rc<dyn Fn(&S) -> Vec<A>>,
14    #[allow(clippy::type_complexity)]
15    modify_all: Rc<dyn Fn(S, &dyn Fn(A) -> B) -> T>,
16}
17
18/// A simple (monomorphic) traversal where `S == T` and `A == B`.
19pub type SimpleTraversal<S, A> = Traversal<S, S, A, A>;
20
21impl<S, T, A, B> Optic for Traversal<S, T, A, B> {}
22
23impl<S, T, A, B> Traversal<S, T, A, B> {
24    pub fn new(
25        get_all: impl Fn(&S) -> Vec<A> + 'static,
26        modify_all: impl Fn(S, &dyn Fn(A) -> B) -> T + 'static,
27    ) -> Self {
28        Self {
29            get_all: Rc::new(get_all),
30            modify_all: Rc::new(modify_all),
31        }
32    }
33
34    pub fn get_all(&self, s: &S) -> Vec<A> {
35        (self.get_all)(s)
36    }
37
38    pub fn over(&self, s: S, f: impl Fn(A) -> B) -> T {
39        (self.modify_all)(s, &f)
40    }
41
42    pub fn set(&self, s: S, b: B) -> T
43    where
44        B: Clone,
45    {
46        (self.modify_all)(s, &|_| b.clone())
47    }
48
49    /// Profunctor encoding via `Traversing::wander`.
50    pub fn transform<P: Traversing>(&self, pab: P::P<A, B>) -> P::P<S, T>
51    where
52        S: 'static,
53        T: 'static,
54        A: 'static,
55        B: 'static,
56    {
57        let get_all = Rc::clone(&self.get_all);
58        let modify_all = Rc::clone(&self.modify_all);
59        P::wander(move |s| get_all(s), move |s, f| modify_all(s, f), pab)
60    }
61
62    /// Convert to a `Fold` (read-only).
63    pub fn to_fold(&self) -> crate::fold::Fold<S, A>
64    where
65        S: 'static,
66        A: 'static,
67    {
68        let get_all = Rc::clone(&self.get_all);
69        crate::fold::Fold::new(move |s| get_all(s))
70    }
71
72    /// Compose with another traversal for deeper multi-focus access.
73    pub fn then<X, Y>(self, inner: Traversal<A, B, X, Y>) -> ComposedTraversal<S, T, X, Y>
74    where
75        S: 'static,
76        T: 'static,
77        A: 'static,
78        B: 'static,
79        X: 'static,
80        Y: 'static,
81    {
82        let outer_get_all = self.get_all;
83        let inner_get_all = Rc::clone(&inner.get_all);
84        let outer_modify_all = self.modify_all;
85        let inner_modify_all = inner.modify_all;
86        ComposedTraversal {
87            get_all: Box::new(move |s| {
88                outer_get_all(s)
89                    .into_iter()
90                    .flat_map(|a| inner_get_all(&a))
91                    .collect()
92            }),
93            modify_all: Box::new(move |s, f| outer_modify_all(s, &|a| (inner_modify_all)(a, f))),
94        }
95    }
96}
97
98/// A composed traversal using boxed closures.
99pub struct ComposedTraversal<S, T, A, B> {
100    #[allow(clippy::type_complexity)]
101    get_all: Box<dyn Fn(&S) -> Vec<A>>,
102    #[allow(clippy::type_complexity)]
103    modify_all: Box<dyn Fn(S, &dyn Fn(A) -> B) -> T>,
104}
105
106impl<S, T, A, B> Optic for ComposedTraversal<S, T, A, B> {}
107
108impl<S, T, A, B> ComposedTraversal<S, T, A, B> {
109    pub fn get_all(&self, s: &S) -> Vec<A> {
110        (self.get_all)(s)
111    }
112
113    pub fn over(&self, s: S, f: impl Fn(A) -> B) -> T {
114        (self.modify_all)(s, &f)
115    }
116
117    pub fn set(&self, s: S, b: B) -> T
118    where
119        B: Clone,
120    {
121        (self.modify_all)(s, &|_| b.clone())
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128    use karpal_profunctor::{FnP, ForgetF};
129
130    fn vec_each_traversal() -> SimpleTraversal<Vec<i32>, i32> {
131        Traversal::new(
132            |v: &Vec<i32>| v.clone(),
133            |v: Vec<i32>, f: &dyn Fn(i32) -> i32| v.into_iter().map(f).collect(),
134        )
135    }
136
137    #[test]
138    fn traversal_get_all() {
139        let trav = vec_each_traversal();
140        assert_eq!(trav.get_all(&vec![1, 2, 3]), vec![1, 2, 3]);
141    }
142
143    #[test]
144    fn traversal_over() {
145        let trav = vec_each_traversal();
146        assert_eq!(trav.over(vec![1, 2, 3], |x| x * 10), vec![10, 20, 30]);
147    }
148
149    #[test]
150    fn traversal_set() {
151        let trav = vec_each_traversal();
152        assert_eq!(trav.set(vec![1, 2, 3], 0), vec![0, 0, 0]);
153    }
154
155    #[test]
156    fn traversal_transform_fnp() {
157        let trav = vec_each_traversal();
158        let double: Box<dyn Fn(i32) -> i32> = Box::new(|x| x * 2);
159        let f = trav.transform::<FnP>(double);
160        assert_eq!(f(vec![1, 2, 3]), vec![2, 4, 6]);
161    }
162
163    #[test]
164    fn traversal_transform_forget() {
165        let trav = vec_each_traversal();
166        let to_string: Box<dyn Fn(i32) -> String> = Box::new(|x| x.to_string());
167        let f = trav.transform::<ForgetF<String>>(to_string);
168        // ForgetF with String Monoid concatenates
169        assert_eq!(f(vec![1, 2, 3]), "123");
170    }
171
172    #[test]
173    fn traversal_identity_law() {
174        let trav = vec_each_traversal();
175        let v = vec![1, 2, 3];
176        assert_eq!(trav.over(v.clone(), |x| x), v);
177    }
178
179    // Composition: traverse into nested vecs
180    #[test]
181    fn traversal_composition() {
182        let outer: SimpleTraversal<Vec<Vec<i32>>, Vec<i32>> = Traversal::new(
183            |v: &Vec<Vec<i32>>| v.clone(),
184            |v: Vec<Vec<i32>>, f: &dyn Fn(Vec<i32>) -> Vec<i32>| {
185                v.into_iter().map(f).collect::<Vec<_>>()
186            },
187        );
188        let inner = vec_each_traversal();
189        let composed = outer.then(inner);
190        assert_eq!(composed.get_all(&vec![vec![1, 2], vec![3]]), vec![1, 2, 3]);
191        assert_eq!(
192            composed.over(vec![vec![1, 2], vec![3]], |x| x * 10),
193            vec![vec![10, 20], vec![30]]
194        );
195    }
196
197    #[test]
198    fn traversal_from_lens() {
199        use crate::lens::Lens;
200
201        #[derive(Debug, Clone, PartialEq)]
202        struct Point {
203            x: i32,
204            y: i32,
205        }
206
207        let lens = Lens::new(|p: &Point| p.x, |p: Point, x| Point { x, ..p });
208        let trav = lens.to_traversal();
209        let p = Point { x: 1, y: 2 };
210        assert_eq!(trav.get_all(&p), vec![1]);
211        assert_eq!(trav.over(p, |x| x + 10), Point { x: 11, y: 2 });
212    }
213
214    #[test]
215    fn traversal_from_prism() {
216        use crate::prism::Prism;
217
218        #[derive(Debug, Clone, PartialEq)]
219        enum Val {
220            Int(i32),
221            Str(String),
222        }
223        let prism = Prism::new(
224            |v: Val| match v {
225                Val::Int(n) => Ok(n),
226                Val::Str(s) => Err(Val::Str(s)),
227            },
228            Val::Int,
229        );
230        let trav = prism.to_traversal();
231        assert_eq!(trav.get_all(&Val::Int(5)), vec![5]);
232        assert_eq!(trav.get_all(&Val::Str("hi".into())), Vec::<i32>::new());
233        assert_eq!(trav.over(Val::Int(5), |x| x * 2), Val::Int(10));
234        assert_eq!(
235            trav.over(Val::Str("hi".into()), |x| x * 2),
236            Val::Str("hi".into())
237        );
238    }
239
240    #[test]
241    fn traversal_composition_law() {
242        // over(over(s, f), g) == over(s, g . f)
243        let trav = vec_each_traversal();
244        let v = vec![1, 2, 3];
245        let left = trav.over(trav.over(v.clone(), |x| x + 1), |x| x * 2);
246        let right = trav.over(v, |x| (x + 1) * 2);
247        assert_eq!(left, right);
248    }
249}