1use crate::data::FloatData;
6use crate::splitter::{MissingInfo, NodeInfo, SplitInfo};
7use crate::utils::is_missing;
8use serde::de::{self, Visitor};
9use serde::{Deserialize, Deserializer, Serialize, Serializer};
10use std::cmp::Ordering;
11use std::fmt::{self, Debug, Write};
12
13#[derive(Debug, Deserialize, Serialize)]
14pub struct SplittableNode {
15 pub num: usize,
16 pub weight_value: f32,
17 pub gain_value: f32,
18 pub gradient_sum: f32,
19 pub hessian_sum: f32,
20 pub split_value: f64,
21 pub split_feature: usize,
22 pub split_gain: f32,
23 pub missing_node: usize,
24 pub left_child: usize,
25 pub right_child: usize,
26 pub start_idx: usize,
27 pub stop_idx: usize,
28 pub lower_bound: f32,
29 pub upper_bound: f32,
30 pub is_leaf: bool,
31 pub is_missing_leaf: bool,
32 pub parent_node: usize,
33 #[allow(clippy::box_collection)]
34 #[serde(serialize_with = "serialize_left_cats", deserialize_with = "deserialize_left_cats")]
35 pub left_cats: Option<Box<[u8]>>,
36 pub stats: Option<Box<NodeStats>>,
37}
38
39#[derive(Deserialize, Serialize, Clone, Debug)]
41pub struct NodeStats {
42 pub depth: usize,
43 pub node_type: NodeType,
44 pub count: usize,
45 pub generalization: Option<f32>,
46 pub weights: [f32; 5],
47}
48
49#[derive(Deserialize, Serialize, Clone, Debug)]
50pub struct Node {
51 pub num: usize,
52 pub weight_value: f32,
53 pub hessian_sum: f32,
54 pub split_value: f64,
55 pub split_feature: usize,
56 pub split_gain: f32,
57 pub missing_node: usize,
58 pub left_child: usize,
59 pub right_child: usize,
60 pub is_leaf: bool,
61 pub parent_node: usize,
62 #[allow(clippy::box_collection)]
63 #[serde(serialize_with = "serialize_left_cats", deserialize_with = "deserialize_left_cats")]
64 pub left_cats: Option<Box<[u8]>>,
65 pub stats: Option<Box<NodeStats>>,
66}
67
68impl Ord for SplittableNode {
69 fn cmp(&self, other: &Self) -> Ordering {
70 self.gain_value.total_cmp(&other.gain_value)
71 }
72}
73
74impl PartialOrd for SplittableNode {
75 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
76 Some(self.cmp(other))
77 }
78}
79
80impl PartialEq for SplittableNode {
81 fn eq(&self, other: &Self) -> bool {
82 self.gain_value == other.gain_value
83 }
84}
85
86impl Eq for SplittableNode {}
87
88impl Node {
89 pub fn make_parent_node(&mut self, split_node: SplittableNode, eta: f32) {
92 self.is_leaf = false;
93 self.missing_node = split_node.missing_node;
94 self.split_value = split_node.split_value;
95 self.split_feature = split_node.split_feature;
96 self.split_gain = split_node.split_gain;
97 self.left_child = split_node.left_child;
98 self.right_child = split_node.right_child;
99 self.parent_node = split_node.parent_node;
100 self.left_cats = split_node.left_cats;
101 if let (Some(stats), Some(sn_stats)) = (&mut self.stats, split_node.stats) {
103 stats.generalization = sn_stats.generalization;
104 stats.weights = sn_stats.weights.map(|x| x * eta);
105 }
106 }
107 pub fn get_child_idx(&self, v: &f64, missing: &f64) -> usize {
109 if is_missing(v, missing) {
111 return self.missing_node;
112 }
113
114 if let Some(left_cats) = &self.left_cats {
116 let cat_idx = *v as usize;
117 let byte_idx = cat_idx >> 3;
118 let bit_idx = cat_idx & 7;
119 if let Some(&byte) = left_cats.get(byte_idx) {
120 if (byte >> bit_idx) & 1 == 1 {
121 return self.left_child;
122 } else {
123 return self.right_child;
124 }
125 } else {
126 return self.right_child;
127 }
128 }
129
130 if v < &self.split_value {
132 self.left_child
133 } else {
134 self.right_child
135 }
136 }
137
138 pub fn has_missing_branch(&self) -> bool {
139 (self.missing_node != self.right_child) && (self.missing_node != self.left_child)
140 }
141}
142
143#[derive(Clone, Copy, Debug, Deserialize, Serialize, PartialEq)]
144pub enum NodeType {
145 Root,
146 Left,
147 Right,
148 Missing,
149}
150
151impl SplittableNode {
152 #[allow(clippy::too_many_arguments)]
153 pub fn from_node_info(
154 num: usize,
155 depth: usize,
156 start_idx: usize,
157 stop_idx: usize,
158 node_info: &NodeInfo,
159 generalization: Option<f32>,
160 node_type: NodeType,
161 parent_node: usize,
162 ) -> Self {
163 SplittableNode {
164 num,
165 weight_value: node_info.weight,
166 gain_value: node_info.gain,
167 gradient_sum: node_info.grad,
168 hessian_sum: node_info.cover,
169 split_value: f64::ZERO,
170 split_feature: 0,
171 split_gain: f32::ZERO,
172 missing_node: 0,
173 left_child: 0,
174 right_child: 0,
175 start_idx,
176 stop_idx,
177 lower_bound: node_info.bounds.0,
178 upper_bound: node_info.bounds.1,
179 is_leaf: true,
180 is_missing_leaf: false,
181 parent_node,
182 left_cats: None,
183 stats: Some(Box::new(NodeStats {
184 depth,
185 node_type,
186 count: node_info.counts,
187 generalization,
188 weights: node_info.weights,
189 })),
190 }
191 }
192
193 #[allow(clippy::too_many_arguments)]
196 #[allow(clippy::box_collection)]
197 pub fn new(
198 num: usize,
199 weight_value: f32,
200 gain_value: f32,
201 gradient_sum: f32,
202 hessian_sum: f32,
203 counts_sum: usize,
204 depth: usize,
205 start_idx: usize,
206 stop_idx: usize,
207 lower_bound: f32,
208 upper_bound: f32,
209 node_type: NodeType,
210 left_cats: Option<Box<[u8]>>,
211 weights: [f32; 5],
212 ) -> Self {
213 SplittableNode {
214 num,
215 weight_value,
216 gain_value,
217 gradient_sum,
218 hessian_sum,
219 split_value: f64::ZERO,
220 split_feature: 0,
221 split_gain: f32::ZERO,
222 missing_node: 0,
223 left_child: 0,
224 right_child: 0,
225 start_idx,
226 stop_idx,
227 lower_bound,
228 upper_bound,
229 is_leaf: true,
230 is_missing_leaf: false,
231 parent_node: 0,
232 left_cats,
233 stats: Some(Box::new(NodeStats {
234 depth,
235 node_type,
236 count: counts_sum,
237 generalization: None,
238 weights,
239 })),
240 }
241 }
242
243 pub fn update_children(
244 &mut self,
245 missing_child: usize,
246 left_child: usize,
247 right_child: usize,
248 split_info: &SplitInfo,
249 ) {
250 self.left_child = left_child;
251 self.right_child = right_child;
252 self.split_feature = split_info.split_feature;
253 self.split_gain = self.get_split_gain(&split_info.left_node, &split_info.right_node, &split_info.missing_node);
254 self.split_value = split_info.split_value;
255 self.missing_node = missing_child;
256 self.is_leaf = false;
257 self.left_cats = split_info.left_cats.as_ref().map(|bitset| {
258 let mut max_byte = 0;
259 for (i, &b) in bitset.iter().enumerate() {
260 if b != 0 {
261 max_byte = i;
262 }
263 }
264 bitset[..=max_byte].to_vec().into_boxed_slice()
265 });
266 }
267
268 pub fn get_split_gain(
269 &self,
270 left_node_info: &NodeInfo,
271 right_node_info: &NodeInfo,
272 missing_node_info: &MissingInfo,
273 ) -> f32 {
274 let missing_split_gain = match &missing_node_info {
275 MissingInfo::Branch(ni) | MissingInfo::Leaf(ni) => ni.gain,
276 _ => 0.,
277 };
278 left_node_info.gain + right_node_info.gain + missing_split_gain - self.gain_value
279 }
280
281 pub fn as_node(&self, eta: f32, save_node_stats: bool) -> Node {
282 Node {
283 num: self.num,
284 weight_value: self.weight_value * eta,
285 hessian_sum: self.hessian_sum,
286 missing_node: self.missing_node,
287 split_value: self.split_value,
288 split_feature: self.split_feature,
289 split_gain: self.split_gain,
290 left_child: self.left_child,
291 right_child: self.right_child,
292 is_leaf: self.is_leaf,
293 parent_node: self.parent_node,
294 left_cats: self.left_cats.clone(),
295 stats: if save_node_stats {
296 if let Some(s) = &self.stats {
297 let mut stats = s.clone();
298 stats.weights = stats.weights.map(|x| x * eta);
299 Some(stats)
300 } else {
301 None
302 }
303 } else {
304 None
305 },
306 }
307 }
308}
309
310impl fmt::Display for Node {
311 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
313 if self.is_leaf {
314 write!(f, "{}:leaf={},cover={}", self.num, self.weight_value, self.hessian_sum)
315 } else {
316 write!(
317 f,
318 "{}:[{} < {}] yes={},no={},missing={},gain={},cover={}",
319 self.num,
320 self.split_feature,
321 self.split_value,
322 self.left_child,
323 self.right_child,
324 self.missing_node,
325 self.split_gain,
326 self.hessian_sum
327 )
328 }
329 }
330}
331
332pub fn serialize_left_cats<S>(left_cats: &Option<Box<[u8]>>, serializer: S) -> Result<S::Ok, S::Error>
333where
334 S: Serializer,
335{
336 match left_cats {
337 Some(bytes) => {
338 let mut s = String::with_capacity(bytes.len() * 2);
339 for &b in bytes.as_ref() {
340 write!(&mut s, "{:02x}", b).map_err(serde::ser::Error::custom)?;
341 }
342 serializer.serialize_str(&s)
343 }
344 None => serializer.serialize_none(),
345 }
346}
347
348pub fn deserialize_left_cats<'de, D>(deserializer: D) -> Result<Option<Box<[u8]>>, D::Error>
349where
350 D: Deserializer<'de>,
351{
352 struct LeftCatsVisitor;
353
354 impl<'de> Visitor<'de> for LeftCatsVisitor {
355 type Value = Option<Box<[u8]>>;
356
357 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
358 formatter.write_str("a hex string, an array of bytes, or null")
359 }
360
361 fn visit_none<E>(self) -> Result<Self::Value, E>
362 where
363 E: de::Error,
364 {
365 Ok(None)
366 }
367
368 fn visit_unit<E>(self) -> Result<Self::Value, E>
369 where
370 E: de::Error,
371 {
372 Ok(None)
373 }
374
375 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
376 where
377 E: de::Error,
378 {
379 if !v.len().is_multiple_of(2) {
380 return Err(de::Error::custom("hex string must have even length"));
381 }
382 let bytes = (0..v.len())
383 .step_by(2)
384 .map(|i| u8::from_str_radix(&v[i..i + 2], 16).map_err(|e| de::Error::custom(e.to_string())))
385 .collect::<Result<Vec<u8>, E>>()?;
386 Ok(Some(bytes.into_boxed_slice()))
387 }
388
389 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
390 where
391 A: de::SeqAccess<'de>,
392 {
393 let mut bytes = Vec::new();
394 while let Some(byte) = seq.next_element()? {
395 bytes.push(byte);
396 }
397 Ok(Some(bytes.into_boxed_slice()))
398 }
399 }
400
401 deserializer.deserialize_any(LeftCatsVisitor)
402}