use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum RelationshipType {
OneToOne,
OneToMany,
ManyToMany,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EntityDefinition {
pub name: String,
pub schema: Value,
pub primary_key: Option<String>,
pub foreign_keys: HashMap<String, String>,
pub relationships: Vec<Relationship>,
pub examples: Vec<Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Relationship {
pub target: String,
pub relationship_type: RelationshipType,
pub field: Option<String>,
}
pub struct EntityInference;
impl EntityInference {
pub fn infer_entities(payloads: Vec<Value>) -> Vec<EntityDefinition> {
let mut entities: HashMap<String, EntityDefinition> = HashMap::new();
for payload in payloads {
if let Some(entity_name) = Self::infer_entity_name(&payload) {
let entity =
entities.entry(entity_name.clone()).or_insert_with(|| EntityDefinition {
name: entity_name.clone(),
schema: json!({}),
primary_key: None,
foreign_keys: HashMap::new(),
relationships: Vec::new(),
examples: Vec::new(),
});
let merged = Self::merge_schema(&entity.schema, &payload);
entity.schema = merged;
entity.examples.push(payload);
}
}
let entity_names: Vec<String> = entities.keys().cloned().collect();
for entity_name in &entity_names {
if let Some(entity) = entities.get_mut(entity_name) {
entity.primary_key = Self::infer_primary_key(&entity.schema);
}
}
let entity_schemas: Vec<(String, Value)> = entities
.iter()
.map(|(name, entity)| (name.clone(), entity.schema.clone()))
.collect();
let entity_names_set: std::collections::HashSet<String> =
entities.keys().cloned().collect();
for (entity_name, schema) in entity_schemas {
if let Some(entity) = entities.get_mut(&entity_name) {
entity.foreign_keys = Self::infer_foreign_keys(&schema, &entity_names_set);
}
}
let foreign_keys_map: HashMap<String, HashMap<String, String>> = entities
.iter()
.map(|(name, entity)| (name.clone(), entity.foreign_keys.clone()))
.collect();
for (entity_name, foreign_keys) in foreign_keys_map {
if let Some(entity) = entities.get_mut(&entity_name) {
entity.relationships = foreign_keys
.iter()
.map(|(field_name, target_entity)| Relationship {
target: target_entity.clone(),
relationship_type: RelationshipType::OneToMany,
field: Some(field_name.clone()),
})
.collect();
}
}
entities.into_values().collect()
}
fn infer_entity_name(payload: &Value) -> Option<String> {
if let Some(obj) = payload.as_object() {
if let Some(id) = obj.get("id") {
if let Some(id_str) = id.as_str() {
if let Some(prefix) = id_str.split('_').next() {
return Some(Self::capitalize(prefix));
}
}
}
if let Some(typ) = obj.get("type") {
if let Some(typ_str) = typ.as_str() {
return Some(Self::capitalize(typ_str));
}
}
if let Some((first_key, _)) = obj.iter().next() {
if first_key.len() > 1 {
return Some(Self::capitalize(first_key));
}
}
}
Some("Entity".to_string())
}
fn merge_schema(existing: &Value, new: &Value) -> Value {
match (existing, new) {
(Value::Object(existing_obj), Value::Object(new_obj)) => {
let mut merged = existing_obj.clone();
for (key, new_val) in new_obj {
let key_str = key.as_str();
if let Some(existing_val) = merged.get(key_str) {
merged.insert(key.clone(), Self::merge_schema(existing_val, new_val));
} else {
merged.insert(key.clone(), Self::infer_field_schema(new_val));
}
}
Value::Object(merged)
}
(_, new) => Self::infer_field_schema(new),
}
}
fn infer_field_schema(value: &Value) -> Value {
match value {
Value::Null => json!({"type": "null"}),
Value::Bool(_) => json!({"type": "boolean"}),
Value::Number(n) => {
if n.is_i64() || n.is_u64() {
json!({"type": "integer"})
} else {
json!({"type": "number"})
}
}
Value::String(s) => {
let mut schema = json!({"type": "string"});
if s.contains('@') && s.contains('.') {
schema["format"] = json!("email");
} else if s.len() == 36 && s.contains('-') {
schema["format"] = json!("uuid");
} else if s.starts_with("http://") || s.starts_with("https://") {
schema["format"] = json!("uri");
}
schema
}
Value::Array(arr) => {
if let Some(first) = arr.first() {
json!({
"type": "array",
"items": Self::infer_field_schema(first)
})
} else {
json!({"type": "array"})
}
}
Value::Object(obj) => {
let mut properties = serde_json::Map::new();
for (key, val) in obj {
properties.insert(key.clone(), Self::infer_field_schema(val));
}
json!({
"type": "object",
"properties": properties
})
}
}
}
fn infer_primary_key(schema: &Value) -> Option<String> {
let properties = if let Some(props) = schema.get("properties").and_then(|p| p.as_object()) {
props
} else if let Some(obj) = schema.as_object() {
obj
} else {
return None;
};
let primary_key_candidates = ["id", "uuid", "_id", "key", "identifier"];
for candidate in &primary_key_candidates {
if properties.contains_key(*candidate) {
return Some(candidate.to_string());
}
}
for key in properties.keys() {
if key.to_lowercase().ends_with("_id") || key.to_lowercase().ends_with("id") {
return Some(key.clone());
}
}
None
}
fn infer_foreign_keys(
schema: &Value,
entity_names: &std::collections::HashSet<String>,
) -> HashMap<String, String> {
let mut foreign_keys = HashMap::new();
if let Some(properties) = schema.get("properties").and_then(|p| p.as_object()) {
for (field_name, _field_schema) in properties {
if field_name.ends_with("_id")
|| field_name.ends_with("Id")
|| field_name.ends_with("_uuid")
{
let base_name = field_name
.trim_end_matches("_id")
.trim_end_matches("Id")
.trim_end_matches("_uuid");
let target_entity = Self::capitalize(base_name);
if entity_names.contains(&target_entity) {
foreign_keys.insert(field_name.clone(), target_entity);
}
}
}
}
foreign_keys
}
fn capitalize(s: &str) -> String {
let mut chars = s.chars();
match chars.next() {
None => String::new(),
Some(first) => first.to_uppercase().collect::<String>() + chars.as_str(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_infer_entities() {
let payloads = vec![
json!({"id": "user_1", "name": "Alice", "email": "alice@example.com"}),
json!({"id": "user_2", "name": "Bob", "email": "bob@example.com"}),
];
let entities = EntityInference::infer_entities(payloads);
assert_eq!(entities.len(), 1);
assert_eq!(entities[0].name, "User");
assert!(entities[0].primary_key.is_some());
}
}