oar-ocr-core 0.6.3

Core types and predictors for oar-ocr
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
//! Table Structure Decoding Processor
//!
//! This module provides postprocessing for table structure recognition models.
//! It decodes structure token logits and extracts bounding boxes for table cells.

use crate::core::OCRError;
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::path::Path;

type TableDecodeArtifacts = (Vec<String>, Vec<[f32; 8]>, f32);
type TableDecodeResult = Result<TableDecodeArtifacts, OCRError>;

/// Wraps table structure tokens with HTML document tags.
///
/// PP-StructureV3 outputs full HTML documents with `<html><body><table>...</table></body></html>`.
/// This function wraps raw table tokens to match that format.
///
/// # Arguments
///
/// * `tokens` - Table structure tokens (e.g., `["<tr>", "<td>", "...</td>", "</tr>"]`)
///
/// # Returns
///
/// Complete HTML string with document wrapper tags.
///
/// # Example
///
/// ```
/// use oar_ocr_core::processors::wrap_table_html;
///
/// let tokens = vec!["<tr>".to_string(), "<td></td>".to_string(), "</tr>".to_string()];
/// let html = wrap_table_html(&tokens);
/// assert!(html.starts_with("<html><body><table>"));
/// assert!(html.ends_with("</table></body></html>"));
/// ```
pub fn wrap_table_html(tokens: &[String]) -> String {
    render_table_html(tokens, None)
}

/// Wraps table structure tokens into an HTML string with cell content filled in.
///
/// This follows standard HTML result logic:
/// - When encountering `<td></td>` or `</td>`, insert the corresponding cell text
/// - Cell texts are matched in order (td_index)
///
/// # Arguments
/// * `tokens` - Structure tokens from table structure recognition
/// * `cell_texts` - Cell texts to fill, in order of `<td>` appearance
///
/// # Example
/// ```
/// use oar_ocr_core::processors::wrap_table_html_with_content;
///
/// let tokens = vec![
///     "<tr>".to_string(),
///     "<td></td>".to_string(),
///     "<td></td>".to_string(),
///     "</tr>".to_string()
/// ];
/// let cell_texts = vec![Some("Cell 1".to_string()), Some("Cell 2".to_string())];
/// let html = wrap_table_html_with_content(&tokens, &cell_texts);
/// assert!(html.contains("Cell 1"));
/// assert!(html.contains("Cell 2"));
/// ```
pub fn wrap_table_html_with_content(tokens: &[String], cell_texts: &[Option<String>]) -> String {
    render_table_html(tokens, Some(cell_texts))
}

/// Renders table HTML, optionally filling cell content.
fn render_table_html(tokens: &[String], cell_texts: Option<&[Option<String>]>) -> String {
    let mut result = Vec::new();
    let mut td_index = 0;
    let mut idx = 0usize;

    result.push("<html><body>".to_string());

    // Check if table tag is already present in tokens
    let has_table_tag = tokens
        .first()
        .map(|t| t.contains("<table"))
        .unwrap_or(false);
    if !has_table_tag {
        result.push("<table>".to_string());
    }

    while idx < tokens.len() {
        let tag = tokens[idx].as_str();

        // Handle standard empty cell token
        if tag == "<td></td>" {
            result.push("<td>".to_string());
            if let Some(texts) = cell_texts
                && let Some(Some(text)) = texts.get(td_index)
            {
                result.push(text.clone());
            }
            result.push("</td>".to_string());
            td_index += 1;
            idx += 1;
            continue;
        }

        // Handle opening td tag (possibly with attributes)
        if tag.starts_with("<td") {
            let parsed = parse_td_tag(tokens, idx);
            result.push(format!("<td{}>", parsed.attrs));

            // Check for bold tag immediately following td
            // Some structure models might output <td><b>...
            let mut is_bold = false;
            let next_idx = parsed.next_index;

            // Peek ahead for <b> inside the cell (before the next tag/end)
            // Note: This simple logic assumes <b> is a distinct token if present
            if next_idx < tokens.len() && tokens[next_idx] == "<b>" {
                is_bold = true;
                // Consume <b>
                // next_idx += 1;
                // actually we don't consume it here, we let the loop handle it?
                // No, we are inside the "fill content" logic.
                // If structure has <b>, we should wrap our content in <b>.
            }

            if let Some(texts) = cell_texts
                && let Some(Some(text)) = texts.get(td_index)
            {
                if is_bold {
                    result.push("<b>".to_string());
                }
                result.push(text.clone());
                if is_bold {
                    result.push("</b>".to_string());
                }
            }

            result.push("</td>".to_string());
            td_index += 1;

            // If we detected <b> structure, we effectively "handled" it by wrapping content.
            // However, to avoid duplicating tags if they are in the token stream,
            // we should ideally consume them. But robustly parsing nested structure
            // like <td><b></b></td> is complex.
            //
            // Standard PP-StructureV3 primarily uses <thead> vs <tbody> for styling
            // and <td></td> tokens. Boldness usually comes from being in <thead>.
            // We'll stick to the parsed index.
            idx = parsed.next_index;
            continue;
        }

        // Pass through all other tokens (<thead>, <tbody>, <tfoot>, <tr>, </tr>, etc.)
        result.push(tokens[idx].clone());
        idx += 1;
    }

    if !has_table_tag {
        result.push("</table>".to_string());
    }
    result.push("</body></html>".to_string());

    result.join("")
}

/// Grid position and span information for a table cell.
#[derive(Debug, Clone, Default)]
pub struct CellGridInfo {
    /// Row index (0-based)
    pub row: usize,
    /// Column index (0-based)
    pub col: usize,
    /// Number of rows this cell spans
    pub row_span: usize,
    /// Number of columns this cell spans
    pub col_span: usize,
}

/// Parses structure tokens to extract grid position and span info for each cell.
///
/// This function walks through HTML structure tokens and tracks row/column positions,
/// accounting for colspan and rowspan attributes. The returned vector has one entry
/// per `<td>` cell in the same order as the bboxes.
///
/// # Arguments
///
/// * `tokens` - Structure tokens from table structure recognition
///
/// # Returns
///
/// A vector of `CellGridInfo` for each cell, in order of appearance.
///
/// # Example
///
/// ```ignore
/// let tokens = vec![
///     "<tr>".to_string(),
///     "<td></td>".to_string(),
///     "<td colspan=\"2\"></td>".to_string(),
///     "</tr>".to_string(),
///     "<tr>".to_string(),
///     "<td></td>".to_string(),
///     "<td></td>".to_string(),
///     "<td></td>".to_string(),
///     "</tr>".to_string(),
/// ];
/// let grid_info = parse_cell_grid_info(&tokens);
/// // First row: cell at (0,0), cell at (0,1) spanning 2 cols
/// // Second row: cells at (1,0), (1,1), (1,2)
/// ```
pub fn parse_cell_grid_info(tokens: &[String]) -> Vec<CellGridInfo> {
    let mut cells = Vec::new();
    let mut current_row: usize = 0;
    let mut current_col: usize = 0;
    let mut idx = 0usize;

    // Track which columns are occupied by rowspans from previous rows
    // Key: (row, col) -> true if occupied
    let mut occupied: std::collections::HashSet<(usize, usize)> = std::collections::HashSet::new();

    while idx < tokens.len() {
        let token = tokens[idx].as_str();

        if token == "<tr>" {
            // Start of a new row - reset column counter
            current_col = 0;
            // Skip columns occupied by rowspans from previous rows
            while occupied.contains(&(current_row, current_col)) {
                current_col += 1;
            }
            idx += 1;
            continue;
        }

        if token == "</tr>" {
            // End of row - move to next row
            current_row += 1;
            idx += 1;
            continue;
        }

        if token == "<td></td>" {
            while occupied.contains(&(current_row, current_col)) {
                current_col += 1;
            }
            cells.push(CellGridInfo {
                row: current_row,
                col: current_col,
                row_span: 1,
                col_span: 1,
            });
            current_col += 1;
            idx += 1;
            continue;
        }

        if token.starts_with("<td") {
            let parsed = parse_td_tag(tokens, idx);

            // Skip columns occupied by rowspans
            while occupied.contains(&(current_row, current_col)) {
                current_col += 1;
            }

            // Record this cell's position
            cells.push(CellGridInfo {
                row: current_row,
                col: current_col,
                row_span: parsed.row_span,
                col_span: parsed.col_span,
            });

            // Mark cells occupied by this cell's rowspan (for future rows)
            if parsed.row_span > 1 {
                for r in 1..parsed.row_span {
                    for c in 0..parsed.col_span {
                        occupied.insert((current_row + r, current_col + c));
                    }
                }
            }

            // Advance column by colspan
            current_col += parsed.col_span;
            idx = parsed.next_index;
            continue;
        }

        idx += 1;
    }

    cells
}

/// Parses a span attribute (colspan or rowspan) from an HTML tag.
fn parse_span_attr(token: &str, attr: &str) -> Option<usize> {
    // Look for patterns like colspan="2" or rowspan="3"
    let pattern = format!("{}=\"", attr);
    if let Some(start) = token.find(&pattern) {
        let value_start = start + pattern.len();
        if let Some(end) = token[value_start..].find('"')
            && let Ok(value) = token[value_start..value_start + end].parse::<usize>()
        {
            return Some(value);
        }
    }
    None
}

/// Parsed information about a `<td>` token sequence.
#[derive(Debug, Clone)]
struct ParsedTdTag {
    /// Raw attributes (including leading spaces) to append after `<td`
    attrs: String,
    /// Rowspan value (defaults to 1)
    row_span: usize,
    /// Colspan value (defaults to 1)
    col_span: usize,
    /// Index to continue parsing from (skips attribute and closing tokens)
    next_index: usize,
}

/// Parses a `<td ...>` sequence that may be split across multiple tokens.
///
/// The Paddle table structure dictionary splits `<td` attributes into separate tokens, e.g.:
/// `["<td", " colspan=\"2\"", " rowspan=\"3\"", ">", "</td>"]`
/// This helper gathers those pieces into a single opening tag and extracts span info.
fn parse_td_tag(tokens: &[String], start_idx: usize) -> ParsedTdTag {
    let mut attrs = String::new();
    let mut col_span = 1usize;
    let mut row_span = 1usize;

    // Handle attributes that might already be embedded in the starting token (e.g., "<td colspan=\"2\">")
    if let Some(start_token) = tokens.get(start_idx)
        && let Some(stripped) = start_token.strip_prefix("<td")
        && let Some(before_gt) = stripped.split('>').next()
        && !before_gt.is_empty()
    {
        attrs.push_str(before_gt);
        if let Some(v) = parse_span_attr(before_gt, "colspan") {
            col_span = v;
        }
        if let Some(v) = parse_span_attr(before_gt, "rowspan") {
            row_span = v;
        }
    }

    let mut idx = start_idx + 1;

    // Consume subsequent attribute tokens until we hit the end of the opening tag
    while idx < tokens.len() {
        let token = tokens[idx].as_str();

        if token == ">"
            || token == "</td>"
            || token.starts_with("<td")
            || token == "<tr>"
            || token == "</tr>"
        {
            break;
        }

        attrs.push_str(token);
        if let Some(v) = parse_span_attr(token, "colspan") {
            col_span = v;
        }
        if let Some(v) = parse_span_attr(token, "rowspan") {
            row_span = v;
        }

        idx += 1;
    }

    // Skip ahead to the token after the closing `</td>` if present
    let mut next_index = idx;
    while next_index < tokens.len() {
        let token = tokens[next_index].as_str();
        if token == "</td>" {
            next_index += 1;
            break;
        }
        if token.starts_with("<td") || token == "<tr>" || token == "</tr>" {
            break;
        }
        next_index += 1;
    }

    ParsedTdTag {
        attrs,
        row_span,
        col_span,
        next_index: next_index.max(start_idx + 1),
    }
}

/// Output from table structure decoding.
#[derive(Debug, Clone)]
pub struct TableStructureDecodeOutput {
    /// HTML structure tokens for each image (without HTML wrapping)
    pub structure_tokens: Vec<Vec<String>>,
    /// Bounding boxes for table cells (4-point polygons: `[x1,y1,x2,y2,x3,y3,x4,y4]`)
    pub bboxes: Vec<Vec<[f32; 8]>>,
    /// Mean confidence scores for structure predictions
    pub structure_scores: Vec<f32>,
}

/// Table structure decoder that converts model outputs to HTML tokens and bboxes.
#[derive(Debug, Clone)]
pub struct TableStructureDecode {
    /// HTML token dictionary (e.g., `<html>`, `<table>`, `<tr>`, `<td>`, etc.)
    character_dict: Vec<String>,
    /// Special tokens to ignore during decoding
    ignored_tokens: Vec<usize>,
    /// Token indices that should have bounding boxes (e.g., `<td>`, `<td`, `<td></td>`)
    td_token_indices: Vec<usize>,
    /// End token index
    end_idx: usize,
}

impl TableStructureDecode {
    /// Creates a new table structure decoder from a dictionary file.
    ///
    /// # Alignment
    ///
    /// This follows `TableLabelDecode.add_special_char()` logic exactly:
    /// - "sos" token is prepended at index 0
    /// - "eos" token is appended at the end
    /// - The dict order is: ["sos", <original_dict...>, "eos"]
    pub fn from_dict_path(dict_path: &Path) -> Result<Self, OCRError> {
        // Load base dictionary
        let mut character_dict = Self::load_dict(dict_path)?;

        // Apply merge_no_span_structure logic
        // Default: merge_no_span_structure=True
        let merge_no_span_structure = true;
        if merge_no_span_structure {
            if !character_dict.contains(&"<td></td>".to_string()) {
                character_dict.push("<td></td>".to_string());
            }
            if let Some(pos) = character_dict.iter().position(|s| s == "<td>") {
                character_dict.remove(pos);
            }
        }

        // Add special tokens
        // CRITICAL: Use lowercase "sos" and "eos" without angle brackets
        // CRITICAL: "sos" goes at the START (index 0), "eos" goes at the END
        let beg_str = "sos";
        let end_str = "eos";

        let original_dict_size = character_dict.len();

        // Build final dict: ["sos"] + original_dict + ["eos"]
        let mut final_dict = Vec::with_capacity(original_dict_size + 2);
        final_dict.push(beg_str.to_string()); // Index 0: "sos"
        final_dict.extend(character_dict); // Index 1 to N: original dict
        final_dict.push(end_str.to_string()); // Index N+1: "eos"

        tracing::debug!("Dictionary processing complete:");
        tracing::debug!("  Original dict size: {}", original_dict_size);
        tracing::debug!("  Final dict size: {}", final_dict.len());
        tracing::debug!(
            "  First 10 dict entries: {:?}",
            &final_dict[..10.min(final_dict.len())]
        );
        tracing::debug!(
            "  Last 10 dict entries: {:?}",
            &final_dict[final_dict.len().saturating_sub(10)..]
        );

        // Build index mappings
        // "sos" is at index 0, "eos" is at the last index
        let start_idx = 0; // "sos" is always at index 0
        let end_idx = final_dict.len() - 1; // "eos" is always at the last index

        // Only ignore "sos" and "eos" tokens
        let ignored_tokens = vec![start_idx, end_idx];

        // Find TD token indices
        // Note: with merge_no_span_structure=true, "<td>" is removed and "<td></td>" is added
        let td_tokens = ["<td>", "<td", "<td></td>"];
        let td_token_indices: Vec<usize> = td_tokens
            .iter()
            .filter_map(|&token| final_dict.iter().position(|s| s == token))
            .collect();

        tracing::debug!("TD token indices: {:?}", td_token_indices);
        tracing::debug!(
            "Ignored tokens (sos={}, eos={}): {:?}",
            start_idx,
            end_idx,
            ignored_tokens
        );

        Ok(Self {
            character_dict: final_dict,
            ignored_tokens,
            td_token_indices,
            end_idx,
        })
    }

    /// Loads dictionary from file.
    ///
    /// Note: We preserve leading spaces for attribute tokens like ` colspan="2"`
    /// since they are needed to generate valid HTML like `<td colspan="2">`.
    fn load_dict(path: &Path) -> Result<Vec<String>, OCRError> {
        let file = File::open(path).map_err(|e| OCRError::ConfigError {
            message: format!("Failed to open dictionary file '{}': {}", path.display(), e),
        })?;

        let reader = BufReader::new(file);
        let mut dict = Vec::new();

        for line in reader.lines() {
            let line = line.map_err(|e| OCRError::ConfigError {
                message: format!("Failed to read dictionary line: {}", e),
            })?;
            // Only trim trailing whitespace, preserve leading spaces for attribute tokens
            // e.g., " colspan=\"2\"" needs the leading space for valid HTML generation
            let trimmed = line.trim_end();
            if !trimmed.is_empty() {
                dict.push(trimmed.to_string());
            }
        }

        Ok(dict)
    }

    /// Decodes structure logits and bbox predictions.
    ///
    /// # Arguments
    ///
    /// * `structure_logits` - [batch, seq_len, vocab_size] structure predictions
    /// * `bbox_preds` - [batch, seq_len, 8] bbox predictions (normalized coordinates)
    /// * `shape_info` - [(orig_h, orig_w, scale, pad_h, pad_w, target_size), ...] for each image.
    ///   `scale` is the ResizeByLong factor used during preprocessing: `target_size / max(orig_h, orig_w)`.
    ///
    /// # Returns
    ///
    /// Decoded structure tokens, bounding boxes, and confidence scores
    pub fn decode(
        &self,
        structure_logits: &ndarray::Array3<f32>,
        bbox_preds: &ndarray::Array3<f32>,
        shape_info: &[[f32; 6]],
    ) -> Result<TableStructureDecodeOutput, OCRError> {
        let batch_size = structure_logits.shape()[0];

        let mut structure_tokens_batch = Vec::with_capacity(batch_size);
        let mut bboxes_batch = Vec::with_capacity(batch_size);
        let mut scores_batch = Vec::with_capacity(batch_size);

        for batch_idx in 0..batch_size {
            let (tokens, bboxes, score) =
                self.decode_single(structure_logits, bbox_preds, batch_idx, shape_info)?;

            structure_tokens_batch.push(tokens);
            bboxes_batch.push(bboxes);
            scores_batch.push(score);
        }

        Ok(TableStructureDecodeOutput {
            structure_tokens: structure_tokens_batch,
            bboxes: bboxes_batch,
            structure_scores: scores_batch,
        })
    }

    /// Decodes a single image from the batch.
    fn decode_single(
        &self,
        structure_logits: &ndarray::Array3<f32>,
        bbox_preds: &ndarray::Array3<f32>,
        batch_idx: usize,
        shape_info: &[[f32; 6]],
    ) -> TableDecodeResult {
        let seq_len = structure_logits.shape()[1];

        // Argmax to get token indices
        let mut structure_tokens = Vec::new();
        let mut bboxes = Vec::new();
        let mut scores = Vec::new();

        tracing::debug!(
            "Starting token decoding for batch {}, sequence length {}",
            batch_idx,
            seq_len
        );
        tracing::debug!("Structure logits shape: {:?}", structure_logits.shape());
        tracing::debug!("Bbox preds shape: {:?}", bbox_preds.shape());

        for seq_idx in 0..seq_len {
            // Get token index (argmax over vocab dimension)
            let (token_idx, token_prob) = self.argmax_at(structure_logits, batch_idx, seq_idx);

            // Stop at end token
            if seq_idx > 0 && token_idx == self.end_idx {
                tracing::debug!(
                    "Stopping at end token (idx: {}) at sequence position {}",
                    token_idx,
                    seq_idx
                );
                break;
            }

            // Skip ignored tokens
            if self.ignored_tokens.contains(&token_idx) {
                tracing::debug!(
                    "Skipping ignored token at seq_idx {}: token_idx={}, token='{}'",
                    seq_idx,
                    token_idx,
                    self.character_dict
                        .get(token_idx)
                        .unwrap_or(&"<INVALID>".to_string())
                );
                continue;
            }

            // Get token string
            let token = self
                .character_dict
                .get(token_idx)
                .cloned()
                .unwrap_or_else(|| format!("UNK_{}", token_idx));

            tracing::debug!(
                "Decoded token at seq_idx {}: token_idx={}, dict_size={}, token='{}', prob={:.6}",
                seq_idx,
                token_idx,
                self.character_dict.len(),
                token,
                token_prob
            );

            structure_tokens.push(token.clone());
            scores.push(token_prob);

            // Extract bbox if this is a TD token
            if self.td_token_indices.contains(&token_idx) {
                let bbox = self.extract_bbox(bbox_preds, batch_idx, seq_idx, shape_info)?;
                tracing::debug!("Extracted bbox for TD token '{}': {:?}", token, bbox);
                bboxes.push(bbox);
            }
        }

        tracing::info!(
            "Decoded {} structure tokens: {:?}",
            structure_tokens.len(),
            structure_tokens
        );
        tracing::info!("Extracted {} bounding boxes", bboxes.len());

        // Use the mean of per-token max logits as structure score
        let mean_score = if scores.is_empty() {
            0.0
        } else {
            let sum: f32 = scores.iter().copied().sum();
            sum / (scores.len() as f32)
        };

        Ok((structure_tokens, bboxes, mean_score))
    }

    /// Finds argmax at specific position in structure logits.
    fn argmax_at(
        &self,
        logits: &ndarray::Array3<f32>,
        batch_idx: usize,
        seq_idx: usize,
    ) -> (usize, f32) {
        let vocab_size = logits.shape()[2];
        let mut max_idx = 0;
        let mut max_val = f32::NEG_INFINITY;

        for vocab_idx in 0..vocab_size {
            let val = logits[[batch_idx, seq_idx, vocab_idx]];
            if val > max_val {
                max_val = val;
                max_idx = vocab_idx;
            }
        }

        // Return the raw max logit value as score
        (max_idx, max_val)
    }

    /// Extracts and denormalizes bounding box.
    ///
    /// TableLabelDecode computes bbox scales as:
    ///   `ratio = min(padded_w / orig_w, padded_h / orig_h)` and
    ///   `scale = padded_{w,h} / ratio = max(orig_w, orig_h)`.
    /// For SLANeXt / SLANet_plus models the padded input is square, so both axes
    /// use the same `scale` (the longest side of the original image).
    fn extract_bbox(
        &self,
        bbox_preds: &ndarray::Array3<f32>,
        batch_idx: usize,
        seq_idx: usize,
        shape_info: &[[f32; 6]],
    ) -> Result<[f32; 8], OCRError> {
        let mut bbox = [0.0f32; 8];

        // Extract normalized coordinates
        for (idx, coord) in bbox.iter_mut().enumerate() {
            *coord = bbox_preds[[batch_idx, seq_idx, idx]];
        }

        // Denormalize using shape information
        if let Some(shape) = shape_info.get(batch_idx) {
            let [orig_h, orig_w, scale, _pad_h, _pad_w, target_size] = *shape;

            if scale <= 0.0 || target_size <= 0.0 {
                return Err(OCRError::InvalidInput {
                    message: format!(
                        "Invalid shape info for batch {}: scale={} target_size={}",
                        batch_idx, scale, target_size
                    ),
                });
            }

            // Equivalent to TableLabelDecode _get_bbox_scales() for SLANeXt/SLANet_plus:
            // padded_w == padded_h == target_size, so target_size / scale == max(orig_w, orig_h).
            let longest_side = target_size / scale;

            // Model outputs are normalized to [0,1] w.r.t. the padded square.
            // Scale by the longest side and clamp to the original dimensions.
            for (idx, coord_ref) in bbox.iter_mut().enumerate() {
                let mut coord = *coord_ref * longest_side;

                if idx % 2 == 0 {
                    coord = coord.clamp(0.0, orig_w);
                } else {
                    coord = coord.clamp(0.0, orig_h);
                }

                *coord_ref = coord;
            }
        }

        Ok(bbox)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_load_dict() {
        // This test would require the actual dictionary file
        // In practice, we'd test with a mock file
    }

    #[test]
    fn test_dictionary_processing() {
        // Test dictionary processing logic
        // Create a temporary dictionary
        let temp_dict = vec![
            "<html>".to_string(),
            "<body>".to_string(),
            "<table>".to_string(),
            "<tr>".to_string(),
            "<td>".to_string(), // This should be removed
            "<td".to_string(),
            " colspan=\"4\"".to_string(),
            ">".to_string(),
            "</td>".to_string(),
            "</tr>".to_string(),
            "</table>".to_string(),
            "</body>".to_string(),
            "</html>".to_string(),
        ];

        // Test merge_no_span_structure logic
        let mut processed_dict = temp_dict.clone();
        let merge_no_span_structure = true;
        if merge_no_span_structure {
            if !processed_dict.contains(&"<td></td>".to_string()) {
                processed_dict.push("<td></td>".to_string());
            }
            if let Some(pos) = processed_dict.iter().position(|s| s == "<td>") {
                processed_dict.remove(pos);
            }
        }

        // Check that <td> was removed
        assert!(!processed_dict.contains(&"<td>".to_string()));

        // Check that <td></td> was added
        assert!(processed_dict.contains(&"<td></td>".to_string()));

        // Add special tokens
        let beg_str = "sos";
        let end_str = "eos";
        let mut final_dict = vec![beg_str.to_string()];
        final_dict.extend(processed_dict);
        final_dict.push(end_str.to_string());

        // Check special tokens are in correct positions
        assert_eq!(final_dict[0], "sos");
        assert_eq!(final_dict[final_dict.len() - 1], "eos");

        // Check that original tokens are preserved (except <td>)
        assert!(final_dict.contains(&"<html>".to_string()));
        assert!(final_dict.contains(&"<td".to_string()));
        assert!(final_dict.contains(&" colspan=\"4\"".to_string()));
    }

    #[test]
    fn test_argmax() -> Result<(), OCRError> {
        use ndarray::Array3;

        let dict_path = Path::new("models/table_structure_dict.txt");
        if !dict_path.exists() {
            return Ok(()); // Skip if dict not available
        }

        let decoder = TableStructureDecode::from_dict_path(dict_path)?;

        // Create simple logits tensor
        let logits = Array3::zeros((1, 5, 50));
        let (idx, _prob) = decoder.argmax_at(&logits, 0, 0);
        assert_eq!(idx, 0); // Should be first token (all zeros)
        Ok(())
    }

    #[test]
    fn test_parse_cell_grid_info_simple() {
        // Simple 2x2 table
        let tokens = vec![
            "<tr>".to_string(),
            "<td></td>".to_string(),
            "<td></td>".to_string(),
            "</tr>".to_string(),
            "<tr>".to_string(),
            "<td></td>".to_string(),
            "<td></td>".to_string(),
            "</tr>".to_string(),
        ];

        let grid = parse_cell_grid_info(&tokens);
        assert_eq!(grid.len(), 4);

        // First row
        assert_eq!(grid[0].row, 0);
        assert_eq!(grid[0].col, 0);
        assert_eq!(grid[0].row_span, 1);
        assert_eq!(grid[0].col_span, 1);

        assert_eq!(grid[1].row, 0);
        assert_eq!(grid[1].col, 1);

        // Second row
        assert_eq!(grid[2].row, 1);
        assert_eq!(grid[2].col, 0);

        assert_eq!(grid[3].row, 1);
        assert_eq!(grid[3].col, 1);
    }

    #[test]
    fn test_parse_cell_grid_info_colspan() {
        // Table with colspan
        let tokens = vec![
            "<tr>".to_string(),
            "<td colspan=\"2\"></td>".to_string(),
            "</tr>".to_string(),
            "<tr>".to_string(),
            "<td></td>".to_string(),
            "<td></td>".to_string(),
            "</tr>".to_string(),
        ];

        let grid = parse_cell_grid_info(&tokens);
        assert_eq!(grid.len(), 3);

        // First row: single cell spanning 2 columns
        assert_eq!(grid[0].row, 0);
        assert_eq!(grid[0].col, 0);
        assert_eq!(grid[0].col_span, 2);

        // Second row: two cells
        assert_eq!(grid[1].row, 1);
        assert_eq!(grid[1].col, 0);

        assert_eq!(grid[2].row, 1);
        assert_eq!(grid[2].col, 1);
    }

    #[test]
    fn test_parse_cell_grid_info_rowspan() {
        // Table with rowspan
        let tokens = vec![
            "<tr>".to_string(),
            "<td rowspan=\"2\"></td>".to_string(),
            "<td></td>".to_string(),
            "</tr>".to_string(),
            "<tr>".to_string(),
            "<td></td>".to_string(), // Should be at col 1, not col 0
            "</tr>".to_string(),
        ];

        let grid = parse_cell_grid_info(&tokens);
        assert_eq!(grid.len(), 3);

        // First row
        assert_eq!(grid[0].row, 0);
        assert_eq!(grid[0].col, 0);
        assert_eq!(grid[0].row_span, 2);

        assert_eq!(grid[1].row, 0);
        assert_eq!(grid[1].col, 1);

        // Second row: cell should skip col 0 (occupied by rowspan)
        assert_eq!(grid[2].row, 1);
        assert_eq!(grid[2].col, 1);
    }

    #[test]
    fn test_parse_cell_grid_info_split_tokens_with_spans() {
        // Tokens are split like Paddle's dictionary: "<td", " colspan=\"2\"", ">", "</td>"
        let tokens = vec![
            "<tr>",
            "<td",
            " colspan=\"2\"",
            ">",
            "</td>",
            "</tr>", // first row, single cell span 2 cols
            "<tr>",
            "<td",
            " rowspan=\"2\"",
            ">",
            "</td>",
            "<td></td>",
            "</tr>", // second row, first cell spans 2 rows
            "<tr>",
            "<td></td>",
            "</tr>", // third row should skip col 0 due to rowspan
        ]
        .into_iter()
        .map(str::to_string)
        .collect::<Vec<_>>();

        let grid = parse_cell_grid_info(&tokens);
        assert_eq!(grid.len(), 4);

        // Row 0: one cell spanning two columns
        assert_eq!(grid[0].row, 0);
        assert_eq!(grid[0].col_span, 2);

        // Row 1: first cell has rowspan=2, so next cell should be at col 1
        assert_eq!(grid[1].row, 1);
        assert_eq!(grid[1].col, 0);
        assert_eq!(grid[1].row_span, 2);

        assert_eq!(grid[2].row, 1);
        assert_eq!(grid[2].col, 1);

        // Row 2: colspan from row 0 should not affect, but rowspan from row 1 should shift to col 1
        assert_eq!(grid[3].row, 2);
        assert_eq!(grid[3].col, 1);
    }

    #[test]
    fn test_wrap_table_html_with_split_tokens() {
        let tokens = vec!["<tr>", "<td", " colspan=\"2\"", ">", "</td>", "</tr>"]
            .into_iter()
            .map(str::to_string)
            .collect::<Vec<_>>();

        let cell_texts = vec![Some("Cell A".to_string())];
        let html = wrap_table_html_with_content(&tokens, &cell_texts);

        assert!(html.contains("<td colspan=\"2\">Cell A</td>"));
        assert!(html.starts_with("<html><body><table>"));
        assert!(html.ends_with("</table></body></html>"));
    }

    #[test]
    fn test_parse_span_attr() {
        assert_eq!(parse_span_attr("<td colspan=\"2\">", "colspan"), Some(2));
        assert_eq!(parse_span_attr("<td rowspan=\"3\">", "rowspan"), Some(3));
        assert_eq!(
            parse_span_attr("<td colspan=\"2\" rowspan=\"3\">", "colspan"),
            Some(2)
        );
        assert_eq!(
            parse_span_attr("<td colspan=\"2\" rowspan=\"3\">", "rowspan"),
            Some(3)
        );
        assert_eq!(parse_span_attr("<td></td>", "colspan"), None);
        assert_eq!(parse_span_attr("<td>", "rowspan"), None);
    }

    #[test]
    fn test_extract_bbox_longest_side_scaling_matches_standard() -> Result<(), OCRError> {
        let decoder = TableStructureDecode {
            character_dict: Vec::new(),
            ignored_tokens: Vec::new(),
            td_token_indices: Vec::new(),
            end_idx: 0,
        };

        // Simulate normalized bbox predictions for a portrait image (orig 300x600)
        let mut bbox_preds = ndarray::Array3::<f32>::zeros((1, 1, 8));
        let preds = [0.45f32, 0.25, 0.9, 0.25, 0.45, 0.8, 0.9, 0.8];
        for (i, val) in preds.iter().enumerate() {
            bbox_preds[[0, 0, i]] = *val;
        }

        let orig_h: f32 = 600.0;
        let orig_w: f32 = 300.0;
        let target_size: f32 = 512.0;
        let scale = target_size / orig_h.max(orig_w); // ResizeByLong scale
        let pad_h = 0.0;
        let pad_w = target_size - (orig_w * scale); // Padding on the right
        let shape_info = [[orig_h, orig_w, scale, pad_h, pad_w, target_size]];

        let bbox = decoder.extract_bbox(&bbox_preds, 0, 0, &shape_info)?;

        // TableLabelDecode _get_bbox_scales for SLANeXt/SLANet_plus:
        // ratio = target_size / max(orig_w, orig_h) => denorm factor = max dim.
        let longest_side = orig_h.max(orig_w);
        let expected = [
            (preds[0] * longest_side).clamp(0.0, orig_w),
            (preds[1] * longest_side).clamp(0.0, orig_h),
            (preds[2] * longest_side).clamp(0.0, orig_w),
            (preds[3] * longest_side).clamp(0.0, orig_h),
            (preds[4] * longest_side).clamp(0.0, orig_w),
            (preds[5] * longest_side).clamp(0.0, orig_h),
            (preds[6] * longest_side).clamp(0.0, orig_w),
            (preds[7] * longest_side).clamp(0.0, orig_h),
        ];

        for (idx, (got, exp)) in bbox.iter().zip(expected.iter()).enumerate() {
            assert!(
                (got - exp).abs() < 1e-3,
                "bbox coord {} mismatch: got {}, expected {}",
                idx,
                got,
                exp
            );
        }
        Ok(())
    }
}