1use crate::types::{OcrResult, TextBlock};
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct FormField {
12 pub name: String,
14 pub value: String,
16 pub field_type: FieldType,
18 pub bbox: [f32; 4],
20 pub confidence: f32,
22}
23
24#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
26pub enum FieldType {
27 Text,
29 Checkbox,
31 RadioButton,
33 Signature,
35 Date,
37 Email,
39 Phone,
41 Currency,
43 Other,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct FormDetectionResult {
50 pub fields: Vec<FormField>,
52 pub checkboxes: Vec<Checkbox>,
54 pub radio_groups: Vec<RadioGroup>,
56 pub signatures: Vec<Signature>,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct Checkbox {
63 pub label: String,
65 pub checked: bool,
67 pub bbox: [f32; 4],
69 pub confidence: f32,
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct RadioGroup {
76 pub name: String,
78 pub options: Vec<RadioButton>,
80 pub selected: Option<usize>,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct RadioButton {
87 pub label: String,
89 pub selected: bool,
91 pub bbox: [f32; 4],
93 pub confidence: f32,
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct Signature {
100 pub label: Option<String>,
102 pub has_signature: bool,
104 pub bbox: [f32; 4],
106 pub quality: f32,
108}
109
110#[derive(Debug, Clone)]
112pub struct FormDetectionConfig {
113 pub min_confidence: f32,
115 pub detect_checkboxes: bool,
117 pub detect_radio_buttons: bool,
119 pub detect_signatures: bool,
121 pub max_label_value_distance: f32,
123}
124
125impl Default for FormDetectionConfig {
126 fn default() -> Self {
127 Self {
128 min_confidence: 0.7,
129 detect_checkboxes: true,
130 detect_radio_buttons: true,
131 detect_signatures: true,
132 max_label_value_distance: 100.0,
133 }
134 }
135}
136
137pub struct FormDetector {
139 config: FormDetectionConfig,
140}
141
142impl FormDetector {
143 pub fn new() -> Self {
145 Self {
146 config: FormDetectionConfig::default(),
147 }
148 }
149
150 pub fn with_config(config: FormDetectionConfig) -> Self {
152 Self { config }
153 }
154
155 pub fn detect_fields(&self, ocr_result: &OcrResult) -> FormDetectionResult {
157 let mut fields = Vec::new();
158 let checkboxes = if self.config.detect_checkboxes {
159 self.detect_checkboxes(&ocr_result.blocks)
160 } else {
161 Vec::new()
162 };
163
164 let radio_groups = if self.config.detect_radio_buttons {
165 self.detect_radio_groups(&ocr_result.blocks)
166 } else {
167 Vec::new()
168 };
169
170 let signatures = if self.config.detect_signatures {
171 self.detect_signatures(&ocr_result.blocks)
172 } else {
173 Vec::new()
174 };
175
176 fields.extend(self.detect_key_value_pairs(&ocr_result.blocks));
178
179 FormDetectionResult {
180 fields,
181 checkboxes,
182 radio_groups,
183 signatures,
184 }
185 }
186
187 fn detect_key_value_pairs(&self, blocks: &[TextBlock]) -> Vec<FormField> {
189 let mut fields = Vec::new();
190
191 for (i, block) in blocks.iter().enumerate() {
193 if block.text.contains(':') {
194 let parts: Vec<&str> = block.text.splitn(2, ':').collect();
196 if parts.len() == 2 {
197 let name = parts[0].trim().to_string();
198 let value = parts[1].trim().to_string();
199
200 if !name.is_empty() {
201 let field_type = self.infer_field_type(&value);
202 fields.push(FormField {
203 name,
204 value,
205 field_type,
206 bbox: block.bbox,
207 confidence: block.confidence,
208 });
209 }
210 }
211 } else if i + 1 < blocks.len() {
212 let next_block = &blocks[i + 1];
214 let distance = self.calculate_distance(block.bbox, next_block.bbox);
215
216 if distance < self.config.max_label_value_distance {
217 fields.push(FormField {
218 name: block.text.clone(),
219 value: next_block.text.clone(),
220 field_type: self.infer_field_type(&next_block.text),
221 bbox: self.merge_bboxes(block.bbox, next_block.bbox),
222 confidence: (block.confidence + next_block.confidence) / 2.0,
223 });
224 }
225 }
226 }
227
228 fields
229 }
230
231 fn detect_checkboxes(&self, blocks: &[TextBlock]) -> Vec<Checkbox> {
233 let mut checkboxes = Vec::new();
234
235 for block in blocks {
236 if block.text.trim() == "[x]" || block.text.trim() == "[X]" || block.text.contains('☑')
238 {
239 checkboxes.push(Checkbox {
240 label: String::new(), checked: true,
242 bbox: block.bbox,
243 confidence: block.confidence,
244 });
245 } else if block.text.trim() == "[ ]" || block.text.contains('☐') {
246 checkboxes.push(Checkbox {
247 label: String::new(),
248 checked: false,
249 bbox: block.bbox,
250 confidence: block.confidence,
251 });
252 }
253 }
254
255 checkboxes
256 }
257
258 fn detect_radio_groups(&self, _blocks: &[TextBlock]) -> Vec<RadioGroup> {
260 Vec::new()
262 }
263
264 fn detect_signatures(&self, blocks: &[TextBlock]) -> Vec<Signature> {
266 let mut signatures = Vec::new();
267
268 for block in blocks {
269 let lower_text = block.text.to_lowercase();
271 if lower_text.contains("signature")
272 || lower_text.contains("sign here")
273 || lower_text.contains("signed")
274 {
275 signatures.push(Signature {
276 label: Some(block.text.clone()),
277 has_signature: false, bbox: block.bbox,
279 quality: 0.0,
280 });
281 }
282 }
283
284 signatures
285 }
286
287 fn infer_field_type(&self, value: &str) -> FieldType {
289 let value_lower = value.to_lowercase();
290
291 if value.contains('@') && value.contains('.') {
292 FieldType::Email
293 } else if value.chars().filter(|c| c.is_ascii_digit()).count() >= 10 {
294 FieldType::Phone
295 } else if value.contains('$')
296 || value.contains('€')
297 || value.contains('£')
298 || value.contains('¥')
299 {
300 FieldType::Currency
301 } else if value_lower.contains("date")
302 || value.contains('/') && value.chars().filter(|c| c.is_ascii_digit()).count() >= 6
303 {
304 FieldType::Date
305 } else {
306 FieldType::Text
307 }
308 }
309
310 fn calculate_distance(&self, bbox1: [f32; 4], bbox2: [f32; 4]) -> f32 {
312 let x1_center = bbox1[0] + bbox1[2] / 2.0;
313 let y1_center = bbox1[1] + bbox1[3] / 2.0;
314 let x2_center = bbox2[0] + bbox2[2] / 2.0;
315 let y2_center = bbox2[1] + bbox2[3] / 2.0;
316
317 let dx = x2_center - x1_center;
318 let dy = y2_center - y1_center;
319
320 (dx * dx + dy * dy).sqrt()
321 }
322
323 fn merge_bboxes(&self, bbox1: [f32; 4], bbox2: [f32; 4]) -> [f32; 4] {
325 let min_x = bbox1[0].min(bbox2[0]);
326 let min_y = bbox1[1].min(bbox2[1]);
327 let max_x = (bbox1[0] + bbox1[2]).max(bbox2[0] + bbox2[2]);
328 let max_y = (bbox1[1] + bbox1[3]).max(bbox2[1] + bbox2[3]);
329
330 [min_x, min_y, max_x - min_x, max_y - min_y]
331 }
332}
333
334impl Default for FormDetector {
335 fn default() -> Self {
336 Self::new()
337 }
338}
339
340impl FormDetectionResult {
341 pub fn get_fields_by_type(&self, field_type: FieldType) -> Vec<&FormField> {
343 self.fields
344 .iter()
345 .filter(|f| f.field_type == field_type)
346 .collect()
347 }
348
349 pub fn get_field(&self, name: &str) -> Option<&FormField> {
351 self.fields.iter().find(|f| f.name == name)
352 }
353
354 pub fn get_checked_boxes(&self) -> Vec<&Checkbox> {
356 self.checkboxes.iter().filter(|c| c.checked).collect()
357 }
358
359 pub fn to_json(&self) -> serde_json::Result<String> {
361 serde_json::to_string_pretty(self)
362 }
363
364 pub fn to_key_value_map(&self) -> std::collections::HashMap<String, String> {
366 let mut map = std::collections::HashMap::new();
367
368 for field in &self.fields {
369 map.insert(field.name.clone(), field.value.clone());
370 }
371
372 for checkbox in &self.checkboxes {
373 map.insert(
374 checkbox.label.clone(),
375 if checkbox.checked { "true" } else { "false" }.to_string(),
376 );
377 }
378
379 map
380 }
381}
382
383#[cfg(test)]
384mod tests {
385 use super::*;
386
387 #[test]
388 fn test_form_field_creation() {
389 let field = FormField {
390 name: "Email".to_string(),
391 value: "test@example.com".to_string(),
392 field_type: FieldType::Email,
393 bbox: [0.0, 0.0, 200.0, 30.0],
394 confidence: 0.95,
395 };
396
397 assert_eq!(field.name, "Email");
398 assert_eq!(field.field_type, FieldType::Email);
399 }
400
401 #[test]
402 fn test_checkbox_creation() {
403 let checkbox = Checkbox {
404 label: "I agree".to_string(),
405 checked: true,
406 bbox: [0.0, 0.0, 20.0, 20.0],
407 confidence: 0.9,
408 };
409
410 assert!(checkbox.checked);
411 assert_eq!(checkbox.label, "I agree");
412 }
413
414 #[test]
415 fn test_field_type_inference() {
416 let detector = FormDetector::new();
417
418 assert_eq!(
419 detector.infer_field_type("test@example.com"),
420 FieldType::Email
421 );
422 assert_eq!(detector.infer_field_type("$100.50"), FieldType::Currency);
423 assert_eq!(detector.infer_field_type("555-1234-5678"), FieldType::Phone);
424 }
425
426 #[test]
427 fn test_form_detection_config() {
428 let config = FormDetectionConfig {
429 min_confidence: 0.8,
430 detect_checkboxes: false,
431 detect_radio_buttons: false,
432 detect_signatures: true,
433 max_label_value_distance: 50.0,
434 };
435
436 let detector = FormDetector::with_config(config.clone());
437 assert_eq!(detector.config.min_confidence, 0.8);
438 assert!(!detector.config.detect_checkboxes);
439 assert!(detector.config.detect_signatures);
440 }
441
442 #[test]
443 fn test_form_result_get_checked_boxes() {
444 let result = FormDetectionResult {
445 fields: Vec::new(),
446 checkboxes: vec![
447 Checkbox {
448 label: "Option 1".to_string(),
449 checked: true,
450 bbox: [0.0, 0.0, 20.0, 20.0],
451 confidence: 0.9,
452 },
453 Checkbox {
454 label: "Option 2".to_string(),
455 checked: false,
456 bbox: [0.0, 30.0, 20.0, 20.0],
457 confidence: 0.9,
458 },
459 ],
460 radio_groups: Vec::new(),
461 signatures: Vec::new(),
462 };
463
464 let checked = result.get_checked_boxes();
465 assert_eq!(checked.len(), 1);
466 assert_eq!(checked[0].label, "Option 1");
467 }
468
469 #[test]
470 fn test_signature_detection() {
471 let signature = Signature {
472 label: Some("Signature:".to_string()),
473 has_signature: false,
474 bbox: [0.0, 0.0, 200.0, 50.0],
475 quality: 0.0,
476 };
477
478 assert!(!signature.has_signature);
479 assert_eq!(signature.quality, 0.0);
480 }
481}