use crate::error::{Result, TextError};
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum DiscourseRelation {
Cause,
Effect,
Contrast,
Elaboration,
Temporal,
Conditional,
Exemplification,
Summary,
None,
}
impl std::fmt::Display for DiscourseRelation {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let label = match self {
Self::Cause => "CAUSE",
Self::Effect => "EFFECT",
Self::Contrast => "CONTRAST",
Self::Elaboration => "ELABORATION",
Self::Temporal => "TEMPORAL",
Self::Conditional => "CONDITIONAL",
Self::Exemplification => "EXEMPLIFICATION",
Self::Summary => "SUMMARY",
Self::None => "NONE",
};
write!(f, "{}", label)
}
}
#[derive(Debug, Clone, Default)]
pub struct CueLexicon {
pub cause: Vec<String>,
pub effect: Vec<String>,
pub contrast: Vec<String>,
pub elaboration: Vec<String>,
pub temporal: Vec<String>,
pub conditional: Vec<String>,
pub exemplification: Vec<String>,
pub summary: Vec<String>,
}
impl CueLexicon {
pub fn default_english() -> Self {
let cue = |phrases: &[&str]| {
phrases
.iter()
.map(|s| s.to_lowercase())
.collect::<Vec<String>>()
};
Self {
cause: cue(&[
"because",
"since",
"as",
"due to",
"owing to",
"given that",
"in light of",
"for the reason that",
"as a result of",
]),
effect: cue(&[
"therefore",
"thus",
"hence",
"consequently",
"as a result",
"as a consequence",
"so",
"accordingly",
"for this reason",
"it follows that",
"this led to",
"this caused",
]),
contrast: cue(&[
"however",
"but",
"yet",
"although",
"even though",
"while",
"whereas",
"on the other hand",
"in contrast",
"nevertheless",
"nonetheless",
"despite",
"in spite of",
"conversely",
"by contrast",
"on the contrary",
"that said",
"still",
"yet",
"though",
]),
elaboration: cue(&[
"furthermore",
"moreover",
"in addition",
"additionally",
"also",
"likewise",
"similarly",
"indeed",
"in fact",
"specifically",
"notably",
"particularly",
"what is more",
"besides",
"more importantly",
]),
temporal: cue(&[
"then",
"next",
"after",
"before",
"when",
"while",
"once",
"previously",
"subsequently",
"later",
"earlier",
"at the same time",
"meanwhile",
"in the meantime",
"afterward",
"afterwards",
"first",
"second",
"finally",
"initially",
]),
conditional: cue(&[
"if",
"unless",
"provided that",
"as long as",
"given that",
"in case",
"assuming that",
"on condition that",
"only if",
"whenever",
]),
exemplification: cue(&[
"for example",
"for instance",
"such as",
"e.g.",
"to illustrate",
"as an example",
"as illustrated by",
"consider",
"take for example",
"as shown by",
]),
summary: cue(&[
"in summary",
"in conclusion",
"to summarize",
"to summarise",
"in brief",
"in short",
"overall",
"to conclude",
"in closing",
"all in all",
"on balance",
"in the end",
"to sum up",
]),
}
}
fn relation_cues(&self) -> impl Iterator<Item = (DiscourseRelation, &[String])> {
[
(DiscourseRelation::Cause, self.cause.as_slice()),
(DiscourseRelation::Effect, self.effect.as_slice()),
(DiscourseRelation::Contrast, self.contrast.as_slice()),
(DiscourseRelation::Elaboration, self.elaboration.as_slice()),
(DiscourseRelation::Temporal, self.temporal.as_slice()),
(DiscourseRelation::Conditional, self.conditional.as_slice()),
(
DiscourseRelation::Exemplification,
self.exemplification.as_slice(),
),
(DiscourseRelation::Summary, self.summary.as_slice()),
]
.into_iter()
}
}
fn starts_with_cue(text_lower: &str, cue: &str) -> bool {
let trimmed = text_lower.trim_start();
if trimmed.starts_with(cue) {
let after = &trimmed[cue.len()..];
return after
.chars()
.next()
.map(|c| !c.is_alphanumeric())
.unwrap_or(true);
}
false
}
fn leading_window(text: &str) -> String {
text.chars().take(80).collect::<String>().to_lowercase()
}
pub fn detect_discourse_relation(
sentence1: &str,
sentence2: &str,
cue_words: &CueLexicon,
) -> Option<DiscourseRelation> {
let window2 = leading_window(sentence2);
let mut best: Option<(DiscourseRelation, usize)> = None;
for (rel, cues) in cue_words.relation_cues() {
for cue in cues {
let found = starts_with_cue(&window2, cue);
let found = found || window2.contains(cue.as_str());
if found {
let cue_len = cue.len();
let is_better = best
.as_ref()
.map(|(_, prev_len)| cue_len > *prev_len)
.unwrap_or(true);
if is_better {
best = Some((rel.clone(), cue_len));
}
}
}
}
let window1_lower = sentence1.to_lowercase();
if best.is_none()
&& (window1_lower.trim_end_matches('.').ends_with("if") || window1_lower.contains(" if "))
{
best = Some((DiscourseRelation::Conditional, 2));
}
best.map(|(rel, _)| rel)
}
#[derive(Debug, Clone)]
pub struct RstNode {
pub sentence_index: usize,
pub text: String,
pub relation_to_parent: Option<DiscourseRelation>,
pub children: Vec<RstNode>,
}
#[derive(Debug, Clone)]
pub struct RhetoricalStructure {
pub root: RstNode,
pub sentence_count: usize,
pub inter_sentence_relations: Vec<(usize, usize, DiscourseRelation)>,
}
impl RhetoricalStructure {
pub fn from_sentence_pairs(
sentences: &[String],
relations: Vec<(usize, usize, DiscourseRelation)>,
) -> Option<Self> {
if sentences.is_empty() {
return None;
}
let mut rel_lookup: HashMap<usize, DiscourseRelation> = HashMap::new();
for (_, j, rel) in &relations {
rel_lookup.insert(*j, rel.clone());
}
let root = RstNode {
sentence_index: 0,
text: sentences[0].clone(),
relation_to_parent: None,
children: sentences
.iter()
.enumerate()
.skip(1)
.map(|(idx, text)| RstNode {
sentence_index: idx,
text: text.clone(),
relation_to_parent: rel_lookup.get(&idx).cloned(),
children: Vec::new(),
})
.collect(),
};
Some(Self {
root,
sentence_count: sentences.len(),
inter_sentence_relations: relations,
})
}
pub fn nodes_dfs(&self) -> Vec<&RstNode> {
let mut stack = vec![&self.root];
let mut result = Vec::new();
while let Some(node) = stack.pop() {
result.push(node);
for child in node.children.iter().rev() {
stack.push(child);
}
}
result
}
}
fn word_set(sentence: &str) -> HashSet<String> {
sentence
.split(|c: char| !c.is_alphanumeric())
.filter(|w| w.len() >= 3)
.map(|w| w.to_lowercase())
.collect()
}
const STOP_WORDS: &[&str] = &[
"the", "and", "for", "are", "was", "were", "has", "have", "had", "not", "but", "that", "this",
"with", "from", "they", "will", "been", "its", "their", "there", "what", "also", "into",
"than", "then", "when", "more", "some", "such", "even", "both", "each", "said", "very", "just",
"over", "like", "about", "would", "could", "should", "which",
];
fn stop_set() -> HashSet<&'static str> {
STOP_WORDS.iter().copied().collect()
}
fn lexical_overlap(s1: &str, s2: &str) -> f64 {
let stops = stop_set();
let w1: HashSet<String> = word_set(s1)
.into_iter()
.filter(|w| !stops.contains(w.as_str()))
.collect();
let w2: HashSet<String> = word_set(s2)
.into_iter()
.filter(|w| !stops.contains(w.as_str()))
.collect();
if w1.is_empty() && w2.is_empty() {
return 1.0;
}
let inter = w1.intersection(&w2).count() as f64;
let union = w1.union(&w2).count() as f64;
if union == 0.0 {
0.0
} else {
inter / union
}
}
fn cue_density(text: &str, cue_words: &CueLexicon) -> usize {
let lower = text.to_lowercase();
cue_words
.relation_cues()
.flat_map(|(_, cues)| cues.iter())
.filter(|cue| lower.contains(cue.as_str()))
.count()
}
fn split_sentences(text: &str) -> Vec<String> {
let mut sentences = Vec::new();
let mut buf = String::new();
for c in text.chars() {
buf.push(c);
if c == '.' || c == '!' || c == '?' {
let s = buf.trim().to_string();
if !s.is_empty() {
sentences.push(s);
}
buf.clear();
}
}
let rem = buf.trim().to_string();
if !rem.is_empty() {
sentences.push(rem);
}
sentences
}
pub fn coherence_score(text: &str) -> f64 {
coherence_score_with_lexicon(text, &CueLexicon::default_english())
}
pub fn coherence_score_with_lexicon(text: &str, cue_words: &CueLexicon) -> f64 {
let sents = split_sentences(text);
if sents.len() < 2 {
return 1.0; }
let pairs: Vec<(&str, &str)> = sents
.windows(2)
.map(|w| (w[0].as_str(), w[1].as_str()))
.collect();
let n = pairs.len() as f64;
let lex_sum: f64 = pairs.iter().map(|(a, b)| lexical_overlap(a, b)).sum();
let lex_score = lex_sum / n;
let cue_count = pairs
.iter()
.filter(|(_, b)| cue_density(b, cue_words) > 0)
.count() as f64;
let cue_score = cue_count / n;
0.6 * lex_score + 0.4 * cue_score
}
pub struct DiscourseAnalyzer {
cue_lexicon: CueLexicon,
}
impl Default for DiscourseAnalyzer {
fn default() -> Self {
Self::new()
}
}
impl DiscourseAnalyzer {
pub fn new() -> Self {
Self {
cue_lexicon: CueLexicon::default_english(),
}
}
pub fn with_lexicon(mut self, lex: CueLexicon) -> Self {
self.cue_lexicon = lex;
self
}
pub fn detect_relation(&self, s1: &str, s2: &str) -> Option<DiscourseRelation> {
detect_discourse_relation(s1, s2, &self.cue_lexicon)
}
pub fn analyse(&self, text: &str) -> Result<DiscourseAnalysis> {
if text.is_empty() {
return Err(TextError::InvalidInput(
"Input text must not be empty".to_string(),
));
}
let sentences = split_sentences(text);
let mut relations: Vec<(usize, usize, DiscourseRelation)> = Vec::new();
for (i, pair) in sentences.windows(2).enumerate() {
let s1 = &pair[0];
let s2 = &pair[1];
if let Some(rel) = detect_discourse_relation(s1, s2, &self.cue_lexicon) {
relations.push((i, i + 1, rel));
}
}
let rst = RhetoricalStructure::from_sentence_pairs(&sentences, relations.clone());
let score = coherence_score_with_lexicon(text, &self.cue_lexicon);
Ok(DiscourseAnalysis {
sentences,
relations,
rst,
coherence: score,
})
}
}
pub struct DiscourseAnalysis {
pub sentences: Vec<String>,
pub relations: Vec<(usize, usize, DiscourseRelation)>,
pub rst: Option<RhetoricalStructure>,
pub coherence: f64,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_detect_contrast() {
let lex = CueLexicon::default_english();
let s1 = "The experiment was promising.";
let s2 = "However, the results were inconclusive.";
let rel = detect_discourse_relation(s1, s2, &lex);
assert_eq!(rel, Some(DiscourseRelation::Contrast));
}
#[test]
fn test_detect_effect() {
let lex = CueLexicon::default_english();
let s1 = "The team worked very hard.";
let s2 = "Therefore, they finished on time.";
let rel = detect_discourse_relation(s1, s2, &lex);
assert_eq!(rel, Some(DiscourseRelation::Effect));
}
#[test]
fn test_detect_cause() {
let lex = CueLexicon::default_english();
let s1 = "The project was delayed.";
let s2 = "Because the supplier did not deliver the parts.";
let rel = detect_discourse_relation(s1, s2, &lex);
assert_eq!(rel, Some(DiscourseRelation::Cause));
}
#[test]
fn test_detect_temporal() {
let lex = CueLexicon::default_english();
let s1 = "She completed the analysis.";
let s2 = "Then she wrote the report.";
let rel = detect_discourse_relation(s1, s2, &lex);
assert_eq!(rel, Some(DiscourseRelation::Temporal));
}
#[test]
fn test_detect_conditional() {
let lex = CueLexicon::default_english();
let s1 = "You will succeed.";
let s2 = "If you follow the plan carefully.";
let rel = detect_discourse_relation(s1, s2, &lex);
assert_eq!(rel, Some(DiscourseRelation::Conditional));
}
#[test]
fn test_detect_elaboration() {
let lex = CueLexicon::default_english();
let s1 = "The new policy was announced.";
let s2 = "Furthermore, it will take effect immediately.";
let rel = detect_discourse_relation(s1, s2, &lex);
assert_eq!(rel, Some(DiscourseRelation::Elaboration));
}
#[test]
fn test_detect_exemplification() {
let lex = CueLexicon::default_english();
let s1 = "Many animals live in the rainforest.";
let s2 = "For example, jaguars and toucans are common there.";
let rel = detect_discourse_relation(s1, s2, &lex);
assert_eq!(rel, Some(DiscourseRelation::Exemplification));
}
#[test]
fn test_detect_summary() {
let lex = CueLexicon::default_english();
let s1 = "We reviewed all the evidence.";
let s2 = "In conclusion, the hypothesis is supported.";
let rel = detect_discourse_relation(s1, s2, &lex);
assert_eq!(rel, Some(DiscourseRelation::Summary));
}
#[test]
fn test_detect_none() {
let lex = CueLexicon::default_english();
let s1 = "The cat sat on the mat.";
let s2 = "The dog ran across the field.";
let rel = detect_discourse_relation(s1, s2, &lex);
let _ = rel;
}
#[test]
fn test_coherence_score_coherent() {
let text = "The researchers conducted an experiment. \
Therefore, they published their findings. \
Furthermore, the findings were widely cited.";
let score = coherence_score(text);
assert!(score > 0.0, "score should be positive: {}", score);
assert!(score <= 1.0, "score should be <= 1.0: {}", score);
}
#[test]
fn test_coherence_score_incoherent() {
let text = "The price of gold rose sharply. \
Elephants live in Africa. \
Quantum mechanics is complex.";
let score = coherence_score(text);
assert!(score <= 1.0);
}
#[test]
fn test_coherence_score_single_sentence() {
let score = coherence_score("This is a single sentence.");
assert_eq!(score, 1.0);
}
#[test]
fn test_rst_tree_construction() {
let sentences = vec![
"Alice studied hard.".to_string(),
"Therefore, she passed the exam.".to_string(),
"However, she felt tired afterward.".to_string(),
];
let relations = vec![
(0, 1, DiscourseRelation::Effect),
(1, 2, DiscourseRelation::Contrast),
];
let tree = RhetoricalStructure::from_sentence_pairs(&sentences, relations);
assert!(tree.is_some());
let tree = tree.expect("already checked");
assert_eq!(tree.sentence_count, 3);
assert_eq!(tree.root.sentence_index, 0);
assert_eq!(tree.root.children.len(), 2);
let child_relations: Vec<Option<DiscourseRelation>> = tree
.root
.children
.iter()
.map(|c| c.relation_to_parent.clone())
.collect();
assert!(child_relations.contains(&Some(DiscourseRelation::Effect)));
assert!(child_relations.contains(&Some(DiscourseRelation::Contrast)));
}
#[test]
fn test_rst_empty_text_returns_none() {
let tree = RhetoricalStructure::from_sentence_pairs(&[], Vec::new());
assert!(tree.is_none());
}
#[test]
fn test_analyser_full_pipeline() {
let analyser = DiscourseAnalyzer::new();
let text = "The company invested heavily in R&D. \
Therefore, its products improved significantly. \
However, costs also increased.";
let analysis = analyser.analyse(text).expect("should succeed");
assert_eq!(analysis.sentences.len(), 3);
assert!(!analysis.relations.is_empty());
assert!(analysis.rst.is_some());
assert!(analysis.coherence >= 0.0 && analysis.coherence <= 1.0);
}
#[test]
fn test_analyser_empty_input_error() {
let analyser = DiscourseAnalyzer::new();
assert!(analyser.analyse("").is_err());
}
#[test]
fn test_dfs_traversal() {
let sentences = vec!["S1".to_string(), "S2".to_string(), "S3".to_string()];
let tree =
RhetoricalStructure::from_sentence_pairs(&sentences, Vec::new()).expect("should build");
let nodes = tree.nodes_dfs();
assert_eq!(nodes.len(), 3);
}
#[test]
fn test_custom_lexicon() {
let mut lex = CueLexicon::default();
lex.effect.push("voila".to_string());
let s1 = "We mixed the chemicals.";
let s2 = "Voila, it worked.";
let rel = detect_discourse_relation(s1, s2, &lex);
assert_eq!(rel, Some(DiscourseRelation::Effect));
}
}