use std::collections::HashMap;
use super::{GraphQLScalar, RelationType, to_pascal_case, to_camel_case};
#[derive(Debug, Clone)]
pub struct GraphQLSchema {
pub types: Vec<GraphQLType>,
pub queries: Vec<QueryDefinition>,
pub mutations: Vec<MutationDefinition>,
pub relationships: Vec<Relationship>,
pub input_types: Vec<GraphQLInputType>,
pub enum_types: Vec<GraphQLEnumType>,
}
impl GraphQLSchema {
pub fn new() -> Self {
Self {
types: Vec::new(),
queries: Vec::new(),
mutations: Vec::new(),
relationships: Vec::new(),
input_types: Vec::new(),
enum_types: Vec::new(),
}
}
pub fn add_type(&mut self, type_def: GraphQLType) {
self.types.push(type_def);
}
pub fn add_query(&mut self, query: QueryDefinition) {
self.queries.push(query);
}
pub fn add_mutation(&mut self, mutation: MutationDefinition) {
self.mutations.push(mutation);
}
pub fn add_relationship(&mut self, relationship: Relationship) {
self.relationships.push(relationship);
}
pub fn get_type(&self, name: &str) -> Option<&GraphQLType> {
self.types.iter().find(|t| t.name == name)
}
pub fn get_relationships_for(&self, type_name: &str) -> Vec<&Relationship> {
self.relationships.iter()
.filter(|r| r.from_type == type_name)
.collect()
}
pub fn to_sdl(&self) -> String {
let mut sdl = String::new();
sdl.push_str("# Custom Scalars\n");
sdl.push_str("scalar DateTime\n");
sdl.push_str("scalar Date\n");
sdl.push_str("scalar Time\n");
sdl.push_str("scalar JSON\n");
sdl.push_str("scalar Decimal\n");
sdl.push_str("scalar BigInt\n");
sdl.push_str("\n");
for enum_type in &self.enum_types {
sdl.push_str(&format!("enum {} {{\n", enum_type.name));
for value in &enum_type.values {
sdl.push_str(&format!(" {}\n", value));
}
sdl.push_str("}\n\n");
}
for type_def in &self.types {
if let Some(ref desc) = type_def.description {
sdl.push_str(&format!("\"\"\"{}\"\"\"\n", desc));
}
sdl.push_str(&format!("type {} {{\n", type_def.name));
for field in &type_def.fields {
if let Some(ref desc) = field.description {
sdl.push_str(&format!(" \"\"\"{}\"\"\"\n", desc));
}
let type_str = if field.nullable {
field.graphql_type.to_string()
} else {
format!("{}!", field.graphql_type)
};
sdl.push_str(&format!(" {}: {}\n", field.name, type_str));
}
for rel in self.get_relationships_for(&type_def.name) {
let type_str = if rel.relation_type.is_list() {
format!("[{}!]!", rel.to_type)
} else {
format!("{}!", rel.to_type)
};
sdl.push_str(&format!(" {}: {}\n", rel.field_name, type_str));
}
sdl.push_str("}\n\n");
}
for input_type in &self.input_types {
sdl.push_str(&format!("input {} {{\n", input_type.name));
for field in &input_type.fields {
let type_str = if field.nullable {
field.graphql_type.to_string()
} else {
format!("{}!", field.graphql_type)
};
sdl.push_str(&format!(" {}: {}\n", field.name, type_str));
}
sdl.push_str("}\n\n");
}
sdl.push_str("type Query {\n");
for query in &self.queries {
let args: Vec<String> = query.arguments.iter()
.map(|a| {
let type_str = if a.nullable {
a.graphql_type.to_string()
} else {
format!("{}!", a.graphql_type)
};
format!("{}: {}", a.name, type_str)
})
.collect();
let args_str = if args.is_empty() {
String::new()
} else {
format!("({})", args.join(", "))
};
let return_type = if query.returns_list {
format!("[{}!]!", query.return_type)
} else {
query.return_type.clone()
};
sdl.push_str(&format!(" {}{}: {}\n", query.name, args_str, return_type));
}
sdl.push_str("}\n\n");
if !self.mutations.is_empty() {
sdl.push_str("type Mutation {\n");
for mutation in &self.mutations {
let args: Vec<String> = mutation.arguments.iter()
.map(|a| {
let type_str = if a.nullable {
a.graphql_type.to_string()
} else {
format!("{}!", a.graphql_type)
};
format!("{}: {}", a.name, type_str)
})
.collect();
let args_str = if args.is_empty() {
String::new()
} else {
format!("({})", args.join(", "))
};
sdl.push_str(&format!(" {}{}: {}\n", mutation.name, args_str, mutation.return_type));
}
sdl.push_str("}\n");
}
sdl
}
}
impl Default for GraphQLSchema {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct GraphQLType {
pub name: String,
pub fields: Vec<GraphQLField>,
pub description: Option<String>,
pub table_name: Option<String>,
}
impl GraphQLType {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
fields: Vec::new(),
description: None,
table_name: None,
}
}
pub fn add_field(&mut self, field: GraphQLField) {
self.fields.push(field);
}
pub fn with_description(mut self, description: impl Into<String>) -> Self {
self.description = Some(description.into());
self
}
pub fn from_table(mut self, table_name: impl Into<String>) -> Self {
self.table_name = Some(table_name.into());
self
}
pub fn get_field(&self, name: &str) -> Option<&GraphQLField> {
self.fields.iter().find(|f| f.name == name)
}
}
#[derive(Debug, Clone)]
pub struct GraphQLField {
pub name: String,
pub graphql_type: FieldType,
pub nullable: bool,
pub description: Option<String>,
pub column_name: Option<String>,
pub deprecated: bool,
pub deprecation_reason: Option<String>,
}
impl GraphQLField {
pub fn new(name: impl Into<String>, graphql_type: FieldType) -> Self {
Self {
name: name.into(),
graphql_type,
nullable: true,
description: None,
column_name: None,
deprecated: false,
deprecation_reason: None,
}
}
pub fn nullable(mut self, nullable: bool) -> Self {
self.nullable = nullable;
self
}
pub fn with_description(mut self, description: impl Into<String>) -> Self {
self.description = Some(description.into());
self
}
pub fn from_column(mut self, column_name: impl Into<String>) -> Self {
self.column_name = Some(column_name.into());
self
}
pub fn deprecated(mut self, reason: impl Into<String>) -> Self {
self.deprecated = true;
self.deprecation_reason = Some(reason.into());
self
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum FieldType {
Scalar(GraphQLScalar),
Object(String),
List(Box<FieldType>),
NonNull(Box<FieldType>),
}
impl FieldType {
pub fn scalar(scalar: GraphQLScalar) -> Self {
FieldType::Scalar(scalar)
}
pub fn object(name: impl Into<String>) -> Self {
FieldType::Object(name.into())
}
pub fn list(inner: FieldType) -> Self {
FieldType::List(Box::new(inner))
}
pub fn non_null(inner: FieldType) -> Self {
FieldType::NonNull(Box::new(inner))
}
}
impl std::fmt::Display for FieldType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
FieldType::Scalar(s) => write!(f, "{}", s.to_sdl()),
FieldType::Object(name) => write!(f, "{}", name),
FieldType::List(inner) => write!(f, "[{}]", inner),
FieldType::NonNull(inner) => write!(f, "{}!", inner),
}
}
}
#[derive(Debug, Clone)]
pub struct GraphQLInputType {
pub name: String,
pub fields: Vec<GraphQLField>,
}
#[derive(Debug, Clone)]
pub struct GraphQLEnumType {
pub name: String,
pub values: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct QueryDefinition {
pub name: String,
pub arguments: Vec<ArgumentDefinition>,
pub return_type: String,
pub returns_list: bool,
pub table_name: Option<String>,
}
impl QueryDefinition {
pub fn new(name: impl Into<String>, return_type: impl Into<String>) -> Self {
Self {
name: name.into(),
arguments: Vec::new(),
return_type: return_type.into(),
returns_list: false,
table_name: None,
}
}
pub fn arg(mut self, arg: ArgumentDefinition) -> Self {
self.arguments.push(arg);
self
}
pub fn returns_list(mut self, list: bool) -> Self {
self.returns_list = list;
self
}
pub fn from_table(mut self, table: impl Into<String>) -> Self {
self.table_name = Some(table.into());
self
}
}
#[derive(Debug, Clone)]
pub struct MutationDefinition {
pub name: String,
pub arguments: Vec<ArgumentDefinition>,
pub return_type: String,
pub table_name: Option<String>,
pub kind: MutationKind,
}
impl MutationDefinition {
pub fn new(name: impl Into<String>, return_type: impl Into<String>, kind: MutationKind) -> Self {
Self {
name: name.into(),
arguments: Vec::new(),
return_type: return_type.into(),
table_name: None,
kind,
}
}
pub fn arg(mut self, arg: ArgumentDefinition) -> Self {
self.arguments.push(arg);
self
}
pub fn from_table(mut self, table: impl Into<String>) -> Self {
self.table_name = Some(table.into());
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MutationKind {
Create,
Update,
Delete,
}
#[derive(Debug, Clone)]
pub struct ArgumentDefinition {
pub name: String,
pub graphql_type: FieldType,
pub nullable: bool,
pub default_value: Option<serde_json::Value>,
}
impl ArgumentDefinition {
pub fn new(name: impl Into<String>, graphql_type: FieldType) -> Self {
Self {
name: name.into(),
graphql_type,
nullable: true,
default_value: None,
}
}
pub fn required(mut self) -> Self {
self.nullable = false;
self
}
pub fn default(mut self, value: serde_json::Value) -> Self {
self.default_value = Some(value);
self
}
}
#[derive(Debug, Clone)]
pub struct Relationship {
pub name: String,
pub from_type: String,
pub to_type: String,
pub from_column: String,
pub to_column: String,
pub relation_type: RelationType,
pub field_name: String,
}
impl Relationship {
pub fn new(
name: impl Into<String>,
from_type: impl Into<String>,
to_type: impl Into<String>,
relation_type: RelationType,
) -> Self {
let name = name.into();
let field_name = to_camel_case(&name);
Self {
name: name.clone(),
from_type: from_type.into(),
to_type: to_type.into(),
from_column: "id".to_string(),
to_column: "id".to_string(),
relation_type,
field_name,
}
}
pub fn columns(mut self, from: impl Into<String>, to: impl Into<String>) -> Self {
self.from_column = from.into();
self.to_column = to.into();
self
}
pub fn field(mut self, name: impl Into<String>) -> Self {
self.field_name = name.into();
self
}
}
#[derive(Debug)]
pub struct SchemaIntrospector {
excluded_tables: Vec<String>,
excluded_columns: HashMap<String, Vec<String>>,
type_names: HashMap<String, String>,
}
impl SchemaIntrospector {
pub fn new() -> Self {
Self {
excluded_tables: vec![
"pg_catalog".to_string(),
"information_schema".to_string(),
],
excluded_columns: HashMap::new(),
type_names: HashMap::new(),
}
}
pub fn exclude_table(&mut self, table: impl Into<String>) {
self.excluded_tables.push(table.into());
}
pub fn exclude_column(&mut self, table: impl Into<String>, column: impl Into<String>) {
self.excluded_columns
.entry(table.into())
.or_default()
.push(column.into());
}
pub fn set_type_name(&mut self, table: impl Into<String>, type_name: impl Into<String>) {
self.type_names.insert(table.into(), type_name.into());
}
pub fn build_schema(&self, tables: &[TableDefinition]) -> GraphQLSchema {
let mut schema = GraphQLSchema::new();
for table in tables {
if self.excluded_tables.contains(&table.name) {
continue;
}
let type_def = self.generate_type(table);
let type_name = type_def.name.clone();
schema.add_type(type_def);
schema.add_query(
QueryDefinition::new(to_camel_case(&table.name), &type_name)
.arg(ArgumentDefinition::new("id", FieldType::scalar(GraphQLScalar::ID)).required())
.from_table(&table.name)
);
schema.add_query(
QueryDefinition::new(format!("{}s", to_camel_case(&table.name)), &type_name)
.arg(ArgumentDefinition::new("limit", FieldType::scalar(GraphQLScalar::Int)))
.arg(ArgumentDefinition::new("offset", FieldType::scalar(GraphQLScalar::Int)))
.arg(ArgumentDefinition::new("where", FieldType::object(format!("{}Filter", type_name))))
.returns_list(true)
.from_table(&table.name)
);
schema.add_mutation(
MutationDefinition::new(format!("create{}", type_name), &type_name, MutationKind::Create)
.arg(ArgumentDefinition::new("input", FieldType::object(format!("Create{}Input", type_name))).required())
.from_table(&table.name)
);
schema.add_mutation(
MutationDefinition::new(format!("update{}", type_name), &type_name, MutationKind::Update)
.arg(ArgumentDefinition::new("id", FieldType::scalar(GraphQLScalar::ID)).required())
.arg(ArgumentDefinition::new("input", FieldType::object(format!("Update{}Input", type_name))).required())
.from_table(&table.name)
);
schema.add_mutation(
MutationDefinition::new(format!("delete{}", type_name), "Boolean".to_string(), MutationKind::Delete)
.arg(ArgumentDefinition::new("id", FieldType::scalar(GraphQLScalar::ID)).required())
.from_table(&table.name)
);
let filter_type = self.generate_filter_type(table);
schema.input_types.push(filter_type);
let create_input = self.generate_create_input(table, &type_name);
schema.input_types.push(create_input);
let update_input = self.generate_update_input(table, &type_name);
schema.input_types.push(update_input);
}
for table in tables {
for fk in &table.foreign_keys {
let from_type = self.get_type_name(&table.name);
let to_type = self.get_type_name(&fk.referenced_table);
schema.add_relationship(
Relationship::new(&fk.name, &from_type, &to_type, RelationType::ManyToOne)
.columns(&fk.column, &fk.referenced_column)
.field(to_camel_case(&fk.name))
);
let reverse_name = format!("{}s", to_camel_case(&table.name));
schema.add_relationship(
Relationship::new(&reverse_name, &to_type, &from_type, RelationType::OneToMany)
.columns(&fk.referenced_column, &fk.column)
.field(&reverse_name)
);
}
}
schema
}
fn generate_type(&self, table: &TableDefinition) -> GraphQLType {
let type_name = self.get_type_name(&table.name);
let mut type_def = GraphQLType::new(&type_name)
.from_table(&table.name);
let excluded = self.excluded_columns.get(&table.name);
for column in &table.columns {
if let Some(excluded) = excluded {
if excluded.contains(&column.name) {
continue;
}
}
let scalar = GraphQLScalar::from_sql_type(&column.data_type);
let field = GraphQLField::new(
to_camel_case(&column.name),
FieldType::scalar(scalar),
)
.nullable(column.nullable)
.from_column(&column.name);
type_def.add_field(field);
}
type_def
}
fn generate_filter_type(&self, table: &TableDefinition) -> GraphQLInputType {
let type_name = self.get_type_name(&table.name);
let mut input = GraphQLInputType {
name: format!("{}Filter", type_name),
fields: Vec::new(),
};
for column in &table.columns {
let scalar = GraphQLScalar::from_sql_type(&column.data_type);
let filter_type_name = format!("{}Filter", scalar.to_sdl());
input.fields.push(GraphQLField::new(
to_camel_case(&column.name),
FieldType::object(filter_type_name),
));
}
input.fields.push(GraphQLField::new(
"AND",
FieldType::list(FieldType::object(format!("{}Filter", type_name))),
));
input.fields.push(GraphQLField::new(
"OR",
FieldType::list(FieldType::object(format!("{}Filter", type_name))),
));
input
}
fn generate_create_input(&self, table: &TableDefinition, type_name: &str) -> GraphQLInputType {
let mut input = GraphQLInputType {
name: format!("Create{}Input", type_name),
fields: Vec::new(),
};
for column in &table.columns {
if column.is_primary_key && column.data_type.to_lowercase().contains("serial") {
continue;
}
let scalar = GraphQLScalar::from_sql_type(&column.data_type);
input.fields.push(GraphQLField::new(
to_camel_case(&column.name),
FieldType::scalar(scalar),
).nullable(column.nullable || column.has_default));
}
input
}
fn generate_update_input(&self, table: &TableDefinition, type_name: &str) -> GraphQLInputType {
let mut input = GraphQLInputType {
name: format!("Update{}Input", type_name),
fields: Vec::new(),
};
for column in &table.columns {
if column.is_primary_key {
continue;
}
let scalar = GraphQLScalar::from_sql_type(&column.data_type);
input.fields.push(GraphQLField::new(
to_camel_case(&column.name),
FieldType::scalar(scalar),
));
}
input
}
fn get_type_name(&self, table_name: &str) -> String {
self.type_names
.get(table_name)
.cloned()
.unwrap_or_else(|| to_pascal_case(table_name))
}
}
impl Default for SchemaIntrospector {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct TableDefinition {
pub name: String,
pub schema: String,
pub columns: Vec<ColumnDefinition>,
pub foreign_keys: Vec<ForeignKeyDefinition>,
}
impl TableDefinition {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
schema: "public".to_string(),
columns: Vec::new(),
foreign_keys: Vec::new(),
}
}
pub fn column(mut self, column: ColumnDefinition) -> Self {
self.columns.push(column);
self
}
pub fn foreign_key(mut self, fk: ForeignKeyDefinition) -> Self {
self.foreign_keys.push(fk);
self
}
}
#[derive(Debug, Clone)]
pub struct ColumnDefinition {
pub name: String,
pub data_type: String,
pub nullable: bool,
pub is_primary_key: bool,
pub has_default: bool,
}
impl ColumnDefinition {
pub fn new(name: impl Into<String>, data_type: impl Into<String>) -> Self {
Self {
name: name.into(),
data_type: data_type.into(),
nullable: true,
is_primary_key: false,
has_default: false,
}
}
pub fn nullable(mut self, nullable: bool) -> Self {
self.nullable = nullable;
self
}
pub fn primary_key(mut self) -> Self {
self.is_primary_key = true;
self.nullable = false;
self
}
pub fn with_default(mut self) -> Self {
self.has_default = true;
self
}
}
#[derive(Debug, Clone)]
pub struct ForeignKeyDefinition {
pub name: String,
pub column: String,
pub referenced_table: String,
pub referenced_column: String,
}
impl ForeignKeyDefinition {
pub fn new(
name: impl Into<String>,
column: impl Into<String>,
referenced_table: impl Into<String>,
referenced_column: impl Into<String>,
) -> Self {
Self {
name: name.into(),
column: column.into(),
referenced_table: referenced_table.into(),
referenced_column: referenced_column.into(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_tables() -> Vec<TableDefinition> {
vec![
TableDefinition::new("users")
.column(ColumnDefinition::new("id", "serial").primary_key())
.column(ColumnDefinition::new("name", "varchar(255)").nullable(false))
.column(ColumnDefinition::new("email", "varchar(255)").nullable(false))
.column(ColumnDefinition::new("created_at", "timestamp").with_default()),
TableDefinition::new("posts")
.column(ColumnDefinition::new("id", "serial").primary_key())
.column(ColumnDefinition::new("title", "varchar(255)").nullable(false))
.column(ColumnDefinition::new("content", "text"))
.column(ColumnDefinition::new("user_id", "integer").nullable(false))
.foreign_key(ForeignKeyDefinition::new("author", "user_id", "users", "id")),
]
}
#[test]
fn test_introspector_build_schema() {
let introspector = SchemaIntrospector::new();
let tables = create_test_tables();
let schema = introspector.build_schema(&tables);
assert_eq!(schema.types.len(), 2);
assert!(schema.get_type("Users").is_some());
assert!(schema.get_type("Posts").is_some());
}
#[test]
fn test_schema_to_sdl() {
let introspector = SchemaIntrospector::new();
let tables = create_test_tables();
let schema = introspector.build_schema(&tables);
let sdl = schema.to_sdl();
assert!(sdl.contains("type Users"));
assert!(sdl.contains("type Posts"));
assert!(sdl.contains("type Query"));
assert!(sdl.contains("type Mutation"));
}
#[test]
fn test_type_generation() {
let introspector = SchemaIntrospector::new();
let table = TableDefinition::new("users")
.column(ColumnDefinition::new("id", "serial").primary_key())
.column(ColumnDefinition::new("name", "varchar").nullable(false));
let type_def = introspector.generate_type(&table);
assert_eq!(type_def.name, "Users");
assert_eq!(type_def.fields.len(), 2);
assert_eq!(type_def.fields[0].name, "id");
assert_eq!(type_def.fields[1].name, "name");
}
#[test]
fn test_relationship_generation() {
let introspector = SchemaIntrospector::new();
let tables = create_test_tables();
let schema = introspector.build_schema(&tables);
let post_relationships = schema.get_relationships_for("Posts");
assert_eq!(post_relationships.len(), 1);
assert_eq!(post_relationships[0].to_type, "Users");
assert_eq!(post_relationships[0].relation_type, RelationType::ManyToOne);
let user_relationships = schema.get_relationships_for("Users");
assert_eq!(user_relationships.len(), 1);
assert_eq!(user_relationships[0].to_type, "Posts");
assert_eq!(user_relationships[0].relation_type, RelationType::OneToMany);
}
#[test]
fn test_excluded_columns() {
let mut introspector = SchemaIntrospector::new();
introspector.exclude_column("users", "password_hash");
let table = TableDefinition::new("users")
.column(ColumnDefinition::new("id", "serial").primary_key())
.column(ColumnDefinition::new("password_hash", "varchar"));
let type_def = introspector.generate_type(&table);
assert_eq!(type_def.fields.len(), 1);
assert!(type_def.get_field("passwordHash").is_none());
}
#[test]
fn test_type_name_override() {
let mut introspector = SchemaIntrospector::new();
introspector.set_type_name("users", "User");
let table = TableDefinition::new("users")
.column(ColumnDefinition::new("id", "serial").primary_key());
let type_def = introspector.generate_type(&table);
assert_eq!(type_def.name, "User");
}
#[test]
fn test_field_type_display() {
assert_eq!(FieldType::scalar(GraphQLScalar::String).to_string(), "String");
assert_eq!(FieldType::object("User").to_string(), "User");
assert_eq!(FieldType::list(FieldType::object("User")).to_string(), "[User]");
assert_eq!(
FieldType::non_null(FieldType::list(FieldType::object("User"))).to_string(),
"[User]!"
);
}
#[test]
fn test_graphql_schema_default() {
let schema = GraphQLSchema::default();
assert!(schema.types.is_empty());
assert!(schema.queries.is_empty());
assert!(schema.mutations.is_empty());
}
}