use serde::{Deserialize, Serialize};
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
pub enum LinkageMethod {
#[default]
Ward,
Complete,
Average,
Single,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ClusteringConfig {
pub num_clusters: Option<usize>,
pub distance_threshold: Option<f32>,
pub linkage: LinkageMethod,
pub min_cluster_size: usize,
pub parallel: bool,
pub checkpoint_interval: usize,
pub verbose: bool,
}
impl Default for ClusteringConfig {
fn default() -> Self {
Self {
num_clusters: Some(10),
distance_threshold: None,
linkage: LinkageMethod::Ward,
min_cluster_size: 5,
parallel: true,
checkpoint_interval: 100,
verbose: false,
}
}
}
impl ClusteringConfig {
pub fn with_num_clusters(num_clusters: usize) -> Self {
Self {
num_clusters: Some(num_clusters),
..Default::default()
}
}
pub fn with_distance_threshold(threshold: f32) -> Self {
Self {
num_clusters: None,
distance_threshold: Some(threshold),
..Default::default()
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct CtfidfConfig {
pub num_keywords: usize,
pub min_df: usize,
pub max_df_ratio: f32,
pub ngram_range: (usize, usize),
pub sublinear_tf: bool,
pub min_term_length: usize,
pub max_term_length: usize,
}
impl Default for CtfidfConfig {
fn default() -> Self {
Self {
num_keywords: 10,
min_df: 2,
max_df_ratio: 0.95,
ngram_range: (1, 1),
sublinear_tf: true,
min_term_length: 2,
max_term_length: 50,
}
}
}
impl CtfidfConfig {
pub fn with_num_keywords(num_keywords: usize) -> Self {
Self {
num_keywords,
..Default::default()
}
}
pub fn with_ngrams(max_n: usize) -> Self {
Self {
ngram_range: (1, max_n),
..Default::default()
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SummarizationConfig {
pub template: DescriptionTemplateType,
#[serde(default)]
pub custom_template: Option<String>,
pub num_representative_docs: usize,
pub max_description_length: usize,
pub include_coherence: bool,
}
impl Default for SummarizationConfig {
fn default() -> Self {
Self {
template: DescriptionTemplateType::Keywords,
custom_template: None,
num_representative_docs: 3,
max_description_length: 200,
include_coherence: false,
}
}
}
pub type DescriptionTemplate = DescriptionTemplateType;
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
#[repr(u8)]
pub enum DescriptionTemplateType {
#[default]
Keywords = 0,
Label = 1,
Extractive = 2,
Custom = 3,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TopicConfig {
pub clustering: ClusteringConfig,
pub ctfidf: CtfidfConfig,
pub summarization: SummarizationConfig,
pub hierarchy_levels: usize,
pub min_topic_size: usize,
pub compute_coherence: bool,
pub verbose: bool,
}
impl Default for TopicConfig {
fn default() -> Self {
Self {
clustering: ClusteringConfig::default(),
ctfidf: CtfidfConfig::default(),
summarization: SummarizationConfig::default(),
hierarchy_levels: 3,
min_topic_size: 5,
compute_coherence: true,
verbose: false,
}
}
}
impl TopicConfig {
pub fn with_num_clusters(num_clusters: usize) -> Self {
Self {
clustering: ClusteringConfig::with_num_clusters(num_clusters),
..Default::default()
}
}
pub fn for_small_corpus() -> Self {
Self {
clustering: ClusteringConfig {
num_clusters: Some(5),
min_cluster_size: 2,
..Default::default()
},
ctfidf: CtfidfConfig {
min_df: 1,
..Default::default()
},
..Default::default()
}
}
pub fn for_large_corpus() -> Self {
Self {
clustering: ClusteringConfig {
num_clusters: Some(50),
min_cluster_size: 10,
checkpoint_interval: 500,
..Default::default()
},
ctfidf: CtfidfConfig {
min_df: 5,
max_df_ratio: 0.8,
..Default::default()
},
..Default::default()
}
}
pub fn with_checkpointing(mut self, interval: usize) -> Self {
self.clustering.checkpoint_interval = interval;
self
}
pub fn verbose(mut self) -> Self {
self.verbose = true;
self
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TopicModelConfig {
pub topic_config: TopicConfig,
pub embedding_dim: usize,
}
impl Default for TopicModelConfig {
fn default() -> Self {
Self {
topic_config: TopicConfig::default(),
embedding_dim: 768, }
}
}
impl TopicModelConfig {
pub fn with_embedding_dim(embedding_dim: usize) -> Self {
Self {
embedding_dim,
..Default::default()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_clustering_config_default() {
let config = ClusteringConfig::default();
assert_eq!(config.num_clusters, Some(10));
assert_eq!(config.linkage, LinkageMethod::Ward);
assert!(config.parallel);
}
#[test]
fn test_ctfidf_config_default() {
let config = CtfidfConfig::default();
assert_eq!(config.num_keywords, 10);
assert_eq!(config.ngram_range, (1, 1));
assert!(config.sublinear_tf);
}
#[test]
fn test_topic_config_small_corpus() {
let config = TopicConfig::for_small_corpus();
assert_eq!(config.clustering.num_clusters, Some(5));
assert_eq!(config.clustering.min_cluster_size, 2);
assert_eq!(config.ctfidf.min_df, 1);
}
#[test]
fn test_topic_config_large_corpus() {
let config = TopicConfig::for_large_corpus();
assert_eq!(config.clustering.num_clusters, Some(50));
assert_eq!(config.clustering.checkpoint_interval, 500);
}
#[test]
fn test_config_serialization() {
let config = TopicConfig::default();
let json = serde_json::to_string(&config).expect("serialization failed");
let deserialized: TopicConfig =
serde_json::from_str(&json).expect("deserialization failed");
assert_eq!(
config.clustering.num_clusters,
deserialized.clustering.num_clusters
);
}
}