Skip to main content

pecto_typescript/extractors/
entity.rs

1use super::common::*;
2use crate::context::ParsedFile;
3use pecto_core::model::*;
4
5/// Extract entities from TypeScript files: TypeORM @Entity, Mongoose schemas.
6pub fn extract(file: &ParsedFile) -> Option<Capability> {
7    let full_text = &file.source;
8
9    // Quick check
10    if !full_text.contains("@Entity")
11        && !full_text.contains("@Column")
12        && !full_text.contains("new Schema(")
13        && !full_text.contains("mongoose.Schema")
14    {
15        return None;
16    }
17
18    let mut entities = Vec::new();
19
20    // TypeORM: @Entity() class with @Column decorators
21    if full_text.contains("@Entity") {
22        extract_typeorm_entities(full_text, &mut entities);
23    }
24
25    // Mongoose: new Schema({...})
26    if full_text.contains("Schema(") {
27        extract_mongoose_schemas(full_text, &mut entities);
28    }
29
30    if entities.is_empty() {
31        return None;
32    }
33
34    let file_stem = file
35        .path
36        .rsplit('/')
37        .next()
38        .unwrap_or(&file.path)
39        .split('.')
40        .next()
41        .unwrap_or("unknown");
42    let capability_name = format!("{}-entity", to_kebab_case(file_stem));
43
44    let mut capability = Capability::new(capability_name, file.path.clone());
45    capability.entities = entities;
46    Some(capability)
47}
48
49fn extract_typeorm_entities(source: &str, entities: &mut Vec<Entity>) {
50    // Find @Entity() class blocks
51    let mut remaining = source;
52    while let Some(entity_pos) = remaining.find("@Entity(") {
53        remaining = &remaining[entity_pos..];
54
55        // Find class name and base class
56        let (class_name, bases) = remaining
57            .find("class ")
58            .map(|pos| {
59                let after = &remaining[pos + 6..];
60                let line_end = after.find('{').unwrap_or(after.len());
61                let class_line = &after[..line_end];
62                let name = class_line
63                    .split([' ', '{', '\n'])
64                    .next()
65                    .unwrap_or("Unknown")
66                    .trim()
67                    .to_string();
68                let bases = if let Some(ext_pos) = class_line.find("extends ") {
69                    let after_ext = &class_line[ext_pos + 8..];
70                    let base = after_ext
71                        .split([' ', '{', '\n', ','])
72                        .next()
73                        .unwrap_or("")
74                        .trim()
75                        .to_string();
76                    if base.is_empty() {
77                        Vec::new()
78                    } else {
79                        vec![base]
80                    }
81                } else {
82                    Vec::new()
83                };
84                (name, bases)
85            })
86            .unwrap_or_else(|| ("Unknown".to_string(), Vec::new()));
87
88        // Find table name from @Entity('tablename') or default to class name
89        let table_name = remaining
90            .find("@Entity(")
91            .and_then(|pos| {
92                let after = &remaining[pos + 8..];
93                let arg = after.split(')').next()?;
94                if arg.contains('"') || arg.contains('\'') {
95                    Some(clean_string_literal(arg.trim()))
96                } else {
97                    None
98                }
99            })
100            .unwrap_or_else(|| class_name.to_lowercase());
101
102        // Extract fields from @Column, @PrimaryGeneratedColumn, @ManyToOne, etc.
103        let mut fields = Vec::new();
104
105        // Find class body (between { and matching })
106        if let Some(class_start) = remaining.find('{') {
107            let class_body = &remaining[class_start..];
108            let mut depth = 0;
109            let mut end = class_body.len();
110            for (i, c) in class_body.chars().enumerate() {
111                match c {
112                    '{' => depth += 1,
113                    '}' => {
114                        depth -= 1;
115                        if depth == 0 {
116                            end = i;
117                            break;
118                        }
119                    }
120                    _ => {}
121                }
122            }
123            let body = &class_body[1..end];
124            extract_typeorm_fields(body, &mut fields);
125        }
126
127        entities.push(Entity {
128            name: class_name,
129            table: table_name,
130            fields,
131            bases,
132        });
133
134        // Move past this entity
135        remaining = &remaining[1..];
136        if let Some(next) = remaining.find("class ") {
137            remaining = &remaining[next..];
138        } else {
139            break;
140        }
141    }
142}
143
144fn extract_typeorm_fields(body: &str, fields: &mut Vec<EntityField>) {
145    let decorators = [
146        "@PrimaryGeneratedColumn",
147        "@PrimaryColumn",
148        "@Column",
149        "@ManyToOne",
150        "@OneToMany",
151        "@ManyToMany",
152        "@OneToOne",
153        "@JoinColumn",
154    ];
155
156    for line in body.lines() {
157        let trimmed = line.trim();
158        let has_decorator = decorators.iter().any(|d| trimmed.starts_with(d));
159        if !has_decorator {
160            continue;
161        }
162
163        let mut constraints = Vec::new();
164
165        if trimmed.starts_with("@PrimaryGeneratedColumn") {
166            constraints.push("@PrimaryGeneratedColumn".to_string());
167        } else if trimmed.starts_with("@PrimaryColumn") {
168            constraints.push("@PrimaryColumn".to_string());
169        } else if trimmed.starts_with("@Column") {
170            constraints.push("@Column".to_string());
171            if trimmed.contains("nullable: false") || trimmed.contains("nullable:false") {
172                constraints.push("required".to_string());
173            }
174            if trimmed.contains("unique: true") || trimmed.contains("unique:true") {
175                constraints.push("unique".to_string());
176            }
177        } else if trimmed.starts_with("@ManyToOne") {
178            constraints.push("@ManyToOne".to_string());
179        } else if trimmed.starts_with("@OneToMany") {
180            constraints.push("@OneToMany".to_string());
181        } else if trimmed.starts_with("@ManyToMany") {
182            constraints.push("@ManyToMany".to_string());
183        } else if trimmed.starts_with("@OneToOne") {
184            constraints.push("@OneToOne".to_string());
185        }
186
187        // Next non-decorator line should be the field: "name: string;"
188        // We look at the line after the decorator in a simplified way
189        // (the field name and type are often on the next line or same line after decorator)
190    }
191
192    // Simpler approach: find all "fieldName: Type" patterns after decorators
193    let lines: Vec<&str> = body.lines().collect();
194    let mut i = 0;
195    while i < lines.len() {
196        let trimmed = lines[i].trim();
197        if decorators.iter().any(|d| trimmed.starts_with(d)) {
198            let mut constraints = Vec::new();
199            if trimmed.contains("PrimaryGeneratedColumn") || trimmed.contains("PrimaryColumn") {
200                constraints.push("primary_key".to_string());
201            }
202            if trimmed.contains("ManyToOne") {
203                constraints.push("@ManyToOne".to_string());
204            }
205            if trimmed.contains("OneToMany") {
206                constraints.push("@OneToMany".to_string());
207            }
208            if trimmed.contains("ManyToMany") {
209                constraints.push("@ManyToMany".to_string());
210            }
211            if trimmed.contains("nullable: false") {
212                constraints.push("required".to_string());
213            }
214            if trimmed.contains("unique: true") {
215                constraints.push("unique".to_string());
216            }
217
218            // Look at next line for field declaration
219            if i + 1 < lines.len() {
220                let next = lines[i + 1].trim();
221                if next.contains(':') && !next.starts_with('@') && !next.starts_with("//") {
222                    let parts: Vec<&str> = next.splitn(2, ':').collect();
223                    if parts.len() == 2 {
224                        let name = parts[0]
225                            .trim()
226                            .trim_start_matches("readonly ")
227                            .trim()
228                            .to_string();
229                        let field_type = parts[1].trim().trim_end_matches(';').trim().to_string();
230                        if !name.is_empty() && !name.starts_with("//") {
231                            fields.push(EntityField {
232                                name,
233                                field_type,
234                                constraints,
235                            });
236                        }
237                    }
238                }
239            }
240        }
241        i += 1;
242    }
243}
244
245fn extract_mongoose_schemas(source: &str, entities: &mut Vec<Entity>) {
246    // Find: const XxxSchema = new Schema({ ... })
247    // or: const XxxSchema = new mongoose.Schema({ ... })
248    let mut remaining = source;
249    while let Some(pos) = remaining.find("Schema(") {
250        // Look back for variable name
251        let before = &remaining[..pos];
252        let schema_name = before
253            .rsplit([' ', '\t', '='])
254            .find(|s| !s.is_empty() && *s != "new" && *s != "mongoose.")
255            .map(|s| {
256                s.trim()
257                    .replace("Schema", "")
258                    .replace("const ", "")
259                    .replace("let ", "")
260            })
261            .unwrap_or_else(|| "Unknown".to_string());
262
263        let name = if schema_name.is_empty() || schema_name == "new" {
264            "Unknown".to_string()
265        } else {
266            schema_name
267        };
268
269        entities.push(Entity {
270            name: name.clone(),
271            table: name.to_lowercase(),
272            fields: Vec::new(), // Mongoose schema fields are complex to parse from text
273            bases: Vec::new(),
274        });
275
276        remaining = &remaining[pos + 7..];
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283    use crate::context::ParsedFile;
284
285    fn parse_file(source: &str, path: &str) -> ParsedFile {
286        ParsedFile::parse(source.to_string(), path.to_string()).unwrap()
287    }
288
289    #[test]
290    fn test_typeorm_entity() {
291        let source = r#"
292import { Entity, PrimaryGeneratedColumn, Column, ManyToOne } from 'typeorm';
293
294@Entity('users')
295export class User {
296    @PrimaryGeneratedColumn()
297    id: number;
298
299    @Column({ nullable: false, unique: true })
300    email: string;
301
302    @Column()
303    name: string;
304
305    @ManyToOne(() => Organization)
306    organization: Organization;
307}
308"#;
309
310        let file = parse_file(source, "entities/user.entity.ts");
311        let capability = extract(&file).unwrap();
312
313        let entity = &capability.entities[0];
314        assert_eq!(entity.name, "User");
315        assert_eq!(entity.table, "users");
316        assert!(
317            entity.fields.len() >= 3,
318            "Should find fields, found {}",
319            entity.fields.len()
320        );
321    }
322
323    #[test]
324    fn test_no_entity() {
325        let source = r#"
326export class UserService {
327    findAll() { return []; }
328}
329"#;
330        let file = parse_file(source, "user.service.ts");
331        assert!(extract(&file).is_none());
332    }
333}