gbrt_rs/objective/mod.rs
1//! Objective functions for gradient boosting.
2//!
3//! This module defines the core trait and infrastructure for objective functions (loss functions)
4//! used in gradient boosting. Objective functions compute the loss and its derivatives (gradient
5//! and Hessian) which guide the boosting process.
6//!
7//! # Module Organization
8//!
9//! - [`regression`]: Regression objectives (MSE, MAE, Huber)
10//! - [`binary_classification`]: Binary classification objectives (Log Loss)
11//! - [`factory`]: Factory functions for creating objective instances
12//!
13//! # Key Components
14//!
15//! - [`Objective`]: Core trait that all objective functions must implement
16//! - [`ObjectiveConfig`]: Configuration for objective functions with builder pattern
17//! - [`ObjectiveError`]: Error type for objective-related failures
18
19mod regression;
20mod binary_classification;
21mod factory;
22
23// Re-export commonly used objective types.
24pub use regression::{RegressionObjective, MSEObjective, MAEObjective, HuberObjective};
25pub use binary_classification::{BinaryClassificationObjective, LogLossObjective};
26pub use factory::{ObjectiveFactory, create_objective, ObjectiveType};
27
28use crate::data::FeatureMatrix;
29use thiserror::Error;
30use serde::{Deserialize, Serialize};
31
32/// Errors that can occur during objective function operations.
33///
34/// This error type covers all failure modes when working with objective functions,
35/// including invalid input data, computation failures, configuration issues,
36/// and serialization problems.
37#[derive(Error, Debug)]
38pub enum ObjectiveError {
39 /// Input data is invalid for the objective function.
40 ///
41 /// This includes mismatched array lengths, non-finite values,
42 /// or target values outside the valid range (e.g., non-binary targets for classification).
43 #[error("Invalid input data: {0}")]
44 InvalidInput(String),
45
46 /// An error occurred during loss/gradient computation.
47 ///
48 /// This typically indicates numerical instability or overflow/underflow issues.
49 #[error("Objective computation error: {0}")]
50 ComputationError(String),
51
52 /// Objective configuration is invalid or missing required parameters.
53 #[error("Configuration error: {0}")]
54 ConfigError(String),
55
56 /// Serialization or deserialization of objective configuration failed.
57 #[error("Serialization error: {0}")]
58 SerializationError(String),
59}
60
61/// Result type for objective function operations.
62///
63/// This is a convenience type alias for `Result<T, ObjectiveError>`.
64pub type ObjectiveResult<T> = std::result::Result<T, ObjectiveError>;
65
66/// Core trait for objective functions in gradient boosting.
67///
68/// Objective functions (also called loss functions) measure how well the model's predictions
69/// match the true target values. They provide the gradient and Hessian (second derivative)
70/// that guide the boosting algorithm in fitting subsequent trees.
71///
72/// # Implementation Requirements
73///
74/// All objective functions must implement the required methods. The [`gradient_hessian`](Objective::gradient_hessian)
75/// method has a default implementation that calls [`gradient`](Objective::gradient) and [`hessian`](Objective::hessian)
76/// separately, but can be overridden for more efficient combined computation.
77///
78/// # Mathematical Background
79///
80/// - **Gradient**: First derivative of the loss with respect to predictions, ∇L(y, ŷ)
81/// - **Hessian**: Second derivative of the loss, ∇²L(y, ŷ)
82/// - These are used to approximate the loss function via second-order Taylor expansion
83pub trait Objective: Send + Sync + std::fmt::Debug {
84 /// Computes the loss value for given predictions and targets.
85 ///
86 /// # Arguments
87 ///
88 /// * `y_true` - Slice of true target values
89 /// * `y_pred` - Slice of predicted values (in the transformed space if applicable)
90 ///
91 /// # Returns
92 ///
93 /// The average loss value across all samples
94 ///
95 /// # Errors
96 ///
97 /// Returns [`ObjectiveError::InvalidInput`] if inputs have mismatched lengths or contain invalid values
98 /// Returns [`ObjectiveError::ComputationError`] if numerical computation fails.
99 fn loss(&self, y_true: &[f64], y_pred: &[f64]) -> ObjectiveResult<f64>;
100
101 /// Computes the gradient of the loss with respect to predictions.
102 ///
103 /// The gradient indicates how predictions should be adjusted to reduce loss.
104 ///
105 /// # Arguments
106 ///
107 /// * `y_true` - Slice of true target values
108 /// * `y_pred` - Slice of predicted values
109 ///
110 /// # Returns
111 ///
112 /// A vector of gradient values, one per sample
113 ///
114 /// # Errors
115 ///
116 /// Returns [`ObjectiveError::InvalidInput`] if inputs have mismatched lengths
117 /// Returns [`ObjectiveError::ComputationError`] if gradient computation fails
118 fn gradient(&self, y_true: &[f64], y_pred: &[f64]) -> ObjectiveResult<Vec<f64>>;
119
120 /// Computes the Hessian (second derivative) of the loss.
121 ///
122 /// The Hessian represents the curvature of the loss function and is used
123 /// to weight the importance of different samples in tree splitting.
124 ///
125 /// # Arguments
126 ///
127 /// * `y_true` - Slice of true target values
128 /// * `y_pred` - Slice of predicted values
129 ///
130 /// # Returns
131 ///
132 /// A vector of Hessian values, one per sample
133 ///
134 /// # Errors
135 ///
136 /// Returns [`ObjectiveError::InvalidInput`] if inputs have mismatched lengths
137 /// Returns [`ObjectiveError::ComputationError`] if Hessian computation fails
138 fn hessian(&self, y_true: &[f64], y_pred: &[f64]) -> ObjectiveResult<Vec<f64>>;
139
140 /// Computes both gradient and Hessian in one pass for efficiency.
141 ///
142 /// This method has a default implementation that calls [`gradient`](Objective::gradient)
143 /// and [`hessian`](Objective::hessian) separately. Implementations should override
144 /// this for better performance when both can be computed simultaneously.
145 ///
146 /// # Arguments
147 ///
148 /// * `y_true` - Slice of true target values
149 /// * `y_pred` - Slice of predicted values
150 ///
151 /// # Returns
152 ///
153 /// A tuple of (gradient_values, hessian_values)
154 ///
155 /// # Errors
156 ///
157 /// See [`gradient`](Objective::gradient) and [`hessian`](Objective::hessian)
158 fn gradient_hessian(&self, y_true: &[f64], y_pred: &[f64]) -> ObjectiveResult<(Vec<f64>, Vec<f64>)> {
159 let gradient = self.gradient(y_true, y_pred)?;
160 let hessian = self.hessian(y_true, y_pred)?;
161 Ok((gradient, hessian))
162 }
163
164 /// Transforms raw predictions to the target space.
165 ///
166 /// For example, logistic loss applies a sigmoid transformation to produce
167 /// probabilities for binary classification.
168 ///
169 /// # Arguments
170 ///
171 /// * `y_pred` - Slice of raw prediction values
172 ///
173 /// # Returns
174 ///
175 /// Transformed predictions (e.g., probabilities for classification)
176 fn transform(&self, y_pred: &[f64]) -> Vec<f64>;
177
178 /// Returns the name of the objective function.
179 ///
180 /// This is used for logging, serialization, and configuration matching.
181 ///
182 /// # Returns
183 ///
184 /// A string slice identifying the objective (e.g., "mse", "logloss")
185 fn name(&self) -> &str;
186
187 /// Checks if this objective is for regression tasks.
188 ///
189 /// # Returns
190 ///
191 /// `true` if the objective is for regression (e.g., MSE, MAE)
192 fn is_regression(&self) -> bool;
193
194 /// Checks if this objective is for classification tasks.
195 ///
196 /// # Returns
197 ///
198 /// `true` if the objective is for classification (e.g., Log Loss)
199 fn is_classification(&self) -> bool;
200
201 /// Computes the default initial prediction for this objective.
202 ///
203 /// The initial prediction is used as the starting point for boosting.
204 /// For regression, this is often the mean of targets. For classification,
205 /// it's typically the logit of the class proportion.
206 ///
207 /// # Arguments
208 ///
209 /// * `y_true` - Slice of true target values
210 ///
211 /// # Returns
212 ///
213 /// The initial prediction value
214 ///
215 /// # Errors
216 ///
217 /// Returns [`ObjectiveError::InvalidInput`] if `y_true` is empty or invalid
218 fn initial_prediction(&self, y_true: &[f64]) -> ObjectiveResult<f64>;
219
220 /// Validates that target values are appropriate for this objective.
221 ///
222 /// This checks for common issues like non-finite values, out-of-range targets,
223 /// or invalid class labels.
224 ///
225 /// # Arguments
226 ///
227 /// * `y_true` - Slice of true target values to validate
228 ///
229 /// # Returns
230 ///
231 /// `Ok(())` if validation passes
232 ///
233 /// # Errors
234 ///
235 /// Returns [`ObjectiveError::InvalidInput`] if validation fails
236 fn validate_targets(&self, y_true: &[f64]) -> ObjectiveResult<()>;
237}
238
239/// Configuration for objective functions.
240///
241/// This struct provides a flexible way to configure objective functions with
242/// a name and optional parameters. It supports serialization/deserialization
243/// for model persistence and uses a builder pattern for ergonomic configuration.
244#[derive(Debug, Clone, Serialize, Deserialize)]
245pub struct ObjectiveConfig {
246 /// The name of the objective function (e.g., "mse", "huber", "logloss")
247 pub name: String,
248 /// Optional parameters for the objective (e.g., "delta" for Huber loss)
249 pub params: std::collections::HashMap<String, f64>,
250}
251
252impl Default for ObjectiveConfig {
253 /// Returns the default configuration using Mean Squared Error (MSE).
254 fn default() -> Self {
255 Self {
256 name: "mse".to_string(),
257 params: std::collections::HashMap::new(),
258 }
259 }
260}
261
262impl ObjectiveConfig {
263 /// Creates a new objective configuration with the given name and no parameters.
264 ///
265 /// # Arguments
266 ///
267 /// * `name` - The name of the objective function
268 ///
269 /// # Returns
270 ///
271 /// A new `ObjectiveConfig` instance
272 pub fn new(name: &str) -> Self {
273 Self {
274 name: name.to_string(),
275 params: std::collections::HashMap::new(),
276 }
277 }
278
279 /// Adds a parameter to the configuration (builder pattern).
280 ///
281 /// This method consumes `self` and returns a modified configuration,
282 /// enabling method chaining for ergonomic configuration.
283 ///
284 /// # Arguments
285 ///
286 /// * `key` - Parameter name
287 /// * `value` - Parameter value
288 ///
289 /// # Returns
290 ///
291 /// Self with the parameter added
292 pub fn with_param(mut self, key: &str, value: f64) -> Self {
293 self.params.insert(key.to_string(), value);
294 self
295 }
296
297 /// Retrieves a parameter value by key.
298 ///
299 /// # Arguments
300 ///
301 /// * `key` - Parameter name to look up
302 ///
303 /// # Returns
304 ///
305 /// `Some(f64)` if the parameter exists, `None` otherwise
306 pub fn get_param(&self, key: &str) -> Option<f64> {
307 self.params.get(key).copied()
308 }
309}
310