Skip to main content

lemma/formatting/
mod.rs

1//! Lemma source code formatting.
2//!
3//! Formats parsed specs into canonical Lemma source text. Uses `AsLemmaSource`
4//! and `Expression::Display` for syntax; this module handles layout only.
5
6use crate::parsing::ast::{
7    expression_precedence, AsLemmaSource, Expression, ExpressionKind, FactValue, LemmaFact,
8    LemmaRule, LemmaSpec, TypeDef,
9};
10use crate::{parse, Error, ResourceLimits};
11
12/// Soft line length limit. Longer lines may be wrapped (unless clauses, expressions).
13/// Facts and other constructs are not broken if they exceed this.
14pub const MAX_COLS: usize = 60;
15
16// =============================================================================
17// Public entry points
18// =============================================================================
19
20/// Format a sequence of parsed specs into canonical Lemma source.
21///
22/// Specs are separated by two blank lines.
23/// The result ends with a single newline.
24#[must_use]
25pub fn format_specs(specs: &[LemmaSpec]) -> String {
26    let mut out = String::new();
27    for (index, spec) in specs.iter().enumerate() {
28        if index > 0 {
29            out.push_str("\n\n");
30        }
31        out.push_str(&format_spec(spec, MAX_COLS));
32    }
33    if !out.ends_with('\n') {
34        out.push('\n');
35    }
36    out
37}
38
39/// Parse a source string and format it to canonical Lemma source.
40///
41/// Returns an error if the source does not parse.
42pub fn format_source(source: &str, attribute: &str) -> Result<String, Error> {
43    let limits = ResourceLimits::default();
44    let result = parse(source, attribute, &limits)?;
45    Ok(format_specs(&result.specs))
46}
47
48// =============================================================================
49// Spec
50// =============================================================================
51
52pub(crate) fn format_spec(spec: &LemmaSpec, max_cols: usize) -> String {
53    let mut out = String::new();
54    out.push_str("spec ");
55    out.push_str(&spec.name);
56    if let Some(ref af) = spec.effective_from {
57        out.push(' ');
58        out.push_str(&af.to_string());
59    }
60    out.push('\n');
61
62    if let Some(ref commentary) = spec.commentary {
63        out.push_str("\"\"\"\n");
64        out.push_str(commentary);
65        out.push_str("\n\"\"\"\n");
66    }
67
68    for meta in &spec.meta_fields {
69        out.push_str(&format!(
70            "meta {}: {}\n",
71            meta.key,
72            AsLemmaSource(&meta.value)
73        ));
74    }
75
76    let named_types: Vec<_> = spec
77        .types
78        .iter()
79        .filter(|t| !matches!(t, TypeDef::Inline { .. }))
80        .collect();
81    if !named_types.is_empty() {
82        out.push('\n');
83        for (index, type_def) in named_types.iter().enumerate() {
84            if index > 0 {
85                out.push('\n');
86            }
87            out.push_str(&format!("{}", AsLemmaSource(*type_def)));
88            out.push('\n');
89        }
90    }
91
92    if !spec.facts.is_empty() {
93        format_sorted_facts(&spec.facts, &mut out);
94    }
95
96    if !spec.rules.is_empty() {
97        out.push('\n');
98        for (index, rule) in spec.rules.iter().enumerate() {
99            if index > 0 {
100                out.push('\n');
101            }
102            out.push_str(&format_rule(rule, max_cols));
103        }
104    }
105
106    out
107}
108
109// =============================================================================
110// Facts
111// =============================================================================
112
113/// Format a fact, optionally with the reference name padded to `align_width` characters
114/// for column-aligned `=` signs within a group.
115/// When `align_width` is 0 or less than the reference length, no padding is added.
116fn format_fact(fact: &LemmaFact, align_width: usize) -> String {
117    let ref_str = format!("{}", fact.reference);
118    let padding = if align_width > ref_str.len() {
119        " ".repeat(align_width - ref_str.len())
120    } else {
121        String::new()
122    };
123    format!(
124        "fact {}{} : {}",
125        ref_str,
126        padding,
127        AsLemmaSource(&fact.value)
128    )
129}
130
131/// Compute the maximum fact reference width across a slice of facts.
132fn max_ref_width(facts: &[&LemmaFact]) -> usize {
133    facts
134        .iter()
135        .map(|f| format!("{}", f.reference).len())
136        .max()
137        .unwrap_or(0)
138}
139
140/// Group facts into two sections separated by a blank line:
141///
142/// 1. Regular facts (literals, type declarations) — original order, aligned
143/// 2. Spec references, each followed by its cross-spec overrides — original order, aligned per sub-group
144fn format_sorted_facts(facts: &[LemmaFact], out: &mut String) {
145    let mut regular: Vec<&LemmaFact> = Vec::new();
146    let mut spec_refs: Vec<&LemmaFact> = Vec::new();
147    let mut overrides: Vec<&LemmaFact> = Vec::new();
148
149    for fact in facts {
150        if !fact.reference.is_local() {
151            overrides.push(fact);
152        } else if matches!(&fact.value, FactValue::SpecReference(_)) {
153            spec_refs.push(fact);
154        } else {
155            regular.push(fact);
156        }
157    }
158
159    // Helper: emit an aligned group of facts
160    let emit_group = |facts: &[&LemmaFact], out: &mut String| {
161        let width = max_ref_width(facts);
162        for fact in facts {
163            out.push_str(&format_fact(fact, width));
164            out.push('\n');
165        }
166    };
167
168    // Group 1: Regular facts (literals + type declarations), original order, aligned
169    if !regular.is_empty() {
170        out.push('\n');
171        emit_group(&regular, out);
172    }
173
174    // Group 2: Spec references, each followed by its overrides
175    if !spec_refs.is_empty() {
176        out.push('\n');
177        for (i, spec_fact) in spec_refs.iter().enumerate() {
178            if i > 0 {
179                out.push('\n');
180            }
181            let ref_name = &spec_fact.reference.name;
182            let mut sub_group: Vec<&LemmaFact> = vec![spec_fact];
183            for ovr in &overrides {
184                if ovr.reference.segments.first().map(|s| s.as_str()) == Some(ref_name.as_str()) {
185                    sub_group.push(ovr);
186                }
187            }
188            emit_group(&sub_group, out);
189        }
190    }
191
192    // Any overrides that didn't match a spec ref (shouldn't happen in valid Lemma, but be safe)
193    let matched_prefixes: Vec<&str> = spec_refs
194        .iter()
195        .map(|f| f.reference.name.as_str())
196        .collect();
197    let unmatched: Vec<&LemmaFact> = overrides
198        .iter()
199        .filter(|o| {
200            o.reference
201                .segments
202                .first()
203                .map(|s| !matched_prefixes.contains(&s.as_str()))
204                .unwrap_or(true)
205        })
206        .copied()
207        .collect();
208    if !unmatched.is_empty() {
209        out.push('\n');
210        emit_group(&unmatched, out);
211    }
212}
213
214// =============================================================================
215// Rules
216// =============================================================================
217
218/// Format a rule with optional line wrapping: long unless lines get "then" on
219/// the next line; long expressions break at arithmetic operators.
220fn format_rule(rule: &LemmaRule, max_cols: usize) -> String {
221    let expr_indent = "  ";
222    let body = format_expr_wrapped(&rule.expression, max_cols, expr_indent, 10);
223    let mut out = String::new();
224    out.push_str("rule ");
225    out.push_str(&rule.name);
226    out.push_str(":\n");
227    out.push_str(expr_indent);
228    out.push_str(&body);
229
230    for unless_clause in &rule.unless_clauses {
231        let condition_str = format_expr_wrapped(&unless_clause.condition, max_cols, "    ", 10);
232        let result_str = format_expr_wrapped(&unless_clause.result, max_cols, "    ", 10);
233        let line = format!("  unless {} then {}", condition_str, result_str);
234        if line.len() > max_cols {
235            out.push_str("\n  unless ");
236            out.push_str(&condition_str);
237            out.push_str("\n    then ");
238            out.push_str(&result_str);
239        } else {
240            out.push_str("\n  unless ");
241            out.push_str(&condition_str);
242            out.push_str(" then ");
243            out.push_str(&result_str);
244        }
245    }
246    out.push('\n');
247    out
248}
249
250// =============================================================================
251// Expression wrapping (soft line breaking at max_cols)
252// =============================================================================
253
254/// Indent every line after the first by `indent`.
255fn indent_after_first_line(s: &str, indent: &str) -> String {
256    let mut first = true;
257    let mut out = String::new();
258    for line in s.lines() {
259        if first {
260            first = false;
261            out.push_str(line);
262        } else {
263            out.push('\n');
264            out.push_str(indent);
265            out.push_str(line);
266        }
267    }
268    if s.ends_with('\n') {
269        out.push('\n');
270    }
271    out
272}
273
274/// Format an expression with optional wrapping at arithmetic operators when over max_cols.
275/// `parent_prec` is used to add parentheses when needed (pass 10 for top level).
276fn format_expr_wrapped(
277    expr: &Expression,
278    max_cols: usize,
279    indent: &str,
280    parent_prec: u8,
281) -> String {
282    let my_prec = expression_precedence(&expr.kind);
283
284    let wrap_in_parens = |s: String| {
285        if parent_prec < 10 && my_prec < parent_prec {
286            format!("({})", s)
287        } else {
288            s
289        }
290    };
291
292    match &expr.kind {
293        ExpressionKind::Arithmetic(left, op, right) => {
294            let left_str = format_expr_wrapped(left.as_ref(), max_cols, indent, my_prec);
295            let right_str = format_expr_wrapped(right.as_ref(), max_cols, indent, my_prec);
296            let single_line = format!("{} {} {}", left_str, op, right_str);
297            if single_line.len() <= max_cols && !single_line.contains('\n') {
298                return wrap_in_parens(single_line);
299            }
300            let continued_right = indent_after_first_line(&right_str, indent);
301            let continuation = format!("{}{} {}", indent, op, continued_right);
302            let multi_line = format!("{}\n{}", left_str, continuation);
303            wrap_in_parens(multi_line)
304        }
305        _ => {
306            let s = expr.to_string();
307            wrap_in_parens(s)
308        }
309    }
310}
311
312// =============================================================================
313// Tests
314// =============================================================================
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319    use crate::parsing::ast::{
320        AsLemmaSource, BooleanValue, DateTimeValue, DurationUnit, TimeValue, TimezoneValue, Value,
321    };
322    use rust_decimal::prelude::FromStr;
323    use rust_decimal::Decimal;
324
325    /// Helper: format a Value as canonical Lemma source via AsLemmaSource.
326    fn fmt_value(v: &Value) -> String {
327        format!("{}", AsLemmaSource(v))
328    }
329
330    #[test]
331    fn test_format_value_text_is_quoted() {
332        let v = Value::Text("light".to_string());
333        assert_eq!(fmt_value(&v), "\"light\"");
334    }
335
336    #[test]
337    fn test_format_value_text_escapes_quotes() {
338        let v = Value::Text("say \"hello\"".to_string());
339        assert_eq!(fmt_value(&v), "\"say \\\"hello\\\"\"");
340    }
341
342    #[test]
343    fn test_format_value_number() {
344        let v = Value::Number(Decimal::from_str("42.50").unwrap());
345        assert_eq!(fmt_value(&v), "42.50");
346    }
347
348    #[test]
349    fn test_format_value_number_integer() {
350        let v = Value::Number(Decimal::from_str("100.00").unwrap());
351        assert_eq!(fmt_value(&v), "100");
352    }
353
354    #[test]
355    fn test_format_value_boolean() {
356        assert_eq!(fmt_value(&Value::Boolean(BooleanValue::True)), "true");
357        assert_eq!(fmt_value(&Value::Boolean(BooleanValue::Yes)), "yes");
358        assert_eq!(fmt_value(&Value::Boolean(BooleanValue::No)), "no");
359        assert_eq!(fmt_value(&Value::Boolean(BooleanValue::Accept)), "accept");
360        assert_eq!(fmt_value(&Value::Boolean(BooleanValue::Reject)), "reject");
361    }
362
363    #[test]
364    fn test_format_value_scale() {
365        let v = Value::Scale(Decimal::from_str("99.50").unwrap(), "eur".to_string());
366        assert_eq!(fmt_value(&v), "99.50 eur");
367    }
368
369    #[test]
370    fn test_format_value_duration() {
371        let v = Value::Duration(Decimal::from(40), DurationUnit::Hour);
372        assert_eq!(fmt_value(&v), "40 hours");
373    }
374
375    #[test]
376    fn test_format_value_ratio_percent() {
377        let v = Value::Ratio(
378            Decimal::from_str("0.10").unwrap(),
379            Some("percent".to_string()),
380        );
381        assert_eq!(fmt_value(&v), "10%");
382    }
383
384    #[test]
385    fn test_format_value_ratio_permille() {
386        let v = Value::Ratio(
387            Decimal::from_str("0.005").unwrap(),
388            Some("permille".to_string()),
389        );
390        assert_eq!(fmt_value(&v), "5%%");
391    }
392
393    #[test]
394    fn test_format_value_ratio_bare() {
395        let v = Value::Ratio(Decimal::from_str("0.25").unwrap(), None);
396        assert_eq!(fmt_value(&v), "0.25");
397    }
398
399    #[test]
400    fn test_format_value_date_only() {
401        let v = Value::Date(DateTimeValue {
402            year: 2024,
403            month: 1,
404            day: 15,
405            hour: 0,
406            minute: 0,
407            second: 0,
408            microsecond: 0,
409            timezone: None,
410        });
411        assert_eq!(fmt_value(&v), "2024-01-15");
412    }
413
414    #[test]
415    fn test_format_value_datetime_with_tz() {
416        let v = Value::Date(DateTimeValue {
417            year: 2024,
418            month: 1,
419            day: 15,
420            hour: 14,
421            minute: 30,
422            second: 0,
423            microsecond: 0,
424            timezone: Some(TimezoneValue {
425                offset_hours: 0,
426                offset_minutes: 0,
427            }),
428        });
429        assert_eq!(fmt_value(&v), "2024-01-15T14:30:00Z");
430    }
431
432    #[test]
433    fn test_format_value_time() {
434        let v = Value::Time(TimeValue {
435            hour: 14,
436            minute: 30,
437            second: 45,
438            timezone: None,
439        });
440        assert_eq!(fmt_value(&v), "14:30:45");
441    }
442
443    #[test]
444    fn test_format_source_round_trips_text() {
445        let source = r#"spec test
446
447fact name: "Alice"
448
449rule greeting: "hello"
450"#;
451        let formatted = format_source(source, "test.lemma").unwrap();
452        assert!(formatted.contains("\"Alice\""), "fact text must be quoted");
453        assert!(formatted.contains("\"hello\""), "rule text must be quoted");
454    }
455
456    #[test]
457    fn test_format_source_preserves_percent() {
458        let source = r#"spec test
459
460fact rate: 10 percent
461
462rule tax: rate * 21%
463"#;
464        let formatted = format_source(source, "test.lemma").unwrap();
465        assert!(
466            formatted.contains("10%"),
467            "fact percent must use shorthand %, got: {}",
468            formatted
469        );
470    }
471
472    #[test]
473    fn test_format_groups_facts_preserving_order() {
474        // Facts are deliberately mixed: the formatter keeps all regular facts together
475        // in original order, aligned
476        let source = r#"spec test
477
478fact income: [number -> minimum 0]
479fact filing_status: [filing_status_type -> default "single"]
480fact country: "NL"
481fact deductions: [number -> minimum 0]
482fact name: [text]
483
484rule total: income
485"#;
486        let formatted = format_source(source, "test.lemma").unwrap();
487        let fact_section = formatted
488            .split("rule total")
489            .next()
490            .unwrap()
491            .split("spec test\n")
492            .nth(1)
493            .unwrap();
494        let lines: Vec<&str> = fact_section.lines().filter(|l| !l.is_empty()).collect();
495        // All regular facts in one group, original order, aligned
496        assert_eq!(lines[0], "fact income        : [number -> minimum 0]");
497        assert_eq!(
498            lines[1],
499            "fact filing_status : [filing_status_type -> default \"single\"]"
500        );
501        assert_eq!(lines[2], "fact country       : \"NL\"");
502        assert_eq!(lines[3], "fact deductions    : [number -> minimum 0]");
503        assert_eq!(lines[4], "fact name          : [text]");
504    }
505
506    #[test]
507    fn test_format_groups_spec_refs_with_overrides() {
508        let source = r#"spec test
509
510fact retail.quantity: 5
511fact wholesale: spec order/wholesale
512fact retail: spec order/retail
513fact wholesale.quantity: 100
514fact base_price: 50
515
516rule total: base_price
517"#;
518        let formatted = format_source(source, "test.lemma").unwrap();
519        let fact_section = formatted
520            .split("rule total")
521            .next()
522            .unwrap()
523            .split("spec test\n")
524            .nth(1)
525            .unwrap();
526        let lines: Vec<&str> = fact_section.lines().filter(|l| !l.is_empty()).collect();
527        // Group 1: Literals
528        assert_eq!(lines[0], "fact base_price : 50");
529        // Group 4: Spec refs in original order, each with its overrides, aligned
530        assert_eq!(lines[1], "fact wholesale          : spec order/wholesale");
531        assert_eq!(lines[2], "fact wholesale.quantity : 100");
532        assert_eq!(lines[3], "fact retail          : spec order/retail");
533        assert_eq!(lines[4], "fact retail.quantity : 5");
534    }
535
536    #[test]
537    fn test_format_source_weather_clothing_text_quoted() {
538        let source = r#"spec weather_clothing
539
540type clothing_style: text
541  -> option "light"
542  -> option "warm"
543
544fact temperature: [number]
545
546rule clothing_layer: "light"
547  unless temperature < 5 then "warm"
548"#;
549        let formatted = format_source(source, "test.lemma").unwrap();
550        assert!(
551            formatted.contains("\"light\""),
552            "text in rule must be quoted, got: {}",
553            formatted
554        );
555        assert!(
556            formatted.contains("\"warm\""),
557            "text in unless must be quoted, got: {}",
558            formatted
559        );
560    }
561
562    // NOTE: Default value type validation (e.g. rejecting "10 $$" as a number
563    // default) is tested at the planning level in engine.rs, not here. The
564    // formatter only parses — it does not validate types. Planning catches
565    // invalid defaults for both primitives and named types.
566
567    #[test]
568    fn test_format_text_option_round_trips() {
569        let source = r#"spec test
570
571type status: text
572  -> option "active"
573  -> option "inactive"
574
575fact s: [status]
576
577rule out: s
578"#;
579        let formatted = format_source(source, "test.lemma").unwrap();
580        assert!(
581            formatted.contains("option \"active\""),
582            "text option must be quoted, got: {}",
583            formatted
584        );
585        assert!(
586            formatted.contains("option \"inactive\""),
587            "text option must be quoted, got: {}",
588            formatted
589        );
590        // Round-trip
591        let reparsed = format_source(&formatted, "test.lemma");
592        assert!(reparsed.is_ok(), "formatted output should re-parse");
593    }
594
595    #[test]
596    fn test_format_help_round_trips() {
597        let source = r#"spec test
598fact quantity: [number -> help "Number of items to order"]
599rule total: quantity
600"#;
601        let formatted = format_source(source, "test.lemma").unwrap();
602        assert!(
603            formatted.contains("help \"Number of items to order\""),
604            "help must be quoted, got: {}",
605            formatted
606        );
607        // Round-trip
608        let reparsed = format_source(&formatted, "test.lemma");
609        assert!(reparsed.is_ok(), "formatted output should re-parse");
610    }
611
612    #[test]
613    fn test_format_scale_type_def_round_trips() {
614        let source = r#"spec test
615
616type money: scale
617  -> unit eur 1.00
618  -> unit usd 1.10
619  -> decimals 2
620  -> minimum 0
621
622fact price: [money]
623
624rule total: price
625"#;
626        let formatted = format_source(source, "test.lemma").unwrap();
627        assert!(
628            formatted.contains("unit eur 1.00"),
629            "scale unit should not be quoted, got: {}",
630            formatted
631        );
632        // Round-trip
633        let reparsed = format_source(&formatted, "test.lemma");
634        assert!(
635            reparsed.is_ok(),
636            "formatted output should re-parse, got: {:?}",
637            reparsed
638        );
639    }
640
641    #[test]
642    fn test_format_expression_display_stable_round_trip() {
643        let source = r#"spec test
644fact a: 1.00
645rule r: a + 2.00 * 3
646"#;
647        let formatted = format_source(source, "test.lemma").unwrap();
648        let again = format_source(&formatted, "test.lemma").unwrap();
649        assert_eq!(
650            formatted, again,
651            "AST Display-based format must be idempotent under parse/format"
652        );
653    }
654}