1use serde::de::DeserializeOwned;
4use std::collections::HashMap;
5use std::marker::PhantomData;
6
7use super::metadata::{
8 ForeignKeyConfig, PivotConfig, PolymorphicConfig, RelationshipMetadata, RelationshipType,
9};
10use crate::error::ModelResult;
11use crate::model::Model;
12
13pub trait InferableModel: Model {
15 fn relationship_hints() -> Vec<RelationshipHint> {
17 Vec::new()
18 }
19
20 fn foreign_key_convention() -> ForeignKeyConvention {
22 ForeignKeyConvention::Underscore
23 }
24
25 fn table_naming_convention() -> TableNamingConvention {
27 TableNamingConvention::Plural
28 }
29}
30
31#[derive(Debug, Clone)]
33pub struct RelationshipHint {
34 pub field_name: String,
36
37 pub relationship_type: RelationshipType,
39
40 pub related_model: String,
42
43 pub custom_foreign_key: Option<String>,
45
46 pub eager_load: bool,
48}
49
50#[derive(Debug, Clone, Copy, PartialEq)]
52pub enum ForeignKeyConvention {
53 Underscore,
55 CamelCase,
57 PascalCase,
59 Custom(&'static str),
61}
62
63#[derive(Debug, Clone, Copy, PartialEq)]
65pub enum TableNamingConvention {
66 Plural,
68 Singular,
70 Custom(&'static str),
72}
73
74pub struct RelationshipInferenceEngine<Parent>
76where
77 Parent: InferableModel + DeserializeOwned + Send + Sync,
78{
79 parent_model: PhantomData<Parent>,
80
81 inference_cache: HashMap<String, RelationshipMetadata>,
83}
84
85impl<Parent> RelationshipInferenceEngine<Parent>
86where
87 Parent: InferableModel + DeserializeOwned + Send + Sync,
88{
89 pub fn new() -> Self {
91 Self {
92 parent_model: PhantomData,
93 inference_cache: HashMap::new(),
94 }
95 }
96
97 pub fn infer_relationship<Related>(
99 &mut self,
100 field_name: &str,
101 relationship_type: RelationshipType,
102 ) -> ModelResult<RelationshipMetadata>
103 where
104 Related: InferableModel + DeserializeOwned + Send + Sync,
105 {
106 let cache_key = format!("{}::{}", field_name, std::any::type_name::<Related>());
107
108 if let Some(cached) = self.inference_cache.get(&cache_key) {
110 return Ok(cached.clone());
111 }
112
113 let metadata =
114 self.infer_relationship_metadata::<Related>(field_name, relationship_type)?;
115
116 self.inference_cache.insert(cache_key, metadata.clone());
118
119 Ok(metadata)
120 }
121
122 pub fn infer_all_relationships(&mut self) -> ModelResult<Vec<RelationshipMetadata>> {
124 let hints = Parent::relationship_hints();
125 let mut relationships = Vec::new();
126
127 for hint in hints {
128 let metadata = self.infer_from_hint(&hint)?;
129 relationships.push(metadata);
130 }
131
132 Ok(relationships)
133 }
134
135 fn infer_relationship_metadata<Related>(
137 &self,
138 field_name: &str,
139 relationship_type: RelationshipType,
140 ) -> ModelResult<RelationshipMetadata>
141 where
142 Related: InferableModel + DeserializeOwned + Send + Sync,
143 {
144 let parent_table = Parent::table_name();
145 let related_table = Related::table_name();
146 let related_model_name = std::any::type_name::<Related>()
147 .split("::")
148 .last()
149 .unwrap_or(std::any::type_name::<Related>());
150
151 let foreign_key_config = match relationship_type {
152 RelationshipType::HasOne | RelationshipType::HasMany => {
153 let foreign_key = self.infer_foreign_key_name(Parent::table_name())?;
155 ForeignKeyConfig::simple(foreign_key, related_table.to_string())
156 }
157 RelationshipType::BelongsTo => {
158 let foreign_key = self.infer_foreign_key_name(Related::table_name())?;
160 ForeignKeyConfig::simple(foreign_key, parent_table.to_string())
161 }
162 RelationshipType::ManyToMany => {
163 let pivot_table = self.infer_pivot_table_name(parent_table, related_table);
165 let local_key = self.infer_foreign_key_name(parent_table)?;
166 let foreign_key = self.infer_foreign_key_name(related_table)?;
167
168 return Ok(RelationshipMetadata::new(
169 relationship_type,
170 field_name.to_string(),
171 related_table.to_string(),
172 related_model_name.to_string(),
173 ForeignKeyConfig::simple(local_key.clone(), pivot_table.clone()),
174 )
175 .with_pivot(PivotConfig::new(pivot_table, local_key, foreign_key)));
176 }
177 RelationshipType::MorphOne
178 | RelationshipType::MorphMany
179 | RelationshipType::MorphTo => {
180 let (type_column, id_column) = self.infer_polymorphic_columns(field_name);
182
183 return Ok(RelationshipMetadata::new(
184 relationship_type,
185 field_name.to_string(),
186 related_table.to_string(),
187 related_model_name.to_string(),
188 ForeignKeyConfig::simple(id_column.clone(), related_table.to_string()),
189 )
190 .with_polymorphic(PolymorphicConfig::new(
191 field_name.to_string(),
192 type_column,
193 id_column,
194 )));
195 }
196 };
197
198 Ok(RelationshipMetadata::new(
199 relationship_type,
200 field_name.to_string(),
201 related_table.to_string(),
202 related_model_name.to_string(),
203 foreign_key_config,
204 ))
205 }
206
207 fn infer_from_hint(&self, hint: &RelationshipHint) -> ModelResult<RelationshipMetadata> {
209 let foreign_key = hint.custom_foreign_key.clone().unwrap_or_else(|| {
210 match self.infer_foreign_key_name(&hint.related_model.to_lowercase()) {
211 Ok(fk) => fk,
212 Err(_) => format!("{}_id", hint.related_model.to_lowercase()),
213 }
214 });
215
216 let related_table = self.infer_table_name(&hint.related_model);
217
218 let mut metadata = RelationshipMetadata::new(
219 hint.relationship_type,
220 hint.field_name.clone(),
221 related_table,
222 hint.related_model.clone(),
223 ForeignKeyConfig::simple(foreign_key, hint.related_model.to_lowercase()),
224 );
225
226 metadata.eager_load = hint.eager_load;
227
228 Ok(metadata)
229 }
230
231 pub fn infer_foreign_key_name(&self, table_or_model: &str) -> ModelResult<String> {
233 let convention = Parent::foreign_key_convention();
234
235 match convention {
236 ForeignKeyConvention::Underscore => {
237 let singular = self.singularize_table_name(table_or_model);
238 Ok(format!("{}_id", singular))
239 }
240 ForeignKeyConvention::CamelCase => {
241 let singular = self.singularize_table_name(table_or_model);
242 Ok(format!("{}Id", self.to_camel_case(&singular)))
243 }
244 ForeignKeyConvention::PascalCase => {
245 let singular = self.singularize_table_name(table_or_model);
246 Ok(format!("{}ID", self.to_pascal_case(&singular)))
247 }
248 ForeignKeyConvention::Custom(pattern) => {
249 let singular = self.singularize_table_name(table_or_model);
250 Ok(pattern.replace("{model}", &singular))
251 }
252 }
253 }
254
255 fn infer_pivot_table_name(&self, table1: &str, table2: &str) -> String {
257 let mut tables = [table1, table2];
258 tables.sort();
259 tables.join("_")
260 }
261
262 fn infer_polymorphic_columns(&self, field_name: &str) -> (String, String) {
264 let base = if field_name.ends_with("able") {
266 field_name.to_string()
267 } else {
268 format!("{}_able", field_name)
269 };
270
271 (format!("{}_type", base), format!("{}_id", base))
272 }
273
274 pub fn infer_table_name(&self, model_name: &str) -> String {
276 let convention = Parent::table_naming_convention();
277 let base_name = model_name.to_lowercase();
278
279 match convention {
280 TableNamingConvention::Plural => self.pluralize_name(&base_name),
281 TableNamingConvention::Singular => base_name,
282 TableNamingConvention::Custom(pattern) => pattern.replace("{model}", &base_name),
283 }
284 }
285
286 pub fn pluralize_name(&self, name: &str) -> String {
288 if name.ends_with('y')
289 && !name.ends_with("ay")
290 && !name.ends_with("ey")
291 && !name.ends_with("iy")
292 && !name.ends_with("oy")
293 && !name.ends_with("uy")
294 {
295 format!("{}ies", &name[..name.len() - 1])
296 } else if name.ends_with('s')
297 || name.ends_with("sh")
298 || name.ends_with("ch")
299 || name.ends_with('x')
300 || name.ends_with('z')
301 {
302 format!("{}es", name)
303 } else {
304 format!("{}s", name)
305 }
306 }
307
308 pub fn singularize_table_name(&self, name: &str) -> String {
310 if name.ends_with("ies") {
311 format!("{}y", &name[..name.len() - 3])
312 } else if name.ends_with("ses")
313 || name.ends_with("ches")
314 || name.ends_with("shes")
315 || name.ends_with("xes")
316 || name.ends_with("zes")
317 {
318 name[..name.len() - 2].to_string()
319 } else if name.ends_with('s') && name.len() > 1 {
320 name[..name.len() - 1].to_string()
321 } else {
322 name.to_string()
323 }
324 }
325
326 pub fn to_camel_case(&self, s: &str) -> String {
328 let parts: Vec<&str> = s.split('_').collect();
329 if parts.is_empty() {
330 return s.to_string();
331 }
332
333 let mut result = parts[0].to_lowercase();
334 for part in &parts[1..] {
335 if !part.is_empty() {
336 let mut chars = part.chars();
337 if let Some(first) = chars.next() {
338 result.push(first.to_uppercase().next().unwrap());
339 result.extend(chars.flat_map(|c| c.to_lowercase()));
340 }
341 }
342 }
343
344 result
345 }
346
347 pub fn to_pascal_case(&self, s: &str) -> String {
349 let camel = self.to_camel_case(s);
350 let mut chars = camel.chars();
351 if let Some(first) = chars.next() {
352 first.to_uppercase().collect::<String>() + &chars.collect::<String>()
353 } else {
354 camel
355 }
356 }
357}
358
359impl<Parent> Default for RelationshipInferenceEngine<Parent>
360where
361 Parent: InferableModel + DeserializeOwned + Send + Sync,
362{
363 fn default() -> Self {
364 Self::new()
365 }
366}
367
368pub struct TypeInferenceHelper;
370
371impl TypeInferenceHelper {
372 pub fn infer_from_type_name(type_name: &str) -> Option<RelationshipType> {
374 if type_name.contains("Option<") {
375 if type_name.contains("Vec<") {
377 None } else {
379 Some(RelationshipType::HasOne) }
381 } else if type_name.contains("Vec<") {
382 Some(RelationshipType::HasMany) } else if type_name.contains("MorphOne<") {
384 Some(RelationshipType::MorphOne)
385 } else if type_name.contains("MorphMany<") {
386 Some(RelationshipType::MorphMany)
387 } else if type_name.contains("MorphTo<") {
388 Some(RelationshipType::MorphTo)
389 } else {
390 None }
392 }
393
394 pub fn infer_from_field_name(field_name: &str) -> Option<RelationshipType> {
396 if field_name.ends_with("_id") {
397 Some(RelationshipType::BelongsTo)
398 } else if field_name.ends_with("_ids") {
399 Some(RelationshipType::ManyToMany)
400 } else if field_name.ends_with("able") || field_name.contains("morph") {
401 Some(RelationshipType::MorphTo) } else {
403 None
404 }
405 }
406
407 pub fn suggest_relationship_type(
409 field_name: &str,
410 type_name: &str,
411 is_collection: bool,
412 is_optional: bool,
413 ) -> RelationshipType {
414 if let Some(rt) = Self::infer_from_field_name(field_name) {
416 return rt;
417 }
418
419 if let Some(rt) = Self::infer_from_type_name(type_name) {
421 return rt;
422 }
423
424 match (is_collection, is_optional) {
426 (true, _) => RelationshipType::HasMany,
427 (false, true) => RelationshipType::HasOne,
428 (false, false) => RelationshipType::BelongsTo, }
430 }
431}
432
433#[macro_export]
435macro_rules! relationship_hints {
436 ($(($field:expr, $type:expr, $related:expr, $eager:expr)),* $(,)?) => {
437 vec![
438 $(
439 $crate::relationships::inference::RelationshipHint {
440 field_name: $field.to_string(),
441 relationship_type: $type,
442 related_model: $related.to_string(),
443 custom_foreign_key: None,
444 eager_load: $eager,
445 }
446 ),*
447 ]
448 };
449
450 ($(($field:expr, $type:expr, $related:expr, $eager:expr, $fk:expr)),* $(,)?) => {
451 vec![
452 $(
453 $crate::relationships::inference::RelationshipHint {
454 field_name: $field.to_string(),
455 relationship_type: $type,
456 related_model: $related.to_string(),
457 custom_foreign_key: Some($fk.to_string()),
458 eager_load: $eager,
459 }
460 ),*
461 ]
462 };
463}
464
465#[cfg(test)]
466mod tests {
467 use super::super::metadata::RelationshipType;
468 use super::*;
469
470 #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
471 struct TestUser {
472 id: Option<i64>,
473 name: String,
474 email: String,
475 }
476
477 impl Model for TestUser {
478 type PrimaryKey = i64;
479
480 fn table_name() -> &'static str {
481 "users"
482 }
483
484 fn primary_key(&self) -> Option<Self::PrimaryKey> {
485 self.id
486 }
487
488 fn set_primary_key(&mut self, key: Self::PrimaryKey) {
489 self.id = Some(key);
490 }
491
492 fn to_fields(&self) -> std::collections::HashMap<String, serde_json::Value> {
493 let mut fields = std::collections::HashMap::new();
494 fields.insert("id".to_string(), serde_json::json!(self.id));
495 fields.insert(
496 "name".to_string(),
497 serde_json::Value::String(self.name.clone()),
498 );
499 fields.insert(
500 "email".to_string(),
501 serde_json::Value::String(self.email.clone()),
502 );
503 fields
504 }
505
506 fn from_row(row: &sqlx::postgres::PgRow) -> crate::error::ModelResult<Self> {
507 use sqlx::Row;
508 Ok(Self {
509 id: row.try_get("id").ok(),
510 name: row.try_get("name").unwrap_or_default(),
511 email: row.try_get("email").unwrap_or_default(),
512 })
513 }
514 }
515
516 impl InferableModel for TestUser {
517 fn relationship_hints() -> Vec<RelationshipHint> {
518 relationship_hints![
519 ("posts", RelationshipType::HasMany, "Post", false),
520 ("profile", RelationshipType::HasOne, "Profile", true),
521 ("roles", RelationshipType::ManyToMany, "Role", false)
522 ]
523 }
524 }
525
526 #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
527 struct TestPost {
528 id: Option<i64>,
529 title: String,
530 user_id: Option<i64>,
531 }
532
533 impl Model for TestPost {
534 type PrimaryKey = i64;
535
536 fn table_name() -> &'static str {
537 "posts"
538 }
539
540 fn primary_key(&self) -> Option<Self::PrimaryKey> {
541 self.id
542 }
543
544 fn set_primary_key(&mut self, key: Self::PrimaryKey) {
545 self.id = Some(key);
546 }
547
548 fn to_fields(&self) -> std::collections::HashMap<String, serde_json::Value> {
549 let mut fields = std::collections::HashMap::new();
550 fields.insert("id".to_string(), serde_json::json!(self.id));
551 fields.insert(
552 "title".to_string(),
553 serde_json::Value::String(self.title.clone()),
554 );
555 fields.insert("user_id".to_string(), serde_json::json!(self.user_id));
556 fields
557 }
558
559 fn from_row(row: &sqlx::postgres::PgRow) -> crate::error::ModelResult<Self> {
560 use sqlx::Row;
561 Ok(Self {
562 id: row.try_get("id").ok(),
563 title: row.try_get("title").unwrap_or_default(),
564 user_id: row.try_get("user_id").ok(),
565 })
566 }
567 }
568
569 impl InferableModel for TestPost {
570 fn relationship_hints() -> Vec<RelationshipHint> {
571 relationship_hints![("user", RelationshipType::BelongsTo, "User", true)]
572 }
573 }
574
575 #[test]
576 fn test_inference_engine_creation() {
577 let mut engine = RelationshipInferenceEngine::<TestUser>::new();
578
579 let relationships = engine.infer_all_relationships().unwrap();
581
582 assert_eq!(relationships.len(), 3);
583 assert_eq!(relationships[0].name, "posts");
584 assert_eq!(
585 relationships[0].relationship_type,
586 RelationshipType::HasMany
587 );
588 assert_eq!(relationships[1].name, "profile");
589 assert_eq!(relationships[1].relationship_type, RelationshipType::HasOne);
590 assert!(relationships[1].eager_load);
591 }
592
593 #[test]
594 fn test_foreign_key_inference() {
595 let engine = RelationshipInferenceEngine::<TestUser>::new();
596
597 let fk = engine.infer_foreign_key_name("user").unwrap();
598 assert_eq!(fk, "user_id");
599
600 let fk = engine.infer_foreign_key_name("posts").unwrap();
601 assert_eq!(fk, "post_id");
602 }
603
604 #[test]
605 fn test_pluralization() {
606 let engine = RelationshipInferenceEngine::<TestUser>::new();
607
608 assert_eq!(engine.pluralize_name("user"), "users");
609 assert_eq!(engine.pluralize_name("post"), "posts");
610 assert_eq!(engine.pluralize_name("category"), "categories");
611 assert_eq!(engine.pluralize_name("box"), "boxes");
612 }
613
614 #[test]
615 fn test_singularization() {
616 let engine = RelationshipInferenceEngine::<TestUser>::new();
617
618 assert_eq!(engine.singularize_table_name("users"), "user");
619 assert_eq!(engine.singularize_table_name("posts"), "post");
620 assert_eq!(engine.singularize_table_name("categories"), "category");
621 assert_eq!(engine.singularize_table_name("boxes"), "box");
622 }
623
624 #[test]
625 fn test_case_conversion() {
626 let engine = RelationshipInferenceEngine::<TestUser>::new();
627
628 assert_eq!(engine.to_camel_case("user_id"), "userId");
629 assert_eq!(engine.to_camel_case("user"), "user");
630 assert_eq!(engine.to_pascal_case("user_id"), "UserId");
631 assert_eq!(engine.to_pascal_case("user"), "User");
632 }
633
634 #[test]
635 fn test_type_inference_helper() {
636 assert_eq!(
637 TypeInferenceHelper::infer_from_type_name("Option<Post>"),
638 Some(RelationshipType::HasOne)
639 );
640 assert_eq!(
641 TypeInferenceHelper::infer_from_type_name("Vec<Post>"),
642 Some(RelationshipType::HasMany)
643 );
644 assert_eq!(
645 TypeInferenceHelper::infer_from_field_name("user_id"),
646 Some(RelationshipType::BelongsTo)
647 );
648 assert_eq!(
649 TypeInferenceHelper::infer_from_field_name("role_ids"),
650 Some(RelationshipType::ManyToMany)
651 );
652 }
653
654 #[test]
655 fn test_relationship_type_suggestion() {
656 let rt = TypeInferenceHelper::suggest_relationship_type("posts", "Vec<Post>", true, false);
657 assert_eq!(rt, RelationshipType::HasMany);
658
659 let rt = TypeInferenceHelper::suggest_relationship_type("user_id", "i64", false, false);
660 assert_eq!(rt, RelationshipType::BelongsTo);
661
662 let rt = TypeInferenceHelper::suggest_relationship_type(
663 "profile",
664 "Option<Profile>",
665 false,
666 true,
667 );
668 assert_eq!(rt, RelationshipType::HasOne);
669 }
670}