1use crate::{
2 util::unpack_table_ref, ActiveEnum, Column, ConjunctRelation, Entity, EntityWriter, Error,
3 PrimaryKey, Relation, RelationType,
4};
5use sea_query::{ColumnSpec, TableCreateStatement};
6use std::collections::{BTreeMap, HashMap};
7
8#[derive(Clone, Debug)]
9pub struct EntityTransformer;
10
11impl EntityTransformer {
12 pub fn transform(table_create_stmts: Vec<TableCreateStatement>) -> Result<EntityWriter, Error> {
13 let mut enums: BTreeMap<String, ActiveEnum> = BTreeMap::new();
14 let mut inverse_relations: BTreeMap<String, Vec<Relation>> = BTreeMap::new();
15 let mut entities = BTreeMap::new();
16 for table_create in table_create_stmts.into_iter() {
17 let table_name = match table_create.get_table_name() {
18 Some(table_ref) => match table_ref {
19 sea_query::TableRef::Table(t)
20 | sea_query::TableRef::SchemaTable(_, t)
21 | sea_query::TableRef::DatabaseSchemaTable(_, _, t)
22 | sea_query::TableRef::TableAlias(t, _)
23 | sea_query::TableRef::SchemaTableAlias(_, t, _)
24 | sea_query::TableRef::DatabaseSchemaTableAlias(_, _, t, _) => t.to_string(),
25 _ => unimplemented!(),
26 },
27 None => {
28 return Err(Error::TransformError(
29 "Table name should not be empty".into(),
30 ))
31 }
32 };
33 let mut primary_keys: Vec<PrimaryKey> = Vec::new();
34 let columns: Vec<Column> = table_create
35 .get_columns()
36 .iter()
37 .map(|col_def| {
38 let primary_key = col_def
39 .get_column_spec()
40 .iter()
41 .any(|spec| matches!(spec, ColumnSpec::PrimaryKey));
42 if primary_key {
43 primary_keys.push(PrimaryKey {
44 name: col_def.get_column_name(),
45 });
46 }
47 col_def.into()
48 })
49 .map(|mut col: Column| {
50 col.unique = table_create
51 .get_indexes()
52 .iter()
53 .filter(|index| index.is_unique_key())
54 .map(|index| index.get_index_spec().get_column_names())
55 .filter(|col_names| col_names.len() == 1 && col_names[0] == col.name)
56 .count()
57 > 0;
58 col
59 })
60 .inspect(|col| {
61 if let sea_query::ColumnType::Enum { name, variants } = col.get_inner_col_type()
62 {
63 enums.insert(
64 name.to_string(),
65 ActiveEnum {
66 enum_name: name.clone(),
67 values: variants.clone(),
68 },
69 );
70 }
71 })
72 .collect();
73 let mut ref_table_counts: BTreeMap<String, usize> = BTreeMap::new();
74 let relations: Vec<Relation> = table_create
75 .get_foreign_key_create_stmts()
76 .iter()
77 .map(|fk_create_stmt| fk_create_stmt.get_foreign_key())
78 .map(|tbl_fk| {
79 let ref_tbl = unpack_table_ref(tbl_fk.get_ref_table().unwrap());
80 if let Some(count) = ref_table_counts.get_mut(&ref_tbl) {
81 if *count == 0 {
82 *count = 1;
83 }
84 *count += 1;
85 } else {
86 ref_table_counts.insert(ref_tbl, 0);
87 };
88 tbl_fk.into()
89 })
90 .collect::<Vec<_>>()
91 .into_iter()
92 .rev()
93 .map(|mut rel: Relation| {
94 rel.self_referencing = rel.ref_table == table_name;
95 if let Some(count) = ref_table_counts.get_mut(&rel.ref_table) {
96 rel.num_suffix = *count;
97 if *count > 0 {
98 *count -= 1;
99 }
100 }
101 rel
102 })
103 .rev()
104 .collect();
105 primary_keys.extend(
106 table_create
107 .get_indexes()
108 .iter()
109 .filter(|index| index.is_primary_key())
110 .flat_map(|index| {
111 index
112 .get_index_spec()
113 .get_column_names()
114 .into_iter()
115 .map(|name| PrimaryKey { name })
116 .collect::<Vec<_>>()
117 }),
118 );
119 let entity = Entity {
120 table_name: table_name.clone(),
121 columns,
122 relations: relations.clone(),
123 conjunct_relations: vec![],
124 primary_keys,
125 };
126 entities.insert(table_name.clone(), entity.clone());
127 for mut rel in relations.into_iter() {
128 if rel.self_referencing {
130 continue;
131 }
132 if rel.num_suffix > 0 {
135 continue;
136 }
137 let ref_table = rel.ref_table;
138 let mut unique = true;
139 for column in rel.columns.iter() {
140 if !entity
141 .columns
142 .iter()
143 .filter(|col| col.unique)
144 .any(|col| col.name.as_str() == column)
145 {
146 unique = false;
147 break;
148 }
149 }
150 if rel.columns.len() == entity.primary_keys.len() {
151 let mut count_pk = 0;
152 for primary_key in entity.primary_keys.iter() {
153 if rel.columns.contains(&primary_key.name) {
154 count_pk += 1;
155 }
156 }
157 if count_pk == entity.primary_keys.len() {
158 unique = true;
159 }
160 }
161 let rel_type = if unique {
162 RelationType::HasOne
163 } else {
164 RelationType::HasMany
165 };
166 rel.rel_type = rel_type;
167 rel.ref_table = table_name.to_string();
168 rel.columns = Vec::new();
169 rel.ref_columns = Vec::new();
170 if let Some(vec) = inverse_relations.get_mut(&ref_table) {
171 vec.push(rel);
172 } else {
173 inverse_relations.insert(ref_table, vec![rel]);
174 }
175 }
176 }
177 for (tbl_name, relations) in inverse_relations.into_iter() {
178 if let Some(entity) = entities.get_mut(&tbl_name) {
179 for relation in relations.into_iter() {
180 let duplicate_relation = entity
181 .relations
182 .iter()
183 .any(|rel| rel.ref_table == relation.ref_table);
184 if !duplicate_relation {
185 entity.relations.push(relation);
186 }
187 }
188 }
189 }
190 for table_name in entities.clone().keys() {
191 let relations = match entities.get(table_name) {
192 Some(entity) => {
193 let is_conjunct_relation =
194 entity.relations.len() == 2 && entity.primary_keys.len() == 2;
195 if !is_conjunct_relation {
196 continue;
197 }
198 entity.relations.clone()
199 }
200 None => unreachable!(),
201 };
202 for (i, rel) in relations.iter().enumerate() {
203 let another_rel = relations.get((i == 0) as usize).unwrap();
204 if let Some(entity) = entities.get_mut(&rel.ref_table) {
205 let conjunct_relation = ConjunctRelation {
206 via: table_name.clone(),
207 to: another_rel.ref_table.clone(),
208 };
209 entity.conjunct_relations.push(conjunct_relation);
210 }
211 }
212 }
213 Ok(EntityWriter {
214 entities: entities
215 .into_values()
216 .map(|mut v| {
217 let duplicated_to: Vec<_> = v
219 .conjunct_relations
220 .iter()
221 .fold(HashMap::new(), |mut acc, conjunct_relation| {
222 acc.entry(conjunct_relation.to.clone())
223 .and_modify(|c| *c += 1)
224 .or_insert(1);
225 acc
226 })
227 .into_iter()
228 .filter(|(_, v)| v > &1)
229 .map(|(k, _)| k)
230 .collect();
231 v.conjunct_relations
232 .retain(|conjunct_relation| !duplicated_to.contains(&conjunct_relation.to));
233
234 v.relations.iter_mut().for_each(|relation| {
237 if v.conjunct_relations
238 .iter()
239 .any(|conjunct_relation| conjunct_relation.to == relation.ref_table)
240 {
241 relation.impl_related = false;
242 }
243 });
244
245 v.relations.sort_by(|a, b| a.ref_table.cmp(&b.ref_table));
247 v.conjunct_relations.sort_by(|a, b| a.to.cmp(&b.to));
248 v
249 })
250 .collect(),
251 enums,
252 })
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259 use pretty_assertions::assert_eq;
260 use proc_macro2::TokenStream;
261 use sea_orm::{DbBackend, Schema};
262 use std::{
263 error::Error,
264 io::{self, BufRead, BufReader},
265 };
266
267 #[test]
268 fn duplicated_many_to_many_paths() -> Result<(), Box<dyn Error>> {
269 use crate::tests_cfg::duplicated_many_to_many_paths::*;
270 let schema = Schema::new(DbBackend::Postgres);
271
272 validate_compact_entities(
273 vec![
274 schema.create_table_from_entity(bills::Entity),
275 schema.create_table_from_entity(users::Entity),
276 schema.create_table_from_entity(users_saved_bills::Entity),
277 schema.create_table_from_entity(users_votes::Entity),
278 ],
279 vec![
280 (
281 "bills",
282 include_str!("../tests_cfg/duplicated_many_to_many_paths/bills.rs"),
283 ),
284 (
285 "users",
286 include_str!("../tests_cfg/duplicated_many_to_many_paths/users.rs"),
287 ),
288 (
289 "users_saved_bills",
290 include_str!("../tests_cfg/duplicated_many_to_many_paths/users_saved_bills.rs"),
291 ),
292 (
293 "users_votes",
294 include_str!("../tests_cfg/duplicated_many_to_many_paths/users_votes.rs"),
295 ),
296 ],
297 )
298 }
299
300 #[test]
301 fn many_to_many() -> Result<(), Box<dyn Error>> {
302 use crate::tests_cfg::many_to_many::*;
303 let schema = Schema::new(DbBackend::Postgres);
304
305 validate_compact_entities(
306 vec![
307 schema.create_table_from_entity(bills::Entity),
308 schema.create_table_from_entity(users::Entity),
309 schema.create_table_from_entity(users_votes::Entity),
310 ],
311 vec![
312 ("bills", include_str!("../tests_cfg/many_to_many/bills.rs")),
313 ("users", include_str!("../tests_cfg/many_to_many/users.rs")),
314 (
315 "users_votes",
316 include_str!("../tests_cfg/many_to_many/users_votes.rs"),
317 ),
318 ],
319 )
320 }
321
322 #[test]
323 fn many_to_many_multiple() -> Result<(), Box<dyn Error>> {
324 use crate::tests_cfg::many_to_many_multiple::*;
325 let schema = Schema::new(DbBackend::Postgres);
326
327 validate_compact_entities(
328 vec![
329 schema.create_table_from_entity(bills::Entity),
330 schema.create_table_from_entity(users::Entity),
331 schema.create_table_from_entity(users_votes::Entity),
332 ],
333 vec![
334 (
335 "bills",
336 include_str!("../tests_cfg/many_to_many_multiple/bills.rs"),
337 ),
338 (
339 "users",
340 include_str!("../tests_cfg/many_to_many_multiple/users.rs"),
341 ),
342 (
343 "users_votes",
344 include_str!("../tests_cfg/many_to_many_multiple/users_votes.rs"),
345 ),
346 ],
347 )
348 }
349
350 #[test]
351 fn self_referencing() -> Result<(), Box<dyn Error>> {
352 use crate::tests_cfg::self_referencing::*;
353 let schema = Schema::new(DbBackend::Postgres);
354
355 validate_compact_entities(
356 vec![
357 schema.create_table_from_entity(bills::Entity),
358 schema.create_table_from_entity(users::Entity),
359 ],
360 vec![
361 (
362 "bills",
363 include_str!("../tests_cfg/self_referencing/bills.rs"),
364 ),
365 (
366 "users",
367 include_str!("../tests_cfg/self_referencing/users.rs"),
368 ),
369 ],
370 )
371 }
372
373 fn validate_compact_entities(
374 table_create_stmts: Vec<TableCreateStatement>,
375 files: Vec<(&str, &str)>,
376 ) -> Result<(), Box<dyn Error>> {
377 let entities: HashMap<_, _> = EntityTransformer::transform(table_create_stmts)?
378 .entities
379 .into_iter()
380 .map(|entity| (entity.table_name.clone(), entity))
381 .collect();
382
383 for (entity_name, file_content) in files {
384 let entity = entities
385 .get(entity_name)
386 .expect("Forget to add entity to the list");
387
388 assert_eq!(
389 parse_from_file(file_content.as_bytes())?.to_string(),
390 EntityWriter::gen_compact_code_blocks(
391 entity,
392 &crate::WithSerde::None,
393 &crate::DateTimeCrate::Chrono,
394 &None,
395 false,
396 false,
397 &Default::default(),
398 &Default::default(),
399 false,
400 true,
401 )
402 .into_iter()
403 .skip(1)
404 .fold(TokenStream::new(), |mut acc, tok| {
405 acc.extend(tok);
406 acc
407 })
408 .to_string()
409 );
410 }
411
412 Ok(())
413 }
414
415 fn parse_from_file<R>(inner: R) -> io::Result<TokenStream>
416 where
417 R: io::Read,
418 {
419 let mut reader = BufReader::new(inner);
420 let mut lines: Vec<String> = Vec::new();
421
422 reader.read_until(b';', &mut Vec::new())?;
423
424 let mut line = String::new();
425 while reader.read_line(&mut line)? > 0 {
426 lines.push(line.to_owned());
427 line.clear();
428 }
429 let content = lines.join("");
430 Ok(content.parse().unwrap())
431 }
432}