liquid_ml/dataframe/
schema.rs1use crate::error::LiquidError;
4use deepsize::DeepSizeOf;
5use serde::{Deserialize, Serialize};
6use sorer::{dataframe::Column, schema::DataType};
7use std::collections::HashMap;
8
9#[derive(
11    Serialize, Deserialize, PartialEq, Clone, Debug, Default, DeepSizeOf,
12)]
13pub struct Schema {
14    pub schema: Vec<DataType>,
16    pub col_names: HashMap<String, usize>,
19}
20
21impl Schema {
24    pub fn new() -> Self {
26        Schema {
27            ..Default::default()
28        }
29    }
30
31    pub fn add_column(
37        &mut self,
38        data_type: DataType,
39        col_name: Option<String>,
40    ) -> Result<(), LiquidError> {
41        if let Some(name) = col_name {
42            if !self.col_names.contains_key(&name) {
43                self.col_names.insert(name, self.schema.len());
44            } else {
45                return Err(LiquidError::NameAlreadyExists);
46            }
47        }
48        self.schema.push(data_type);
49        Ok(())
50    }
51
52    pub fn col_type(&self, idx: usize) -> Result<&DataType, LiquidError> {
55        match self.schema.get(idx) {
56            Some(data_type) => Ok(data_type),
57            None => Err(LiquidError::ColIndexOutOfBounds),
58        }
59    }
60
61    pub fn col_idx(&self, col_name: &str) -> Option<usize> {
63        match self.col_names.get(col_name) {
64            Some(x) => Some(*x),
65            None => None,
66        }
67    }
68
69    pub fn col_name(
71        &self,
72        col_idx: usize,
73    ) -> Result<Option<&str>, LiquidError> {
74        if col_idx >= self.width() {
75            return Err(LiquidError::ColIndexOutOfBounds);
76        }
77        match self.col_names.iter().find(|(_, &v)| v == col_idx) {
78            Some((col_name, _)) => Ok(Some(col_name)),
79            None => Ok(None),
80        }
81    }
82
83    pub fn width(&self) -> usize {
85        self.schema.len()
86    }
87
88    fn char_to_data_type(c: char) -> DataType {
89        match c {
90            'B' => DataType::Bool,
91            'I' => DataType::Int,
92            'F' => DataType::Float,
93            'S' => DataType::String,
94            _ => panic!("Tried to make a bad Schema"),
95        }
96    }
97}
98
99impl From<&str> for Schema {
100    fn from(types: &str) -> Self {
111        let mut schema = Vec::new();
112        for c in types.chars() {
113            schema.push(Schema::char_to_data_type(c));
114        }
115        Schema {
116            schema,
117            col_names: HashMap::new(),
118        }
119    }
120}
121
122impl From<Vec<DataType>> for Schema {
123    fn from(types: Vec<DataType>) -> Self {
125        Schema {
126            schema: types,
127            col_names: HashMap::new(),
128        }
129    }
130}
131
132impl From<&Vec<Column>> for Schema {
133    fn from(columns: &Vec<Column>) -> Self {
135        let mut schema = Vec::new();
136        for c in columns {
137            match c {
138                Column::Bool(_) => schema.push(DataType::Bool),
139                Column::Int(_) => schema.push(DataType::Int),
140                Column::Float(_) => schema.push(DataType::Float),
141                Column::String(_) => schema.push(DataType::String),
142            };
143        }
144        Schema {
145            schema,
146            col_names: HashMap::new(),
147        }
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154
155    #[test]
156    fn test_from_data_types() {
157        let mut data_types = vec![];
158        let mut s = Schema::from(data_types.clone());
159        assert_eq!(s.width(), 0);
160        data_types = vec![
161            DataType::Int,
162            DataType::Int,
163            DataType::Float,
164            DataType::Bool,
165            DataType::String,
166        ];
167        s = Schema::from(data_types.clone());
168        for (idx, data_type) in data_types.iter().enumerate() {
169            assert_eq!(data_type, s.col_type(idx).unwrap());
170        }
171        assert_eq!(s.width(), data_types.len());
172    }
173
174    #[test]
175    fn test_from_str() {
176        let mut types_str = "";
177        let mut s = Schema::from(types_str);
178        assert_eq!(s.width(), 0);
179        types_str = "IIFBS";
180        s = Schema::from(types_str);
181        let data_types = vec![
182            DataType::Int,
183            DataType::Int,
184            DataType::Float,
185            DataType::Bool,
186            DataType::String,
187        ];
188        for (idx, data_type) in data_types.iter().enumerate() {
189            assert_eq!(data_type, s.col_type(idx).unwrap());
190        }
191        assert_eq!(s.width(), data_types.len());
192    }
193
194    #[test]
196    fn test_col_getters_setters() {
197        let mut s = Schema::new();
199        assert_eq!(s.width(), 0);
200        s.add_column(DataType::String, None).unwrap();
201        assert_eq!(s.width(), 1);
202        s.add_column(DataType::Int, Some(String::from("foo")))
203            .unwrap();
204        assert_eq!(s.width(), 2);
205        assert_eq!(s.col_idx("foo"), Some(1));
206    }
207}