use crate::ir::{Class, Enum, FieldType, IR, TaggedEnum};
use std::collections::HashSet;
pub struct SchemaFormatter<'a> {
ir: &'a IR,
}
impl<'a> SchemaFormatter<'a> {
pub fn new(ir: &'a IR) -> Self {
Self { ir }
}
pub fn render(&mut self, output_type: &FieldType) -> String {
let mut result = String::new();
let (enums, tagged_enums, _classes) = self.collect_dependencies(output_type);
for enum_name in enums {
if let Some(e) = self.ir.find_enum(&enum_name) {
result.push_str(&self.render_enum(e));
result.push_str("\n\n");
}
}
for tagged_enum_name in tagged_enums {
if let Some(te) = self.ir.find_tagged_enum(&tagged_enum_name) {
result.push_str(&self.render_tagged_enum(te));
result.push_str("\n\n");
}
}
match output_type {
FieldType::Bool => {
result.push_str("Answer with ONLY a JSON boolean value (true or false), nothing else. Do not wrap in an object.");
}
FieldType::String => {
result.push_str("Answer with ONLY a JSON string value, nothing else. Do not wrap in an object.");
}
FieldType::Int => {
result.push_str("Answer with ONLY a JSON integer value, nothing else. Do not wrap in an object.");
}
FieldType::Float => {
result.push_str("Answer with ONLY a JSON float value, nothing else. Do not wrap in an object.");
}
_ => {
result.push_str("Answer in JSON using this schema:\n");
result.push_str(&self.render_type(output_type, 0));
}
}
result
}
fn collect_dependencies(&self, field_type: &FieldType) -> (Vec<String>, Vec<String>, Vec<String>) {
let mut enums = Vec::new();
let mut tagged_enums = Vec::new();
let mut classes = Vec::new();
let mut visited = HashSet::new();
self.collect_deps_recursive(field_type, &mut enums, &mut tagged_enums, &mut classes, &mut visited);
(enums, tagged_enums, classes)
}
fn collect_deps_recursive(
&self,
field_type: &FieldType,
enums: &mut Vec<String>,
tagged_enums: &mut Vec<String>,
classes: &mut Vec<String>,
visited: &mut HashSet<String>,
) {
match field_type {
FieldType::Enum(name) => {
if !visited.contains(name) {
visited.insert(name.clone());
enums.push(name.clone());
}
}
FieldType::TaggedEnum(name) => {
if !visited.contains(name) {
visited.insert(name.clone());
tagged_enums.push(name.clone());
if let Some(te) = self.ir.find_tagged_enum(name) {
for variant in &te.variants {
for field in &variant.fields {
self.collect_deps_recursive(&field.field_type, enums, tagged_enums, classes, visited);
}
}
}
}
}
FieldType::Class(name) => {
if !visited.contains(name) {
visited.insert(name.clone());
if let Some(class) = self.ir.find_class(name) {
classes.push(name.clone());
for field in &class.fields {
self.collect_deps_recursive(&field.field_type, enums, tagged_enums, classes, visited);
}
} else if self.ir.find_enum(name).is_some() {
enums.push(name.clone());
} else if self.ir.find_tagged_enum(name).is_some() {
tagged_enums.push(name.clone());
} else {
classes.push(name.clone());
}
}
}
FieldType::List(inner) => {
self.collect_deps_recursive(inner, enums, tagged_enums, classes, visited);
}
FieldType::Map(k, v) => {
self.collect_deps_recursive(k, enums, tagged_enums, classes, visited);
self.collect_deps_recursive(v, enums, tagged_enums, classes, visited);
}
FieldType::Union(types) => {
for t in types {
self.collect_deps_recursive(t, enums, tagged_enums, classes, visited);
}
}
_ => {}
}
}
fn render_enum(&self, e: &Enum) -> String {
let mut result = String::new();
if let Some(desc) = &e.description {
result.push_str(&format!("{} ({})\n", e.name, desc));
} else {
result.push_str(&format!("{}\n", e.name));
}
result.push_str(&"-".repeat(e.name.len()));
result.push('\n');
for value in &e.values {
result.push_str(&format!("- {}\n", value));
}
result.trim_end().to_string()
}
fn render_tagged_enum(&self, te: &TaggedEnum) -> String {
let mut result = String::new();
if let Some(desc) = &te.description {
result.push_str(&format!("{} ({})\n", te.name, desc));
} else {
result.push_str(&format!("{}\n", te.name));
}
result.push_str(&"-".repeat(te.name.len()));
result.push('\n');
result.push_str(&format!("Variants (set \"{}\" field to pick one):\n", te.tag_field));
for variant in &te.variants {
let fields_str = if variant.fields.is_empty() {
"{}".to_string()
} else {
let field_strs: Vec<String> = variant.fields.iter().map(|f| {
let optional_marker = if f.optional { "?" } else { "" };
format!("{}{}: {}", f.name, optional_marker, self.render_type(&f.field_type, 0))
}).collect();
format!("{{ {} }}", field_strs.join(", "))
};
if let Some(desc) = &variant.description {
result.push_str(&format!("- {}: {} // {}\n", variant.name, fields_str, desc));
} else {
result.push_str(&format!("- {}: {}\n", variant.name, fields_str));
}
}
result.trim_end().to_string()
}
fn render_type(&self, field_type: &FieldType, indent: usize) -> String {
let indent_str = " ".repeat(indent);
match field_type {
FieldType::String => "string".to_string(),
FieldType::Int => "int".to_string(),
FieldType::Float => "float".to_string(),
FieldType::Bool => "bool".to_string(),
FieldType::Enum(name) => name.clone(),
FieldType::TaggedEnum(name) => name.clone(),
FieldType::Class(name) => {
if let Some(class) = self.ir.find_class(name) {
self.render_class(class, indent)
} else if self.ir.find_enum(name).is_some() {
name.clone()
} else if self.ir.find_tagged_enum(name).is_some() {
name.clone()
} else {
name.clone()
}
}
FieldType::List(inner) => {
format!("[ // array with one or more items\n{}{}\n{}]",
" ".repeat(indent + 1),
self.render_type(inner, indent + 1),
indent_str
)
}
FieldType::Map(k, v) => {
format!("map<{}, {}>",
self.render_type(k, indent),
self.render_type(v, indent)
)
}
FieldType::Union(types) => {
types.iter()
.map(|t| self.render_type(t, indent))
.collect::<Vec<_>>()
.join(" or ")
}
}
}
fn render_class(&self, class: &Class, indent: usize) -> String {
let indent_str = " ".repeat(indent);
let field_indent = " ".repeat(indent + 1);
let mut result = String::from("{\n");
if let Some(desc) = &class.description {
result.push_str(&format!("{}// {}\n", field_indent, desc));
}
for field in &class.fields {
let optional_marker = if field.optional { "?" } else { "" };
let field_line = if let Some(desc) = &field.description {
format!(
"{}{}{}: {}, // {}\n",
field_indent,
field.name,
optional_marker,
self.render_type(&field.field_type, indent + 1),
desc
)
} else {
format!(
"{}{}{}: {},\n",
field_indent,
field.name,
optional_marker,
self.render_type(&field.field_type, indent + 1)
)
};
result.push_str(&field_line);
}
result.push_str(&format!("{}}}", indent_str));
result
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::*;
#[test]
fn test_simple_schema() {
let mut ir = IR::new();
ir.classes.push(Class {
name: "Person".to_string(),
description: None,
fields: vec![
Field {
name: "name".to_string(),
field_type: FieldType::String,
optional: false,
description: None,
},
Field {
name: "age".to_string(),
field_type: FieldType::Int,
optional: false,
description: None,
},
],
});
let mut formatter = SchemaFormatter::new(&ir);
let output = formatter.render(&FieldType::Class("Person".to_string()));
assert!(output.contains("Answer in JSON using this schema:"));
assert!(output.contains("name: string"));
assert!(output.contains("age: int"));
}
#[test]
fn test_enum_schema() {
let mut ir = IR::new();
ir.enums.push(Enum {
name: "Month".to_string(),
description: None,
values: vec!["January".to_string(), "February".to_string(), "March".to_string()],
});
ir.classes.push(Class {
name: "Person".to_string(),
description: None,
fields: vec![
Field {
name: "name".to_string(),
field_type: FieldType::String,
optional: false,
description: None,
},
Field {
name: "birthMonth".to_string(),
field_type: FieldType::Enum("Month".to_string()),
optional: false,
description: None,
},
],
});
let mut formatter = SchemaFormatter::new(&ir);
let output = formatter.render(&FieldType::Class("Person".to_string()));
assert!(output.contains("Month\n----"));
assert!(output.contains("- January"));
assert!(output.contains("birthMonth: Month"));
}
#[test]
fn test_enum_referenced_as_class() {
let mut ir = IR::new();
ir.enums.push(Enum {
name: "Status".to_string(),
description: None,
values: vec!["Active".to_string(), "Inactive".to_string(), "Pending".to_string()],
});
ir.classes.push(Class {
name: "Task".to_string(),
description: None,
fields: vec![
Field {
name: "title".to_string(),
field_type: FieldType::String,
optional: false,
description: None,
},
Field {
name: "status".to_string(),
field_type: FieldType::Class("Status".to_string()),
optional: false,
description: None,
},
],
});
let mut formatter = SchemaFormatter::new(&ir);
let output = formatter.render(&FieldType::Class("Task".to_string()));
assert!(output.contains("Status\n------"), "Enum header should be present");
assert!(output.contains("- Active"), "Enum variant Active should be listed");
assert!(output.contains("- Inactive"), "Enum variant Inactive should be listed");
assert!(output.contains("- Pending"), "Enum variant Pending should be listed");
assert!(output.contains("status: Status"), "Field should reference Status enum");
}
#[test]
fn test_optional_fields_marked() {
let mut ir = IR::new();
ir.classes.push(Class {
name: "User".to_string(),
description: None,
fields: vec![
Field {
name: "name".to_string(),
field_type: FieldType::String,
optional: false,
description: None,
},
Field {
name: "email".to_string(),
field_type: FieldType::String,
optional: true,
description: None,
},
Field {
name: "age".to_string(),
field_type: FieldType::Int,
optional: true,
description: Some("User's age".to_string()),
},
],
});
let mut formatter = SchemaFormatter::new(&ir);
let output = formatter.render(&FieldType::Class("User".to_string()));
assert!(output.contains("name: string"), "Required field should not have ?");
assert!(!output.contains("name?:"), "Required field should not have ?");
assert!(output.contains("email?: string"), "Optional field should have ?");
assert!(output.contains("age?: int"), "Optional field with description should have ?");
}
}