1use crate::optic::Optic;
2use karpal_profunctor::Traversing;
3use std::rc::Rc;
4
5pub struct Traversal<S, T, A, B> {
9 #[allow(clippy::type_complexity)]
10 get_all: Rc<dyn Fn(&S) -> Vec<A>>,
11 #[allow(clippy::type_complexity)]
12 modify_all: Rc<dyn Fn(S, &dyn Fn(A) -> B) -> T>,
13}
14
15pub type SimpleTraversal<S, A> = Traversal<S, S, A, A>;
17
18impl<S, T, A, B> Optic for Traversal<S, T, A, B> {}
19
20impl<S, T, A, B> Traversal<S, T, A, B> {
21 pub fn new(
22 get_all: impl Fn(&S) -> Vec<A> + 'static,
23 modify_all: impl Fn(S, &dyn Fn(A) -> B) -> T + 'static,
24 ) -> Self {
25 Self {
26 get_all: Rc::new(get_all),
27 modify_all: Rc::new(modify_all),
28 }
29 }
30
31 pub fn get_all(&self, s: &S) -> Vec<A> {
32 (self.get_all)(s)
33 }
34
35 pub fn over(&self, s: S, f: impl Fn(A) -> B) -> T {
36 (self.modify_all)(s, &f)
37 }
38
39 pub fn set(&self, s: S, b: B) -> T
40 where
41 B: Clone,
42 {
43 (self.modify_all)(s, &|_| b.clone())
44 }
45
46 pub fn transform<P: Traversing>(&self, pab: P::P<A, B>) -> P::P<S, T>
48 where
49 S: 'static,
50 T: 'static,
51 A: 'static,
52 B: 'static,
53 {
54 let get_all = Rc::clone(&self.get_all);
55 let modify_all = Rc::clone(&self.modify_all);
56 P::wander(move |s| get_all(s), move |s, f| modify_all(s, f), pab)
57 }
58
59 pub fn to_fold(&self) -> crate::fold::Fold<S, A>
61 where
62 S: 'static,
63 A: 'static,
64 {
65 let get_all = Rc::clone(&self.get_all);
66 crate::fold::Fold::new(move |s| get_all(s))
67 }
68
69 pub fn then<X, Y>(self, inner: Traversal<A, B, X, Y>) -> ComposedTraversal<S, T, X, Y>
71 where
72 S: 'static,
73 T: 'static,
74 A: 'static,
75 B: 'static,
76 X: 'static,
77 Y: 'static,
78 {
79 let outer_get_all = self.get_all;
80 let inner_get_all = Rc::clone(&inner.get_all);
81 let outer_modify_all = self.modify_all;
82 let inner_modify_all = inner.modify_all;
83 ComposedTraversal {
84 get_all: Box::new(move |s| {
85 outer_get_all(s)
86 .into_iter()
87 .flat_map(|a| inner_get_all(&a))
88 .collect()
89 }),
90 modify_all: Box::new(move |s, f| outer_modify_all(s, &|a| (inner_modify_all)(a, f))),
91 }
92 }
93}
94
95pub struct ComposedTraversal<S, T, A, B> {
97 #[allow(clippy::type_complexity)]
98 get_all: Box<dyn Fn(&S) -> Vec<A>>,
99 #[allow(clippy::type_complexity)]
100 modify_all: Box<dyn Fn(S, &dyn Fn(A) -> B) -> T>,
101}
102
103impl<S, T, A, B> Optic for ComposedTraversal<S, T, A, B> {}
104
105impl<S, T, A, B> ComposedTraversal<S, T, A, B> {
106 pub fn get_all(&self, s: &S) -> Vec<A> {
107 (self.get_all)(s)
108 }
109
110 pub fn over(&self, s: S, f: impl Fn(A) -> B) -> T {
111 (self.modify_all)(s, &f)
112 }
113
114 pub fn set(&self, s: S, b: B) -> T
115 where
116 B: Clone,
117 {
118 (self.modify_all)(s, &|_| b.clone())
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125 use karpal_profunctor::{FnP, ForgetF};
126
127 fn vec_each_traversal() -> SimpleTraversal<Vec<i32>, i32> {
128 Traversal::new(
129 |v: &Vec<i32>| v.clone(),
130 |v: Vec<i32>, f: &dyn Fn(i32) -> i32| v.into_iter().map(f).collect(),
131 )
132 }
133
134 #[test]
135 fn traversal_get_all() {
136 let trav = vec_each_traversal();
137 assert_eq!(trav.get_all(&vec![1, 2, 3]), vec![1, 2, 3]);
138 }
139
140 #[test]
141 fn traversal_over() {
142 let trav = vec_each_traversal();
143 assert_eq!(trav.over(vec![1, 2, 3], |x| x * 10), vec![10, 20, 30]);
144 }
145
146 #[test]
147 fn traversal_set() {
148 let trav = vec_each_traversal();
149 assert_eq!(trav.set(vec![1, 2, 3], 0), vec![0, 0, 0]);
150 }
151
152 #[test]
153 fn traversal_transform_fnp() {
154 let trav = vec_each_traversal();
155 let double: Box<dyn Fn(i32) -> i32> = Box::new(|x| x * 2);
156 let f = trav.transform::<FnP>(double);
157 assert_eq!(f(vec![1, 2, 3]), vec![2, 4, 6]);
158 }
159
160 #[test]
161 fn traversal_transform_forget() {
162 let trav = vec_each_traversal();
163 let to_string: Box<dyn Fn(i32) -> String> = Box::new(|x| x.to_string());
164 let f = trav.transform::<ForgetF<String>>(to_string);
165 assert_eq!(f(vec![1, 2, 3]), "123");
167 }
168
169 #[test]
170 fn traversal_identity_law() {
171 let trav = vec_each_traversal();
172 let v = vec![1, 2, 3];
173 assert_eq!(trav.over(v.clone(), |x| x), v);
174 }
175
176 #[test]
178 fn traversal_composition() {
179 let outer: SimpleTraversal<Vec<Vec<i32>>, Vec<i32>> = Traversal::new(
180 |v: &Vec<Vec<i32>>| v.clone(),
181 |v: Vec<Vec<i32>>, f: &dyn Fn(Vec<i32>) -> Vec<i32>| {
182 v.into_iter().map(f).collect::<Vec<_>>()
183 },
184 );
185 let inner = vec_each_traversal();
186 let composed = outer.then(inner);
187 assert_eq!(composed.get_all(&vec![vec![1, 2], vec![3]]), vec![1, 2, 3]);
188 assert_eq!(
189 composed.over(vec![vec![1, 2], vec![3]], |x| x * 10),
190 vec![vec![10, 20], vec![30]]
191 );
192 }
193
194 #[test]
195 fn traversal_from_lens() {
196 use crate::lens::Lens;
197
198 #[derive(Debug, Clone, PartialEq)]
199 struct Point {
200 x: i32,
201 y: i32,
202 }
203
204 let lens = Lens::new(|p: &Point| p.x, |p: Point, x| Point { x, ..p });
205 let trav = lens.to_traversal();
206 let p = Point { x: 1, y: 2 };
207 assert_eq!(trav.get_all(&p), vec![1]);
208 assert_eq!(trav.over(p, |x| x + 10), Point { x: 11, y: 2 });
209 }
210
211 #[test]
212 fn traversal_from_prism() {
213 use crate::prism::Prism;
214
215 #[derive(Debug, Clone, PartialEq)]
216 enum Val {
217 Int(i32),
218 Str(String),
219 }
220 let prism = Prism::new(
221 |v: Val| match v {
222 Val::Int(n) => Ok(n),
223 Val::Str(s) => Err(Val::Str(s)),
224 },
225 Val::Int,
226 );
227 let trav = prism.to_traversal();
228 assert_eq!(trav.get_all(&Val::Int(5)), vec![5]);
229 assert_eq!(trav.get_all(&Val::Str("hi".into())), Vec::<i32>::new());
230 assert_eq!(trav.over(Val::Int(5), |x| x * 2), Val::Int(10));
231 assert_eq!(
232 trav.over(Val::Str("hi".into()), |x| x * 2),
233 Val::Str("hi".into())
234 );
235 }
236
237 #[test]
238 fn traversal_composition_law() {
239 let trav = vec_each_traversal();
241 let v = vec![1, 2, 3];
242 let left = trav.over(trav.over(v.clone(), |x| x + 1), |x| x * 2);
243 let right = trav.over(v, |x| (x + 1) * 2);
244 assert_eq!(left, right);
245 }
246}