1use async_trait::async_trait;
4use ranvier_core::{bus::Bus, outcome::Outcome, transition::Transition};
5use serde::{Deserialize, Serialize};
6use std::fmt::Debug;
7use std::marker::PhantomData;
8
9#[derive(Debug, Clone)]
11pub struct RequiredNode<T> {
12 field_name: String,
13 _marker: PhantomData<T>,
14}
15
16impl<T> RequiredNode<T> {
17 pub fn new(field_name: impl Into<String>) -> Self {
18 Self {
19 field_name: field_name.into(),
20 _marker: PhantomData,
21 }
22 }
23}
24
25#[async_trait]
26impl<T> Transition<Option<T>, T> for RequiredNode<T>
27where
28 T: Send + Sync + 'static,
29{
30 type Error = String;
31 type Resources = ();
32
33 async fn run(
34 &self,
35 input: Option<T>,
36 _resources: &Self::Resources,
37 _bus: &mut Bus,
38 ) -> Outcome<T, Self::Error> {
39 match input {
40 Some(value) => Outcome::next(value),
41 None => Outcome::fault(format!("Required field '{}' is missing", self.field_name)),
42 }
43 }
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct RangeValidator<T> {
49 pub min: T,
50 pub max: T,
51 pub field_name: String,
52}
53
54impl<T> RangeValidator<T> {
55 pub fn new(min: T, max: T, field_name: impl Into<String>) -> Self {
56 Self {
57 min,
58 max,
59 field_name: field_name.into(),
60 }
61 }
62}
63
64#[async_trait]
65impl<T> Transition<T, T> for RangeValidator<T>
66where
67 T: PartialOrd + Debug + Send + Sync + Clone + 'static,
68{
69 type Error = String;
70 type Resources = ();
71
72 async fn run(
73 &self,
74 input: T,
75 _resources: &Self::Resources,
76 _bus: &mut Bus,
77 ) -> Outcome<T, Self::Error> {
78 if input < self.min || input > self.max {
79 Outcome::fault(format!(
80 "Field '{}' value {:?} out of range [{:?}, {:?}]",
81 self.field_name, input, self.min, self.max
82 ))
83 } else {
84 Outcome::next(input)
85 }
86 }
87}
88
89#[derive(Debug, Clone)]
91pub struct PatternValidator {
92 pattern: String,
93 field_name: String,
94}
95
96impl PatternValidator {
97 pub fn new(pattern: impl Into<String>, field_name: impl Into<String>) -> Self {
98 Self {
99 pattern: pattern.into(),
100 field_name: field_name.into(),
101 }
102 }
103}
104
105#[async_trait]
106impl Transition<String, String> for PatternValidator {
107 type Error = String;
108 type Resources = ();
109
110 async fn run(
111 &self,
112 input: String,
113 _resources: &Self::Resources,
114 _bus: &mut Bus,
115 ) -> Outcome<String, Self::Error> {
116 if input.contains(&self.pattern) || self.pattern == "*" {
119 Outcome::next(input)
120 } else {
121 Outcome::fault(format!(
122 "Field '{}' value '{}' does not match pattern '{}'",
123 self.field_name, input, self.pattern
124 ))
125 }
126 }
127}
128
129#[derive(Debug, Clone)]
131pub struct SchemaValidator {
132 required_fields: Vec<String>,
133}
134
135impl SchemaValidator {
136 pub fn new(required_fields: Vec<String>) -> Self {
137 Self { required_fields }
138 }
139}
140
141#[async_trait]
142impl Transition<serde_json::Value, serde_json::Value> for SchemaValidator {
143 type Error = String;
144 type Resources = ();
145
146 async fn run(
147 &self,
148 input: serde_json::Value,
149 _resources: &Self::Resources,
150 _bus: &mut Bus,
151 ) -> Outcome<serde_json::Value, Self::Error> {
152 if let serde_json::Value::Object(ref map) = input {
153 for field in &self.required_fields {
154 if !map.contains_key(field) {
155 return Outcome::fault(format!("Missing required field: '{field}'"));
156 }
157 }
158 Outcome::next(input)
159 } else {
160 Outcome::fault("Expected JSON object".to_string())
161 }
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168
169 #[tokio::test]
170 async fn required_node_passes_some() {
171 let node = RequiredNode::<i32>::new("count");
172 let mut bus = Bus::new();
173 let result = node.run(Some(42), &(), &mut bus).await;
174 assert!(matches!(result, Outcome::Next(42)));
175 }
176
177 #[tokio::test]
178 async fn required_node_faults_none() {
179 let node = RequiredNode::<i32>::new("count");
180 let mut bus = Bus::new();
181 let result = node.run(None, &(), &mut bus).await;
182 assert!(matches!(result, Outcome::Fault(_)));
183 }
184
185 #[tokio::test]
186 async fn range_validator_in_range() {
187 let node = RangeValidator::new(1, 100, "age");
188 let mut bus = Bus::new();
189 let result = node.run(25, &(), &mut bus).await;
190 assert!(matches!(result, Outcome::Next(25)));
191 }
192
193 #[tokio::test]
194 async fn range_validator_out_of_range() {
195 let node = RangeValidator::new(1, 100, "age");
196 let mut bus = Bus::new();
197 let result = node.run(200, &(), &mut bus).await;
198 assert!(matches!(result, Outcome::Fault(_)));
199 }
200
201 #[tokio::test]
202 async fn pattern_validator_matches() {
203 let node = PatternValidator::new("@", "email");
204 let mut bus = Bus::new();
205 let result = node.run("user@example.com".into(), &(), &mut bus).await;
206 assert!(matches!(result, Outcome::Next(_)));
207 }
208
209 #[tokio::test]
210 async fn pattern_validator_no_match() {
211 let node = PatternValidator::new("@", "email");
212 let mut bus = Bus::new();
213 let result = node.run("invalid-email".into(), &(), &mut bus).await;
214 assert!(matches!(result, Outcome::Fault(_)));
215 }
216
217 #[tokio::test]
218 async fn schema_validator_passes() {
219 let node = SchemaValidator::new(vec!["name".into(), "age".into()]);
220 let mut bus = Bus::new();
221 let input = serde_json::json!({"name": "Alice", "age": 30});
222 let result = node.run(input, &(), &mut bus).await;
223 assert!(matches!(result, Outcome::Next(_)));
224 }
225
226 #[tokio::test]
227 async fn schema_validator_missing_field() {
228 let node = SchemaValidator::new(vec!["name".into(), "age".into()]);
229 let mut bus = Bus::new();
230 let input = serde_json::json!({"name": "Alice"});
231 let result = node.run(input, &(), &mut bus).await;
232 assert!(matches!(result, Outcome::Fault(_)));
233 }
234}