use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
pub struct TopicId(pub u32);
impl TopicId {
#[inline]
pub const fn new(id: u32) -> Self {
Self(id)
}
#[inline]
pub const fn as_u32(&self) -> u32 {
self.0
}
#[inline]
pub const fn is_leaf(&self, num_documents: u32) -> bool {
self.0 < num_documents
}
}
impl From<u32> for TopicId {
fn from(id: u32) -> Self {
Self(id)
}
}
impl From<TopicId> for u32 {
fn from(id: TopicId) -> Self {
id.0
}
}
impl std::fmt::Display for TopicId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Topic({})", self.0)
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Topic {
pub id: TopicId,
pub parent_id: Option<TopicId>,
pub children: Vec<TopicId>,
pub level: usize,
pub keywords: Vec<(String, f32)>,
pub description: String,
#[serde(
serialize_with = "serialize_arc_slice",
deserialize_with = "deserialize_arc_slice",
default
)]
pub centroid: Option<Arc<[f32]>>,
pub document_count: usize,
pub coherence: Option<f32>,
}
fn serialize_arc_slice<S>(
value: &Option<Arc<[f32]>>,
serializer: S,
) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match value {
Some(arc) => {
let vec: Vec<f32> = arc.iter().copied().collect();
serializer.serialize_some(&vec)
}
None => serializer.serialize_none(),
}
}
fn deserialize_arc_slice<'de, D>(
deserializer: D,
) -> std::result::Result<Option<Arc<[f32]>>, D::Error>
where
D: serde::Deserializer<'de>,
{
let opt: Option<Vec<f32>> = Option::deserialize(deserializer)?;
Ok(opt.map(|v| Arc::from(v.into_boxed_slice())))
}
impl Topic {
pub fn new(id: TopicId) -> Self {
Self {
id,
parent_id: None,
children: Vec::new(),
level: 0,
keywords: Vec::new(),
description: String::new(),
centroid: None,
document_count: 0,
coherence: None,
}
}
pub fn new_leaf(id: TopicId, keywords: Vec<(String, f32)>, centroid: Vec<f32>) -> Self {
Self {
id,
parent_id: None,
children: Vec::new(),
level: 0,
keywords,
description: String::new(),
centroid: Some(centroid.into()),
document_count: 1,
coherence: None,
}
}
pub fn new_merged(
id: TopicId,
left: TopicId,
right: TopicId,
level: usize,
keywords: Vec<(String, f32)>,
centroid: Vec<f32>,
document_count: usize,
) -> Self {
Self {
id,
parent_id: None,
children: vec![left, right],
level,
keywords,
description: String::new(),
centroid: Some(centroid.into()),
document_count,
coherence: None,
}
}
#[inline]
pub fn is_leaf(&self) -> bool {
self.children.is_empty()
}
#[inline]
pub fn is_root(&self) -> bool {
self.parent_id.is_none()
}
pub fn keyword_summary(&self, n: usize) -> String {
self.keywords
.iter()
.take(n)
.map(|(word, _)| word.as_str())
.collect::<Vec<_>>()
.join(", ")
}
pub fn with_description(mut self, description: impl Into<String>) -> Self {
self.description = description.into();
self
}
pub fn with_keywords(mut self, keywords: Vec<(String, f32)>) -> Self {
self.keywords = keywords;
self
}
pub fn with_centroid(mut self, centroid: Arc<[f32]>) -> Self {
self.centroid = Some(centroid);
self
}
pub fn with_document_count(mut self, count: usize) -> Self {
self.document_count = count;
self
}
pub fn with_parent(mut self, parent_id: TopicId) -> Self {
self.parent_id = Some(parent_id);
self
}
pub fn with_coherence(mut self, coherence: f32) -> Self {
self.coherence = Some(coherence);
self
}
pub fn centroid(&self) -> Option<&[f32]> {
self.centroid.as_ref().map(|c| c.as_ref())
}
}
impl std::fmt::Display for Topic {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.description.is_empty() {
write!(f, "{}: {}", self.id, self.keyword_summary(5))
} else {
write!(f, "{}: {}", self.id, self.description)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_topic_id_creation() {
let id = TopicId::new(42);
assert_eq!(id.as_u32(), 42);
assert_eq!(format!("{}", id), "Topic(42)");
}
#[test]
fn test_topic_id_leaf_check() {
let id = TopicId::new(5);
assert!(id.is_leaf(10));
assert!(!id.is_leaf(5));
assert!(!id.is_leaf(3));
}
#[test]
fn test_topic_leaf() {
let keywords = vec![
("machine".to_string(), 0.9),
("learning".to_string(), 0.8),
("algorithm".to_string(), 0.7),
];
let centroid = vec![0.1, 0.2, 0.3];
let topic = Topic::new_leaf(TopicId::new(0), keywords, centroid);
assert!(topic.is_leaf());
assert!(topic.is_root());
assert_eq!(topic.document_count, 1);
assert_eq!(topic.keyword_summary(2), "machine, learning");
}
#[test]
fn test_topic_merged() {
let keywords = vec![("data".to_string(), 0.85), ("science".to_string(), 0.75)];
let centroid = vec![0.15, 0.25, 0.35];
let topic = Topic::new_merged(
TopicId::new(10),
TopicId::new(1),
TopicId::new(2),
1,
keywords,
centroid,
5,
);
assert!(!topic.is_leaf());
assert!(topic.is_root());
assert_eq!(topic.children.len(), 2);
assert_eq!(topic.document_count, 5);
}
#[test]
fn test_topic_with_description() {
let topic = Topic::new_leaf(TopicId::new(0), vec![], vec![])
.with_description("Machine Learning and AI");
assert_eq!(topic.description, "Machine Learning and AI");
assert_eq!(format!("{}", topic), "Topic(0): Machine Learning and AI");
}
#[test]
fn test_topic_serialization() {
let keywords = vec![("test".to_string(), 0.5)];
let topic = Topic::new_leaf(TopicId::new(1), keywords, vec![0.1, 0.2]);
let json = serde_json::to_string(&topic).expect("serialization failed");
let deserialized: Topic = serde_json::from_str(&json).expect("deserialization failed");
assert_eq!(topic.id, deserialized.id);
assert_eq!(topic.keywords.len(), deserialized.keywords.len());
}
}