gbrt_rs/objective/
mod.rs

1//! Objective functions for gradient boosting.
2//!
3//! This module defines the core trait and infrastructure for objective functions (loss functions)
4//! used in gradient boosting. Objective functions compute the loss and its derivatives (gradient
5//! and Hessian) which guide the boosting process.
6//!
7//! # Module Organization
8//!
9//! - [`regression`]: Regression objectives (MSE, MAE, Huber)
10//! - [`binary_classification`]: Binary classification objectives (Log Loss)
11//! - [`factory`]: Factory functions for creating objective instances
12//!
13//! # Key Components
14//!
15//! - [`Objective`]: Core trait that all objective functions must implement
16//! - [`ObjectiveConfig`]: Configuration for objective functions with builder pattern
17//! - [`ObjectiveError`]: Error type for objective-related failures
18
19mod regression;
20mod binary_classification;
21mod factory;
22
23// Re-export commonly used objective types.
24pub use regression::{RegressionObjective, MSEObjective, MAEObjective, HuberObjective};
25pub use binary_classification::{BinaryClassificationObjective, LogLossObjective};
26pub use factory::{ObjectiveFactory, create_objective, ObjectiveType};
27
28use crate::data::FeatureMatrix;
29use thiserror::Error;
30use serde::{Deserialize, Serialize};
31
32/// Errors that can occur during objective function operations.
33///
34/// This error type covers all failure modes when working with objective functions,
35/// including invalid input data, computation failures, configuration issues,
36/// and serialization problems.
37#[derive(Error, Debug)]
38pub enum ObjectiveError {
39    /// Input data is invalid for the objective function.
40    ///
41    /// This includes mismatched array lengths, non-finite values,
42    /// or target values outside the valid range (e.g., non-binary targets for classification).
43    #[error("Invalid input data: {0}")]
44    InvalidInput(String),
45    
46    /// An error occurred during loss/gradient computation.
47    ///
48    /// This typically indicates numerical instability or overflow/underflow issues.
49    #[error("Objective computation error: {0}")]
50    ComputationError(String),
51    
52    /// Objective configuration is invalid or missing required parameters.
53    #[error("Configuration error: {0}")]
54    ConfigError(String),
55   
56    /// Serialization or deserialization of objective configuration failed.
57    #[error("Serialization error: {0}")]
58    SerializationError(String),
59}
60
61/// Result type for objective function operations.
62///
63/// This is a convenience type alias for `Result<T, ObjectiveError>`.
64pub type ObjectiveResult<T> = std::result::Result<T, ObjectiveError>;
65
66/// Core trait for objective functions in gradient boosting.
67///
68/// Objective functions (also called loss functions) measure how well the model's predictions
69/// match the true target values. They provide the gradient and Hessian (second derivative)
70/// that guide the boosting algorithm in fitting subsequent trees.
71///
72/// # Implementation Requirements
73///
74/// All objective functions must implement the required methods. The [`gradient_hessian`](Objective::gradient_hessian)
75/// method has a default implementation that calls [`gradient`](Objective::gradient) and [`hessian`](Objective::hessian)
76/// separately, but can be overridden for more efficient combined computation.
77///
78/// # Mathematical Background
79///
80/// - **Gradient**: First derivative of the loss with respect to predictions, ∇L(y, ŷ)
81/// - **Hessian**: Second derivative of the loss, ∇²L(y, ŷ)
82/// - These are used to approximate the loss function via second-order Taylor expansion
83pub trait Objective: Send + Sync + std::fmt::Debug {
84    /// Computes the loss value for given predictions and targets.
85    ///
86    /// # Arguments
87    ///
88    /// * `y_true` - Slice of true target values
89    /// * `y_pred` - Slice of predicted values (in the transformed space if applicable)
90    ///
91    /// # Returns
92    ///
93    /// The average loss value across all samples
94    ///
95    /// # Errors
96    ///
97    /// Returns [`ObjectiveError::InvalidInput`] if inputs have mismatched lengths or contain invalid values
98    /// Returns [`ObjectiveError::ComputationError`] if numerical computation fails.
99    fn loss(&self, y_true: &[f64], y_pred: &[f64]) -> ObjectiveResult<f64>;
100    
101    /// Computes the gradient of the loss with respect to predictions.
102    ///
103    /// The gradient indicates how predictions should be adjusted to reduce loss.
104    ///
105    /// # Arguments
106    ///
107    /// * `y_true` - Slice of true target values
108    /// * `y_pred` - Slice of predicted values
109    ///
110    /// # Returns
111    ///
112    /// A vector of gradient values, one per sample
113    ///
114    /// # Errors
115    ///
116    /// Returns [`ObjectiveError::InvalidInput`] if inputs have mismatched lengths
117    /// Returns [`ObjectiveError::ComputationError`] if gradient computation fails 
118    fn gradient(&self, y_true: &[f64], y_pred: &[f64]) -> ObjectiveResult<Vec<f64>>;
119    
120    /// Computes the Hessian (second derivative) of the loss.
121    ///
122    /// The Hessian represents the curvature of the loss function and is used
123    /// to weight the importance of different samples in tree splitting.
124    ///
125    /// # Arguments
126    ///
127    /// * `y_true` - Slice of true target values
128    /// * `y_pred` - Slice of predicted values
129    ///
130    /// # Returns
131    ///
132    /// A vector of Hessian values, one per sample
133    ///
134    /// # Errors
135    ///
136    /// Returns [`ObjectiveError::InvalidInput`] if inputs have mismatched lengths
137    /// Returns [`ObjectiveError::ComputationError`] if Hessian computation fails 
138    fn hessian(&self, y_true: &[f64], y_pred: &[f64]) -> ObjectiveResult<Vec<f64>>;
139    
140    /// Computes both gradient and Hessian in one pass for efficiency.
141    ///
142    /// This method has a default implementation that calls [`gradient`](Objective::gradient)
143    /// and [`hessian`](Objective::hessian) separately. Implementations should override
144    /// this for better performance when both can be computed simultaneously.
145    ///
146    /// # Arguments
147    ///
148    /// * `y_true` - Slice of true target values
149    /// * `y_pred` - Slice of predicted values
150    ///
151    /// # Returns
152    ///
153    /// A tuple of (gradient_values, hessian_values)
154    ///
155    /// # Errors
156    ///
157    /// See [`gradient`](Objective::gradient) and [`hessian`](Objective::hessian) 
158    fn gradient_hessian(&self, y_true: &[f64], y_pred: &[f64]) -> ObjectiveResult<(Vec<f64>, Vec<f64>)> {
159        let gradient = self.gradient(y_true, y_pred)?;
160        let hessian = self.hessian(y_true, y_pred)?;
161        Ok((gradient, hessian))
162    }
163    
164    /// Transforms raw predictions to the target space.
165    ///
166    /// For example, logistic loss applies a sigmoid transformation to produce
167    /// probabilities for binary classification.
168    ///
169    /// # Arguments
170    ///
171    /// * `y_pred` - Slice of raw prediction values
172    ///
173    /// # Returns
174    ///
175    /// Transformed predictions (e.g., probabilities for classification) 
176    fn transform(&self, y_pred: &[f64]) -> Vec<f64>;
177    
178    /// Returns the name of the objective function.
179    ///
180    /// This is used for logging, serialization, and configuration matching.
181    ///
182    /// # Returns
183    ///
184    /// A string slice identifying the objective (e.g., "mse", "logloss") 
185    fn name(&self) -> &str;
186    
187    /// Checks if this objective is for regression tasks.
188    ///
189    /// # Returns
190    ///
191    /// `true` if the objective is for regression (e.g., MSE, MAE) 
192    fn is_regression(&self) -> bool;
193    
194    /// Checks if this objective is for classification tasks.
195    ///
196    /// # Returns
197    ///
198    /// `true` if the objective is for classification (e.g., Log Loss) 
199    fn is_classification(&self) -> bool;
200    
201    /// Computes the default initial prediction for this objective.
202    ///
203    /// The initial prediction is used as the starting point for boosting.
204    /// For regression, this is often the mean of targets. For classification,
205    /// it's typically the logit of the class proportion.
206    ///
207    /// # Arguments
208    ///
209    /// * `y_true` - Slice of true target values
210    ///
211    /// # Returns
212    ///
213    /// The initial prediction value
214    ///
215    /// # Errors
216    ///
217    /// Returns [`ObjectiveError::InvalidInput`] if `y_true` is empty or invalid 
218    fn initial_prediction(&self, y_true: &[f64]) -> ObjectiveResult<f64>;
219    
220    /// Validates that target values are appropriate for this objective.
221    ///
222    /// This checks for common issues like non-finite values, out-of-range targets,
223    /// or invalid class labels.
224    ///
225    /// # Arguments
226    ///
227    /// * `y_true` - Slice of true target values to validate
228    ///
229    /// # Returns
230    ///
231    /// `Ok(())` if validation passes
232    ///
233    /// # Errors
234    ///
235    /// Returns [`ObjectiveError::InvalidInput`] if validation fails 
236    fn validate_targets(&self, y_true: &[f64]) -> ObjectiveResult<()>;
237}
238
239/// Configuration for objective functions.
240///
241/// This struct provides a flexible way to configure objective functions with
242/// a name and optional parameters. It supports serialization/deserialization
243/// for model persistence and uses a builder pattern for ergonomic configuration.
244#[derive(Debug, Clone, Serialize, Deserialize)]
245pub struct ObjectiveConfig {
246    /// The name of the objective function (e.g., "mse", "huber", "logloss")
247    pub name: String,
248    /// Optional parameters for the objective (e.g., "delta" for Huber loss)
249    pub params: std::collections::HashMap<String, f64>,
250}
251
252impl Default for ObjectiveConfig {
253    /// Returns the default configuration using Mean Squared Error (MSE).
254    fn default() -> Self {
255        Self {
256            name: "mse".to_string(),
257            params: std::collections::HashMap::new(),
258        }
259    }
260}
261
262impl ObjectiveConfig {
263    /// Creates a new objective configuration with the given name and no parameters.
264    ///
265    /// # Arguments
266    ///
267    /// * `name` - The name of the objective function
268    ///
269    /// # Returns
270    ///
271    /// A new `ObjectiveConfig` instance
272    pub fn new(name: &str) -> Self {
273        Self {
274            name: name.to_string(),
275            params: std::collections::HashMap::new(),
276        }
277    }
278    
279    /// Adds a parameter to the configuration (builder pattern).
280    ///
281    /// This method consumes `self` and returns a modified configuration,
282    /// enabling method chaining for ergonomic configuration.
283    ///
284    /// # Arguments
285    ///
286    /// * `key` - Parameter name
287    /// * `value` - Parameter value
288    ///
289    /// # Returns
290    ///
291    /// Self with the parameter added
292    pub fn with_param(mut self, key: &str, value: f64) -> Self {
293        self.params.insert(key.to_string(), value);
294        self
295    }
296
297    /// Retrieves a parameter value by key.
298    ///
299    /// # Arguments
300    ///
301    /// * `key` - Parameter name to look up
302    ///
303    /// # Returns
304    ///
305    /// `Some(f64)` if the parameter exists, `None` otherwise
306    pub fn get_param(&self, key: &str) -> Option<f64> {
307        self.params.get(key).copied()
308    }
309}
310