use super::extraction::TextBlock;
use crate::pdf::error::{PdfError, Result};
const KMEANS_MAX_ITERATIONS: usize = 100;
const KMEANS_CONVERGENCE_THRESHOLD: f32 = 0.01;
#[derive(Debug, Clone)]
pub struct FontSizeCluster {
pub centroid: f32,
pub members: Vec<TextBlock>,
}
pub fn cluster_font_sizes(blocks: &[TextBlock], k: usize) -> Result<Vec<FontSizeCluster>> {
if blocks.is_empty() {
return Ok(Vec::new());
}
if k == 0 {
return Err(PdfError::TextExtractionFailed("K must be greater than 0".to_string()));
}
let actual_k = k.min(blocks.len());
let mut font_sizes: Vec<f32> = blocks
.iter()
.map(|b| b.font_size)
.filter(|fs| fs.is_finite()) .collect();
font_sizes.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal)); font_sizes.dedup_by(|a, b| (*a - *b).abs() < 0.05);
let mut centroids: Vec<f32> = Vec::new();
if font_sizes.len() >= actual_k {
let step = font_sizes.len() / actual_k;
for i in 0..actual_k {
let idx = i * step;
centroids.push(font_sizes[idx.min(font_sizes.len() - 1)]);
}
} else {
centroids = font_sizes.clone();
let min_font = font_sizes[font_sizes.len() - 1];
let max_font = font_sizes[0];
let range = max_font - min_font;
while centroids.len() < actual_k {
let t = centroids.len() as f32 / (actual_k - 1) as f32;
let interpolated = max_font - t * range;
centroids.push(interpolated);
}
centroids.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
}
for _ in 0..KMEANS_MAX_ITERATIONS {
let clusters = assign_blocks_to_centroids(blocks, ¢roids);
let mut new_centroids = Vec::with_capacity(actual_k);
for (i, cluster) in clusters.iter().enumerate() {
if !cluster.is_empty() {
new_centroids.push(cluster.iter().map(|b| b.font_size).sum::<f32>() / cluster.len() as f32);
} else {
new_centroids.push(centroids[i]);
}
}
let converged = centroids
.iter()
.zip(new_centroids.iter())
.all(|(old, new)| (old - new).abs() < KMEANS_CONVERGENCE_THRESHOLD);
std::mem::swap(&mut centroids, &mut new_centroids);
if converged {
break;
}
}
let clusters = assign_blocks_to_centroids(blocks, ¢roids);
let mut result: Vec<FontSizeCluster> = Vec::new();
for i in 0..actual_k {
if !clusters[i].is_empty() {
let centroid_value = centroids[i];
result.push(FontSizeCluster {
centroid: centroid_value,
members: clusters[i].clone(),
});
}
}
result.sort_by(|a, b| b.centroid.partial_cmp(&a.centroid).unwrap_or(std::cmp::Ordering::Equal));
Ok(result)
}
pub fn assign_heading_levels_smart(
clusters: &[FontSizeCluster],
min_heading_ratio: f32,
min_heading_gap: f32,
) -> Vec<(f32, Option<u8>)> {
if clusters.is_empty() {
return Vec::new();
}
if clusters.len() == 1 {
return vec![(clusters[0].centroid, None)];
}
let body_idx = clusters
.iter()
.enumerate()
.max_by_key(|(_, c)| c.members.len())
.map(|(i, _)| i)
.unwrap_or(0);
let body_centroid = clusters[body_idx].centroid;
let min_heading_size = body_centroid * min_heading_ratio;
let min_heading_abs = body_centroid + min_heading_gap;
let heading_threshold = min_heading_size.max(min_heading_abs);
let mut heading_candidates: Vec<(usize, f32)> = clusters
.iter()
.enumerate()
.filter(|(i, c)| *i != body_idx && c.centroid >= heading_threshold)
.map(|(i, c)| (i, c.centroid))
.collect();
heading_candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let max_headings = 6usize;
let mut result: Vec<(f32, Option<u8>)> = Vec::with_capacity(clusters.len());
for (i, cluster) in clusters.iter().enumerate() {
if i == body_idx {
result.push((cluster.centroid, None));
} else if let Some(pos) = heading_candidates.iter().position(|(idx, _)| *idx == i) {
if pos < max_headings {
result.push((cluster.centroid, Some((pos + 1) as u8)));
} else {
result.push((cluster.centroid, None));
}
} else {
result.push((cluster.centroid, None));
}
}
result
}
fn assign_blocks_to_centroids(blocks: &[TextBlock], centroids: &[f32]) -> Vec<Vec<TextBlock>> {
let mut clusters: Vec<Vec<TextBlock>> = vec![Vec::new(); centroids.len()];
for block in blocks {
let mut min_distance = f32::INFINITY;
let mut best_cluster = 0;
for (i, ¢roid) in centroids.iter().enumerate() {
let distance = (block.font_size - centroid).abs();
if distance < min_distance {
min_distance = distance;
best_cluster = i;
}
}
clusters[best_cluster].push(block.clone());
}
clusters
}