use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DiversityConfig {
pub min_diversity: f32,
pub mode_collapse_threshold: f32,
pub embedding_dim: usize,
pub ngram_size: usize,
pub window_size: usize,
pub semantic_diversity: bool,
pub lexical_weight: f32,
pub semantic_weight: f32,
}
impl Default for DiversityConfig {
fn default() -> Self {
Self {
min_diversity: 0.5,
mode_collapse_threshold: 0.9,
embedding_dim: 768,
ngram_size: 3,
window_size: 100,
semantic_diversity: true,
lexical_weight: 0.4,
semantic_weight: 0.6,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DiversityResult {
pub diversity_score: f32,
pub lexical_diversity: f32,
pub semantic_diversity: f32,
pub type_token_ratio: f32,
pub unique_ngram_ratio: f32,
pub embedding_variance: f32,
pub unique_tokens: usize,
pub total_tokens: usize,
pub category_diversity: HashMap<String, f32>,
}
impl Default for DiversityResult {
fn default() -> Self {
Self {
diversity_score: 0.0,
lexical_diversity: 0.0,
semantic_diversity: 0.0,
type_token_ratio: 0.0,
unique_ngram_ratio: 0.0,
embedding_variance: 0.0,
unique_tokens: 0,
total_tokens: 0,
category_diversity: HashMap::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModeCollapseResult {
pub has_mode_collapse: bool,
pub collapse_severity: f32,
pub average_similarity: f32,
pub dominant_cluster_percentage: f32,
pub repeated_patterns: Vec<RepeatedPattern>,
pub diagnosis: String,
}
impl Default for ModeCollapseResult {
fn default() -> Self {
Self {
has_mode_collapse: false,
collapse_severity: 0.0,
average_similarity: 0.0,
dominant_cluster_percentage: 0.0,
repeated_patterns: Vec::new(),
diagnosis: "No mode collapse detected".to_string(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RepeatedPattern {
pub pattern: String,
pub count: usize,
pub occurrences: Vec<usize>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DiversificationSuggestion {
pub suggestion_type: SuggestionType,
pub message: String,
pub priority: u8,
pub parameters: HashMap<String, String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum SuggestionType {
IncreaseTemperature,
AdjustTopP,
AdjustTopK,
DiverseBeamSearch,
PromptVariation,
SeedVariation,
RepetitionPenalty,
NucleusSampling,
EmbeddingNoise,
}
pub struct DiversityAnalyzer {
config: DiversityConfig,
history: Arc<RwLock<Vec<HistorySample>>>,
ngram_cache: Arc<RwLock<HashMap<String, HashSet<String>>>>,
}
#[derive(Clone)]
struct HistorySample {
text: String,
embedding: Option<Vec<f32>>,
timestamp: std::time::Instant,
}
impl DiversityAnalyzer {
pub fn new(config: DiversityConfig) -> Self {
Self {
config,
history: Arc::new(RwLock::new(Vec::new())),
ngram_cache: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn default_config() -> Self {
Self::new(DiversityConfig::default())
}
pub fn calculate_diversity(
&self,
samples: &[String],
embeddings: Option<&[Vec<f32>]>,
) -> DiversityResult {
if samples.is_empty() {
return DiversityResult::default();
}
let lexical = self.calculate_lexical_diversity(samples);
let semantic = if let Some(emb) = embeddings {
self.calculate_semantic_diversity(emb)
} else if self.config.semantic_diversity {
let simple_emb: Vec<Vec<f32>> = samples
.iter()
.map(|s| self.compute_simple_embedding(s))
.collect();
self.calculate_semantic_diversity(&simple_emb)
} else {
SemanticDiversityResult {
diversity_score: 0.5,
variance: 0.0,
average_distance: 0.0,
}
};
let (ttr, unique, total) = self.calculate_type_token_ratio(samples);
let ngram_ratio = self.calculate_ngram_diversity(samples);
let diversity_score = self.config.lexical_weight * lexical.diversity_score
+ self.config.semantic_weight * semantic.diversity_score;
DiversityResult {
diversity_score,
lexical_diversity: lexical.diversity_score,
semantic_diversity: semantic.diversity_score,
type_token_ratio: ttr,
unique_ngram_ratio: ngram_ratio,
embedding_variance: semantic.variance,
unique_tokens: unique,
total_tokens: total,
category_diversity: HashMap::new(),
}
}
pub fn detect_mode_collapse(
&self,
samples: &[String],
embeddings: Option<&[Vec<f32>]>,
) -> ModeCollapseResult {
if samples.len() < 2 {
return ModeCollapseResult::default();
}
let emb = match embeddings {
Some(e) => e.to_vec(),
None => samples
.iter()
.map(|s| self.compute_simple_embedding(s))
.collect(),
};
let mut total_sim = 0.0;
let mut count = 0;
for i in 0..emb.len() {
for j in (i + 1)..emb.len() {
total_sim += cosine_similarity(&emb[i], &emb[j]);
count += 1;
}
}
let avg_similarity = if count > 0 {
total_sim / count as f32
} else {
0.0
};
let repeated_patterns = self.find_repeated_patterns(samples);
let dominant_percentage = self.estimate_dominant_cluster(&emb);
let has_collapse = avg_similarity > self.config.mode_collapse_threshold
|| dominant_percentage > 0.7
|| repeated_patterns.len() > samples.len() / 4;
let collapse_severity = if has_collapse {
((avg_similarity - self.config.mode_collapse_threshold)
/ (1.0 - self.config.mode_collapse_threshold))
.clamp(0.0, 1.0)
* 0.5
+ dominant_percentage * 0.3
+ (repeated_patterns.len() as f32 / samples.len() as f32).min(1.0) * 0.2
} else {
0.0
};
let diagnosis = if has_collapse {
if avg_similarity > self.config.mode_collapse_threshold {
format!(
"High similarity detected (avg: {:.2}). Samples are too similar.",
avg_similarity
)
} else if dominant_percentage > 0.7 {
format!(
"Dominant cluster contains {:.0}% of samples.",
dominant_percentage * 100.0
)
} else {
format!(
"Found {} repeated patterns indicating lack of diversity.",
repeated_patterns.len()
)
}
} else {
"No mode collapse detected".to_string()
};
ModeCollapseResult {
has_mode_collapse: has_collapse,
collapse_severity,
average_similarity: avg_similarity,
dominant_cluster_percentage: dominant_percentage,
repeated_patterns,
diagnosis,
}
}
pub fn suggest_diversification(
&self,
diversity_result: &DiversityResult,
mode_collapse: Option<&ModeCollapseResult>,
) -> Vec<DiversificationSuggestion> {
let mut suggestions = Vec::new();
if diversity_result.diversity_score < self.config.min_diversity {
suggestions.push(DiversificationSuggestion {
suggestion_type: SuggestionType::IncreaseTemperature,
message: "Increase temperature parameter to add more randomness".to_string(),
priority: 3,
parameters: [("temperature".to_string(), "1.0-1.5".to_string())]
.into_iter()
.collect(),
});
}
if diversity_result.lexical_diversity < 0.4 {
suggestions.push(DiversificationSuggestion {
suggestion_type: SuggestionType::RepetitionPenalty,
message: "Apply repetition penalty to avoid repeated phrases".to_string(),
priority: 2,
parameters: [("repetition_penalty".to_string(), "1.1-1.3".to_string())]
.into_iter()
.collect(),
});
}
if diversity_result.semantic_diversity < 0.4 {
suggestions.push(DiversificationSuggestion {
suggestion_type: SuggestionType::DiverseBeamSearch,
message: "Use diverse beam search for more varied outputs".to_string(),
priority: 2,
parameters: [
("num_beam_groups".to_string(), "4".to_string()),
("diversity_penalty".to_string(), "0.5".to_string()),
]
.into_iter()
.collect(),
});
}
if let Some(collapse) = mode_collapse {
if collapse.has_mode_collapse {
suggestions.push(DiversificationSuggestion {
suggestion_type: SuggestionType::SeedVariation,
message: "Use different random seeds for each generation".to_string(),
priority: 3,
parameters: HashMap::new(),
});
suggestions.push(DiversificationSuggestion {
suggestion_type: SuggestionType::AdjustTopP,
message: "Adjust top-p (nucleus) sampling parameter".to_string(),
priority: 2,
parameters: [("top_p".to_string(), "0.9-0.95".to_string())]
.into_iter()
.collect(),
});
if collapse.collapse_severity > 0.5 {
suggestions.push(DiversificationSuggestion {
suggestion_type: SuggestionType::PromptVariation,
message: "Add variations to input prompts".to_string(),
priority: 3,
parameters: HashMap::new(),
});
}
}
}
if diversity_result.type_token_ratio < 0.3 {
suggestions.push(DiversificationSuggestion {
suggestion_type: SuggestionType::AdjustTopK,
message: "Increase top-k to sample from larger vocabulary".to_string(),
priority: 1,
parameters: [("top_k".to_string(), "50-100".to_string())]
.into_iter()
.collect(),
});
}
suggestions.sort_by_key(|b| std::cmp::Reverse(b.priority));
suggestions
}
fn calculate_lexical_diversity(&self, samples: &[String]) -> LexicalDiversityResult {
let mut all_tokens = Vec::new();
let mut all_bigrams = HashSet::new();
let mut all_trigrams = HashSet::new();
for sample in samples {
let tokens: Vec<&str> = sample.split_whitespace().collect();
all_tokens.extend(tokens.iter().map(|s| s.to_lowercase()));
for i in 0..tokens.len() {
if i + 1 < tokens.len() {
all_bigrams.insert(format!("{} {}", tokens[i], tokens[i + 1]));
}
if i + 2 < tokens.len() {
all_trigrams.insert(format!(
"{} {} {}",
tokens[i],
tokens[i + 1],
tokens[i + 2]
));
}
}
}
let unique_tokens: HashSet<String> = all_tokens.iter().cloned().collect();
let ttr = if all_tokens.is_empty() {
0.0
} else {
unique_tokens.len() as f32 / all_tokens.len() as f32
};
let mut token_counts: HashMap<String, usize> = HashMap::new();
for token in &all_tokens {
*token_counts.entry(token.clone()).or_insert(0) += 1;
}
let hapax_count = token_counts.values().filter(|&&c| c == 1).count();
let hapax_ratio = if unique_tokens.is_empty() {
0.0
} else {
hapax_count as f32 / unique_tokens.len() as f32
};
let diversity_score = (ttr * 0.4 + hapax_ratio * 0.3 + 0.3).min(1.0);
LexicalDiversityResult {
diversity_score,
ttr,
hapax_ratio,
unique_bigrams: all_bigrams.len(),
unique_trigrams: all_trigrams.len(),
}
}
fn calculate_semantic_diversity(&self, embeddings: &[Vec<f32>]) -> SemanticDiversityResult {
if embeddings.is_empty() {
return SemanticDiversityResult::default();
}
let dim = embeddings[0].len();
let n = embeddings.len() as f32;
let mut mean = vec![0.0f32; dim];
for emb in embeddings {
for (i, val) in emb.iter().enumerate() {
mean[i] += val / n;
}
}
let mut variance = 0.0f32;
for emb in embeddings {
for (i, val) in emb.iter().enumerate() {
variance += (val - mean[i]).powi(2);
}
}
variance /= n * dim as f32;
let mut total_distance = 0.0;
let mut count = 0;
for i in 0..embeddings.len() {
for j in (i + 1)..embeddings.len() {
let sim = cosine_similarity(&embeddings[i], &embeddings[j]);
total_distance += 1.0 - sim; count += 1;
}
}
let avg_distance = if count > 0 {
total_distance / count as f32
} else {
0.0
};
let diversity_score = (avg_distance * 0.6 + variance.sqrt() * 0.4).min(1.0);
SemanticDiversityResult {
diversity_score,
variance,
average_distance: avg_distance,
}
}
fn calculate_type_token_ratio(&self, samples: &[String]) -> (f32, usize, usize) {
let mut all_tokens = Vec::new();
for sample in samples {
let tokens: Vec<String> = sample
.split_whitespace()
.map(|s| s.to_lowercase())
.collect();
all_tokens.extend(tokens);
}
let unique: HashSet<String> = all_tokens.iter().cloned().collect();
let unique_count = unique.len();
let total_count = all_tokens.len();
let ttr = if total_count == 0 {
0.0
} else {
unique_count as f32 / total_count as f32
};
(ttr, unique_count, total_count)
}
fn calculate_ngram_diversity(&self, samples: &[String]) -> f32 {
let mut all_ngrams = HashSet::new();
let mut total_ngrams = 0;
for sample in samples {
let tokens: Vec<&str> = sample.split_whitespace().collect();
for i in 0..tokens.len().saturating_sub(self.config.ngram_size - 1) {
let ngram: String = tokens[i..i + self.config.ngram_size].join(" ");
all_ngrams.insert(ngram);
total_ngrams += 1;
}
}
if total_ngrams == 0 {
return 0.0;
}
all_ngrams.len() as f32 / total_ngrams as f32
}
fn find_repeated_patterns(&self, samples: &[String]) -> Vec<RepeatedPattern> {
let mut patterns: HashMap<String, Vec<usize>> = HashMap::new();
for (idx, sample) in samples.iter().enumerate() {
let tokens: Vec<&str> = sample.split_whitespace().collect();
for n in 3..=5 {
for i in 0..tokens.len().saturating_sub(n - 1) {
let ngram: String = tokens[i..i + n].join(" ");
patterns.entry(ngram).or_default().push(idx);
}
}
}
patterns
.into_iter()
.filter(|(_, indices)| indices.len() >= 2)
.map(|(pattern, occurrences)| RepeatedPattern {
pattern,
count: occurrences.len(),
occurrences,
})
.collect()
}
fn estimate_dominant_cluster(&self, embeddings: &[Vec<f32>]) -> f32 {
if embeddings.len() < 3 {
return 1.0;
}
let dim = embeddings[0].len();
let n = embeddings.len() as f32;
let mut centroid = vec![0.0f32; dim];
for emb in embeddings {
for (i, val) in emb.iter().enumerate() {
centroid[i] += val / n;
}
}
let threshold = 0.8;
let close_count = embeddings
.iter()
.filter(|emb| cosine_similarity(emb, ¢roid) > threshold)
.count();
close_count as f32 / embeddings.len() as f32
}
fn compute_simple_embedding(&self, text: &str) -> Vec<f32> {
let mut embedding = vec![0.0f32; self.config.embedding_dim];
let text_lower = text.to_lowercase();
let words: Vec<&str> = text_lower.split_whitespace().collect();
for (i, word) in words.iter().enumerate() {
for (j, c) in word.chars().enumerate() {
let idx =
((c as usize * 31 + j * 17 + i * 13) % self.config.embedding_dim) as usize;
embedding[idx] += 1.0;
}
}
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for val in &mut embedding {
*val /= norm;
}
}
embedding
}
pub fn add_to_history(&self, text: String, embedding: Option<Vec<f32>>) {
let mut history = self.history.write();
while history.len() >= self.config.window_size {
history.remove(0);
}
history.push(HistorySample {
text,
embedding,
timestamp: std::time::Instant::now(),
});
}
pub fn get_rolling_diversity(&self) -> DiversityResult {
let history = self.history.read();
if history.is_empty() {
return DiversityResult::default();
}
let texts: Vec<String> = history.iter().map(|s| s.text.clone()).collect();
let embeddings: Option<Vec<Vec<f32>>> = if history.iter().all(|s| s.embedding.is_some()) {
Some(history.iter().filter_map(|s| s.embedding.clone()).collect())
} else {
None
};
self.calculate_diversity(&texts, embeddings.as_deref())
}
pub fn clear_history(&self) {
let mut history = self.history.write();
history.clear();
}
}
struct LexicalDiversityResult {
diversity_score: f32,
ttr: f32,
hapax_ratio: f32,
unique_bigrams: usize,
unique_trigrams: usize,
}
#[derive(Default)]
struct SemanticDiversityResult {
diversity_score: f32,
variance: f32,
average_distance: f32,
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
dot / (norm_a * norm_b)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_diversity_calculation() {
let analyzer = DiversityAnalyzer::default_config();
let samples = vec![
"The quick brown fox jumps over the lazy dog.".to_string(),
"A fast red cat leaps across the sleepy hound.".to_string(),
"The swift grey wolf runs past the tired sheep.".to_string(),
];
let result = analyzer.calculate_diversity(&samples, None);
assert!(result.diversity_score > 0.0);
assert!(result.lexical_diversity > 0.0);
}
#[test]
fn test_mode_collapse_detection_similar() {
let analyzer = DiversityAnalyzer::default_config();
let samples = vec![
"The cat sat on the mat.".to_string(),
"The cat sat on the mat.".to_string(),
"The cat sat on the mat.".to_string(),
"The cat sat on the mat.".to_string(),
];
let result = analyzer.detect_mode_collapse(&samples, None);
assert!(result.has_mode_collapse);
assert!(result.average_similarity > 0.9);
}
#[test]
fn test_mode_collapse_detection_diverse() {
let analyzer = DiversityAnalyzer::default_config();
let samples = vec![
"The weather is sunny today.".to_string(),
"I enjoy programming in Rust.".to_string(),
"Machine learning is fascinating.".to_string(),
"The ocean waves are calming.".to_string(),
];
let result = analyzer.detect_mode_collapse(&samples, None);
assert!(result.collapse_severity < 0.8);
}
#[test]
fn test_diversification_suggestions() {
let analyzer = DiversityAnalyzer::default_config();
let low_diversity = DiversityResult {
diversity_score: 0.2,
lexical_diversity: 0.3,
semantic_diversity: 0.2,
type_token_ratio: 0.2,
..Default::default()
};
let suggestions = analyzer.suggest_diversification(&low_diversity, None);
assert!(!suggestions.is_empty());
}
#[test]
fn test_type_token_ratio() {
let analyzer = DiversityAnalyzer::default_config();
let samples = vec![
"one two three four five".to_string(),
"one one one one one".to_string(),
];
let (ttr, unique, total) = analyzer.calculate_type_token_ratio(&samples);
assert_eq!(total, 10);
assert_eq!(unique, 5);
assert!((ttr - 0.5).abs() < 0.001);
}
#[test]
fn test_repeated_patterns() {
let analyzer = DiversityAnalyzer::default_config();
let samples = vec![
"the quick brown fox".to_string(),
"the quick brown cat".to_string(),
"the quick brown dog".to_string(),
];
let patterns = analyzer.find_repeated_patterns(&samples);
assert!(!patterns.is_empty());
let found = patterns.iter().any(|p| p.pattern == "the quick brown");
assert!(found);
}
#[test]
fn test_history_tracking() {
let analyzer = DiversityAnalyzer::new(DiversityConfig {
window_size: 5,
..Default::default()
});
for i in 0..10 {
analyzer.add_to_history(format!("Sample text number {}", i), None);
}
let history = analyzer.history.read();
assert_eq!(history.len(), 5); }
#[test]
fn test_rolling_diversity() {
let analyzer = DiversityAnalyzer::default_config();
analyzer.add_to_history("First unique sentence about cats.".to_string(), None);
analyzer.add_to_history("Second different statement about dogs.".to_string(), None);
analyzer.add_to_history("Third varied text about birds.".to_string(), None);
let result = analyzer.get_rolling_diversity();
assert!(result.diversity_score > 0.0);
}
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
let c = vec![0.0, 1.0, 0.0];
assert!((cosine_similarity(&a, &c) - 0.0).abs() < 0.001);
let d = vec![-1.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &d) - (-1.0)).abs() < 0.001);
}
}