Skip to main content

rdocx_layout/
style_resolver.rs

1//! Style resolution: cascade styles and generate numbering markers.
2//!
3//! Ports the logic from `crates/rdocx/src/style.rs` since rdocx-layout
4//! depends on rdocx-oxml directly (not rdocx).
5
6use std::collections::{HashMap, HashSet};
7
8use rdocx_oxml::numbering::{CT_Numbering, ST_NumberFormat};
9use rdocx_oxml::properties::{CT_PPr, CT_RPr};
10use rdocx_oxml::styles::{CT_Style, CT_Styles, StyleType};
11
12/// A fully resolved paragraph with merged properties and numbering info.
13#[derive(Debug, Clone)]
14pub struct ResolvedParagraph {
15    /// Merged paragraph properties (style chain + direct formatting).
16    pub ppr: CT_PPr,
17    /// Resolved runs with merged run properties.
18    pub runs: Vec<ResolvedRun>,
19    /// Numbering marker info (if paragraph is part of a list).
20    pub numbering: Option<ResolvedNumbering>,
21}
22
23/// A run with fully resolved properties.
24#[derive(Debug, Clone)]
25pub struct ResolvedRun {
26    /// Merged run properties.
27    pub rpr: CT_RPr,
28    /// Run content items.
29    pub content: Vec<rdocx_oxml::text::RunContent>,
30}
31
32/// Resolved numbering marker for a list paragraph.
33#[derive(Debug, Clone)]
34pub struct ResolvedNumbering {
35    /// The text of the marker (e.g., "1.", "a)", bullet char).
36    pub marker_text: String,
37    /// Run properties for the marker.
38    pub marker_rpr: CT_RPr,
39}
40
41/// Tracks numbering counters across paragraphs.
42pub struct NumberingState {
43    /// (numId, ilvl) → current count
44    counters: HashMap<(u32, u32), u32>,
45}
46
47impl Default for NumberingState {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53impl NumberingState {
54    pub fn new() -> Self {
55        NumberingState {
56            counters: HashMap::new(),
57        }
58    }
59
60    /// Advance the counter for the given numId/ilvl and return the new value.
61    /// Also resets any deeper levels.
62    pub fn advance(&mut self, num_id: u32, ilvl: u32, start: u32) -> u32 {
63        let key = (num_id, ilvl);
64        let counter = self.counters.entry(key).or_insert(start - 1);
65        *counter += 1;
66        let value = *counter;
67
68        // Reset deeper levels
69        for deeper in (ilvl + 1)..=8 {
70            self.counters.remove(&(num_id, deeper));
71        }
72
73        value
74    }
75
76    /// Get the current count for a level (without advancing).
77    pub fn current(&self, num_id: u32, ilvl: u32) -> u32 {
78        self.counters.get(&(num_id, ilvl)).copied().unwrap_or(0)
79    }
80}
81
82/// Resolve paragraph properties by walking the style inheritance chain.
83pub fn resolve_paragraph_properties(style_id: Option<&str>, styles: &CT_Styles) -> CT_PPr {
84    let mut effective = CT_PPr::default();
85
86    // 1. Start from docDefaults
87    if let Some(ref defaults) = styles.doc_defaults
88        && let Some(ref ppr) = defaults.ppr
89    {
90        effective.merge_from(ppr);
91    }
92
93    // 2. Walk the basedOn chain
94    if let Some(sid) = style_id {
95        let chain = collect_style_chain(sid, styles);
96        // Apply from most-base to most-derived
97        for style in chain.iter().rev() {
98            if let Some(ref ppr) = style.ppr {
99                effective.merge_from(ppr);
100            }
101        }
102    } else {
103        // Apply the default paragraph style
104        if let Some(default_style) = styles.get_default(StyleType::Paragraph)
105            && let Some(ref ppr) = default_style.ppr
106        {
107            effective.merge_from(ppr);
108        }
109    }
110
111    effective
112}
113
114/// Resolve run properties by walking paragraph and character style chains.
115pub fn resolve_run_properties(
116    para_style_id: Option<&str>,
117    run_style_id: Option<&str>,
118    styles: &CT_Styles,
119) -> CT_RPr {
120    let mut effective = CT_RPr::default();
121
122    // 1. docDefaults run properties
123    if let Some(ref defaults) = styles.doc_defaults
124        && let Some(ref rpr) = defaults.rpr
125    {
126        effective.merge_from(rpr);
127    }
128
129    // 2. paragraph style's rpr (following basedOn chain)
130    let para_sid = para_style_id.or_else(|| {
131        styles
132            .get_default(StyleType::Paragraph)
133            .map(|s| s.style_id.as_str())
134    });
135    if let Some(sid) = para_sid {
136        let chain = collect_style_chain(sid, styles);
137        for style in chain.iter().rev() {
138            if let Some(ref rpr) = style.rpr {
139                effective.merge_from(rpr);
140            }
141        }
142    }
143
144    // 3. character style's rpr (following basedOn chain)
145    if let Some(sid) = run_style_id {
146        let chain = collect_style_chain(sid, styles);
147        for style in chain.iter().rev() {
148            if let Some(ref rpr) = style.rpr {
149                effective.merge_from(rpr);
150            }
151        }
152    }
153
154    effective
155}
156
157/// Generate the marker text for a numbered/bulleted list item.
158pub fn generate_marker(
159    num_id: u32,
160    ilvl: u32,
161    numbering: &CT_Numbering,
162    state: &mut NumberingState,
163) -> Option<ResolvedNumbering> {
164    let abs = numbering.get_abstract_num_for(num_id)?;
165    let lvl = abs.levels.iter().find(|l| l.ilvl == ilvl)?;
166
167    let num_fmt = lvl.num_fmt.unwrap_or(ST_NumberFormat::Decimal);
168    let start = lvl.start.unwrap_or(1);
169    let lvl_text = lvl.lvl_text.as_deref().unwrap_or("%1.");
170
171    let marker_text = if num_fmt == ST_NumberFormat::Bullet {
172        lvl_text.to_string()
173    } else {
174        let count = state.advance(num_id, ilvl, start);
175        format_lvl_text(lvl_text, num_id, ilvl, count, numbering, state)
176    };
177
178    let marker_rpr = lvl.rpr.clone().unwrap_or_default();
179
180    Some(ResolvedNumbering {
181        marker_text,
182        marker_rpr,
183    })
184}
185
186/// Format level text by substituting %1, %2, etc. with formatted counters.
187fn format_lvl_text(
188    template: &str,
189    num_id: u32,
190    current_ilvl: u32,
191    current_count: u32,
192    numbering: &CT_Numbering,
193    state: &NumberingState,
194) -> String {
195    let abs = match numbering.get_abstract_num_for(num_id) {
196        Some(a) => a,
197        None => return template.to_string(),
198    };
199
200    let mut result = template.to_string();
201    for lvl_idx in 0..=8u32 {
202        let placeholder = format!("%{}", lvl_idx + 1);
203        if result.contains(&placeholder) {
204            let count = if lvl_idx == current_ilvl {
205                current_count
206            } else {
207                state.current(num_id, lvl_idx)
208            };
209            let fmt = abs
210                .levels
211                .iter()
212                .find(|l| l.ilvl == lvl_idx)
213                .and_then(|l| l.num_fmt)
214                .unwrap_or(ST_NumberFormat::Decimal);
215            let formatted = format_number(count, fmt);
216            result = result.replace(&placeholder, &formatted);
217        }
218    }
219    result
220}
221
222/// Format a number according to ST_NumberFormat.
223fn format_number(n: u32, fmt: ST_NumberFormat) -> String {
224    match fmt {
225        ST_NumberFormat::Decimal => n.to_string(),
226        ST_NumberFormat::UpperRoman => to_roman(n, true),
227        ST_NumberFormat::LowerRoman => to_roman(n, false),
228        ST_NumberFormat::UpperLetter => to_letter(n, true),
229        ST_NumberFormat::LowerLetter => to_letter(n, false),
230        ST_NumberFormat::Ordinal => format!("{n}"),
231        ST_NumberFormat::Bullet | ST_NumberFormat::None => String::new(),
232    }
233}
234
235fn to_roman(mut n: u32, upper: bool) -> String {
236    let vals = [
237        (1000, "M"),
238        (900, "CM"),
239        (500, "D"),
240        (400, "CD"),
241        (100, "C"),
242        (90, "XC"),
243        (50, "L"),
244        (40, "XL"),
245        (10, "X"),
246        (9, "IX"),
247        (5, "V"),
248        (4, "IV"),
249        (1, "I"),
250    ];
251    let mut result = String::new();
252    for &(value, numeral) in &vals {
253        while n >= value {
254            result.push_str(numeral);
255            n -= value;
256        }
257    }
258    if upper { result } else { result.to_lowercase() }
259}
260
261fn to_letter(n: u32, upper: bool) -> String {
262    if n == 0 {
263        return String::new();
264    }
265    let base = if upper { b'A' } else { b'a' };
266    let idx = ((n - 1) % 26) as u8;
267    String::from(char::from(base + idx))
268}
269
270/// Collect the style chain from the given style up through basedOn ancestors.
271fn collect_style_chain<'a>(style_id: &str, styles: &'a CT_Styles) -> Vec<&'a CT_Style> {
272    let mut chain = Vec::new();
273    let mut current_id = Some(style_id.to_string());
274    let mut seen = HashSet::new();
275
276    while let Some(ref sid) = current_id {
277        if !seen.insert(sid.clone()) {
278            break; // Prevent cycles
279        }
280        if let Some(style) = styles.get_by_id(sid) {
281            chain.push(style);
282            current_id = style.based_on.clone();
283        } else {
284            break;
285        }
286    }
287
288    chain
289}
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294    use rdocx_oxml::units::{HalfPoint, Twips};
295
296    fn test_styles() -> CT_Styles {
297        let mut styles = CT_Styles::new_default();
298        styles.styles.push(CT_Style {
299            style_id: "Heading2".to_string(),
300            style_type: StyleType::Paragraph,
301            name: Some("heading 2".to_string()),
302            based_on: Some("Heading1".to_string()),
303            next_style: Some("Normal".to_string()),
304            is_default: false,
305            ppr: Some(CT_PPr {
306                space_before: Some(Twips(40)),
307                ..Default::default()
308            }),
309            rpr: Some(CT_RPr {
310                sz: Some(HalfPoint(26)),
311                color: Some("2E74B5".to_string()),
312                ..Default::default()
313            }),
314        });
315        styles
316    }
317
318    #[test]
319    fn resolve_normal_paragraph() {
320        let styles = test_styles();
321        let ppr = resolve_paragraph_properties(Some("Normal"), &styles);
322        assert_eq!(ppr.space_after, Some(Twips(160)));
323    }
324
325    #[test]
326    fn resolve_heading1() {
327        let styles = test_styles();
328        let ppr = resolve_paragraph_properties(Some("Heading1"), &styles);
329        assert_eq!(ppr.keep_next, Some(true));
330        assert_eq!(ppr.space_before, Some(Twips(240)));
331        assert_eq!(ppr.space_after, Some(Twips(0)));
332    }
333
334    #[test]
335    fn resolve_heading2_inherits_heading1() {
336        let styles = test_styles();
337        let ppr = resolve_paragraph_properties(Some("Heading2"), &styles);
338        assert_eq!(ppr.keep_next, Some(true));
339        assert_eq!(ppr.space_before, Some(Twips(40)));
340    }
341
342    #[test]
343    fn resolve_heading2_rpr() {
344        let styles = test_styles();
345        let rpr = resolve_run_properties(Some("Heading2"), None, &styles);
346        assert_eq!(rpr.font_ascii, Some("Calibri".to_string()));
347        assert_eq!(rpr.sz, Some(HalfPoint(26)));
348        assert_eq!(rpr.bold, Some(true));
349        assert_eq!(rpr.color, Some("2E74B5".to_string()));
350    }
351
352    #[test]
353    fn numbering_decimal_marker() {
354        let mut numbering = CT_Numbering::new();
355        let num_id = numbering.add_numbered_list();
356
357        let mut state = NumberingState::new();
358        let marker1 = generate_marker(num_id, 0, &numbering, &mut state).unwrap();
359        assert_eq!(marker1.marker_text, "1.");
360        let marker2 = generate_marker(num_id, 0, &numbering, &mut state).unwrap();
361        assert_eq!(marker2.marker_text, "2.");
362    }
363
364    #[test]
365    fn numbering_bullet_marker() {
366        let mut numbering = CT_Numbering::new();
367        let num_id = numbering.add_bullet_list();
368
369        let mut state = NumberingState::new();
370        let marker = generate_marker(num_id, 0, &numbering, &mut state).unwrap();
371        assert_eq!(marker.marker_text, "\u{2022}");
372    }
373
374    #[test]
375    fn numbering_sub_level_reset() {
376        let mut numbering = CT_Numbering::new();
377        let num_id = numbering.add_numbered_list();
378
379        let mut state = NumberingState::new();
380        // Level 0: 1, 2
381        generate_marker(num_id, 0, &numbering, &mut state);
382        generate_marker(num_id, 0, &numbering, &mut state);
383        // Level 1: a
384        let sub = generate_marker(num_id, 1, &numbering, &mut state).unwrap();
385        assert_eq!(sub.marker_text, "a.");
386        // Back to level 0: 3 — this should reset level 1
387        generate_marker(num_id, 0, &numbering, &mut state);
388        let sub2 = generate_marker(num_id, 1, &numbering, &mut state).unwrap();
389        assert_eq!(sub2.marker_text, "a."); // reset
390    }
391
392    #[test]
393    fn roman_numeral_formatting() {
394        assert_eq!(to_roman(1, true), "I");
395        assert_eq!(to_roman(4, true), "IV");
396        assert_eq!(to_roman(9, true), "IX");
397        assert_eq!(to_roman(14, false), "xiv");
398    }
399
400    #[test]
401    fn letter_formatting() {
402        assert_eq!(to_letter(1, false), "a");
403        assert_eq!(to_letter(26, false), "z");
404        assert_eq!(to_letter(27, false), "a"); // wraps
405        assert_eq!(to_letter(1, true), "A");
406    }
407}