Skip to main content

spider_pipeline/
validation.rs

1//! Item Pipeline for validating scraped items.
2//!
3//! This module provides `ValidationPipeline`, a configurable pipeline that
4//! validates items using declarative field rules and custom validator closures.
5
6use crate::pipeline::Pipeline;
7use async_trait::async_trait;
8use log::{debug, warn};
9use serde_json::Value;
10use spider_util::{error::PipelineError, item::ScrapedItem};
11use std::collections::HashMap;
12use std::marker::PhantomData;
13use std::sync::Arc;
14
15type ValidatorFn<I> = dyn Fn(&I, &Value) -> Result<(), String> + Send + Sync + 'static;
16
17/// JSON value type matcher for field validation.
18#[derive(Debug, Clone, PartialEq, Eq)]
19pub enum JsonType {
20    Null,
21    Bool,
22    Number,
23    String,
24    Array,
25    Object,
26}
27
28/// Declarative rules for validating fields in an item.
29#[derive(Debug, Clone)]
30pub enum ValidationRule {
31    Required,
32    NonEmptyString,
33    Type(JsonType),
34    MinLen(usize),
35    MaxLen(usize),
36    MinNumber(f64),
37    MaxNumber(f64),
38}
39
40/// A pipeline that validates items and drops invalid entries.
41pub struct ValidationPipeline<I: ScrapedItem> {
42    rules: HashMap<String, Vec<ValidationRule>>,
43    validators: Vec<Arc<ValidatorFn<I>>>,
44    _phantom: PhantomData<I>,
45}
46
47impl<I: ScrapedItem> ValidationPipeline<I> {
48    /// Creates a new empty `ValidationPipeline`.
49    pub fn new() -> Self {
50        Self {
51            rules: HashMap::new(),
52            validators: Vec::new(),
53            _phantom: PhantomData,
54        }
55    }
56
57    /// Adds a field rule for the given top-level field name.
58    pub fn with_rule(mut self, field: impl Into<String>, rule: ValidationRule) -> Self {
59        self.rules.entry(field.into()).or_default().push(rule);
60        self
61    }
62
63    /// Adds a custom validator closure.
64    pub fn with_validator<F>(mut self, validator: F) -> Self
65    where
66        F: Fn(&I, &Value) -> Result<(), String> + Send + Sync + 'static,
67    {
68        self.validators.push(Arc::new(validator));
69        self
70    }
71
72    fn validate_type(value: &Value, expected: &JsonType) -> bool {
73        match expected {
74            JsonType::Null => value.is_null(),
75            JsonType::Bool => value.is_boolean(),
76            JsonType::Number => value.is_number(),
77            JsonType::String => value.is_string(),
78            JsonType::Array => value.is_array(),
79            JsonType::Object => value.is_object(),
80        }
81    }
82
83    fn validate_item(&self, json: &Value) -> Result<(), String> {
84        let map = json
85            .as_object()
86            .ok_or_else(|| "Item must be a JSON object for validation.".to_string())?;
87
88        for (field, rules) in &self.rules {
89            let value = map.get(field);
90            for rule in rules {
91                match rule {
92                    ValidationRule::Required => {
93                        if value.is_none() {
94                            return Err(format!("Missing required field '{}'.", field));
95                        }
96                    }
97                    ValidationRule::NonEmptyString => {
98                        if let Some(v) = value {
99                            match v.as_str() {
100                                Some(s) if !s.trim().is_empty() => {}
101                                Some(_) => {
102                                    return Err(format!(
103                                        "Field '{}' must be a non-empty string.",
104                                        field
105                                    ));
106                                }
107                                None => {
108                                    return Err(format!("Field '{}' must be a string.", field));
109                                }
110                            }
111                        }
112                    }
113                    ValidationRule::Type(expected) => {
114                        if let Some(v) = value
115                            && !Self::validate_type(v, expected)
116                        {
117                            return Err(format!(
118                                "Field '{}' has invalid type. Expected {:?}.",
119                                field, expected
120                            ));
121                        }
122                    }
123                    ValidationRule::MinLen(min) => {
124                        if let Some(v) = value {
125                            if let Some(s) = v.as_str() {
126                                if s.len() < *min {
127                                    return Err(format!(
128                                        "Field '{}' length {} is less than {}.",
129                                        field,
130                                        s.len(),
131                                        min
132                                    ));
133                                }
134                            } else if let Some(arr) = v.as_array() {
135                                if arr.len() < *min {
136                                    return Err(format!(
137                                        "Field '{}' array length {} is less than {}.",
138                                        field,
139                                        arr.len(),
140                                        min
141                                    ));
142                                }
143                            } else {
144                                return Err(format!(
145                                    "Field '{}' must be string or array for MinLen.",
146                                    field
147                                ));
148                            }
149                        }
150                    }
151                    ValidationRule::MaxLen(max) => {
152                        if let Some(v) = value {
153                            if let Some(s) = v.as_str() {
154                                if s.len() > *max {
155                                    return Err(format!(
156                                        "Field '{}' length {} is greater than {}.",
157                                        field,
158                                        s.len(),
159                                        max
160                                    ));
161                                }
162                            } else if let Some(arr) = v.as_array() {
163                                if arr.len() > *max {
164                                    return Err(format!(
165                                        "Field '{}' array length {} is greater than {}.",
166                                        field,
167                                        arr.len(),
168                                        max
169                                    ));
170                                }
171                            } else {
172                                return Err(format!(
173                                    "Field '{}' must be string or array for MaxLen.",
174                                    field
175                                ));
176                            }
177                        }
178                    }
179                    ValidationRule::MinNumber(min) => {
180                        if let Some(v) = value {
181                            let num = v.as_f64().ok_or_else(|| {
182                                format!("Field '{}' must be numeric for MinNumber.", field)
183                            })?;
184                            if num < *min {
185                                return Err(format!(
186                                    "Field '{}' number {} is less than {}.",
187                                    field, num, min
188                                ));
189                            }
190                        }
191                    }
192                    ValidationRule::MaxNumber(max) => {
193                        if let Some(v) = value {
194                            let num = v.as_f64().ok_or_else(|| {
195                                format!("Field '{}' must be numeric for MaxNumber.", field)
196                            })?;
197                            if num > *max {
198                                return Err(format!(
199                                    "Field '{}' number {} is greater than {}.",
200                                    field, num, max
201                                ));
202                            }
203                        }
204                    }
205                }
206            }
207        }
208
209        Ok(())
210    }
211}
212
213impl<I: ScrapedItem> Default for ValidationPipeline<I> {
214    fn default() -> Self {
215        Self::new()
216    }
217}
218
219#[async_trait]
220impl<I: ScrapedItem> Pipeline<I> for ValidationPipeline<I> {
221    fn name(&self) -> &str {
222        "ValidationPipeline"
223    }
224
225    async fn process_item(&self, item: I) -> Result<Option<I>, PipelineError> {
226        debug!("ValidationPipeline processing item.");
227        let json = item.to_json_value();
228
229        if let Err(err) = self.validate_item(&json) {
230            warn!("Validation failed, dropping item: {}", err);
231            return Ok(None);
232        }
233
234        for validator in &self.validators {
235            if let Err(err) = validator(&item, &json) {
236                warn!("Custom validation failed, dropping item: {}", err);
237                return Ok(None);
238            }
239        }
240
241        Ok(Some(item))
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248    use serde::{Deserialize, Serialize};
249    use serde_json::json;
250    use spider_util::item::ScrapedItem;
251    use std::any::Any;
252
253    #[derive(Debug, Clone, Serialize, Deserialize)]
254    struct TestItem {
255        title: String,
256        price: f64,
257    }
258
259    impl ScrapedItem for TestItem {
260        fn as_any(&self) -> &dyn Any {
261            self
262        }
263
264        fn box_clone(&self) -> Box<dyn ScrapedItem + Send + Sync> {
265            Box::new(self.clone())
266        }
267
268        fn to_json_value(&self) -> Value {
269            serde_json::to_value(self).expect("serialize test item")
270        }
271    }
272
273    #[tokio::test]
274    async fn passes_valid_item() {
275        let pipeline = ValidationPipeline::<TestItem>::new()
276            .with_rule("title", ValidationRule::Required)
277            .with_rule("title", ValidationRule::NonEmptyString)
278            .with_rule("price", ValidationRule::MinNumber(1.0))
279            .with_rule("price", ValidationRule::MaxNumber(100.0));
280
281        let item = TestItem {
282            title: "Book".to_string(),
283            price: 20.0,
284        };
285
286        let out = pipeline
287            .process_item(item)
288            .await
289            .expect("pipeline should not fail");
290        assert!(out.is_some());
291    }
292
293    #[tokio::test]
294    async fn drops_missing_required_field() {
295        let pipeline =
296            ValidationPipeline::<TestItem>::new().with_rule("missing", ValidationRule::Required);
297        let item = TestItem {
298            title: "Book".to_string(),
299            price: 20.0,
300        };
301
302        let out = pipeline
303            .process_item(item)
304            .await
305            .expect("pipeline should not fail");
306        assert!(out.is_none());
307    }
308
309    #[tokio::test]
310    async fn drops_on_custom_validator_error() {
311        let pipeline =
312            ValidationPipeline::<TestItem>::new().with_validator(|_item, json| {
313                match json.get("title").and_then(Value::as_str) {
314                    Some("Book") => Ok(()),
315                    _ => Err("title mismatch".to_string()),
316                }
317            });
318
319        let item = TestItem {
320            title: "Other".to_string(),
321            price: 20.0,
322        };
323
324        let out = pipeline
325            .process_item(item)
326            .await
327            .expect("pipeline should not fail");
328        assert!(out.is_none());
329    }
330
331    #[tokio::test]
332    async fn drops_on_invalid_type_rule() {
333        let pipeline = ValidationPipeline::<TestItem>::new()
334            .with_rule("title", ValidationRule::Type(JsonType::Number));
335        let item = TestItem {
336            title: "Book".to_string(),
337            price: 20.0,
338        };
339
340        let out = pipeline
341            .process_item(item)
342            .await
343            .expect("pipeline should not fail");
344        assert!(out.is_none());
345    }
346
347    #[tokio::test]
348    async fn handles_multiple_rules() {
349        let pipeline = ValidationPipeline::<TestItem>::new()
350            .with_rule("title", ValidationRule::MinLen(2))
351            .with_rule("title", ValidationRule::MaxLen(10))
352            .with_validator(|_, _| Ok(()));
353        let item = TestItem {
354            title: "ok".to_string(),
355            price: 5.0,
356        };
357        let out = pipeline
358            .process_item(item)
359            .await
360            .expect("pipeline should not fail");
361        assert_eq!(
362            out.expect("item should pass").to_json_value(),
363            json!({"title":"ok","price":5.0})
364        );
365    }
366}