gbrt_rs/utils/
validation.rs

1//! Data and parameter validation utilities for gradient boosting models.
2//!
3//! This module provides comprehensive validation functions for:
4//! - Input data quality (finite values, dimensional consistency)
5//! - Parameter constraints (ranges, positivity)
6//! - Model hyperparameter validation
7//!
8//! # Design
9//!
10//! The module offers both **struct-based validators** with builder patterns for
11//! complex validation scenarios and **standalone functions** for quick checks.
12//! All validation functions return [`ValidationResult<T>`], providing detailed
13//! error messages when validation fails.
14//!
15//! # Key Components
16//!
17//! - [`DataValidator`]: Validates data arrays, features, and targets
18//! - [`ParameterValidator`]: Validates model parameters against constraints
19//! - Standalone functions: Lightweight validation helpers for common checks
20//!
21//! # Error Handling
22//!
23//! Validation failures return [`ValidationError`] with specific variants describing
24//! the exact nature of the validation failure, enabling precise error reporting.
25
26use std::collections::HashMap;
27use thiserror::Error;
28
29/// Errors that can occur during validation.
30///
31/// Each variant provides detailed information about the specific validation
32/// failure, including parameter names, expected constraints, and actual values.
33#[derive(Error, Debug)]
34pub enum ValidationError {
35    /// General validation failure with a descriptive message.
36    #[error("Validation failed: {0}")]
37    ValidationFailed(String),
38    
39    /// Parameter value is outside the allowed range.
40    #[error("Invalid parameter: {name} = {value} (expected: {constraint})")]
41    InvalidParameter {
42        /// Name of the parameter that failed validation.
43        name: String,
44        /// Actual value provided.
45        value: f64,
46        /// Expected constraint or range.
47        constraint: String,
48    },
49    
50    /// Array dimensions don't match expected values.
51    #[error("Dimension mismatch: expected {expected}, got {actual}")]
52    DimensionMismatch {
53        /// Expected dimension or size.
54        expected: String,
55        /// Actual dimension or size found.
56        actual: String,
57    },
58    
59    /// Input data is empty.
60    #[error("Empty data")]
61    EmptyData,
62    
63    /// Non-finite value (NaN, Inf, -Inf) found in data.
64    #[error("Non-finite value found")]
65    NonFiniteValue,
66    
67    /// Probability value is outside [0, 1] range. 
68    #[error("Invalid probability: {value} (must be in [0, 1])")]
69    /// The invalid probability value.
70    InvalidProbability { value: f64 },
71}
72
73/// Result type for validation operations.
74pub type ValidationResult<T> = std::result::Result<T, ValidationError>;
75
76/// Validator for input data integrity.
77///
78/// Provides methods to check data quality including finiteness, positivity,
79/// dimensional consistency, and target value validity.
80pub struct DataValidator;
81
82impl DataValidator {
83    /// Creates a new data validator.
84    pub fn new() -> Self {
85        Self
86    }
87    
88    /// Checks that all values in a slice are finite (not NaN, Inf, or -Inf).
89    ///
90    /// # Arguments
91    ///
92    /// * `values` - Slice of values to check
93    ///
94    /// # Returns
95    ///
96    /// `Ok(())` if all values are finite, `Err(ValidationError::NonFiniteValue)` otherwise
97    pub fn check_finite(&self, values: &[f64]) -> ValidationResult<()> {
98        check_finite(values)
99    }
100    
101    /// Checks that all values are positive (> 0.0).
102    ///
103    /// # Arguments
104    ///
105    /// * `values` - Slice of values to check
106    ///
107    /// # Returns
108    ///
109    /// `Ok(())` if all values are positive, `Err` otherwise 
110    pub fn check_positive(&self, values: &[f64]) -> ValidationResult<()> {
111        check_positive(values)
112    }
113    
114    /// Checks that all values are valid probabilities (in [0, 1]).
115    ///
116    /// # Arguments
117    ///
118    /// * `values` - Slice of values to check
119    ///
120    /// # Returns
121    ///
122    /// `Ok(())` if all values are valid probabilities, `Err` otherwise 
123    pub fn check_probability(&self, values: &[f64]) -> ValidationResult<()> {
124        check_probability(values)
125    }
126    
127    /// Validates feature matrix dimensions and optional feature names.
128    ///
129    /// # Arguments
130    ///
131    /// * `n_samples` - Number of samples (must be > 0)
132    /// * `n_features` - Number of features (must be > 0)
133    /// * `feature_names` - Optional slice of feature names (length must match n_features)
134    ///
135    /// # Returns
136    ///
137    /// `Ok(())` if dimensions are valid, `Err` otherwise 
138    pub fn validate_features(
139        &self, 
140        n_samples: usize, 
141        n_features: usize,
142        feature_names: Option<&[String]>
143    ) -> ValidationResult<()> {
144        validate_features(n_samples, n_features, feature_names)
145    }
146    
147    /// Validates target values for regression or classification.
148    ///
149    /// # Arguments
150    ///
151    /// * `targets` - Target values to validate
152    /// * `is_classification` - Whether this is a classification problem
153    ///
154    /// # Returns
155    ///
156    /// `Ok(())` if targets are valid, `Err` otherwise
157    ///
158    /// # Behavior
159    ///
160    /// - For classification: targets must be 0.0 or 1.0
161    /// - For regression: targets must be finite (no other constraints) 
162    pub fn validate_targets(&self, targets: &[f64], is_classification: bool) -> ValidationResult<()> {
163        validate_targets(targets, is_classification)
164    }
165    
166    /// Asserts that two slices have the same length.
167    ///
168    /// # Arguments
169    ///
170    /// * `a` - First slice
171    /// * `b` - Second slice
172    ///
173    /// # Returns
174    ///
175    /// `Ok(())` if lengths match, `Err(ValidationError::DimensionMismatch)` otherwise 
176    pub fn assert_same_length(&self, a: &[f64], b: &[f64]) -> ValidationResult<()> {
177        assert_same_length(a, b)
178    }
179    
180    /// Asserts that a slice is not empty.
181    ///
182    /// # Arguments
183    ///
184    /// * `values` - Slice to check
185    ///
186    /// # Returns
187    ///
188    /// `Ok(())` if non-empty, `Err(ValidationError::EmptyData)` otherwise
189    pub fn assert_non_empty(&self, values: &[f64]) -> ValidationResult<()> {
190        assert_non_empty(values)
191    }
192}
193
194impl Default for DataValidator {
195    fn default() -> Self {
196        Self::new()
197    }
198}
199
200/// Validator for model parameters with configurable constraints.
201///
202/// Maintains a map of parameter names to allowed (min, max) ranges,
203/// enabling reusable validation logic for hyperparameters.
204pub struct ParameterValidator {
205    /// Map of parameter name to (min, max) constraint.
206    constraints: HashMap<String, (f64, f64)>, // (min, max) for each parameter
207}
208
209impl ParameterValidator {
210    /// Creates a new parameter validator with no constraints.
211    pub fn new() -> Self {
212        Self {
213            constraints: HashMap::new(),
214        }
215    }
216    
217    /// Adds a parameter range constraint (builder pattern).
218    ///
219    /// # Arguments
220    ///
221    /// * `name` - Parameter name
222    /// * `min` - Minimum allowed value (inclusive)
223    /// * `max` - Maximum allowed value (inclusive)
224    ///
225    /// # Returns
226    ///
227    /// Self with the new constraint added
228    pub fn add_constraint(mut self, name: &str, min: f64, max: f64) -> Self {
229        self.constraints.insert(name.to_string(), (min, max));
230        self
231    }
232    
233    /// Validates a single parameter against its constraints.
234    ///
235    /// # Arguments
236    ///
237    /// * `name` - Parameter name
238    /// * `value` - Parameter value to validate
239    ///
240    /// # Returns
241    ///
242    /// `Ok(())` if valid or no constraint exists, `Err` if out of range 
243    pub fn validate_parameter(&self, name: &str, value: f64) -> ValidationResult<()> {
244        if let Some(&(min, max)) = self.constraints.get(name) {
245            if value < min || value > max {
246                return Err(ValidationError::InvalidParameter {
247                    name: name.to_string(),
248                    value,
249                    constraint: format!("[{}, {}]", min, max),
250                });
251            }
252        }
253        Ok(())
254    }
255    
256    /// Validates multiple parameters against all registered constraints.
257    ///
258    /// # Arguments
259    ///
260    /// * `params` - HashMap of parameter name → value
261    ///
262    /// # Returns
263    ///
264    /// `Ok(())` if all parameters satisfy constraints, `Err` otherwise 
265    pub fn validate_parameters(&self, params: &HashMap<String, f64>) -> ValidationResult<()> {
266        for (name, &value) in params {
267            self.validate_parameter(name, value)?;
268        }
269        Ok(())
270    }
271    
272    /// Validates common gradient boosting hyperparameters.
273    ///
274    /// Checks all parameters against sensible ranges for a typical GBM model.
275    ///
276    /// # Arguments
277    ///
278    /// * `n_estimators` - Number of trees (must be > 0)
279    /// * `learning_rate` - Step size (must be in (0, 1])
280    /// * `max_depth` - Tree depth (must be > 0)
281    /// * `subsample` - Subsampling rate (must be in [0, 1])
282    ///
283    /// # Returns
284    ///
285    /// `Ok(())` if all parameters are valid 
286    pub fn validate_model_params(
287        &self,
288        n_estimators: usize,
289        learning_rate: f64,
290        max_depth: usize,
291        subsample: f64,
292    ) -> ValidationResult<()> {
293        let mut params = HashMap::new();
294        params.insert("n_estimators".to_string(), n_estimators as f64);
295        params.insert("learning_rate".to_string(), learning_rate);
296        params.insert("max_depth".to_string(), max_depth as f64);
297        params.insert("subsample".to_string(), subsample);
298        
299        let validator = ParameterValidator::new()
300            .add_constraint("n_estimators", 1.0, 10000.0)
301            .add_constraint("learning_rate", 1e-10, 1.0)
302            .add_constraint("max_depth", 1.0, 100.0)
303            .add_constraint("subsample", 0.0, 1.0);
304        
305        validator.validate_parameters(&params)
306    }
307}
308
309impl Default for ParameterValidator {
310    fn default() -> Self {
311        Self::new()
312    }
313}
314
315// Standalone validation functions
316
317/// Checks that all values in a slice are finite (not NaN, Inf, or -Inf).
318///
319/// # Arguments
320///
321/// * `values` - Slice to check
322///
323/// # Returns
324///
325/// `Ok(())` if all values are finite
326pub fn check_finite(values: &[f64]) -> ValidationResult<()> {
327    if values.iter().any(|&x| !x.is_finite()) {
328        return Err(ValidationError::NonFiniteValue);
329    }
330    Ok(())
331}
332
333/// Checks that all values are positive (> 0.0).
334///
335/// # Arguments
336///
337/// * `values` - Slice to check
338///
339/// # Returns
340///
341/// `Ok(())` if all values are positive
342pub fn check_positive(values: &[f64]) -> ValidationResult<()> {
343    if values.iter().any(|&x| x <= 0.0) {
344        return Err(ValidationError::ValidationFailed(
345            "All values must be positive".to_string()
346        ));
347    }
348    Ok(())
349}
350
351/// Checks that all values are valid probabilities (in [0, 1]).
352///
353/// # Arguments
354///
355/// * `values` - Slice to check
356///
357/// # Returns
358///
359/// `Ok(())` if all values are probabilities
360pub fn check_probability(values: &[f64]) -> ValidationResult<()> {
361    for &value in values {
362        if value < 0.0 || value > 1.0 {
363            return Err(ValidationError::InvalidProbability { value });
364        }
365    }
366    Ok(())
367}
368
369/// Validates feature matrix dimensions and optional feature names.
370///
371/// # Arguments
372///
373/// * `n_samples` - Number of samples (must be > 0)
374/// * `n_features` - Number of features (must be > 0)
375/// * `feature_names` - Optional feature names (length must match n_features if provided)
376///
377/// # Returns
378///
379/// `Ok(())` if dimensions are valid
380pub fn validate_features(
381    n_samples: usize, 
382    n_features: usize,
383    feature_names: Option<&[String]>
384) -> ValidationResult<()> {
385    if n_samples == 0 {
386        return Err(ValidationError::EmptyData);
387    }
388    
389    if n_features == 0 {
390        return Err(ValidationError::ValidationFailed(
391            "Number of features must be positive".to_string()
392        ));
393    }
394    
395    if let Some(names) = feature_names {
396        if names.len() != n_features {
397            return Err(ValidationError::DimensionMismatch {
398                expected: n_features.to_string(),
399                actual: names.len().to_string(),
400            });
401        }
402    }
403    
404    Ok(())
405}
406
407/// Validates target values for regression or classification.
408///
409/// For classification, ensures targets are binary (0 or 1). For regression,
410/// ensures all values are finite.
411///
412/// # Arguments
413///
414/// * `targets` - Target values
415/// * `is_classification` - Whether this is a classification problem
416///
417/// # Returns
418///
419/// `Ok(())` if targets are valid
420pub fn validate_targets(targets: &[f64], is_classification: bool) -> ValidationResult<()> {
421    if targets.is_empty() {
422        return Err(ValidationError::EmptyData);
423    }
424    
425    check_finite(targets)?;
426    
427    if is_classification {
428        // For classification, targets should be 0 or 1
429        for &target in targets {
430            if target != 0.0 && target != 1.0 {
431                return Err(ValidationError::ValidationFailed(
432                    format!("Classification targets must be 0 or 1, got {}", target)
433                ));
434            }
435        }
436    }
437    
438    Ok(())
439}
440
441/// Asserts that two slices have the same length.
442///
443/// # Arguments
444///
445/// * `a` - First slice
446/// * `b` - Second slice
447///
448/// # Returns
449///
450/// `Ok(())` if lengths match
451pub fn assert_same_length(a: &[f64], b: &[f64]) -> ValidationResult<()> {
452    if a.len() != b.len() {
453        return Err(ValidationError::DimensionMismatch {
454            expected: a.len().to_string(),
455            actual: b.len().to_string(),
456        });
457    }
458    Ok(())
459}
460
461/// Asserts that a slice is not empty.
462///
463/// # Arguments
464///
465/// * `values` - Slice to check
466///
467/// # Returns
468///
469/// `Ok(())` if non-empty
470pub fn assert_non_empty(values: &[f64]) -> ValidationResult<()> {
471    if values.is_empty() {
472        return Err(ValidationError::EmptyData);
473    }
474    Ok(())
475}
476
477/// Validates that a value falls within a specified range.
478///
479/// # Arguments
480///
481/// * `value` - Value to check
482/// * `min` - Minimum allowed (inclusive)
483/// * `max` - Maximum allowed (inclusive)
484/// * `name` - Name for error messages
485///
486/// # Returns
487///
488/// `Ok(())` if value is in range
489pub fn validate_range(value: f64, min: f64, max: f64, name: &str) -> ValidationResult<()> {
490    if value < min || value > max {
491        return Err(ValidationError::InvalidParameter {
492            name: name.to_string(),
493            value,
494            constraint: format!("[{}, {}]", min, max),
495        });
496    }
497    Ok(())
498}
499
500/// Validates that a value is positive (> 0).
501///
502/// # Arguments
503///
504/// * `value` - Value to check
505/// * `name` - Name for error messages
506///
507/// # Returns
508///
509/// `Ok(())` if value is positive
510pub fn validate_positive(value: f64, name: &str) -> ValidationResult<()> {
511    if value <= 0.0 {
512        return Err(ValidationError::InvalidParameter {
513            name: name.to_string(),
514            value,
515            constraint: "> 0".to_string(),
516        });
517    }
518    Ok(())
519}
520
521/// Validates that a value is non-negative (≥ 0).
522///
523/// # Arguments
524///
525/// * `value` - Value to check
526/// * `name` - Name for error messages
527///
528/// # Returns
529///
530/// `Ok(())` if value is non-negative
531pub fn validate_non_negative(value: f64, name: &str) -> ValidationResult<()> {
532    if value < 0.0 {
533        return Err(ValidationError::InvalidParameter {
534            name: name.to_string(),
535            value,
536            constraint: ">= 0".to_string(),
537        });
538    }
539    Ok(())
540}
541
542/// Validates common gradient boosting model hyperparameters.
543///
544/// Checks all parameters against sensible ranges for typical GBM models.
545///
546/// # Arguments
547///
548/// * `n_estimators` - Number of trees
549/// * `learning_rate` - Step size
550/// * `max_depth` - Maximum tree depth
551/// * `subsample` - Subsampling rate
552///
553/// # Returns
554///
555/// `Ok(())` if all parameters are valid
556pub fn validate_model_params(
557    n_estimators: usize,
558    learning_rate: f64,
559    max_depth: usize,
560    subsample: f64,
561) -> ValidationResult<()> {
562    if n_estimators == 0 {
563        return Err(ValidationError::ValidationFailed(
564            "n_estimators must be positive".to_string()
565        ));
566    }
567    
568    validate_range(learning_rate, 1e-10, 1.0, "learning_rate")?;
569    
570    if max_depth == 0 {
571        return Err(ValidationError::ValidationFailed(
572            "max_depth must be positive".to_string()
573        ));
574    }
575    
576    validate_range(subsample, 0.0, 1.0, "subsample")?;
577    
578    Ok(())
579}
580