elif_orm/relationships/
inference.rs

1//! Relationship Type Inference - Utilities for inferring relationship types and configurations
2
3use serde::de::DeserializeOwned;
4use std::collections::HashMap;
5use std::marker::PhantomData;
6
7use super::metadata::{
8    ForeignKeyConfig, PivotConfig, PolymorphicConfig, RelationshipMetadata, RelationshipType,
9};
10use crate::error::ModelResult;
11use crate::model::Model;
12
13/// Trait for models that can have their relationships inferred
14pub trait InferableModel: Model {
15    /// Get relationship inference hints for this model
16    fn relationship_hints() -> Vec<RelationshipHint> {
17        Vec::new()
18    }
19
20    /// Get foreign key naming convention for this model
21    fn foreign_key_convention() -> ForeignKeyConvention {
22        ForeignKeyConvention::Underscore
23    }
24
25    /// Get table naming convention for this model
26    fn table_naming_convention() -> TableNamingConvention {
27        TableNamingConvention::Plural
28    }
29}
30
31/// Hint for relationship inference
32#[derive(Debug, Clone)]
33pub struct RelationshipHint {
34    /// The field name in the model
35    pub field_name: String,
36
37    /// The expected relationship type
38    pub relationship_type: RelationshipType,
39
40    /// The related model type name
41    pub related_model: String,
42
43    /// Custom foreign key if different from convention
44    pub custom_foreign_key: Option<String>,
45
46    /// Whether this should be eagerly loaded by default
47    pub eager_load: bool,
48}
49
50/// Foreign key naming conventions
51#[derive(Debug, Clone, Copy, PartialEq)]
52pub enum ForeignKeyConvention {
53    /// model_id (e.g., user_id)
54    Underscore,
55    /// modelId (camelCase)
56    CamelCase,
57    /// modelID (PascalCase with ID suffix)
58    PascalCase,
59    /// Custom pattern with {model} placeholder
60    Custom(&'static str),
61}
62
63/// Table naming conventions
64#[derive(Debug, Clone, Copy, PartialEq)]
65pub enum TableNamingConvention {
66    /// Plural form (users, posts)
67    Plural,
68    /// Singular form (user, post)
69    Singular,
70    /// Custom pattern
71    Custom(&'static str),
72}
73
74/// Relationship type inference engine
75pub struct RelationshipInferenceEngine<Parent>
76where
77    Parent: InferableModel + DeserializeOwned + Send + Sync,
78{
79    parent_model: PhantomData<Parent>,
80
81    /// Cache of inferred relationships
82    inference_cache: HashMap<String, RelationshipMetadata>,
83}
84
85impl<Parent> RelationshipInferenceEngine<Parent>
86where
87    Parent: InferableModel + DeserializeOwned + Send + Sync,
88{
89    /// Create a new inference engine
90    pub fn new() -> Self {
91        Self {
92            parent_model: PhantomData,
93            inference_cache: HashMap::new(),
94        }
95    }
96
97    /// Infer relationship metadata for a given field name and related model
98    pub fn infer_relationship<Related>(
99        &mut self,
100        field_name: &str,
101        relationship_type: RelationshipType,
102    ) -> ModelResult<RelationshipMetadata>
103    where
104        Related: InferableModel + DeserializeOwned + Send + Sync,
105    {
106        let cache_key = format!("{}::{}", field_name, std::any::type_name::<Related>());
107
108        // Check cache first
109        if let Some(cached) = self.inference_cache.get(&cache_key) {
110            return Ok(cached.clone());
111        }
112
113        let metadata =
114            self.infer_relationship_metadata::<Related>(field_name, relationship_type)?;
115
116        // Cache the result
117        self.inference_cache.insert(cache_key, metadata.clone());
118
119        Ok(metadata)
120    }
121
122    /// Infer all relationships for the parent model using hints
123    pub fn infer_all_relationships(&mut self) -> ModelResult<Vec<RelationshipMetadata>> {
124        let hints = Parent::relationship_hints();
125        let mut relationships = Vec::new();
126
127        for hint in hints {
128            let metadata = self.infer_from_hint(&hint)?;
129            relationships.push(metadata);
130        }
131
132        Ok(relationships)
133    }
134
135    /// Infer relationship metadata from model structure and naming conventions
136    fn infer_relationship_metadata<Related>(
137        &self,
138        field_name: &str,
139        relationship_type: RelationshipType,
140    ) -> ModelResult<RelationshipMetadata>
141    where
142        Related: InferableModel + DeserializeOwned + Send + Sync,
143    {
144        let parent_table = Parent::table_name();
145        let related_table = Related::table_name();
146        let related_model_name = std::any::type_name::<Related>()
147            .split("::")
148            .last()
149            .unwrap_or(std::any::type_name::<Related>());
150
151        let foreign_key_config = match relationship_type {
152            RelationshipType::HasOne | RelationshipType::HasMany => {
153                // Related table has foreign key pointing to parent
154                let foreign_key = self.infer_foreign_key_name(Parent::table_name())?;
155                ForeignKeyConfig::simple(foreign_key, related_table.to_string())
156            }
157            RelationshipType::BelongsTo => {
158                // Parent table has foreign key pointing to related
159                let foreign_key = self.infer_foreign_key_name(Related::table_name())?;
160                ForeignKeyConfig::simple(foreign_key, parent_table.to_string())
161            }
162            RelationshipType::ManyToMany => {
163                // Pivot table with both foreign keys
164                let pivot_table = self.infer_pivot_table_name(parent_table, related_table);
165                let local_key = self.infer_foreign_key_name(parent_table)?;
166                let foreign_key = self.infer_foreign_key_name(related_table)?;
167
168                return Ok(RelationshipMetadata::new(
169                    relationship_type,
170                    field_name.to_string(),
171                    related_table.to_string(),
172                    related_model_name.to_string(),
173                    ForeignKeyConfig::simple(local_key.clone(), pivot_table.clone()),
174                )
175                .with_pivot(PivotConfig::new(pivot_table, local_key, foreign_key)));
176            }
177            RelationshipType::MorphOne
178            | RelationshipType::MorphMany
179            | RelationshipType::MorphTo => {
180                // Polymorphic relationships
181                let (type_column, id_column) = self.infer_polymorphic_columns(field_name);
182
183                return Ok(RelationshipMetadata::new(
184                    relationship_type,
185                    field_name.to_string(),
186                    related_table.to_string(),
187                    related_model_name.to_string(),
188                    ForeignKeyConfig::simple(id_column.clone(), related_table.to_string()),
189                )
190                .with_polymorphic(PolymorphicConfig::new(
191                    field_name.to_string(),
192                    type_column,
193                    id_column,
194                )));
195            }
196        };
197
198        Ok(RelationshipMetadata::new(
199            relationship_type,
200            field_name.to_string(),
201            related_table.to_string(),
202            related_model_name.to_string(),
203            foreign_key_config,
204        ))
205    }
206
207    /// Infer relationship metadata from a hint
208    fn infer_from_hint(&self, hint: &RelationshipHint) -> ModelResult<RelationshipMetadata> {
209        let foreign_key = hint.custom_foreign_key.clone().unwrap_or_else(|| {
210            match self.infer_foreign_key_name(&hint.related_model.to_lowercase()) {
211                Ok(fk) => fk,
212                Err(_) => format!("{}_id", hint.related_model.to_lowercase()),
213            }
214        });
215
216        let related_table = self.infer_table_name(&hint.related_model);
217
218        let mut metadata = RelationshipMetadata::new(
219            hint.relationship_type,
220            hint.field_name.clone(),
221            related_table,
222            hint.related_model.clone(),
223            ForeignKeyConfig::simple(foreign_key, hint.related_model.to_lowercase()),
224        );
225
226        metadata.eager_load = hint.eager_load;
227
228        Ok(metadata)
229    }
230
231    /// Infer foreign key name based on convention
232    pub fn infer_foreign_key_name(&self, table_or_model: &str) -> ModelResult<String> {
233        let convention = Parent::foreign_key_convention();
234
235        match convention {
236            ForeignKeyConvention::Underscore => {
237                let singular = self.singularize_table_name(table_or_model);
238                Ok(format!("{}_id", singular))
239            }
240            ForeignKeyConvention::CamelCase => {
241                let singular = self.singularize_table_name(table_or_model);
242                Ok(format!("{}Id", self.to_camel_case(&singular)))
243            }
244            ForeignKeyConvention::PascalCase => {
245                let singular = self.singularize_table_name(table_or_model);
246                Ok(format!("{}ID", self.to_pascal_case(&singular)))
247            }
248            ForeignKeyConvention::Custom(pattern) => {
249                let singular = self.singularize_table_name(table_or_model);
250                Ok(pattern.replace("{model}", &singular))
251            }
252        }
253    }
254
255    /// Infer pivot table name for many-to-many relationships
256    fn infer_pivot_table_name(&self, table1: &str, table2: &str) -> String {
257        let mut tables = [table1, table2];
258        tables.sort();
259        tables.join("_")
260    }
261
262    /// Infer polymorphic column names
263    fn infer_polymorphic_columns(&self, field_name: &str) -> (String, String) {
264        // Standard Laravel-style naming: commentable_type, commentable_id
265        let base = if field_name.ends_with("able") {
266            field_name.to_string()
267        } else {
268            format!("{}_able", field_name)
269        };
270
271        (format!("{}_type", base), format!("{}_id", base))
272    }
273
274    /// Infer table name from model name
275    pub fn infer_table_name(&self, model_name: &str) -> String {
276        let convention = Parent::table_naming_convention();
277        let base_name = model_name.to_lowercase();
278
279        match convention {
280            TableNamingConvention::Plural => self.pluralize_name(&base_name),
281            TableNamingConvention::Singular => base_name,
282            TableNamingConvention::Custom(pattern) => pattern.replace("{model}", &base_name),
283        }
284    }
285
286    /// Simple pluralization (English-centric)
287    pub fn pluralize_name(&self, name: &str) -> String {
288        if name.ends_with('y')
289            && !name.ends_with("ay")
290            && !name.ends_with("ey")
291            && !name.ends_with("iy")
292            && !name.ends_with("oy")
293            && !name.ends_with("uy")
294        {
295            format!("{}ies", &name[..name.len() - 1])
296        } else if name.ends_with('s')
297            || name.ends_with("sh")
298            || name.ends_with("ch")
299            || name.ends_with('x')
300            || name.ends_with('z')
301        {
302            format!("{}es", name)
303        } else {
304            format!("{}s", name)
305        }
306    }
307
308    /// Simple singularization (English-centric)  
309    pub fn singularize_table_name(&self, name: &str) -> String {
310        if name.ends_with("ies") {
311            format!("{}y", &name[..name.len() - 3])
312        } else if name.ends_with("ses")
313            || name.ends_with("ches")
314            || name.ends_with("shes")
315            || name.ends_with("xes")
316            || name.ends_with("zes")
317        {
318            name[..name.len() - 2].to_string()
319        } else if name.ends_with('s') && name.len() > 1 {
320            name[..name.len() - 1].to_string()
321        } else {
322            name.to_string()
323        }
324    }
325
326    /// Convert to camelCase
327    pub fn to_camel_case(&self, s: &str) -> String {
328        let parts: Vec<&str> = s.split('_').collect();
329        if parts.is_empty() {
330            return s.to_string();
331        }
332
333        let mut result = parts[0].to_lowercase();
334        for part in &parts[1..] {
335            if !part.is_empty() {
336                let mut chars = part.chars();
337                if let Some(first) = chars.next() {
338                    result.push(first.to_uppercase().next().unwrap());
339                    result.extend(chars.flat_map(|c| c.to_lowercase()));
340                }
341            }
342        }
343
344        result
345    }
346
347    /// Convert to PascalCase
348    pub fn to_pascal_case(&self, s: &str) -> String {
349        let camel = self.to_camel_case(s);
350        let mut chars = camel.chars();
351        if let Some(first) = chars.next() {
352            first.to_uppercase().collect::<String>() + &chars.collect::<String>()
353        } else {
354            camel
355        }
356    }
357}
358
359impl<Parent> Default for RelationshipInferenceEngine<Parent>
360where
361    Parent: InferableModel + DeserializeOwned + Send + Sync,
362{
363    fn default() -> Self {
364        Self::new()
365    }
366}
367
368/// Utility for inferring relationship types from field types
369pub struct TypeInferenceHelper;
370
371impl TypeInferenceHelper {
372    /// Infer relationship type from Rust type information
373    pub fn infer_from_type_name(type_name: &str) -> Option<RelationshipType> {
374        if type_name.contains("Option<") {
375            // Single optional relationship
376            if type_name.contains("Vec<") {
377                None // Shouldn't have Option<Vec<T>>
378            } else {
379                Some(RelationshipType::HasOne) // Default to HasOne for Option<T>
380            }
381        } else if type_name.contains("Vec<") {
382            Some(RelationshipType::HasMany) // Collection relationship
383        } else if type_name.contains("MorphOne<") {
384            Some(RelationshipType::MorphOne)
385        } else if type_name.contains("MorphMany<") {
386            Some(RelationshipType::MorphMany)
387        } else if type_name.contains("MorphTo<") {
388            Some(RelationshipType::MorphTo)
389        } else {
390            None // Can't infer from basic types
391        }
392    }
393
394    /// Check if a field name suggests a specific relationship type
395    pub fn infer_from_field_name(field_name: &str) -> Option<RelationshipType> {
396        if field_name.ends_with("_id") {
397            Some(RelationshipType::BelongsTo)
398        } else if field_name.ends_with("_ids") {
399            Some(RelationshipType::ManyToMany)
400        } else if field_name.ends_with("able") || field_name.contains("morph") {
401            Some(RelationshipType::MorphTo) // Default for polymorphic
402        } else {
403            None
404        }
405    }
406
407    /// Suggest relationship type based on multiple hints
408    pub fn suggest_relationship_type(
409        field_name: &str,
410        type_name: &str,
411        is_collection: bool,
412        is_optional: bool,
413    ) -> RelationshipType {
414        // Try field name inference first
415        if let Some(rt) = Self::infer_from_field_name(field_name) {
416            return rt;
417        }
418
419        // Try type name inference
420        if let Some(rt) = Self::infer_from_type_name(type_name) {
421            return rt;
422        }
423
424        // Fall back to collection/optional hints
425        match (is_collection, is_optional) {
426            (true, _) => RelationshipType::HasMany,
427            (false, true) => RelationshipType::HasOne,
428            (false, false) => RelationshipType::BelongsTo, // Required single relationship
429        }
430    }
431}
432
433/// Macro helper for generating relationship hints
434#[macro_export]
435macro_rules! relationship_hints {
436    ($(($field:expr, $type:expr, $related:expr, $eager:expr)),* $(,)?) => {
437        vec![
438            $(
439                $crate::relationships::inference::RelationshipHint {
440                    field_name: $field.to_string(),
441                    relationship_type: $type,
442                    related_model: $related.to_string(),
443                    custom_foreign_key: None,
444                    eager_load: $eager,
445                }
446            ),*
447        ]
448    };
449
450    ($(($field:expr, $type:expr, $related:expr, $eager:expr, $fk:expr)),* $(,)?) => {
451        vec![
452            $(
453                $crate::relationships::inference::RelationshipHint {
454                    field_name: $field.to_string(),
455                    relationship_type: $type,
456                    related_model: $related.to_string(),
457                    custom_foreign_key: Some($fk.to_string()),
458                    eager_load: $eager,
459                }
460            ),*
461        ]
462    };
463}
464
465#[cfg(test)]
466mod tests {
467    use super::super::metadata::RelationshipType;
468    use super::*;
469
470    #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
471    struct TestUser {
472        id: Option<i64>,
473        name: String,
474        email: String,
475    }
476
477    impl Model for TestUser {
478        type PrimaryKey = i64;
479
480        fn table_name() -> &'static str {
481            "users"
482        }
483
484        fn primary_key(&self) -> Option<Self::PrimaryKey> {
485            self.id
486        }
487
488        fn set_primary_key(&mut self, key: Self::PrimaryKey) {
489            self.id = Some(key);
490        }
491
492        fn to_fields(&self) -> std::collections::HashMap<String, serde_json::Value> {
493            let mut fields = std::collections::HashMap::new();
494            fields.insert("id".to_string(), serde_json::json!(self.id));
495            fields.insert(
496                "name".to_string(),
497                serde_json::Value::String(self.name.clone()),
498            );
499            fields.insert(
500                "email".to_string(),
501                serde_json::Value::String(self.email.clone()),
502            );
503            fields
504        }
505
506        fn from_row(row: &sqlx::postgres::PgRow) -> crate::error::ModelResult<Self> {
507            use sqlx::Row;
508            Ok(Self {
509                id: row.try_get("id").ok(),
510                name: row.try_get("name").unwrap_or_default(),
511                email: row.try_get("email").unwrap_or_default(),
512            })
513        }
514    }
515
516    impl InferableModel for TestUser {
517        fn relationship_hints() -> Vec<RelationshipHint> {
518            relationship_hints![
519                ("posts", RelationshipType::HasMany, "Post", false),
520                ("profile", RelationshipType::HasOne, "Profile", true),
521                ("roles", RelationshipType::ManyToMany, "Role", false)
522            ]
523        }
524    }
525
526    #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
527    struct TestPost {
528        id: Option<i64>,
529        title: String,
530        user_id: Option<i64>,
531    }
532
533    impl Model for TestPost {
534        type PrimaryKey = i64;
535
536        fn table_name() -> &'static str {
537            "posts"
538        }
539
540        fn primary_key(&self) -> Option<Self::PrimaryKey> {
541            self.id
542        }
543
544        fn set_primary_key(&mut self, key: Self::PrimaryKey) {
545            self.id = Some(key);
546        }
547
548        fn to_fields(&self) -> std::collections::HashMap<String, serde_json::Value> {
549            let mut fields = std::collections::HashMap::new();
550            fields.insert("id".to_string(), serde_json::json!(self.id));
551            fields.insert(
552                "title".to_string(),
553                serde_json::Value::String(self.title.clone()),
554            );
555            fields.insert("user_id".to_string(), serde_json::json!(self.user_id));
556            fields
557        }
558
559        fn from_row(row: &sqlx::postgres::PgRow) -> crate::error::ModelResult<Self> {
560            use sqlx::Row;
561            Ok(Self {
562                id: row.try_get("id").ok(),
563                title: row.try_get("title").unwrap_or_default(),
564                user_id: row.try_get("user_id").ok(),
565            })
566        }
567    }
568
569    impl InferableModel for TestPost {
570        fn relationship_hints() -> Vec<RelationshipHint> {
571            relationship_hints![("user", RelationshipType::BelongsTo, "User", true)]
572        }
573    }
574
575    #[test]
576    fn test_inference_engine_creation() {
577        let mut engine = RelationshipInferenceEngine::<TestUser>::new();
578
579        // Test that we can create and use the engine
580        let relationships = engine.infer_all_relationships().unwrap();
581
582        assert_eq!(relationships.len(), 3);
583        assert_eq!(relationships[0].name, "posts");
584        assert_eq!(
585            relationships[0].relationship_type,
586            RelationshipType::HasMany
587        );
588        assert_eq!(relationships[1].name, "profile");
589        assert_eq!(relationships[1].relationship_type, RelationshipType::HasOne);
590        assert!(relationships[1].eager_load);
591    }
592
593    #[test]
594    fn test_foreign_key_inference() {
595        let engine = RelationshipInferenceEngine::<TestUser>::new();
596
597        let fk = engine.infer_foreign_key_name("user").unwrap();
598        assert_eq!(fk, "user_id");
599
600        let fk = engine.infer_foreign_key_name("posts").unwrap();
601        assert_eq!(fk, "post_id");
602    }
603
604    #[test]
605    fn test_pluralization() {
606        let engine = RelationshipInferenceEngine::<TestUser>::new();
607
608        assert_eq!(engine.pluralize_name("user"), "users");
609        assert_eq!(engine.pluralize_name("post"), "posts");
610        assert_eq!(engine.pluralize_name("category"), "categories");
611        assert_eq!(engine.pluralize_name("box"), "boxes");
612    }
613
614    #[test]
615    fn test_singularization() {
616        let engine = RelationshipInferenceEngine::<TestUser>::new();
617
618        assert_eq!(engine.singularize_table_name("users"), "user");
619        assert_eq!(engine.singularize_table_name("posts"), "post");
620        assert_eq!(engine.singularize_table_name("categories"), "category");
621        assert_eq!(engine.singularize_table_name("boxes"), "box");
622    }
623
624    #[test]
625    fn test_case_conversion() {
626        let engine = RelationshipInferenceEngine::<TestUser>::new();
627
628        assert_eq!(engine.to_camel_case("user_id"), "userId");
629        assert_eq!(engine.to_camel_case("user"), "user");
630        assert_eq!(engine.to_pascal_case("user_id"), "UserId");
631        assert_eq!(engine.to_pascal_case("user"), "User");
632    }
633
634    #[test]
635    fn test_type_inference_helper() {
636        assert_eq!(
637            TypeInferenceHelper::infer_from_type_name("Option<Post>"),
638            Some(RelationshipType::HasOne)
639        );
640        assert_eq!(
641            TypeInferenceHelper::infer_from_type_name("Vec<Post>"),
642            Some(RelationshipType::HasMany)
643        );
644        assert_eq!(
645            TypeInferenceHelper::infer_from_field_name("user_id"),
646            Some(RelationshipType::BelongsTo)
647        );
648        assert_eq!(
649            TypeInferenceHelper::infer_from_field_name("role_ids"),
650            Some(RelationshipType::ManyToMany)
651        );
652    }
653
654    #[test]
655    fn test_relationship_type_suggestion() {
656        let rt = TypeInferenceHelper::suggest_relationship_type("posts", "Vec<Post>", true, false);
657        assert_eq!(rt, RelationshipType::HasMany);
658
659        let rt = TypeInferenceHelper::suggest_relationship_type("user_id", "i64", false, false);
660        assert_eq!(rt, RelationshipType::BelongsTo);
661
662        let rt = TypeInferenceHelper::suggest_relationship_type(
663            "profile",
664            "Option<Profile>",
665            false,
666            true,
667        );
668        assert_eq!(rt, RelationshipType::HasOne);
669    }
670}