Skip to main content

lemma/formatting/
mod.rs

1//! Lemma source code formatting.
2//!
3//! Formats parsed specs into canonical Lemma source text.
4//! Value and constraint formatting is delegated to [`AsLemmaSource`] (in `parsing::ast`),
5//! which emits valid, round-trippable Lemma syntax. The regular `Display` impls on AST
6//! types are for human-readable output (error messages, evaluation); they are **not** used
7//! here. This module handles layout: alignment, line wrapping, and section ordering.
8
9use crate::parsing::ast::{
10    expression_precedence, AsLemmaSource, Expression, ExpressionKind, FactValue, LemmaFact,
11    LemmaRule, LemmaSpec, TypeDef,
12};
13use crate::{parse, Error, ResourceLimits};
14
15/// Soft line length limit. Longer lines may be wrapped (unless clauses, expressions).
16/// Facts and other constructs are not broken if they exceed this.
17pub const MAX_COLS: usize = 60;
18
19// =============================================================================
20// Public entry points
21// =============================================================================
22
23/// Format a sequence of parsed specs into canonical Lemma source.
24///
25/// Specs are separated by two blank lines.
26/// The result ends with a single newline.
27#[must_use]
28pub fn format_specs(specs: &[LemmaSpec]) -> String {
29    let mut out = String::new();
30    for (index, spec) in specs.iter().enumerate() {
31        if index > 0 {
32            out.push_str("\n\n");
33        }
34        out.push_str(&format_spec(spec, MAX_COLS));
35    }
36    if !out.ends_with('\n') {
37        out.push('\n');
38    }
39    out
40}
41
42/// Parse a source string and format it to canonical Lemma source.
43///
44/// Returns an error if the source does not parse.
45pub fn format_source(source: &str, attribute: &str) -> Result<String, Error> {
46    let limits = ResourceLimits::default();
47    let result = parse(source, attribute, &limits)?;
48    Ok(format_specs(&result.specs))
49}
50
51// =============================================================================
52// Spec
53// =============================================================================
54
55fn format_spec(spec: &LemmaSpec, max_cols: usize) -> String {
56    let mut out = String::new();
57    out.push_str("spec ");
58    out.push_str(&spec.name);
59    if let Some(ref af) = spec.effective_from {
60        out.push(' ');
61        out.push_str(&af.to_string());
62    }
63    out.push('\n');
64
65    if let Some(ref commentary) = spec.commentary {
66        out.push_str("\"\"\"\n");
67        out.push_str(commentary);
68        out.push_str("\n\"\"\"\n");
69    }
70
71    for meta in &spec.meta_fields {
72        out.push_str(&format!(
73            "meta {}: {}\n",
74            meta.key,
75            AsLemmaSource(&meta.value)
76        ));
77    }
78
79    let named_types: Vec<_> = spec
80        .types
81        .iter()
82        .filter(|t| !matches!(t, TypeDef::Inline { .. }))
83        .collect();
84    if !named_types.is_empty() {
85        out.push('\n');
86        for (index, type_def) in named_types.iter().enumerate() {
87            if index > 0 {
88                out.push('\n');
89            }
90            out.push_str(&format!("{}", AsLemmaSource(*type_def)));
91            out.push('\n');
92        }
93    }
94
95    if !spec.facts.is_empty() {
96        format_sorted_facts(&spec.facts, &mut out);
97    }
98
99    if !spec.rules.is_empty() {
100        out.push('\n');
101        for (index, rule) in spec.rules.iter().enumerate() {
102            if index > 0 {
103                out.push('\n');
104            }
105            out.push_str(&format_rule(rule, max_cols));
106        }
107    }
108
109    out
110}
111
112// =============================================================================
113// Type definitions — delegated to AsLemmaSource<TypeDef>
114// =============================================================================
115
116// =============================================================================
117// Facts
118// =============================================================================
119
120/// Format a fact, optionally with the reference name padded to `align_width` characters
121/// for column-aligned `=` signs within a group.
122/// When `align_width` is 0 or less than the reference length, no padding is added.
123fn format_fact(fact: &LemmaFact, align_width: usize) -> String {
124    let ref_str = format!("{}", fact.reference);
125    let padding = if align_width > ref_str.len() {
126        " ".repeat(align_width - ref_str.len())
127    } else {
128        String::new()
129    };
130    format!(
131        "fact {}{} : {}",
132        ref_str,
133        padding,
134        AsLemmaSource(&fact.value)
135    )
136}
137
138/// Compute the maximum fact reference width across a slice of facts.
139fn max_ref_width(facts: &[&LemmaFact]) -> usize {
140    facts
141        .iter()
142        .map(|f| format!("{}", f.reference).len())
143        .max()
144        .unwrap_or(0)
145}
146
147/// Group facts into two sections separated by a blank line:
148///
149/// 1. Regular facts (literals, type declarations) — original order, aligned
150/// 2. Spec references, each followed by its cross-spec overrides — original order, aligned per sub-group
151fn format_sorted_facts(facts: &[LemmaFact], out: &mut String) {
152    let mut regular: Vec<&LemmaFact> = Vec::new();
153    let mut spec_refs: Vec<&LemmaFact> = Vec::new();
154    let mut overrides: Vec<&LemmaFact> = Vec::new();
155
156    for fact in facts {
157        if !fact.reference.is_local() {
158            overrides.push(fact);
159        } else if matches!(&fact.value, FactValue::SpecReference(_)) {
160            spec_refs.push(fact);
161        } else {
162            regular.push(fact);
163        }
164    }
165
166    // Helper: emit an aligned group of facts
167    let emit_group = |facts: &[&LemmaFact], out: &mut String| {
168        let width = max_ref_width(facts);
169        for fact in facts {
170            out.push_str(&format_fact(fact, width));
171            out.push('\n');
172        }
173    };
174
175    // Group 1: Regular facts (literals + type declarations), original order, aligned
176    if !regular.is_empty() {
177        out.push('\n');
178        emit_group(&regular, out);
179    }
180
181    // Group 2: Spec references, each followed by its overrides
182    if !spec_refs.is_empty() {
183        out.push('\n');
184        for (i, spec_fact) in spec_refs.iter().enumerate() {
185            if i > 0 {
186                out.push('\n');
187            }
188            let ref_name = &spec_fact.reference.name;
189            let mut sub_group: Vec<&LemmaFact> = vec![spec_fact];
190            for ovr in &overrides {
191                if ovr.reference.segments.first().map(|s| s.as_str()) == Some(ref_name.as_str()) {
192                    sub_group.push(ovr);
193                }
194            }
195            emit_group(&sub_group, out);
196        }
197    }
198
199    // Any overrides that didn't match a spec ref (shouldn't happen in valid Lemma, but be safe)
200    let matched_prefixes: Vec<&str> = spec_refs
201        .iter()
202        .map(|f| f.reference.name.as_str())
203        .collect();
204    let unmatched: Vec<&LemmaFact> = overrides
205        .iter()
206        .filter(|o| {
207            o.reference
208                .segments
209                .first()
210                .map(|s| !matched_prefixes.contains(&s.as_str()))
211                .unwrap_or(true)
212        })
213        .copied()
214        .collect();
215    if !unmatched.is_empty() {
216        out.push('\n');
217        emit_group(&unmatched, out);
218    }
219}
220
221// =============================================================================
222// Rules
223// =============================================================================
224
225/// Format a rule with optional line wrapping: long unless lines get "then" on
226/// the next line; long expressions break at arithmetic operators.
227fn format_rule(rule: &LemmaRule, max_cols: usize) -> String {
228    let expr_indent = "  ";
229    let body = format_expr_wrapped(&rule.expression, max_cols, expr_indent, 10);
230    let mut out = String::new();
231    out.push_str("rule ");
232    out.push_str(&rule.name);
233    out.push_str(":\n");
234    out.push_str(expr_indent);
235    out.push_str(&body);
236
237    for unless_clause in &rule.unless_clauses {
238        let condition_str = format_expr_wrapped(&unless_clause.condition, max_cols, "    ", 10);
239        let result_str = format_expr_wrapped(&unless_clause.result, max_cols, "    ", 10);
240        let line = format!("  unless {} then {}", condition_str, result_str);
241        if line.len() > max_cols {
242            out.push_str("\n  unless ");
243            out.push_str(&condition_str);
244            out.push_str("\n    then ");
245            out.push_str(&result_str);
246        } else {
247            out.push_str("\n  unless ");
248            out.push_str(&condition_str);
249            out.push_str(" then ");
250            out.push_str(&result_str);
251        }
252    }
253    out.push('\n');
254    out
255}
256
257// =============================================================================
258// Expressions — produce valid Lemma source with precedence-based parens
259// =============================================================================
260
261/// Format an expression as valid Lemma source (flat, no wrapping).
262///
263/// Uses `AsLemmaSource<Value>` for literals and the AST types' `Display` impls
264/// for operators, rule references, etc. (those `Display` impls already emit
265/// valid Lemma syntax for these simple tokens).
266fn format_expr(expr: &Expression, parent_prec: u8) -> String {
267    let my_prec = expression_precedence(&expr.kind);
268
269    let needs_parens = parent_prec < 10 && my_prec < parent_prec;
270
271    let inner = match &expr.kind {
272        ExpressionKind::Literal(lit) => format!("{}", AsLemmaSource(lit)),
273        ExpressionKind::Reference(r) => format!("{}", r),
274        ExpressionKind::UnresolvedUnitLiteral(..) => {
275            // Expression::Display already normalizes the decimal.
276            format!("{}", expr)
277        }
278        ExpressionKind::Arithmetic(left, op, right) => {
279            let left_str = format_expr(left, my_prec);
280            let right_str = format_expr(right, my_prec);
281            format!("{} {} {}", left_str, op.symbol(), right_str)
282        }
283        ExpressionKind::Comparison(left, op, right) => {
284            let left_str = format_expr(left, my_prec);
285            let right_str = format_expr(right, my_prec);
286            format!("{} {} {}", left_str, op.symbol(), right_str)
287        }
288        ExpressionKind::UnitConversion(value, target) => {
289            let value_str = format_expr(value, my_prec);
290            format!("{} in {}", value_str, target)
291        }
292        ExpressionKind::LogicalNegation(inner_expr, _) => {
293            let inner_str = format_expr(inner_expr, my_prec);
294            format!("not {}", inner_str)
295        }
296        ExpressionKind::LogicalAnd(left, right) => {
297            let left_str = format_expr(left, my_prec);
298            let right_str = format_expr(right, my_prec);
299            format!("{} and {}", left_str, right_str)
300        }
301        ExpressionKind::MathematicalComputation(op, operand) => {
302            let operand_str = format_expr(operand, my_prec);
303            format!("{} {}", op, operand_str)
304        }
305        ExpressionKind::Veto(veto) => match &veto.message {
306            Some(msg) => format!("veto {}", crate::parsing::ast::quote_lemma_text(msg)),
307            None => "veto".to_string(),
308        },
309        ExpressionKind::Now => "now".to_string(),
310        ExpressionKind::DateRelative(kind, date_expr, tolerance) => {
311            let date_str = format_expr(date_expr, my_prec);
312            match tolerance {
313                Some(tol) => format!("{} {} {}", date_str, kind, format_expr(tol, my_prec)),
314                None => format!("{} {}", date_str, kind),
315            }
316        }
317        ExpressionKind::DateCalendar(kind, unit, date_expr) => {
318            let date_str = format_expr(date_expr, my_prec);
319            format!("{} {} {}", date_str, kind, unit)
320        }
321    };
322
323    if needs_parens {
324        format!("({})", inner)
325    } else {
326        inner
327    }
328}
329
330// =============================================================================
331// Expression wrapping (soft line breaking at max_cols)
332// =============================================================================
333
334/// Indent every line after the first by `indent`.
335fn indent_after_first_line(s: &str, indent: &str) -> String {
336    let mut first = true;
337    let mut out = String::new();
338    for line in s.lines() {
339        if first {
340            first = false;
341            out.push_str(line);
342        } else {
343            out.push('\n');
344            out.push_str(indent);
345            out.push_str(line);
346        }
347    }
348    if s.ends_with('\n') {
349        out.push('\n');
350    }
351    out
352}
353
354/// Format an expression with optional wrapping at arithmetic operators when over max_cols.
355/// `parent_prec` is used to add parentheses when needed (pass 10 for top level).
356fn format_expr_wrapped(
357    expr: &Expression,
358    max_cols: usize,
359    indent: &str,
360    parent_prec: u8,
361) -> String {
362    let my_prec = expression_precedence(&expr.kind);
363
364    let wrap_in_parens = |s: String| {
365        if parent_prec < 10 && my_prec < parent_prec {
366            format!("({})", s)
367        } else {
368            s
369        }
370    };
371
372    match &expr.kind {
373        ExpressionKind::Arithmetic(left, op, right) => {
374            let left_str = format_expr_wrapped(left.as_ref(), max_cols, indent, my_prec);
375            let right_str = format_expr_wrapped(right.as_ref(), max_cols, indent, my_prec);
376            let op_symbol = op.symbol();
377            let single_line = format!("{} {} {}", left_str, op_symbol, right_str);
378            if single_line.len() <= max_cols && !single_line.contains('\n') {
379                return wrap_in_parens(single_line);
380            }
381            let continued_right = indent_after_first_line(&right_str, indent);
382            let continuation = format!("{}{} {}", indent, op_symbol, continued_right);
383            let multi_line = format!("{}\n{}", left_str, continuation);
384            wrap_in_parens(multi_line)
385        }
386        _ => {
387            let s = format_expr(expr, parent_prec);
388            wrap_in_parens(s)
389        }
390    }
391}
392
393// =============================================================================
394// Tests
395// =============================================================================
396
397#[cfg(test)]
398mod tests {
399    use super::*;
400    use crate::parsing::ast::{
401        AsLemmaSource, BooleanValue, DateTimeValue, DurationUnit, TimeValue, TimezoneValue, Value,
402    };
403    use rust_decimal::prelude::FromStr;
404    use rust_decimal::Decimal;
405
406    /// Helper: format a Value as canonical Lemma source via AsLemmaSource.
407    fn fmt_value(v: &Value) -> String {
408        format!("{}", AsLemmaSource(v))
409    }
410
411    #[test]
412    fn test_format_value_text_is_quoted() {
413        let v = Value::Text("light".to_string());
414        assert_eq!(fmt_value(&v), "\"light\"");
415    }
416
417    #[test]
418    fn test_format_value_text_escapes_quotes() {
419        let v = Value::Text("say \"hello\"".to_string());
420        assert_eq!(fmt_value(&v), "\"say \\\"hello\\\"\"");
421    }
422
423    #[test]
424    fn test_format_value_number() {
425        let v = Value::Number(Decimal::from_str("42.50").unwrap());
426        assert_eq!(fmt_value(&v), "42.5");
427    }
428
429    #[test]
430    fn test_format_value_number_integer() {
431        let v = Value::Number(Decimal::from_str("100.00").unwrap());
432        assert_eq!(fmt_value(&v), "100");
433    }
434
435    #[test]
436    fn test_format_value_boolean() {
437        assert_eq!(fmt_value(&Value::Boolean(BooleanValue::True)), "true");
438        assert_eq!(fmt_value(&Value::Boolean(BooleanValue::Yes)), "yes");
439        assert_eq!(fmt_value(&Value::Boolean(BooleanValue::No)), "no");
440        assert_eq!(fmt_value(&Value::Boolean(BooleanValue::Accept)), "accept");
441        assert_eq!(fmt_value(&Value::Boolean(BooleanValue::Reject)), "reject");
442    }
443
444    #[test]
445    fn test_format_value_scale() {
446        let v = Value::Scale(Decimal::from_str("99.50").unwrap(), "eur".to_string());
447        assert_eq!(fmt_value(&v), "99.5 eur");
448    }
449
450    #[test]
451    fn test_format_value_duration() {
452        let v = Value::Duration(Decimal::from(40), DurationUnit::Hour);
453        assert_eq!(fmt_value(&v), "40 hours");
454    }
455
456    #[test]
457    fn test_format_value_ratio_percent() {
458        let v = Value::Ratio(
459            Decimal::from_str("0.10").unwrap(),
460            Some("percent".to_string()),
461        );
462        assert_eq!(fmt_value(&v), "10%");
463    }
464
465    #[test]
466    fn test_format_value_ratio_permille() {
467        let v = Value::Ratio(
468            Decimal::from_str("0.005").unwrap(),
469            Some("permille".to_string()),
470        );
471        assert_eq!(fmt_value(&v), "5%%");
472    }
473
474    #[test]
475    fn test_format_value_ratio_bare() {
476        let v = Value::Ratio(Decimal::from_str("0.25").unwrap(), None);
477        assert_eq!(fmt_value(&v), "0.25");
478    }
479
480    #[test]
481    fn test_format_value_date_only() {
482        let v = Value::Date(DateTimeValue {
483            year: 2024,
484            month: 1,
485            day: 15,
486            hour: 0,
487            minute: 0,
488            second: 0,
489            microsecond: 0,
490            timezone: None,
491        });
492        assert_eq!(fmt_value(&v), "2024-01-15");
493    }
494
495    #[test]
496    fn test_format_value_datetime_with_tz() {
497        let v = Value::Date(DateTimeValue {
498            year: 2024,
499            month: 1,
500            day: 15,
501            hour: 14,
502            minute: 30,
503            second: 0,
504            microsecond: 0,
505            timezone: Some(TimezoneValue {
506                offset_hours: 0,
507                offset_minutes: 0,
508            }),
509        });
510        assert_eq!(fmt_value(&v), "2024-01-15T14:30:00Z");
511    }
512
513    #[test]
514    fn test_format_value_time() {
515        let v = Value::Time(TimeValue {
516            hour: 14,
517            minute: 30,
518            second: 45,
519            timezone: None,
520        });
521        assert_eq!(fmt_value(&v), "14:30:45");
522    }
523
524    #[test]
525    fn test_format_source_round_trips_text() {
526        let source = r#"spec test
527
528fact name: "Alice"
529
530rule greeting: "hello"
531"#;
532        let formatted = format_source(source, "test.lemma").unwrap();
533        assert!(formatted.contains("\"Alice\""), "fact text must be quoted");
534        assert!(formatted.contains("\"hello\""), "rule text must be quoted");
535    }
536
537    #[test]
538    fn test_format_source_preserves_percent() {
539        let source = r#"spec test
540
541fact rate: 10 percent
542
543rule tax: rate * 21%
544"#;
545        let formatted = format_source(source, "test.lemma").unwrap();
546        assert!(
547            formatted.contains("10%"),
548            "fact percent must use shorthand %, got: {}",
549            formatted
550        );
551    }
552
553    #[test]
554    fn test_format_groups_facts_preserving_order() {
555        // Facts are deliberately mixed: the formatter keeps all regular facts together
556        // in original order, aligned
557        let source = r#"spec test
558
559fact income: [number -> minimum 0]
560fact filing_status: [filing_status_type -> default "single"]
561fact country: "NL"
562fact deductions: [number -> minimum 0]
563fact name: [text]
564
565rule total: income
566"#;
567        let formatted = format_source(source, "test.lemma").unwrap();
568        let fact_section = formatted
569            .split("rule total")
570            .next()
571            .unwrap()
572            .split("spec test\n")
573            .nth(1)
574            .unwrap();
575        let lines: Vec<&str> = fact_section.lines().filter(|l| !l.is_empty()).collect();
576        // All regular facts in one group, original order, aligned
577        assert_eq!(lines[0], "fact income        : [number -> minimum 0]");
578        assert_eq!(
579            lines[1],
580            "fact filing_status : [filing_status_type -> default \"single\"]"
581        );
582        assert_eq!(lines[2], "fact country       : \"NL\"");
583        assert_eq!(lines[3], "fact deductions    : [number -> minimum 0]");
584        assert_eq!(lines[4], "fact name          : [text]");
585    }
586
587    #[test]
588    fn test_format_groups_spec_refs_with_overrides() {
589        let source = r#"spec test
590
591fact retail.quantity: 5
592fact wholesale: spec order/wholesale
593fact retail: spec order/retail
594fact wholesale.quantity: 100
595fact base_price: 50
596
597rule total: base_price
598"#;
599        let formatted = format_source(source, "test.lemma").unwrap();
600        let fact_section = formatted
601            .split("rule total")
602            .next()
603            .unwrap()
604            .split("spec test\n")
605            .nth(1)
606            .unwrap();
607        let lines: Vec<&str> = fact_section.lines().filter(|l| !l.is_empty()).collect();
608        // Group 1: Literals
609        assert_eq!(lines[0], "fact base_price : 50");
610        // Group 4: Spec refs in original order, each with its overrides, aligned
611        assert_eq!(lines[1], "fact wholesale          : spec order/wholesale");
612        assert_eq!(lines[2], "fact wholesale.quantity : 100");
613        assert_eq!(lines[3], "fact retail          : spec order/retail");
614        assert_eq!(lines[4], "fact retail.quantity : 5");
615    }
616
617    #[test]
618    fn test_format_source_weather_clothing_text_quoted() {
619        let source = r#"spec weather_clothing
620
621type clothing_style: text
622  -> option "light"
623  -> option "warm"
624
625fact temperature: [number]
626
627rule clothing_layer: "light"
628  unless temperature < 5 then "warm"
629"#;
630        let formatted = format_source(source, "test.lemma").unwrap();
631        assert!(
632            formatted.contains("\"light\""),
633            "text in rule must be quoted, got: {}",
634            formatted
635        );
636        assert!(
637            formatted.contains("\"warm\""),
638            "text in unless must be quoted, got: {}",
639            formatted
640        );
641    }
642
643    // NOTE: Default value type validation (e.g. rejecting "10 $$" as a number
644    // default) is tested at the planning level in engine.rs, not here. The
645    // formatter only parses — it does not validate types. Planning catches
646    // invalid defaults for both primitives and named types.
647
648    #[test]
649    fn test_format_text_option_round_trips() {
650        let source = r#"spec test
651
652type status: text
653  -> option "active"
654  -> option "inactive"
655
656fact s: [status]
657
658rule out: s
659"#;
660        let formatted = format_source(source, "test.lemma").unwrap();
661        assert!(
662            formatted.contains("option \"active\""),
663            "text option must be quoted, got: {}",
664            formatted
665        );
666        assert!(
667            formatted.contains("option \"inactive\""),
668            "text option must be quoted, got: {}",
669            formatted
670        );
671        // Round-trip
672        let reparsed = format_source(&formatted, "test.lemma");
673        assert!(reparsed.is_ok(), "formatted output should re-parse");
674    }
675
676    #[test]
677    fn test_format_help_round_trips() {
678        let source = r#"spec test
679fact quantity: [number -> help "Number of items to order"]
680rule total: quantity
681"#;
682        let formatted = format_source(source, "test.lemma").unwrap();
683        assert!(
684            formatted.contains("help \"Number of items to order\""),
685            "help must be quoted, got: {}",
686            formatted
687        );
688        // Round-trip
689        let reparsed = format_source(&formatted, "test.lemma");
690        assert!(reparsed.is_ok(), "formatted output should re-parse");
691    }
692
693    #[test]
694    fn test_format_scale_type_def_round_trips() {
695        let source = r#"spec test
696
697type money: scale
698  -> unit eur 1.00
699  -> unit usd 1.10
700  -> decimals 2
701  -> minimum 0
702
703fact price: [money]
704
705rule total: price
706"#;
707        let formatted = format_source(source, "test.lemma").unwrap();
708        assert!(
709            formatted.contains("unit eur 1.00"),
710            "scale unit should not be quoted, got: {}",
711            formatted
712        );
713        // Round-trip
714        let reparsed = format_source(&formatted, "test.lemma");
715        assert!(
716            reparsed.is_ok(),
717            "formatted output should re-parse, got: {:?}",
718            reparsed
719        );
720    }
721}