gbrt_rs/utils/validation.rs
1//! Data and parameter validation utilities for gradient boosting models.
2//!
3//! This module provides comprehensive validation functions for:
4//! - Input data quality (finite values, dimensional consistency)
5//! - Parameter constraints (ranges, positivity)
6//! - Model hyperparameter validation
7//!
8//! # Design
9//!
10//! The module offers both **struct-based validators** with builder patterns for
11//! complex validation scenarios and **standalone functions** for quick checks.
12//! All validation functions return [`ValidationResult<T>`], providing detailed
13//! error messages when validation fails.
14//!
15//! # Key Components
16//!
17//! - [`DataValidator`]: Validates data arrays, features, and targets
18//! - [`ParameterValidator`]: Validates model parameters against constraints
19//! - Standalone functions: Lightweight validation helpers for common checks
20//!
21//! # Error Handling
22//!
23//! Validation failures return [`ValidationError`] with specific variants describing
24//! the exact nature of the validation failure, enabling precise error reporting.
25
26use std::collections::HashMap;
27use thiserror::Error;
28
29/// Errors that can occur during validation.
30///
31/// Each variant provides detailed information about the specific validation
32/// failure, including parameter names, expected constraints, and actual values.
33#[derive(Error, Debug)]
34pub enum ValidationError {
35 /// General validation failure with a descriptive message.
36 #[error("Validation failed: {0}")]
37 ValidationFailed(String),
38
39 /// Parameter value is outside the allowed range.
40 #[error("Invalid parameter: {name} = {value} (expected: {constraint})")]
41 InvalidParameter {
42 /// Name of the parameter that failed validation.
43 name: String,
44 /// Actual value provided.
45 value: f64,
46 /// Expected constraint or range.
47 constraint: String,
48 },
49
50 /// Array dimensions don't match expected values.
51 #[error("Dimension mismatch: expected {expected}, got {actual}")]
52 DimensionMismatch {
53 /// Expected dimension or size.
54 expected: String,
55 /// Actual dimension or size found.
56 actual: String,
57 },
58
59 /// Input data is empty.
60 #[error("Empty data")]
61 EmptyData,
62
63 /// Non-finite value (NaN, Inf, -Inf) found in data.
64 #[error("Non-finite value found")]
65 NonFiniteValue,
66
67 /// Probability value is outside [0, 1] range.
68 #[error("Invalid probability: {value} (must be in [0, 1])")]
69 /// The invalid probability value.
70 InvalidProbability { value: f64 },
71}
72
73/// Result type for validation operations.
74pub type ValidationResult<T> = std::result::Result<T, ValidationError>;
75
76/// Validator for input data integrity.
77///
78/// Provides methods to check data quality including finiteness, positivity,
79/// dimensional consistency, and target value validity.
80pub struct DataValidator;
81
82impl DataValidator {
83 /// Creates a new data validator.
84 pub fn new() -> Self {
85 Self
86 }
87
88 /// Checks that all values in a slice are finite (not NaN, Inf, or -Inf).
89 ///
90 /// # Arguments
91 ///
92 /// * `values` - Slice of values to check
93 ///
94 /// # Returns
95 ///
96 /// `Ok(())` if all values are finite, `Err(ValidationError::NonFiniteValue)` otherwise
97 pub fn check_finite(&self, values: &[f64]) -> ValidationResult<()> {
98 check_finite(values)
99 }
100
101 /// Checks that all values are positive (> 0.0).
102 ///
103 /// # Arguments
104 ///
105 /// * `values` - Slice of values to check
106 ///
107 /// # Returns
108 ///
109 /// `Ok(())` if all values are positive, `Err` otherwise
110 pub fn check_positive(&self, values: &[f64]) -> ValidationResult<()> {
111 check_positive(values)
112 }
113
114 /// Checks that all values are valid probabilities (in [0, 1]).
115 ///
116 /// # Arguments
117 ///
118 /// * `values` - Slice of values to check
119 ///
120 /// # Returns
121 ///
122 /// `Ok(())` if all values are valid probabilities, `Err` otherwise
123 pub fn check_probability(&self, values: &[f64]) -> ValidationResult<()> {
124 check_probability(values)
125 }
126
127 /// Validates feature matrix dimensions and optional feature names.
128 ///
129 /// # Arguments
130 ///
131 /// * `n_samples` - Number of samples (must be > 0)
132 /// * `n_features` - Number of features (must be > 0)
133 /// * `feature_names` - Optional slice of feature names (length must match n_features)
134 ///
135 /// # Returns
136 ///
137 /// `Ok(())` if dimensions are valid, `Err` otherwise
138 pub fn validate_features(
139 &self,
140 n_samples: usize,
141 n_features: usize,
142 feature_names: Option<&[String]>
143 ) -> ValidationResult<()> {
144 validate_features(n_samples, n_features, feature_names)
145 }
146
147 /// Validates target values for regression or classification.
148 ///
149 /// # Arguments
150 ///
151 /// * `targets` - Target values to validate
152 /// * `is_classification` - Whether this is a classification problem
153 ///
154 /// # Returns
155 ///
156 /// `Ok(())` if targets are valid, `Err` otherwise
157 ///
158 /// # Behavior
159 ///
160 /// - For classification: targets must be 0.0 or 1.0
161 /// - For regression: targets must be finite (no other constraints)
162 pub fn validate_targets(&self, targets: &[f64], is_classification: bool) -> ValidationResult<()> {
163 validate_targets(targets, is_classification)
164 }
165
166 /// Asserts that two slices have the same length.
167 ///
168 /// # Arguments
169 ///
170 /// * `a` - First slice
171 /// * `b` - Second slice
172 ///
173 /// # Returns
174 ///
175 /// `Ok(())` if lengths match, `Err(ValidationError::DimensionMismatch)` otherwise
176 pub fn assert_same_length(&self, a: &[f64], b: &[f64]) -> ValidationResult<()> {
177 assert_same_length(a, b)
178 }
179
180 /// Asserts that a slice is not empty.
181 ///
182 /// # Arguments
183 ///
184 /// * `values` - Slice to check
185 ///
186 /// # Returns
187 ///
188 /// `Ok(())` if non-empty, `Err(ValidationError::EmptyData)` otherwise
189 pub fn assert_non_empty(&self, values: &[f64]) -> ValidationResult<()> {
190 assert_non_empty(values)
191 }
192}
193
194impl Default for DataValidator {
195 fn default() -> Self {
196 Self::new()
197 }
198}
199
200/// Validator for model parameters with configurable constraints.
201///
202/// Maintains a map of parameter names to allowed (min, max) ranges,
203/// enabling reusable validation logic for hyperparameters.
204pub struct ParameterValidator {
205 /// Map of parameter name to (min, max) constraint.
206 constraints: HashMap<String, (f64, f64)>, // (min, max) for each parameter
207}
208
209impl ParameterValidator {
210 /// Creates a new parameter validator with no constraints.
211 pub fn new() -> Self {
212 Self {
213 constraints: HashMap::new(),
214 }
215 }
216
217 /// Adds a parameter range constraint (builder pattern).
218 ///
219 /// # Arguments
220 ///
221 /// * `name` - Parameter name
222 /// * `min` - Minimum allowed value (inclusive)
223 /// * `max` - Maximum allowed value (inclusive)
224 ///
225 /// # Returns
226 ///
227 /// Self with the new constraint added
228 pub fn add_constraint(mut self, name: &str, min: f64, max: f64) -> Self {
229 self.constraints.insert(name.to_string(), (min, max));
230 self
231 }
232
233 /// Validates a single parameter against its constraints.
234 ///
235 /// # Arguments
236 ///
237 /// * `name` - Parameter name
238 /// * `value` - Parameter value to validate
239 ///
240 /// # Returns
241 ///
242 /// `Ok(())` if valid or no constraint exists, `Err` if out of range
243 pub fn validate_parameter(&self, name: &str, value: f64) -> ValidationResult<()> {
244 if let Some(&(min, max)) = self.constraints.get(name) {
245 if value < min || value > max {
246 return Err(ValidationError::InvalidParameter {
247 name: name.to_string(),
248 value,
249 constraint: format!("[{}, {}]", min, max),
250 });
251 }
252 }
253 Ok(())
254 }
255
256 /// Validates multiple parameters against all registered constraints.
257 ///
258 /// # Arguments
259 ///
260 /// * `params` - HashMap of parameter name → value
261 ///
262 /// # Returns
263 ///
264 /// `Ok(())` if all parameters satisfy constraints, `Err` otherwise
265 pub fn validate_parameters(&self, params: &HashMap<String, f64>) -> ValidationResult<()> {
266 for (name, &value) in params {
267 self.validate_parameter(name, value)?;
268 }
269 Ok(())
270 }
271
272 /// Validates common gradient boosting hyperparameters.
273 ///
274 /// Checks all parameters against sensible ranges for a typical GBM model.
275 ///
276 /// # Arguments
277 ///
278 /// * `n_estimators` - Number of trees (must be > 0)
279 /// * `learning_rate` - Step size (must be in (0, 1])
280 /// * `max_depth` - Tree depth (must be > 0)
281 /// * `subsample` - Subsampling rate (must be in [0, 1])
282 ///
283 /// # Returns
284 ///
285 /// `Ok(())` if all parameters are valid
286 pub fn validate_model_params(
287 &self,
288 n_estimators: usize,
289 learning_rate: f64,
290 max_depth: usize,
291 subsample: f64,
292 ) -> ValidationResult<()> {
293 let mut params = HashMap::new();
294 params.insert("n_estimators".to_string(), n_estimators as f64);
295 params.insert("learning_rate".to_string(), learning_rate);
296 params.insert("max_depth".to_string(), max_depth as f64);
297 params.insert("subsample".to_string(), subsample);
298
299 let validator = ParameterValidator::new()
300 .add_constraint("n_estimators", 1.0, 10000.0)
301 .add_constraint("learning_rate", 1e-10, 1.0)
302 .add_constraint("max_depth", 1.0, 100.0)
303 .add_constraint("subsample", 0.0, 1.0);
304
305 validator.validate_parameters(¶ms)
306 }
307}
308
309impl Default for ParameterValidator {
310 fn default() -> Self {
311 Self::new()
312 }
313}
314
315// Standalone validation functions
316
317/// Checks that all values in a slice are finite (not NaN, Inf, or -Inf).
318///
319/// # Arguments
320///
321/// * `values` - Slice to check
322///
323/// # Returns
324///
325/// `Ok(())` if all values are finite
326pub fn check_finite(values: &[f64]) -> ValidationResult<()> {
327 if values.iter().any(|&x| !x.is_finite()) {
328 return Err(ValidationError::NonFiniteValue);
329 }
330 Ok(())
331}
332
333/// Checks that all values are positive (> 0.0).
334///
335/// # Arguments
336///
337/// * `values` - Slice to check
338///
339/// # Returns
340///
341/// `Ok(())` if all values are positive
342pub fn check_positive(values: &[f64]) -> ValidationResult<()> {
343 if values.iter().any(|&x| x <= 0.0) {
344 return Err(ValidationError::ValidationFailed(
345 "All values must be positive".to_string()
346 ));
347 }
348 Ok(())
349}
350
351/// Checks that all values are valid probabilities (in [0, 1]).
352///
353/// # Arguments
354///
355/// * `values` - Slice to check
356///
357/// # Returns
358///
359/// `Ok(())` if all values are probabilities
360pub fn check_probability(values: &[f64]) -> ValidationResult<()> {
361 for &value in values {
362 if value < 0.0 || value > 1.0 {
363 return Err(ValidationError::InvalidProbability { value });
364 }
365 }
366 Ok(())
367}
368
369/// Validates feature matrix dimensions and optional feature names.
370///
371/// # Arguments
372///
373/// * `n_samples` - Number of samples (must be > 0)
374/// * `n_features` - Number of features (must be > 0)
375/// * `feature_names` - Optional feature names (length must match n_features if provided)
376///
377/// # Returns
378///
379/// `Ok(())` if dimensions are valid
380pub fn validate_features(
381 n_samples: usize,
382 n_features: usize,
383 feature_names: Option<&[String]>
384) -> ValidationResult<()> {
385 if n_samples == 0 {
386 return Err(ValidationError::EmptyData);
387 }
388
389 if n_features == 0 {
390 return Err(ValidationError::ValidationFailed(
391 "Number of features must be positive".to_string()
392 ));
393 }
394
395 if let Some(names) = feature_names {
396 if names.len() != n_features {
397 return Err(ValidationError::DimensionMismatch {
398 expected: n_features.to_string(),
399 actual: names.len().to_string(),
400 });
401 }
402 }
403
404 Ok(())
405}
406
407/// Validates target values for regression or classification.
408///
409/// For classification, ensures targets are binary (0 or 1). For regression,
410/// ensures all values are finite.
411///
412/// # Arguments
413///
414/// * `targets` - Target values
415/// * `is_classification` - Whether this is a classification problem
416///
417/// # Returns
418///
419/// `Ok(())` if targets are valid
420pub fn validate_targets(targets: &[f64], is_classification: bool) -> ValidationResult<()> {
421 if targets.is_empty() {
422 return Err(ValidationError::EmptyData);
423 }
424
425 check_finite(targets)?;
426
427 if is_classification {
428 // For classification, targets should be 0 or 1
429 for &target in targets {
430 if target != 0.0 && target != 1.0 {
431 return Err(ValidationError::ValidationFailed(
432 format!("Classification targets must be 0 or 1, got {}", target)
433 ));
434 }
435 }
436 }
437
438 Ok(())
439}
440
441/// Asserts that two slices have the same length.
442///
443/// # Arguments
444///
445/// * `a` - First slice
446/// * `b` - Second slice
447///
448/// # Returns
449///
450/// `Ok(())` if lengths match
451pub fn assert_same_length(a: &[f64], b: &[f64]) -> ValidationResult<()> {
452 if a.len() != b.len() {
453 return Err(ValidationError::DimensionMismatch {
454 expected: a.len().to_string(),
455 actual: b.len().to_string(),
456 });
457 }
458 Ok(())
459}
460
461/// Asserts that a slice is not empty.
462///
463/// # Arguments
464///
465/// * `values` - Slice to check
466///
467/// # Returns
468///
469/// `Ok(())` if non-empty
470pub fn assert_non_empty(values: &[f64]) -> ValidationResult<()> {
471 if values.is_empty() {
472 return Err(ValidationError::EmptyData);
473 }
474 Ok(())
475}
476
477/// Validates that a value falls within a specified range.
478///
479/// # Arguments
480///
481/// * `value` - Value to check
482/// * `min` - Minimum allowed (inclusive)
483/// * `max` - Maximum allowed (inclusive)
484/// * `name` - Name for error messages
485///
486/// # Returns
487///
488/// `Ok(())` if value is in range
489pub fn validate_range(value: f64, min: f64, max: f64, name: &str) -> ValidationResult<()> {
490 if value < min || value > max {
491 return Err(ValidationError::InvalidParameter {
492 name: name.to_string(),
493 value,
494 constraint: format!("[{}, {}]", min, max),
495 });
496 }
497 Ok(())
498}
499
500/// Validates that a value is positive (> 0).
501///
502/// # Arguments
503///
504/// * `value` - Value to check
505/// * `name` - Name for error messages
506///
507/// # Returns
508///
509/// `Ok(())` if value is positive
510pub fn validate_positive(value: f64, name: &str) -> ValidationResult<()> {
511 if value <= 0.0 {
512 return Err(ValidationError::InvalidParameter {
513 name: name.to_string(),
514 value,
515 constraint: "> 0".to_string(),
516 });
517 }
518 Ok(())
519}
520
521/// Validates that a value is non-negative (≥ 0).
522///
523/// # Arguments
524///
525/// * `value` - Value to check
526/// * `name` - Name for error messages
527///
528/// # Returns
529///
530/// `Ok(())` if value is non-negative
531pub fn validate_non_negative(value: f64, name: &str) -> ValidationResult<()> {
532 if value < 0.0 {
533 return Err(ValidationError::InvalidParameter {
534 name: name.to_string(),
535 value,
536 constraint: ">= 0".to_string(),
537 });
538 }
539 Ok(())
540}
541
542/// Validates common gradient boosting model hyperparameters.
543///
544/// Checks all parameters against sensible ranges for typical GBM models.
545///
546/// # Arguments
547///
548/// * `n_estimators` - Number of trees
549/// * `learning_rate` - Step size
550/// * `max_depth` - Maximum tree depth
551/// * `subsample` - Subsampling rate
552///
553/// # Returns
554///
555/// `Ok(())` if all parameters are valid
556pub fn validate_model_params(
557 n_estimators: usize,
558 learning_rate: f64,
559 max_depth: usize,
560 subsample: f64,
561) -> ValidationResult<()> {
562 if n_estimators == 0 {
563 return Err(ValidationError::ValidationFailed(
564 "n_estimators must be positive".to_string()
565 ));
566 }
567
568 validate_range(learning_rate, 1e-10, 1.0, "learning_rate")?;
569
570 if max_depth == 0 {
571 return Err(ValidationError::ValidationFailed(
572 "max_depth must be positive".to_string()
573 ));
574 }
575
576 validate_range(subsample, 0.0, 1.0, "subsample")?;
577
578 Ok(())
579}
580