rusty_ai/trees/params.rs
1use std::error::Error;
2
3/// Struct representing the parameters for a decision tree.
4#[derive(Clone, Debug)]
5pub struct TreeParams {
6 pub min_samples_split: u16,
7 pub max_depth: Option<u16>,
8}
9
10impl Default for TreeParams {
11 /// Creates a new instance of `TreeParams` with default values.
12 fn default() -> Self {
13 Self::new()
14 }
15}
16
17impl TreeParams {
18 /// Creates a new instance of `TreeParams` with default values.
19 pub fn new() -> Self {
20 Self {
21 min_samples_split: 2,
22 max_depth: None,
23 }
24 }
25
26 /// Sets the minimum number of samples required to split a node.
27 ///
28 /// # Arguments
29 ///
30 /// * `min_samples_split` - The minimum number of samples to split.
31 ///
32 /// # Errors
33 ///
34 /// Returns an error if `min_samples_split` is less than 2.
35 pub fn set_min_samples_split(&mut self, min_samples_split: u16) -> Result<(), Box<dyn Error>> {
36 if min_samples_split < 2 {
37 return Err("The minimum number of samples to split must be greater than 1.".into());
38 }
39 self.min_samples_split = min_samples_split;
40 Ok(())
41 }
42
43 /// Sets the maximum depth of the decision tree.
44 ///
45 /// # Arguments
46 ///
47 /// * `max_depth` - The maximum depth of the tree.
48 ///
49 /// # Errors
50 ///
51 /// Returns an error if `max_depth` is less than 1.
52 pub fn set_max_depth(&mut self, max_depth: Option<u16>) -> Result<(), Box<dyn Error>> {
53 if max_depth.is_some_and(|depth| depth < 1) {
54 return Err("The maximum depth must be greater than 0.".into());
55 }
56 self.max_depth = max_depth;
57 Ok(())
58 }
59
60 /// Returns the minimum number of samples required to split a node.
61 pub fn min_samples_split(&self) -> u16 {
62 self.min_samples_split
63 }
64
65 /// Returns the maximum depth of the decision tree.
66 pub fn max_depth(&self) -> Option<u16> {
67 self.max_depth
68 }
69}
70
71/// Struct representing the parameters for a decision tree classifier.
72#[derive(Clone, Debug)]
73pub struct TreeClassifierParams {
74 pub base_params: TreeParams,
75 pub criterion: String,
76}
77
78impl Default for TreeClassifierParams {
79 /// Creates a new instance of `TreeClassifierParams` with default values.
80 fn default() -> Self {
81 Self::new()
82 }
83}
84
85impl TreeClassifierParams {
86 /// Creates a new instance of `TreeClassifierParams` with default values.
87 pub fn new() -> Self {
88 Self {
89 base_params: TreeParams::new(),
90 criterion: "gini".to_string(),
91 }
92 }
93
94 /// Sets the minimum number of samples required to split a node.
95 ///
96 /// # Arguments
97 ///
98 /// * `min_samples_split` - The minimum number of samples to split.
99 ///
100 /// # Errors
101 ///
102 /// Returns an error if `min_samples_split` is less than 2.
103 pub fn set_min_samples_split(&mut self, min_samples_split: u16) -> Result<(), Box<dyn Error>> {
104 self.base_params.set_min_samples_split(min_samples_split)
105 }
106
107 /// Sets the maximum depth of the decision tree.
108 ///
109 /// # Arguments
110 ///
111 /// * `max_depth` - The maximum depth of the tree.
112 ///
113 /// # Errors
114 ///
115 /// Returns an error if `max_depth` is less than 1.
116 pub fn set_max_depth(&mut self, max_depth: Option<u16>) -> Result<(), Box<dyn Error>> {
117 self.base_params.set_max_depth(max_depth)
118 }
119
120 /// Sets the criterion used for splitting nodes in the decision tree.
121 ///
122 /// # Arguments
123 ///
124 /// * `criterion` - The criterion for splitting nodes.
125 ///
126 /// # Errors
127 ///
128 /// Returns an error if `criterion` is not "gini" or "entropy".
129 pub fn set_criterion(&mut self, criterion: String) -> Result<(), Box<dyn Error>> {
130 if !["gini", "entropy"].contains(&criterion.as_str()) {
131 return Err("The criterion must be either 'gini' or 'entropy'.".into());
132 }
133 self.criterion = criterion;
134 Ok(())
135 }
136
137 /// Returns the minimum number of samples required to split a node.
138 pub fn min_samples_split(&self) -> u16 {
139 self.base_params.min_samples_split
140 }
141
142 /// Returns the maximum depth of the decision tree.
143 pub fn max_depth(&self) -> Option<u16> {
144 self.base_params.max_depth
145 }
146
147 /// Returns the criterion used for splitting nodes in the decision tree.
148 pub fn criterion(&self) -> &str {
149 &self.criterion
150 }
151}