Skip to main content

pecto_python/extractors/
entity.rs

1use super::common::*;
2use crate::context::ParsedFile;
3use pecto_core::model::*;
4
5/// Extract entities from Python models: SQLAlchemy, Django ORM, Pydantic.
6pub fn extract(file: &ParsedFile) -> Option<Capability> {
7    let root = file.tree.root_node();
8    let source = file.source.as_bytes();
9    let full_text = &file.source;
10
11    // Quick check: does this file have model-like patterns?
12    if !full_text.contains("Column(")
13        && !full_text.contains("mapped_column(")
14        && !full_text.contains("models.Model")
15        && !full_text.contains("BaseModel")
16        && !full_text.contains("SQLModel")
17        && !full_text.contains("Base)")
18        && !full_text.contains("DeclarativeBase")
19    {
20        return None;
21    }
22
23    let mut entities = Vec::new();
24
25    // First pass: collect all model class names for same-file inheritance resolution
26    let mut known_model_classes: Vec<String> = Vec::new();
27
28    // Pre-scan to identify known model base classes in this file
29    for i in 0..root.named_child_count() {
30        let node = root.named_child(i).unwrap();
31        let class_node = if node.kind() == "class_definition" {
32            node
33        } else if node.kind() == "decorated_definition" {
34            match get_inner_definition(&node) {
35                Some(n) if n.kind() == "class_definition" => n,
36                _ => continue,
37            }
38        } else {
39            continue;
40        };
41
42        let name = get_def_name(&class_node, source);
43        let bases = get_class_bases(&class_node, source);
44
45        // Mark classes that directly extend known model frameworks
46        let is_direct_model = bases.iter().any(|b| {
47            b == "Base"
48                || b.contains("DeclarativeBase")
49                || b == "models.Model"
50                || b.starts_with("models.")
51                || b == "BaseModel"
52                || b == "SQLModel"
53        });
54
55        if is_direct_model || has_table_kwarg(&class_node, source) {
56            known_model_classes.push(name);
57        }
58    }
59
60    // Second pass: add classes that inherit from already-known model classes
61    // This handles chains like: SQLModel -> UserBase -> UserCreate
62    let mut changed = true;
63    while changed {
64        changed = false;
65        for i in 0..root.named_child_count() {
66            let node = root.named_child(i).unwrap();
67            let class_node = if node.kind() == "class_definition" {
68                node
69            } else if node.kind() == "decorated_definition" {
70                match get_inner_definition(&node) {
71                    Some(n) if n.kind() == "class_definition" => n,
72                    _ => continue,
73                }
74            } else {
75                continue;
76            };
77
78            let name = get_def_name(&class_node, source);
79            if known_model_classes.contains(&name) {
80                continue;
81            }
82            let bases = get_class_bases(&class_node, source);
83            if bases
84                .iter()
85                .any(|b| known_model_classes.iter().any(|k| k == b))
86            {
87                known_model_classes.push(name);
88                changed = true;
89            }
90        }
91    }
92
93    for i in 0..root.named_child_count() {
94        let node = root.named_child(i).unwrap();
95
96        let class_node = if node.kind() == "class_definition" {
97            node
98        } else if node.kind() == "decorated_definition" {
99            match get_inner_definition(&node) {
100                Some(n) if n.kind() == "class_definition" => n,
101                _ => continue,
102            }
103        } else {
104            continue;
105        };
106
107        let class_name = get_def_name(&class_node, source);
108
109        // Skip base class definitions themselves (e.g. class Base(DeclarativeBase))
110        if is_base_class_definition(&class_name) {
111            continue;
112        }
113
114        let bases = get_class_bases(&class_node, source);
115
116        // Check if class has table=True kwarg (SQLModel database model)
117        let has_table_true = has_table_kwarg(&class_node, source);
118
119        // Filter bases to only real class names (exclude keyword args like table=True)
120        let class_bases: Vec<String> = bases.iter().filter(|b| !b.contains('=')).cloned().collect();
121
122        // SQLAlchemy: class User(Base) or class User(DeclarativeBase)
123        if bases
124            .iter()
125            .any(|b| b == "Base" || b.contains("DeclarativeBase"))
126            && let Some(mut entity) = extract_sqlalchemy_entity(&class_node, source, &class_name)
127        {
128            entity.bases = class_bases;
129            entities.push(entity);
130        }
131        // SQLModel with table=True: database-backed model (e.g. class User(UserBase, table=True))
132        else if has_table_true
133            && full_text.contains("SQLModel")
134            && let Some(mut entity) = extract_sqlmodel_entity(&class_node, source, &class_name)
135        {
136            entity.bases = class_bases;
137            entities.push(entity);
138        }
139        // Django: class User(models.Model)
140        else if bases
141            .iter()
142            .any(|b| b == "models.Model" || b.starts_with("models."))
143            && let Some(mut entity) = extract_django_model(&class_node, source, &class_name)
144        {
145            entity.bases = class_bases;
146            entities.push(entity);
147        }
148        // Pydantic/SQLModel (non-table): direct base or inherits from known model in same file
149        else if bases.iter().any(|b| {
150            b == "BaseModel" || b == "SQLModel" || known_model_classes.iter().any(|k| k == b)
151        }) && let Some(mut entity) =
152            extract_pydantic_model(&class_node, source, &class_name)
153        {
154            entity.bases = class_bases;
155            entities.push(entity);
156        }
157    }
158
159    if entities.is_empty() {
160        return None;
161    }
162
163    let file_stem = file
164        .path
165        .rsplit('/')
166        .next()
167        .unwrap_or(&file.path)
168        .trim_end_matches(".py");
169    let capability_name = format!("{}-model", to_kebab_case(file_stem));
170
171    let mut capability = Capability::new(capability_name, file.path.clone());
172    capability.entities = entities;
173    Some(capability)
174}
175
176fn get_class_bases(class_node: &tree_sitter::Node, source: &[u8]) -> Vec<String> {
177    let mut bases = Vec::new();
178    if let Some(arg_list) = class_node.child_by_field_name("superclasses") {
179        for i in 0..arg_list.named_child_count() {
180            let arg = arg_list.named_child(i).unwrap();
181            bases.push(node_text(&arg, source));
182        }
183    }
184    bases
185}
186
187/// Check if a class definition has `table=True` in its keyword arguments.
188fn has_table_kwarg(class_node: &tree_sitter::Node, source: &[u8]) -> bool {
189    if let Some(arg_list) = class_node.child_by_field_name("superclasses") {
190        for i in 0..arg_list.named_child_count() {
191            let arg = arg_list.named_child(i).unwrap();
192            if arg.kind() == "keyword_argument" {
193                let text = node_text(&arg, source);
194                if text.contains("table") && text.contains("True") {
195                    return true;
196                }
197            }
198        }
199    }
200    false
201}
202
203/// Extract SQLModel entity with table=True: uses Pydantic-style typed fields + Relationship()
204fn extract_sqlmodel_entity(
205    class_node: &tree_sitter::Node,
206    source: &[u8],
207    class_name: &str,
208) -> Option<Entity> {
209    let body = class_node.child_by_field_name("body")?;
210    let mut fields = Vec::new();
211    let table_name = class_name.to_lowercase();
212
213    for i in 0..body.named_child_count() {
214        let stmt = body.named_child(i).unwrap();
215        if stmt.kind() != "expression_statement" {
216            continue;
217        }
218
219        let text = node_text(&stmt, source);
220
221        // Skip non-field lines
222        if !text.contains(':') || text.starts_with('#') {
223            continue;
224        }
225
226        // name: type = Field(...) or name: type = Relationship(...)
227        let parts: Vec<&str> = text.splitn(2, ':').collect();
228        if parts.len() != 2 {
229            continue;
230        }
231
232        let name = parts[0].trim().to_string();
233        if name.starts_with('_') || name == "model_config" {
234            continue;
235        }
236
237        let type_and_default = parts[1].trim();
238        let field_type = type_and_default
239            .split('=')
240            .next()
241            .unwrap_or("")
242            .trim()
243            .to_string();
244
245        // Skip Relationship fields (they describe associations, not columns)
246        if text.contains("Relationship(") {
247            fields.push(EntityField {
248                name,
249                field_type: format!("relationship({})", field_type),
250                constraints: vec!["relationship".to_string()],
251            });
252            continue;
253        }
254
255        let mut constraints = Vec::new();
256        if text.contains("Field(") {
257            if text.contains("primary_key=True") {
258                constraints.push("primary_key".to_string());
259            }
260            if text.contains("unique=True") {
261                constraints.push("unique".to_string());
262            }
263            if text.contains("index=True") {
264                constraints.push("indexed".to_string());
265            }
266            if text.contains("foreign_key=") {
267                constraints.push("relationship".to_string());
268            }
269            if text.contains("nullable=False") {
270                constraints.push("required".to_string());
271            }
272            if let Some(v) = extract_kwarg_value(&text, "max_length") {
273                constraints.push(format!("max_length={}", v));
274            }
275            if let Some(v) = extract_kwarg_value(&text, "min_length") {
276                constraints.push(format!("min_length={}", v));
277            }
278        }
279
280        // If no explicit nullable and not Optional/None default, it's required
281        if !text.contains("| None")
282            && !text.contains("Optional")
283            && !text.contains("= None")
284            && !constraints.contains(&"required".to_string())
285        {
286            constraints.push("required".to_string());
287        }
288
289        fields.push(EntityField {
290            name,
291            field_type,
292            constraints,
293        });
294    }
295
296    Some(Entity {
297        name: class_name.to_string(),
298        table: table_name,
299        fields,
300        bases: Vec::new(),
301    })
302}
303
304/// Extract SQLAlchemy model: Column(String), Column(Integer), relationship()
305fn extract_sqlalchemy_entity(
306    class_node: &tree_sitter::Node,
307    source: &[u8],
308    class_name: &str,
309) -> Option<Entity> {
310    let body = class_node.child_by_field_name("body")?;
311    let mut fields = Vec::new();
312    let mut table_name = class_name.to_lowercase();
313
314    for i in 0..body.named_child_count() {
315        let stmt = body.named_child(i).unwrap();
316        if stmt.kind() != "expression_statement" {
317            continue;
318        }
319
320        let text = node_text(&stmt, source);
321
322        // __tablename__ = "users"
323        if text.contains("__tablename__") {
324            if let Some(val) = extract_assignment_string(&text) {
325                table_name = val;
326            }
327            continue;
328        }
329
330        // field = Column(Type, ...) or field: Mapped[type] = mapped_column(...)
331        if (text.contains("Column(")
332            || text.contains("relationship(")
333            || text.contains("mapped_column("))
334            && let Some(field) = parse_sqlalchemy_field(&text)
335        {
336            fields.push(field);
337        }
338    }
339
340    Some(Entity {
341        name: class_name.to_string(),
342        table: table_name,
343        fields,
344        bases: Vec::new(),
345    })
346}
347
348fn parse_sqlalchemy_field(text: &str) -> Option<EntityField> {
349    // Handle both old-style `name = Column(...)` and new-style `name: Mapped[type] = mapped_column(...)`
350    let (name, rhs) = if text.contains("mapped_column(") {
351        // name: Mapped[type] = mapped_column(Type, ...)
352        let colon_parts: Vec<&str> = text.splitn(2, ':').collect();
353        if colon_parts.len() != 2 {
354            return None;
355        }
356        let field_name = colon_parts[0].trim().to_string();
357        let after_colon = colon_parts[1].trim();
358        // Find the `= mapped_column(` part
359        if let Some(eq_pos) = after_colon.find("= mapped_column(") {
360            (field_name, after_colon[eq_pos + 2..].trim().to_string())
361        } else if let Some(eq_pos) = after_colon.find("=mapped_column(") {
362            (field_name, after_colon[eq_pos + 1..].trim().to_string())
363        } else {
364            return None;
365        }
366    } else {
367        let parts: Vec<&str> = text.splitn(2, '=').collect();
368        if parts.len() != 2 {
369            return None;
370        }
371        (parts[0].trim().to_string(), parts[1].trim().to_string())
372    };
373
374    if rhs.starts_with("Column(") || rhs.starts_with("mapped_column(") {
375        let prefix_len = if rhs.starts_with("mapped_column(") {
376            14
377        } else {
378            7
379        };
380        let inner = &rhs[prefix_len..rhs.rfind(')')?];
381        let args: Vec<&str> = inner.split(',').map(|s| s.trim()).collect();
382        let field_type = args.first().unwrap_or(&"").to_string();
383
384        let mut constraints = Vec::new();
385        for arg in &args[1..] {
386            if arg.contains("primary_key=True") {
387                constraints.push("primary_key".to_string());
388            }
389            if arg.contains("nullable=False") {
390                constraints.push("required".to_string());
391            }
392            if arg.contains("unique=True") {
393                constraints.push("unique".to_string());
394            }
395            if arg.contains("index=True") {
396                constraints.push("indexed".to_string());
397            }
398        }
399
400        Some(EntityField {
401            name,
402            field_type,
403            constraints,
404        })
405    } else if rhs.starts_with("relationship(") {
406        let inner = &rhs[13..rhs.rfind(')')?];
407        let target = inner
408            .split(',')
409            .next()?
410            .trim()
411            .trim_matches('"')
412            .trim_matches('\'');
413        Some(EntityField {
414            name,
415            field_type: format!("relationship({})", target),
416            constraints: vec!["relationship".to_string()],
417        })
418    } else {
419        None
420    }
421}
422
423/// Filter out non-entity base class declarations (e.g. `class Base(DeclarativeBase)` itself)
424fn is_base_class_definition(class_name: &str) -> bool {
425    class_name == "Base" || class_name == "DeclarativeBase"
426}
427
428/// Extract Django model: CharField, IntegerField, ForeignKey, etc.
429fn extract_django_model(
430    class_node: &tree_sitter::Node,
431    source: &[u8],
432    class_name: &str,
433) -> Option<Entity> {
434    let body = class_node.child_by_field_name("body")?;
435    let mut fields = Vec::new();
436    let table_name = class_name.to_lowercase();
437
438    for i in 0..body.named_child_count() {
439        let stmt = body.named_child(i).unwrap();
440        if stmt.kind() != "expression_statement" {
441            continue;
442        }
443
444        let text = node_text(&stmt, source);
445
446        // field = models.CharField(...) or field = CharField(...)
447        if (text.contains("Field(")
448            || text.contains("ForeignKey(")
449            || text.contains("ManyToManyField(")
450            || text.contains("OneToOneField("))
451            && let Some(field) = parse_django_field(&text)
452        {
453            fields.push(field);
454        }
455    }
456
457    Some(Entity {
458        name: class_name.to_string(),
459        table: table_name,
460        fields,
461        bases: Vec::new(),
462    })
463}
464
465fn parse_django_field(text: &str) -> Option<EntityField> {
466    let parts: Vec<&str> = text.splitn(2, '=').collect();
467    if parts.len() != 2 {
468        return None;
469    }
470
471    let name = parts[0].trim().to_string();
472    let rhs = parts[1].trim();
473
474    // Extract field type: models.CharField or CharField
475    let field_type = rhs.split('(').next()?.trim().replace("models.", "");
476
477    let mut constraints = Vec::new();
478    if rhs.contains("primary_key=True") {
479        constraints.push("primary_key".to_string());
480    }
481    if rhs.contains("blank=False")
482        || rhs.contains("null=False")
483        || !rhs.contains("blank=True") && !rhs.contains("null=True")
484    {
485        constraints.push("required".to_string());
486    }
487    if rhs.contains("unique=True") {
488        constraints.push("unique".to_string());
489    }
490    if rhs.contains("max_length=")
491        && let Some(ml) = extract_kwarg_value(rhs, "max_length")
492    {
493        constraints.push(format!("max_length={}", ml));
494    }
495    if field_type.contains("ForeignKey") || field_type.contains("OneToOne") {
496        constraints.push("relationship".to_string());
497    }
498    if field_type.contains("ManyToMany") {
499        constraints.push("many_to_many".to_string());
500    }
501
502    Some(EntityField {
503        name,
504        field_type,
505        constraints,
506    })
507}
508
509/// Extract Pydantic model fields from type annotations.
510fn extract_pydantic_model(
511    class_node: &tree_sitter::Node,
512    source: &[u8],
513    class_name: &str,
514) -> Option<Entity> {
515    let body = class_node.child_by_field_name("body")?;
516    let mut fields = Vec::new();
517
518    for i in 0..body.named_child_count() {
519        let stmt = body.named_child(i).unwrap();
520
521        let text = node_text(&stmt, source);
522
523        // name: str or name: str = Field(...)
524        if stmt.kind() == "expression_statement" && text.contains(':') {
525            let parts: Vec<&str> = text.splitn(2, ':').collect();
526            if parts.len() == 2 {
527                let name = parts[0].trim().to_string();
528                let type_and_default = parts[1].trim();
529                let field_type = type_and_default
530                    .split('=')
531                    .next()
532                    .unwrap_or("")
533                    .trim()
534                    .to_string();
535
536                if name.starts_with('_') || name == "model_config" || name == "Config" {
537                    continue;
538                }
539
540                let mut constraints = Vec::new();
541                if text.contains("Field(") {
542                    if let Some(v) = extract_kwarg_value(&text, "min_length") {
543                        constraints.push(format!("min_length={}", v));
544                    }
545                    if let Some(v) = extract_kwarg_value(&text, "max_length") {
546                        constraints.push(format!("max_length={}", v));
547                    }
548                    if text.contains("gt=") || text.contains("ge=") {
549                        constraints.push("min_value".to_string());
550                    }
551                    if text.contains("lt=") || text.contains("le=") {
552                        constraints.push("max_value".to_string());
553                    }
554                }
555
556                if !field_type.starts_with("Optional") && !text.contains("= None") {
557                    constraints.push("required".to_string());
558                }
559
560                fields.push(EntityField {
561                    name,
562                    field_type,
563                    constraints,
564                });
565            }
566        }
567    }
568
569    Some(Entity {
570        name: class_name.to_string(),
571        table: class_name.to_lowercase(),
572        fields,
573        bases: Vec::new(),
574    })
575}
576
577fn extract_assignment_string(text: &str) -> Option<String> {
578    let after_eq = text.split('=').nth(1)?.trim();
579    Some(clean_string_literal(after_eq))
580}
581
582fn extract_kwarg_value<'a>(text: &'a str, key: &str) -> Option<&'a str> {
583    let pattern = format!("{}=", key);
584    let start = text.find(&pattern)? + pattern.len();
585    let remaining = &text[start..];
586    let end = remaining.find([',', ')'])?;
587    Some(remaining[..end].trim())
588}
589
590#[cfg(test)]
591mod tests {
592    use super::*;
593    use crate::context::ParsedFile;
594
595    fn parse_file(source: &str, path: &str) -> ParsedFile {
596        ParsedFile::parse(source.to_string(), path.to_string()).unwrap()
597    }
598
599    #[test]
600    fn test_sqlalchemy_model() {
601        let source = r#"
602from sqlalchemy import Column, Integer, String, ForeignKey
603from sqlalchemy.orm import relationship
604
605class User(Base):
606    __tablename__ = "users"
607
608    id = Column(Integer, primary_key=True)
609    name = Column(String, nullable=False)
610    email = Column(String, unique=True)
611    posts = relationship("Post")
612"#;
613
614        let file = parse_file(source, "models/user.py");
615        let capability = extract(&file).unwrap();
616
617        let entity = &capability.entities[0];
618        assert_eq!(entity.name, "User");
619        assert_eq!(entity.table, "users");
620        assert_eq!(entity.fields.len(), 4);
621        assert!(
622            entity.fields[0]
623                .constraints
624                .contains(&"primary_key".to_string())
625        );
626        assert!(entity.fields[2].constraints.contains(&"unique".to_string()));
627    }
628
629    #[test]
630    fn test_django_model() {
631        let source = r#"
632from django.db import models
633
634class Article(models.Model):
635    title = models.CharField(max_length=200)
636    content = models.TextField()
637    author = models.ForeignKey("User", on_delete=models.CASCADE)
638    tags = models.ManyToManyField("Tag")
639"#;
640
641        let file = parse_file(source, "models.py");
642        let capability = extract(&file).unwrap();
643
644        let entity = &capability.entities[0];
645        assert_eq!(entity.name, "Article");
646        assert_eq!(entity.fields.len(), 4);
647        assert!(
648            entity.fields[0]
649                .constraints
650                .iter()
651                .any(|c| c.contains("max_length"))
652        );
653        assert!(
654            entity.fields[2]
655                .constraints
656                .contains(&"relationship".to_string())
657        );
658    }
659
660    #[test]
661    fn test_pydantic_model() {
662        let source = r#"
663from pydantic import BaseModel, Field
664
665class UserCreate(BaseModel):
666    name: str = Field(min_length=2, max_length=50)
667    email: str
668    age: int = Field(gt=0, lt=150)
669    bio: Optional[str] = None
670"#;
671
672        let file = parse_file(source, "schemas/user.py");
673        let capability = extract(&file).unwrap();
674
675        let entity = &capability.entities[0];
676        assert_eq!(entity.name, "UserCreate");
677        assert_eq!(entity.fields.len(), 4);
678        assert!(
679            entity.fields[0]
680                .constraints
681                .iter()
682                .any(|c| c.contains("min_length"))
683        );
684        // bio is Optional with default None → not required
685        assert!(
686            !entity.fields[3]
687                .constraints
688                .contains(&"required".to_string())
689        );
690    }
691
692    #[test]
693    fn test_no_model() {
694        let source = r#"
695class Helper:
696    def do_something(self):
697        pass
698"#;
699        let file = parse_file(source, "utils.py");
700        assert!(extract(&file).is_none());
701    }
702
703    #[test]
704    fn test_sqlmodel_table_entity() {
705        let source = r#"
706from sqlmodel import Field, SQLModel, Relationship
707
708class UserBase(SQLModel):
709    email: str = Field(unique=True, max_length=255)
710    is_active: bool = True
711
712class User(UserBase, table=True):
713    id: int = Field(primary_key=True)
714    hashed_password: str
715    items: list["Item"] = Relationship(back_populates="owner")
716"#;
717
718        let file = parse_file(source, "models.py");
719        let capability = extract(&file).unwrap();
720
721        // UserBase is Pydantic-like (no table=True), User is a table entity
722        assert!(capability.entities.len() >= 2);
723
724        let user = capability
725            .entities
726            .iter()
727            .find(|e| e.name == "User")
728            .unwrap();
729        assert_eq!(user.table, "user");
730        assert!(
731            user.fields
732                .iter()
733                .any(|f| f.name == "id" && f.constraints.contains(&"primary_key".to_string()))
734        );
735        assert!(user.fields.iter().any(|f| f.name == "hashed_password"));
736        assert!(
737            user.fields
738                .iter()
739                .any(|f| f.constraints.contains(&"relationship".to_string()))
740        );
741    }
742
743    #[test]
744    fn test_sqlalchemy_mapped_column() {
745        let source = r#"
746from sqlalchemy.orm import Mapped, mapped_column
747from sqlalchemy import Integer, String, Float
748
749class Trade(Base):
750    __tablename__ = "trades"
751
752    id: Mapped[int] = mapped_column(Integer, primary_key=True)
753    ticker: Mapped[str] = mapped_column(String(20))
754    price: Mapped[float] = mapped_column(Float)
755    status: Mapped[str] = mapped_column(String(20), unique=True)
756"#;
757
758        let file = parse_file(source, "models.py");
759        let capability = extract(&file).unwrap();
760
761        let trade = &capability.entities[0];
762        assert_eq!(trade.name, "Trade");
763        assert_eq!(trade.table, "trades");
764        assert_eq!(trade.fields.len(), 4);
765        assert!(
766            trade.fields[0]
767                .constraints
768                .contains(&"primary_key".to_string())
769        );
770        assert_eq!(trade.fields[1].name, "ticker");
771        assert_eq!(trade.fields[1].field_type, "String(20)");
772        assert!(trade.fields[3].constraints.contains(&"unique".to_string()));
773    }
774
775    #[test]
776    fn test_sqlmodel_inheritance_chain() {
777        // Tests same-file inheritance resolution: SQLModel -> UserBase -> UserCreate
778        let source = r#"
779from sqlmodel import Field, SQLModel
780
781class UserBase(SQLModel):
782    email: str = Field(max_length=255)
783
784class UserCreate(UserBase):
785    password: str = Field(min_length=8)
786
787class ItemBase(SQLModel):
788    title: str
789
790class ItemCreate(ItemBase):
791    pass
792"#;
793
794        let file = parse_file(source, "models.py");
795        let capability = extract(&file).unwrap();
796
797        let names: Vec<&str> = capability
798            .entities
799            .iter()
800            .map(|e| e.name.as_str())
801            .collect();
802        assert!(
803            names.contains(&"UserBase"),
804            "Should find UserBase, got: {:?}",
805            names
806        );
807        assert!(
808            names.contains(&"UserCreate"),
809            "Should find UserCreate (inherits UserBase), got: {:?}",
810            names
811        );
812        assert!(
813            names.contains(&"ItemBase"),
814            "Should find ItemBase, got: {:?}",
815            names
816        );
817        assert!(
818            names.contains(&"ItemCreate"),
819            "Should find ItemCreate (inherits ItemBase), got: {:?}",
820            names
821        );
822
823        // UserCreate should have its own field (password)
824        let user_create = capability
825            .entities
826            .iter()
827            .find(|e| e.name == "UserCreate")
828            .unwrap();
829        assert_eq!(user_create.fields.len(), 1);
830        assert_eq!(user_create.fields[0].name, "password");
831
832        // ItemCreate has `pass` body → 0 fields (inherited fields not resolved)
833        let item_create = capability
834            .entities
835            .iter()
836            .find(|e| e.name == "ItemCreate")
837            .unwrap();
838        assert_eq!(item_create.fields.len(), 0);
839    }
840
841    #[test]
842    fn test_base_class_not_entity() {
843        let source = r#"
844from sqlalchemy.orm import DeclarativeBase
845
846class Base(DeclarativeBase):
847    pass
848"#;
849        let file = parse_file(source, "database.py");
850        assert!(extract(&file).is_none());
851    }
852}