Skip to main content

lemma/formatting/
mod.rs

1//! Lemma source code formatting.
2//!
3//! Formats parsed documents into canonical Lemma source text.
4//! Value and constraint formatting is delegated to [`AsLemmaSource`] (in `parsing::ast`),
5//! which emits valid, round-trippable Lemma syntax. The regular `Display` impls on AST
6//! types are for human-readable output (error messages, evaluation); they are **not** used
7//! here. This module handles layout: alignment, line wrapping, and section ordering.
8
9use crate::parsing::ast::{
10    expression_precedence, AsLemmaSource, Expression, ExpressionKind, FactValue, LemmaDoc,
11    LemmaFact, LemmaRule, TypeDef,
12};
13use crate::{parse, LemmaError, ResourceLimits};
14
15/// Soft line length limit. Longer lines may be wrapped (unless clauses, expressions).
16/// Facts and other constructs are not broken if they exceed this.
17pub const MAX_COLS: usize = 60;
18
19// =============================================================================
20// Public entry points
21// =============================================================================
22
23/// Format a sequence of parsed documents into canonical Lemma source.
24///
25/// Documents are separated by two blank lines.
26/// The result ends with a single newline.
27#[must_use]
28pub fn format_docs(docs: &[LemmaDoc]) -> String {
29    let mut out = String::new();
30    for (index, doc) in docs.iter().enumerate() {
31        if index > 0 {
32            out.push_str("\n\n");
33        }
34        out.push_str(&format_document(doc, MAX_COLS));
35    }
36    if !out.ends_with('\n') {
37        out.push('\n');
38    }
39    out
40}
41
42/// Parse a source string and format it to canonical Lemma source.
43///
44/// Returns an error if the source does not parse.
45pub fn format_source(source: &str, attribute: &str) -> Result<String, LemmaError> {
46    let limits = ResourceLimits::default();
47    let docs = parse(source, attribute, &limits)?;
48    Ok(format_docs(&docs))
49}
50
51// =============================================================================
52// Document
53// =============================================================================
54
55fn format_document(doc: &LemmaDoc, max_cols: usize) -> String {
56    let mut out = String::new();
57    out.push_str("doc ");
58    out.push_str(&doc.name);
59    out.push('\n');
60
61    if let Some(ref commentary) = doc.commentary {
62        out.push_str("\"\"\"\n");
63        out.push_str(commentary);
64        out.push_str("\n\"\"\"\n");
65    }
66
67    let named_types: Vec<_> = doc
68        .types
69        .iter()
70        .filter(|t| !matches!(t, TypeDef::Inline { .. }))
71        .collect();
72    if !named_types.is_empty() {
73        out.push('\n');
74        for (index, type_def) in named_types.iter().enumerate() {
75            if index > 0 {
76                out.push('\n');
77            }
78            out.push_str(&format!("{}", AsLemmaSource(*type_def)));
79            out.push('\n');
80        }
81    }
82
83    if !doc.facts.is_empty() {
84        format_sorted_facts(&doc.facts, &mut out);
85    }
86
87    if !doc.rules.is_empty() {
88        out.push('\n');
89        for (index, rule) in doc.rules.iter().enumerate() {
90            if index > 0 {
91                out.push('\n');
92            }
93            out.push_str(&format_rule(rule, max_cols));
94        }
95    }
96
97    out
98}
99
100// =============================================================================
101// Type definitions — delegated to AsLemmaSource<TypeDef>
102// =============================================================================
103
104// =============================================================================
105// Facts
106// =============================================================================
107
108/// Format a fact, optionally with the reference name padded to `align_width` characters
109/// for column-aligned `=` signs within a group.
110/// When `align_width` is 0 or less than the reference length, no padding is added.
111fn format_fact(fact: &LemmaFact, align_width: usize) -> String {
112    let ref_str = format!("{}", fact.reference);
113    let padded = if align_width > ref_str.len() {
114        format!("{:width$}", ref_str, width = align_width)
115    } else {
116        ref_str
117    };
118    format!("fact {} = {}", padded, AsLemmaSource(&fact.value))
119}
120
121/// Compute the maximum fact reference width across a slice of facts.
122fn max_ref_width(facts: &[&LemmaFact]) -> usize {
123    facts
124        .iter()
125        .map(|f| format!("{}", f.reference).len())
126        .max()
127        .unwrap_or(0)
128}
129
130/// Group facts into two sections separated by a blank line:
131///
132/// 1. Regular facts (literals, type declarations) — original order, aligned
133/// 2. Document references, each followed by its cross-doc overrides — original order, aligned per sub-group
134fn format_sorted_facts(facts: &[LemmaFact], out: &mut String) {
135    let mut regular: Vec<&LemmaFact> = Vec::new();
136    let mut doc_refs: Vec<&LemmaFact> = Vec::new();
137    let mut overrides: Vec<&LemmaFact> = Vec::new();
138
139    for fact in facts {
140        if !fact.reference.is_local() {
141            overrides.push(fact);
142        } else if matches!(&fact.value, FactValue::DocumentReference(_)) {
143            doc_refs.push(fact);
144        } else {
145            regular.push(fact);
146        }
147    }
148
149    // Helper: emit an aligned group of facts
150    let emit_group = |facts: &[&LemmaFact], out: &mut String| {
151        let width = max_ref_width(facts);
152        for fact in facts {
153            out.push_str(&format_fact(fact, width));
154            out.push('\n');
155        }
156    };
157
158    // Group 1: Regular facts (literals + type declarations), original order, aligned
159    if !regular.is_empty() {
160        out.push('\n');
161        emit_group(&regular, out);
162    }
163
164    // Group 2: Doc references, each followed by its overrides
165    if !doc_refs.is_empty() {
166        out.push('\n');
167        for (i, doc_fact) in doc_refs.iter().enumerate() {
168            if i > 0 {
169                out.push('\n');
170            }
171            let ref_name = &doc_fact.reference.fact;
172            let mut sub_group: Vec<&LemmaFact> = vec![doc_fact];
173            for ovr in &overrides {
174                if ovr.reference.segments.first().map(|s| s.as_str()) == Some(ref_name.as_str()) {
175                    sub_group.push(ovr);
176                }
177            }
178            emit_group(&sub_group, out);
179        }
180    }
181
182    // Any overrides that didn't match a doc ref (shouldn't happen in valid Lemma, but be safe)
183    let matched_prefixes: Vec<&str> = doc_refs.iter().map(|f| f.reference.fact.as_str()).collect();
184    let unmatched: Vec<&LemmaFact> = overrides
185        .iter()
186        .filter(|o| {
187            o.reference
188                .segments
189                .first()
190                .map(|s| !matched_prefixes.contains(&s.as_str()))
191                .unwrap_or(true)
192        })
193        .copied()
194        .collect();
195    if !unmatched.is_empty() {
196        out.push('\n');
197        emit_group(&unmatched, out);
198    }
199}
200
201// =============================================================================
202// Rules
203// =============================================================================
204
205/// Format a rule with optional line wrapping: long unless lines get "then" on
206/// the next line; long expressions break at arithmetic operators.
207fn format_rule(rule: &LemmaRule, max_cols: usize) -> String {
208    let expr_indent = "  ";
209    let body = format_expr_wrapped(&rule.expression, max_cols, expr_indent, 10);
210    let mut out = String::new();
211    out.push_str("rule ");
212    out.push_str(&rule.name);
213    out.push_str(" = ");
214    out.push_str(&body);
215
216    for unless_clause in &rule.unless_clauses {
217        let condition_str = format_expr_wrapped(&unless_clause.condition, max_cols, "    ", 10);
218        let result_str = format_expr_wrapped(&unless_clause.result, max_cols, "    ", 10);
219        let line = format!("  unless {} then {}", condition_str, result_str);
220        if line.len() > max_cols {
221            out.push_str("\n  unless ");
222            out.push_str(&condition_str);
223            out.push_str("\n    then ");
224            out.push_str(&result_str);
225        } else {
226            out.push_str("\n  unless ");
227            out.push_str(&condition_str);
228            out.push_str(" then ");
229            out.push_str(&result_str);
230        }
231    }
232    out.push('\n');
233    out
234}
235
236// =============================================================================
237// Expressions — produce valid Lemma source with precedence-based parens
238// =============================================================================
239
240/// Format an expression as valid Lemma source (flat, no wrapping).
241///
242/// Uses `AsLemmaSource<Value>` for literals and the AST types' `Display` impls
243/// for operators, rule references, etc. (those `Display` impls already emit
244/// valid Lemma syntax for these simple tokens).
245fn format_expr(expr: &Expression, parent_prec: u8) -> String {
246    let my_prec = expression_precedence(&expr.kind);
247
248    let needs_parens = parent_prec < 10 && my_prec < parent_prec;
249
250    let inner = match &expr.kind {
251        ExpressionKind::Literal(lit) => format!("{}", AsLemmaSource(lit)),
252        ExpressionKind::FactReference(r) => format!("{}", r),
253        ExpressionKind::RuleReference(rr) => format!("{}", rr),
254        ExpressionKind::UnresolvedUnitLiteral(..) => {
255            // Expression::Display already normalizes the decimal.
256            format!("{}", expr)
257        }
258        ExpressionKind::Arithmetic(left, op, right) => {
259            let left_str = format_expr(left, my_prec);
260            let right_str = format_expr(right, my_prec);
261            format!("{} {} {}", left_str, op.symbol(), right_str)
262        }
263        ExpressionKind::Comparison(left, op, right) => {
264            let left_str = format_expr(left, my_prec);
265            let right_str = format_expr(right, my_prec);
266            format!("{} {} {}", left_str, op.symbol(), right_str)
267        }
268        ExpressionKind::UnitConversion(value, target) => {
269            let value_str = format_expr(value, my_prec);
270            format!("{} in {}", value_str, target)
271        }
272        ExpressionKind::LogicalNegation(inner_expr, _) => {
273            let inner_str = format_expr(inner_expr, my_prec);
274            format!("not {}", inner_str)
275        }
276        ExpressionKind::LogicalAnd(left, right) => {
277            let left_str = format_expr(left, my_prec);
278            let right_str = format_expr(right, my_prec);
279            format!("{} and {}", left_str, right_str)
280        }
281        ExpressionKind::LogicalOr(left, right) => {
282            let left_str = format_expr(left, my_prec);
283            let right_str = format_expr(right, my_prec);
284            format!("{} or {}", left_str, right_str)
285        }
286        ExpressionKind::MathematicalComputation(op, operand) => {
287            let operand_str = format_expr(operand, my_prec);
288            format!("{} {}", op, operand_str)
289        }
290        ExpressionKind::Veto(veto) => match &veto.message {
291            Some(msg) => format!("veto {}", crate::parsing::ast::quote_lemma_text(msg)),
292            None => "veto".to_string(),
293        },
294    };
295
296    if needs_parens {
297        format!("({})", inner)
298    } else {
299        inner
300    }
301}
302
303// =============================================================================
304// Expression wrapping (soft line breaking at max_cols)
305// =============================================================================
306
307/// Indent every line after the first by `indent`.
308fn indent_after_first_line(s: &str, indent: &str) -> String {
309    let mut first = true;
310    let mut out = String::new();
311    for line in s.lines() {
312        if first {
313            first = false;
314            out.push_str(line);
315        } else {
316            out.push('\n');
317            out.push_str(indent);
318            out.push_str(line);
319        }
320    }
321    if s.ends_with('\n') {
322        out.push('\n');
323    }
324    out
325}
326
327/// Format an expression with optional wrapping at arithmetic operators when over max_cols.
328/// `parent_prec` is used to add parentheses when needed (pass 10 for top level).
329fn format_expr_wrapped(
330    expr: &Expression,
331    max_cols: usize,
332    indent: &str,
333    parent_prec: u8,
334) -> String {
335    let my_prec = expression_precedence(&expr.kind);
336
337    let wrap_in_parens = |s: String| {
338        if parent_prec < 10 && my_prec < parent_prec {
339            format!("({})", s)
340        } else {
341            s
342        }
343    };
344
345    match &expr.kind {
346        ExpressionKind::Arithmetic(left, op, right) => {
347            let left_str = format_expr_wrapped(left.as_ref(), max_cols, indent, my_prec);
348            let right_str = format_expr_wrapped(right.as_ref(), max_cols, indent, my_prec);
349            let op_symbol = op.symbol();
350            let single_line = format!("{} {} {}", left_str, op_symbol, right_str);
351            if single_line.len() <= max_cols && !single_line.contains('\n') {
352                return wrap_in_parens(single_line);
353            }
354            let continued_right = indent_after_first_line(&right_str, indent);
355            let continuation = format!("{}{} {}", indent, op_symbol, continued_right);
356            let multi_line = format!("{}\n{}", left_str, continuation);
357            wrap_in_parens(multi_line)
358        }
359        _ => {
360            let s = format_expr(expr, parent_prec);
361            wrap_in_parens(s)
362        }
363    }
364}
365
366// =============================================================================
367// Tests
368// =============================================================================
369
370#[cfg(test)]
371mod tests {
372    use super::*;
373    use crate::parsing::ast::{
374        AsLemmaSource, BooleanValue, DateTimeValue, DurationUnit, TimeValue, TimezoneValue, Value,
375    };
376    use rust_decimal::prelude::FromStr;
377    use rust_decimal::Decimal;
378
379    /// Helper: format a Value as canonical Lemma source via AsLemmaSource.
380    fn fmt_value(v: &Value) -> String {
381        format!("{}", AsLemmaSource(v))
382    }
383
384    #[test]
385    fn test_format_value_text_is_quoted() {
386        let v = Value::Text("light".to_string());
387        assert_eq!(fmt_value(&v), "\"light\"");
388    }
389
390    #[test]
391    fn test_format_value_text_escapes_quotes() {
392        let v = Value::Text("say \"hello\"".to_string());
393        assert_eq!(fmt_value(&v), "\"say \\\"hello\\\"\"");
394    }
395
396    #[test]
397    fn test_format_value_number() {
398        let v = Value::Number(Decimal::from_str("42.50").unwrap());
399        assert_eq!(fmt_value(&v), "42.5");
400    }
401
402    #[test]
403    fn test_format_value_number_integer() {
404        let v = Value::Number(Decimal::from_str("100.00").unwrap());
405        assert_eq!(fmt_value(&v), "100");
406    }
407
408    #[test]
409    fn test_format_value_boolean() {
410        assert_eq!(fmt_value(&Value::Boolean(BooleanValue::True)), "true");
411        assert_eq!(fmt_value(&Value::Boolean(BooleanValue::Yes)), "yes");
412        assert_eq!(fmt_value(&Value::Boolean(BooleanValue::No)), "no");
413        assert_eq!(fmt_value(&Value::Boolean(BooleanValue::Accept)), "accept");
414        assert_eq!(fmt_value(&Value::Boolean(BooleanValue::Reject)), "reject");
415    }
416
417    #[test]
418    fn test_format_value_scale() {
419        let v = Value::Scale(Decimal::from_str("99.50").unwrap(), "eur".to_string());
420        assert_eq!(fmt_value(&v), "99.5 eur");
421    }
422
423    #[test]
424    fn test_format_value_duration() {
425        let v = Value::Duration(Decimal::from(40), DurationUnit::Hour);
426        assert_eq!(fmt_value(&v), "40 hours");
427    }
428
429    #[test]
430    fn test_format_value_ratio_percent() {
431        let v = Value::Ratio(
432            Decimal::from_str("0.10").unwrap(),
433            Some("percent".to_string()),
434        );
435        assert_eq!(fmt_value(&v), "10%");
436    }
437
438    #[test]
439    fn test_format_value_ratio_permille() {
440        let v = Value::Ratio(
441            Decimal::from_str("0.005").unwrap(),
442            Some("permille".to_string()),
443        );
444        assert_eq!(fmt_value(&v), "5%%");
445    }
446
447    #[test]
448    fn test_format_value_ratio_bare() {
449        let v = Value::Ratio(Decimal::from_str("0.25").unwrap(), None);
450        assert_eq!(fmt_value(&v), "0.25");
451    }
452
453    #[test]
454    fn test_format_value_date_only() {
455        let v = Value::Date(DateTimeValue {
456            year: 2024,
457            month: 1,
458            day: 15,
459            hour: 0,
460            minute: 0,
461            second: 0,
462            timezone: None,
463        });
464        assert_eq!(fmt_value(&v), "2024-01-15");
465    }
466
467    #[test]
468    fn test_format_value_datetime_with_tz() {
469        let v = Value::Date(DateTimeValue {
470            year: 2024,
471            month: 1,
472            day: 15,
473            hour: 14,
474            minute: 30,
475            second: 0,
476            timezone: Some(TimezoneValue {
477                offset_hours: 0,
478                offset_minutes: 0,
479            }),
480        });
481        assert_eq!(fmt_value(&v), "2024-01-15T14:30:00Z");
482    }
483
484    #[test]
485    fn test_format_value_time() {
486        let v = Value::Time(TimeValue {
487            hour: 14,
488            minute: 30,
489            second: 45,
490            timezone: None,
491        });
492        assert_eq!(fmt_value(&v), "14:30:45");
493    }
494
495    #[test]
496    fn test_format_source_round_trips_text() {
497        let source = r#"doc test
498
499fact name = "Alice"
500
501rule greeting = "hello"
502"#;
503        let formatted = format_source(source, "test.lemma").unwrap();
504        assert!(formatted.contains("\"Alice\""), "fact text must be quoted");
505        assert!(formatted.contains("\"hello\""), "rule text must be quoted");
506    }
507
508    #[test]
509    fn test_format_source_preserves_percent() {
510        let source = r#"doc test
511
512fact rate = 10 percent
513
514rule tax = rate * 21%
515"#;
516        let formatted = format_source(source, "test.lemma").unwrap();
517        assert!(
518            formatted.contains("10%"),
519            "fact percent must use shorthand %, got: {}",
520            formatted
521        );
522    }
523
524    #[test]
525    fn test_format_groups_facts_preserving_order() {
526        // Facts are deliberately mixed: the formatter keeps all regular facts together
527        // in original order, aligned
528        let source = r#"doc test
529
530fact income = [number -> minimum 0]
531fact filing_status = [filing_status_type -> default "single"]
532fact country = "NL"
533fact deductions = [number -> minimum 0]
534fact name = [text]
535
536rule total = income
537"#;
538        let formatted = format_source(source, "test.lemma").unwrap();
539        let fact_section = formatted
540            .split("rule total")
541            .next()
542            .unwrap()
543            .split("doc test\n")
544            .nth(1)
545            .unwrap();
546        let lines: Vec<&str> = fact_section.lines().filter(|l| !l.is_empty()).collect();
547        // All regular facts in one group, original order, aligned
548        assert_eq!(lines[0], "fact income        = [number -> minimum 0]");
549        assert_eq!(
550            lines[1],
551            "fact filing_status = [filing_status_type -> default \"single\"]"
552        );
553        assert_eq!(lines[2], "fact country       = \"NL\"");
554        assert_eq!(lines[3], "fact deductions    = [number -> minimum 0]");
555        assert_eq!(lines[4], "fact name          = [text]");
556    }
557
558    #[test]
559    fn test_format_groups_doc_refs_with_overrides() {
560        let source = r#"doc test
561
562fact retail.quantity = 5
563fact wholesale = doc order/wholesale
564fact retail = doc order/retail
565fact wholesale.quantity = 100
566fact base_price = 50
567
568rule total = base_price
569"#;
570        let formatted = format_source(source, "test.lemma").unwrap();
571        let fact_section = formatted
572            .split("rule total")
573            .next()
574            .unwrap()
575            .split("doc test\n")
576            .nth(1)
577            .unwrap();
578        let lines: Vec<&str> = fact_section.lines().filter(|l| !l.is_empty()).collect();
579        // Group 1: Literals
580        assert_eq!(lines[0], "fact base_price = 50");
581        // Group 4: Doc refs in original order, each with its overrides, aligned
582        assert_eq!(lines[1], "fact wholesale          = doc order/wholesale");
583        assert_eq!(lines[2], "fact wholesale.quantity = 100");
584        assert_eq!(lines[3], "fact retail          = doc order/retail");
585        assert_eq!(lines[4], "fact retail.quantity = 5");
586    }
587
588    #[test]
589    fn test_format_source_weather_clothing_text_quoted() {
590        let source = r#"doc weather_clothing
591
592type clothing_style = text
593  -> option "light"
594  -> option "warm"
595
596fact temperature = [number]
597
598rule clothing_layer = "light"
599  unless temperature < 5 then "warm"
600"#;
601        let formatted = format_source(source, "test.lemma").unwrap();
602        assert!(
603            formatted.contains("\"light\""),
604            "text in rule must be quoted, got: {}",
605            formatted
606        );
607        assert!(
608            formatted.contains("\"warm\""),
609            "text in unless must be quoted, got: {}",
610            formatted
611        );
612    }
613
614    // NOTE: Default value type validation (e.g. rejecting "10 $$" as a number
615    // default) is tested at the planning level in engine.rs, not here. The
616    // formatter only parses — it does not validate types. Planning catches
617    // invalid defaults for both primitives and named types.
618
619    #[test]
620    fn test_format_text_option_round_trips() {
621        let source = r#"doc test
622
623type status = text
624  -> option "active"
625  -> option "inactive"
626
627fact s = [status]
628
629rule out = s
630"#;
631        let formatted = format_source(source, "test.lemma").unwrap();
632        assert!(
633            formatted.contains("option \"active\""),
634            "text option must be quoted, got: {}",
635            formatted
636        );
637        assert!(
638            formatted.contains("option \"inactive\""),
639            "text option must be quoted, got: {}",
640            formatted
641        );
642        // Round-trip
643        let reparsed = format_source(&formatted, "test.lemma");
644        assert!(reparsed.is_ok(), "formatted output should re-parse");
645    }
646
647    #[test]
648    fn test_format_help_round_trips() {
649        let source = r#"doc test
650fact quantity = [number -> help "Number of items to order"]
651rule total = quantity
652"#;
653        let formatted = format_source(source, "test.lemma").unwrap();
654        assert!(
655            formatted.contains("help \"Number of items to order\""),
656            "help must be quoted, got: {}",
657            formatted
658        );
659        // Round-trip
660        let reparsed = format_source(&formatted, "test.lemma");
661        assert!(reparsed.is_ok(), "formatted output should re-parse");
662    }
663
664    #[test]
665    fn test_format_scale_type_def_round_trips() {
666        let source = r#"doc test
667
668type money = scale
669  -> unit eur 1.00
670  -> unit usd 1.10
671  -> decimals 2
672  -> minimum 0
673
674fact price = [money]
675
676rule total = price
677"#;
678        let formatted = format_source(source, "test.lemma").unwrap();
679        assert!(
680            formatted.contains("unit eur 1.00"),
681            "scale unit should not be quoted, got: {}",
682            formatted
683        );
684        // Round-trip
685        let reparsed = format_source(&formatted, "test.lemma");
686        assert!(
687            reparsed.is_ok(),
688            "formatted output should re-parse, got: {:?}",
689            reparsed
690        );
691    }
692}