Skip to main content

lemma/formatting/
mod.rs

1//! Lemma source code formatting.
2//!
3//! Formats parsed specs into canonical Lemma source text. Uses `AsLemmaSource`
4//! and `Expression::Display` for syntax; this module handles layout only.
5
6use crate::parsing::ast::{
7    expression_precedence, AsLemmaSource, DataValue, Expression, ExpressionKind, LemmaData,
8    LemmaRule, LemmaSpec,
9};
10use crate::{parse, Error, ResourceLimits};
11
12/// Soft line length limit. Longer lines may be wrapped (unless clauses, expressions).
13/// Data and other constructs are not broken if they exceed this.
14pub const MAX_COLS: usize = 60;
15
16// =============================================================================
17// Public entry points
18// =============================================================================
19
20/// Format a sequence of parsed specs into canonical Lemma source.
21///
22/// Specs are separated by two blank lines.
23/// The result ends with a single newline.
24#[must_use]
25pub fn format_specs(specs: &[LemmaSpec]) -> String {
26    let mut out = String::new();
27    for (index, spec) in specs.iter().enumerate() {
28        if index > 0 {
29            out.push_str("\n\n");
30        }
31        out.push_str(&format_spec(spec, MAX_COLS));
32    }
33    if !out.ends_with('\n') {
34        out.push('\n');
35    }
36    out
37}
38
39/// Parse a source string and format it to canonical Lemma source.
40///
41/// Returns an error if the source does not parse.
42pub fn format_source(source: &str, attribute: &str) -> Result<String, Error> {
43    let limits = ResourceLimits::default();
44    let result = parse(source, attribute, &limits)?;
45    Ok(format_specs(&result.specs))
46}
47
48// =============================================================================
49// Spec
50// =============================================================================
51
52pub(crate) fn format_spec(spec: &LemmaSpec, max_cols: usize) -> String {
53    let mut out = String::new();
54    out.push_str("spec ");
55    out.push_str(&spec.name);
56    if let crate::parsing::ast::EffectiveDate::DateTimeValue(ref af) = spec.effective_from {
57        out.push(' ');
58        out.push_str(&af.to_string());
59    }
60    out.push('\n');
61
62    if let Some(ref commentary) = spec.commentary {
63        out.push_str("\"\"\"\n");
64        out.push_str(commentary);
65        out.push_str("\n\"\"\"\n");
66    }
67
68    for meta in &spec.meta_fields {
69        out.push_str(&format!(
70            "meta {}: {}\n",
71            meta.key,
72            AsLemmaSource(&meta.value)
73        ));
74    }
75
76    if !spec.data.is_empty() {
77        format_sorted_data(&spec.data, &mut out);
78    }
79
80    if !spec.rules.is_empty() {
81        out.push('\n');
82        for (index, rule) in spec.rules.iter().enumerate() {
83            if index > 0 {
84                out.push('\n');
85            }
86            out.push_str(&format_rule(rule, max_cols));
87        }
88    }
89
90    out
91}
92
93// =============================================================================
94// Data
95// =============================================================================
96
97/// Format a data, optionally with the reference name padded to `align_width` characters
98/// for column-aligned `=` signs within a group.
99/// When `align_width` is 0 or less than the reference length, no padding is added.
100fn format_data(data: &LemmaData, align_width: usize) -> String {
101    let ref_str = format!("{}", data.reference);
102    let padding = if align_width > ref_str.len() {
103        " ".repeat(align_width - ref_str.len())
104    } else {
105        String::new()
106    };
107    match &data.value {
108        DataValue::TypeDeclaration {
109            base,
110            constraints,
111            from,
112        } if from.is_some() && constraints.is_none() => {
113            format!(
114                "data {}{} from {}",
115                ref_str,
116                padding,
117                from.as_ref().unwrap()
118            )
119        }
120        _ => {
121            format!(
122                "data {}{} : {}",
123                ref_str,
124                padding,
125                AsLemmaSource(&data.value)
126            )
127        }
128    }
129}
130
131/// Compute the maximum data reference width across a slice of data.
132fn max_ref_width(data: &[&LemmaData]) -> usize {
133    data.iter()
134        .map(|f| format!("{}", f.reference).len())
135        .max()
136        .unwrap_or(0)
137}
138
139fn format_with_statement(data: &LemmaData) -> String {
140    let alias = &data.reference.name;
141    if let DataValue::SpecReference(spec_ref) = &data.value {
142        let spec_name = &spec_ref.name;
143        let last_segment = spec_name.rsplit('/').next().unwrap_or(spec_name);
144        if alias == last_segment {
145            format!("with {}", spec_ref)
146        } else {
147            format!("with {}: {}", alias, spec_ref)
148        }
149    } else {
150        unreachable!("BUG: format_with_statement called on non-SpecReference data")
151    }
152}
153
154/// Group data into two sections separated by a blank line:
155///
156/// 1. Regular data (literals, type declarations) — original order, aligned
157/// 2. With statements (spec refs), each followed by their literal bindings — original order
158fn format_sorted_data(data: &[LemmaData], out: &mut String) {
159    let mut regular: Vec<&LemmaData> = Vec::new();
160    let mut spec_refs: Vec<&LemmaData> = Vec::new();
161    let mut overrides: Vec<&LemmaData> = Vec::new();
162
163    for data in data {
164        if !data.reference.is_local() {
165            overrides.push(data);
166        } else if matches!(&data.value, DataValue::SpecReference(_)) {
167            spec_refs.push(data);
168        } else {
169            regular.push(data);
170        }
171    }
172
173    let emit_group = |data: &[&LemmaData], out: &mut String| {
174        let width = max_ref_width(data);
175        for data in data {
176            out.push_str(&format_data(data, width));
177            out.push('\n');
178        }
179    };
180
181    if !regular.is_empty() {
182        out.push('\n');
183        emit_group(&regular, out);
184    }
185
186    if !spec_refs.is_empty() {
187        out.push('\n');
188
189        let has_overrides = |spec_data: &LemmaData| -> bool {
190            let ref_name = &spec_data.reference.name;
191            overrides.iter().any(|o| {
192                o.reference.segments.first().map(|s| s.as_str()) == Some(ref_name.as_str())
193            })
194        };
195
196        let is_bare = |spec_data: &LemmaData| -> bool {
197            if let DataValue::SpecReference(sr) = &spec_data.value {
198                let last = sr.name.rsplit('/').next().unwrap_or(&sr.name);
199                spec_data.reference.name == last
200                    && sr.effective.is_none()
201                    && !has_overrides(spec_data)
202            } else {
203                false
204            }
205        };
206
207        let mut i = 0;
208        while i < spec_refs.len() {
209            if i > 0 {
210                out.push('\n');
211            }
212            if is_bare(spec_refs[i]) {
213                // Collect consecutive bare refs into a comma-separated line
214                let mut group_names = Vec::new();
215                while i < spec_refs.len() && is_bare(spec_refs[i]) {
216                    if let DataValue::SpecReference(sr) = &spec_refs[i].value {
217                        group_names.push(sr.to_string());
218                    }
219                    i += 1;
220                }
221                if group_names.len() == 1 {
222                    out.push_str(&format!("with {}", group_names[0]));
223                } else {
224                    out.push_str(&format!("with {}", group_names.join(", ")));
225                }
226                out.push('\n');
227            } else {
228                let spec_data = spec_refs[i];
229                out.push_str(&format_with_statement(spec_data));
230                out.push('\n');
231                let ref_name = &spec_data.reference.name;
232                let binding_overrides: Vec<&LemmaData> = overrides
233                    .iter()
234                    .filter(|o| {
235                        o.reference.segments.first().map(|s| s.as_str()) == Some(ref_name.as_str())
236                    })
237                    .copied()
238                    .collect();
239                if !binding_overrides.is_empty() {
240                    let width = max_ref_width(&binding_overrides);
241                    for ovr in &binding_overrides {
242                        out.push_str(&format_data(ovr, width));
243                        out.push('\n');
244                    }
245                }
246                i += 1;
247            }
248        }
249    }
250
251    let matched_prefixes: Vec<&str> = spec_refs
252        .iter()
253        .map(|f| f.reference.name.as_str())
254        .collect();
255    let unmatched: Vec<&LemmaData> = overrides
256        .iter()
257        .filter(|o| {
258            o.reference
259                .segments
260                .first()
261                .map(|s| !matched_prefixes.contains(&s.as_str()))
262                .unwrap_or(true)
263        })
264        .copied()
265        .collect();
266    if !unmatched.is_empty() {
267        out.push('\n');
268        emit_group(&unmatched, out);
269    }
270}
271
272// =============================================================================
273// Rules
274// =============================================================================
275
276/// Format a rule with optional line wrapping: long unless lines get "then" on
277/// the next line; long expressions break at arithmetic operators.
278fn format_rule(rule: &LemmaRule, max_cols: usize) -> String {
279    let expr_indent = "  ";
280    let body = format_expr_wrapped(&rule.expression, max_cols, expr_indent, 10);
281    let mut out = String::new();
282    out.push_str("rule ");
283    out.push_str(&rule.name);
284    out.push_str(":\n");
285    out.push_str(expr_indent);
286    out.push_str(&body);
287
288    for unless_clause in &rule.unless_clauses {
289        let condition_str = format_expr_wrapped(&unless_clause.condition, max_cols, "    ", 10);
290        let result_str = format_expr_wrapped(&unless_clause.result, max_cols, "    ", 10);
291        let line = format!("  unless {} then {}", condition_str, result_str);
292        if line.len() > max_cols {
293            out.push_str("\n  unless ");
294            out.push_str(&condition_str);
295            out.push_str("\n    then ");
296            out.push_str(&result_str);
297        } else {
298            out.push_str("\n  unless ");
299            out.push_str(&condition_str);
300            out.push_str(" then ");
301            out.push_str(&result_str);
302        }
303    }
304    out.push('\n');
305    out
306}
307
308// =============================================================================
309// Expression wrapping (soft line breaking at max_cols)
310// =============================================================================
311
312/// Indent every line after the first by `indent`.
313fn indent_after_first_line(s: &str, indent: &str) -> String {
314    let mut first = true;
315    let mut out = String::new();
316    for line in s.lines() {
317        if first {
318            first = false;
319            out.push_str(line);
320        } else {
321            out.push('\n');
322            out.push_str(indent);
323            out.push_str(line);
324        }
325    }
326    if s.ends_with('\n') {
327        out.push('\n');
328    }
329    out
330}
331
332/// Format an expression with optional wrapping at arithmetic operators when over max_cols.
333/// `parent_prec` is used to add parentheses when needed (pass 10 for top level).
334fn format_expr_wrapped(
335    expr: &Expression,
336    max_cols: usize,
337    indent: &str,
338    parent_prec: u8,
339) -> String {
340    let my_prec = expression_precedence(&expr.kind);
341
342    let wrap_in_parens = |s: String| {
343        if parent_prec < 10 && my_prec < parent_prec {
344            format!("({})", s)
345        } else {
346            s
347        }
348    };
349
350    match &expr.kind {
351        ExpressionKind::Arithmetic(left, op, right) => {
352            let left_str = format_expr_wrapped(left.as_ref(), max_cols, indent, my_prec);
353            let right_str = format_expr_wrapped(right.as_ref(), max_cols, indent, my_prec);
354            let single_line = format!("{} {} {}", left_str, op, right_str);
355            if single_line.len() <= max_cols && !single_line.contains('\n') {
356                return wrap_in_parens(single_line);
357            }
358            let continued_right = indent_after_first_line(&right_str, indent);
359            let continuation = format!("{}{} {}", indent, op, continued_right);
360            let multi_line = format!("{}\n{}", left_str, continuation);
361            wrap_in_parens(multi_line)
362        }
363        _ => {
364            let s = expr.to_string();
365            wrap_in_parens(s)
366        }
367    }
368}
369
370// =============================================================================
371// Tests
372// =============================================================================
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377    use crate::parsing::ast::{
378        AsLemmaSource, BooleanValue, DateTimeValue, DurationUnit, TimeValue, TimezoneValue, Value,
379    };
380    use rust_decimal::prelude::FromStr;
381    use rust_decimal::Decimal;
382
383    /// Helper: format a Value as canonical Lemma source via AsLemmaSource.
384    fn fmt_value(v: &Value) -> String {
385        format!("{}", AsLemmaSource(v))
386    }
387
388    #[test]
389    fn test_format_value_text_is_quoted() {
390        let v = Value::Text("light".to_string());
391        assert_eq!(fmt_value(&v), "\"light\"");
392    }
393
394    #[test]
395    fn test_format_value_text_escapes_quotes() {
396        let v = Value::Text("say \"hello\"".to_string());
397        assert_eq!(fmt_value(&v), "\"say \\\"hello\\\"\"");
398    }
399
400    #[test]
401    fn test_format_value_number() {
402        let v = Value::Number(Decimal::from_str("42.50").unwrap());
403        assert_eq!(fmt_value(&v), "42.50");
404    }
405
406    #[test]
407    fn test_format_value_number_integer() {
408        let v = Value::Number(Decimal::from_str("100.00").unwrap());
409        assert_eq!(fmt_value(&v), "100");
410    }
411
412    #[test]
413    fn test_format_value_boolean() {
414        assert_eq!(fmt_value(&Value::Boolean(BooleanValue::True)), "true");
415        assert_eq!(fmt_value(&Value::Boolean(BooleanValue::Yes)), "yes");
416        assert_eq!(fmt_value(&Value::Boolean(BooleanValue::No)), "no");
417        assert_eq!(fmt_value(&Value::Boolean(BooleanValue::Accept)), "accept");
418        assert_eq!(fmt_value(&Value::Boolean(BooleanValue::Reject)), "reject");
419    }
420
421    #[test]
422    fn test_format_value_scale() {
423        let v = Value::Scale(Decimal::from_str("99.50").unwrap(), "eur".to_string());
424        assert_eq!(fmt_value(&v), "99.50 eur");
425    }
426
427    #[test]
428    fn test_format_value_duration() {
429        let v = Value::Duration(Decimal::from(40), DurationUnit::Hour);
430        assert_eq!(fmt_value(&v), "40 hours");
431    }
432
433    #[test]
434    fn test_format_value_ratio_percent() {
435        let v = Value::Ratio(
436            Decimal::from_str("0.10").unwrap(),
437            Some("percent".to_string()),
438        );
439        assert_eq!(fmt_value(&v), "10%");
440    }
441
442    #[test]
443    fn test_format_value_ratio_permille() {
444        let v = Value::Ratio(
445            Decimal::from_str("0.005").unwrap(),
446            Some("permille".to_string()),
447        );
448        assert_eq!(fmt_value(&v), "5%%");
449    }
450
451    #[test]
452    fn test_format_value_ratio_bare() {
453        let v = Value::Ratio(Decimal::from_str("0.25").unwrap(), None);
454        assert_eq!(fmt_value(&v), "0.25");
455    }
456
457    #[test]
458    fn test_format_value_date_only() {
459        let v = Value::Date(DateTimeValue {
460            year: 2024,
461            month: 1,
462            day: 15,
463            hour: 0,
464            minute: 0,
465            second: 0,
466            microsecond: 0,
467            timezone: None,
468        });
469        assert_eq!(fmt_value(&v), "2024-01-15");
470    }
471
472    #[test]
473    fn test_format_value_datetime_with_tz() {
474        let v = Value::Date(DateTimeValue {
475            year: 2024,
476            month: 1,
477            day: 15,
478            hour: 14,
479            minute: 30,
480            second: 0,
481            microsecond: 0,
482            timezone: Some(TimezoneValue {
483                offset_hours: 0,
484                offset_minutes: 0,
485            }),
486        });
487        assert_eq!(fmt_value(&v), "2024-01-15T14:30:00Z");
488    }
489
490    #[test]
491    fn test_format_value_time() {
492        let v = Value::Time(TimeValue {
493            hour: 14,
494            minute: 30,
495            second: 45,
496            timezone: None,
497        });
498        assert_eq!(fmt_value(&v), "14:30:45");
499    }
500
501    #[test]
502    fn test_format_source_round_trips_text() {
503        let source = r#"spec test
504
505data name: "Alice"
506
507rule greeting: "hello"
508"#;
509        let formatted = format_source(source, "test.lemma").unwrap();
510        assert!(formatted.contains("\"Alice\""), "data text must be quoted");
511        assert!(formatted.contains("\"hello\""), "rule text must be quoted");
512    }
513
514    #[test]
515    fn test_format_source_preserves_percent() {
516        let source = r#"spec test
517
518data rate: 10 percent
519
520rule tax: rate * 21%
521"#;
522        let formatted = format_source(source, "test.lemma").unwrap();
523        assert!(
524            formatted.contains("10%"),
525            "data percent must use shorthand %, got: {}",
526            formatted
527        );
528    }
529
530    #[test]
531    fn test_format_groups_data_preserving_order() {
532        // Data are deliberately mixed: the formatter keeps all regular data together
533        // in original order, aligned
534        let source = r#"spec test
535
536data income: number -> minimum 0
537data filing_status: filing_status_type -> default "single"
538data country: "NL"
539data deductions: number -> minimum 0
540data name: text
541
542rule total: income
543"#;
544        let formatted = format_source(source, "test.lemma").unwrap();
545        let data_section = formatted
546            .split("rule total")
547            .next()
548            .unwrap()
549            .split("spec test\n")
550            .nth(1)
551            .unwrap();
552        let lines: Vec<&str> = data_section.lines().filter(|l| !l.is_empty()).collect();
553        // All regular data in one group, original order, aligned
554        assert_eq!(lines[0], "data income        : number -> minimum 0");
555        assert_eq!(
556            lines[1],
557            "data filing_status : filing_status_type -> default \"single\""
558        );
559        assert_eq!(lines[2], "data country       : \"NL\"");
560        assert_eq!(lines[3], "data deductions    : number -> minimum 0");
561        assert_eq!(lines[4], "data name          : text");
562    }
563
564    #[test]
565    fn test_format_groups_spec_refs_with_overrides() {
566        let source = r#"spec test
567
568data retail.quantity: 5
569with order/wholesale
570with order/retail
571data wholesale.quantity: 100
572data base_price: 50
573
574rule total: base_price
575"#;
576        let formatted = format_source(source, "test.lemma").unwrap();
577        let data_section = formatted
578            .split("rule total")
579            .next()
580            .unwrap()
581            .split("spec test\n")
582            .nth(1)
583            .unwrap();
584        let lines: Vec<&str> = data_section.lines().filter(|l| !l.is_empty()).collect();
585        // Group 1: Literals
586        assert_eq!(lines[0], "data base_price : 50");
587        // Group 4: Spec refs in original order, each with its overrides, aligned
588        assert_eq!(lines[1], "with order/wholesale");
589        assert_eq!(lines[2], "data wholesale.quantity : 100");
590        assert_eq!(lines[3], "with order/retail");
591        assert_eq!(lines[4], "data retail.quantity : 5");
592    }
593
594    #[test]
595    fn test_format_source_weather_clothing_text_quoted() {
596        let source = r#"spec weather_clothing
597
598data clothing_style: text
599  -> option "light"
600  -> option "warm"
601
602data temperature: number
603
604rule clothing_layer: "light"
605  unless temperature < 5 then "warm"
606"#;
607        let formatted = format_source(source, "test.lemma").unwrap();
608        assert!(
609            formatted.contains("\"light\""),
610            "text in rule must be quoted, got: {}",
611            formatted
612        );
613        assert!(
614            formatted.contains("\"warm\""),
615            "text in unless must be quoted, got: {}",
616            formatted
617        );
618    }
619
620    // NOTE: Default value type validation (e.g. rejecting "10 $$" as a number
621    // default) is tested at the planning level in engine.rs, not here. The
622    // formatter only parses — it does not validate types. Planning catches
623    // invalid defaults for both primitives and named types.
624
625    #[test]
626    fn test_format_text_option_round_trips() {
627        let source = r#"spec test
628
629data status: text
630  -> option "active"
631  -> option "inactive"
632
633data s: status
634
635rule out: s
636"#;
637        let formatted = format_source(source, "test.lemma").unwrap();
638        assert!(
639            formatted.contains("option \"active\""),
640            "text option must be quoted, got: {}",
641            formatted
642        );
643        assert!(
644            formatted.contains("option \"inactive\""),
645            "text option must be quoted, got: {}",
646            formatted
647        );
648        // Round-trip
649        let reparsed = format_source(&formatted, "test.lemma");
650        assert!(reparsed.is_ok(), "formatted output should re-parse");
651    }
652
653    #[test]
654    fn test_format_help_round_trips() {
655        let source = r#"spec test
656data quantity: number -> help "Number of items to order"
657rule total: quantity
658"#;
659        let formatted = format_source(source, "test.lemma").unwrap();
660        assert!(
661            formatted.contains("help \"Number of items to order\""),
662            "help must be quoted, got: {}",
663            formatted
664        );
665        // Round-trip
666        let reparsed = format_source(&formatted, "test.lemma");
667        assert!(reparsed.is_ok(), "formatted output should re-parse");
668    }
669
670    #[test]
671    fn test_format_scale_type_def_round_trips() {
672        let source = r#"spec test
673
674data money: scale
675  -> unit eur 1.00
676  -> unit usd 1.10
677  -> decimals 2
678  -> minimum 0
679
680data price: money
681
682rule total: price
683"#;
684        let formatted = format_source(source, "test.lemma").unwrap();
685        assert!(
686            formatted.contains("unit eur 1.00"),
687            "scale unit should not be quoted, got: {}",
688            formatted
689        );
690        // Round-trip
691        let reparsed = format_source(&formatted, "test.lemma");
692        assert!(
693            reparsed.is_ok(),
694            "formatted output should re-parse, got: {:?}",
695            reparsed
696        );
697    }
698
699    #[test]
700    fn test_format_expression_display_stable_round_trip() {
701        let source = r#"spec test
702data a: 1.00
703rule r: a + 2.00 * 3
704"#;
705        let formatted = format_source(source, "test.lemma").unwrap();
706        let again = format_source(&formatted, "test.lemma").unwrap();
707        assert_eq!(
708            formatted, again,
709            "AST Display-based format must be idempotent under parse/format"
710        );
711    }
712}