use crate::error::Result;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CoherenceConfig {
pub similarity_threshold: f32,
pub contradiction_threshold: f32,
pub logical_flow_threshold: f32,
pub embedding_dim: usize,
pub enable_caching: bool,
pub max_cache_size: usize,
pub use_approximate: bool,
}
impl Default for CoherenceConfig {
fn default() -> Self {
Self {
similarity_threshold: 0.7,
contradiction_threshold: 0.3,
logical_flow_threshold: 0.6,
embedding_dim: 768,
enable_caching: true,
max_cache_size: 1000,
use_approximate: false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SemanticConsistencyResult {
pub is_consistent: bool,
pub consistency_score: f32,
pub segment_similarities: Vec<(usize, usize, f32)>,
pub inconsistent_segments: Vec<usize>,
pub average_similarity: f32,
pub similarity_std_dev: f32,
}
impl Default for SemanticConsistencyResult {
fn default() -> Self {
Self {
is_consistent: true,
consistency_score: 1.0,
segment_similarities: Vec::new(),
inconsistent_segments: Vec::new(),
average_similarity: 1.0,
similarity_std_dev: 0.0,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContradictionResult {
pub has_contradictions: bool,
pub contradiction_count: usize,
pub contradictions: Vec<Contradiction>,
pub contradiction_score: f32,
}
impl Default for ContradictionResult {
fn default() -> Self {
Self {
has_contradictions: false,
contradiction_count: 0,
contradictions: Vec::new(),
contradiction_score: 0.0,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Contradiction {
pub segment_a: usize,
pub segment_b: usize,
pub text_a: String,
pub text_b: String,
pub severity: f32,
pub contradiction_type: ContradictionType,
pub explanation: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ContradictionType {
Logical,
Temporal,
Numeric,
AttributeMismatch,
Causal,
Contextual,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LogicalFlowResult {
pub has_logical_flow: bool,
pub flow_score: f32,
pub violations: Vec<CoherenceViolation>,
pub transition_scores: Vec<f32>,
pub suggestions: Vec<String>,
}
impl Default for LogicalFlowResult {
fn default() -> Self {
Self {
has_logical_flow: true,
flow_score: 1.0,
violations: Vec::new(),
transition_scores: Vec::new(),
suggestions: Vec::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CoherenceViolation {
pub segment_index: usize,
pub violation_type: ViolationType,
pub severity: f32,
pub description: String,
pub suggestion: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ViolationType {
TopicShift,
MissingTransition,
BrokenReference,
IllogicalSequence,
IncompleteThought,
NonSequitur,
}
pub struct CoherenceValidator {
config: CoherenceConfig,
embedding_cache: Arc<RwLock<HashMap<String, Vec<f32>>>>,
negation_patterns: Vec<String>,
transition_markers: Vec<String>,
}
impl CoherenceValidator {
pub fn new(config: CoherenceConfig) -> Self {
Self {
config,
embedding_cache: Arc::new(RwLock::new(HashMap::new())),
negation_patterns: vec![
"not".to_string(),
"never".to_string(),
"no".to_string(),
"none".to_string(),
"neither".to_string(),
"nothing".to_string(),
"without".to_string(),
"isn't".to_string(),
"aren't".to_string(),
"wasn't".to_string(),
"weren't".to_string(),
"don't".to_string(),
"doesn't".to_string(),
"didn't".to_string(),
"won't".to_string(),
"wouldn't".to_string(),
"couldn't".to_string(),
"shouldn't".to_string(),
],
transition_markers: vec![
"however".to_string(),
"therefore".to_string(),
"furthermore".to_string(),
"moreover".to_string(),
"consequently".to_string(),
"thus".to_string(),
"hence".to_string(),
"additionally".to_string(),
"nonetheless".to_string(),
"meanwhile".to_string(),
"finally".to_string(),
"first".to_string(),
"second".to_string(),
"then".to_string(),
"next".to_string(),
],
}
}
pub fn default_config() -> Self {
Self::new(CoherenceConfig::default())
}
pub fn validate_semantic_consistency(
&self,
segments: &[String],
embeddings: Option<&[Vec<f32>]>,
) -> Result<SemanticConsistencyResult> {
if segments.is_empty() {
return Ok(SemanticConsistencyResult::default());
}
if segments.len() == 1 {
return Ok(SemanticConsistencyResult {
is_consistent: true,
consistency_score: 1.0,
..Default::default()
});
}
let computed_embeddings = match embeddings {
Some(emb) => emb.to_vec(),
None => segments
.iter()
.map(|s| self.compute_simple_embedding(s))
.collect(),
};
let mut similarities = Vec::new();
let mut inconsistent = Vec::new();
for i in 0..computed_embeddings.len() {
for j in (i + 1)..computed_embeddings.len() {
let sim = cosine_similarity(&computed_embeddings[i], &computed_embeddings[j]);
similarities.push((i, j, sim));
if sim < self.config.similarity_threshold {
if !inconsistent.contains(&i) {
inconsistent.push(i);
}
if !inconsistent.contains(&j) {
inconsistent.push(j);
}
}
}
}
let all_sims: Vec<f32> = similarities.iter().map(|(_, _, s)| *s).collect();
let avg = if all_sims.is_empty() {
1.0
} else {
all_sims.iter().sum::<f32>() / all_sims.len() as f32
};
let std_dev = compute_std_dev(&all_sims, avg);
let consistency_score = avg;
let is_consistent =
inconsistent.is_empty() && consistency_score >= self.config.similarity_threshold;
Ok(SemanticConsistencyResult {
is_consistent,
consistency_score,
segment_similarities: similarities,
inconsistent_segments: inconsistent,
average_similarity: avg,
similarity_std_dev: std_dev,
})
}
pub fn detect_contradictions(
&self,
segments: &[String],
embeddings: Option<&[Vec<f32>]>,
) -> Result<ContradictionResult> {
if segments.len() < 2 {
return Ok(ContradictionResult::default());
}
let mut contradictions = Vec::new();
for i in 0..segments.len() {
for j in (i + 1)..segments.len() {
if let Some(contradiction) =
self.check_negation_contradiction(i, j, &segments[i], &segments[j])
{
contradictions.push(contradiction);
}
}
}
for i in 0..segments.len() {
for j in (i + 1)..segments.len() {
if let Some(contradiction) =
self.check_numeric_contradiction(i, j, &segments[i], &segments[j])
{
contradictions.push(contradiction);
}
}
}
if let Some(emb) = embeddings {
for i in 0..segments.len() {
for j in (i + 1)..segments.len() {
let sim = cosine_similarity(&emb[i], &emb[j]);
let has_negation_i = self.contains_negation(&segments[i]);
let has_negation_j = self.contains_negation(&segments[j]);
if sim < 0.3 && (has_negation_i != has_negation_j) {
contradictions.push(Contradiction {
segment_a: i,
segment_b: j,
text_a: segments[i].clone(),
text_b: segments[j].clone(),
severity: 1.0 - sim,
contradiction_type: ContradictionType::Logical,
explanation: "Semantic analysis suggests contradiction".to_string(),
});
}
}
}
}
let has_contradictions = !contradictions.is_empty();
let contradiction_count = contradictions.len();
let max_pairs = segments.len() * (segments.len() - 1) / 2;
let contradiction_score = if max_pairs > 0 {
(contradiction_count as f32 / max_pairs as f32).min(1.0)
} else {
0.0
};
Ok(ContradictionResult {
has_contradictions,
contradiction_count,
contradictions,
contradiction_score,
})
}
pub fn check_logical_flow(
&self,
segments: &[String],
embeddings: Option<&[Vec<f32>]>,
) -> Result<LogicalFlowResult> {
if segments.len() < 2 {
return Ok(LogicalFlowResult::default());
}
let mut violations = Vec::new();
let mut transition_scores = Vec::new();
let mut suggestions = Vec::new();
let computed_embeddings = match embeddings {
Some(emb) => emb.to_vec(),
None => segments
.iter()
.map(|s| self.compute_simple_embedding(s))
.collect(),
};
for i in 0..(segments.len() - 1) {
let sim = cosine_similarity(&computed_embeddings[i], &computed_embeddings[i + 1]);
transition_scores.push(sim);
if sim < 0.4 {
violations.push(CoherenceViolation {
segment_index: i + 1,
violation_type: ViolationType::TopicShift,
severity: 1.0 - sim,
description: format!("Abrupt topic shift between segments {} and {}", i, i + 1),
suggestion: Some("Add a transition sentence".to_string()),
});
suggestions.push(format!(
"Consider adding a transition between segments {} and {}",
i,
i + 1
));
}
if sim >= 0.4 && sim < 0.6 {
let has_transition = self.has_transition_marker(&segments[i + 1]);
if !has_transition {
violations.push(CoherenceViolation {
segment_index: i + 1,
violation_type: ViolationType::MissingTransition,
severity: 0.3,
description: format!("Missing transition marker at segment {}", i + 1),
suggestion: Some("Add a transition word".to_string()),
});
}
}
}
let avg_transition = if transition_scores.is_empty() {
1.0
} else {
transition_scores.iter().sum::<f32>() / transition_scores.len() as f32
};
let violation_penalty =
violations.iter().map(|v| v.severity).sum::<f32>() / segments.len() as f32;
let marker_hits = segments
.iter()
.filter(|s| self.has_transition_marker(s))
.count() as f32;
let marker_bonus = if segments.is_empty() {
0.0
} else {
(marker_hits / segments.len() as f32) * 0.3
};
let flow_score = (avg_transition - violation_penalty * 0.5 + marker_bonus).clamp(0.0, 1.0);
let has_logical_flow = flow_score >= self.config.logical_flow_threshold;
Ok(LogicalFlowResult {
has_logical_flow,
flow_score,
violations,
transition_scores,
suggestions,
})
}
fn compute_simple_embedding(&self, text: &str) -> Vec<f32> {
if self.config.enable_caching {
let cache = self.embedding_cache.read();
if let Some(embedding) = cache.get(text) {
return embedding.clone();
}
}
let mut embedding = vec![0.0f32; self.config.embedding_dim];
let text_lower = text.to_lowercase();
let words: Vec<&str> = text_lower.split_whitespace().collect();
for word in &words {
let mut hash: usize = 0xcbf2_9ce4_8422_2325;
for c in word.bytes() {
hash ^= c as usize;
hash = hash.wrapping_mul(0x100_0000_01b3);
}
let idx = hash % self.config.embedding_dim;
embedding[idx] += 1.0;
for window in word.as_bytes().windows(2) {
let mut hh: usize = 0xcbf2_9ce4_8422_2325;
for &c in window {
hh ^= c as usize;
hh = hh.wrapping_mul(0x100_0000_01b3);
}
let idx2 = hh % self.config.embedding_dim;
embedding[idx2] += 0.5;
}
}
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for val in &mut embedding {
*val /= norm;
}
}
if self.config.enable_caching {
let mut cache = self.embedding_cache.write();
if cache.len() < self.config.max_cache_size {
cache.insert(text.to_string(), embedding.clone());
}
}
embedding
}
fn contains_negation(&self, text: &str) -> bool {
let text_lower = text.to_lowercase();
self.negation_patterns
.iter()
.any(|pattern| text_lower.contains(pattern))
}
fn check_negation_contradiction(
&self,
idx_a: usize,
idx_b: usize,
text_a: &str,
text_b: &str,
) -> Option<Contradiction> {
let text_a_lower = text_a.to_lowercase();
let text_b_lower = text_b.to_lowercase();
let words_a: Vec<&str> = text_a_lower.split_whitespace().collect();
let words_b: Vec<&str> = text_b_lower.split_whitespace().collect();
let has_neg_a = self.contains_negation(text_a);
let has_neg_b = self.contains_negation(text_b);
if has_neg_a != has_neg_b {
let content_a: Vec<&str> = words_a
.iter()
.filter(|w| w.len() > 3 && !self.negation_patterns.contains(&w.to_string()))
.copied()
.collect();
let content_b: Vec<&str> = words_b
.iter()
.filter(|w| w.len() > 3 && !self.negation_patterns.contains(&w.to_string()))
.copied()
.collect();
let common: Vec<&str> = content_a
.iter()
.filter(|w| content_b.contains(w))
.copied()
.collect();
if common.len() >= 2 {
return Some(Contradiction {
segment_a: idx_a,
segment_b: idx_b,
text_a: text_a.to_string(),
text_b: text_b.to_string(),
severity: 0.7,
contradiction_type: ContradictionType::Logical,
explanation: format!(
"Possible negation contradiction on topics: {}",
common.join(", ")
),
});
}
}
None
}
fn check_numeric_contradiction(
&self,
idx_a: usize,
idx_b: usize,
text_a: &str,
text_b: &str,
) -> Option<Contradiction> {
let numbers_a: Vec<f64> = extract_numbers(text_a);
let numbers_b: Vec<f64> = extract_numbers(text_b);
if numbers_a.len() == 1 && numbers_b.len() == 1 {
let num_a = numbers_a[0];
let num_b = numbers_b[0];
let diff = (num_a - num_b).abs();
let max_val = num_a.abs().max(num_b.abs());
if max_val > 0.0 && diff / max_val > 0.5 {
let text_a_no_num = text_a
.chars()
.filter(|c| !c.is_numeric() && *c != '.')
.collect::<String>();
let text_b_no_num = text_b
.chars()
.filter(|c| !c.is_numeric() && *c != '.')
.collect::<String>();
let jaccard = jaccard_similarity(&text_a_no_num, &text_b_no_num);
if jaccard > 0.5 {
return Some(Contradiction {
segment_a: idx_a,
segment_b: idx_b,
text_a: text_a.to_string(),
text_b: text_b.to_string(),
severity: 0.6,
contradiction_type: ContradictionType::Numeric,
explanation: format!("Numeric inconsistency: {} vs {}", num_a, num_b),
});
}
}
}
None
}
fn has_transition_marker(&self, text: &str) -> bool {
let text_lower = text.to_lowercase();
self.transition_markers
.iter()
.any(|marker| text_lower.contains(marker))
}
pub fn clear_cache(&self) {
let mut cache = self.embedding_cache.write();
cache.clear();
}
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
dot / (norm_a * norm_b)
}
fn compute_std_dev(values: &[f32], mean: f32) -> f32 {
if values.len() < 2 {
return 0.0;
}
let variance: f32 =
values.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / (values.len() - 1) as f32;
variance.sqrt()
}
fn extract_numbers(text: &str) -> Vec<f64> {
let mut numbers = Vec::new();
let mut current = String::new();
for c in text.chars() {
if c.is_numeric() || c == '.' || (c == '-' && current.is_empty()) {
current.push(c);
} else if !current.is_empty() {
if let Ok(num) = current.parse::<f64>() {
numbers.push(num);
}
current.clear();
}
}
if !current.is_empty() {
if let Ok(num) = current.parse::<f64>() {
numbers.push(num);
}
}
numbers
}
fn jaccard_similarity(a: &str, b: &str) -> f32 {
let words_a: std::collections::HashSet<&str> = a.split_whitespace().collect();
let words_b: std::collections::HashSet<&str> = b.split_whitespace().collect();
let intersection = words_a.intersection(&words_b).count();
let union = words_a.union(&words_b).count();
if union == 0 {
return 0.0;
}
intersection as f32 / union as f32
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_semantic_consistency_single_segment() {
let validator = CoherenceValidator::default_config();
let segments = vec!["This is a test.".to_string()];
let result = validator
.validate_semantic_consistency(&segments, None)
.unwrap();
assert!(result.is_consistent);
assert_eq!(result.consistency_score, 1.0);
}
#[test]
fn test_semantic_consistency_similar_segments() {
let validator = CoherenceValidator::default_config();
let segments = vec![
"The cat sat on the mat.".to_string(),
"The cat was sitting on the mat.".to_string(),
];
let result = validator
.validate_semantic_consistency(&segments, None)
.unwrap();
assert!(result.consistency_score > 0.5);
}
#[test]
fn test_contradiction_detection_negation() {
let validator = CoherenceValidator::default_config();
let segments = vec![
"The system is running properly.".to_string(),
"The system is not running properly.".to_string(),
];
let result = validator.detect_contradictions(&segments, None).unwrap();
assert!(result.has_contradictions);
assert!(result.contradiction_count > 0);
}
#[test]
fn test_contradiction_detection_numeric() {
let validator = CoherenceValidator::default_config();
let segments = vec![
"The temperature was 25 degrees.".to_string(),
"The temperature was 75 degrees.".to_string(),
];
let result = validator.detect_contradictions(&segments, None).unwrap();
assert!(result.has_contradictions);
}
#[test]
fn test_logical_flow() {
let validator = CoherenceValidator::default_config();
let segments = vec![
"First, we need to analyze the data.".to_string(),
"Then, we process the results.".to_string(),
"Finally, we generate the report.".to_string(),
];
let result = validator.check_logical_flow(&segments, None).unwrap();
assert!(result.flow_score > 0.0);
assert!(!result.transition_scores.is_empty());
}
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
let c = vec![0.0, 1.0, 0.0];
assert!((cosine_similarity(&a, &c) - 0.0).abs() < 0.001);
}
#[test]
fn test_extract_numbers() {
let numbers = extract_numbers("The value is 42.5 and -10");
assert_eq!(numbers.len(), 2);
assert!((numbers[0] - 42.5).abs() < 0.001);
assert!((numbers[1] - (-10.0)).abs() < 0.001);
}
#[test]
fn test_jaccard_similarity() {
let sim = jaccard_similarity("hello world", "hello there world");
assert!(sim > 0.5);
}
#[test]
fn test_cache_operations() {
let validator = CoherenceValidator::new(CoherenceConfig {
enable_caching: true,
..Default::default()
});
let _ = validator.compute_simple_embedding("test text");
let _ = validator.compute_simple_embedding("test text");
validator.clear_cache();
let cache = validator.embedding_cache.read();
assert!(cache.is_empty());
}
}