sea_orm_codegen/entity/
transformer.rs

1use crate::{
2    util::unpack_table_ref, ActiveEnum, Column, ConjunctRelation, Entity, EntityWriter, Error,
3    PrimaryKey, Relation, RelationType,
4};
5use sea_query::{ColumnSpec, TableCreateStatement};
6use std::collections::{BTreeMap, HashMap};
7
8#[derive(Clone, Debug)]
9pub struct EntityTransformer;
10
11impl EntityTransformer {
12    pub fn transform(table_create_stmts: Vec<TableCreateStatement>) -> Result<EntityWriter, Error> {
13        let mut enums: BTreeMap<String, ActiveEnum> = BTreeMap::new();
14        let mut inverse_relations: BTreeMap<String, Vec<Relation>> = BTreeMap::new();
15        let mut entities = BTreeMap::new();
16        for table_create in table_create_stmts.into_iter() {
17            let table_name = match table_create.get_table_name() {
18                Some(table_ref) => match table_ref {
19                    sea_query::TableRef::Table(t)
20                    | sea_query::TableRef::SchemaTable(_, t)
21                    | sea_query::TableRef::DatabaseSchemaTable(_, _, t)
22                    | sea_query::TableRef::TableAlias(t, _)
23                    | sea_query::TableRef::SchemaTableAlias(_, t, _)
24                    | sea_query::TableRef::DatabaseSchemaTableAlias(_, _, t, _) => t.to_string(),
25                    _ => unimplemented!(),
26                },
27                None => {
28                    return Err(Error::TransformError(
29                        "Table name should not be empty".into(),
30                    ))
31                }
32            };
33            let mut primary_keys: Vec<PrimaryKey> = Vec::new();
34            let columns: Vec<Column> = table_create
35                .get_columns()
36                .iter()
37                .map(|col_def| {
38                    let primary_key = col_def
39                        .get_column_spec()
40                        .iter()
41                        .any(|spec| matches!(spec, ColumnSpec::PrimaryKey));
42                    if primary_key {
43                        primary_keys.push(PrimaryKey {
44                            name: col_def.get_column_name(),
45                        });
46                    }
47                    col_def.into()
48                })
49                .map(|mut col: Column| {
50                    col.unique = table_create
51                        .get_indexes()
52                        .iter()
53                        .filter(|index| index.is_unique_key())
54                        .map(|index| index.get_index_spec().get_column_names())
55                        .filter(|col_names| col_names.len() == 1 && col_names[0] == col.name)
56                        .count()
57                        > 0;
58                    col
59                })
60                .inspect(|col| {
61                    if let sea_query::ColumnType::Enum { name, variants } = col.get_inner_col_type()
62                    {
63                        enums.insert(
64                            name.to_string(),
65                            ActiveEnum {
66                                enum_name: name.clone(),
67                                values: variants.clone(),
68                            },
69                        );
70                    }
71                })
72                .collect();
73            let mut ref_table_counts: BTreeMap<String, usize> = BTreeMap::new();
74            let relations: Vec<Relation> = table_create
75                .get_foreign_key_create_stmts()
76                .iter()
77                .map(|fk_create_stmt| fk_create_stmt.get_foreign_key())
78                .map(|tbl_fk| {
79                    let ref_tbl = unpack_table_ref(tbl_fk.get_ref_table().unwrap());
80                    if let Some(count) = ref_table_counts.get_mut(&ref_tbl) {
81                        if *count == 0 {
82                            *count = 1;
83                        }
84                        *count += 1;
85                    } else {
86                        ref_table_counts.insert(ref_tbl, 0);
87                    };
88                    tbl_fk.into()
89                })
90                .collect::<Vec<_>>()
91                .into_iter()
92                .rev()
93                .map(|mut rel: Relation| {
94                    rel.self_referencing = rel.ref_table == table_name;
95                    if let Some(count) = ref_table_counts.get_mut(&rel.ref_table) {
96                        rel.num_suffix = *count;
97                        if *count > 0 {
98                            *count -= 1;
99                        }
100                    }
101                    rel
102                })
103                .rev()
104                .collect();
105            primary_keys.extend(
106                table_create
107                    .get_indexes()
108                    .iter()
109                    .filter(|index| index.is_primary_key())
110                    .flat_map(|index| {
111                        index
112                            .get_index_spec()
113                            .get_column_names()
114                            .into_iter()
115                            .map(|name| PrimaryKey { name })
116                            .collect::<Vec<_>>()
117                    }),
118            );
119            let entity = Entity {
120                table_name: table_name.clone(),
121                columns,
122                relations: relations.clone(),
123                conjunct_relations: vec![],
124                primary_keys,
125            };
126            entities.insert(table_name.clone(), entity.clone());
127            for mut rel in relations.into_iter() {
128                // This will produce a duplicated relation
129                if rel.self_referencing {
130                    continue;
131                }
132                // This will cause compile error on the many side,
133                // got relation variant but without Related<T> implemented
134                if rel.num_suffix > 0 {
135                    continue;
136                }
137                let ref_table = rel.ref_table;
138                let mut unique = true;
139                for column in rel.columns.iter() {
140                    if !entity
141                        .columns
142                        .iter()
143                        .filter(|col| col.unique)
144                        .any(|col| col.name.as_str() == column)
145                    {
146                        unique = false;
147                        break;
148                    }
149                }
150                if rel.columns.len() == entity.primary_keys.len() {
151                    let mut count_pk = 0;
152                    for primary_key in entity.primary_keys.iter() {
153                        if rel.columns.contains(&primary_key.name) {
154                            count_pk += 1;
155                        }
156                    }
157                    if count_pk == entity.primary_keys.len() {
158                        unique = true;
159                    }
160                }
161                let rel_type = if unique {
162                    RelationType::HasOne
163                } else {
164                    RelationType::HasMany
165                };
166                rel.rel_type = rel_type;
167                rel.ref_table = table_name.to_string();
168                rel.columns = Vec::new();
169                rel.ref_columns = Vec::new();
170                if let Some(vec) = inverse_relations.get_mut(&ref_table) {
171                    vec.push(rel);
172                } else {
173                    inverse_relations.insert(ref_table, vec![rel]);
174                }
175            }
176        }
177        for (tbl_name, relations) in inverse_relations.into_iter() {
178            if let Some(entity) = entities.get_mut(&tbl_name) {
179                for relation in relations.into_iter() {
180                    let duplicate_relation = entity
181                        .relations
182                        .iter()
183                        .any(|rel| rel.ref_table == relation.ref_table);
184                    if !duplicate_relation {
185                        entity.relations.push(relation);
186                    }
187                }
188            }
189        }
190        for table_name in entities.clone().keys() {
191            let relations = match entities.get(table_name) {
192                Some(entity) => {
193                    let is_conjunct_relation =
194                        entity.relations.len() == 2 && entity.primary_keys.len() == 2;
195                    if !is_conjunct_relation {
196                        continue;
197                    }
198                    entity.relations.clone()
199                }
200                None => unreachable!(),
201            };
202            for (i, rel) in relations.iter().enumerate() {
203                let another_rel = relations.get((i == 0) as usize).unwrap();
204                if let Some(entity) = entities.get_mut(&rel.ref_table) {
205                    let conjunct_relation = ConjunctRelation {
206                        via: table_name.clone(),
207                        to: another_rel.ref_table.clone(),
208                    };
209                    entity.conjunct_relations.push(conjunct_relation);
210                }
211            }
212        }
213        Ok(EntityWriter {
214            entities: entities
215                .into_values()
216                .map(|mut v| {
217                    // Filter duplicated conjunct relations
218                    let duplicated_to: Vec<_> = v
219                        .conjunct_relations
220                        .iter()
221                        .fold(HashMap::new(), |mut acc, conjunct_relation| {
222                            acc.entry(conjunct_relation.to.clone())
223                                .and_modify(|c| *c += 1)
224                                .or_insert(1);
225                            acc
226                        })
227                        .into_iter()
228                        .filter(|(_, v)| v > &1)
229                        .map(|(k, _)| k)
230                        .collect();
231                    v.conjunct_relations
232                        .retain(|conjunct_relation| !duplicated_to.contains(&conjunct_relation.to));
233
234                    // Skip `impl Related ... { fn to() ... }` implementation block,
235                    // if the same related entity is being referenced by a conjunct relation
236                    v.relations.iter_mut().for_each(|relation| {
237                        if v.conjunct_relations
238                            .iter()
239                            .any(|conjunct_relation| conjunct_relation.to == relation.ref_table)
240                        {
241                            relation.impl_related = false;
242                        }
243                    });
244
245                    // Sort relation vectors
246                    v.relations.sort_by(|a, b| a.ref_table.cmp(&b.ref_table));
247                    v.conjunct_relations.sort_by(|a, b| a.to.cmp(&b.to));
248                    v
249                })
250                .collect(),
251            enums,
252        })
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259    use pretty_assertions::assert_eq;
260    use proc_macro2::TokenStream;
261    use sea_orm::{DbBackend, Schema};
262    use std::{
263        error::Error,
264        io::{self, BufRead, BufReader},
265    };
266
267    #[test]
268    fn duplicated_many_to_many_paths() -> Result<(), Box<dyn Error>> {
269        use crate::tests_cfg::duplicated_many_to_many_paths::*;
270        let schema = Schema::new(DbBackend::Postgres);
271
272        validate_compact_entities(
273            vec![
274                schema.create_table_from_entity(bills::Entity),
275                schema.create_table_from_entity(users::Entity),
276                schema.create_table_from_entity(users_saved_bills::Entity),
277                schema.create_table_from_entity(users_votes::Entity),
278            ],
279            vec![
280                (
281                    "bills",
282                    include_str!("../tests_cfg/duplicated_many_to_many_paths/bills.rs"),
283                ),
284                (
285                    "users",
286                    include_str!("../tests_cfg/duplicated_many_to_many_paths/users.rs"),
287                ),
288                (
289                    "users_saved_bills",
290                    include_str!("../tests_cfg/duplicated_many_to_many_paths/users_saved_bills.rs"),
291                ),
292                (
293                    "users_votes",
294                    include_str!("../tests_cfg/duplicated_many_to_many_paths/users_votes.rs"),
295                ),
296            ],
297        )
298    }
299
300    #[test]
301    fn many_to_many() -> Result<(), Box<dyn Error>> {
302        use crate::tests_cfg::many_to_many::*;
303        let schema = Schema::new(DbBackend::Postgres);
304
305        validate_compact_entities(
306            vec![
307                schema.create_table_from_entity(bills::Entity),
308                schema.create_table_from_entity(users::Entity),
309                schema.create_table_from_entity(users_votes::Entity),
310            ],
311            vec![
312                ("bills", include_str!("../tests_cfg/many_to_many/bills.rs")),
313                ("users", include_str!("../tests_cfg/many_to_many/users.rs")),
314                (
315                    "users_votes",
316                    include_str!("../tests_cfg/many_to_many/users_votes.rs"),
317                ),
318            ],
319        )
320    }
321
322    #[test]
323    fn many_to_many_multiple() -> Result<(), Box<dyn Error>> {
324        use crate::tests_cfg::many_to_many_multiple::*;
325        let schema = Schema::new(DbBackend::Postgres);
326
327        validate_compact_entities(
328            vec![
329                schema.create_table_from_entity(bills::Entity),
330                schema.create_table_from_entity(users::Entity),
331                schema.create_table_from_entity(users_votes::Entity),
332            ],
333            vec![
334                (
335                    "bills",
336                    include_str!("../tests_cfg/many_to_many_multiple/bills.rs"),
337                ),
338                (
339                    "users",
340                    include_str!("../tests_cfg/many_to_many_multiple/users.rs"),
341                ),
342                (
343                    "users_votes",
344                    include_str!("../tests_cfg/many_to_many_multiple/users_votes.rs"),
345                ),
346            ],
347        )
348    }
349
350    #[test]
351    fn self_referencing() -> Result<(), Box<dyn Error>> {
352        use crate::tests_cfg::self_referencing::*;
353        let schema = Schema::new(DbBackend::Postgres);
354
355        validate_compact_entities(
356            vec![
357                schema.create_table_from_entity(bills::Entity),
358                schema.create_table_from_entity(users::Entity),
359            ],
360            vec![
361                (
362                    "bills",
363                    include_str!("../tests_cfg/self_referencing/bills.rs"),
364                ),
365                (
366                    "users",
367                    include_str!("../tests_cfg/self_referencing/users.rs"),
368                ),
369            ],
370        )
371    }
372
373    fn validate_compact_entities(
374        table_create_stmts: Vec<TableCreateStatement>,
375        files: Vec<(&str, &str)>,
376    ) -> Result<(), Box<dyn Error>> {
377        let entities: HashMap<_, _> = EntityTransformer::transform(table_create_stmts)?
378            .entities
379            .into_iter()
380            .map(|entity| (entity.table_name.clone(), entity))
381            .collect();
382
383        for (entity_name, file_content) in files {
384            let entity = entities
385                .get(entity_name)
386                .expect("Forget to add entity to the list");
387
388            assert_eq!(
389                parse_from_file(file_content.as_bytes())?.to_string(),
390                EntityWriter::gen_compact_code_blocks(
391                    entity,
392                    &crate::WithSerde::None,
393                    &crate::DateTimeCrate::Chrono,
394                    &None,
395                    false,
396                    false,
397                    &Default::default(),
398                    &Default::default(),
399                    false,
400                    true,
401                )
402                .into_iter()
403                .skip(1)
404                .fold(TokenStream::new(), |mut acc, tok| {
405                    acc.extend(tok);
406                    acc
407                })
408                .to_string()
409            );
410        }
411
412        Ok(())
413    }
414
415    fn parse_from_file<R>(inner: R) -> io::Result<TokenStream>
416    where
417        R: io::Read,
418    {
419        let mut reader = BufReader::new(inner);
420        let mut lines: Vec<String> = Vec::new();
421
422        reader.read_until(b';', &mut Vec::new())?;
423
424        let mut line = String::new();
425        while reader.read_line(&mut line)? > 0 {
426            lines.push(line.to_owned());
427            line.clear();
428        }
429        let content = lines.join("");
430        Ok(content.parse().unwrap())
431    }
432}