1use crate::cli::FormatArgs;
4use crate::config::SCHEMA_FILE_PATH;
5use crate::error::{CliError, CliResult};
6use crate::output::{self, success};
7
8pub async fn run(args: FormatArgs) -> CliResult<()> {
10 output::header("Format Schema");
11
12 let cwd = std::env::current_dir()?;
13 let schema_path = args.schema.unwrap_or_else(|| cwd.join(SCHEMA_FILE_PATH));
14
15 if !schema_path.exists() {
16 return Err(
17 CliError::Config(format!("Schema file not found: {}", schema_path.display())).into(),
18 );
19 }
20
21 output::kv("Schema", &schema_path.display().to_string());
22 output::newline();
23
24 output::step(1, 3, "Reading schema...");
26 let schema_content = std::fs::read_to_string(&schema_path)?;
27
28 let schema = parse_schema(&schema_content)?;
30
31 output::step(2, 3, "Formatting...");
33 let formatted = format_schema(&schema);
34
35 let changed = formatted != schema_content;
37
38 if args.check {
39 if changed {
41 output::newline();
42 output::error("Schema is not formatted correctly!");
43 output::info("Run `prax format` to fix formatting.");
44 return Err(CliError::Format("Schema needs formatting".to_string()).into());
45 } else {
46 output::newline();
47 success("Schema is already formatted!");
48 return Ok(());
49 }
50 }
51
52 output::step(3, 3, "Writing formatted schema...");
54
55 if changed {
56 std::fs::write(&schema_path, &formatted)?;
57 output::newline();
58 success("Schema formatted successfully!");
59 } else {
60 output::newline();
61 success("Schema is already formatted!");
62 }
63
64 Ok(())
65}
66
67fn parse_schema(content: &str) -> CliResult<prax_schema::Schema> {
68 prax_schema::validate_schema(content)
71 .map_err(|e| CliError::Schema(format!("Syntax error: {}", e)))
72}
73
74fn format_schema(schema: &prax_schema::ast::Schema) -> String {
76 let mut output = String::new();
77
78 output.push_str("datasource db {\n");
81 output.push_str(" provider = \"postgresql\"\n");
82 output.push_str(" url = env(\"DATABASE_URL\")\n");
83 output.push_str("}\n");
84 let mut first_section = false;
85
86 if !first_section {
88 output.push('\n');
89 }
90 output.push_str("generator client {\n");
91 output.push_str(" provider = \"prax-client-rust\"\n");
92 output.push_str(" output = \"./src/generated\"\n");
93 output.push_str("}\n");
94 first_section = false;
95
96 for enum_def in schema.enums.values() {
98 if !first_section {
99 output.push('\n');
100 }
101 format_enum(&mut output, enum_def);
102 first_section = false;
103 }
104
105 for model in schema.models.values() {
107 if !first_section {
108 output.push('\n');
109 }
110 format_model(&mut output, model);
111 first_section = false;
112 }
113
114 for view in schema.views.values() {
116 if !first_section {
117 output.push('\n');
118 }
119 format_view(&mut output, view);
120 first_section = false;
121 }
122
123 for composite in schema.types.values() {
125 if !first_section {
126 output.push('\n');
127 }
128 format_composite(&mut output, composite);
129 first_section = false;
130 }
131
132 output
133}
134
135fn format_enum(output: &mut String, enum_def: &prax_schema::ast::Enum) {
136 if let Some(doc) = &enum_def.documentation {
138 for line in doc.text.lines() {
139 output.push_str(&format!("/// {}\n", line));
140 }
141 }
142
143 output.push_str(&format!("enum {} {{\n", enum_def.name()));
144
145 for variant in &enum_def.variants {
146 if let Some(doc) = &variant.documentation {
148 for line in doc.text.lines() {
149 output.push_str(&format!(" /// {}\n", line));
150 }
151 }
152
153 output.push_str(&format!(" {}", variant.name()));
154
155 for attr in &variant.attributes {
157 output.push_str(&format!(" {}", format_attribute(attr)));
158 }
159
160 output.push('\n');
161 }
162
163 for attr in &enum_def.attributes {
165 output.push_str(&format!("\n {}", format_attribute(attr)));
166 }
167
168 output.push_str("}\n");
169}
170
171fn format_model(output: &mut String, model: &prax_schema::ast::Model) {
172 if let Some(doc) = &model.documentation {
174 for line in doc.text.lines() {
175 output.push_str(&format!("/// {}\n", line));
176 }
177 }
178
179 output.push_str(&format!("model {} {{\n", model.name()));
180
181 let max_name_len = model
183 .fields
184 .values()
185 .map(|f| f.name().len())
186 .max()
187 .unwrap_or(0);
188
189 let max_type_len = model
190 .fields
191 .values()
192 .map(|f| format_field_type(&f.field_type, f.modifier).len())
193 .max()
194 .unwrap_or(0);
195
196 for field in model.fields.values() {
197 if let Some(doc) = &field.documentation {
199 for line in doc.text.lines() {
200 output.push_str(&format!(" /// {}\n", line));
201 }
202 }
203
204 let type_str = format_field_type(&field.field_type, field.modifier);
205
206 let padded_name = format!("{:width$}", field.name(), width = max_name_len);
208 let padded_type = format!("{:width$}", type_str, width = max_type_len);
209
210 output.push_str(&format!(" {} {}", padded_name, padded_type));
211
212 for attr in &field.attributes {
214 output.push_str(&format!(" {}", format_attribute(attr)));
215 }
216
217 output.push('\n');
218 }
219
220 let model_attrs: Vec<_> = model.attributes.iter().collect();
222 if !model_attrs.is_empty() {
223 output.push('\n');
224 for attr in model_attrs {
225 output.push_str(&format!(" {}\n", format_attribute(attr)));
226 }
227 }
228
229 output.push_str("}\n");
230}
231
232fn format_view(output: &mut String, view: &prax_schema::ast::View) {
233 if let Some(doc) = &view.documentation {
235 for line in doc.text.lines() {
236 output.push_str(&format!("/// {}\n", line));
237 }
238 }
239
240 output.push_str(&format!("view {} {{\n", view.name()));
241
242 let max_name_len = view
244 .fields
245 .values()
246 .map(|f| f.name().len())
247 .max()
248 .unwrap_or(0);
249
250 let max_type_len = view
251 .fields
252 .values()
253 .map(|f| format_field_type(&f.field_type, f.modifier).len())
254 .max()
255 .unwrap_or(0);
256
257 for field in view.fields.values() {
258 let type_str = format_field_type(&field.field_type, field.modifier);
259 let padded_name = format!("{:width$}", field.name(), width = max_name_len);
260 let padded_type = format!("{:width$}", type_str, width = max_type_len);
261
262 output.push_str(&format!(" {} {}", padded_name, padded_type));
263
264 for attr in &field.attributes {
265 output.push_str(&format!(" {}", format_attribute(attr)));
266 }
267
268 output.push('\n');
269 }
270
271 let view_attrs: Vec<_> = view.attributes.iter().collect();
273 if !view_attrs.is_empty() {
274 output.push('\n');
275 for attr in view_attrs {
276 output.push_str(&format!(" {}\n", format_attribute(attr)));
277 }
278 }
279
280 output.push_str("}\n");
281}
282
283fn format_composite(output: &mut String, composite: &prax_schema::ast::CompositeType) {
284 if let Some(doc) = &composite.documentation {
286 for line in doc.text.lines() {
287 output.push_str(&format!("/// {}\n", line));
288 }
289 }
290
291 output.push_str(&format!("type {} {{\n", composite.name()));
292
293 let max_name_len = composite
295 .fields
296 .values()
297 .map(|f| f.name().len())
298 .max()
299 .unwrap_or(0);
300
301 let max_type_len = composite
302 .fields
303 .values()
304 .map(|f| format_field_type(&f.field_type, f.modifier).len())
305 .max()
306 .unwrap_or(0);
307
308 for field in composite.fields.values() {
309 let type_str = format_field_type(&field.field_type, field.modifier);
310 let padded_name = format!("{:width$}", field.name(), width = max_name_len);
311 let padded_type = format!("{:width$}", type_str, width = max_type_len);
312
313 output.push_str(&format!(" {} {}", padded_name, padded_type));
314
315 for attr in &field.attributes {
316 output.push_str(&format!(" {}", format_attribute(attr)));
317 }
318
319 output.push('\n');
320 }
321
322 output.push_str("}\n");
323}
324
325fn format_field_type(
326 field_type: &prax_schema::ast::FieldType,
327 modifier: prax_schema::ast::TypeModifier,
328) -> String {
329 use prax_schema::ast::{FieldType, ScalarType, TypeModifier};
330
331 let base = match field_type {
332 FieldType::Scalar(scalar) => match scalar {
333 ScalarType::Int => "Int",
334 ScalarType::BigInt => "BigInt",
335 ScalarType::Float => "Float",
336 ScalarType::String => "String",
337 ScalarType::Boolean => "Boolean",
338 ScalarType::DateTime => "DateTime",
339 ScalarType::Date => "Date",
340 ScalarType::Time => "Time",
341 ScalarType::Json => "Json",
342 ScalarType::Bytes => "Bytes",
343 ScalarType::Decimal => "Decimal",
344 ScalarType::Uuid => "Uuid",
345 ScalarType::Cuid => "Cuid",
346 ScalarType::Cuid2 => "Cuid2",
347 ScalarType::NanoId => "NanoId",
348 ScalarType::Ulid => "Ulid",
349 ScalarType::Vector(_) => "Vector",
350 ScalarType::HalfVector(_) => "HalfVector",
351 ScalarType::SparseVector(_) => "SparseVector",
352 ScalarType::Bit(_) => "Bit",
353 }
354 .to_string(),
355 FieldType::Model(name) => name.to_string(),
356 FieldType::Enum(name) => name.to_string(),
357 FieldType::Composite(name) => name.to_string(),
358 FieldType::Unsupported(name) => format!("Unsupported(\"{}\")", name),
359 };
360
361 match modifier {
362 TypeModifier::Optional => format!("{}?", base),
363 TypeModifier::List => format!("{}[]", base),
364 TypeModifier::OptionalList => format!("{}[]?", base),
365 TypeModifier::Required => base,
366 }
367}
368
369fn format_attribute(attr: &prax_schema::ast::Attribute) -> String {
370 let prefix = if attr.is_model_attribute() { "@@" } else { "@" };
372
373 if attr.args.is_empty() {
374 format!("{}{}", prefix, attr.name())
375 } else {
376 let args: Vec<String> = attr
377 .args
378 .iter()
379 .map(|arg| {
380 if let Some(name) = &arg.name {
381 format!("{}: {}", name.as_str(), format_attribute_value(&arg.value))
382 } else {
383 format_attribute_value(&arg.value)
384 }
385 })
386 .collect();
387
388 format!("{}{}({})", prefix, attr.name(), args.join(", "))
389 }
390}
391
392fn format_attribute_value(value: &prax_schema::ast::AttributeValue) -> String {
393 use prax_schema::ast::AttributeValue;
394
395 match value {
396 AttributeValue::String(s) => format!("\"{}\"", s),
397 AttributeValue::Int(i) => i.to_string(),
398 AttributeValue::Float(f) => f.to_string(),
399 AttributeValue::Boolean(b) => b.to_string(),
400 AttributeValue::Ident(id) => id.to_string(),
401 AttributeValue::Function(name, args) => {
402 if args.is_empty() {
403 format!("{}()", name)
404 } else {
405 let arg_strs: Vec<String> = args.iter().map(format_attribute_value).collect();
406 format!("{}({})", name, arg_strs.join(", "))
407 }
408 }
409 AttributeValue::Array(items) => {
410 let item_strs: Vec<String> = items.iter().map(format_attribute_value).collect();
411 format!("[{}]", item_strs.join(", "))
412 }
413 AttributeValue::FieldRef(field) => field.to_string(),
414 AttributeValue::FieldRefList(fields) => {
415 format!(
416 "[{}]",
417 fields
418 .iter()
419 .map(|f| f.to_string())
420 .collect::<Vec<_>>()
421 .join(", ")
422 )
423 }
424 }
425}