rusty_ai/trees/
params.rs

1use std::error::Error;
2
3/// Struct representing the parameters for a decision tree.
4#[derive(Clone, Debug)]
5pub struct TreeParams {
6    pub min_samples_split: u16,
7    pub max_depth: Option<u16>,
8}
9
10impl Default for TreeParams {
11    /// Creates a new instance of `TreeParams` with default values.
12    fn default() -> Self {
13        Self::new()
14    }
15}
16
17impl TreeParams {
18    /// Creates a new instance of `TreeParams` with default values.
19    pub fn new() -> Self {
20        Self {
21            min_samples_split: 2,
22            max_depth: None,
23        }
24    }
25
26    /// Sets the minimum number of samples required to split a node.
27    ///
28    /// # Arguments
29    ///
30    /// * `min_samples_split` - The minimum number of samples to split.
31    ///
32    /// # Errors
33    ///
34    /// Returns an error if `min_samples_split` is less than 2.
35    pub fn set_min_samples_split(&mut self, min_samples_split: u16) -> Result<(), Box<dyn Error>> {
36        if min_samples_split < 2 {
37            return Err("The minimum number of samples to split must be greater than 1.".into());
38        }
39        self.min_samples_split = min_samples_split;
40        Ok(())
41    }
42
43    /// Sets the maximum depth of the decision tree.
44    ///
45    /// # Arguments
46    ///
47    /// * `max_depth` - The maximum depth of the tree.
48    ///
49    /// # Errors
50    ///
51    /// Returns an error if `max_depth` is less than 1.
52    pub fn set_max_depth(&mut self, max_depth: Option<u16>) -> Result<(), Box<dyn Error>> {
53        if max_depth.is_some_and(|depth| depth < 1) {
54            return Err("The maximum depth must be greater than 0.".into());
55        }
56        self.max_depth = max_depth;
57        Ok(())
58    }
59
60    /// Returns the minimum number of samples required to split a node.
61    pub fn min_samples_split(&self) -> u16 {
62        self.min_samples_split
63    }
64
65    /// Returns the maximum depth of the decision tree.
66    pub fn max_depth(&self) -> Option<u16> {
67        self.max_depth
68    }
69}
70
71/// Struct representing the parameters for a decision tree classifier.
72#[derive(Clone, Debug)]
73pub struct TreeClassifierParams {
74    pub base_params: TreeParams,
75    pub criterion: String,
76}
77
78impl Default for TreeClassifierParams {
79    /// Creates a new instance of `TreeClassifierParams` with default values.
80    fn default() -> Self {
81        Self::new()
82    }
83}
84
85impl TreeClassifierParams {
86    /// Creates a new instance of `TreeClassifierParams` with default values.
87    pub fn new() -> Self {
88        Self {
89            base_params: TreeParams::new(),
90            criterion: "gini".to_string(),
91        }
92    }
93
94    /// Sets the minimum number of samples required to split a node.
95    ///
96    /// # Arguments
97    ///
98    /// * `min_samples_split` - The minimum number of samples to split.
99    ///
100    /// # Errors
101    ///
102    /// Returns an error if `min_samples_split` is less than 2.
103    pub fn set_min_samples_split(&mut self, min_samples_split: u16) -> Result<(), Box<dyn Error>> {
104        self.base_params.set_min_samples_split(min_samples_split)
105    }
106
107    /// Sets the maximum depth of the decision tree.
108    ///
109    /// # Arguments
110    ///
111    /// * `max_depth` - The maximum depth of the tree.
112    ///
113    /// # Errors
114    ///
115    /// Returns an error if `max_depth` is less than 1.
116    pub fn set_max_depth(&mut self, max_depth: Option<u16>) -> Result<(), Box<dyn Error>> {
117        self.base_params.set_max_depth(max_depth)
118    }
119
120    /// Sets the criterion used for splitting nodes in the decision tree.
121    ///
122    /// # Arguments
123    ///
124    /// * `criterion` - The criterion for splitting nodes.
125    ///
126    /// # Errors
127    ///
128    /// Returns an error if `criterion` is not "gini" or "entropy".
129    pub fn set_criterion(&mut self, criterion: String) -> Result<(), Box<dyn Error>> {
130        if !["gini", "entropy"].contains(&criterion.as_str()) {
131            return Err("The criterion must be either 'gini' or 'entropy'.".into());
132        }
133        self.criterion = criterion;
134        Ok(())
135    }
136
137    /// Returns the minimum number of samples required to split a node.
138    pub fn min_samples_split(&self) -> u16 {
139        self.base_params.min_samples_split
140    }
141
142    /// Returns the maximum depth of the decision tree.
143    pub fn max_depth(&self) -> Option<u16> {
144        self.base_params.max_depth
145    }
146
147    /// Returns the criterion used for splitting nodes in the decision tree.
148    pub fn criterion(&self) -> &str {
149        &self.criterion
150    }
151}