1use async_trait::async_trait;
4use ranvier_core::{bus::Bus, outcome::Outcome, transition::Transition};
5use std::marker::PhantomData;
6use std::sync::Arc;
7
8pub struct MapNode<In, Out, F> {
10 f: Arc<F>,
11 _marker: PhantomData<(In, Out)>,
12}
13
14impl<In, Out, F> MapNode<In, Out, F>
15where
16 F: Fn(In) -> Out + Send + Sync + 'static,
17{
18 pub fn new(f: F) -> Self {
19 Self {
20 f: Arc::new(f),
21 _marker: PhantomData,
22 }
23 }
24}
25
26impl<In, Out, F> Clone for MapNode<In, Out, F> {
27 fn clone(&self) -> Self {
28 Self {
29 f: self.f.clone(),
30 _marker: PhantomData,
31 }
32 }
33}
34
35impl<In, Out, F> std::fmt::Debug for MapNode<In, Out, F> {
36 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37 f.debug_struct("MapNode").finish()
38 }
39}
40
41#[async_trait]
42impl<In, Out, F> Transition<In, Out> for MapNode<In, Out, F>
43where
44 In: Send + Sync + 'static,
45 Out: Send + Sync + 'static,
46 F: Fn(In) -> Out + Send + Sync + 'static,
47{
48 type Error = String;
49 type Resources = ();
50
51 async fn run(
52 &self,
53 input: In,
54 _resources: &Self::Resources,
55 _bus: &mut Bus,
56 ) -> Outcome<Out, Self::Error> {
57 Outcome::next((self.f)(input))
58 }
59}
60
61pub struct FilterTransformNode<T, F> {
63 predicate: Arc<F>,
64 _marker: PhantomData<T>,
65}
66
67impl<T, F> FilterTransformNode<T, F>
68where
69 F: Fn(&T) -> bool + Send + Sync + 'static,
70{
71 pub fn new(predicate: F) -> Self {
72 Self {
73 predicate: Arc::new(predicate),
74 _marker: PhantomData,
75 }
76 }
77}
78
79impl<T, F> Clone for FilterTransformNode<T, F> {
80 fn clone(&self) -> Self {
81 Self {
82 predicate: self.predicate.clone(),
83 _marker: PhantomData,
84 }
85 }
86}
87
88impl<T, F> std::fmt::Debug for FilterTransformNode<T, F> {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 f.debug_struct("FilterTransformNode").finish()
91 }
92}
93
94#[async_trait]
95impl<T, F> Transition<Vec<T>, Vec<T>> for FilterTransformNode<T, F>
96where
97 T: Send + Sync + 'static,
98 F: Fn(&T) -> bool + Send + Sync + 'static,
99{
100 type Error = String;
101 type Resources = ();
102
103 async fn run(
104 &self,
105 input: Vec<T>,
106 _resources: &Self::Resources,
107 _bus: &mut Bus,
108 ) -> Outcome<Vec<T>, Self::Error> {
109 let filtered: Vec<T> = input.into_iter().filter(|x| (self.predicate)(x)).collect();
110 Outcome::next(filtered)
111 }
112}
113
114#[derive(Debug, Clone)]
116pub struct FlattenNode<T> {
117 _marker: PhantomData<T>,
118}
119
120impl<T> FlattenNode<T> {
121 pub fn new() -> Self {
122 Self {
123 _marker: PhantomData,
124 }
125 }
126}
127
128impl<T> Default for FlattenNode<T> {
129 fn default() -> Self {
130 Self::new()
131 }
132}
133
134#[async_trait]
135impl<T> Transition<Vec<Vec<T>>, Vec<T>> for FlattenNode<T>
136where
137 T: Send + Sync + 'static,
138{
139 type Error = String;
140 type Resources = ();
141
142 async fn run(
143 &self,
144 input: Vec<Vec<T>>,
145 _resources: &Self::Resources,
146 _bus: &mut Bus,
147 ) -> Outcome<Vec<T>, Self::Error> {
148 Outcome::next(input.into_iter().flatten().collect())
149 }
150}
151
152#[derive(Debug, Clone)]
154pub struct MergeNode;
155
156impl MergeNode {
157 pub fn new() -> Self {
158 Self
159 }
160}
161
162impl Default for MergeNode {
163 fn default() -> Self {
164 Self::new()
165 }
166}
167
168#[async_trait]
169impl Transition<(serde_json::Value, serde_json::Value), serde_json::Value> for MergeNode {
170 type Error = String;
171 type Resources = ();
172
173 async fn run(
174 &self,
175 input: (serde_json::Value, serde_json::Value),
176 _resources: &Self::Resources,
177 _bus: &mut Bus,
178 ) -> Outcome<serde_json::Value, Self::Error> {
179 let (mut base, overlay) = input;
180 if let (serde_json::Value::Object(base_map), serde_json::Value::Object(overlay_map)) =
181 (&mut base, overlay)
182 {
183 for (k, v) in overlay_map {
184 base_map.insert(k, v);
185 }
186 Outcome::next(base)
187 } else {
188 Outcome::fault("MergeNode requires two JSON objects".to_string())
189 }
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196
197 #[tokio::test]
198 async fn map_node_transforms_value() {
199 let node = MapNode::new(|x: i32| x * 2);
200 let mut bus = Bus::new();
201 let result = node.run(21, &(), &mut bus).await;
202 assert!(matches!(result, Outcome::Next(42)));
203 }
204
205 #[tokio::test]
206 async fn map_node_type_conversion() {
207 let node = MapNode::new(|x: i32| x.to_string());
208 let mut bus = Bus::new();
209 let result = node.run(42, &(), &mut bus).await;
210 match result {
211 Outcome::Next(s) => assert_eq!(s, "42"),
212 _ => panic!("Expected Next"),
213 }
214 }
215
216 #[tokio::test]
217 async fn filter_transform_keeps_matching() {
218 let node = FilterTransformNode::new(|x: &i32| *x > 3);
219 let mut bus = Bus::new();
220 let result = node.run(vec![1, 2, 3, 4, 5], &(), &mut bus).await;
221 match result {
222 Outcome::Next(v) => assert_eq!(v, vec![4, 5]),
223 _ => panic!("Expected Next"),
224 }
225 }
226
227 #[tokio::test]
228 async fn flatten_node_flattens() {
229 let node = FlattenNode::<i32>::new();
230 let mut bus = Bus::new();
231 let result = node.run(vec![vec![1, 2], vec![3, 4]], &(), &mut bus).await;
232 match result {
233 Outcome::Next(v) => assert_eq!(v, vec![1, 2, 3, 4]),
234 _ => panic!("Expected Next"),
235 }
236 }
237
238 #[tokio::test]
239 async fn merge_node_combines_objects() {
240 let node = MergeNode::new();
241 let mut bus = Bus::new();
242 let a = serde_json::json!({"name": "Alice", "age": 30});
243 let b = serde_json::json!({"age": 31, "city": "NYC"});
244 let result = node.run((a, b), &(), &mut bus).await;
245 match result {
246 Outcome::Next(v) => {
247 assert_eq!(v["name"], "Alice");
248 assert_eq!(v["age"], 31); assert_eq!(v["city"], "NYC");
250 }
251 _ => panic!("Expected Next"),
252 }
253 }
254}