use crate::registry::SchemaRegistry;
use crate::types::{
EnumDefinition, PrimitiveType, SchemaDefinition, SchemaType,
StructDefinition, VariantData,
};
use std::fmt::Write;
pub fn generate(definition: &SchemaDefinition) -> String {
let mut output = String::new();
if let Some(ref desc) = definition.description {
write_description(&mut output, desc, 0);
}
let type_def = generate_type_definition(&definition.name, &definition.schema_type);
output.push_str(&type_def);
output
}
pub fn generate_bundle(registry: &SchemaRegistry) -> String {
let mut output = String::new();
writeln!(output, "# Auto-generated GraphQL Schema Definition").unwrap();
if let Some(ref title) = registry.config().title {
writeln!(output, "# {}", title).unwrap();
}
if let Some(ref description) = registry.config().description {
writeln!(output, "# {}", description).unwrap();
}
writeln!(output).unwrap();
writeln!(output, "# Custom scalar types").unwrap();
writeln!(output, "scalar DateTime").unwrap();
writeln!(output, "scalar Date").unwrap();
writeln!(output, "scalar Time").unwrap();
writeln!(output, "scalar UUID").unwrap();
writeln!(output, "scalar JSON").unwrap();
writeln!(output, "scalar BigInt").unwrap();
writeln!(output).unwrap();
let definitions: Vec<_> = match registry.topological_sort() {
Ok(sorted) => sorted,
Err(_) => registry.definitions().map(|(_, def)| def).collect(),
};
for definition in definitions {
if let Some(ref desc) = definition.description {
write_description(&mut output, desc, 0);
}
let type_def = generate_type_definition(&definition.name, &definition.schema_type);
output.push_str(&type_def);
writeln!(output).unwrap();
}
output
}
fn generate_type_definition(name: &str, schema_type: &SchemaType) -> String {
let name = to_pascal_case(name);
match schema_type {
SchemaType::Struct(def) => generate_graphql_type(&name, def),
SchemaType::Enum(def) => generate_enum_type(&name, def),
SchemaType::Newtype(def) => {
format!("# Type alias: {} = {}\n", name, get_graphql_type(&def.inner_type, true))
}
SchemaType::Primitive(_) => {
String::new()
}
_ => {
format!("scalar {}\n", name)
}
}
}
fn generate_graphql_type(name: &str, def: &StructDefinition) -> String {
let mut output = String::new();
if def.is_tuple_struct {
writeln!(output, "type {} {{", name).unwrap();
for (i, field) in def.fields.values().enumerate() {
let field_type = get_graphql_type(&field.schema_type, field.required);
writeln!(output, " _{}: {}", i, field_type).unwrap();
}
writeln!(output, "}}").unwrap();
} else {
writeln!(output, "type {} {{", name).unwrap();
for (field_name, field) in &def.fields {
if let Some(ref desc) = field.description {
write_description(&mut output, desc, 2);
}
let deprecated = if field.deprecated {
" @deprecated"
} else {
""
};
let gql_name = to_camel_case(field_name);
let field_type = get_graphql_type(&field.schema_type, field.required);
writeln!(output, " {}: {}{}", gql_name, field_type, deprecated).unwrap();
}
writeln!(output, "}}").unwrap();
}
output.push_str(&generate_input_type(name, def));
output
}
fn generate_input_type(name: &str, def: &StructDefinition) -> String {
let mut output = String::new();
if def.is_tuple_struct {
return output;
}
writeln!(output).unwrap();
writeln!(output, "input {}Input {{", name).unwrap();
for (field_name, field) in &def.fields {
let gql_name = to_camel_case(field_name);
let field_type = get_graphql_input_type(&field.schema_type, field.required);
writeln!(output, " {}: {}", gql_name, field_type).unwrap();
}
writeln!(output, "}}").unwrap();
output
}
fn generate_enum_type(name: &str, def: &EnumDefinition) -> String {
let mut output = String::new();
if def.is_simple_enum() {
writeln!(output, "enum {} {{", name).unwrap();
for variant in &def.variants {
if let Some(ref desc) = variant.description {
write_description(&mut output, desc, 2);
}
let deprecated = if variant.deprecated {
" @deprecated"
} else {
""
};
let variant_name = to_screaming_snake_case(&variant.name);
writeln!(output, " {}{}", variant_name, deprecated).unwrap();
}
writeln!(output, "}}").unwrap();
return output;
}
let mut variant_type_names = Vec::new();
for variant in &def.variants {
let variant_type_name = format!("{}_{}", name, variant.name);
variant_type_names.push(variant_type_name.clone());
match &variant.data {
VariantData::Unit => {
writeln!(output, "type {} {{", variant_type_name).unwrap();
writeln!(output, " _type: String!").unwrap();
writeln!(output, "}}").unwrap();
writeln!(output).unwrap();
}
VariantData::Newtype(inner) => {
writeln!(output, "type {} {{", variant_type_name).unwrap();
writeln!(output, " _type: String!").unwrap();
writeln!(
output,
" value: {}",
get_graphql_type(inner, true)
)
.unwrap();
writeln!(output, "}}").unwrap();
writeln!(output).unwrap();
}
VariantData::Tuple(types) => {
writeln!(output, "type {} {{", variant_type_name).unwrap();
writeln!(output, " _type: String!").unwrap();
for (i, t) in types.iter().enumerate() {
writeln!(output, " _{}: {}", i, get_graphql_type(t, true)).unwrap();
}
writeln!(output, "}}").unwrap();
writeln!(output).unwrap();
}
VariantData::Struct(fields) => {
writeln!(output, "type {} {{", variant_type_name).unwrap();
writeln!(output, " _type: String!").unwrap();
for (field_name, field) in fields {
let gql_name = to_camel_case(field_name);
let field_type = get_graphql_type(&field.schema_type, field.required);
writeln!(output, " {}: {}", gql_name, field_type).unwrap();
}
writeln!(output, "}}").unwrap();
writeln!(output).unwrap();
}
}
}
writeln!(output, "union {} = {}", name, variant_type_names.join(" | ")).unwrap();
output
}
fn get_graphql_type(schema_type: &SchemaType, required: bool) -> String {
let base_type = match schema_type {
SchemaType::Primitive(prim) => get_primitive_type(prim),
SchemaType::Option(inner) => {
return get_graphql_type(inner, false);
}
SchemaType::Array(inner) => {
format!("[{}]", get_graphql_type(inner, true))
}
SchemaType::Set(inner) => {
format!("[{}]", get_graphql_type(inner, true))
}
SchemaType::Map(_value_type) => {
"JSON".to_string()
}
SchemaType::Tuple(types) => {
if types.len() == 1 {
get_graphql_type(&types[0], true)
} else {
"JSON".to_string()
}
}
SchemaType::Struct(_) => {
"JSON".to_string()
}
SchemaType::Enum(_) => {
"String".to_string()
}
SchemaType::Newtype(def) => {
return get_graphql_type(&def.inner_type, required);
}
SchemaType::Reference(name) => to_pascal_case(name),
SchemaType::Unit => "Boolean".to_string(), SchemaType::Any => "JSON".to_string(),
};
if required {
format!("{}!", base_type)
} else {
base_type
}
}
fn get_graphql_input_type(schema_type: &SchemaType, required: bool) -> String {
let base_type = match schema_type {
SchemaType::Primitive(prim) => get_primitive_type(prim),
SchemaType::Option(inner) => {
return get_graphql_input_type(inner, false);
}
SchemaType::Array(inner) => {
format!("[{}]", get_graphql_input_type(inner, true))
}
SchemaType::Set(inner) => {
format!("[{}]", get_graphql_input_type(inner, true))
}
SchemaType::Map(_) => "JSON".to_string(),
SchemaType::Tuple(types) => {
if types.len() == 1 {
get_graphql_input_type(&types[0], true)
} else {
"JSON".to_string()
}
}
SchemaType::Struct(_) => "JSON".to_string(),
SchemaType::Enum(_) => "String".to_string(),
SchemaType::Newtype(def) => {
return get_graphql_input_type(&def.inner_type, required);
}
SchemaType::Reference(name) => format!("{}Input", to_pascal_case(name)),
SchemaType::Unit => "Boolean".to_string(),
SchemaType::Any => "JSON".to_string(),
};
if required {
format!("{}!", base_type)
} else {
base_type
}
}
fn get_primitive_type(prim: &PrimitiveType) -> String {
match prim {
PrimitiveType::Bool => "Boolean".to_string(),
PrimitiveType::I8
| PrimitiveType::I16
| PrimitiveType::I32
| PrimitiveType::U8
| PrimitiveType::U16 => "Int".to_string(),
PrimitiveType::I64
| PrimitiveType::I128
| PrimitiveType::Isize
| PrimitiveType::U32
| PrimitiveType::U64
| PrimitiveType::U128
| PrimitiveType::Usize => "BigInt".to_string(),
PrimitiveType::F32 | PrimitiveType::F64 => "Float".to_string(),
PrimitiveType::Char | PrimitiveType::String => "String".to_string(),
PrimitiveType::Bytes => "String".to_string(), }
}
fn write_description(output: &mut String, description: &str, indent: usize) {
let indent_str = " ".repeat(indent);
if description.contains('\n') {
writeln!(output, "{}\"\"\"", indent_str).unwrap();
for line in description.lines() {
writeln!(output, "{}{}", indent_str, line).unwrap();
}
writeln!(output, "{}\"\"\"", indent_str).unwrap();
} else {
writeln!(output, "{}\"{}\"", indent_str, escape_string(description)).unwrap();
}
}
fn escape_string(s: &str) -> String {
s.replace('\\', "\\\\")
.replace('"', "\\\"")
.replace('\n', "\\n")
.replace('\r', "\\r")
.replace('\t', "\\t")
}
fn to_pascal_case(s: &str) -> String {
let mut result = String::with_capacity(s.len());
let mut capitalize_next = true;
for c in s.chars() {
if c == '_' || c == '-' || c == ' ' {
capitalize_next = true;
} else if capitalize_next {
result.extend(c.to_uppercase());
capitalize_next = false;
} else {
result.push(c);
}
}
result
}
fn to_camel_case(s: &str) -> String {
let pascal = to_pascal_case(s);
let mut chars = pascal.chars();
match chars.next() {
None => String::new(),
Some(first) => first.to_lowercase().chain(chars).collect(),
}
}
fn to_screaming_snake_case(s: &str) -> String {
let mut result = String::with_capacity(s.len() + 4);
let mut prev_was_uppercase = false;
let mut prev_was_separator = true;
for c in s.chars() {
if c == '-' || c == ' ' {
result.push('_');
prev_was_separator = true;
prev_was_uppercase = false;
} else if c.is_uppercase() {
if !prev_was_separator && !prev_was_uppercase {
result.push('_');
}
result.extend(c.to_uppercase());
prev_was_uppercase = true;
prev_was_separator = false;
} else {
result.extend(c.to_uppercase());
prev_was_uppercase = false;
prev_was_separator = c == '_';
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{EnumRepresentation, EnumVariant, StructField};
use indexmap::IndexMap;
#[test]
fn test_primitive_types() {
assert_eq!(get_primitive_type(&PrimitiveType::Bool), "Boolean");
assert_eq!(get_primitive_type(&PrimitiveType::I32), "Int");
assert_eq!(get_primitive_type(&PrimitiveType::I64), "BigInt");
assert_eq!(get_primitive_type(&PrimitiveType::F64), "Float");
assert_eq!(get_primitive_type(&PrimitiveType::String), "String");
}
#[test]
fn test_graphql_type_required() {
let string_type = SchemaType::Primitive(PrimitiveType::String);
assert_eq!(get_graphql_type(&string_type, true), "String!");
assert_eq!(get_graphql_type(&string_type, false), "String");
}
#[test]
fn test_option_type() {
let opt = SchemaType::Option(Box::new(SchemaType::Primitive(PrimitiveType::String)));
assert_eq!(get_graphql_type(&opt, true), "String");
assert_eq!(get_graphql_type(&opt, false), "String");
}
#[test]
fn test_array_type() {
let arr = SchemaType::Array(Box::new(SchemaType::Primitive(PrimitiveType::I32)));
assert_eq!(get_graphql_type(&arr, true), "[Int!]!");
assert_eq!(get_graphql_type(&arr, false), "[Int!]");
}
#[test]
fn test_reference_type() {
let ref_type = SchemaType::Reference("user".to_string());
assert_eq!(get_graphql_type(&ref_type, true), "User!");
}
#[test]
fn test_simple_struct() {
let def = StructDefinition::new()
.with_field(
"name",
StructField::new(SchemaType::Primitive(PrimitiveType::String), "name"),
)
.with_field(
"age",
StructField::new(
SchemaType::Option(Box::new(SchemaType::Primitive(PrimitiveType::I32))),
"age",
),
);
let output = generate_graphql_type("User", &def);
assert!(output.contains("type User {"));
assert!(output.contains("name: String!"));
assert!(output.contains("age: Int"));
assert!(output.contains("input UserInput {"));
}
#[test]
fn test_simple_enum() {
let def = EnumDefinition::new(EnumRepresentation::External)
.with_variant(EnumVariant::unit("Active"))
.with_variant(EnumVariant::unit("Inactive"));
let output = generate_enum_type("Status", &def);
assert!(output.contains("enum Status {"));
assert!(output.contains("ACTIVE"));
assert!(output.contains("INACTIVE"));
}
#[test]
fn test_complex_enum() {
let mut fields = IndexMap::new();
fields.insert(
"reason".to_string(),
StructField::new(SchemaType::Primitive(PrimitiveType::String), "reason"),
);
let def = EnumDefinition::new(EnumRepresentation::External)
.with_variant(EnumVariant::unit("Active"))
.with_variant(EnumVariant::struct_variant("Suspended", fields));
let output = generate_enum_type("Status", &def);
assert!(output.contains("type Status_Active {"));
assert!(output.contains("type Status_Suspended {"));
assert!(output.contains("union Status = Status_Active | Status_Suspended"));
}
#[test]
fn test_case_conversions() {
assert_eq!(to_pascal_case("user_name"), "UserName");
assert_eq!(to_pascal_case("user-name"), "UserName");
assert_eq!(to_camel_case("user_name"), "userName");
assert_eq!(to_screaming_snake_case("userName"), "USER_NAME");
assert_eq!(to_screaming_snake_case("UserName"), "USER_NAME");
}
#[test]
fn test_deprecated_field() {
let mut field = StructField::new(SchemaType::Primitive(PrimitiveType::String), "old");
field.deprecated = true;
let def = StructDefinition::new()
.with_field("old", field);
let output = generate_graphql_type("Test", &def);
assert!(output.contains("@deprecated"));
}
#[test]
fn test_full_definition() {
let def = SchemaDefinition::new(
"User",
SchemaType::Struct(
StructDefinition::new().with_field(
"id",
StructField::new(SchemaType::Primitive(PrimitiveType::U64), "id"),
),
),
)
.with_description("A user in the system");
let output = generate(&def);
assert!(output.contains("\"A user in the system\""));
assert!(output.contains("type User"));
}
}