entrenar/config/infer/
schema.rs1use std::collections::HashMap;
4
5use super::stats::ColumnStats;
6use super::types::FeatureType;
7
8#[derive(Debug, Clone, Default)]
10pub struct InferredSchema {
11 pub features: HashMap<String, FeatureType>,
13 pub stats: HashMap<String, ColumnStats>,
15}
16
17impl InferredSchema {
18 pub fn features_of_type(&self, feature_type: FeatureType) -> Vec<&str> {
20 self.features
21 .iter()
22 .filter(|(_, &t)| t == feature_type)
23 .map(|(name, _)| name.as_str())
24 .collect()
25 }
26
27 pub fn targets(&self) -> Vec<&str> {
29 self.features
30 .iter()
31 .filter(|(_, &t)| {
32 matches!(
33 t,
34 FeatureType::BinaryTarget
35 | FeatureType::MultiClassTarget
36 | FeatureType::RegressionTarget
37 )
38 })
39 .map(|(name, _)| name.as_str())
40 .collect()
41 }
42
43 pub fn inputs(&self) -> Vec<&str> {
45 self.features
46 .iter()
47 .filter(|(_, &t)| {
48 !matches!(
49 t,
50 FeatureType::BinaryTarget
51 | FeatureType::MultiClassTarget
52 | FeatureType::RegressionTarget
53 )
54 })
55 .map(|(name, _)| name.as_str())
56 .collect()
57 }
58}
59
60#[cfg(test)]
61mod tests {
62 use super::*;
63
64 fn make_schema() -> InferredSchema {
65 let mut schema = InferredSchema::default();
66 schema.features.insert("age".to_string(), FeatureType::Numeric);
67 schema.features.insert("income".to_string(), FeatureType::Numeric);
68 schema.features.insert("category".to_string(), FeatureType::Categorical);
69 schema.features.insert("text".to_string(), FeatureType::Text);
70 schema.features.insert("is_spam".to_string(), FeatureType::BinaryTarget);
71 schema.features.insert("label".to_string(), FeatureType::MultiClassTarget);
72 schema.features.insert("price".to_string(), FeatureType::RegressionTarget);
73 schema
74 }
75
76 #[test]
77 fn test_inferred_schema_default() {
78 let schema = InferredSchema::default();
79 assert!(schema.features.is_empty());
80 assert!(schema.stats.is_empty());
81 }
82
83 #[test]
84 fn test_features_of_type_numeric() {
85 let schema = make_schema();
86 let numeric = schema.features_of_type(FeatureType::Numeric);
87 assert_eq!(numeric.len(), 2);
88 assert!(numeric.contains(&"age"));
89 assert!(numeric.contains(&"income"));
90 }
91
92 #[test]
93 fn test_features_of_type_categorical() {
94 let schema = make_schema();
95 let categorical = schema.features_of_type(FeatureType::Categorical);
96 assert_eq!(categorical.len(), 1);
97 assert!(categorical.contains(&"category"));
98 }
99
100 #[test]
101 fn test_features_of_type_text() {
102 let schema = make_schema();
103 let text = schema.features_of_type(FeatureType::Text);
104 assert_eq!(text.len(), 1);
105 assert!(text.contains(&"text"));
106 }
107
108 #[test]
109 fn test_features_of_type_empty() {
110 let schema = make_schema();
111 let embedding = schema.features_of_type(FeatureType::Embedding);
112 assert!(embedding.is_empty());
113 }
114
115 #[test]
116 fn test_targets() {
117 let schema = make_schema();
118 let targets = schema.targets();
119 assert_eq!(targets.len(), 3);
120 assert!(targets.contains(&"is_spam"));
121 assert!(targets.contains(&"label"));
122 assert!(targets.contains(&"price"));
123 }
124
125 #[test]
126 fn test_targets_empty() {
127 let schema = InferredSchema::default();
128 let targets = schema.targets();
129 assert!(targets.is_empty());
130 }
131
132 #[test]
133 fn test_inputs() {
134 let schema = make_schema();
135 let inputs = schema.inputs();
136 assert_eq!(inputs.len(), 4);
137 assert!(inputs.contains(&"age"));
138 assert!(inputs.contains(&"income"));
139 assert!(inputs.contains(&"category"));
140 assert!(inputs.contains(&"text"));
141 }
142
143 #[test]
144 fn test_inputs_excludes_targets() {
145 let schema = make_schema();
146 let inputs = schema.inputs();
147 assert!(!inputs.contains(&"is_spam"));
148 assert!(!inputs.contains(&"label"));
149 assert!(!inputs.contains(&"price"));
150 }
151
152 #[test]
153 fn test_inputs_empty() {
154 let schema = InferredSchema::default();
155 let inputs = schema.inputs();
156 assert!(inputs.is_empty());
157 }
158
159 #[test]
160 fn test_inferred_schema_clone() {
161 let schema = make_schema();
162 let cloned = schema.clone();
163 assert_eq!(schema.features.len(), cloned.features.len());
164 }
165
166 #[test]
167 fn test_inferred_schema_debug() {
168 let schema = make_schema();
169 let debug_str = format!("{schema:?}");
170 assert!(debug_str.contains("InferredSchema"));
171 assert!(debug_str.contains("features"));
172 }
173}