1use crate::types::{OcrResult, TextBlock};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct Table {
13 pub bbox: [f32; 4],
15 pub rows: usize,
17 pub cols: usize,
19 pub cells: Vec<TableCell>,
21 pub confidence: f32,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct TableCell {
28 pub row: usize,
30 pub col: usize,
32 pub row_span: usize,
34 pub col_span: usize,
36 pub text: String,
38 pub bbox: [f32; 4],
40 pub confidence: f32,
42 pub is_header: bool,
44}
45
46impl TableCell {
47 pub fn new(row: usize, col: usize, text: String, bbox: [f32; 4]) -> Self {
49 Self {
50 row,
51 col,
52 row_span: 1,
53 col_span: 1,
54 text,
55 bbox,
56 confidence: 1.0,
57 is_header: false,
58 }
59 }
60
61 pub fn with_header(mut self, is_header: bool) -> Self {
63 self.is_header = is_header;
64 self
65 }
66
67 pub fn with_confidence(mut self, confidence: f32) -> Self {
69 self.confidence = confidence;
70 self
71 }
72}
73
74#[derive(Debug, Clone)]
76pub struct TableExtractionConfig {
77 pub min_confidence: f32,
79 pub min_rows: usize,
81 pub min_cols: usize,
83 pub detect_headers: bool,
85 pub cell_padding: f32,
87}
88
89impl Default for TableExtractionConfig {
90 fn default() -> Self {
91 Self {
92 min_confidence: 0.7,
93 min_rows: 2,
94 min_cols: 2,
95 detect_headers: true,
96 cell_padding: 5.0,
97 }
98 }
99}
100
101pub struct TableExtractor {
103 config: TableExtractionConfig,
104}
105
106impl TableExtractor {
107 pub fn new() -> Self {
109 Self {
110 config: TableExtractionConfig::default(),
111 }
112 }
113
114 pub fn with_config(config: TableExtractionConfig) -> Self {
116 Self { config }
117 }
118
119 pub fn extract_tables(&self, ocr_result: &OcrResult) -> Vec<Table> {
121 let mut tables = Vec::new();
122
123 let table_candidates = self.detect_table_regions(&ocr_result.blocks);
125
126 for candidate in table_candidates {
127 if let Some(table) = self.build_table(candidate) {
128 if table.rows >= self.config.min_rows && table.cols >= self.config.min_cols {
129 tables.push(table);
130 }
131 }
132 }
133
134 tables
135 }
136
137 fn detect_table_regions<'a>(&self, blocks: &'a [TextBlock]) -> Vec<Vec<&'a TextBlock>> {
139 let mut regions = Vec::new();
140
141 let mut sorted_blocks: Vec<&TextBlock> = blocks.iter().collect();
146 sorted_blocks.sort_by(|a, b| {
147 a.bbox[1]
148 .partial_cmp(&b.bbox[1])
149 .unwrap_or(std::cmp::Ordering::Equal)
150 });
151
152 let mut rows: Vec<Vec<&TextBlock>> = Vec::new();
154 let mut current_row: Vec<&TextBlock> = Vec::new();
155 let mut current_y = 0.0;
156
157 for block in sorted_blocks {
158 let block_y = block.bbox[1];
159
160 if current_row.is_empty() {
161 current_y = block_y;
162 current_row.push(block);
163 } else if (block_y - current_y).abs() < self.config.cell_padding * 2.0 {
164 current_row.push(block);
166 } else {
167 if !current_row.is_empty() {
169 rows.push(current_row.clone());
170 }
171 current_row.clear();
172 current_row.push(block);
173 current_y = block_y;
174 }
175 }
176 if !current_row.is_empty() {
177 rows.push(current_row);
178 }
179
180 let mut table_rows: Vec<Vec<&TextBlock>> = Vec::new();
182
183 for row in rows {
184 let mut sorted_row = row.clone();
186 sorted_row.sort_by(|a, b| {
187 a.bbox[0]
188 .partial_cmp(&b.bbox[0])
189 .unwrap_or(std::cmp::Ordering::Equal)
190 });
191
192 if sorted_row.len() >= self.config.min_cols {
193 table_rows.push(sorted_row);
194 } else if !table_rows.is_empty() {
195 if table_rows.len() >= self.config.min_rows {
197 regions.push(table_rows.iter().flatten().copied().collect());
198 }
199 table_rows.clear();
200 }
201 }
202
203 if !table_rows.is_empty() && table_rows.len() >= self.config.min_rows {
204 regions.push(table_rows.iter().flatten().copied().collect());
205 }
206
207 regions
208 }
209
210 fn build_table(&self, blocks: Vec<&TextBlock>) -> Option<Table> {
212 if blocks.is_empty() {
213 return None;
214 }
215
216 let min_x = blocks
218 .iter()
219 .map(|b| b.bbox[0])
220 .fold(f32::INFINITY, f32::min);
221 let min_y = blocks
222 .iter()
223 .map(|b| b.bbox[1])
224 .fold(f32::INFINITY, f32::min);
225 let max_x = blocks
226 .iter()
227 .map(|b| b.bbox[0] + b.bbox[2])
228 .fold(f32::NEG_INFINITY, f32::max);
229 let max_y = blocks
230 .iter()
231 .map(|b| b.bbox[1] + b.bbox[3])
232 .fold(f32::NEG_INFINITY, f32::max);
233
234 let bbox = [min_x, min_y, max_x - min_x, max_y - min_y];
235
236 let row_positions = self.detect_row_positions(&blocks);
238 let col_positions = self.detect_column_positions(&blocks);
239
240 let rows = row_positions.len();
241 let cols = col_positions.len();
242
243 if rows < self.config.min_rows || cols < self.config.min_cols {
244 return None;
245 }
246
247 let mut cells = Vec::new();
249
250 for block in &blocks {
251 let row = self.find_row_index(&row_positions, block.bbox[1]);
252 let col = self.find_col_index(&col_positions, block.bbox[0]);
253
254 if let (Some(r), Some(c)) = (row, col) {
255 let is_header = self.config.detect_headers && r == 0;
256 let cell = TableCell::new(r, c, block.text.clone(), block.bbox)
257 .with_header(is_header)
258 .with_confidence(block.confidence);
259
260 cells.push(cell);
261 }
262 }
263
264 let avg_confidence = if !cells.is_empty() {
266 cells.iter().map(|c| c.confidence).sum::<f32>() / cells.len() as f32
267 } else {
268 0.0
269 };
270
271 Some(Table {
272 bbox,
273 rows,
274 cols,
275 cells,
276 confidence: avg_confidence,
277 })
278 }
279
280 fn detect_row_positions(&self, blocks: &[&TextBlock]) -> Vec<f32> {
282 let mut y_coords: Vec<f32> = blocks.iter().map(|b| b.bbox[1]).collect();
283 y_coords.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
284
285 let mut positions = Vec::new();
286 let mut last_y = f32::NEG_INFINITY;
287
288 for &y in &y_coords {
289 if (y - last_y).abs() > self.config.cell_padding * 2.0 {
290 positions.push(y);
291 last_y = y;
292 }
293 }
294
295 positions
296 }
297
298 fn detect_column_positions(&self, blocks: &[&TextBlock]) -> Vec<f32> {
300 let mut x_coords: Vec<f32> = blocks.iter().map(|b| b.bbox[0]).collect();
301 x_coords.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
302
303 let mut positions = Vec::new();
304 let mut last_x = f32::NEG_INFINITY;
305
306 for &x in &x_coords {
307 if (x - last_x).abs() > self.config.cell_padding * 2.0 {
308 positions.push(x);
309 last_x = x;
310 }
311 }
312
313 positions
314 }
315
316 fn find_row_index(&self, row_positions: &[f32], y: f32) -> Option<usize> {
318 for (idx, &pos) in row_positions.iter().enumerate() {
319 if (y - pos).abs() < self.config.cell_padding * 3.0 {
320 return Some(idx);
321 }
322 }
323 None
324 }
325
326 fn find_col_index(&self, col_positions: &[f32], x: f32) -> Option<usize> {
328 for (idx, &pos) in col_positions.iter().enumerate() {
329 if (x - pos).abs() < self.config.cell_padding * 3.0 {
330 return Some(idx);
331 }
332 }
333 None
334 }
335}
336
337impl Default for TableExtractor {
338 fn default() -> Self {
339 Self::new()
340 }
341}
342
343impl Table {
344 pub fn to_csv(&self) -> String {
346 let mut output = String::new();
347
348 let mut grid: HashMap<(usize, usize), String> = HashMap::new();
350
351 for cell in &self.cells {
352 grid.insert((cell.row, cell.col), cell.text.clone());
353 }
354
355 for row in 0..self.rows {
357 let mut row_data = Vec::new();
358 for col in 0..self.cols {
359 let cell_text = grid.get(&(row, col)).cloned().unwrap_or_default();
360 let escaped = if cell_text.contains(',') || cell_text.contains('"') {
362 format!("\"{}\"", cell_text.replace('"', "\"\""))
363 } else {
364 cell_text
365 };
366 row_data.push(escaped);
367 }
368 output.push_str(&row_data.join(","));
369 output.push('\n');
370 }
371
372 output
373 }
374
375 pub fn to_markdown(&self) -> String {
377 let mut output = String::new();
378
379 let mut grid: HashMap<(usize, usize), String> = HashMap::new();
381
382 for cell in &self.cells {
383 grid.insert((cell.row, cell.col), cell.text.clone());
384 }
385
386 for row in 0..self.rows {
388 output.push('|');
389 for col in 0..self.cols {
390 let cell_text = grid.get(&(row, col)).cloned().unwrap_or_default();
391 output.push(' ');
392 output.push_str(&cell_text);
393 output.push_str(" |");
394 }
395 output.push('\n');
396
397 if row == 0 {
399 output.push('|');
400 for _ in 0..self.cols {
401 output.push_str("---|");
402 }
403 output.push('\n');
404 }
405 }
406
407 output
408 }
409
410 pub fn to_html(&self) -> String {
412 let mut output = String::from("<table>\n");
413
414 let mut grid: HashMap<(usize, usize), &TableCell> = HashMap::new();
416
417 for cell in &self.cells {
418 grid.insert((cell.row, cell.col), cell);
419 }
420
421 for row in 0..self.rows {
423 output.push_str(" <tr>\n");
424 for col in 0..self.cols {
425 if let Some(cell) = grid.get(&(row, col)) {
426 let tag = if cell.is_header { "th" } else { "td" };
427 output.push_str(&format!(" <{}>{}</{}>\n", tag, cell.text, tag));
428 } else {
429 output.push_str(" <td></td>\n");
430 }
431 }
432 output.push_str(" </tr>\n");
433 }
434
435 output.push_str("</table>");
436 output
437 }
438
439 pub fn get_row(&self, row_index: usize) -> Vec<&TableCell> {
441 self.cells.iter().filter(|c| c.row == row_index).collect()
442 }
443
444 pub fn get_column(&self, col_index: usize) -> Vec<&TableCell> {
446 self.cells.iter().filter(|c| c.col == col_index).collect()
447 }
448
449 pub fn get_headers(&self) -> Vec<&TableCell> {
451 self.cells.iter().filter(|c| c.is_header).collect()
452 }
453}
454
455#[cfg(test)]
456mod tests {
457 use super::*;
458
459 #[test]
460 fn test_table_cell_creation() {
461 let cell = TableCell::new(0, 0, "Header".to_string(), [0.0, 0.0, 100.0, 20.0])
462 .with_header(true)
463 .with_confidence(0.95);
464
465 assert_eq!(cell.row, 0);
466 assert_eq!(cell.col, 0);
467 assert_eq!(cell.text, "Header");
468 assert!(cell.is_header);
469 assert_eq!(cell.confidence, 0.95);
470 }
471
472 #[test]
473 fn test_table_to_csv() {
474 let cells = vec![
475 TableCell::new(0, 0, "Name".to_string(), [0.0, 0.0, 50.0, 20.0]),
476 TableCell::new(0, 1, "Age".to_string(), [50.0, 0.0, 50.0, 20.0]),
477 TableCell::new(1, 0, "Alice".to_string(), [0.0, 20.0, 50.0, 20.0]),
478 TableCell::new(1, 1, "30".to_string(), [50.0, 20.0, 50.0, 20.0]),
479 ];
480
481 let table = Table {
482 bbox: [0.0, 0.0, 100.0, 40.0],
483 rows: 2,
484 cols: 2,
485 cells,
486 confidence: 0.9,
487 };
488
489 let csv = table.to_csv();
490 assert!(csv.contains("Name,Age"));
491 assert!(csv.contains("Alice,30"));
492 }
493
494 #[test]
495 fn test_table_to_markdown() {
496 let cells = vec![
497 TableCell::new(0, 0, "Header1".to_string(), [0.0, 0.0, 50.0, 20.0]).with_header(true),
498 TableCell::new(0, 1, "Header2".to_string(), [50.0, 0.0, 50.0, 20.0]).with_header(true),
499 TableCell::new(1, 0, "Data1".to_string(), [0.0, 20.0, 50.0, 20.0]),
500 TableCell::new(1, 1, "Data2".to_string(), [50.0, 20.0, 50.0, 20.0]),
501 ];
502
503 let table = Table {
504 bbox: [0.0, 0.0, 100.0, 40.0],
505 rows: 2,
506 cols: 2,
507 cells,
508 confidence: 0.9,
509 };
510
511 let md = table.to_markdown();
512 assert!(md.contains("Header1"));
513 assert!(md.contains("---"));
514 assert!(md.contains("Data1"));
515 }
516
517 #[test]
518 fn test_table_extractor_config() {
519 let config = TableExtractionConfig {
520 min_confidence: 0.8,
521 min_rows: 3,
522 min_cols: 3,
523 detect_headers: false,
524 cell_padding: 10.0,
525 };
526
527 let extractor = TableExtractor::with_config(config.clone());
528 assert_eq!(extractor.config.min_confidence, 0.8);
529 assert_eq!(extractor.config.min_rows, 3);
530 assert!(!extractor.config.detect_headers);
531 }
532
533 #[test]
534 fn test_table_get_row() {
535 let cells = vec![
536 TableCell::new(0, 0, "A".to_string(), [0.0, 0.0, 20.0, 20.0]),
537 TableCell::new(0, 1, "B".to_string(), [20.0, 0.0, 20.0, 20.0]),
538 TableCell::new(1, 0, "C".to_string(), [0.0, 20.0, 20.0, 20.0]),
539 ];
540
541 let table = Table {
542 bbox: [0.0, 0.0, 40.0, 40.0],
543 rows: 2,
544 cols: 2,
545 cells,
546 confidence: 0.9,
547 };
548
549 let row0 = table.get_row(0);
550 assert_eq!(row0.len(), 2);
551 assert_eq!(row0[0].text, "A");
552 }
553
554 #[test]
555 fn test_table_get_headers() {
556 let cells = vec![
557 TableCell::new(0, 0, "Header1".to_string(), [0.0, 0.0, 50.0, 20.0]).with_header(true),
558 TableCell::new(0, 1, "Header2".to_string(), [50.0, 0.0, 50.0, 20.0]).with_header(true),
559 TableCell::new(1, 0, "Data".to_string(), [0.0, 20.0, 50.0, 20.0]),
560 ];
561
562 let table = Table {
563 bbox: [0.0, 0.0, 100.0, 40.0],
564 rows: 2,
565 cols: 2,
566 cells,
567 confidence: 0.9,
568 };
569
570 let headers = table.get_headers();
571 assert_eq!(headers.len(), 2);
572 assert!(headers.iter().all(|h| h.is_header));
573 }
574}