1use crate::{Entity, EntityType, Error, Result};
45
46pub const DEFAULT_MAX_SPAN_WIDTH: usize = 12;
51
52#[derive(Debug, Clone)]
54pub struct SpanConfig {
55 pub max_span_width: usize,
57 pub threshold: f32,
59}
60
61impl Default for SpanConfig {
62 fn default() -> Self {
63 Self {
64 max_span_width: DEFAULT_MAX_SPAN_WIDTH,
65 threshold: 0.5,
66 }
67 }
68}
69
70pub fn make_span_tensors(num_words: usize, max_width: usize) -> (Vec<i64>, Vec<bool>) {
98 let num_spans = match num_words.checked_mul(max_width) {
100 Some(v) => v,
101 None => {
102 log::warn!(
103 "[span_utils] Span count overflow: {} words * {} max_width, returning empty",
104 num_words,
105 max_width
106 );
107 return (Vec::new(), Vec::new());
108 }
109 };
110
111 let span_idx_len = match num_spans.checked_mul(2) {
112 Some(v) => v,
113 None => {
114 log::warn!(
115 "[span_utils] Span idx length overflow: {} * 2, returning empty",
116 num_spans
117 );
118 return (Vec::new(), Vec::new());
119 }
120 };
121
122 let mut span_idx: Vec<i64> = vec![0; span_idx_len];
123 let mut span_mask: Vec<bool> = vec![false; num_spans];
124
125 for start in 0..num_words {
126 let remaining_width = num_words - start;
127 let actual_max_width = max_width.min(remaining_width);
128
129 for width in 0..actual_max_width {
130 let dim = match start.checked_mul(max_width) {
132 Some(v) => match v.checked_add(width) {
133 Some(d) => d,
134 None => continue,
135 },
136 None => continue,
137 };
138
139 if let Some(dim2) = dim.checked_mul(2) {
141 if dim2 + 1 < span_idx_len && dim < num_spans {
142 span_idx[dim2] = start as i64;
143 span_idx[dim2 + 1] = (start + width) as i64;
145 span_mask[dim] = true;
146 }
147 }
148 }
149 }
150
151 (span_idx, span_mask)
152}
153
154pub fn calculate_word_positions(text: &str, words: &[&str]) -> Result<Vec<(usize, usize)>> {
171 let mut positions = Vec::with_capacity(words.len());
172 let mut pos = 0;
173
174 for (idx, word) in words.iter().enumerate() {
175 if let Some(rel_start) = text[pos..].find(word) {
177 let abs_start = pos + rel_start;
178 let abs_end = abs_start + word.len();
179
180 if !positions.is_empty() {
182 let (_prev_start, prev_end) = positions[positions.len() - 1];
183 if abs_start < prev_end {
184 log::warn!(
185 "[span_utils] Word '{}' at {} overlaps with previous word ending at {}",
186 word,
187 abs_start,
188 prev_end
189 );
190 }
191 }
192
193 positions.push((abs_start, abs_end));
194 pos = abs_end;
195 } else {
196 return Err(Error::Parse(format!(
197 "Word '{}' (index {}) not found in text starting at position {}",
198 word, idx, pos
199 )));
200 }
201 }
202
203 Ok(positions)
204}
205
206pub fn extract_span<'a>(
219 text: &'a str,
220 word_positions: &[(usize, usize)],
221 start_word: usize,
222 end_word: usize,
223) -> Option<(&'a str, usize, usize)> {
224 let start_pos = word_positions.get(start_word)?.0;
225 let end_pos = word_positions.get(end_word)?.1;
226
227 if start_pos > end_pos || end_pos > text.len() {
228 return None;
229 }
230
231 Some((&text[start_pos..end_pos], start_pos, end_pos))
232}
233
234pub fn decode_span_output(
261 output_data: &[f32],
262 shape: &[i64],
263 text: &str,
264 text_words: &[&str],
265 entity_types: &[&str],
266 config: &SpanConfig,
267) -> Result<Vec<Entity>> {
268 if shape.len() < 3 {
270 return Err(Error::Parse(format!(
271 "Expected at least 3D output, got shape {:?}",
272 shape
273 )));
274 }
275
276 let (out_num_words, out_max_width, num_classes) = if shape.len() == 4 {
278 (shape[1] as usize, shape[2] as usize, shape[3] as usize)
280 } else if shape.len() == 3 {
281 (shape[0] as usize, shape[1] as usize, shape[2] as usize)
283 } else {
284 return Err(Error::Parse(format!(
285 "Unexpected output shape: {:?}",
286 shape
287 )));
288 };
289
290 log::debug!(
291 "[span_utils] Decoding: words={}, max_width={}, classes={}, data_len={}",
292 out_num_words,
293 out_max_width,
294 num_classes,
295 output_data.len()
296 );
297
298 let word_positions = calculate_word_positions(text, text_words)?;
300 let span_converter = crate::offset::SpanConverter::new(text);
303
304 let num_text_words = text_words.len();
306 if out_num_words < num_text_words {
307 log::warn!(
308 "[span_utils] Output has fewer words ({}) than input ({})",
309 out_num_words,
310 num_text_words
311 );
312 }
313
314 let mut entities = Vec::with_capacity(32);
315
316 for start in 0..num_text_words.min(out_num_words) {
318 for width in 0..config.max_span_width.min(out_max_width) {
319 let end = start + width;
320 if end >= num_text_words {
321 break;
322 }
323
324 let base_idx = (start * out_max_width * num_classes) + (width * num_classes);
326
327 let mut best_score = config.threshold;
328 let mut best_type_idx = None;
329
330 for type_idx in 0..num_classes.min(entity_types.len()) {
331 let score = output_data.get(base_idx + type_idx).copied().unwrap_or(0.0);
332
333 if score > best_score {
334 best_score = score;
335 best_type_idx = Some(type_idx);
336 }
337 }
338
339 if let Some(type_idx) = best_type_idx {
341 if let Some((span_text, start_byte, end_byte)) =
342 extract_span(text, &word_positions, start, end)
343 {
344 let entity_type = map_label_to_entity_type(entity_types[type_idx]);
345 let mut entity = Entity::new(
346 span_text,
347 entity_type,
348 span_converter.byte_to_char(start_byte),
349 span_converter.byte_to_char(end_byte),
350 best_score as f64,
351 );
352 entity.provenance =
353 Some(crate::Provenance::ml("span-decoder", best_score as f64));
354 entities.push(entity);
355 }
356 }
357 }
358 }
359
360 entities.sort_by(|a, b| {
362 a.start
363 .cmp(&b.start)
364 .then_with(|| b.confidence.total_cmp(&a.confidence))
366 });
367
368 let mut filtered = Vec::with_capacity(entities.len());
370 for entity in entities {
371 let overlaps = filtered
372 .iter()
373 .any(|e: &Entity| ranges_overlap(e.start, e.end, entity.start, entity.end));
374 if !overlaps {
375 filtered.push(entity);
376 }
377 }
378
379 Ok(filtered)
380}
381
382#[inline]
384fn ranges_overlap(start1: usize, end1: usize, start2: usize, end2: usize) -> bool {
385 start1 < end2 && start2 < end1
386}
387
388pub fn map_label_to_entity_type(label: &str) -> EntityType {
392 match label.to_lowercase().as_str() {
393 "person" | "per" => EntityType::Person,
394 "organization" | "org" | "company" | "corp" => EntityType::Organization,
395 "location" | "loc" | "place" | "gpe" => EntityType::Location,
396 "date" => EntityType::Date,
397 "datetime" => EntityType::Date,
398 "time" => EntityType::Time,
399 "money" | "currency" => EntityType::Money,
400 "monetary" => EntityType::Money,
401 "percent" | "percentage" => EntityType::Percent,
402 "email" => EntityType::Email,
403 "phone" => EntityType::Phone,
404 "url" => EntityType::Url,
405 "quantity" => EntityType::Quantity,
406 "measure" => EntityType::Quantity,
407 "cardinal" => EntityType::Cardinal,
408 "number" | "num" => EntityType::Cardinal,
409 "ordinal" => EntityType::Ordinal,
410 "event" => EntityType::Other("EVENT".to_string()),
411 "product" | "prod" => EntityType::Other("PRODUCT".to_string()),
412 "work_of_art" | "work" => EntityType::Other("WORK_OF_ART".to_string()),
413 "law" | "legal" => EntityType::Other("LAW".to_string()),
414 "language" | "lang" => EntityType::Other("LANGUAGE".to_string()),
415 "norp" => EntityType::Other("NORP".to_string()), "fac" | "facility" => EntityType::Other("FACILITY".to_string()),
417 "animal" => EntityType::Other("ANIMAL".to_string()),
419 "biology" => EntityType::Other("BIOLOGY".to_string()),
420 "celestial" => EntityType::Other("CELESTIAL".to_string()),
421 "culture" => EntityType::Other("CULTURE".to_string()),
422 "discipline" => EntityType::Other("DISCIPLINE".to_string()),
423 "disease" => EntityType::Other("DISEASE".to_string()),
424 "feeling" => EntityType::Other("FEELING".to_string()),
425 "food" => EntityType::Other("FOOD".to_string()),
426 "group" => EntityType::Other("GROUP".to_string()),
427 "instrument" => EntityType::Other("INSTRUMENT".to_string()),
428 "media" => EntityType::Other("MEDIA".to_string()),
429 "asset" => EntityType::Other("ASSET".to_string()),
430 "artifact" => EntityType::Other("ARTIFACT".to_string()),
431 "part" => EntityType::Other("PART".to_string()),
432 "physical_phenomenon" | "physical" => EntityType::Other("PHYSICAL_PHENOMENON".to_string()),
433 "plant" => EntityType::Other("PLANT".to_string()),
434 "property" => EntityType::Other("PROPERTY".to_string()),
435 "psych" => EntityType::Other("PSYCH".to_string()),
436 "relation" => EntityType::Other("RELATION".to_string()),
437 "struct" => EntityType::Other("STRUCT".to_string()),
438 "substance" => EntityType::Other("SUBSTANCE".to_string()),
439 "super" | "supernatural" => EntityType::Other("SUPER".to_string()),
440 "vehicle" | "vehi" => EntityType::Other("VEHICLE".to_string()),
441 _ => EntityType::Other(label.to_uppercase()),
442 }
443}
444
445#[cfg(test)]
446mod tests {
447 use super::*;
448
449 #[test]
450 fn test_make_span_tensors_basic() {
451 let (span_idx, span_mask) = make_span_tensors(3, 2);
452
453 assert_eq!(span_mask.len(), 6);
455 assert_eq!(span_idx.len(), 12);
456
457 assert!(span_mask[0]);
459 assert_eq!(span_idx[0], 0);
460 assert_eq!(span_idx[1], 0);
461
462 assert!(span_mask[1]);
464 assert_eq!(span_idx[2], 0);
465 assert_eq!(span_idx[3], 1);
466 }
467
468 #[test]
469 fn test_make_span_tensors_overflow_protection() {
470 let (span_idx, span_mask) = make_span_tensors(usize::MAX / 2, DEFAULT_MAX_SPAN_WIDTH);
472 assert!(span_idx.is_empty());
474 assert!(span_mask.is_empty());
475 }
476
477 #[test]
478 fn test_calculate_word_positions() {
479 let text = "Steve Jobs founded Apple";
480 let words: Vec<&str> = text.split_whitespace().collect();
481
482 let positions = calculate_word_positions(text, &words).unwrap();
483
484 assert_eq!(positions.len(), 4);
485 assert_eq!(positions[0], (0, 5)); assert_eq!(positions[1], (6, 10)); assert_eq!(positions[2], (11, 18)); assert_eq!(positions[3], (19, 24)); }
490
491 #[test]
492 fn test_extract_span() {
493 let text = "Steve Jobs founded Apple";
494 let positions = vec![(0, 5), (6, 10), (11, 18), (19, 24)];
495
496 let (span, start, end) = extract_span(text, &positions, 0, 0).unwrap();
498 assert_eq!(span, "Steve");
499 assert_eq!((start, end), (0, 5));
500
501 let (span, start, end) = extract_span(text, &positions, 0, 1).unwrap();
503 assert_eq!(span, "Steve Jobs");
504 assert_eq!((start, end), (0, 10));
505
506 let (span, start, end) = extract_span(text, &positions, 1, 3).unwrap();
508 assert_eq!(span, "Jobs founded Apple");
509 assert_eq!((start, end), (6, 24));
510 }
511
512 #[test]
513 fn test_map_label_to_entity_type() {
514 assert_eq!(map_label_to_entity_type("person"), EntityType::Person);
515 assert_eq!(map_label_to_entity_type("PER"), EntityType::Person);
516 assert_eq!(
517 map_label_to_entity_type("organization"),
518 EntityType::Organization
519 );
520 assert_eq!(map_label_to_entity_type("ORG"), EntityType::Organization);
521 assert_eq!(map_label_to_entity_type("location"), EntityType::Location);
522 assert_eq!(map_label_to_entity_type("GPE"), EntityType::Location);
523 assert_eq!(
524 map_label_to_entity_type("custom_type"),
525 EntityType::Other("CUSTOM_TYPE".to_string())
526 );
527 }
528
529 #[test]
530 fn test_ranges_overlap() {
531 assert!(ranges_overlap(0, 10, 5, 15)); assert!(ranges_overlap(0, 10, 0, 5)); assert!(ranges_overlap(5, 15, 0, 10)); assert!(!ranges_overlap(0, 5, 10, 15)); assert!(!ranges_overlap(0, 5, 5, 10)); }
537}