Skip to main content

schema_model/model/
schema.rs

1use crate::model::enum_type::EnumType;
2use crate::model::function::Function;
3use crate::model::other_sql::OtherSql;
4use crate::model::procedure::Procedure;
5use crate::model::table::Table;
6use crate::model::types::{DatabaseType, RelationType};
7use crate::model::view::View;
8use std::collections::HashMap;
9
10#[derive(Debug)]
11pub struct Schema {
12    schema_name: Option<String>,
13    case_sensitive_text: bool,
14    tables: Vec<Table>,
15    views: Vec<View>,
16    functions: Vec<Function>,
17    procedures: Vec<Procedure>,
18    other_sql: Vec<OtherSql>,
19    // Case-insensitive map: store lowercase name -> index in tables vec
20    table_map: HashMap<String, usize>,
21    enum_types: HashMap<String, EnumType>,
22}
23
24impl Default for Schema {
25    fn default() -> Self {
26        Self {
27            schema_name: None,
28            case_sensitive_text: true,
29            tables: Vec::new(),
30            views: Vec::new(),
31            functions: Vec::new(),
32            procedures: Vec::new(),
33            other_sql: Vec::new(),
34            table_map: HashMap::new(),
35            enum_types: HashMap::new(),
36        }
37    }
38}
39
40impl Schema {
41    pub fn new<S: Into<String>>(schema_name: Option<S>) -> Self {
42        Self {
43            schema_name: schema_name.map(|s| s.into()),
44            ..Default::default()
45        }
46    }
47
48    pub fn schema_name(&self) -> Option<&str> {
49        self.schema_name.as_deref()
50    }
51
52    pub fn case_sensitive_text(&self) -> bool {
53        self.case_sensitive_text
54    }
55
56    pub fn set_case_sensitive_text(&mut self, value: bool) {
57        self.case_sensitive_text = value;
58    }
59
60    pub fn tables(&self) -> &[Table] {
61        &self.tables
62    }
63
64    pub fn get_table(&self, name: &str) -> &Table {
65        let idx = self.table_index(name);
66        &self.tables[idx]
67    }
68
69    pub(crate) fn get_table_mut(&mut self, name: &str) -> &mut Table {
70        let idx = self.table_index(name);
71        &mut self.tables[idx]
72    }
73
74    fn table_index(&self, name: &str) -> usize {
75        let name_lower = name.to_lowercase();
76        *self.table_map.get(&name_lower)
77            .unwrap_or_else(|| panic!("Unable to locate a table with the name '{}'", name))
78    }
79
80    pub fn get_optional_table(&self, name: &str) -> Option<&Table> {
81        let name_lower = name.to_lowercase();
82        self.table_map.get(&name_lower).map(|&idx| &self.tables[idx])
83    }
84
85    pub fn all_views(&self) -> &[View] {
86        &self.views
87    }
88
89    pub fn views(&self, database_type: DatabaseType) -> Vec<View> {
90        self.views
91            .iter()
92            .filter(|view| view.database_type().is_none() || view.database_type().unwrap() == database_type)
93            .cloned()
94            .collect()
95    }
96
97    pub fn enum_types(&self) -> impl Iterator<Item = &EnumType> {
98        self.enum_types.values()
99    }
100
101    pub fn get_enum_type(&self, type_name: &str) -> &EnumType {
102        self.enum_types
103            .get(type_name)
104            .unwrap_or_else(|| panic!("Unable to locate an enum type with the name '{}'", type_name))
105    }
106
107    pub fn functions(&self) -> &[Function] {
108        &self.functions
109    }
110
111    pub fn procedures(&self) -> &[Procedure] {
112        &self.procedures
113    }
114
115    pub fn other_sql(&self) -> &[OtherSql] {
116        &self.other_sql
117    }
118
119    pub fn validate(&self) -> Vec<String> {
120        let mut errors: Vec<String> = Vec::new();
121        for table in &self.tables {
122            for relation in table.relations() {
123                if relation.relation_type() == RelationType::SetNull {
124                    let from_table_name = relation.from_table_name().to_string();
125                    let from_column_name = relation.from_column_name().to_string();
126                    let from_table = self.get_table(&from_table_name);
127                    if from_table.column(&from_column_name).is_required() {
128                        errors.push(format!(
129                            "ERROR: {}.{} is required. The {}.{} relation specifies setnull, which is not allowed",
130                            from_table_name,
131                            from_column_name,
132                            relation.to_table_name(),
133                            relation.to_column_name()
134                        ));
135                    }
136                }
137            }
138        }
139        errors
140    }
141
142    // pub fn build_reverse_relations(&mut self) {
143    //     // We need mutable access to parent tables too, so handle indices carefully.
144    //     // First, collect the relations to add per parent table to avoid multiple mutable borrows.
145    //     let mut to_add: HashMap<usize, Vec<Relation>> = HashMap::new();
146    //     for (child_idx, table) in self.tables.iter().enumerate() {
147    //         if !table.relations().is_empty() {
148    //             for relation in table.relations() {
149    //                 let parent_name = relation.to_table_name();
150    //                 if let Some(&parent_idx) = self.table_map.get(&parent_name.to_lowercase()) {
151    //                     let reverse = Relation::new(
152    //                         relation.to_table_name().to_string(),
153    //                         relation.to_column_name().to_string(),
154    //                         relation.from_table_name().to_string(),
155    //                         relation.from_column_name().to_string(),
156    //                         relation.relation_type(),
157    //                         false,
158    //                     );
159    //                     to_add.entry(parent_idx).or_default().push(reverse);
160    //                 } else {
161    //                     // Parent not found; ignore or log in real implementation
162    //                     let _ = child_idx; // keep variable used
163    //                 }
164    //             }
165    //         }
166    //     }
167    //     for (idx, rels) in to_add {
168    //         if let Some(parent) = self.tables.get_mut(idx) {
169    //             parent.reverse_relations_mut().extend(rels);
170    //         }
171    //     }
172    // }
173
174    pub(crate) fn add_table(&mut self, table: Table) {
175        let idx = self.tables.len();
176        self.table_map.insert(table.name().to_lowercase(), idx);
177        self.tables.push(table);
178    }
179
180    pub(crate) fn add_view(&mut self, view: View) {
181        self.views.push(view);
182    }
183
184    pub(crate) fn add_enum_type(&mut self, enum_type: EnumType) {
185        self.enum_types
186            .insert(enum_type.name().to_string(), enum_type);
187    }
188
189    pub(crate) fn add_functions(&mut self, functions: Vec<Function>) {
190        self.functions.extend(functions);
191    }
192
193    pub(crate) fn add_procedures(&mut self, procedures: Vec<Procedure>) {
194        self.procedures.extend(procedures);
195    }
196
197    pub(crate) fn add_other_sql(&mut self, other_sql: OtherSql) {
198        self.other_sql.push(other_sql);
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205    use crate::model::column::Column;
206    use crate::model::column_type::ColumnType;
207    use crate::model::relation::Relation;
208
209    fn make_schema() -> Schema {
210        Schema::new(Some("schema"))
211    }
212
213    #[test]
214    fn add_and_get_table_and_sort() {
215        let mut schema = make_schema();
216        let table1 = Table::new(
217            Some("schema"),
218            "Table1",
219            Option::<&str>::None,
220            crate::model::types::LockEscalation::Auto,
221            false,
222            vec![Column::new(Some("s"), "id", ColumnType::Int, 0, 0, true)],
223            Vec::new(),
224            Vec::new(),
225            Vec::new(),
226            Vec::new(),
227            Vec::new(),
228            Vec::new(),
229            Vec::new(),
230            Vec::new(),
231        );
232        let table2 = Table::new(
233            Some("schema"),
234            "Table2",
235            Option::<&str>::None,
236            crate::model::types::LockEscalation::Auto,
237            false,
238            Vec::new(),
239            Vec::new(),
240            Vec::new(),
241            Vec::new(),
242            Vec::new(),
243            Vec::new(),
244            Vec::new(),
245            Vec::new(),
246            Vec::new(),
247        );
248        schema.add_table(table1);
249        schema.add_table(table2);
250        assert_eq!(schema.get_table("Table2").name(), "Table2"); // case-insensitive
251        let names: Vec<_> = schema.tables().iter().map(|t| t.name()).collect();
252        assert_eq!(names, vec!["Table1", "Table2"]);
253        // table_map rebuilt so get_table still works
254        assert_eq!(schema.get_table("Table1").name(), "Table1");
255    }
256
257    #[test]
258    fn views_filtered_by_database_type() {
259        let mut s = make_schema();
260        s.add_view(View::new(Some("s"), "v1", "sql1", Some(DatabaseType::Postgresql)));
261        s.add_view(View::new(Some("s"), "v2", "sql2", Some(DatabaseType::SqlServer)));
262        let pg = s.views(DatabaseType::Postgresql);
263        assert_eq!(pg.len(), 1);
264        assert_eq!(pg[0].name(), "v1");
265    }
266
267    #[test]
268    fn validate_setnull_error_when_required() {
269        let mut s = make_schema();
270        let parent = Table::new(
271            Some("s"),
272            "parent",
273            Option::<&str>::None,
274            crate::model::types::LockEscalation::Auto,
275            false,
276            vec![Column::new(Some("s"), "id", ColumnType::Int, 0, 0, true)],
277            Vec::new(),
278            Vec::new(),
279            Vec::new(),
280            Vec::new(),
281            Vec::new(),
282            Vec::new(),
283            Vec::new(),
284            Vec::new(),
285        );
286        s.add_table(parent);
287
288        let child = Table::new(
289            Some("s"),
290            "child",
291            Option::<&str>::None,
292            crate::model::types::LockEscalation::Auto,
293            false,
294            vec![Column::new(Some("s"), "pid", ColumnType::Int, 0, 0, true)],
295            Vec::new(),
296            Vec::new(),
297            vec![Relation::new(
298                "parent",
299                "id",
300                "child",
301                "pid",
302                RelationType::SetNull,
303                false,
304            )],
305            Vec::new(),
306            Vec::new(),
307            Vec::new(),
308            Vec::new(),
309            Vec::new(),
310        );
311        s.add_table(child);
312
313        let errors = s.validate();
314        assert_eq!(errors.len(), 1);
315        assert!(errors[0].contains("setnull"));
316    }
317
318    // #[test]
319    // fn build_reverse_relations_creates_back_refs() {
320    //     let mut s = make_schema();
321    //     let mut parent = Table::new(
322    //         Some("s"),
323    //         "p",
324    //         Option::<&str>::None,
325    //         crate::model::types::LockEscalation::Auto,
326    //         false,
327    //         vec![Column::new(Some("s"), "id", ColumnType::Int, 0, 0, true)],
328    //         Vec::new(),
329    //         Vec::new(),
330    //         Vec::new(),
331    //         Vec::new(),
332    //         Vec::new(),
333    //         Vec::new(),
334    //         Vec::new(),
335    //         Vec::new(),
336    //     );
337    //     let mut child = Table::new(
338    //         Some("s"),
339    //         "c",
340    //         Option::<&str>::None,
341    //         crate::model::types::LockEscalation::Auto,
342    //         false,
343    //         vec![Column::new(Some("s"), "pid", ColumnType::Int, 0, 0, false)],
344    //         Vec::new(),
345    //         Vec::new(),
346    //         vec![Relation::new(
347    //             "p",
348    //             "id",
349    //             "c",
350    //             "pid",
351    //             RelationType::Cascade,
352    //             false,
353    //         )],
354    //         Vec::new(),
355    //         Vec::new(),
356    //         Vec::new(),
357    //         Vec::new(),
358    //         Vec::new(),
359    //     );
360    //     s.add_table(parent);
361    //     s.add_table(child);
362    //
363    //     s.build_reverse_relations();
364    //     let p_ref = s.get_table("p");
365    //     assert_eq!(p_ref.reverse_relations().len(), 1);
366    //     let rr = &p_ref.reverse_relations()[0];
367    //     assert_eq!(rr.from_table_name(), "c");
368    //     assert_eq!(rr.to_table_name(), "p");
369    // }
370}