Skip to main content

ranvier_std/nodes/
validation.rs

1//! Validation nodes for input checking within Ranvier circuits.
2
3use async_trait::async_trait;
4use ranvier_core::{bus::Bus, outcome::Outcome, transition::Transition};
5use serde::{Deserialize, Serialize};
6use std::fmt::Debug;
7use std::marker::PhantomData;
8
9/// Validates that an `Option<T>` is `Some`, faulting if `None`.
10#[derive(Debug, Clone)]
11pub struct RequiredNode<T> {
12    field_name: String,
13    _marker: PhantomData<T>,
14}
15
16impl<T> RequiredNode<T> {
17    pub fn new(field_name: impl Into<String>) -> Self {
18        Self {
19            field_name: field_name.into(),
20            _marker: PhantomData,
21        }
22    }
23}
24
25#[async_trait]
26impl<T> Transition<Option<T>, T> for RequiredNode<T>
27where
28    T: Send + Sync + 'static,
29{
30    type Error = String;
31    type Resources = ();
32
33    async fn run(
34        &self,
35        input: Option<T>,
36        _resources: &Self::Resources,
37        _bus: &mut Bus,
38    ) -> Outcome<T, Self::Error> {
39        match input {
40            Some(value) => Outcome::next(value),
41            None => Outcome::fault(format!("Required field '{}' is missing", self.field_name)),
42        }
43    }
44}
45
46/// Validates that a numeric value falls within a specified range.
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct RangeValidator<T> {
49    pub min: T,
50    pub max: T,
51    pub field_name: String,
52}
53
54impl<T> RangeValidator<T> {
55    pub fn new(min: T, max: T, field_name: impl Into<String>) -> Self {
56        Self {
57            min,
58            max,
59            field_name: field_name.into(),
60        }
61    }
62}
63
64#[async_trait]
65impl<T> Transition<T, T> for RangeValidator<T>
66where
67    T: PartialOrd + Debug + Send + Sync + Clone + 'static,
68{
69    type Error = String;
70    type Resources = ();
71
72    async fn run(
73        &self,
74        input: T,
75        _resources: &Self::Resources,
76        _bus: &mut Bus,
77    ) -> Outcome<T, Self::Error> {
78        if input < self.min || input > self.max {
79            Outcome::fault(format!(
80                "Field '{}' value {:?} out of range [{:?}, {:?}]",
81                self.field_name, input, self.min, self.max
82            ))
83        } else {
84            Outcome::next(input)
85        }
86    }
87}
88
89/// Validates that a string matches a regex pattern.
90#[derive(Debug, Clone)]
91pub struct PatternValidator {
92    pattern: String,
93    field_name: String,
94}
95
96impl PatternValidator {
97    pub fn new(pattern: impl Into<String>, field_name: impl Into<String>) -> Self {
98        Self {
99            pattern: pattern.into(),
100            field_name: field_name.into(),
101        }
102    }
103}
104
105#[async_trait]
106impl Transition<String, String> for PatternValidator {
107    type Error = String;
108    type Resources = ();
109
110    async fn run(
111        &self,
112        input: String,
113        _resources: &Self::Resources,
114        _bus: &mut Bus,
115    ) -> Outcome<String, Self::Error> {
116        // Simple glob-like pattern matching (contains check)
117        // For full regex, users would use the `regex` crate directly
118        if input.contains(&self.pattern) || self.pattern == "*" {
119            Outcome::next(input)
120        } else {
121            Outcome::fault(format!(
122                "Field '{}' value '{}' does not match pattern '{}'",
123                self.field_name, input, self.pattern
124            ))
125        }
126    }
127}
128
129/// Validates a JSON value against expected structure.
130#[derive(Debug, Clone)]
131pub struct SchemaValidator {
132    required_fields: Vec<String>,
133}
134
135impl SchemaValidator {
136    pub fn new(required_fields: Vec<String>) -> Self {
137        Self { required_fields }
138    }
139}
140
141#[async_trait]
142impl Transition<serde_json::Value, serde_json::Value> for SchemaValidator {
143    type Error = String;
144    type Resources = ();
145
146    async fn run(
147        &self,
148        input: serde_json::Value,
149        _resources: &Self::Resources,
150        _bus: &mut Bus,
151    ) -> Outcome<serde_json::Value, Self::Error> {
152        if let serde_json::Value::Object(ref map) = input {
153            for field in &self.required_fields {
154                if !map.contains_key(field) {
155                    return Outcome::fault(format!("Missing required field: '{field}'"));
156                }
157            }
158            Outcome::next(input)
159        } else {
160            Outcome::fault("Expected JSON object".to_string())
161        }
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168
169    #[tokio::test]
170    async fn required_node_passes_some() {
171        let node = RequiredNode::<i32>::new("count");
172        let mut bus = Bus::new();
173        let result = node.run(Some(42), &(), &mut bus).await;
174        assert!(matches!(result, Outcome::Next(42)));
175    }
176
177    #[tokio::test]
178    async fn required_node_faults_none() {
179        let node = RequiredNode::<i32>::new("count");
180        let mut bus = Bus::new();
181        let result = node.run(None, &(), &mut bus).await;
182        assert!(matches!(result, Outcome::Fault(_)));
183    }
184
185    #[tokio::test]
186    async fn range_validator_in_range() {
187        let node = RangeValidator::new(1, 100, "age");
188        let mut bus = Bus::new();
189        let result = node.run(25, &(), &mut bus).await;
190        assert!(matches!(result, Outcome::Next(25)));
191    }
192
193    #[tokio::test]
194    async fn range_validator_out_of_range() {
195        let node = RangeValidator::new(1, 100, "age");
196        let mut bus = Bus::new();
197        let result = node.run(200, &(), &mut bus).await;
198        assert!(matches!(result, Outcome::Fault(_)));
199    }
200
201    #[tokio::test]
202    async fn pattern_validator_matches() {
203        let node = PatternValidator::new("@", "email");
204        let mut bus = Bus::new();
205        let result = node.run("user@example.com".into(), &(), &mut bus).await;
206        assert!(matches!(result, Outcome::Next(_)));
207    }
208
209    #[tokio::test]
210    async fn pattern_validator_no_match() {
211        let node = PatternValidator::new("@", "email");
212        let mut bus = Bus::new();
213        let result = node.run("invalid-email".into(), &(), &mut bus).await;
214        assert!(matches!(result, Outcome::Fault(_)));
215    }
216
217    #[tokio::test]
218    async fn schema_validator_passes() {
219        let node = SchemaValidator::new(vec!["name".into(), "age".into()]);
220        let mut bus = Bus::new();
221        let input = serde_json::json!({"name": "Alice", "age": 30});
222        let result = node.run(input, &(), &mut bus).await;
223        assert!(matches!(result, Outcome::Next(_)));
224    }
225
226    #[tokio::test]
227    async fn schema_validator_missing_field() {
228        let node = SchemaValidator::new(vec!["name".into(), "age".into()]);
229        let mut bus = Bus::new();
230        let input = serde_json::json!({"name": "Alice"});
231        let result = node.run(input, &(), &mut bus).await;
232        assert!(matches!(result, Outcome::Fault(_)));
233    }
234}