Skip to main content

aspect_std/
validation.rs

1//! Validation aspect for pre/post condition checking.
2
3use aspect_core::{Aspect, AspectError, JoinPoint, ProceedingJoinPoint};
4use std::any::Any;
5use std::sync::Arc;
6
7/// Validation rule trait.
8///
9/// Implement this trait to create custom validation rules that can be
10/// composed and applied to functions.
11pub trait ValidationRule: Send + Sync {
12    /// Validate the input.
13    ///
14    /// Returns `Ok(())` if validation passes, or `Err(message)` if it fails.
15    fn validate(&self, ctx: &JoinPoint) -> Result<(), String>;
16
17    /// Get a description of this validation rule.
18    fn description(&self) -> &str {
19        "validation rule"
20    }
21}
22
23/// Validation aspect for enforcing constraints.
24///
25/// Allows composing multiple validation rules that are checked before
26/// function execution.
27///
28/// # Example
29///
30/// ```rust,ignore
31/// use aspect_std::{ValidationAspect, ValidationRule};
32/// use aspect_macros::aspect;
33///
34/// struct AgeValidator;
35/// impl ValidationRule for AgeValidator {
36///     fn validate(&self, ctx: &JoinPoint) -> Result<(), String> {
37///         // Validation logic
38///         Ok(())
39///     }
40/// }
41///
42/// let validator = ValidationAspect::new()
43///     .add_rule(Box::new(AgeValidator));
44///
45/// #[aspect(validator)]
46/// fn set_age(age: i32) -> Result<(), String> {
47///     Ok(())
48/// }
49/// ```
50pub struct ValidationAspect {
51    rules: Vec<Box<dyn ValidationRule>>,
52}
53
54impl ValidationAspect {
55    /// Create a new validation aspect.
56    pub fn new() -> Self {
57        Self { rules: Vec::new() }
58    }
59
60    /// Add a validation rule.
61    pub fn add_rule(mut self, rule: Box<dyn ValidationRule>) -> Self {
62        self.rules.push(rule);
63        self
64    }
65
66    /// Run all validation rules.
67    fn validate(&self, ctx: &JoinPoint) -> Result<(), AspectError> {
68        for rule in self.rules.iter() {
69            if let Err(msg) = rule.validate(ctx) {
70                return Err(AspectError::execution(format!(
71                    "Validation failed for {}: {}",
72                    ctx.function_name, msg
73                )));
74            }
75        }
76        Ok(())
77    }
78}
79
80impl Default for ValidationAspect {
81    fn default() -> Self {
82        Self::new()
83    }
84}
85
86impl Aspect for ValidationAspect {
87    fn around(&self, pjp: ProceedingJoinPoint) -> Result<Box<dyn Any>, AspectError> {
88        // Validate before execution
89        self.validate(pjp.context())?;
90
91        // Execute the function
92        let result = pjp.proceed();
93
94        // Note: After-validation is not supported in this simplified version
95        // as it would require cloning the context
96        result
97    }
98}
99
100// Common validation rules
101
102/// Validates that a value is not empty.
103pub struct NotEmptyValidator {
104    field_name: String,
105    getter: Arc<dyn Fn(&JoinPoint) -> Option<String> + Send + Sync>,
106}
107
108impl NotEmptyValidator {
109    /// Create a new not-empty validator.
110    ///
111    /// # Arguments
112    /// * `field_name` - Name of the field being validated
113    /// * `getter` - Function to extract the value from JoinPoint
114    pub fn new<F>(field_name: &str, getter: F) -> Self
115    where
116        F: Fn(&JoinPoint) -> Option<String> + Send + Sync + 'static,
117    {
118        Self {
119            field_name: field_name.to_string(),
120            getter: Arc::new(getter),
121        }
122    }
123}
124
125impl ValidationRule for NotEmptyValidator {
126    fn validate(&self, ctx: &JoinPoint) -> Result<(), String> {
127        if let Some(value) = (self.getter)(ctx) {
128            if value.is_empty() {
129                return Err(format!("{} cannot be empty", self.field_name));
130            }
131        }
132        Ok(())
133    }
134
135    fn description(&self) -> &str {
136        "not empty"
137    }
138}
139
140/// Validates that a numeric value is within a range.
141pub struct RangeValidator {
142    field_name: String,
143    min: i64,
144    max: i64,
145    getter: Arc<dyn Fn(&JoinPoint) -> Option<i64> + Send + Sync>,
146}
147
148impl RangeValidator {
149    /// Create a new range validator.
150    ///
151    /// # Arguments
152    /// * `field_name` - Name of the field being validated
153    /// * `min` - Minimum allowed value (inclusive)
154    /// * `max` - Maximum allowed value (inclusive)
155    /// * `getter` - Function to extract the value from JoinPoint
156    pub fn new<F>(field_name: &str, min: i64, max: i64, getter: F) -> Self
157    where
158        F: Fn(&JoinPoint) -> Option<i64> + Send + Sync + 'static,
159    {
160        Self {
161            field_name: field_name.to_string(),
162            min,
163            max,
164            getter: Arc::new(getter),
165        }
166    }
167}
168
169impl ValidationRule for RangeValidator {
170    fn validate(&self, ctx: &JoinPoint) -> Result<(), String> {
171        if let Some(value) = (self.getter)(ctx) {
172            if value < self.min || value > self.max {
173                return Err(format!(
174                    "{} must be between {} and {}, got {}",
175                    self.field_name, self.min, self.max, value
176                ));
177            }
178        }
179        Ok(())
180    }
181
182    fn description(&self) -> &str {
183        "range check"
184    }
185}
186
187/// Custom validation rule using a closure.
188pub struct CustomValidator {
189    description: String,
190    validator: Arc<dyn Fn(&JoinPoint) -> Result<(), String> + Send + Sync>,
191}
192
193impl CustomValidator {
194    /// Create a custom validator from a closure.
195    ///
196    /// # Arguments
197    /// * `description` - Description of this validation
198    /// * `validator` - Validation function
199    pub fn new<F>(description: &str, validator: F) -> Self
200    where
201        F: Fn(&JoinPoint) -> Result<(), String> + Send + Sync + 'static,
202    {
203        Self {
204            description: description.to_string(),
205            validator: Arc::new(validator),
206        }
207    }
208}
209
210impl ValidationRule for CustomValidator {
211    fn validate(&self, ctx: &JoinPoint) -> Result<(), String> {
212        (self.validator)(ctx)
213    }
214
215    fn description(&self) -> &str {
216        &self.description
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223
224    #[test]
225    fn test_validation_aspect_creation() {
226        let validator = ValidationAspect::new();
227        assert_eq!(validator.rules.len(), 0);
228    }
229
230    #[test]
231    fn test_custom_validator() {
232        let validator = CustomValidator::new("test", |_ctx| Ok(()));
233        let ctx = JoinPoint {
234            function_name: "test",
235            module_path: "test",
236            location: aspect_core::Location {
237                file: "test.rs",
238                line: 1,
239            },
240        };
241
242        assert!(validator.validate(&ctx).is_ok());
243    }
244
245    #[test]
246    fn test_custom_validator_failure() {
247        let validator = CustomValidator::new("test", |_ctx| Err("validation failed".to_string()));
248        let ctx = JoinPoint {
249            function_name: "test",
250            module_path: "test",
251            location: aspect_core::Location {
252                file: "test.rs",
253                line: 1,
254            },
255        };
256
257        assert!(validator.validate(&ctx).is_err());
258    }
259
260    #[test]
261    fn test_not_empty_validator() {
262        let validator = NotEmptyValidator::new("username", |_ctx| Some("alice".to_string()));
263        let ctx = JoinPoint {
264            function_name: "test",
265            module_path: "test",
266            location: aspect_core::Location {
267                file: "test.rs",
268                line: 1,
269            },
270        };
271
272        assert!(validator.validate(&ctx).is_ok());
273    }
274
275    #[test]
276    fn test_not_empty_validator_failure() {
277        let validator = NotEmptyValidator::new("username", |_ctx| Some("".to_string()));
278        let ctx = JoinPoint {
279            function_name: "test",
280            module_path: "test",
281            location: aspect_core::Location {
282                file: "test.rs",
283                line: 1,
284            },
285        };
286
287        let result = validator.validate(&ctx);
288        assert!(result.is_err());
289        assert!(result.unwrap_err().contains("cannot be empty"));
290    }
291
292    #[test]
293    fn test_range_validator() {
294        let validator = RangeValidator::new("age", 0, 120, |_ctx| Some(25));
295        let ctx = JoinPoint {
296            function_name: "test",
297            module_path: "test",
298            location: aspect_core::Location {
299                file: "test.rs",
300                line: 1,
301            },
302        };
303
304        assert!(validator.validate(&ctx).is_ok());
305    }
306
307    #[test]
308    fn test_range_validator_failure() {
309        let validator = RangeValidator::new("age", 0, 120, |_ctx| Some(150));
310        let ctx = JoinPoint {
311            function_name: "test",
312            module_path: "test",
313            location: aspect_core::Location {
314                file: "test.rs",
315                line: 1,
316            },
317        };
318
319        let result = validator.validate(&ctx);
320        assert!(result.is_err());
321        assert!(result.unwrap_err().contains("must be between"));
322    }
323}