gbrt_rs/objective/factory.rs
1#![allow(dead_code)]
2
3//! Objective Function Factory and Registry
4//!
5//! This module provides factory patterns for creating objective functions used
6//! in gradient boosting. It supports both regression and binary classification
7//! tasks with configurable parameters.
8//!
9//! # Architecture
10//!
11//! - [`ObjectiveFactory`]: Static factory for creating objective instances
12//! - [`ObjectiveBuilder`]: Fluent builder for configuring objectives
13//! - [`ObjectiveRegistry`]: Validation and discovery of available objectives
14//! - [`ObjectiveType`]: Enum distinguishing regression vs classification
15//!
16//! # Supported Objectives
17//!
18//! **Regression:**
19//! - `mse` / `mean_squared_error`
20//! - `mae` / `mean_absolute_error`
21//! - `huber` (with `delta` parameter)
22//!
23//! **Binary Classification:**
24//! - `log_loss` / `binary_crossentropy`
25//!
26
27use super::{
28 Objective, ObjectiveError, ObjectiveResult, ObjectiveConfig,
29 RegressionObjective, BinaryClassificationObjective,
30 MSEObjective, MAEObjective, HuberObjective, LogLossObjective
31};
32use serde::{Deserialize, Serialize};
33use std::collections::HashMap;
34
35/// Type of machine learning task for objective selection.
36///
37/// Used to categorize objectives and validate compatibility with
38/// gradient booster configuration.
39#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
40pub enum ObjectiveType {
41 /// Regression objectives predict continuous values.
42 Regression,
43 /// Binary classification objectives predict probabilities for two classes.
44 BinaryClassification,
45}
46
47impl std::fmt::Display for ObjectiveType {
48 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 match self {
50 ObjectiveType::Regression => write!(f, "regression"),
51 ObjectiveType::BinaryClassification => write!(f, "binary_classification"),
52 }
53 }
54}
55
56/// Factory for creating objective function instances.
57///
58/// `ObjectiveFactory` provides static methods to instantiate objective
59/// functions with appropriate types and parameters. It handles validation
60/// and configuration of objective-specific hyperparameters.
61pub struct ObjectiveFactory;
62
63impl ObjectiveFactory {
64 /// Creates a regression objective function by name.
65 ///
66 /// # Supported Names
67 /// - `mse`, `mean_squared_error`
68 /// - `mae`, `mean_absolute_error`
69 /// - `huber` (creates with delta=1.0)
70 ///
71 /// # Parameters
72 /// - `name`: Objective function name (case-insensitive)
73 ///
74 /// # Returns
75 /// Boxed trait object implementing [`RegressionObjective`]
76 ///
77 /// # Errors
78 /// - `ObjectiveError::ConfigError` if `name` is not a valid regression objective
79 pub fn create_regression_objective(name: &str) -> ObjectiveResult<Box<dyn RegressionObjective>> {
80 match name.to_lowercase().as_str() {
81 "mse" | "mean_squared_error" => Ok(Box::new(MSEObjective::new())),
82 "mae" | "mean_absolute_error" => Ok(Box::new(MAEObjective::new())),
83 "huber" => Ok(Box::new(HuberObjective::new(1.0))),
84 _ => Err(ObjectiveError::ConfigError(
85 format!("Unknown regression objective: {}", name)
86 )),
87 }
88 }
89
90 /// Creates a binary classification objective function by name.
91 ///
92 /// # Supported Names
93 /// - `log_loss`, `binary_crossentropy`, `binary_cross_entropy`
94 ///
95 /// # Parameters
96 /// - `name`: Objective function name (case-insensitive)
97 ///
98 /// # Returns
99 /// Boxed trait object implementing [`BinaryClassificationObjective`]
100 ///
101 /// # Errors
102 /// - `ObjectiveError::ConfigError` if `name` is not a valid classification objective
103 pub fn create_binary_classification_objective(name: &str) -> ObjectiveResult<Box<dyn BinaryClassificationObjective>> {
104 match name.to_lowercase().as_str() {
105 "log_loss" | "binary_crossentropy" | "binary_cross_entropy" => {
106 Ok(Box::new(LogLossObjective::new()))
107 },
108 _ => Err(ObjectiveError::ConfigError(
109 format!("Unknown binary classification objective: {}", name)
110 )),
111 }
112 }
113
114 /// Creates an objective from a configuration struct.
115 ///
116 /// # Parameters
117 /// - `config`: [`ObjectiveConfig`] with name and parameters
118 ///
119 /// # Returns
120 /// Boxed trait object implementing [`Objective`]
121 ///
122 /// # Supported Configurations
123 /// - `mse`: No parameters
124 /// - `mae`: No parameters
125 /// - `huber`: `delta` parameter (default: 1.0)
126 /// - `log_loss`: `epsilon` parameter (default: 1e-15)
127 ///
128 /// # Errors
129 /// - `ObjectiveError::ConfigError` if objective name is unknown
130 /// - `ObjectiveError::ConfigError` if parameters are invalid
131 pub fn create_objective_from_config(config: &ObjectiveConfig) -> ObjectiveResult<Box<dyn Objective>> {
132 let name = config.name.to_lowercase();
133
134 match name.as_str() {
135 // Regression objectives
136 "mse" | "mean_squared_error" => Ok(Box::new(MSEObjective::new())),
137 "mae" | "mean_absolute_error" => Ok(Box::new(MAEObjective::new())),
138 "huber" => {
139 let delta = config.get_param("delta").unwrap_or(1.0);
140 Ok(Box::new(HuberObjective::new(delta)))
141 },
142
143 // Binary classification objectives
144 "log_loss" | "binary_crossentropy" | "binary_cross_entropy" => {
145 let epsilon = config.get_param("epsilon").unwrap_or(1e-15);
146 Ok(Box::new(LogLossObjective::new().with_epsilon(epsilon)))
147 },
148
149 _ => Err(ObjectiveError::ConfigError(
150 format!("Unknown objective: {}", config.name)
151 )),
152 }
153 }
154
155 /// Determines the task type for a given objective name.
156 ///
157 /// # Parameters
158 /// - `name`: Objective name (case-insensitive)
159 ///
160 /// # Returns
161 /// [`ObjectiveType`] indicating regression or classification
162 ///
163 /// # Errors
164 /// - `ObjectiveError::ConfigError` if objective name is unknown
165 pub fn get_objective_type(name: &str) -> ObjectiveResult<ObjectiveType> {
166 match name.to_lowercase().as_str() {
167 "mse" | "mean_squared_error" | "mae" | "mean_absolute_error" | "huber" => {
168 Ok(ObjectiveType::Regression)
169 },
170 "log_loss" | "binary_crossentropy" | "binary_cross_entropy" => {
171 Ok(ObjectiveType::BinaryClassification)
172 },
173 _ => Err(ObjectiveError::ConfigError(
174 format!("Unknown objective: {}", name)
175 )),
176 }
177 }
178
179 /// Returns a list of all available objective names.
180 ///
181 /// # Returns
182 /// Vector of lowercase objective names
183 pub fn available_objectives() -> Vec<&'static str> {
184 vec![
185 "mse",
186 "mae",
187 "huber",
188 "log_loss",
189 ]
190 }
191
192 /// Gets default parameters for an objective.
193 ///
194 /// # Parameters
195 /// - `name`: Objective name (case-insensitive)
196 ///
197 /// # Returns
198 /// HashMap of parameter names to default values
199 ///
200 /// # Errors
201 /// - `ObjectiveError::ConfigError` if objective name is unknown
202 pub fn get_default_params(name: &str) -> ObjectiveResult<HashMap<String, f64>> {
203 let mut params = HashMap::new();
204
205 match name.to_lowercase().as_str() {
206 "huber" => {
207 params.insert("delta".to_string(), 1.0);
208 },
209 "log_loss" => {
210 params.insert("epsilon".to_string(), 1e-15);
211 },
212 "mse" | "mae" => {
213 // No parameters for these
214 },
215 _ => return Err(ObjectiveError::ConfigError(
216 format!("Unknown objective: {}", name)
217 )),
218 }
219
220 Ok(params)
221 }
222}
223
224/// Convenience function to create an objective by name with default parameters.
225///
226/// # Parameters
227/// - `name`: Objective name (case-insensitive)
228///
229/// # Returns
230/// Boxed trait object implementing [`Objective`]
231pub fn create_objective(name: &str) -> ObjectiveResult<Box<dyn Objective>> {
232 let config = ObjectiveConfig::new(name);
233 ObjectiveFactory::create_objective_from_config(&config)
234}
235
236/// Fluent builder for constructing [`ObjectiveConfig`].
237///
238/// Provides a convenient API for configuring objective functions with parameters.
239pub struct ObjectiveBuilder {
240 name: String,
241 params: HashMap<String, f64>,
242}
243
244impl ObjectiveBuilder {
245 /// Creates a new builder for the specified objective.
246 ///
247 /// # Parameters
248 /// - `name`: Objective name (case-insensitive)
249 pub fn new(name: &str) -> Self {
250 Self {
251 name: name.to_string(),
252 params: HashMap::new(),
253 }
254 }
255
256 /// Adds a parameter to the objective configuration.
257 ///
258 /// # Parameters
259 /// - `key`: Parameter name (e.g., "delta" for Huber)
260 /// - `value`: Parameter value
261 pub fn with_param(mut self, key: &str, value: f64) -> Self {
262 self.params.insert(key.to_string(), value);
263 self
264 }
265
266 /// Builds the final ObjectiveConfig.
267 ///
268 /// # Returns
269 /// Configured ObjectiveConfig ready for factory consumption
270 pub fn build(self) -> ObjectiveConfig {
271 ObjectiveConfig {
272 name: self.name,
273 params: self.params,
274 }
275 }
276}
277
278/// Registry for objective discovery and validation.
279///
280/// Provides utilities for checking objective availability, exploring options,
281/// and validating configurations before instantiation.
282pub struct ObjectiveRegistry;
283
284impl ObjectiveRegistry {
285 /// Checks if an objective name is valid and supported.
286 ///
287 /// # Parameters
288 /// - `name`: Objective name (case-insensitive)
289 ///
290 /// # Returns
291 /// `true` if objective exists, `false` otherwise
292 pub fn is_valid_objective(name: &str) -> bool {
293 ObjectiveFactory::available_objectives().contains(&name)
294 }
295
296 /// Returns all regression objective names.
297 pub fn regression_objectives() -> Vec<&'static str> {
298 vec!["mse", "mae", "huber"]
299 }
300
301 /// Returns all classification objective names.
302 pub fn classification_objectives() -> Vec<&'static str> {
303 vec!["log_loss"]
304 }
305
306 /// Validates an objective configuration before instantiation.
307 ///
308 /// Checks that:
309 /// - Objective name is valid
310 /// - All parameters are recognized
311 /// - Parameter values are valid (positive for numerical params)
312 ///
313 /// # Parameters
314 /// - `config`: ObjectiveConfig to validate
315 ///
316 /// # Returns
317 /// `Ok(())` if configuration is valid
318 ///
319 /// # Errors
320 /// - `ObjectiveError::ConfigError` if objective name is unknown
321 /// - `ObjectiveError::ConfigError` if parameters are invalid
322 pub fn validate_config(config: &ObjectiveConfig) -> ObjectiveResult<()> {
323 if !Self::is_valid_objective(&config.name) {
324 return Err(ObjectiveError::ConfigError(
325 format!("Invalid objective name: {}", config.name)
326 ));
327 }
328
329 // Validate parameters
330 let default_params = ObjectiveFactory::get_default_params(&config.name)?;
331 for (key, value) in &config.params {
332 if !default_params.contains_key(key) {
333 return Err(ObjectiveError::ConfigError(
334 format!("Unknown parameter '{}' for objective '{}'", key, config.name)
335 ));
336 }
337
338 // Validate parameter values
339 match key.as_str() {
340 "delta" | "alpha" | "epsilon" => {
341 if *value <= 0.0 {
342 return Err(ObjectiveError::ConfigError(
343 format!("Parameter '{}' must be positive, got {}", key, value)
344 ));
345 }
346 },
347 _ => {}
348 }
349 }
350
351 Ok(())
352 }
353}
354