use alloc::sync::Arc;
#[cfg(not(feature = "std"))]
use alloc::string::{String, ToString};
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
use crate::discourse::{ListStyle, sentence_word_counts};
use crate::salience::Salience;
use crate::style::{LengthDistribution, SalienceBias, StyleProfile};
#[derive(Debug, Clone, Default, PartialEq)]
#[non_exhaustive]
pub struct RenderedDocument {
pub text: String,
pub paragraphs: Vec<RenderedParagraph>,
pub sentences: Vec<RenderedSentence>,
pub connectives_used: Vec<UsedConnective>,
pub list_styles_used: Vec<UsedListStyle>,
}
#[derive(Debug, Clone, Default, PartialEq)]
#[non_exhaustive]
pub struct RenderedParagraph {
pub text: String,
pub sentences: Vec<RenderedSentence>,
}
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub struct RenderedSentence {
pub text: String,
pub word_count: usize,
pub opening_connective: Option<String>,
pub paragraph_index: usize,
pub sentence_index_in_paragraph: usize,
}
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub struct UsedConnective {
pub connective: String,
pub paragraph_index: usize,
pub sentence_index_in_paragraph: usize,
}
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub struct UsedListStyle {
pub list_style: ListStyle,
pub paragraph_index: usize,
pub sentence_index_in_paragraph: usize,
}
impl RenderedDocument {
pub(crate) fn from_paragraphs(rendered: Vec<ParagraphRender>) -> Self {
let mut paragraphs = Vec::with_capacity(rendered.len());
let mut all_sentences: Vec<RenderedSentence> = Vec::new();
let mut connectives_used: Vec<UsedConnective> = Vec::new();
let mut list_styles_used: Vec<UsedListStyle> = Vec::new();
for (p_idx, p) in rendered.iter().enumerate() {
let mut sentences: Vec<RenderedSentence> = Vec::new();
let counts = sentence_word_counts(&p.text);
let split_sentences = split_sentences(&p.text);
for (s_idx, sentence_text) in split_sentences.iter().enumerate() {
let meta = p.events.get(s_idx);
let opening_connective = meta.and_then(|m| m.connective.clone());
if let Some(c) = &opening_connective {
connectives_used.push(UsedConnective {
connective: c.clone(),
paragraph_index: p_idx,
sentence_index_in_paragraph: s_idx,
});
}
if let Some(ls) = meta.and_then(|m| m.list_style) {
list_styles_used.push(UsedListStyle {
list_style: ls,
paragraph_index: p_idx,
sentence_index_in_paragraph: s_idx,
});
}
let word_count = counts
.get(s_idx)
.copied()
.unwrap_or_else(|| sentence_text.split_whitespace().count());
let s = RenderedSentence {
text: sentence_text.clone(),
word_count,
opening_connective,
paragraph_index: p_idx,
sentence_index_in_paragraph: s_idx,
};
sentences.push(s.clone());
all_sentences.push(s);
}
paragraphs.push(RenderedParagraph {
text: p.text.clone(),
sentences,
});
}
let text = paragraphs
.iter()
.map(|p| p.text.as_str())
.collect::<Vec<_>>()
.join("\n\n");
Self {
text,
paragraphs,
sentences: all_sentences,
connectives_used,
list_styles_used,
}
}
}
pub(crate) struct ParagraphRender {
pub(crate) text: String,
pub(crate) events: Vec<EventMeta>,
}
#[derive(Default)]
pub(crate) struct EventMeta {
pub(crate) connective: Option<String>,
pub(crate) list_style: Option<ListStyle>,
}
fn split_sentences(text: &str) -> Vec<String> {
let mut out: Vec<String> = Vec::new();
let mut current = String::new();
let mut last_was_terminator = false;
for ch in text.chars() {
current.push(ch);
if matches!(ch, '.' | '!' | '?') {
last_was_terminator = true;
} else if last_was_terminator && ch.is_whitespace() {
let trimmed = current.trim().to_string();
if !trimmed.is_empty() {
out.push(trimmed);
}
current.clear();
last_was_terminator = false;
} else if !ch.is_whitespace() {
last_was_terminator = false;
}
}
let trimmed = current.trim().to_string();
if !trimmed.is_empty() {
out.push(trimmed);
}
out
}
#[derive(Clone)]
#[non_exhaustive]
pub struct RefineConfig {
pub enabled: bool,
pub max_iterations: u8,
pub min_improvement: f32,
pub weights: RefineWeights,
pub diagnosers: Vec<Arc<dyn Diagnoser>>,
}
impl core::fmt::Debug for RefineConfig {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("RefineConfig")
.field("enabled", &self.enabled)
.field("max_iterations", &self.max_iterations)
.field("min_improvement", &self.min_improvement)
.field("weights", &self.weights)
.field("diagnosers_count", &self.diagnosers.len())
.finish()
}
}
impl RefineConfig {
pub fn off() -> Self {
Self {
enabled: false,
max_iterations: 3,
min_improvement: 0.01,
weights: RefineWeights::default(),
diagnosers: Vec::new(),
}
}
pub fn balanced() -> Self {
Self {
enabled: true,
max_iterations: 3,
min_improvement: 0.01,
weights: RefineWeights::default(),
diagnosers: crate::refine_diagnosers::default_set(),
}
}
pub fn is_off(&self) -> bool {
!self.enabled
}
pub fn with_max_iterations(mut self, n: u8) -> Self {
self.max_iterations = n;
self
}
pub fn with_min_improvement(mut self, m: f32) -> Self {
self.min_improvement = m;
self
}
pub fn with_weights(mut self, w: RefineWeights) -> Self {
self.weights = w;
self
}
pub fn with_diagnoser(mut self, d: Arc<dyn Diagnoser>) -> Self {
self.diagnosers.push(d);
self
}
}
impl Default for RefineConfig {
fn default() -> Self {
Self::off()
}
}
#[derive(Debug, Clone, Copy)]
#[non_exhaustive]
pub struct RefineWeights {
pub repetition: f32,
pub rhythm: f32,
pub connective: f32,
pub paragraph_opener: f32,
pub list_style_diversity: f32,
pub rst_balance: f32,
pub profile_match: f32,
}
impl Default for RefineWeights {
fn default() -> Self {
Self {
repetition: 1.0,
rhythm: 1.0,
connective: 1.0,
paragraph_opener: 1.0,
list_style_diversity: 1.0,
rst_balance: 1.0,
profile_match: 1.0,
}
}
}
#[derive(Debug, Clone)]
pub struct Diagnostic {
pub diagnoser: &'static str,
pub severity: f32,
pub constraints: Vec<RefineConstraint>,
}
pub trait Diagnoser: Send + Sync {
fn name(&self) -> &'static str;
fn diagnose(
&self,
document: &RenderedDocument,
profile: Option<&StyleProfile>,
) -> Vec<Diagnostic>;
}
#[derive(Debug, Clone)]
pub struct RefineOutcome {
pub text: String,
pub iterations_run: u8,
pub final_score: f32,
pub converged_clean: bool,
}
pub(crate) fn run_refine_loop<F>(
config: &RefineConfig,
profile: Option<&StyleProfile>,
initial: RenderedDocument,
initial_session_state: crate::session::Session,
session: &mut crate::session::Session,
mut render_with_session: F,
) -> Result<RefineOutcome, crate::error::ProsaicError>
where
F: FnMut(&mut crate::session::Session) -> Result<RenderedDocument, crate::error::ProsaicError>,
{
use crate::refine_score::score_document;
let mut best = initial;
let mut best_score = score_document(&best, &config.weights, profile);
let mut best_diagnostics = run_all_diagnosers(&config.diagnosers, &best, profile);
if best_diagnostics.is_empty() {
return Ok(RefineOutcome {
text: best.text,
iterations_run: 0,
final_score: best_score,
converged_clean: true,
});
}
let mut iter = 0_u8;
let mut prev_diag_signature = diagnosis_signature(&best_diagnostics);
while iter < config.max_iterations {
let constraints = aggregate_constraints(&best_diagnostics);
if constraints.is_empty() {
break;
}
*session = initial_session_state.clone();
apply_constraints_to_session(session, &constraints);
let candidate = match render_with_session(session) {
Ok(d) => d,
Err(e) => {
session.clear_refine_overrides();
iter += 1;
if iter >= config.max_iterations {
return Ok(RefineOutcome {
text: best.text,
iterations_run: iter,
final_score: best_score,
converged_clean: false,
});
}
let _ = e; continue;
}
};
session.clear_refine_overrides();
let candidate_score = score_document(&candidate, &config.weights, profile);
let candidate_diagnostics = run_all_diagnosers(&config.diagnosers, &candidate, profile);
let candidate_signature = diagnosis_signature(&candidate_diagnostics);
if candidate_signature == prev_diag_signature && iter > 0 {
break;
}
if candidate_score - best_score < config.min_improvement {
break;
}
best = candidate;
best_score = candidate_score;
best_diagnostics = candidate_diagnostics;
prev_diag_signature = candidate_signature;
iter += 1;
if best_diagnostics.is_empty() {
return Ok(RefineOutcome {
text: best.text,
iterations_run: iter,
final_score: best_score,
converged_clean: true,
});
}
}
Ok(RefineOutcome {
text: best.text,
iterations_run: iter,
final_score: best_score,
converged_clean: best_diagnostics.is_empty(),
})
}
fn run_all_diagnosers(
diagnosers: &[Arc<dyn Diagnoser>],
document: &RenderedDocument,
profile: Option<&StyleProfile>,
) -> Vec<Diagnostic> {
let mut out = Vec::new();
for d in diagnosers {
out.extend(d.diagnose(document, profile));
}
out
}
fn aggregate_constraints(diagnostics: &[Diagnostic]) -> Vec<RefineConstraint> {
let mut out = Vec::new();
for d in diagnostics {
for c in &d.constraints {
let already = out
.iter()
.any(|existing: &RefineConstraint| match (existing, c) {
(
RefineConstraint::BlacklistConnective(a),
RefineConstraint::BlacklistConnective(b),
) => a == b,
(
RefineConstraint::BlacklistListStyle(a),
RefineConstraint::BlacklistListStyle(b),
) => a == b,
_ => false,
});
if !already {
out.push(c.clone());
}
}
}
out
}
fn apply_constraints_to_session(
session: &mut crate::session::Session,
constraints: &[RefineConstraint],
) {
let mut blacklist_connectives = Vec::new();
let mut blacklist_list_styles = Vec::new();
let mut prime_connectives: Vec<String> = Vec::new();
let mut prime_list_styles: Vec<ListStyle> = Vec::new();
let mut salience_bias_override: Option<SalienceBias> = None;
let mut length_distribution_override: Option<LengthDistribution> = None;
let mut force_variant_tier: Vec<(String, Salience)> = Vec::new();
for c in constraints {
match c {
RefineConstraint::BlacklistConnective(s) => blacklist_connectives.push(s.clone()),
RefineConstraint::BlacklistListStyle(s) => blacklist_list_styles.push(*s),
RefineConstraint::PrimeRecencyWindow {
connectives,
list_styles,
} => {
prime_connectives.extend(connectives.iter().cloned());
prime_list_styles.extend(list_styles.iter().copied());
}
RefineConstraint::OverrideSalienceBias(bias) => {
salience_bias_override = Some(*bias);
}
RefineConstraint::ForceVariantTier { template_key, tier } => {
if let Some(existing) = force_variant_tier
.iter_mut()
.find(|(k, _)| k == template_key)
{
existing.1 = *tier;
} else {
force_variant_tier.push((template_key.clone(), *tier));
}
}
RefineConstraint::TightenLengthDistribution(d) => {
length_distribution_override = Some(d.clone());
}
}
}
session.set_refine_blacklists(blacklist_connectives, blacklist_list_styles);
session.prime_refine_recency(&prime_connectives, &prime_list_styles);
session.set_refine_salience_bias(salience_bias_override);
session.set_refine_length_distribution(length_distribution_override);
session.set_refine_force_variant_tiers(force_variant_tier);
}
fn diagnosis_signature(diagnostics: &[Diagnostic]) -> Vec<(&'static str, u32)> {
let mut sig: Vec<(&'static str, u32)> = diagnostics
.iter()
.map(|d| (d.diagnoser, (d.severity * 1000.0) as u32))
.collect();
sig.sort();
sig
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum RefineConstraint {
BlacklistConnective(String),
BlacklistListStyle(ListStyle),
PrimeRecencyWindow {
connectives: Vec<String>,
list_styles: Vec<ListStyle>,
},
OverrideSalienceBias(SalienceBias),
ForceVariantTier {
template_key: String,
tier: Salience,
},
TightenLengthDistribution(LengthDistribution),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn refine_config_off_is_default() {
let c = RefineConfig::default();
assert!(c.is_off());
assert!(!c.enabled);
}
#[test]
fn refine_config_balanced_is_enabled() {
let c = RefineConfig::balanced();
assert!(!c.is_off());
assert_eq!(c.max_iterations, 3);
}
#[test]
fn refine_config_with_max_iterations_overrides_default() {
let c = RefineConfig::balanced().with_max_iterations(7);
assert_eq!(c.max_iterations, 7);
}
#[test]
fn weights_default_is_uniform() {
let w = RefineWeights::default();
assert_eq!(w.repetition, 1.0);
assert_eq!(w.profile_match, 1.0);
}
#[test]
fn split_sentences_handles_terminators() {
let s = split_sentences("First sentence. Second one. Third.");
assert_eq!(s, vec!["First sentence.", "Second one.", "Third."]);
}
#[test]
fn split_sentences_handles_no_trailing_terminator() {
let s = split_sentences("First. Trailing");
assert_eq!(s, vec!["First.", "Trailing"]);
}
#[test]
fn split_sentences_handles_empty() {
let s = split_sentences("");
assert!(s.is_empty());
}
#[test]
fn apply_constraints_blacklist_connective_writes_session_blacklist() {
let mut session = crate::session::Session::new();
let constraints = vec![
RefineConstraint::BlacklistConnective("Additionally,".to_string()),
RefineConstraint::BlacklistConnective("Furthermore,".to_string()),
];
super::apply_constraints_to_session(&mut session, &constraints);
assert_eq!(
session.refine_blacklist_connectives,
vec!["Additionally,".to_string(), "Furthermore,".to_string()]
);
}
#[test]
fn apply_constraints_blacklist_list_style_writes_session_blacklist() {
let mut session = crate::session::Session::new();
let constraints = vec![RefineConstraint::BlacklistListStyle(ListStyle::Including)];
super::apply_constraints_to_session(&mut session, &constraints);
assert_eq!(
session.refine_blacklist_list_styles,
vec![ListStyle::Including]
);
}
#[test]
fn apply_constraints_prime_recency_pushes_phantom_history() {
let mut session = crate::session::Session::new();
let constraints = vec![RefineConstraint::PrimeRecencyWindow {
connectives: vec!["Additionally,".to_string(), "Furthermore,".to_string()],
list_styles: vec![ListStyle::Including, ListStyle::Bracketed],
}];
super::apply_constraints_to_session(&mut session, &constraints);
let baseline_session_clone = session.clone();
super::apply_constraints_to_session(
&mut session,
&[RefineConstraint::PrimeRecencyWindow {
connectives: vec!["Additionally,".to_string()],
list_styles: vec![],
}],
);
assert!(session.refine_blacklist_connectives.is_empty());
let _ = baseline_session_clone;
}
#[test]
fn apply_constraints_override_salience_bias_writes_session_override() {
let mut session = crate::session::Session::new();
let constraints = vec![RefineConstraint::OverrideSalienceBias(SalienceBias::Lower)];
super::apply_constraints_to_session(&mut session, &constraints);
assert_eq!(session.refine_salience_bias, Some(SalienceBias::Lower));
}
#[test]
fn apply_constraints_override_salience_bias_last_writer_wins() {
let mut session = crate::session::Session::new();
let constraints = vec![
RefineConstraint::OverrideSalienceBias(SalienceBias::Lower),
RefineConstraint::OverrideSalienceBias(SalienceBias::Higher),
];
super::apply_constraints_to_session(&mut session, &constraints);
assert_eq!(session.refine_salience_bias, Some(SalienceBias::Higher));
}
#[test]
fn apply_constraints_force_variant_tier_writes_session_map() {
let mut session = crate::session::Session::new();
let constraints = vec![
RefineConstraint::ForceVariantTier {
template_key: "evt.modified".to_string(),
tier: Salience::High,
},
RefineConstraint::ForceVariantTier {
template_key: "evt.touched".to_string(),
tier: Salience::Low,
},
];
super::apply_constraints_to_session(&mut session, &constraints);
assert_eq!(
session.refine_forced_tier_for("evt.modified"),
Some(Salience::High)
);
assert_eq!(
session.refine_forced_tier_for("evt.touched"),
Some(Salience::Low)
);
assert_eq!(session.refine_forced_tier_for("evt.unset"), None);
}
#[test]
fn apply_constraints_force_variant_tier_replaces_for_same_key() {
let mut session = crate::session::Session::new();
let constraints = vec![
RefineConstraint::ForceVariantTier {
template_key: "evt.modified".to_string(),
tier: Salience::High,
},
RefineConstraint::ForceVariantTier {
template_key: "evt.modified".to_string(),
tier: Salience::Low,
},
];
super::apply_constraints_to_session(&mut session, &constraints);
assert_eq!(
session.refine_forced_tier_for("evt.modified"),
Some(Salience::Low)
);
assert_eq!(session.refine_force_variant_tier.len(), 1);
}
#[test]
fn apply_constraints_tighten_length_distribution_writes_session_override() {
let mut session = crate::session::Session::new();
let target = LengthDistribution {
short: 0.5,
medium: 0.3,
long: 0.2,
short_max_words: 7,
medium_max_words: 15,
};
let constraints = vec![RefineConstraint::TightenLengthDistribution(target.clone())];
super::apply_constraints_to_session(&mut session, &constraints);
assert_eq!(session.refine_length_distribution, Some(target));
}
#[test]
fn apply_constraints_clear_then_reapply_resets_override_fields() {
let mut session = crate::session::Session::new();
super::apply_constraints_to_session(
&mut session,
&[
RefineConstraint::OverrideSalienceBias(SalienceBias::Lower),
RefineConstraint::TightenLengthDistribution(LengthDistribution {
short: 0.7,
medium: 0.2,
long: 0.1,
short_max_words: 5,
medium_max_words: 12,
}),
RefineConstraint::ForceVariantTier {
template_key: "k".to_string(),
tier: Salience::High,
},
],
);
assert!(session.refine_salience_bias.is_some());
assert!(session.refine_length_distribution.is_some());
assert!(!session.refine_force_variant_tier.is_empty());
session.clear_refine_overrides();
assert!(session.refine_salience_bias.is_none());
assert!(session.refine_length_distribution.is_none());
assert!(session.refine_force_variant_tier.is_empty());
}
#[test]
fn rendered_document_from_paragraphs_aggregates_correctly() {
let para1 = ParagraphRender {
text: "Foo was modified. It was renamed.".to_string(),
events: vec![
EventMeta {
connective: None,
list_style: None,
},
EventMeta {
connective: Some("Additionally,".to_string()),
list_style: None,
},
],
};
let para2 = ParagraphRender {
text: "Bar was deleted.".to_string(),
events: vec![EventMeta::default()],
};
let doc = RenderedDocument::from_paragraphs(vec![para1, para2]);
assert_eq!(doc.paragraphs.len(), 2);
assert_eq!(doc.sentences.len(), 3);
assert_eq!(doc.connectives_used.len(), 1);
assert_eq!(doc.connectives_used[0].connective, "Additionally,");
assert_eq!(doc.connectives_used[0].paragraph_index, 0);
assert_eq!(doc.connectives_used[0].sentence_index_in_paragraph, 1);
assert_eq!(
doc.text,
"Foo was modified. It was renamed.\n\nBar was deleted."
);
}
}