Skip to main content

entrenar/config/infer/
schema.rs

1//! Inferred schema representation
2
3use std::collections::HashMap;
4
5use super::stats::ColumnStats;
6use super::types::FeatureType;
7
8/// Inferred schema for a dataset
9#[derive(Debug, Clone, Default)]
10pub struct InferredSchema {
11    /// Feature name -> inferred type
12    pub features: HashMap<String, FeatureType>,
13    /// Column statistics used for inference
14    pub stats: HashMap<String, ColumnStats>,
15}
16
17impl InferredSchema {
18    /// Get features of a specific type
19    pub fn features_of_type(&self, feature_type: FeatureType) -> Vec<&str> {
20        self.features
21            .iter()
22            .filter(|(_, &t)| t == feature_type)
23            .map(|(name, _)| name.as_str())
24            .collect()
25    }
26
27    /// Get target columns
28    pub fn targets(&self) -> Vec<&str> {
29        self.features
30            .iter()
31            .filter(|(_, &t)| {
32                matches!(
33                    t,
34                    FeatureType::BinaryTarget
35                        | FeatureType::MultiClassTarget
36                        | FeatureType::RegressionTarget
37                )
38            })
39            .map(|(name, _)| name.as_str())
40            .collect()
41    }
42
43    /// Get input feature columns (non-targets)
44    pub fn inputs(&self) -> Vec<&str> {
45        self.features
46            .iter()
47            .filter(|(_, &t)| {
48                !matches!(
49                    t,
50                    FeatureType::BinaryTarget
51                        | FeatureType::MultiClassTarget
52                        | FeatureType::RegressionTarget
53                )
54            })
55            .map(|(name, _)| name.as_str())
56            .collect()
57    }
58}
59
60#[cfg(test)]
61mod tests {
62    use super::*;
63
64    fn make_schema() -> InferredSchema {
65        let mut schema = InferredSchema::default();
66        schema.features.insert("age".to_string(), FeatureType::Numeric);
67        schema.features.insert("income".to_string(), FeatureType::Numeric);
68        schema.features.insert("category".to_string(), FeatureType::Categorical);
69        schema.features.insert("text".to_string(), FeatureType::Text);
70        schema.features.insert("is_spam".to_string(), FeatureType::BinaryTarget);
71        schema.features.insert("label".to_string(), FeatureType::MultiClassTarget);
72        schema.features.insert("price".to_string(), FeatureType::RegressionTarget);
73        schema
74    }
75
76    #[test]
77    fn test_inferred_schema_default() {
78        let schema = InferredSchema::default();
79        assert!(schema.features.is_empty());
80        assert!(schema.stats.is_empty());
81    }
82
83    #[test]
84    fn test_features_of_type_numeric() {
85        let schema = make_schema();
86        let numeric = schema.features_of_type(FeatureType::Numeric);
87        assert_eq!(numeric.len(), 2);
88        assert!(numeric.contains(&"age"));
89        assert!(numeric.contains(&"income"));
90    }
91
92    #[test]
93    fn test_features_of_type_categorical() {
94        let schema = make_schema();
95        let categorical = schema.features_of_type(FeatureType::Categorical);
96        assert_eq!(categorical.len(), 1);
97        assert!(categorical.contains(&"category"));
98    }
99
100    #[test]
101    fn test_features_of_type_text() {
102        let schema = make_schema();
103        let text = schema.features_of_type(FeatureType::Text);
104        assert_eq!(text.len(), 1);
105        assert!(text.contains(&"text"));
106    }
107
108    #[test]
109    fn test_features_of_type_empty() {
110        let schema = make_schema();
111        let embedding = schema.features_of_type(FeatureType::Embedding);
112        assert!(embedding.is_empty());
113    }
114
115    #[test]
116    fn test_targets() {
117        let schema = make_schema();
118        let targets = schema.targets();
119        assert_eq!(targets.len(), 3);
120        assert!(targets.contains(&"is_spam"));
121        assert!(targets.contains(&"label"));
122        assert!(targets.contains(&"price"));
123    }
124
125    #[test]
126    fn test_targets_empty() {
127        let schema = InferredSchema::default();
128        let targets = schema.targets();
129        assert!(targets.is_empty());
130    }
131
132    #[test]
133    fn test_inputs() {
134        let schema = make_schema();
135        let inputs = schema.inputs();
136        assert_eq!(inputs.len(), 4);
137        assert!(inputs.contains(&"age"));
138        assert!(inputs.contains(&"income"));
139        assert!(inputs.contains(&"category"));
140        assert!(inputs.contains(&"text"));
141    }
142
143    #[test]
144    fn test_inputs_excludes_targets() {
145        let schema = make_schema();
146        let inputs = schema.inputs();
147        assert!(!inputs.contains(&"is_spam"));
148        assert!(!inputs.contains(&"label"));
149        assert!(!inputs.contains(&"price"));
150    }
151
152    #[test]
153    fn test_inputs_empty() {
154        let schema = InferredSchema::default();
155        let inputs = schema.inputs();
156        assert!(inputs.is_empty());
157    }
158
159    #[test]
160    fn test_inferred_schema_clone() {
161        let schema = make_schema();
162        let cloned = schema.clone();
163        assert_eq!(schema.features.len(), cloned.features.len());
164    }
165
166    #[test]
167    fn test_inferred_schema_debug() {
168        let schema = make_schema();
169        let debug_str = format!("{schema:?}");
170        assert!(debug_str.contains("InferredSchema"));
171        assert!(debug_str.contains("features"));
172    }
173}