Skip to main content

prosaic_core/
refine_score.rs

1//! Composite scorer for the retrospective refine pass.
2//!
3//! Computes a single weighted-sum quality score over a [`RenderedDocument`]
4//! for a given [`RefineWeights`] and optional [`StyleProfile`]. The
5//! scorer is a pure function — no mutation, no side effects. Higher
6//! scores are better; the iteration controller compares candidate scores
7//! to decide whether each refinement iteration is improving the output.
8//!
9//! Each component lands in `[0.0, 1.0]` so the weighted sum stays within
10//! a predictable range. Components:
11//!
12//! - **Repetition compliance** — 1 - average per-sentence word
13//!   repetition fraction. Higher = more lexical variety across sentences.
14//! - **Rhythm compliance** — 1 - normalized cadence flatness. Higher =
15//!   more sentence-length variance.
16//! - **Connective family balance** — 1 - dominant-family share. Higher
17//!   = no single family dominates document-scope emissions.
18//! - **Paragraph opener diversity** — distinct openers / total
19//!   paragraphs that opened with a connective. Higher = more variety.
20//! - **List-style diversity** — distinct styles / total styles emitted.
21//! - **RST relation balance** — 1 - dominant-relation share.
22//! - **Profile match** — 1 - L1 distance between observed and profile
23//!   target length distribution; gated on profile presence.
24
25#[cfg(not(feature = "std"))]
26use alloc::string::String;
27#[cfg(not(feature = "std"))]
28use alloc::vec::Vec;
29
30use crate::discourse::ListStyle;
31use crate::refine::{RefineWeights, RenderedDocument};
32use crate::rst::RstRelation;
33use crate::style::StyleProfile;
34
35/// Compute the composite score for `document` under `weights` and
36/// `profile`. Returns a value in `[0.0, sum_of_weights]`. Higher is
37/// better.
38pub fn score_document(
39    document: &RenderedDocument,
40    weights: &RefineWeights,
41    profile: Option<&StyleProfile>,
42) -> f32 {
43    weights.repetition * repetition_compliance(document)
44        + weights.rhythm * rhythm_compliance(document)
45        + weights.connective * connective_family_balance(document)
46        + weights.paragraph_opener * paragraph_opener_diversity(document)
47        + weights.list_style_diversity * list_style_diversity(document)
48        + weights.rst_balance * rst_relation_balance(document)
49        + weights.profile_match * profile_match(document, profile)
50}
51
52fn repetition_compliance(document: &RenderedDocument) -> f32 {
53    if document.sentences.len() < 2 {
54        return 1.0;
55    }
56    // Approximation: 1 minus the average pairwise Jaccard similarity over
57    // adjacent sentences. Adjacent-pair similarity is what discourse
58    // repetition perceives most strongly.
59    let mut total_sim = 0.0_f32;
60    let mut pairs = 0_usize;
61    for window in document.sentences.windows(2) {
62        let a = tokenize(&window[0].text);
63        let b = tokenize(&window[1].text);
64        if a.is_empty() || b.is_empty() {
65            continue;
66        }
67        let intersection: usize = a.iter().filter(|w| b.contains(w)).count();
68        let union: usize = a
69            .iter()
70            .chain(b.iter())
71            .collect::<alloc::collections::BTreeSet<_>>()
72            .len();
73        if union > 0 {
74            total_sim += intersection as f32 / union as f32;
75            pairs += 1;
76        }
77    }
78    if pairs == 0 {
79        return 1.0;
80    }
81    1.0 - (total_sim / pairs as f32).clamp(0.0, 1.0)
82}
83
84fn rhythm_compliance(document: &RenderedDocument) -> f32 {
85    if document.sentences.len() < 3 {
86        return 1.0;
87    }
88    let lengths: Vec<f32> = document
89        .sentences
90        .iter()
91        .map(|s| s.word_count as f32)
92        .collect();
93    let n = lengths.len() as f32;
94    let mean = lengths.iter().sum::<f32>() / n;
95    let variance = lengths
96        .iter()
97        .map(|x| {
98            let d = x - mean;
99            d * d
100        })
101        .sum::<f32>()
102        / n;
103    let stdev = approx_sqrt(variance);
104    // Normalize: stdev of 0 → score 0 (perfectly flat); stdev ≥ 6 → score 1.
105    (stdev / 6.0_f32).clamp(0.0, 1.0)
106}
107
108/// Newton-Raphson `sqrt` approximation. Used in place of `f32::sqrt` to
109/// keep the refine module no_std-compatible (the std `sqrt` impl isn't
110/// available in `core` on stable).
111fn approx_sqrt(x: f32) -> f32 {
112    if x <= 0.0 {
113        return 0.0;
114    }
115    let mut g = if x >= 1.0 { x } else { 1.0 };
116    for _ in 0..6 {
117        g = 0.5 * (g + x / g);
118    }
119    g
120}
121
122fn connective_family_balance(document: &RenderedDocument) -> f32 {
123    if document.connectives_used.is_empty() {
124        return 1.0;
125    }
126    let total = document.connectives_used.len() as f32;
127    let mut count = alloc::collections::BTreeMap::<&'static str, usize>::new();
128    for u in &document.connectives_used {
129        if let Some(family) = family_for(&u.connective) {
130            *count.entry(family).or_insert(0) += 1;
131        }
132    }
133    if count.is_empty() {
134        return 1.0;
135    }
136    let dominant = count.values().copied().max().unwrap_or(0) as f32;
137    (1.0 - dominant / total).clamp(0.0, 1.0)
138}
139
140fn paragraph_opener_diversity(document: &RenderedDocument) -> f32 {
141    let openers: Vec<&String> = document
142        .paragraphs
143        .iter()
144        .filter_map(|p| {
145            p.sentences
146                .first()
147                .and_then(|s| s.opening_connective.as_ref())
148        })
149        .collect();
150    if openers.is_empty() {
151        return 1.0;
152    }
153    let distinct: alloc::collections::BTreeSet<&String> = openers.iter().copied().collect();
154    (distinct.len() as f32 / openers.len() as f32).clamp(0.0, 1.0)
155}
156
157fn list_style_diversity(document: &RenderedDocument) -> f32 {
158    if document.list_styles_used.is_empty() {
159        return 1.0;
160    }
161    let distinct: alloc::collections::BTreeSet<ListStyle> = document
162        .list_styles_used
163        .iter()
164        .map(|u| u.list_style)
165        .collect();
166    (distinct.len() as f32 / document.list_styles_used.len() as f32).clamp(0.0, 1.0)
167}
168
169fn rst_relation_balance(document: &RenderedDocument) -> f32 {
170    if document.connectives_used.is_empty() {
171        return 1.0;
172    }
173    let mut count = alloc::collections::BTreeMap::<RstRelation, usize>::new();
174    let mut classified_total = 0_usize;
175    for u in &document.connectives_used {
176        if let Some(rst) = rst_for(&u.connective) {
177            *count.entry(rst).or_insert(0) += 1;
178            classified_total += 1;
179        }
180    }
181    if classified_total == 0 {
182        return 1.0;
183    }
184    let dominant = count.values().copied().max().unwrap_or(0) as f32;
185    (1.0 - dominant / classified_total as f32).clamp(0.0, 1.0)
186}
187
188fn profile_match(document: &RenderedDocument, profile: Option<&StyleProfile>) -> f32 {
189    let Some(profile) = profile else {
190        return 1.0;
191    };
192    if profile.sentence_length.is_neutral() || document.sentences.is_empty() {
193        return 1.0;
194    }
195    let dist = &profile.sentence_length;
196    let mut counts = [0_usize; 3];
197    for sentence in &document.sentences {
198        let bucket = if sentence.word_count <= dist.short_max_words as usize {
199            0
200        } else if sentence.word_count <= dist.medium_max_words as usize {
201            1
202        } else {
203            2
204        };
205        counts[bucket] += 1;
206    }
207    let total = document.sentences.len() as f32;
208    let observed = [
209        counts[0] as f32 / total,
210        counts[1] as f32 / total,
211        counts[2] as f32 / total,
212    ];
213    let target_sum = dist.short + dist.medium + dist.long;
214    if target_sum <= 0.0 {
215        return 1.0;
216    }
217    let target = [
218        dist.short / target_sum,
219        dist.medium / target_sum,
220        dist.long / target_sum,
221    ];
222    let l1 = (observed[0] - target[0]).abs()
223        + (observed[1] - target[1]).abs()
224        + (observed[2] - target[2]).abs();
225    // L1 distance ranges 0..=2 for normalized distributions.
226    (1.0 - l1 / 2.0).clamp(0.0, 1.0)
227}
228
229fn tokenize(text: &str) -> Vec<String> {
230    text.split_whitespace()
231        .filter_map(|w| {
232            let cleaned: String = w
233                .chars()
234                .filter(|c| c.is_alphanumeric())
235                .flat_map(|c| c.to_lowercase())
236                .collect();
237            if cleaned.len() > 2 {
238                Some(cleaned)
239            } else {
240                None
241            }
242        })
243        .collect()
244}
245
246fn family_for(connective: &str) -> Option<&'static str> {
247    for c in &["Additionally,", "Furthermore,", "It also"] {
248        if connective.starts_with(c) {
249            return Some("continuation");
250        }
251    }
252    for c in &["Similarly,", "Likewise,"] {
253        if connective.starts_with(c) {
254            return Some("similarity");
255        }
256    }
257    for c in &["Meanwhile,", "However,", "On the other hand,"] {
258        if connective.starts_with(c) {
259            return Some("contrast");
260        }
261    }
262    None
263}
264
265fn rst_for(connective: &str) -> Option<RstRelation> {
266    for c in &["Additionally,", "Furthermore,", "It also"] {
267        if connective.starts_with(c) {
268            return Some(RstRelation::Elaboration);
269        }
270    }
271    for c in &["Similarly,", "Likewise,"] {
272        if connective.starts_with(c) {
273            return Some(RstRelation::Sequence);
274        }
275    }
276    for c in &["Meanwhile,", "However,", "On the other hand,"] {
277        if connective.starts_with(c) {
278            return Some(RstRelation::Contrast);
279        }
280    }
281    None
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287    use crate::refine::{EventMeta, ParagraphRender, RenderedDocument};
288
289    fn doc_from(paragraphs: Vec<ParagraphRender>) -> RenderedDocument {
290        RenderedDocument::from_paragraphs(paragraphs)
291    }
292
293    fn one_paragraph(
294        text: &str,
295        connective: Option<&str>,
296        list_style: Option<ListStyle>,
297    ) -> ParagraphRender {
298        ParagraphRender {
299            text: text.to_string(),
300            events: vec![EventMeta {
301                connective: connective.map(|s| s.to_string()),
302                list_style,
303            }],
304        }
305    }
306
307    fn weights() -> RefineWeights {
308        RefineWeights::default()
309    }
310
311    // ── Pure-function determinism ────────────────────────────────────────
312
313    #[test]
314    fn empty_document_scores_at_max() {
315        // No sentences = no detected failures. Score should sum to all
316        // weights at full value.
317        let doc = doc_from(vec![]);
318        let s = score_document(&doc, &weights(), None);
319        let max = weights().repetition
320            + weights().rhythm
321            + weights().connective
322            + weights().paragraph_opener
323            + weights().list_style_diversity
324            + weights().rst_balance
325            + weights().profile_match;
326        assert!((s - max).abs() < 0.001);
327    }
328
329    #[test]
330    fn score_is_deterministic() {
331        let doc = doc_from(vec![
332            one_paragraph("First short sentence.", None, None),
333            one_paragraph(
334                "Additionally, second longer sentence with more words.",
335                Some("Additionally,"),
336                None,
337            ),
338        ]);
339        let a = score_document(&doc, &weights(), None);
340        let b = score_document(&doc, &weights(), None);
341        assert_eq!(a, b);
342    }
343
344    // ── Monotonicity ─────────────────────────────────────────────────────
345
346    #[test]
347    fn rhythm_compliance_higher_with_more_variance() {
348        let flat = doc_from(
349            (0..6)
350                .map(|i| {
351                    one_paragraph(
352                        &format!(
353                            "{} word word word word word word word word word.",
354                            "x".repeat(i + 1)
355                        ),
356                        None,
357                        None,
358                    )
359                })
360                .collect(),
361        );
362        let varied = doc_from(vec![
363            one_paragraph("Short.", None, None),
364            one_paragraph("A medium length sentence here for context.", None, None),
365            one_paragraph(
366                "And a much longer sentence with several clauses extending well beyond average length.",
367                None,
368                None,
369            ),
370            one_paragraph("Tiny.", None, None),
371            one_paragraph(
372                "Another medium length sentence with reasonable word count.",
373                None,
374                None,
375            ),
376            one_paragraph(
377                "Yet another extended one with more words to really push the variance up.",
378                None,
379                None,
380            ),
381        ]);
382        assert!(rhythm_compliance(&varied) > rhythm_compliance(&flat));
383    }
384
385    #[test]
386    fn paragraph_opener_diversity_higher_with_distinct_openers() {
387        let monotone = doc_from(
388            (0..4)
389                .map(|_| {
390                    one_paragraph(
391                        "Additionally, opener text here.",
392                        Some("Additionally,"),
393                        None,
394                    )
395                })
396                .collect(),
397        );
398        let diverse = doc_from(vec![
399            one_paragraph("Additionally, opener.", Some("Additionally,"), None),
400            one_paragraph("Furthermore, opener.", Some("Furthermore,"), None),
401            one_paragraph("However, opener.", Some("However,"), None),
402            one_paragraph("Similarly, opener.", Some("Similarly,"), None),
403        ]);
404        assert!(paragraph_opener_diversity(&diverse) > paragraph_opener_diversity(&monotone));
405    }
406
407    #[test]
408    fn list_style_diversity_higher_with_distinct_styles() {
409        let monotone = doc_from(
410            (0..4)
411                .map(|_| one_paragraph("Sentence with list.", None, Some(ListStyle::Including)))
412                .collect(),
413        );
414        let diverse = doc_from(vec![
415            one_paragraph("Sentence.", None, Some(ListStyle::Including)),
416            one_paragraph("Sentence.", None, Some(ListStyle::SuchAs)),
417            one_paragraph("Sentence.", None, Some(ListStyle::Dash)),
418            one_paragraph("Sentence.", None, Some(ListStyle::Bracketed)),
419        ]);
420        assert!(list_style_diversity(&diverse) > list_style_diversity(&monotone));
421    }
422
423    #[test]
424    fn rst_relation_balance_higher_when_balanced() {
425        let imbalanced = doc_from(
426            (0..5)
427                .map(|_| one_paragraph("Additionally, sentence.", Some("Additionally,"), None))
428                .collect(),
429        );
430        let balanced = doc_from(vec![
431            one_paragraph("Additionally, sentence.", Some("Additionally,"), None),
432            one_paragraph("However, sentence.", Some("However,"), None),
433            one_paragraph("Similarly, sentence.", Some("Similarly,"), None),
434            one_paragraph("Furthermore, sentence.", Some("Furthermore,"), None),
435            one_paragraph("Likewise, sentence.", Some("Likewise,"), None),
436        ]);
437        assert!(rst_relation_balance(&balanced) > rst_relation_balance(&imbalanced));
438    }
439
440    #[test]
441    fn profile_match_higher_when_distribution_aligns() {
442        let target = crate::style::LengthDistribution {
443            short: 1.0,
444            medium: 0.0,
445            long: 0.0,
446            short_max_words: 8,
447            medium_max_words: 18,
448        };
449        let p = StyleProfile::builder("short-only")
450            .sentence_length(target)
451            .build()
452            .unwrap();
453        let aligned = doc_from(
454            (0..6)
455                .map(|_| {
456                    one_paragraph("Short text here.", None, None) // 3 words → short
457                })
458                .collect(),
459        );
460        let misaligned = doc_from(
461            (0..6)
462                .map(|_| {
463                    one_paragraph(
464                        "A long sentence with many many words far above the short threshold count.",
465                        None,
466                        None,
467                    )
468                })
469                .collect(),
470        );
471        assert!(profile_match(&aligned, Some(&p)) > profile_match(&misaligned, Some(&p)));
472    }
473
474    #[test]
475    fn full_score_increases_when_one_dimension_strictly_improves() {
476        // Same documents except `improved` has more diverse paragraph
477        // openers. All other components compute the same — so the full
478        // score must rise by ≈ weights.paragraph_opener × delta.
479        let mono_openers = doc_from(vec![
480            one_paragraph("Additionally, foo.", Some("Additionally,"), None),
481            one_paragraph("Additionally, bar.", Some("Additionally,"), None),
482            one_paragraph("Additionally, baz.", Some("Additionally,"), None),
483            one_paragraph("Additionally, qux.", Some("Additionally,"), None),
484        ]);
485        let diverse_openers = doc_from(vec![
486            one_paragraph("Additionally, foo.", Some("Additionally,"), None),
487            one_paragraph("Furthermore, bar.", Some("Furthermore,"), None),
488            one_paragraph("However, baz.", Some("However,"), None),
489            one_paragraph("Similarly, qux.", Some("Similarly,"), None),
490        ]);
491        assert!(
492            score_document(&diverse_openers, &weights(), None)
493                > score_document(&mono_openers, &weights(), None)
494        );
495    }
496
497    // ── Tokenizer guard ──────────────────────────────────────────────────
498
499    #[test]
500    fn tokenize_drops_short_and_punct() {
501        let toks = tokenize("a, foo bar! the. baz?");
502        assert_eq!(
503            toks,
504            vec![
505                "foo".to_string(),
506                "bar".to_string(),
507                "the".to_string(),
508                "baz".to_string()
509            ]
510        );
511    }
512}