liquid_ml/dataframe/
schema.rs1use crate::error::LiquidError;
4use deepsize::DeepSizeOf;
5use serde::{Deserialize, Serialize};
6use sorer::{dataframe::Column, schema::DataType};
7use std::collections::HashMap;
8
9#[derive(
11 Serialize, Deserialize, PartialEq, Clone, Debug, Default, DeepSizeOf,
12)]
13pub struct Schema {
14 pub schema: Vec<DataType>,
16 pub col_names: HashMap<String, usize>,
19}
20
21impl Schema {
24 pub fn new() -> Self {
26 Schema {
27 ..Default::default()
28 }
29 }
30
31 pub fn add_column(
37 &mut self,
38 data_type: DataType,
39 col_name: Option<String>,
40 ) -> Result<(), LiquidError> {
41 if let Some(name) = col_name {
42 if !self.col_names.contains_key(&name) {
43 self.col_names.insert(name, self.schema.len());
44 } else {
45 return Err(LiquidError::NameAlreadyExists);
46 }
47 }
48 self.schema.push(data_type);
49 Ok(())
50 }
51
52 pub fn col_type(&self, idx: usize) -> Result<&DataType, LiquidError> {
55 match self.schema.get(idx) {
56 Some(data_type) => Ok(data_type),
57 None => Err(LiquidError::ColIndexOutOfBounds),
58 }
59 }
60
61 pub fn col_idx(&self, col_name: &str) -> Option<usize> {
63 match self.col_names.get(col_name) {
64 Some(x) => Some(*x),
65 None => None,
66 }
67 }
68
69 pub fn col_name(
71 &self,
72 col_idx: usize,
73 ) -> Result<Option<&str>, LiquidError> {
74 if col_idx >= self.width() {
75 return Err(LiquidError::ColIndexOutOfBounds);
76 }
77 match self.col_names.iter().find(|(_, &v)| v == col_idx) {
78 Some((col_name, _)) => Ok(Some(col_name)),
79 None => Ok(None),
80 }
81 }
82
83 pub fn width(&self) -> usize {
85 self.schema.len()
86 }
87
88 fn char_to_data_type(c: char) -> DataType {
89 match c {
90 'B' => DataType::Bool,
91 'I' => DataType::Int,
92 'F' => DataType::Float,
93 'S' => DataType::String,
94 _ => panic!("Tried to make a bad Schema"),
95 }
96 }
97}
98
99impl From<&str> for Schema {
100 fn from(types: &str) -> Self {
111 let mut schema = Vec::new();
112 for c in types.chars() {
113 schema.push(Schema::char_to_data_type(c));
114 }
115 Schema {
116 schema,
117 col_names: HashMap::new(),
118 }
119 }
120}
121
122impl From<Vec<DataType>> for Schema {
123 fn from(types: Vec<DataType>) -> Self {
125 Schema {
126 schema: types,
127 col_names: HashMap::new(),
128 }
129 }
130}
131
132impl From<&Vec<Column>> for Schema {
133 fn from(columns: &Vec<Column>) -> Self {
135 let mut schema = Vec::new();
136 for c in columns {
137 match c {
138 Column::Bool(_) => schema.push(DataType::Bool),
139 Column::Int(_) => schema.push(DataType::Int),
140 Column::Float(_) => schema.push(DataType::Float),
141 Column::String(_) => schema.push(DataType::String),
142 };
143 }
144 Schema {
145 schema,
146 col_names: HashMap::new(),
147 }
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154
155 #[test]
156 fn test_from_data_types() {
157 let mut data_types = vec![];
158 let mut s = Schema::from(data_types.clone());
159 assert_eq!(s.width(), 0);
160 data_types = vec![
161 DataType::Int,
162 DataType::Int,
163 DataType::Float,
164 DataType::Bool,
165 DataType::String,
166 ];
167 s = Schema::from(data_types.clone());
168 for (idx, data_type) in data_types.iter().enumerate() {
169 assert_eq!(data_type, s.col_type(idx).unwrap());
170 }
171 assert_eq!(s.width(), data_types.len());
172 }
173
174 #[test]
175 fn test_from_str() {
176 let mut types_str = "";
177 let mut s = Schema::from(types_str);
178 assert_eq!(s.width(), 0);
179 types_str = "IIFBS";
180 s = Schema::from(types_str);
181 let data_types = vec![
182 DataType::Int,
183 DataType::Int,
184 DataType::Float,
185 DataType::Bool,
186 DataType::String,
187 ];
188 for (idx, data_type) in data_types.iter().enumerate() {
189 assert_eq!(data_type, s.col_type(idx).unwrap());
190 }
191 assert_eq!(s.width(), data_types.len());
192 }
193
194 #[test]
196 fn test_col_getters_setters() {
197 let mut s = Schema::new();
199 assert_eq!(s.width(), 0);
200 s.add_column(DataType::String, None).unwrap();
201 assert_eq!(s.width(), 1);
202 s.add_column(DataType::Int, Some(String::from("foo")))
203 .unwrap();
204 assert_eq!(s.width(), 2);
205 assert_eq!(s.col_idx("foo"), Some(1));
206 }
207}