gbrt_rs/objective/
factory.rs

1#![allow(dead_code)]
2
3//! Objective Function Factory and Registry
4//! 
5//! This module provides factory patterns for creating objective functions used
6//! in gradient boosting. It supports both regression and binary classification
7//! tasks with configurable parameters.
8//! 
9//! # Architecture
10//! 
11//! - [`ObjectiveFactory`]: Static factory for creating objective instances
12//! - [`ObjectiveBuilder`]: Fluent builder for configuring objectives
13//! - [`ObjectiveRegistry`]: Validation and discovery of available objectives
14//! - [`ObjectiveType`]: Enum distinguishing regression vs classification
15//! 
16//! # Supported Objectives
17//! 
18//! **Regression:**
19//! - `mse` / `mean_squared_error`
20//! - `mae` / `mean_absolute_error`
21//! - `huber` (with `delta` parameter)
22//! 
23//! **Binary Classification:**
24//! - `log_loss` / `binary_crossentropy`
25//! 
26
27use super::{
28    Objective, ObjectiveError, ObjectiveResult, ObjectiveConfig,
29    RegressionObjective, BinaryClassificationObjective,
30    MSEObjective, MAEObjective, HuberObjective, LogLossObjective
31};
32use serde::{Deserialize, Serialize};
33use std::collections::HashMap;
34
35/// Type of machine learning task for objective selection.
36/// 
37/// Used to categorize objectives and validate compatibility with
38/// gradient booster configuration.
39#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
40pub enum ObjectiveType {
41    /// Regression objectives predict continuous values.
42    Regression,
43    /// Binary classification objectives predict probabilities for two classes.
44    BinaryClassification,
45}
46
47impl std::fmt::Display for ObjectiveType {
48    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49        match self {
50            ObjectiveType::Regression => write!(f, "regression"),
51            ObjectiveType::BinaryClassification => write!(f, "binary_classification"),
52        }
53    }
54}
55
56/// Factory for creating objective function instances.
57/// 
58/// `ObjectiveFactory` provides static methods to instantiate objective
59/// functions with appropriate types and parameters. It handles validation
60/// and configuration of objective-specific hyperparameters.
61pub struct ObjectiveFactory;
62
63impl ObjectiveFactory {
64    /// Creates a regression objective function by name.
65    /// 
66    /// # Supported Names
67    /// - `mse`, `mean_squared_error`
68    /// - `mae`, `mean_absolute_error`
69    /// - `huber` (creates with delta=1.0)
70    /// 
71    /// # Parameters
72    /// - `name`: Objective function name (case-insensitive)
73    /// 
74    /// # Returns
75    /// Boxed trait object implementing [`RegressionObjective`]
76    /// 
77    /// # Errors
78    /// - `ObjectiveError::ConfigError` if `name` is not a valid regression objective    
79    pub fn create_regression_objective(name: &str) -> ObjectiveResult<Box<dyn RegressionObjective>> {
80        match name.to_lowercase().as_str() {
81            "mse" | "mean_squared_error" => Ok(Box::new(MSEObjective::new())),
82            "mae" | "mean_absolute_error" => Ok(Box::new(MAEObjective::new())),
83            "huber" => Ok(Box::new(HuberObjective::new(1.0))),
84            _ => Err(ObjectiveError::ConfigError(
85                format!("Unknown regression objective: {}", name)
86            )),
87        }
88    }
89    
90    /// Creates a binary classification objective function by name.
91    /// 
92    /// # Supported Names
93    /// - `log_loss`, `binary_crossentropy`, `binary_cross_entropy`
94    /// 
95    /// # Parameters
96    /// - `name`: Objective function name (case-insensitive)
97    /// 
98    /// # Returns
99    /// Boxed trait object implementing [`BinaryClassificationObjective`]
100    /// 
101    /// # Errors
102    /// - `ObjectiveError::ConfigError` if `name` is not a valid classification objective    
103    pub fn create_binary_classification_objective(name: &str) -> ObjectiveResult<Box<dyn BinaryClassificationObjective>> {
104        match name.to_lowercase().as_str() {
105            "log_loss" | "binary_crossentropy" | "binary_cross_entropy" => {
106                Ok(Box::new(LogLossObjective::new()))
107            },
108            _ => Err(ObjectiveError::ConfigError(
109                format!("Unknown binary classification objective: {}", name)
110            )),
111        }
112    }
113    
114    /// Creates an objective from a configuration struct.
115    /// 
116    /// # Parameters
117    /// - `config`: [`ObjectiveConfig`] with name and parameters
118    /// 
119    /// # Returns
120    /// Boxed trait object implementing [`Objective`]
121    /// 
122    /// # Supported Configurations
123    /// - `mse`: No parameters
124    /// - `mae`: No parameters
125    /// - `huber`: `delta` parameter (default: 1.0)
126    /// - `log_loss`: `epsilon` parameter (default: 1e-15)
127    /// 
128    /// # Errors
129    /// - `ObjectiveError::ConfigError` if objective name is unknown
130    /// - `ObjectiveError::ConfigError` if parameters are invalid    
131    pub fn create_objective_from_config(config: &ObjectiveConfig) -> ObjectiveResult<Box<dyn Objective>> {
132        let name = config.name.to_lowercase();
133        
134        match name.as_str() {
135            // Regression objectives
136            "mse" | "mean_squared_error" => Ok(Box::new(MSEObjective::new())),
137            "mae" | "mean_absolute_error" => Ok(Box::new(MAEObjective::new())),
138            "huber" => {
139                let delta = config.get_param("delta").unwrap_or(1.0);
140                Ok(Box::new(HuberObjective::new(delta)))
141            },
142            
143            // Binary classification objectives
144            "log_loss" | "binary_crossentropy" | "binary_cross_entropy" => {
145                let epsilon = config.get_param("epsilon").unwrap_or(1e-15);
146                Ok(Box::new(LogLossObjective::new().with_epsilon(epsilon)))
147            },
148            
149            _ => Err(ObjectiveError::ConfigError(
150                format!("Unknown objective: {}", config.name)
151            )),
152        }
153    }
154    
155    /// Determines the task type for a given objective name.
156    /// 
157    /// # Parameters
158    /// - `name`: Objective name (case-insensitive)
159    /// 
160    /// # Returns
161    /// [`ObjectiveType`] indicating regression or classification
162    /// 
163    /// # Errors
164    /// - `ObjectiveError::ConfigError` if objective name is unknown    
165    pub fn get_objective_type(name: &str) -> ObjectiveResult<ObjectiveType> {
166        match name.to_lowercase().as_str() {
167            "mse" | "mean_squared_error" | "mae" | "mean_absolute_error" | "huber" => {
168                Ok(ObjectiveType::Regression)
169            },
170            "log_loss" | "binary_crossentropy" | "binary_cross_entropy" => {
171                Ok(ObjectiveType::BinaryClassification)
172            },
173            _ => Err(ObjectiveError::ConfigError(
174                format!("Unknown objective: {}", name)
175            )),
176        }
177    }
178    
179    /// Returns a list of all available objective names.
180    /// 
181    /// # Returns
182    /// Vector of lowercase objective names
183    pub fn available_objectives() -> Vec<&'static str> {
184        vec![
185            "mse",
186            "mae", 
187            "huber",
188            "log_loss",
189        ]
190    }
191    
192    /// Gets default parameters for an objective.
193    /// 
194    /// # Parameters
195    /// - `name`: Objective name (case-insensitive)
196    /// 
197    /// # Returns
198    /// HashMap of parameter names to default values
199    /// 
200    /// # Errors
201    /// - `ObjectiveError::ConfigError` if objective name is unknown
202    pub fn get_default_params(name: &str) -> ObjectiveResult<HashMap<String, f64>> {
203        let mut params = HashMap::new();
204        
205        match name.to_lowercase().as_str() {
206            "huber" => {
207                params.insert("delta".to_string(), 1.0);
208            },
209            "log_loss" => {
210                params.insert("epsilon".to_string(), 1e-15);
211            },
212            "mse" | "mae" => {
213                // No parameters for these
214            },
215            _ => return Err(ObjectiveError::ConfigError(
216                format!("Unknown objective: {}", name)
217            )),
218        }
219        
220        Ok(params)
221    }
222}
223
224/// Convenience function to create an objective by name with default parameters.
225/// 
226/// # Parameters
227/// - `name`: Objective name (case-insensitive)
228/// 
229/// # Returns
230/// Boxed trait object implementing [`Objective`]
231pub fn create_objective(name: &str) -> ObjectiveResult<Box<dyn Objective>> {
232    let config = ObjectiveConfig::new(name);
233    ObjectiveFactory::create_objective_from_config(&config)
234}
235
236/// Fluent builder for constructing [`ObjectiveConfig`].
237/// 
238/// Provides a convenient API for configuring objective functions with parameters.
239pub struct ObjectiveBuilder {
240    name: String,
241    params: HashMap<String, f64>,
242}
243
244impl ObjectiveBuilder {
245    /// Creates a new builder for the specified objective.
246    /// 
247    /// # Parameters
248    /// - `name`: Objective name (case-insensitive)
249    pub fn new(name: &str) -> Self {
250        Self {
251            name: name.to_string(),
252            params: HashMap::new(),
253        }
254    }
255    
256    /// Adds a parameter to the objective configuration.
257    /// 
258    /// # Parameters
259    /// - `key`: Parameter name (e.g., "delta" for Huber)
260    /// - `value`: Parameter value
261    pub fn with_param(mut self, key: &str, value: f64) -> Self {
262        self.params.insert(key.to_string(), value);
263        self
264    }
265
266    /// Builds the final ObjectiveConfig.
267    /// 
268    /// # Returns
269    /// Configured ObjectiveConfig ready for factory consumption
270    pub fn build(self) -> ObjectiveConfig {
271        ObjectiveConfig {
272            name: self.name,
273            params: self.params,
274        }
275    }
276}
277
278/// Registry for objective discovery and validation.
279/// 
280/// Provides utilities for checking objective availability, exploring options,
281/// and validating configurations before instantiation.
282pub struct ObjectiveRegistry;
283
284impl ObjectiveRegistry {
285    /// Checks if an objective name is valid and supported.
286    /// 
287    /// # Parameters
288    /// - `name`: Objective name (case-insensitive)
289    /// 
290    /// # Returns
291    /// `true` if objective exists, `false` otherwise
292    pub fn is_valid_objective(name: &str) -> bool {
293        ObjectiveFactory::available_objectives().contains(&name)
294    }
295    
296    /// Returns all regression objective names.
297    pub fn regression_objectives() -> Vec<&'static str> {
298        vec!["mse", "mae", "huber"]
299    }
300    
301    /// Returns all classification objective names.
302    pub fn classification_objectives() -> Vec<&'static str> {
303        vec!["log_loss"]
304    }
305    
306    /// Validates an objective configuration before instantiation.
307    /// 
308    /// Checks that:
309    /// - Objective name is valid
310    /// - All parameters are recognized
311    /// - Parameter values are valid (positive for numerical params)
312    /// 
313    /// # Parameters
314    /// - `config`: ObjectiveConfig to validate
315    /// 
316    /// # Returns
317    /// `Ok(())` if configuration is valid
318    /// 
319    /// # Errors
320    /// - `ObjectiveError::ConfigError` if objective name is unknown
321    /// - `ObjectiveError::ConfigError` if parameters are invalid 
322    pub fn validate_config(config: &ObjectiveConfig) -> ObjectiveResult<()> {
323        if !Self::is_valid_objective(&config.name) {
324            return Err(ObjectiveError::ConfigError(
325                format!("Invalid objective name: {}", config.name)
326            ));
327        }
328        
329        // Validate parameters
330        let default_params = ObjectiveFactory::get_default_params(&config.name)?;
331        for (key, value) in &config.params {
332            if !default_params.contains_key(key) {
333                return Err(ObjectiveError::ConfigError(
334                    format!("Unknown parameter '{}' for objective '{}'", key, config.name)
335                ));
336            }
337            
338            // Validate parameter values
339            match key.as_str() {
340                "delta" | "alpha" | "epsilon" => {
341                    if *value <= 0.0 {
342                        return Err(ObjectiveError::ConfigError(
343                            format!("Parameter '{}' must be positive, got {}", key, value)
344                        ));
345                    }
346                },
347                _ => {}
348            }
349        }
350        
351        Ok(())
352    }
353}
354