1use alloc::sync::Arc;
25
26#[cfg(not(feature = "std"))]
27use alloc::string::{String, ToString};
28#[cfg(not(feature = "std"))]
29use alloc::vec::Vec;
30
31use crate::discourse::{ListStyle, sentence_word_counts};
32use crate::salience::Salience;
33use crate::style::{LengthDistribution, SalienceBias, StyleProfile};
34
35#[derive(Debug, Clone, Default, PartialEq)]
51#[non_exhaustive]
52pub struct RenderedDocument {
53 pub text: String,
55 pub paragraphs: Vec<RenderedParagraph>,
57 pub sentences: Vec<RenderedSentence>,
61 pub connectives_used: Vec<UsedConnective>,
64 pub list_styles_used: Vec<UsedListStyle>,
66}
67
68#[derive(Debug, Clone, Default, PartialEq)]
69#[non_exhaustive]
70pub struct RenderedParagraph {
71 pub text: String,
73 pub sentences: Vec<RenderedSentence>,
75}
76
77#[derive(Debug, Clone, PartialEq)]
78#[non_exhaustive]
79pub struct RenderedSentence {
80 pub text: String,
81 pub word_count: usize,
82 pub opening_connective: Option<String>,
86 pub paragraph_index: usize,
88 pub sentence_index_in_paragraph: usize,
90}
91
92#[derive(Debug, Clone, PartialEq)]
93#[non_exhaustive]
94pub struct UsedConnective {
95 pub connective: String,
96 pub paragraph_index: usize,
97 pub sentence_index_in_paragraph: usize,
98}
99
100#[derive(Debug, Clone, PartialEq)]
101#[non_exhaustive]
102pub struct UsedListStyle {
103 pub list_style: ListStyle,
104 pub paragraph_index: usize,
105 pub sentence_index_in_paragraph: usize,
106}
107
108impl RenderedDocument {
109 pub(crate) fn from_paragraphs(rendered: Vec<ParagraphRender>) -> Self {
114 let mut paragraphs = Vec::with_capacity(rendered.len());
115 let mut all_sentences: Vec<RenderedSentence> = Vec::new();
116 let mut connectives_used: Vec<UsedConnective> = Vec::new();
117 let mut list_styles_used: Vec<UsedListStyle> = Vec::new();
118
119 for (p_idx, p) in rendered.iter().enumerate() {
120 let mut sentences: Vec<RenderedSentence> = Vec::new();
121 let counts = sentence_word_counts(&p.text);
122 let split_sentences = split_sentences(&p.text);
126 for (s_idx, sentence_text) in split_sentences.iter().enumerate() {
127 let meta = p.events.get(s_idx);
128 let opening_connective = meta.and_then(|m| m.connective.clone());
129 if let Some(c) = &opening_connective {
130 connectives_used.push(UsedConnective {
131 connective: c.clone(),
132 paragraph_index: p_idx,
133 sentence_index_in_paragraph: s_idx,
134 });
135 }
136 if let Some(ls) = meta.and_then(|m| m.list_style) {
137 list_styles_used.push(UsedListStyle {
138 list_style: ls,
139 paragraph_index: p_idx,
140 sentence_index_in_paragraph: s_idx,
141 });
142 }
143 let word_count = counts
144 .get(s_idx)
145 .copied()
146 .unwrap_or_else(|| sentence_text.split_whitespace().count());
147 let s = RenderedSentence {
148 text: sentence_text.clone(),
149 word_count,
150 opening_connective,
151 paragraph_index: p_idx,
152 sentence_index_in_paragraph: s_idx,
153 };
154 sentences.push(s.clone());
155 all_sentences.push(s);
156 }
157 paragraphs.push(RenderedParagraph {
158 text: p.text.clone(),
159 sentences,
160 });
161 }
162
163 let text = paragraphs
164 .iter()
165 .map(|p| p.text.as_str())
166 .collect::<Vec<_>>()
167 .join("\n\n");
168
169 Self {
170 text,
171 paragraphs,
172 sentences: all_sentences,
173 connectives_used,
174 list_styles_used,
175 }
176 }
177}
178
179pub(crate) struct ParagraphRender {
183 pub(crate) text: String,
184 pub(crate) events: Vec<EventMeta>,
185}
186
187#[derive(Default)]
188pub(crate) struct EventMeta {
189 pub(crate) connective: Option<String>,
190 pub(crate) list_style: Option<ListStyle>,
191}
192
193fn split_sentences(text: &str) -> Vec<String> {
197 let mut out: Vec<String> = Vec::new();
198 let mut current = String::new();
199 let mut last_was_terminator = false;
200 for ch in text.chars() {
201 current.push(ch);
202 if matches!(ch, '.' | '!' | '?') {
203 last_was_terminator = true;
204 } else if last_was_terminator && ch.is_whitespace() {
205 let trimmed = current.trim().to_string();
207 if !trimmed.is_empty() {
208 out.push(trimmed);
209 }
210 current.clear();
211 last_was_terminator = false;
212 } else if !ch.is_whitespace() {
213 last_was_terminator = false;
214 }
215 }
216 let trimmed = current.trim().to_string();
217 if !trimmed.is_empty() {
218 out.push(trimmed);
219 }
220 out
221}
222
223#[derive(Clone)]
230#[non_exhaustive]
231pub struct RefineConfig {
232 pub enabled: bool,
233 pub max_iterations: u8,
234 pub min_improvement: f32,
235 pub weights: RefineWeights,
236 pub diagnosers: Vec<Arc<dyn Diagnoser>>,
237}
238
239impl core::fmt::Debug for RefineConfig {
240 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
241 f.debug_struct("RefineConfig")
242 .field("enabled", &self.enabled)
243 .field("max_iterations", &self.max_iterations)
244 .field("min_improvement", &self.min_improvement)
245 .field("weights", &self.weights)
246 .field("diagnosers_count", &self.diagnosers.len())
247 .finish()
248 }
249}
250
251impl RefineConfig {
252 pub fn off() -> Self {
254 Self {
255 enabled: false,
256 max_iterations: 3,
257 min_improvement: 0.01,
258 weights: RefineWeights::default(),
259 diagnosers: Vec::new(),
260 }
261 }
262
263 pub fn balanced() -> Self {
266 Self {
267 enabled: true,
268 max_iterations: 3,
269 min_improvement: 0.01,
270 weights: RefineWeights::default(),
271 diagnosers: crate::refine_diagnosers::default_set(),
272 }
273 }
274
275 pub fn is_off(&self) -> bool {
276 !self.enabled
277 }
278
279 pub fn with_max_iterations(mut self, n: u8) -> Self {
280 self.max_iterations = n;
281 self
282 }
283
284 pub fn with_min_improvement(mut self, m: f32) -> Self {
285 self.min_improvement = m;
286 self
287 }
288
289 pub fn with_weights(mut self, w: RefineWeights) -> Self {
290 self.weights = w;
291 self
292 }
293
294 pub fn with_diagnoser(mut self, d: Arc<dyn Diagnoser>) -> Self {
295 self.diagnosers.push(d);
296 self
297 }
298}
299
300impl Default for RefineConfig {
301 fn default() -> Self {
302 Self::off()
303 }
304}
305
306#[derive(Debug, Clone, Copy)]
312#[non_exhaustive]
313pub struct RefineWeights {
314 pub repetition: f32,
315 pub rhythm: f32,
316 pub connective: f32,
317 pub paragraph_opener: f32,
318 pub list_style_diversity: f32,
319 pub rst_balance: f32,
320 pub profile_match: f32,
321}
322
323impl Default for RefineWeights {
324 fn default() -> Self {
325 Self {
326 repetition: 1.0,
327 rhythm: 1.0,
328 connective: 1.0,
329 paragraph_opener: 1.0,
330 list_style_diversity: 1.0,
331 rst_balance: 1.0,
332 profile_match: 1.0,
333 }
334 }
335}
336
337#[derive(Debug, Clone)]
341pub struct Diagnostic {
342 pub diagnoser: &'static str,
343 pub severity: f32,
344 pub constraints: Vec<RefineConstraint>,
345}
346
347pub trait Diagnoser: Send + Sync {
351 fn name(&self) -> &'static str;
352 fn diagnose(
353 &self,
354 document: &RenderedDocument,
355 profile: Option<&StyleProfile>,
356 ) -> Vec<Diagnostic>;
357}
358
359#[derive(Debug, Clone)]
365pub struct RefineOutcome {
366 pub text: String,
367 pub iterations_run: u8,
368 pub final_score: f32,
369 pub converged_clean: bool,
373}
374
375pub(crate) fn run_refine_loop<F>(
384 config: &RefineConfig,
385 profile: Option<&StyleProfile>,
386 initial: RenderedDocument,
387 initial_session_state: crate::session::Session,
388 session: &mut crate::session::Session,
389 mut render_with_session: F,
390) -> Result<RefineOutcome, crate::error::ProsaicError>
391where
392 F: FnMut(&mut crate::session::Session) -> Result<RenderedDocument, crate::error::ProsaicError>,
393{
394 use crate::refine_score::score_document;
395
396 let mut best = initial;
397 let mut best_score = score_document(&best, &config.weights, profile);
398 let mut best_diagnostics = run_all_diagnosers(&config.diagnosers, &best, profile);
399
400 if best_diagnostics.is_empty() {
401 return Ok(RefineOutcome {
402 text: best.text,
403 iterations_run: 0,
404 final_score: best_score,
405 converged_clean: true,
406 });
407 }
408
409 let mut iter = 0_u8;
410 let mut prev_diag_signature = diagnosis_signature(&best_diagnostics);
411
412 while iter < config.max_iterations {
413 let constraints = aggregate_constraints(&best_diagnostics);
414 if constraints.is_empty() {
415 break;
416 }
417
418 *session = initial_session_state.clone();
420 apply_constraints_to_session(session, &constraints);
421
422 let candidate = match render_with_session(session) {
423 Ok(d) => d,
424 Err(e) => {
425 session.clear_refine_overrides();
428 iter += 1;
429 if iter >= config.max_iterations {
430 return Ok(RefineOutcome {
431 text: best.text,
432 iterations_run: iter,
433 final_score: best_score,
434 converged_clean: false,
435 });
436 }
437 let _ = e; continue;
439 }
440 };
441 session.clear_refine_overrides();
442
443 let candidate_score = score_document(&candidate, &config.weights, profile);
444 let candidate_diagnostics = run_all_diagnosers(&config.diagnosers, &candidate, profile);
445 let candidate_signature = diagnosis_signature(&candidate_diagnostics);
446
447 if candidate_signature == prev_diag_signature && iter > 0 {
449 break;
450 }
451
452 if candidate_score - best_score < config.min_improvement {
454 break;
455 }
456
457 best = candidate;
458 best_score = candidate_score;
459 best_diagnostics = candidate_diagnostics;
460 prev_diag_signature = candidate_signature;
461 iter += 1;
462
463 if best_diagnostics.is_empty() {
464 return Ok(RefineOutcome {
465 text: best.text,
466 iterations_run: iter,
467 final_score: best_score,
468 converged_clean: true,
469 });
470 }
471 }
472
473 Ok(RefineOutcome {
474 text: best.text,
475 iterations_run: iter,
476 final_score: best_score,
477 converged_clean: best_diagnostics.is_empty(),
478 })
479}
480
481fn run_all_diagnosers(
482 diagnosers: &[Arc<dyn Diagnoser>],
483 document: &RenderedDocument,
484 profile: Option<&StyleProfile>,
485) -> Vec<Diagnostic> {
486 let mut out = Vec::new();
487 for d in diagnosers {
488 out.extend(d.diagnose(document, profile));
489 }
490 out
491}
492
493fn aggregate_constraints(diagnostics: &[Diagnostic]) -> Vec<RefineConstraint> {
494 let mut out = Vec::new();
495 for d in diagnostics {
496 for c in &d.constraints {
497 let already = out
500 .iter()
501 .any(|existing: &RefineConstraint| match (existing, c) {
502 (
503 RefineConstraint::BlacklistConnective(a),
504 RefineConstraint::BlacklistConnective(b),
505 ) => a == b,
506 (
507 RefineConstraint::BlacklistListStyle(a),
508 RefineConstraint::BlacklistListStyle(b),
509 ) => a == b,
510 _ => false,
511 });
512 if !already {
513 out.push(c.clone());
514 }
515 }
516 }
517 out
518}
519
520fn apply_constraints_to_session(
521 session: &mut crate::session::Session,
522 constraints: &[RefineConstraint],
523) {
524 let mut blacklist_connectives = Vec::new();
525 let mut blacklist_list_styles = Vec::new();
526 let mut prime_connectives: Vec<String> = Vec::new();
527 let mut prime_list_styles: Vec<ListStyle> = Vec::new();
528 let mut salience_bias_override: Option<SalienceBias> = None;
529 let mut length_distribution_override: Option<LengthDistribution> = None;
530 let mut force_variant_tier: Vec<(String, Salience)> = Vec::new();
531
532 for c in constraints {
533 match c {
534 RefineConstraint::BlacklistConnective(s) => blacklist_connectives.push(s.clone()),
535 RefineConstraint::BlacklistListStyle(s) => blacklist_list_styles.push(*s),
536 RefineConstraint::PrimeRecencyWindow {
537 connectives,
538 list_styles,
539 } => {
540 prime_connectives.extend(connectives.iter().cloned());
541 prime_list_styles.extend(list_styles.iter().copied());
542 }
543 RefineConstraint::OverrideSalienceBias(bias) => {
544 salience_bias_override = Some(*bias);
549 }
550 RefineConstraint::ForceVariantTier { template_key, tier } => {
551 if let Some(existing) = force_variant_tier
555 .iter_mut()
556 .find(|(k, _)| k == template_key)
557 {
558 existing.1 = *tier;
559 } else {
560 force_variant_tier.push((template_key.clone(), *tier));
561 }
562 }
563 RefineConstraint::TightenLengthDistribution(d) => {
564 length_distribution_override = Some(d.clone());
565 }
566 }
567 }
568
569 session.set_refine_blacklists(blacklist_connectives, blacklist_list_styles);
570 session.prime_refine_recency(&prime_connectives, &prime_list_styles);
571 session.set_refine_salience_bias(salience_bias_override);
572 session.set_refine_length_distribution(length_distribution_override);
573 session.set_refine_force_variant_tiers(force_variant_tier);
574}
575
576fn diagnosis_signature(diagnostics: &[Diagnostic]) -> Vec<(&'static str, u32)> {
577 let mut sig: Vec<(&'static str, u32)> = diagnostics
578 .iter()
579 .map(|d| (d.diagnoser, (d.severity * 1000.0) as u32))
580 .collect();
581 sig.sort();
582 sig
583}
584
585#[derive(Debug, Clone)]
590#[non_exhaustive]
591pub enum RefineConstraint {
592 BlacklistConnective(String),
594 BlacklistListStyle(ListStyle),
596 PrimeRecencyWindow {
599 connectives: Vec<String>,
600 list_styles: Vec<ListStyle>,
601 },
602 OverrideSalienceBias(SalienceBias),
604 ForceVariantTier {
606 template_key: String,
607 tier: Salience,
608 },
609 TightenLengthDistribution(LengthDistribution),
611}
612
613#[cfg(test)]
614mod tests {
615 use super::*;
616
617 #[test]
618 fn refine_config_off_is_default() {
619 let c = RefineConfig::default();
620 assert!(c.is_off());
621 assert!(!c.enabled);
622 }
623
624 #[test]
625 fn refine_config_balanced_is_enabled() {
626 let c = RefineConfig::balanced();
627 assert!(!c.is_off());
628 assert_eq!(c.max_iterations, 3);
629 }
630
631 #[test]
632 fn refine_config_with_max_iterations_overrides_default() {
633 let c = RefineConfig::balanced().with_max_iterations(7);
634 assert_eq!(c.max_iterations, 7);
635 }
636
637 #[test]
638 fn weights_default_is_uniform() {
639 let w = RefineWeights::default();
640 assert_eq!(w.repetition, 1.0);
641 assert_eq!(w.profile_match, 1.0);
642 }
643
644 #[test]
645 fn split_sentences_handles_terminators() {
646 let s = split_sentences("First sentence. Second one. Third.");
647 assert_eq!(s, vec!["First sentence.", "Second one.", "Third."]);
648 }
649
650 #[test]
651 fn split_sentences_handles_no_trailing_terminator() {
652 let s = split_sentences("First. Trailing");
653 assert_eq!(s, vec!["First.", "Trailing"]);
654 }
655
656 #[test]
657 fn split_sentences_handles_empty() {
658 let s = split_sentences("");
659 assert!(s.is_empty());
660 }
661
662 #[test]
663 fn apply_constraints_blacklist_connective_writes_session_blacklist() {
664 let mut session = crate::session::Session::new();
665 let constraints = vec![
666 RefineConstraint::BlacklistConnective("Additionally,".to_string()),
667 RefineConstraint::BlacklistConnective("Furthermore,".to_string()),
668 ];
669 super::apply_constraints_to_session(&mut session, &constraints);
670 assert_eq!(
671 session.refine_blacklist_connectives,
672 vec!["Additionally,".to_string(), "Furthermore,".to_string()]
673 );
674 }
675
676 #[test]
677 fn apply_constraints_blacklist_list_style_writes_session_blacklist() {
678 let mut session = crate::session::Session::new();
679 let constraints = vec![RefineConstraint::BlacklistListStyle(ListStyle::Including)];
680 super::apply_constraints_to_session(&mut session, &constraints);
681 assert_eq!(
682 session.refine_blacklist_list_styles,
683 vec![ListStyle::Including]
684 );
685 }
686
687 #[test]
688 fn apply_constraints_prime_recency_pushes_phantom_history() {
689 let mut session = crate::session::Session::new();
690 let constraints = vec![RefineConstraint::PrimeRecencyWindow {
691 connectives: vec!["Additionally,".to_string(), "Furthermore,".to_string()],
692 list_styles: vec![ListStyle::Including, ListStyle::Bracketed],
693 }];
694 super::apply_constraints_to_session(&mut session, &constraints);
695 let baseline_session_clone = session.clone();
701 super::apply_constraints_to_session(
702 &mut session,
703 &[RefineConstraint::PrimeRecencyWindow {
704 connectives: vec!["Additionally,".to_string()],
705 list_styles: vec![],
706 }],
707 );
708 assert!(session.refine_blacklist_connectives.is_empty());
713 let _ = baseline_session_clone;
714 }
715
716 #[test]
717 fn apply_constraints_override_salience_bias_writes_session_override() {
718 let mut session = crate::session::Session::new();
719 let constraints = vec![RefineConstraint::OverrideSalienceBias(SalienceBias::Lower)];
720 super::apply_constraints_to_session(&mut session, &constraints);
721 assert_eq!(session.refine_salience_bias, Some(SalienceBias::Lower));
722 }
723
724 #[test]
725 fn apply_constraints_override_salience_bias_last_writer_wins() {
726 let mut session = crate::session::Session::new();
727 let constraints = vec![
728 RefineConstraint::OverrideSalienceBias(SalienceBias::Lower),
729 RefineConstraint::OverrideSalienceBias(SalienceBias::Higher),
730 ];
731 super::apply_constraints_to_session(&mut session, &constraints);
732 assert_eq!(session.refine_salience_bias, Some(SalienceBias::Higher));
733 }
734
735 #[test]
736 fn apply_constraints_force_variant_tier_writes_session_map() {
737 let mut session = crate::session::Session::new();
738 let constraints = vec![
739 RefineConstraint::ForceVariantTier {
740 template_key: "evt.modified".to_string(),
741 tier: Salience::High,
742 },
743 RefineConstraint::ForceVariantTier {
744 template_key: "evt.touched".to_string(),
745 tier: Salience::Low,
746 },
747 ];
748 super::apply_constraints_to_session(&mut session, &constraints);
749 assert_eq!(
750 session.refine_forced_tier_for("evt.modified"),
751 Some(Salience::High)
752 );
753 assert_eq!(
754 session.refine_forced_tier_for("evt.touched"),
755 Some(Salience::Low)
756 );
757 assert_eq!(session.refine_forced_tier_for("evt.unset"), None);
758 }
759
760 #[test]
761 fn apply_constraints_force_variant_tier_replaces_for_same_key() {
762 let mut session = crate::session::Session::new();
765 let constraints = vec![
766 RefineConstraint::ForceVariantTier {
767 template_key: "evt.modified".to_string(),
768 tier: Salience::High,
769 },
770 RefineConstraint::ForceVariantTier {
771 template_key: "evt.modified".to_string(),
772 tier: Salience::Low,
773 },
774 ];
775 super::apply_constraints_to_session(&mut session, &constraints);
776 assert_eq!(
777 session.refine_forced_tier_for("evt.modified"),
778 Some(Salience::Low)
779 );
780 assert_eq!(session.refine_force_variant_tier.len(), 1);
781 }
782
783 #[test]
784 fn apply_constraints_tighten_length_distribution_writes_session_override() {
785 let mut session = crate::session::Session::new();
786 let target = LengthDistribution {
787 short: 0.5,
788 medium: 0.3,
789 long: 0.2,
790 short_max_words: 7,
791 medium_max_words: 15,
792 };
793 let constraints = vec![RefineConstraint::TightenLengthDistribution(target.clone())];
794 super::apply_constraints_to_session(&mut session, &constraints);
795 assert_eq!(session.refine_length_distribution, Some(target));
796 }
797
798 #[test]
799 fn apply_constraints_clear_then_reapply_resets_override_fields() {
800 let mut session = crate::session::Session::new();
805 super::apply_constraints_to_session(
806 &mut session,
807 &[
808 RefineConstraint::OverrideSalienceBias(SalienceBias::Lower),
809 RefineConstraint::TightenLengthDistribution(LengthDistribution {
810 short: 0.7,
811 medium: 0.2,
812 long: 0.1,
813 short_max_words: 5,
814 medium_max_words: 12,
815 }),
816 RefineConstraint::ForceVariantTier {
817 template_key: "k".to_string(),
818 tier: Salience::High,
819 },
820 ],
821 );
822 assert!(session.refine_salience_bias.is_some());
823 assert!(session.refine_length_distribution.is_some());
824 assert!(!session.refine_force_variant_tier.is_empty());
825 session.clear_refine_overrides();
826 assert!(session.refine_salience_bias.is_none());
827 assert!(session.refine_length_distribution.is_none());
828 assert!(session.refine_force_variant_tier.is_empty());
829 }
830
831 #[test]
832 fn rendered_document_from_paragraphs_aggregates_correctly() {
833 let para1 = ParagraphRender {
834 text: "Foo was modified. It was renamed.".to_string(),
835 events: vec![
836 EventMeta {
837 connective: None,
838 list_style: None,
839 },
840 EventMeta {
841 connective: Some("Additionally,".to_string()),
842 list_style: None,
843 },
844 ],
845 };
846 let para2 = ParagraphRender {
847 text: "Bar was deleted.".to_string(),
848 events: vec![EventMeta::default()],
849 };
850
851 let doc = RenderedDocument::from_paragraphs(vec![para1, para2]);
852 assert_eq!(doc.paragraphs.len(), 2);
853 assert_eq!(doc.sentences.len(), 3);
854 assert_eq!(doc.connectives_used.len(), 1);
855 assert_eq!(doc.connectives_used[0].connective, "Additionally,");
856 assert_eq!(doc.connectives_used[0].paragraph_index, 0);
857 assert_eq!(doc.connectives_used[0].sentence_index_in_paragraph, 1);
858 assert_eq!(
859 doc.text,
860 "Foo was modified. It was renamed.\n\nBar was deleted."
861 );
862 }
863}