oxify_connect_vision/
table_extraction.rs

1//! Table extraction and structure detection.
2//!
3//! This module provides functionality to detect and extract tables from images,
4//! including cell-level extraction, structure preservation, and export to various formats.
5
6use crate::types::{OcrResult, TextBlock};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// Represents a detected table in an image.
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct Table {
13    /// Bounding box of the entire table [x, y, width, height]
14    pub bbox: [f32; 4],
15    /// Number of rows in the table
16    pub rows: usize,
17    /// Number of columns in the table
18    pub cols: usize,
19    /// Table cells organized by position
20    pub cells: Vec<TableCell>,
21    /// Confidence score for table detection
22    pub confidence: f32,
23}
24
25/// Represents a single cell in a table.
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct TableCell {
28    /// Row index (0-based)
29    pub row: usize,
30    /// Column index (0-based)
31    pub col: usize,
32    /// Row span (for merged cells)
33    pub row_span: usize,
34    /// Column span (for merged cells)
35    pub col_span: usize,
36    /// Cell text content
37    pub text: String,
38    /// Bounding box [x, y, width, height]
39    pub bbox: [f32; 4],
40    /// Confidence score
41    pub confidence: f32,
42    /// Whether this is a header cell
43    pub is_header: bool,
44}
45
46impl TableCell {
47    /// Create a new table cell.
48    pub fn new(row: usize, col: usize, text: String, bbox: [f32; 4]) -> Self {
49        Self {
50            row,
51            col,
52            row_span: 1,
53            col_span: 1,
54            text,
55            bbox,
56            confidence: 1.0,
57            is_header: false,
58        }
59    }
60
61    /// Mark cell as a header.
62    pub fn with_header(mut self, is_header: bool) -> Self {
63        self.is_header = is_header;
64        self
65    }
66
67    /// Set cell confidence.
68    pub fn with_confidence(mut self, confidence: f32) -> Self {
69        self.confidence = confidence;
70        self
71    }
72}
73
74/// Table extraction configuration.
75#[derive(Debug, Clone)]
76pub struct TableExtractionConfig {
77    /// Minimum confidence threshold for table detection
78    pub min_confidence: f32,
79    /// Minimum number of rows to consider as a table
80    pub min_rows: usize,
81    /// Minimum number of columns to consider as a table
82    pub min_cols: usize,
83    /// Enable header row detection
84    pub detect_headers: bool,
85    /// Cell padding tolerance (pixels)
86    pub cell_padding: f32,
87}
88
89impl Default for TableExtractionConfig {
90    fn default() -> Self {
91        Self {
92            min_confidence: 0.7,
93            min_rows: 2,
94            min_cols: 2,
95            detect_headers: true,
96            cell_padding: 5.0,
97        }
98    }
99}
100
101/// Table extractor for detecting and extracting tables from OCR results.
102pub struct TableExtractor {
103    config: TableExtractionConfig,
104}
105
106impl TableExtractor {
107    /// Create a new table extractor with default configuration.
108    pub fn new() -> Self {
109        Self {
110            config: TableExtractionConfig::default(),
111        }
112    }
113
114    /// Create a new table extractor with custom configuration.
115    pub fn with_config(config: TableExtractionConfig) -> Self {
116        Self { config }
117    }
118
119    /// Extract tables from OCR result.
120    pub fn extract_tables(&self, ocr_result: &OcrResult) -> Vec<Table> {
121        let mut tables = Vec::new();
122
123        // Group text blocks into potential table regions
124        let table_candidates = self.detect_table_regions(&ocr_result.blocks);
125
126        for candidate in table_candidates {
127            if let Some(table) = self.build_table(candidate) {
128                if table.rows >= self.config.min_rows && table.cols >= self.config.min_cols {
129                    tables.push(table);
130                }
131            }
132        }
133
134        tables
135    }
136
137    /// Detect potential table regions from text blocks.
138    fn detect_table_regions<'a>(&self, blocks: &'a [TextBlock]) -> Vec<Vec<&'a TextBlock>> {
139        let mut regions = Vec::new();
140
141        // Simple heuristic: group blocks with similar y-coordinates (rows)
142        // and detect regular column patterns
143
144        // Sort blocks by y-coordinate
145        let mut sorted_blocks: Vec<&TextBlock> = blocks.iter().collect();
146        sorted_blocks.sort_by(|a, b| {
147            a.bbox[1]
148                .partial_cmp(&b.bbox[1])
149                .unwrap_or(std::cmp::Ordering::Equal)
150        });
151
152        // Group into rows based on y-coordinate similarity
153        let mut rows: Vec<Vec<&TextBlock>> = Vec::new();
154        let mut current_row: Vec<&TextBlock> = Vec::new();
155        let mut current_y = 0.0;
156
157        for block in sorted_blocks {
158            let block_y = block.bbox[1];
159
160            if current_row.is_empty() {
161                current_y = block_y;
162                current_row.push(block);
163            } else if (block_y - current_y).abs() < self.config.cell_padding * 2.0 {
164                // Same row
165                current_row.push(block);
166            } else {
167                // New row
168                if !current_row.is_empty() {
169                    rows.push(current_row.clone());
170                }
171                current_row.clear();
172                current_row.push(block);
173                current_y = block_y;
174            }
175        }
176        if !current_row.is_empty() {
177            rows.push(current_row);
178        }
179
180        // Detect table-like patterns: consecutive rows with similar column counts
181        let mut table_rows: Vec<Vec<&TextBlock>> = Vec::new();
182
183        for row in rows {
184            // Sort row blocks by x-coordinate
185            let mut sorted_row = row.clone();
186            sorted_row.sort_by(|a, b| {
187                a.bbox[0]
188                    .partial_cmp(&b.bbox[0])
189                    .unwrap_or(std::cmp::Ordering::Equal)
190            });
191
192            if sorted_row.len() >= self.config.min_cols {
193                table_rows.push(sorted_row);
194            } else if !table_rows.is_empty() {
195                // End of potential table
196                if table_rows.len() >= self.config.min_rows {
197                    regions.push(table_rows.iter().flatten().copied().collect());
198                }
199                table_rows.clear();
200            }
201        }
202
203        if !table_rows.is_empty() && table_rows.len() >= self.config.min_rows {
204            regions.push(table_rows.iter().flatten().copied().collect());
205        }
206
207        regions
208    }
209
210    /// Build a table structure from a group of text blocks.
211    fn build_table(&self, blocks: Vec<&TextBlock>) -> Option<Table> {
212        if blocks.is_empty() {
213            return None;
214        }
215
216        // Calculate table bounding box
217        let min_x = blocks
218            .iter()
219            .map(|b| b.bbox[0])
220            .fold(f32::INFINITY, f32::min);
221        let min_y = blocks
222            .iter()
223            .map(|b| b.bbox[1])
224            .fold(f32::INFINITY, f32::min);
225        let max_x = blocks
226            .iter()
227            .map(|b| b.bbox[0] + b.bbox[2])
228            .fold(f32::NEG_INFINITY, f32::max);
229        let max_y = blocks
230            .iter()
231            .map(|b| b.bbox[1] + b.bbox[3])
232            .fold(f32::NEG_INFINITY, f32::max);
233
234        let bbox = [min_x, min_y, max_x - min_x, max_y - min_y];
235
236        // Detect row and column positions
237        let row_positions = self.detect_row_positions(&blocks);
238        let col_positions = self.detect_column_positions(&blocks);
239
240        let rows = row_positions.len();
241        let cols = col_positions.len();
242
243        if rows < self.config.min_rows || cols < self.config.min_cols {
244            return None;
245        }
246
247        // Build cells
248        let mut cells = Vec::new();
249
250        for block in &blocks {
251            let row = self.find_row_index(&row_positions, block.bbox[1]);
252            let col = self.find_col_index(&col_positions, block.bbox[0]);
253
254            if let (Some(r), Some(c)) = (row, col) {
255                let is_header = self.config.detect_headers && r == 0;
256                let cell = TableCell::new(r, c, block.text.clone(), block.bbox)
257                    .with_header(is_header)
258                    .with_confidence(block.confidence);
259
260                cells.push(cell);
261            }
262        }
263
264        // Calculate average confidence
265        let avg_confidence = if !cells.is_empty() {
266            cells.iter().map(|c| c.confidence).sum::<f32>() / cells.len() as f32
267        } else {
268            0.0
269        };
270
271        Some(Table {
272            bbox,
273            rows,
274            cols,
275            cells,
276            confidence: avg_confidence,
277        })
278    }
279
280    /// Detect row positions from text blocks.
281    fn detect_row_positions(&self, blocks: &[&TextBlock]) -> Vec<f32> {
282        let mut y_coords: Vec<f32> = blocks.iter().map(|b| b.bbox[1]).collect();
283        y_coords.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
284
285        let mut positions = Vec::new();
286        let mut last_y = f32::NEG_INFINITY;
287
288        for &y in &y_coords {
289            if (y - last_y).abs() > self.config.cell_padding * 2.0 {
290                positions.push(y);
291                last_y = y;
292            }
293        }
294
295        positions
296    }
297
298    /// Detect column positions from text blocks.
299    fn detect_column_positions(&self, blocks: &[&TextBlock]) -> Vec<f32> {
300        let mut x_coords: Vec<f32> = blocks.iter().map(|b| b.bbox[0]).collect();
301        x_coords.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
302
303        let mut positions = Vec::new();
304        let mut last_x = f32::NEG_INFINITY;
305
306        for &x in &x_coords {
307            if (x - last_x).abs() > self.config.cell_padding * 2.0 {
308                positions.push(x);
309                last_x = x;
310            }
311        }
312
313        positions
314    }
315
316    /// Find row index for a y-coordinate.
317    fn find_row_index(&self, row_positions: &[f32], y: f32) -> Option<usize> {
318        for (idx, &pos) in row_positions.iter().enumerate() {
319            if (y - pos).abs() < self.config.cell_padding * 3.0 {
320                return Some(idx);
321            }
322        }
323        None
324    }
325
326    /// Find column index for an x-coordinate.
327    fn find_col_index(&self, col_positions: &[f32], x: f32) -> Option<usize> {
328        for (idx, &pos) in col_positions.iter().enumerate() {
329            if (x - pos).abs() < self.config.cell_padding * 3.0 {
330                return Some(idx);
331            }
332        }
333        None
334    }
335}
336
337impl Default for TableExtractor {
338    fn default() -> Self {
339        Self::new()
340    }
341}
342
343impl Table {
344    /// Export table to CSV format.
345    pub fn to_csv(&self) -> String {
346        let mut output = String::new();
347
348        // Create a grid to hold cell contents
349        let mut grid: HashMap<(usize, usize), String> = HashMap::new();
350
351        for cell in &self.cells {
352            grid.insert((cell.row, cell.col), cell.text.clone());
353        }
354
355        // Generate CSV rows
356        for row in 0..self.rows {
357            let mut row_data = Vec::new();
358            for col in 0..self.cols {
359                let cell_text = grid.get(&(row, col)).cloned().unwrap_or_default();
360                // Escape quotes and commas
361                let escaped = if cell_text.contains(',') || cell_text.contains('"') {
362                    format!("\"{}\"", cell_text.replace('"', "\"\""))
363                } else {
364                    cell_text
365                };
366                row_data.push(escaped);
367            }
368            output.push_str(&row_data.join(","));
369            output.push('\n');
370        }
371
372        output
373    }
374
375    /// Export table to Markdown format.
376    pub fn to_markdown(&self) -> String {
377        let mut output = String::new();
378
379        // Create a grid to hold cell contents
380        let mut grid: HashMap<(usize, usize), String> = HashMap::new();
381
382        for cell in &self.cells {
383            grid.insert((cell.row, cell.col), cell.text.clone());
384        }
385
386        // Generate markdown table
387        for row in 0..self.rows {
388            output.push('|');
389            for col in 0..self.cols {
390                let cell_text = grid.get(&(row, col)).cloned().unwrap_or_default();
391                output.push(' ');
392                output.push_str(&cell_text);
393                output.push_str(" |");
394            }
395            output.push('\n');
396
397            // Add separator after header row
398            if row == 0 {
399                output.push('|');
400                for _ in 0..self.cols {
401                    output.push_str("---|");
402                }
403                output.push('\n');
404            }
405        }
406
407        output
408    }
409
410    /// Export table to HTML format.
411    pub fn to_html(&self) -> String {
412        let mut output = String::from("<table>\n");
413
414        // Create a grid to hold cell contents
415        let mut grid: HashMap<(usize, usize), &TableCell> = HashMap::new();
416
417        for cell in &self.cells {
418            grid.insert((cell.row, cell.col), cell);
419        }
420
421        // Generate HTML table
422        for row in 0..self.rows {
423            output.push_str("  <tr>\n");
424            for col in 0..self.cols {
425                if let Some(cell) = grid.get(&(row, col)) {
426                    let tag = if cell.is_header { "th" } else { "td" };
427                    output.push_str(&format!("    <{}>{}</{}>\n", tag, cell.text, tag));
428                } else {
429                    output.push_str("    <td></td>\n");
430                }
431            }
432            output.push_str("  </tr>\n");
433        }
434
435        output.push_str("</table>");
436        output
437    }
438
439    /// Get all cells in a specific row.
440    pub fn get_row(&self, row_index: usize) -> Vec<&TableCell> {
441        self.cells.iter().filter(|c| c.row == row_index).collect()
442    }
443
444    /// Get all cells in a specific column.
445    pub fn get_column(&self, col_index: usize) -> Vec<&TableCell> {
446        self.cells.iter().filter(|c| c.col == col_index).collect()
447    }
448
449    /// Get header cells (first row if header detection is enabled).
450    pub fn get_headers(&self) -> Vec<&TableCell> {
451        self.cells.iter().filter(|c| c.is_header).collect()
452    }
453}
454
455#[cfg(test)]
456mod tests {
457    use super::*;
458
459    #[test]
460    fn test_table_cell_creation() {
461        let cell = TableCell::new(0, 0, "Header".to_string(), [0.0, 0.0, 100.0, 20.0])
462            .with_header(true)
463            .with_confidence(0.95);
464
465        assert_eq!(cell.row, 0);
466        assert_eq!(cell.col, 0);
467        assert_eq!(cell.text, "Header");
468        assert!(cell.is_header);
469        assert_eq!(cell.confidence, 0.95);
470    }
471
472    #[test]
473    fn test_table_to_csv() {
474        let cells = vec![
475            TableCell::new(0, 0, "Name".to_string(), [0.0, 0.0, 50.0, 20.0]),
476            TableCell::new(0, 1, "Age".to_string(), [50.0, 0.0, 50.0, 20.0]),
477            TableCell::new(1, 0, "Alice".to_string(), [0.0, 20.0, 50.0, 20.0]),
478            TableCell::new(1, 1, "30".to_string(), [50.0, 20.0, 50.0, 20.0]),
479        ];
480
481        let table = Table {
482            bbox: [0.0, 0.0, 100.0, 40.0],
483            rows: 2,
484            cols: 2,
485            cells,
486            confidence: 0.9,
487        };
488
489        let csv = table.to_csv();
490        assert!(csv.contains("Name,Age"));
491        assert!(csv.contains("Alice,30"));
492    }
493
494    #[test]
495    fn test_table_to_markdown() {
496        let cells = vec![
497            TableCell::new(0, 0, "Header1".to_string(), [0.0, 0.0, 50.0, 20.0]).with_header(true),
498            TableCell::new(0, 1, "Header2".to_string(), [50.0, 0.0, 50.0, 20.0]).with_header(true),
499            TableCell::new(1, 0, "Data1".to_string(), [0.0, 20.0, 50.0, 20.0]),
500            TableCell::new(1, 1, "Data2".to_string(), [50.0, 20.0, 50.0, 20.0]),
501        ];
502
503        let table = Table {
504            bbox: [0.0, 0.0, 100.0, 40.0],
505            rows: 2,
506            cols: 2,
507            cells,
508            confidence: 0.9,
509        };
510
511        let md = table.to_markdown();
512        assert!(md.contains("Header1"));
513        assert!(md.contains("---"));
514        assert!(md.contains("Data1"));
515    }
516
517    #[test]
518    fn test_table_extractor_config() {
519        let config = TableExtractionConfig {
520            min_confidence: 0.8,
521            min_rows: 3,
522            min_cols: 3,
523            detect_headers: false,
524            cell_padding: 10.0,
525        };
526
527        let extractor = TableExtractor::with_config(config.clone());
528        assert_eq!(extractor.config.min_confidence, 0.8);
529        assert_eq!(extractor.config.min_rows, 3);
530        assert!(!extractor.config.detect_headers);
531    }
532
533    #[test]
534    fn test_table_get_row() {
535        let cells = vec![
536            TableCell::new(0, 0, "A".to_string(), [0.0, 0.0, 20.0, 20.0]),
537            TableCell::new(0, 1, "B".to_string(), [20.0, 0.0, 20.0, 20.0]),
538            TableCell::new(1, 0, "C".to_string(), [0.0, 20.0, 20.0, 20.0]),
539        ];
540
541        let table = Table {
542            bbox: [0.0, 0.0, 40.0, 40.0],
543            rows: 2,
544            cols: 2,
545            cells,
546            confidence: 0.9,
547        };
548
549        let row0 = table.get_row(0);
550        assert_eq!(row0.len(), 2);
551        assert_eq!(row0[0].text, "A");
552    }
553
554    #[test]
555    fn test_table_get_headers() {
556        let cells = vec![
557            TableCell::new(0, 0, "Header1".to_string(), [0.0, 0.0, 50.0, 20.0]).with_header(true),
558            TableCell::new(0, 1, "Header2".to_string(), [50.0, 0.0, 50.0, 20.0]).with_header(true),
559            TableCell::new(1, 0, "Data".to_string(), [0.0, 20.0, 50.0, 20.0]),
560        ];
561
562        let table = Table {
563            bbox: [0.0, 0.0, 100.0, 40.0],
564            rows: 2,
565            cols: 2,
566            cells,
567            confidence: 0.9,
568        };
569
570        let headers = table.get_headers();
571        assert_eq!(headers.len(), 2);
572        assert!(headers.iter().all(|h| h.is_header));
573    }
574}