1use crate::optic::Optic;
5use karpal_profunctor::Traversing;
6use std::rc::Rc;
7
8pub struct Traversal<S, T, A, B> {
12 #[allow(clippy::type_complexity)]
13 get_all: Rc<dyn Fn(&S) -> Vec<A>>,
14 #[allow(clippy::type_complexity)]
15 modify_all: Rc<dyn Fn(S, &dyn Fn(A) -> B) -> T>,
16}
17
18pub type SimpleTraversal<S, A> = Traversal<S, S, A, A>;
20
21impl<S, T, A, B> Optic for Traversal<S, T, A, B> {}
22
23impl<S, T, A, B> Traversal<S, T, A, B> {
24 pub fn new(
25 get_all: impl Fn(&S) -> Vec<A> + 'static,
26 modify_all: impl Fn(S, &dyn Fn(A) -> B) -> T + 'static,
27 ) -> Self {
28 Self {
29 get_all: Rc::new(get_all),
30 modify_all: Rc::new(modify_all),
31 }
32 }
33
34 pub fn get_all(&self, s: &S) -> Vec<A> {
35 (self.get_all)(s)
36 }
37
38 pub fn over(&self, s: S, f: impl Fn(A) -> B) -> T {
39 (self.modify_all)(s, &f)
40 }
41
42 pub fn set(&self, s: S, b: B) -> T
43 where
44 B: Clone,
45 {
46 (self.modify_all)(s, &|_| b.clone())
47 }
48
49 pub fn transform<P: Traversing>(&self, pab: P::P<A, B>) -> P::P<S, T>
51 where
52 S: 'static,
53 T: 'static,
54 A: 'static,
55 B: 'static,
56 {
57 let get_all = Rc::clone(&self.get_all);
58 let modify_all = Rc::clone(&self.modify_all);
59 P::wander(move |s| get_all(s), move |s, f| modify_all(s, f), pab)
60 }
61
62 pub fn to_fold(&self) -> crate::fold::Fold<S, A>
64 where
65 S: 'static,
66 A: 'static,
67 {
68 let get_all = Rc::clone(&self.get_all);
69 crate::fold::Fold::new(move |s| get_all(s))
70 }
71
72 pub fn then<X, Y>(self, inner: Traversal<A, B, X, Y>) -> ComposedTraversal<S, T, X, Y>
74 where
75 S: 'static,
76 T: 'static,
77 A: 'static,
78 B: 'static,
79 X: 'static,
80 Y: 'static,
81 {
82 let outer_get_all = self.get_all;
83 let inner_get_all = Rc::clone(&inner.get_all);
84 let outer_modify_all = self.modify_all;
85 let inner_modify_all = inner.modify_all;
86 ComposedTraversal {
87 get_all: Box::new(move |s| {
88 outer_get_all(s)
89 .into_iter()
90 .flat_map(|a| inner_get_all(&a))
91 .collect()
92 }),
93 modify_all: Box::new(move |s, f| outer_modify_all(s, &|a| (inner_modify_all)(a, f))),
94 }
95 }
96}
97
98pub struct ComposedTraversal<S, T, A, B> {
100 #[allow(clippy::type_complexity)]
101 get_all: Box<dyn Fn(&S) -> Vec<A>>,
102 #[allow(clippy::type_complexity)]
103 modify_all: Box<dyn Fn(S, &dyn Fn(A) -> B) -> T>,
104}
105
106impl<S, T, A, B> Optic for ComposedTraversal<S, T, A, B> {}
107
108impl<S, T, A, B> ComposedTraversal<S, T, A, B> {
109 pub fn get_all(&self, s: &S) -> Vec<A> {
110 (self.get_all)(s)
111 }
112
113 pub fn over(&self, s: S, f: impl Fn(A) -> B) -> T {
114 (self.modify_all)(s, &f)
115 }
116
117 pub fn set(&self, s: S, b: B) -> T
118 where
119 B: Clone,
120 {
121 (self.modify_all)(s, &|_| b.clone())
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128 use karpal_profunctor::{FnP, ForgetF};
129
130 fn vec_each_traversal() -> SimpleTraversal<Vec<i32>, i32> {
131 Traversal::new(
132 |v: &Vec<i32>| v.clone(),
133 |v: Vec<i32>, f: &dyn Fn(i32) -> i32| v.into_iter().map(f).collect(),
134 )
135 }
136
137 #[test]
138 fn traversal_get_all() {
139 let trav = vec_each_traversal();
140 assert_eq!(trav.get_all(&vec![1, 2, 3]), vec![1, 2, 3]);
141 }
142
143 #[test]
144 fn traversal_over() {
145 let trav = vec_each_traversal();
146 assert_eq!(trav.over(vec![1, 2, 3], |x| x * 10), vec![10, 20, 30]);
147 }
148
149 #[test]
150 fn traversal_set() {
151 let trav = vec_each_traversal();
152 assert_eq!(trav.set(vec![1, 2, 3], 0), vec![0, 0, 0]);
153 }
154
155 #[test]
156 fn traversal_transform_fnp() {
157 let trav = vec_each_traversal();
158 let double: Box<dyn Fn(i32) -> i32> = Box::new(|x| x * 2);
159 let f = trav.transform::<FnP>(double);
160 assert_eq!(f(vec![1, 2, 3]), vec![2, 4, 6]);
161 }
162
163 #[test]
164 fn traversal_transform_forget() {
165 let trav = vec_each_traversal();
166 let to_string: Box<dyn Fn(i32) -> String> = Box::new(|x| x.to_string());
167 let f = trav.transform::<ForgetF<String>>(to_string);
168 assert_eq!(f(vec![1, 2, 3]), "123");
170 }
171
172 #[test]
173 fn traversal_identity_law() {
174 let trav = vec_each_traversal();
175 let v = vec![1, 2, 3];
176 assert_eq!(trav.over(v.clone(), |x| x), v);
177 }
178
179 #[test]
181 fn traversal_composition() {
182 let outer: SimpleTraversal<Vec<Vec<i32>>, Vec<i32>> = Traversal::new(
183 |v: &Vec<Vec<i32>>| v.clone(),
184 |v: Vec<Vec<i32>>, f: &dyn Fn(Vec<i32>) -> Vec<i32>| {
185 v.into_iter().map(f).collect::<Vec<_>>()
186 },
187 );
188 let inner = vec_each_traversal();
189 let composed = outer.then(inner);
190 assert_eq!(composed.get_all(&vec![vec![1, 2], vec![3]]), vec![1, 2, 3]);
191 assert_eq!(
192 composed.over(vec![vec![1, 2], vec![3]], |x| x * 10),
193 vec![vec![10, 20], vec![30]]
194 );
195 }
196
197 #[test]
198 fn traversal_from_lens() {
199 use crate::lens::Lens;
200
201 #[derive(Debug, Clone, PartialEq)]
202 struct Point {
203 x: i32,
204 y: i32,
205 }
206
207 let lens = Lens::new(|p: &Point| p.x, |p: Point, x| Point { x, ..p });
208 let trav = lens.to_traversal();
209 let p = Point { x: 1, y: 2 };
210 assert_eq!(trav.get_all(&p), vec![1]);
211 assert_eq!(trav.over(p, |x| x + 10), Point { x: 11, y: 2 });
212 }
213
214 #[test]
215 fn traversal_from_prism() {
216 use crate::prism::Prism;
217
218 #[derive(Debug, Clone, PartialEq)]
219 enum Val {
220 Int(i32),
221 Str(String),
222 }
223 let prism = Prism::new(
224 |v: Val| match v {
225 Val::Int(n) => Ok(n),
226 Val::Str(s) => Err(Val::Str(s)),
227 },
228 Val::Int,
229 );
230 let trav = prism.to_traversal();
231 assert_eq!(trav.get_all(&Val::Int(5)), vec![5]);
232 assert_eq!(trav.get_all(&Val::Str("hi".into())), Vec::<i32>::new());
233 assert_eq!(trav.over(Val::Int(5), |x| x * 2), Val::Int(10));
234 assert_eq!(
235 trav.over(Val::Str("hi".into()), |x| x * 2),
236 Val::Str("hi".into())
237 );
238 }
239
240 #[test]
241 fn traversal_composition_law() {
242 let trav = vec_each_traversal();
244 let v = vec![1, 2, 3];
245 let left = trav.over(trav.over(v.clone(), |x| x + 1), |x| x * 2);
246 let right = trav.over(v, |x| (x + 1) * 2);
247 assert_eq!(left, right);
248 }
249}