Skip to main content

pecto_core/
inheritance.rs

1use crate::model::{EntityField, ProjectSpec};
2use std::collections::HashMap;
3
4/// Merge inherited fields into entities that have `bases` populated.
5///
6/// For each entity with parent classes in `bases`, if those parents are also
7/// entities in the same spec, prepend the parent's fields (skipping duplicates).
8/// Handles inheritance chains (A -> B -> C) across all capabilities.
9///
10/// Works for all languages:
11/// - Python: `class UserCreate(UserBase)` where UserBase has email, is_active
12/// - Java: `class Pet extends BaseEntity` where BaseEntity has id
13/// - C#: `class Product : BaseEntity` where BaseEntity has Id
14/// - TypeScript: `class User extends BaseEntity` where BaseEntity has id
15pub fn merge_inherited_fields(spec: &mut ProjectSpec) {
16    // Build maps: entity name → fields, entity name → bases
17    let mut field_map: HashMap<String, Vec<EntityField>> = HashMap::new();
18    let mut bases_map: HashMap<String, Vec<String>> = HashMap::new();
19
20    for capability in &spec.capabilities {
21        for entity in &capability.entities {
22            field_map.insert(entity.name.clone(), entity.fields.clone());
23            if !entity.bases.is_empty() {
24                bases_map.insert(entity.name.clone(), entity.bases.clone());
25            }
26        }
27    }
28
29    // Resolve full inherited fields for each entity (with chain support)
30    let mut resolved: HashMap<String, Vec<EntityField>> = HashMap::new();
31
32    for capability in &spec.capabilities {
33        for entity in &capability.entities {
34            if entity.bases.is_empty() {
35                continue;
36            }
37
38            let inherited =
39                collect_inherited_fields(&entity.name, &entity.fields, &field_map, &bases_map);
40
41            if !inherited.is_empty() {
42                resolved.insert(entity.name.clone(), inherited);
43            }
44        }
45    }
46
47    // Apply resolved fields
48    for capability in &mut spec.capabilities {
49        for entity in &mut capability.entities {
50            if let Some(mut inherited) = resolved.remove(&entity.name) {
51                inherited.append(&mut entity.fields);
52                entity.fields = inherited;
53            }
54        }
55    }
56}
57
58/// Collect fields from all ancestors, walking up the chain top-down.
59/// Returns fields to prepend (parent fields not already in the child).
60/// Fields from the most distant ancestor come first (e.g., Base.id before Named.name).
61fn collect_inherited_fields(
62    entity_name: &str,
63    own_fields: &[EntityField],
64    field_map: &HashMap<String, Vec<EntityField>>,
65    bases_map: &HashMap<String, Vec<String>>,
66) -> Vec<EntityField> {
67    // First, build the ordered ancestor chain from most distant to most immediate
68    let mut ancestor_chain: Vec<String> = Vec::new();
69    let mut visited = vec![entity_name.to_string()];
70    collect_ancestors(entity_name, bases_map, &mut ancestor_chain, &mut visited);
71
72    // Now collect fields top-down (most distant ancestor first)
73    let mut inherited = Vec::new();
74    for ancestor in &ancestor_chain {
75        if let Some(ancestor_fields) = field_map.get(ancestor) {
76            for field in ancestor_fields {
77                let already_exists = inherited.iter().any(|f: &EntityField| f.name == field.name)
78                    || own_fields.iter().any(|f| f.name == field.name);
79                if !already_exists {
80                    inherited.push(field.clone());
81                }
82            }
83        }
84    }
85
86    inherited
87}
88
89/// Recursively collect ancestors in top-down order (most distant first).
90fn collect_ancestors(
91    entity_name: &str,
92    bases_map: &HashMap<String, Vec<String>>,
93    chain: &mut Vec<String>,
94    visited: &mut Vec<String>,
95) {
96    let Some(bases) = bases_map.get(entity_name) else {
97        return;
98    };
99
100    for base in bases {
101        if visited.contains(base) {
102            continue;
103        }
104        visited.push(base.clone());
105
106        // Recurse to get grandparent first
107        collect_ancestors(base, bases_map, chain, visited);
108
109        // Then add this parent
110        chain.push(base.clone());
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117    use crate::model::*;
118
119    fn make_entity(name: &str, fields: Vec<(&str, &str)>, bases: Vec<&str>) -> Entity {
120        Entity {
121            name: name.to_string(),
122            table: name.to_lowercase(),
123            fields: fields
124                .into_iter()
125                .map(|(n, t)| EntityField {
126                    name: n.to_string(),
127                    field_type: t.to_string(),
128                    constraints: Vec::new(),
129                })
130                .collect(),
131            bases: bases.into_iter().map(|b| b.to_string()).collect(),
132        }
133    }
134
135    #[test]
136    fn test_simple_inheritance() {
137        let mut spec = ProjectSpec::new("test".to_string());
138
139        let mut cap = Capability::new("models".to_string(), "models.py".to_string());
140        cap.entities = vec![
141            make_entity(
142                "UserBase",
143                vec![("email", "str"), ("is_active", "bool")],
144                vec![],
145            ),
146            make_entity("UserCreate", vec![("password", "str")], vec!["UserBase"]),
147        ];
148        spec.capabilities.push(cap);
149
150        merge_inherited_fields(&mut spec);
151
152        let user_create = &spec.capabilities[0].entities[1];
153        assert_eq!(user_create.name, "UserCreate");
154        assert_eq!(user_create.fields.len(), 3); // email + is_active + password
155        assert_eq!(user_create.fields[0].name, "email");
156        assert_eq!(user_create.fields[1].name, "is_active");
157        assert_eq!(user_create.fields[2].name, "password");
158    }
159
160    #[test]
161    fn test_chain_inheritance() {
162        // A -> B -> C: C should get fields from both A and B
163        let mut spec = ProjectSpec::new("test".to_string());
164
165        let mut cap = Capability::new("models".to_string(), "models.py".to_string());
166        cap.entities = vec![
167            make_entity("Base", vec![("id", "int")], vec![]),
168            make_entity("Named", vec![("name", "str")], vec!["Base"]),
169            make_entity("Pet", vec![("breed", "str")], vec!["Named"]),
170        ];
171        spec.capabilities.push(cap);
172
173        merge_inherited_fields(&mut spec);
174
175        // Named should have id + name
176        let named = &spec.capabilities[0].entities[1];
177        assert_eq!(named.fields.len(), 2);
178        assert_eq!(named.fields[0].name, "id");
179
180        // Pet should have id + name + breed
181        let pet = &spec.capabilities[0].entities[2];
182        assert_eq!(pet.fields.len(), 3);
183        assert_eq!(pet.fields[0].name, "id");
184        assert_eq!(pet.fields[1].name, "name");
185        assert_eq!(pet.fields[2].name, "breed");
186    }
187
188    #[test]
189    fn test_field_override() {
190        // Child redefines a parent field → should keep child's version
191        let mut spec = ProjectSpec::new("test".to_string());
192
193        let mut cap = Capability::new("models".to_string(), "models.py".to_string());
194        cap.entities = vec![
195            make_entity("Base", vec![("email", "str"), ("name", "str")], vec![]),
196            make_entity(
197                "User",
198                vec![("email", "EmailStr"), ("age", "int")],
199                vec!["Base"],
200            ),
201        ];
202        spec.capabilities.push(cap);
203
204        merge_inherited_fields(&mut spec);
205
206        let user = &spec.capabilities[0].entities[1];
207        assert_eq!(user.fields.len(), 3); // name (inherited) + email (own) + age (own)
208        // email should be the child's version (EmailStr), not parent's (str)
209        let email = user.fields.iter().find(|f| f.name == "email").unwrap();
210        assert_eq!(email.field_type, "EmailStr");
211    }
212
213    #[test]
214    fn test_no_bases_unchanged() {
215        let mut spec = ProjectSpec::new("test".to_string());
216
217        let mut cap = Capability::new("models".to_string(), "models.py".to_string());
218        cap.entities = vec![make_entity(
219            "User",
220            vec![("id", "int"), ("name", "str")],
221            vec![],
222        )];
223        spec.capabilities.push(cap);
224
225        merge_inherited_fields(&mut spec);
226
227        let user = &spec.capabilities[0].entities[0];
228        assert_eq!(user.fields.len(), 2); // unchanged
229    }
230
231    #[test]
232    fn test_cross_capability_inheritance() {
233        // Parent entity in one capability, child in another (e.g., different files)
234        let mut spec = ProjectSpec::new("test".to_string());
235
236        let mut cap1 = Capability::new("base-model".to_string(), "base.py".to_string());
237        cap1.entities = vec![make_entity(
238            "BaseEntity",
239            vec![("id", "int"), ("created_at", "datetime")],
240            vec![],
241        )];
242        spec.capabilities.push(cap1);
243
244        let mut cap2 = Capability::new("user-model".to_string(), "user.py".to_string());
245        cap2.entities = vec![make_entity(
246            "User",
247            vec![("name", "str")],
248            vec!["BaseEntity"],
249        )];
250        spec.capabilities.push(cap2);
251
252        merge_inherited_fields(&mut spec);
253
254        let user = &spec.capabilities[1].entities[0];
255        assert_eq!(user.fields.len(), 3); // id + created_at + name
256        assert_eq!(user.fields[0].name, "id");
257    }
258}