use std::collections::HashMap;
use crate::dataset::BinnedDataset;
use crate::Result;
#[derive(Debug, Clone, PartialEq)]
pub enum ParamValue {
Numeric(f32),
Categorical(String),
}
impl ParamValue {
pub fn as_numeric(&self) -> f32 {
match self {
Self::Numeric(v) => *v,
Self::Categorical(_) => 0.0,
}
}
pub fn as_categorical(&self) -> Option<&str> {
match self {
Self::Categorical(s) => Some(s),
Self::Numeric(_) => None,
}
}
pub fn is_numeric(&self) -> bool {
matches!(self, Self::Numeric(_))
}
pub fn is_categorical(&self) -> bool {
matches!(self, Self::Categorical(_))
}
}
impl From<f32> for ParamValue {
fn from(v: f32) -> Self {
Self::Numeric(v)
}
}
impl From<String> for ParamValue {
fn from(s: String) -> Self {
Self::Categorical(s)
}
}
impl From<&str> for ParamValue {
fn from(s: &str) -> Self {
Self::Categorical(s.to_string())
}
}
pub trait TunableModel: Clone + Send + Sync + Sized {
type Config: Clone + Send + Sync;
fn train(dataset: &BinnedDataset, config: &Self::Config) -> Result<Self>;
fn train_with_validation(
train_data: &BinnedDataset,
val_data: &BinnedDataset,
val_targets: &[f32],
config: &Self::Config,
) -> Result<Self> {
let _ = (val_data, val_targets); Self::train(train_data, config)
}
fn predict(&self, dataset: &BinnedDataset) -> Vec<f32>;
fn num_trees(&self) -> usize;
fn apply_params(config: &mut Self::Config, params: &HashMap<String, ParamValue>);
fn valid_params() -> &'static [&'static str];
fn default_config() -> Self::Config;
fn is_gpu_config(config: &Self::Config) -> bool {
let _ = config;
false
}
fn get_learning_rate(config: &Self::Config) -> f32 {
let _ = config;
0.1
}
fn configure_validation(
config: &mut Self::Config,
validation_ratio: f32,
early_stopping_rounds: usize,
) {
let _ = (config, validation_ratio, early_stopping_rounds);
}
fn set_num_rounds(config: &mut Self::Config, num_rounds: usize) {
let _ = (config, num_rounds);
}
fn save_rkyv(&self, path: &std::path::Path) -> Result<()> {
let _ = path;
Err(crate::TreeBoostError::Config(
"Model serialization not supported for this model type".to_string(),
))
}
fn save_bincode(&self, path: &std::path::Path) -> Result<()> {
let _ = path;
Err(crate::TreeBoostError::Config(
"Model serialization not supported for this model type".to_string(),
))
}
fn supports_conformal() -> bool {
false
}
fn conformal_quantile(&self) -> Option<f32> {
None
}
fn configure_conformal(config: &mut Self::Config, calibration_ratio: f32, quantile: f32) {
let _ = (config, calibration_ratio, quantile);
}
}
pub trait ParamMapExt {
fn to_param_values(&self) -> HashMap<String, ParamValue>;
fn to_param_values_with_space(
&self,
space: &super::config::ParameterSpace,
) -> HashMap<String, ParamValue>;
}
impl ParamMapExt for HashMap<String, f32> {
fn to_param_values(&self) -> HashMap<String, ParamValue> {
self.iter()
.map(|(k, v)| (k.clone(), ParamValue::Numeric(*v)))
.collect()
}
fn to_param_values_with_space(
&self,
space: &super::config::ParameterSpace,
) -> HashMap<String, ParamValue> {
self.iter()
.map(|(k, v)| {
if let Some(param_def) = space.get(k) {
if let Some(cat_value) = param_def.bounds.get_categorical_value(*v as usize) {
return (k.clone(), ParamValue::Categorical(cat_value.to_string()));
}
}
(k.clone(), ParamValue::Numeric(*v))
})
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_param_value_numeric() {
let v = ParamValue::Numeric(3.14);
assert!(v.is_numeric());
assert!(!v.is_categorical());
assert!((v.as_numeric() - 3.14).abs() < 1e-6);
assert!(v.as_categorical().is_none());
}
#[test]
fn test_param_value_categorical() {
let v = ParamValue::Categorical("PureTree".to_string());
assert!(v.is_categorical());
assert!(!v.is_numeric());
assert_eq!(v.as_categorical(), Some("PureTree"));
assert!((v.as_numeric() - 0.0).abs() < 1e-6);
}
#[test]
fn test_param_value_from() {
let v1: ParamValue = 3.14f32.into();
assert!(v1.is_numeric());
let v2: ParamValue = "LinearThenTree".into();
assert!(v2.is_categorical());
let v3: ParamValue = String::from("RandomForest").into();
assert!(v3.is_categorical());
}
#[test]
fn test_param_map_ext() {
let mut map = HashMap::new();
map.insert("learning_rate".to_string(), 0.1f32);
map.insert("max_depth".to_string(), 6.0f32);
let param_values = map.to_param_values();
assert_eq!(param_values.len(), 2);
assert!(param_values["learning_rate"].is_numeric());
assert!((param_values["learning_rate"].as_numeric() - 0.1).abs() < 1e-6);
}
#[test]
fn test_param_map_ext_with_space() {
use crate::tuner::config::{ParamBounds, ParameterSpace};
let space = ParameterSpace::new()
.with_param("learning_rate", ParamBounds::continuous(0.01, 0.5), 0.1)
.with_param(
"mode",
ParamBounds::categorical(vec![
"PureTree".to_string(),
"LinearThenTree".to_string(),
"RandomForest".to_string(),
]),
0.0, );
let mut params = HashMap::new();
params.insert("learning_rate".to_string(), 0.15f32);
params.insert("mode".to_string(), 1.0f32);
let param_values = params.to_param_values_with_space(&space);
assert!(param_values["learning_rate"].is_numeric());
assert!((param_values["learning_rate"].as_numeric() - 0.15).abs() < 1e-6);
assert!(param_values["mode"].is_categorical());
assert_eq!(
param_values["mode"].as_categorical(),
Some("LinearThenTree")
);
}
#[test]
fn test_param_map_ext_with_space_unknown_param() {
use crate::tuner::config::{ParamBounds, ParameterSpace};
let space = ParameterSpace::new().with_param(
"learning_rate",
ParamBounds::continuous(0.01, 0.5),
0.1,
);
let mut params = HashMap::new();
params.insert("learning_rate".to_string(), 0.15f32);
params.insert("unknown_param".to_string(), 42.0f32);
let param_values = params.to_param_values_with_space(&space);
assert!(param_values["learning_rate"].is_numeric());
assert!(param_values["unknown_param"].is_numeric());
assert!((param_values["unknown_param"].as_numeric() - 42.0).abs() < 1e-6);
}
}