Skip to main content

perpetual/
node.rs

1//! Node
2//!
3//! Internal structures for representing nodes in a decision tree.
4//! This includes `SplittableNode` used during training and `Node` used for inference.
5use crate::data::FloatData;
6use crate::splitter::{MissingInfo, NodeInfo, SplitInfo};
7use crate::utils::is_missing;
8use serde::de::{self, Visitor};
9use serde::{Deserialize, Deserializer, Serialize, Serializer};
10use std::cmp::Ordering;
11use std::fmt::{self, Debug, Write};
12
13#[derive(Debug, Deserialize, Serialize)]
14pub struct SplittableNode {
15    pub num: usize,
16    pub weight_value: f32,
17    pub gain_value: f32,
18    pub gradient_sum: f32,
19    pub hessian_sum: f32,
20    pub split_value: f64,
21    pub split_feature: usize,
22    pub split_gain: f32,
23    pub missing_node: usize,
24    pub left_child: usize,
25    pub right_child: usize,
26    pub start_idx: usize,
27    pub stop_idx: usize,
28    pub lower_bound: f32,
29    pub upper_bound: f32,
30    pub is_leaf: bool,
31    pub is_missing_leaf: bool,
32    pub parent_node: usize,
33    #[allow(clippy::box_collection)]
34    #[serde(serialize_with = "serialize_left_cats", deserialize_with = "deserialize_left_cats")]
35    pub left_cats: Option<Box<[u8]>>,
36    pub stats: Option<Box<NodeStats>>,
37}
38
39/// Statistics stored for each node when save_node_stats is enabled.
40#[derive(Deserialize, Serialize, Clone, Debug)]
41pub struct NodeStats {
42    pub depth: usize,
43    pub node_type: NodeType,
44    pub count: usize,
45    pub generalization: Option<f32>,
46    pub weights: [f32; 5],
47}
48
49#[derive(Deserialize, Serialize, Clone, Debug)]
50pub struct Node {
51    pub num: usize,
52    pub weight_value: f32,
53    pub hessian_sum: f32,
54    pub split_value: f64,
55    pub split_feature: usize,
56    pub split_gain: f32,
57    pub missing_node: usize,
58    pub left_child: usize,
59    pub right_child: usize,
60    pub is_leaf: bool,
61    pub parent_node: usize,
62    #[allow(clippy::box_collection)]
63    #[serde(serialize_with = "serialize_left_cats", deserialize_with = "deserialize_left_cats")]
64    pub left_cats: Option<Box<[u8]>>,
65    pub stats: Option<Box<NodeStats>>,
66}
67
68impl Ord for SplittableNode {
69    fn cmp(&self, other: &Self) -> Ordering {
70        self.gain_value.total_cmp(&other.gain_value)
71    }
72}
73
74impl PartialOrd for SplittableNode {
75    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
76        Some(self.cmp(other))
77    }
78}
79
80impl PartialEq for SplittableNode {
81    fn eq(&self, other: &Self) -> bool {
82        self.gain_value == other.gain_value
83    }
84}
85
86impl Eq for SplittableNode {}
87
88impl Node {
89    /// Update all the info that is needed if this node is a
90    /// parent node, this consumes the SplitableNode.
91    pub fn make_parent_node(&mut self, split_node: SplittableNode, eta: f32) {
92        self.is_leaf = false;
93        self.missing_node = split_node.missing_node;
94        self.split_value = split_node.split_value;
95        self.split_feature = split_node.split_feature;
96        self.split_gain = split_node.split_gain;
97        self.left_child = split_node.left_child;
98        self.right_child = split_node.right_child;
99        self.parent_node = split_node.parent_node;
100        self.left_cats = split_node.left_cats;
101        // If we are keeping stats, update them from the split_node stats.
102        if let (Some(stats), Some(sn_stats)) = (&mut self.stats, split_node.stats) {
103            stats.generalization = sn_stats.generalization;
104            stats.weights = sn_stats.weights.map(|x| x * eta);
105        }
106    }
107    /// Get the path that should be traveled down, given a value.
108    pub fn get_child_idx(&self, v: &f64, missing: &f64) -> usize {
109        // Check for missing values FIRST
110        if is_missing(v, missing) {
111            return self.missing_node;
112        }
113
114        // Then check categorical splits
115        if let Some(left_cats) = &self.left_cats {
116            let cat_idx = *v as usize;
117            let byte_idx = cat_idx >> 3;
118            let bit_idx = cat_idx & 7;
119            if let Some(&byte) = left_cats.get(byte_idx) {
120                if (byte >> bit_idx) & 1 == 1 {
121                    return self.left_child;
122                } else {
123                    return self.right_child;
124                }
125            } else {
126                return self.right_child;
127            }
128        }
129
130        // Finally numerical splits
131        if v < &self.split_value {
132            self.left_child
133        } else {
134            self.right_child
135        }
136    }
137
138    pub fn has_missing_branch(&self) -> bool {
139        (self.missing_node != self.right_child) && (self.missing_node != self.left_child)
140    }
141}
142
143#[derive(Clone, Copy, Debug, Deserialize, Serialize, PartialEq)]
144pub enum NodeType {
145    Root,
146    Left,
147    Right,
148    Missing,
149}
150
151impl SplittableNode {
152    #[allow(clippy::too_many_arguments)]
153    pub fn from_node_info(
154        num: usize,
155        depth: usize,
156        start_idx: usize,
157        stop_idx: usize,
158        node_info: &NodeInfo,
159        generalization: Option<f32>,
160        node_type: NodeType,
161        parent_node: usize,
162    ) -> Self {
163        SplittableNode {
164            num,
165            weight_value: node_info.weight,
166            gain_value: node_info.gain,
167            gradient_sum: node_info.grad,
168            hessian_sum: node_info.cover,
169            split_value: f64::ZERO,
170            split_feature: 0,
171            split_gain: f32::ZERO,
172            missing_node: 0,
173            left_child: 0,
174            right_child: 0,
175            start_idx,
176            stop_idx,
177            lower_bound: node_info.bounds.0,
178            upper_bound: node_info.bounds.1,
179            is_leaf: true,
180            is_missing_leaf: false,
181            parent_node,
182            left_cats: None,
183            stats: Some(Box::new(NodeStats {
184                depth,
185                node_type,
186                count: node_info.counts,
187                generalization,
188                weights: node_info.weights,
189            })),
190        }
191    }
192
193    /// Create a default splitable node,
194    /// we default to the node being a leaf.
195    #[allow(clippy::too_many_arguments)]
196    #[allow(clippy::box_collection)]
197    pub fn new(
198        num: usize,
199        weight_value: f32,
200        gain_value: f32,
201        gradient_sum: f32,
202        hessian_sum: f32,
203        counts_sum: usize,
204        depth: usize,
205        start_idx: usize,
206        stop_idx: usize,
207        lower_bound: f32,
208        upper_bound: f32,
209        node_type: NodeType,
210        left_cats: Option<Box<[u8]>>,
211        weights: [f32; 5],
212    ) -> Self {
213        SplittableNode {
214            num,
215            weight_value,
216            gain_value,
217            gradient_sum,
218            hessian_sum,
219            split_value: f64::ZERO,
220            split_feature: 0,
221            split_gain: f32::ZERO,
222            missing_node: 0,
223            left_child: 0,
224            right_child: 0,
225            start_idx,
226            stop_idx,
227            lower_bound,
228            upper_bound,
229            is_leaf: true,
230            is_missing_leaf: false,
231            parent_node: 0,
232            left_cats,
233            stats: Some(Box::new(NodeStats {
234                depth,
235                node_type,
236                count: counts_sum,
237                generalization: None,
238                weights,
239            })),
240        }
241    }
242
243    pub fn update_children(
244        &mut self,
245        missing_child: usize,
246        left_child: usize,
247        right_child: usize,
248        split_info: &SplitInfo,
249    ) {
250        self.left_child = left_child;
251        self.right_child = right_child;
252        self.split_feature = split_info.split_feature;
253        self.split_gain = self.get_split_gain(&split_info.left_node, &split_info.right_node, &split_info.missing_node);
254        self.split_value = split_info.split_value;
255        self.missing_node = missing_child;
256        self.is_leaf = false;
257        self.left_cats = split_info.left_cats.as_ref().map(|bitset| {
258            let mut max_byte = 0;
259            for (i, &b) in bitset.iter().enumerate() {
260                if b != 0 {
261                    max_byte = i;
262                }
263            }
264            bitset[..=max_byte].to_vec().into_boxed_slice()
265        });
266    }
267
268    pub fn get_split_gain(
269        &self,
270        left_node_info: &NodeInfo,
271        right_node_info: &NodeInfo,
272        missing_node_info: &MissingInfo,
273    ) -> f32 {
274        let missing_split_gain = match &missing_node_info {
275            MissingInfo::Branch(ni) | MissingInfo::Leaf(ni) => ni.gain,
276            _ => 0.,
277        };
278        left_node_info.gain + right_node_info.gain + missing_split_gain - self.gain_value
279    }
280
281    pub fn as_node(&self, eta: f32, save_node_stats: bool) -> Node {
282        Node {
283            num: self.num,
284            weight_value: self.weight_value * eta,
285            hessian_sum: self.hessian_sum,
286            missing_node: self.missing_node,
287            split_value: self.split_value,
288            split_feature: self.split_feature,
289            split_gain: self.split_gain,
290            left_child: self.left_child,
291            right_child: self.right_child,
292            is_leaf: self.is_leaf,
293            parent_node: self.parent_node,
294            left_cats: self.left_cats.clone(),
295            stats: if save_node_stats {
296                if let Some(s) = &self.stats {
297                    let mut stats = s.clone();
298                    stats.weights = stats.weights.map(|x| x * eta);
299                    Some(stats)
300                } else {
301                    None
302                }
303            } else {
304                None
305            },
306        }
307    }
308}
309
310impl fmt::Display for Node {
311    // This trait requires `fmt` with this exact signature.
312    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
313        if self.is_leaf {
314            write!(f, "{}:leaf={},cover={}", self.num, self.weight_value, self.hessian_sum)
315        } else {
316            write!(
317                f,
318                "{}:[{} < {}] yes={},no={},missing={},gain={},cover={}",
319                self.num,
320                self.split_feature,
321                self.split_value,
322                self.left_child,
323                self.right_child,
324                self.missing_node,
325                self.split_gain,
326                self.hessian_sum
327            )
328        }
329    }
330}
331
332pub fn serialize_left_cats<S>(left_cats: &Option<Box<[u8]>>, serializer: S) -> Result<S::Ok, S::Error>
333where
334    S: Serializer,
335{
336    match left_cats {
337        Some(bytes) => {
338            let mut s = String::with_capacity(bytes.len() * 2);
339            for &b in bytes.as_ref() {
340                write!(&mut s, "{:02x}", b).map_err(serde::ser::Error::custom)?;
341            }
342            serializer.serialize_str(&s)
343        }
344        None => serializer.serialize_none(),
345    }
346}
347
348pub fn deserialize_left_cats<'de, D>(deserializer: D) -> Result<Option<Box<[u8]>>, D::Error>
349where
350    D: Deserializer<'de>,
351{
352    struct LeftCatsVisitor;
353
354    impl<'de> Visitor<'de> for LeftCatsVisitor {
355        type Value = Option<Box<[u8]>>;
356
357        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
358            formatter.write_str("a hex string, an array of bytes, or null")
359        }
360
361        fn visit_none<E>(self) -> Result<Self::Value, E>
362        where
363            E: de::Error,
364        {
365            Ok(None)
366        }
367
368        fn visit_unit<E>(self) -> Result<Self::Value, E>
369        where
370            E: de::Error,
371        {
372            Ok(None)
373        }
374
375        fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
376        where
377            E: de::Error,
378        {
379            if !v.len().is_multiple_of(2) {
380                return Err(de::Error::custom("hex string must have even length"));
381            }
382            let bytes = (0..v.len())
383                .step_by(2)
384                .map(|i| u8::from_str_radix(&v[i..i + 2], 16).map_err(|e| de::Error::custom(e.to_string())))
385                .collect::<Result<Vec<u8>, E>>()?;
386            Ok(Some(bytes.into_boxed_slice()))
387        }
388
389        fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
390        where
391            A: de::SeqAccess<'de>,
392        {
393            let mut bytes = Vec::new();
394            while let Some(byte) = seq.next_element()? {
395                bytes.push(byte);
396            }
397            Ok(Some(bytes.into_boxed_slice()))
398        }
399    }
400
401    deserializer.deserialize_any(LeftCatsVisitor)
402}