use super::extraction::TextBlock;
use crate::pdf::error::{PdfError, Result};
const KMEANS_MAX_ITERATIONS: usize = 100;
const KMEANS_CONVERGENCE_THRESHOLD: f32 = 0.01;
#[derive(Debug, Clone)]
pub struct FontSizeCluster {
pub centroid: f32,
pub members: Vec<TextBlock>,
}
pub fn cluster_font_sizes(blocks: &[TextBlock], k: usize) -> Result<Vec<FontSizeCluster>> {
if blocks.is_empty() {
return Ok(Vec::new());
}
if k == 0 {
return Err(PdfError::TextExtractionFailed("K must be greater than 0".to_string()));
}
let actual_k = k.min(blocks.len());
let mut font_sizes: Vec<f32> = blocks
.iter()
.map(|b| b.font_size)
.filter(|fs| fs.is_finite()) .collect();
font_sizes.sort_by(|a, b| b.total_cmp(a)); font_sizes.dedup_by(|a, b| (*a - *b).abs() < 0.05);
let mut centroids: Vec<f32> = Vec::new();
if font_sizes.len() >= actual_k {
let step = font_sizes.len() / actual_k;
for i in 0..actual_k {
let idx = i * step;
centroids.push(font_sizes[idx.min(font_sizes.len() - 1)]);
}
} else {
centroids = font_sizes.clone();
let min_font = font_sizes[font_sizes.len() - 1];
let max_font = font_sizes[0];
let range = max_font - min_font;
while centroids.len() < actual_k {
let t = centroids.len() as f32 / (actual_k - 1) as f32;
let interpolated = max_font - t * range;
centroids.push(interpolated);
}
centroids.sort_by(|a, b| b.total_cmp(a));
}
let font_sizes: Vec<f32> = blocks.iter().map(|b| b.font_size).collect();
let mut prev_assignments: Vec<usize> = vec![0; font_sizes.len()];
let mut first_iter = true;
for _ in 0..KMEANS_MAX_ITERATIONS {
let (size_clusters, assignments) = assign_sizes_to_centroids_tracked(&font_sizes, ¢roids);
let assignments_changed = if first_iter {
first_iter = false;
1 } else {
assignments
.iter()
.zip(prev_assignments.iter())
.filter(|(a, b)| a != b)
.count()
};
prev_assignments = assignments;
if assignments_changed == 0 {
break;
}
let mut new_centroids = Vec::with_capacity(actual_k);
for (i, cluster) in size_clusters.iter().enumerate() {
if !cluster.is_empty() {
new_centroids.push(cluster.iter().sum::<f32>() / cluster.len() as f32);
} else {
new_centroids.push(centroids[i]);
}
}
let converged = centroids
.iter()
.zip(new_centroids.iter())
.all(|(old, new)| (old - new).abs() < KMEANS_CONVERGENCE_THRESHOLD);
std::mem::swap(&mut centroids, &mut new_centroids);
if converged {
break;
}
}
let clusters = assign_blocks_to_centroids(blocks, ¢roids);
let mut result: Vec<FontSizeCluster> = Vec::new();
for i in 0..actual_k {
if !clusters[i].is_empty() {
let centroid_value = centroids[i];
result.push(FontSizeCluster {
centroid: centroid_value,
members: clusters[i].clone(),
});
}
}
result.sort_by(|a, b| b.centroid.total_cmp(&a.centroid));
Ok(result)
}
pub fn assign_heading_levels_smart(
clusters: &[FontSizeCluster],
min_heading_ratio: f32,
min_heading_gap: f32,
) -> Vec<(f32, Option<u8>)> {
if clusters.is_empty() {
return Vec::new();
}
if clusters.len() == 1 {
return vec![(clusters[0].centroid, None)];
}
let body_idx = clusters
.iter()
.enumerate()
.max_by_key(|(_, c)| c.members.iter().map(|block| block.text.len()).sum::<usize>())
.map(|(i, _)| i)
.unwrap_or(0);
let body_centroid = clusters[body_idx].centroid;
let min_heading_size = body_centroid * min_heading_ratio;
let min_heading_abs = body_centroid + min_heading_gap;
let heading_threshold = min_heading_size.min(min_heading_abs);
let mut heading_candidates: Vec<(usize, f32)> = clusters
.iter()
.enumerate()
.filter(|(i, c)| *i != body_idx && c.centroid >= heading_threshold)
.map(|(i, c)| (i, c.centroid))
.collect();
heading_candidates.sort_by(|a, b| b.1.total_cmp(&a.1));
let max_headings = 6usize;
let mut result: Vec<(f32, Option<u8>)> = Vec::with_capacity(clusters.len());
for (i, cluster) in clusters.iter().enumerate() {
if i == body_idx {
result.push((cluster.centroid, None));
} else if let Some(pos) = heading_candidates.iter().position(|(idx, _)| *idx == i) {
if pos < max_headings {
result.push((cluster.centroid, Some((pos + 1) as u8)));
} else {
result.push((cluster.centroid, None));
}
} else {
result.push((cluster.centroid, None));
}
}
result
}
fn assign_sizes_to_centroids_tracked(font_sizes: &[f32], centroids: &[f32]) -> (Vec<Vec<f32>>, Vec<usize>) {
let mut clusters: Vec<Vec<f32>> = vec![Vec::new(); centroids.len()];
let mut assignments: Vec<usize> = Vec::with_capacity(font_sizes.len());
for &size in font_sizes {
let mut min_distance = f32::INFINITY;
let mut best_cluster = 0;
for (i, ¢roid) in centroids.iter().enumerate() {
let distance = (size - centroid).abs();
if distance < min_distance {
min_distance = distance;
best_cluster = i;
}
}
clusters[best_cluster].push(size);
assignments.push(best_cluster);
}
(clusters, assignments)
}
fn assign_blocks_to_centroids(blocks: &[TextBlock], centroids: &[f32]) -> Vec<Vec<TextBlock>> {
let mut clusters: Vec<Vec<TextBlock>> = vec![Vec::new(); centroids.len()];
for block in blocks {
let mut min_distance = f32::INFINITY;
let mut best_cluster = 0;
for (i, ¢roid) in centroids.iter().enumerate() {
let distance = (block.font_size - centroid).abs();
if distance < min_distance {
min_distance = distance;
best_cluster = i;
}
}
clusters[best_cluster].push(block.clone());
}
clusters
}
#[cfg(test)]
mod tests {
use super::*;
use crate::pdf::hierarchy::bounding_box::BoundingBox;
fn make_block(text: &str, font_size: f32) -> TextBlock {
TextBlock {
text: text.to_string(),
bbox: BoundingBox {
left: 0.0,
top: 0.0,
right: 100.0,
bottom: font_size,
},
font_size,
}
}
#[test]
fn test_body_cluster_by_text_content_not_member_count() {
let mut blocks = Vec::new();
for i in 0..10 {
blocks.push(make_block(&format!("Hdr{i}"), 8.0)); }
for _ in 0..3 {
blocks.push(make_block("This is a longer body text paragraph with content.", 12.0));
}
let clusters = cluster_font_sizes(&blocks, 2).unwrap();
let levels = assign_heading_levels_smart(&clusters, 1.15, 1.5);
let body_centroid = levels.iter().find(|(_, l)| l.is_none()).map(|(c, _)| *c);
assert!(body_centroid.is_some(), "should have a body cluster");
let bc = body_centroid.unwrap();
assert!((bc - 12.0).abs() < 1.0, "body centroid should be near 12pt, got {bc}");
}
#[test]
fn test_body_cluster_equal_members_picks_more_content() {
let blocks = vec![
make_block("AB", 18.0),
make_block("CD", 18.0),
make_block("This is much longer body text content here.", 12.0),
make_block("Another long paragraph of body text for the doc.", 12.0),
];
let clusters = cluster_font_sizes(&blocks, 2).unwrap();
let levels = assign_heading_levels_smart(&clusters, 1.15, 1.5);
let body_centroid = levels.iter().find(|(_, l)| l.is_none()).map(|(c, _)| *c);
assert!(body_centroid.is_some());
let bc = body_centroid.unwrap();
assert!(
(bc - 12.0).abs() < 1.0,
"body should be 12pt cluster (more text), got {bc}"
);
}
}