1use std::collections::HashMap;
7use std::path::{Path, PathBuf};
8use std::sync::OnceLock;
9
10use anyhow::{Context, Result};
11use candle_core::{DType, Device, Tensor};
12use candle_nn::{Linear, Module, VarBuilder};
13use candle_transformers::models::bert::{BertModel, Config as BertConfig};
14use regex::Regex;
15use serde::Deserialize;
16use unicode_normalization::UnicodeNormalization;
17
18use crate::paths::AppPaths;
19use crate::storage::entities::{NewEntity, NewRelationship};
20
21const MODEL_ID: &str = "Davlan/bert-base-multilingual-cased-ner-hrl";
22const MAX_SEQ_LEN: usize = 512;
23const STRIDE: usize = 256;
24const MAX_ENTS: usize = 30;
25#[cfg(test)]
28const TOP_K_RELATIONS: usize = 5;
29const DEFAULT_RELATION: &str = "mentions";
30const MIN_ENTITY_CHARS: usize = 2;
31
32static REGEX_EMAIL: OnceLock<Regex> = OnceLock::new();
33static REGEX_URL: OnceLock<Regex> = OnceLock::new();
34static REGEX_UUID: OnceLock<Regex> = OnceLock::new();
35static REGEX_ALL_CAPS: OnceLock<Regex> = OnceLock::new();
36static REGEX_SECTION_MARKER: OnceLock<Regex> = OnceLock::new();
38static REGEX_BRAND_CAMEL: OnceLock<Regex> = OnceLock::new();
40
41const ALL_CAPS_STOPWORDS: &[&str] = &[
59 "ACEITE",
60 "ACID",
61 "ACK",
62 "ACL",
63 "ACRESCENTADO",
64 "ADAPTER",
65 "ADICIONADA",
66 "ADICIONADAS",
67 "ADICIONADO",
68 "ADICIONADOS",
69 "ADICIONAR",
70 "AGENTS",
71 "AINDA",
72 "ALL",
73 "ALTA",
74 "ALWAYS",
75 "APENAS",
76 "API",
77 "ARTEFATOS",
78 "ATIVA",
79 "ATIVO",
80 "BAIXA",
81 "BANCO",
82 "BLOQUEAR",
83 "BORDA",
84 "BUG",
85 "CAPÍTULO",
86 "CASO",
87 "CEO",
88 "CHECKLIST",
89 "CLARO",
90 "CLI",
91 "COMPLETED",
92 "CONFIRMADO",
93 "CONFIRMARAM",
94 "CONFIRME",
95 "CONFIRMEI",
96 "CONFIRMOU",
97 "CONTRATO",
98 "CRIE",
99 "CRÍTICO",
100 "CRITICAL",
101 "CSV",
102 "DDL",
103 "DEFAULT",
104 "DEFINIR",
105 "DEPARTMENT",
106 "DESC",
107 "DEVE",
108 "DEVEMOS",
109 "DISCO",
110 "DONE",
111 "DSL",
112 "DTO",
113 "EFEITO",
114 "ENTRADA",
115 "EPERM",
116 "ERROR",
117 "ESCREVA",
118 "ESCRITA",
119 "ESRCH",
120 "ESSA",
121 "ESSE",
122 "ESSENCIAL",
123 "ESTA",
124 "ESTADO",
125 "ESTE",
126 "ETAPA",
127 "EVITAR",
128 "EXEMPLO",
129 "EXPANDIR",
130 "EXPOR",
131 "FALHA",
132 "FASE",
133 "FATO",
134 "FIFO",
135 "FIXED",
136 "FIXME",
137 "FLUXO",
138 "FONTES",
139 "FORBIDDEN",
140 "FUNCIONA",
141 "HACK",
142 "HEARTBEAT",
143 "HTTP",
144 "HTTPS",
145 "INATIVO",
146 "JAMAIS",
147 "JSON",
148 "JWT",
149 "LEITURA",
150 "LLM",
151 "MESMO",
152 "METADADOS",
153 "MUST",
154 "NEGUE",
155 "NEVER",
156 "NOTE",
157 "NUNCA",
158 "OBRIGATORIA",
159 "OBRIGATÓRIO",
160 "PADRÃO",
161 "PASSIVA",
162 "PASSO",
163 "PENDING",
164 "PLAN",
165 "PODEMOS",
166 "PONTEIROS",
167 "PROIBIDO",
168 "PROJETO",
169 "RECUSE",
170 "REGRA",
171 "REGRAS",
172 "REQUIRED",
173 "REQUISITO",
174 "REST",
175 "SEÇÃO",
176 "SEMPRE",
177 "SHALL",
178 "SHOULD",
179 "SOMENTE",
180 "SOUL",
181 "TODAS",
182 "TODO",
183 "TODOS",
184 "TOKEN",
185 "TOOLS",
186 "TSV",
187 "UI",
188 "URL",
189 "USAR",
190 "VALIDAR",
191 "VAMOS",
192 "VOCÊ",
193 "WARNING",
194 "XML",
195 "YAML",
196];
197
198const HTTP_METHODS: &[&str] = &[
201 "GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS", "CONNECT", "TRACE",
202];
203
204fn is_filtered_all_caps(token: &str) -> bool {
205 let is_identifier = token.contains('_');
207 if is_identifier {
208 return false;
209 }
210 ALL_CAPS_STOPWORDS.contains(&token) || HTTP_METHODS.contains(&token)
211}
212
213fn regex_email() -> &'static Regex {
214 REGEX_EMAIL.get_or_init(|| {
216 Regex::new(r"[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}")
217 .expect("compile-time validated email regex literal")
218 })
219}
220
221fn regex_url() -> &'static Regex {
222 REGEX_URL.get_or_init(|| {
224 Regex::new(r#"https?://[^\s\)\]\}"'<>]+"#)
225 .expect("compile-time validated URL regex literal")
226 })
227}
228
229fn regex_uuid() -> &'static Regex {
230 REGEX_UUID.get_or_init(|| {
232 Regex::new(r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}")
233 .expect("compile-time validated UUID regex literal")
234 })
235}
236
237fn regex_all_caps() -> &'static Regex {
238 REGEX_ALL_CAPS.get_or_init(|| {
239 Regex::new(r"\b[A-Z][A-Z0-9_]{2,}\b")
240 .expect("compile-time validated all-caps regex literal")
241 })
242}
243
244fn regex_section_marker() -> &'static Regex {
245 REGEX_SECTION_MARKER.get_or_init(|| {
246 Regex::new("\\b(?:Etapa|Fase|Passo|Camada|Se\u{00e7}\u{00e3}o|Cap\u{00ed}tulo)\\s+\\d+\\b")
253 .expect("compile-time validated section marker regex literal")
254 })
255}
256
257fn regex_brand_camel() -> &'static Regex {
258 REGEX_BRAND_CAMEL.get_or_init(|| {
259 Regex::new(r"\b[A-Z][a-z]+[A-Z][A-Za-z]+\b")
262 .expect("compile-time validated CamelCase brand regex literal")
263 })
264}
265
266#[derive(Debug, Clone, PartialEq)]
267pub struct ExtractedEntity {
268 pub name: String,
269 pub entity_type: String,
270}
271
272#[derive(Debug, Clone)]
274pub struct ExtractedUrl {
275 pub url: String,
276 pub offset: usize,
278}
279
280#[derive(Debug, Clone)]
281pub struct ExtractionResult {
282 pub entities: Vec<NewEntity>,
283 pub relationships: Vec<NewRelationship>,
284 pub relationships_truncated: bool,
287 pub extraction_method: String,
290 pub urls: Vec<ExtractedUrl>,
292}
293
294pub trait Extractor: Send + Sync {
295 fn extract(&self, body: &str) -> Result<ExtractionResult>;
296}
297
298#[derive(Deserialize)]
299struct ModelConfig {
300 #[serde(default)]
301 id2label: HashMap<String, String>,
302 hidden_size: usize,
303}
304
305struct BertNerModel {
306 bert: BertModel,
307 classifier: Linear,
308 device: Device,
309 id2label: HashMap<usize, String>,
310}
311
312impl BertNerModel {
313 fn load(model_dir: &Path) -> Result<Self> {
314 let config_path = model_dir.join("config.json");
315 let weights_path = model_dir.join("model.safetensors");
316
317 let config_str = std::fs::read_to_string(&config_path)
318 .with_context(|| format!("lendo config.json em {config_path:?}"))?;
319 let model_cfg: ModelConfig =
320 serde_json::from_str(&config_str).context("parseando config.json do modelo NER")?;
321
322 let id2label: HashMap<usize, String> = model_cfg
323 .id2label
324 .into_iter()
325 .filter_map(|(k, v)| k.parse::<usize>().ok().map(|n| (n, v)))
326 .collect();
327
328 let num_labels = id2label.len().max(9);
329 let hidden_size = model_cfg.hidden_size;
330
331 let bert_config_str = std::fs::read_to_string(&config_path)
332 .with_context(|| format!("relendo config.json para bert em {config_path:?}"))?;
333 let bert_cfg: BertConfig =
334 serde_json::from_str(&bert_config_str).context("parseando BertConfig")?;
335
336 let device = Device::Cpu;
337
338 let vb = unsafe {
346 VarBuilder::from_mmaped_safetensors(&[&weights_path], DType::F32, &device)
347 .with_context(|| format!("mapping {weights_path:?}"))?
348 };
349 let bert = BertModel::load(vb.pp("bert"), &bert_cfg).context("loading BertModel")?;
350
351 let cls_vb = vb.pp("classifier");
354 let weight = cls_vb
355 .get((num_labels, hidden_size), "weight")
356 .context("carregando classifier.weight do safetensors")?;
357 let bias = cls_vb
358 .get(num_labels, "bias")
359 .context("carregando classifier.bias do safetensors")?;
360 let classifier = Linear::new(weight, Some(bias));
361
362 Ok(Self {
363 bert,
364 classifier,
365 device,
366 id2label,
367 })
368 }
369
370 fn predict(&self, token_ids: &[u32], attention_mask: &[u32]) -> Result<Vec<String>> {
371 let len = token_ids.len();
372 let ids_i64: Vec<i64> = token_ids.iter().map(|&x| x as i64).collect();
373 let mask_i64: Vec<i64> = attention_mask.iter().map(|&x| x as i64).collect();
374
375 let input_ids = Tensor::from_vec(ids_i64, (1, len), &self.device)
376 .context("creating tensor input_ids")?;
377 let token_type_ids = Tensor::zeros((1, len), DType::I64, &self.device)
378 .context("creating tensor token_type_ids")?;
379 let attn_mask = Tensor::from_vec(mask_i64, (1, len), &self.device)
380 .context("creating tensor attention_mask")?;
381
382 let sequence_output = self
383 .bert
384 .forward(&input_ids, &token_type_ids, Some(&attn_mask))
385 .context("BertModel forward pass")?;
386
387 let logits = self
388 .classifier
389 .forward(&sequence_output)
390 .context("classifier forward pass")?;
391
392 let logits_2d = logits.squeeze(0).context("removing batch dimension")?;
393
394 let num_tokens = logits_2d.dim(0).context("dim(0)")?;
395
396 let mut labels = Vec::with_capacity(num_tokens);
397 for i in 0..num_tokens {
398 let token_logits = logits_2d.get(i).context("get token logits")?;
399 let vec: Vec<f32> = token_logits.to_vec1().context("to_vec1 logits")?;
400 let argmax = vec
401 .iter()
402 .enumerate()
403 .max_by(|(_, a), (_, b)| {
404 a.partial_cmp(b)
405 .expect("BERT NER logits invariant: no NaN in classifier output")
406 })
407 .map(|(idx, _)| idx)
408 .unwrap_or(0);
409 let label = self
410 .id2label
411 .get(&argmax)
412 .cloned()
413 .unwrap_or_else(|| "O".to_string());
414 labels.push(label);
415 }
416
417 Ok(labels)
418 }
419
420 fn predict_batch(&self, windows: &[(Vec<u32>, Vec<String>)]) -> Result<Vec<Vec<String>>> {
429 let batch_size = windows.len();
430 let max_len = windows.iter().map(|(ids, _)| ids.len()).max().unwrap_or(0);
431 if max_len == 0 {
432 return Ok(vec![vec![]; batch_size]);
433 }
434
435 let mut padded_ids: Vec<Tensor> = Vec::with_capacity(batch_size);
436 let mut padded_masks: Vec<Tensor> = Vec::with_capacity(batch_size);
437
438 for (ids, _) in windows {
439 let len = ids.len();
440 let pad_right = max_len - len;
441
442 let ids_i64: Vec<i64> = ids.iter().map(|&x| x as i64).collect();
443 let t = Tensor::from_vec(ids_i64, len, &self.device)
445 .context("creating id tensor for batch")?;
446 let t = t
447 .pad_with_zeros(0, 0, pad_right)
448 .context("padding id tensor")?;
449 padded_ids.push(t);
450
451 let mut mask_i64 = vec![1i64; len];
453 mask_i64.extend(vec![0i64; pad_right]);
454 let m = Tensor::from_vec(mask_i64, max_len, &self.device)
455 .context("creating mask tensor for batch")?;
456 padded_masks.push(m);
457 }
458
459 let input_ids = Tensor::stack(&padded_ids, 0).context("stack input_ids")?;
461 let attn_mask = Tensor::stack(&padded_masks, 0).context("stack attn_mask")?;
462 let token_type_ids = Tensor::zeros((batch_size, max_len), DType::I64, &self.device)
463 .context("creating token_type_ids tensor for batch")?;
464
465 let sequence_output = self
467 .bert
468 .forward(&input_ids, &token_type_ids, Some(&attn_mask))
469 .context("BertModel batch forward pass")?;
470 let logits = self
473 .classifier
474 .forward(&sequence_output)
475 .context("forward pass batch classificador")?;
476 let mut results = Vec::with_capacity(batch_size);
479 for (i, (window_ids, _)) in windows.iter().enumerate() {
480 let example_logits = logits.get(i).context("get logits exemplo")?;
481 let real_len = window_ids.len();
483 let example_slice = example_logits
484 .narrow(0, 0, real_len)
485 .context("narrow para tokens reais")?;
486 let logits_2d: Vec<Vec<f32>> = example_slice.to_vec2().context("to_vec2 logits")?;
487
488 let labels: Vec<String> = logits_2d
489 .iter()
490 .map(|token_logits| {
491 let argmax = token_logits
492 .iter()
493 .enumerate()
494 .max_by(|(_, a), (_, b)| {
495 a.partial_cmp(b)
496 .expect("BERT NER logits invariant: no NaN in classifier output")
497 })
498 .map(|(idx, _)| idx)
499 .unwrap_or(0);
500 self.id2label
501 .get(&argmax)
502 .cloned()
503 .unwrap_or_else(|| "O".to_string())
504 })
505 .collect();
506
507 results.push(labels);
508 }
509
510 Ok(results)
511 }
512}
513
514static NER_MODEL: OnceLock<Option<BertNerModel>> = OnceLock::new();
515
516fn get_or_init_model(paths: &AppPaths) -> Option<&'static BertNerModel> {
517 NER_MODEL
518 .get_or_init(|| match load_model(paths) {
519 Ok(m) => Some(m),
520 Err(e) => {
521 tracing::warn!("NER model unavailable (graceful degradation): {e:#}");
522 None
523 }
524 })
525 .as_ref()
526}
527
528fn model_dir(paths: &AppPaths) -> PathBuf {
529 paths.models.join("bert-multilingual-ner")
530}
531
532fn ensure_model_files(paths: &AppPaths) -> Result<PathBuf> {
533 let dir = model_dir(paths);
534 std::fs::create_dir_all(&dir).with_context(|| format!("creating model directory: {dir:?}"))?;
535
536 let weights = dir.join("model.safetensors");
537 let config = dir.join("config.json");
538 let tokenizer = dir.join("tokenizer.json");
539
540 if weights.exists() && config.exists() && tokenizer.exists() {
541 return Ok(dir);
542 }
543
544 tracing::info!("Downloading NER model (first run, ~676 MB)...");
545 crate::output::emit_progress_i18n(
546 "Downloading NER model (first run, ~676 MB)...",
547 crate::i18n::validation::runtime_pt::downloading_ner_model(),
548 );
549
550 let api = huggingface_hub::api::sync::Api::new().context("creating HF Hub client")?;
551 let repo = api.model(MODEL_ID.to_string());
552
553 for (remote, local) in &[
557 ("model.safetensors", "model.safetensors"),
558 ("config.json", "config.json"),
559 ("onnx/tokenizer.json", "tokenizer.json"),
560 ("tokenizer_config.json", "tokenizer_config.json"),
561 ] {
562 let dest = dir.join(local);
563 if !dest.exists() {
564 let src = repo
565 .get(remote)
566 .with_context(|| format!("baixando {remote} do HF Hub"))?;
567 std::fs::copy(&src, &dest).with_context(|| format!("copiando {local} para cache"))?;
568 }
569 }
570
571 Ok(dir)
572}
573
574fn load_model(paths: &AppPaths) -> Result<BertNerModel> {
575 let dir = ensure_model_files(paths)?;
576 BertNerModel::load(&dir)
577}
578
579#[inline]
581fn hash_str(s: &str) -> u64 {
582 use std::hash::{Hash, Hasher};
583 let mut h = std::collections::hash_map::DefaultHasher::new();
584 s.hash(&mut h);
585 h.finish()
586}
587
588fn apply_regex_prefilter(body: &str) -> Vec<ExtractedEntity> {
589 let mut entities = Vec::new();
590 let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
591
592 let add = |entities: &mut Vec<ExtractedEntity>,
593 seen: &mut std::collections::HashSet<String>,
594 name: &str,
595 entity_type: &str| {
596 let name = name.trim().to_string();
597 if name.len() >= MIN_ENTITY_CHARS && seen.insert(name.clone()) {
598 entities.push(ExtractedEntity {
599 name,
600 entity_type: entity_type.to_string(),
601 });
602 }
603 };
604
605 let cleaned = regex_section_marker().replace_all(body, " ");
608 let cleaned = cleaned.as_ref();
609
610 for m in regex_email().find_iter(cleaned) {
611 add(&mut entities, &mut seen, m.as_str(), "concept");
613 }
614 for m in regex_uuid().find_iter(cleaned) {
615 add(&mut entities, &mut seen, m.as_str(), "concept");
616 }
617 for m in regex_all_caps().find_iter(cleaned) {
618 let candidate = m.as_str();
619 if !is_filtered_all_caps(candidate) {
621 add(&mut entities, &mut seen, candidate, "concept");
622 }
623 }
624 for m in regex_brand_camel().find_iter(cleaned) {
627 let name = m.as_str();
628 if !ALL_CAPS_STOPWORDS.contains(&name.to_uppercase().as_str()) {
630 add(&mut entities, &mut seen, name, "organization");
631 }
632 }
633
634 entities
635}
636
637pub fn extract_urls(body: &str) -> Vec<ExtractedUrl> {
641 let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
642 let mut result = Vec::new();
643 for m in regex_url().find_iter(body) {
644 let raw = m.as_str();
645 let cleaned = raw
646 .trim_end_matches('`')
647 .trim_end_matches(',')
648 .trim_end_matches('.')
649 .trim_end_matches(';')
650 .trim_end_matches(')')
651 .trim_end_matches(']')
652 .trim_end_matches('}');
653 if cleaned.len() >= 10 && seen.insert(cleaned.to_string()) {
654 result.push(ExtractedUrl {
655 url: cleaned.to_string(),
656 offset: m.start(),
657 });
658 }
659 }
660 result
661}
662
663fn iob_to_entities(tokens: &[String], labels: &[String]) -> Vec<ExtractedEntity> {
664 let mut entities: Vec<ExtractedEntity> = Vec::new();
665 let mut current_parts: Vec<String> = Vec::new();
666 let mut current_type: Option<String> = None;
667
668 let flush =
669 |parts: &mut Vec<String>, typ: &mut Option<String>, entities: &mut Vec<ExtractedEntity>| {
670 if let Some(t) = typ.take() {
671 let name = parts.join(" ").trim().to_string();
672 let is_single_caps = !name.contains(' ')
676 && name == name.to_uppercase()
677 && name.len() >= MIN_ENTITY_CHARS;
678 let should_skip = is_single_caps && is_filtered_all_caps(&name);
679 let is_section_marker = regex_section_marker().is_match(&name);
684 if name.len() >= MIN_ENTITY_CHARS && !should_skip && !is_section_marker {
685 entities.push(ExtractedEntity {
686 name,
687 entity_type: t,
688 });
689 }
690 parts.clear();
691 }
692 };
693
694 for (token, label) in tokens.iter().zip(labels.iter()) {
695 if label == "O" {
696 flush(&mut current_parts, &mut current_type, &mut entities);
697 continue;
698 }
699
700 let (prefix, bio_type) = if let Some(rest) = label.strip_prefix("B-") {
701 ("B", rest)
702 } else if let Some(rest) = label.strip_prefix("I-") {
703 ("I", rest)
704 } else {
705 flush(&mut current_parts, &mut current_type, &mut entities);
706 continue;
707 };
708
709 const PT_VERB_FALSE_POSITIVES: &[&str] = &[
716 "L\u{00ea}",
717 "V\u{00ea}",
718 "C\u{00e1}",
719 "P\u{00f4}r",
720 "Ser",
721 "Vir",
722 "Ver",
723 "Dar",
724 "Ler",
725 "Ter",
726 ];
727
728 let entity_type = match bio_type {
729 "DATE" => "date",
731 "PER" => {
732 if PT_VERB_FALSE_POSITIVES.contains(&token.as_str()) {
734 flush(&mut current_parts, &mut current_type, &mut entities);
735 continue;
736 }
737 "person"
738 }
739 "ORG" => {
740 let t = token.to_lowercase();
741 if t.contains("lib")
742 || t.contains("sdk")
743 || t.contains("cli")
744 || t.contains("crate")
745 || t.contains("npm")
746 {
747 "tool"
748 } else {
749 "organization"
751 }
752 }
753 "LOC" => "location",
755 other => other,
756 };
757
758 if prefix == "B" {
759 if token.starts_with("##") {
760 let clean = token.strip_prefix("##").unwrap_or(token.as_str());
763 if let Some(last) = current_parts.last_mut() {
764 last.push_str(clean);
765 }
766 continue;
767 }
768 flush(&mut current_parts, &mut current_type, &mut entities);
769 current_parts.push(token.clone());
770 current_type = Some(entity_type.to_string());
771 } else if prefix == "I" && current_type.is_some() {
772 let clean = token.strip_prefix("##").unwrap_or(token.as_str());
773 if token.starts_with("##") {
774 if let Some(last) = current_parts.last_mut() {
775 last.push_str(clean);
776 }
777 } else {
778 current_parts.push(clean.to_string());
779 }
780 }
781 }
782
783 flush(&mut current_parts, &mut current_type, &mut entities);
784 entities
785}
786
787#[cfg(test)]
797fn build_relationships(entities: &[NewEntity]) -> (Vec<NewRelationship>, bool) {
798 if entities.len() < 2 {
799 return (Vec::new(), false);
800 }
801
802 let max_rels = crate::constants::max_relationships_per_memory();
805 let n = entities.len().min(MAX_ENTS);
806 let mut rels: Vec<NewRelationship> = Vec::new();
807 let mut seen: std::collections::HashSet<(usize, usize)> = std::collections::HashSet::new();
808
809 let mut hit_cap = false;
810 'outer: for i in 0..n {
811 if rels.len() >= max_rels {
812 hit_cap = true;
813 break;
814 }
815
816 let mut for_entity = 0usize;
817 for j in (i + 1)..n {
818 if for_entity >= TOP_K_RELATIONS {
819 break;
820 }
821 if rels.len() >= max_rels {
822 hit_cap = true;
823 break 'outer;
824 }
825
826 let key = (i.min(j), i.max(j));
827 if !seen.insert(key) {
828 continue;
829 }
830
831 rels.push(NewRelationship {
832 source: entities[i].name.clone(),
833 target: entities[j].name.clone(),
834 relation: DEFAULT_RELATION.to_string(),
835 strength: 0.5,
836 description: None,
837 });
838 for_entity += 1;
839 }
840 }
841
842 if hit_cap {
844 tracing::warn!(
845 "relationships truncated to {max_rels} (with {n} entities, theoretical max was ~{}x combinations)",
846 n.saturating_sub(1)
847 );
848 }
849
850 (rels, hit_cap)
851}
852
853fn build_relationships_by_sentence_cooccurrence(
864 body: &str,
865 entities: &[NewEntity],
866) -> (Vec<NewRelationship>, bool) {
867 if entities.len() < 2 {
868 return (Vec::new(), false);
869 }
870
871 let max_rels = crate::constants::max_relationships_per_memory();
872 let lower_names: Vec<(usize, String)> = entities
873 .iter()
874 .take(MAX_ENTS)
875 .enumerate()
876 .map(|(i, e)| (i, e.name.to_lowercase()))
877 .collect();
878
879 let mut rels: Vec<NewRelationship> = Vec::new();
880 let mut seen: std::collections::HashSet<(usize, usize)> = std::collections::HashSet::new();
881 let mut hit_cap = false;
882
883 for sentence in body.split(['.', '!', '?', '\n']) {
884 if sentence.trim().is_empty() {
885 continue;
886 }
887 let lower_sentence = sentence.to_lowercase();
888 let present: Vec<usize> = lower_names
889 .iter()
890 .filter(|(_, name)| !name.is_empty() && lower_sentence.contains(name.as_str()))
891 .map(|(i, _)| *i)
892 .collect();
893
894 if present.len() < 2 {
895 continue;
896 }
897
898 for i in 0..present.len() {
899 for j in (i + 1)..present.len() {
900 if rels.len() >= max_rels {
901 hit_cap = true;
902 tracing::warn!(
903 "relationships truncated to {max_rels} during sentence-level pairing"
904 );
905 return (rels, hit_cap);
906 }
907 let ei = present[i];
908 let ej = present[j];
909 let key = (ei.min(ej), ei.max(ej));
910 if seen.insert(key) {
911 rels.push(NewRelationship {
912 source: entities[ei].name.clone(),
913 target: entities[ej].name.clone(),
914 relation: DEFAULT_RELATION.to_string(),
915 strength: 0.5,
916 description: None,
917 });
918 }
919 }
920 }
921 }
922
923 (rels, hit_cap)
924}
925
926fn run_ner_sliding_window(
927 model: &BertNerModel,
928 body: &str,
929 paths: &AppPaths,
930) -> Result<Vec<ExtractedEntity>> {
931 let tokenizer_path = model_dir(paths).join("tokenizer.json");
932 let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
933 .map_err(|e| anyhow::anyhow!("loading NER tokenizer: {e}"))?;
934
935 let encoding = tokenizer
936 .encode(body, false)
937 .map_err(|e| anyhow::anyhow!("encoding NER input: {e}"))?;
938
939 let mut all_ids: Vec<u32> = encoding.get_ids().to_vec();
940 let mut all_tokens: Vec<String> = encoding
941 .get_tokens()
942 .iter()
943 .map(|s| s.to_string())
944 .collect();
945
946 if all_ids.is_empty() {
947 return Ok(Vec::new());
948 }
949
950 let max_tokens = crate::constants::extraction_max_tokens();
958 if all_ids.len() > max_tokens {
959 tracing::warn!(
960 target: "extraction",
961 original_tokens = all_ids.len(),
962 capped_tokens = max_tokens,
963 "NER input truncated to cap; later body region will be skipped by NER (regex prefilter still covers full body)"
964 );
965 all_ids.truncate(max_tokens);
966 all_tokens.truncate(max_tokens);
967 }
968
969 let mut windows: Vec<(Vec<u32>, Vec<String>)> = Vec::new();
971 let mut start = 0usize;
972 loop {
973 let end = (start + MAX_SEQ_LEN).min(all_ids.len());
974 windows.push((
975 all_ids[start..end].to_vec(),
976 all_tokens[start..end].to_vec(),
977 ));
978 if end >= all_ids.len() {
979 break;
980 }
981 start += STRIDE;
982 }
983
984 windows.sort_by_key(|(ids, _)| ids.len());
986
987 let batch_size = crate::constants::ner_batch_size();
989 let mut entities: Vec<ExtractedEntity> = Vec::new();
990 let mut seen: std::collections::HashSet<u64> = std::collections::HashSet::new();
991
992 for chunk in windows.chunks(batch_size) {
993 match model.predict_batch(chunk) {
994 Ok(batch_labels) => {
995 for (labels, (_, tokens)) in batch_labels.iter().zip(chunk.iter()) {
996 for ent in iob_to_entities(tokens, labels) {
997 if seen.insert(hash_str(&ent.name)) {
998 entities.push(ent);
999 }
1000 }
1001 }
1002 }
1003 Err(e) => {
1004 tracing::warn!(
1005 "batch NER failed (chunk of {} windows): {e:#} — falling back to single-window",
1006 chunk.len()
1007 );
1008 for (ids, tokens) in chunk {
1010 let mask = vec![1u32; ids.len()];
1011 match model.predict(ids, &mask) {
1012 Ok(labels) => {
1013 for ent in iob_to_entities(tokens, &labels) {
1014 if seen.insert(hash_str(&ent.name)) {
1015 entities.push(ent);
1016 }
1017 }
1018 }
1019 Err(e2) => {
1020 tracing::warn!("NER window fallback also failed: {e2:#}");
1021 }
1022 }
1023 }
1024 }
1025 }
1026 }
1027
1028 Ok(entities)
1029}
1030
1031fn extend_with_numeric_suffix(entities: Vec<ExtractedEntity>, body: &str) -> Vec<ExtractedEntity> {
1038 static SUFFIX_RE: OnceLock<Regex> = OnceLock::new();
1039 let suffix_re = SUFFIX_RE.get_or_init(|| {
1042 Regex::new(r"^([\-\s]+\d+(?:\.\d+)?[a-z]?)")
1043 .expect("compile-time validated numeric suffix regex literal")
1044 });
1045
1046 entities
1047 .into_iter()
1048 .map(|ent| {
1049 if let Some(pos) = body.find(&ent.name) {
1051 let after_pos = pos + ent.name.len();
1052 if after_pos < body.len() {
1053 let after = &body[after_pos..];
1054 if let Some(m) = suffix_re.find(after) {
1055 let suffix = m.as_str();
1056 if suffix.len() <= 7 {
1059 let mut extended = String::with_capacity(ent.name.len() + suffix.len());
1060 extended.push_str(&ent.name);
1061 extended.push_str(suffix);
1062 return ExtractedEntity {
1063 name: extended,
1064 entity_type: ent.entity_type,
1065 };
1066 }
1067 }
1068 }
1069 }
1070 ent
1071 })
1072 .collect()
1073}
1074
1075fn augment_versioned_model_names(
1095 entities: Vec<ExtractedEntity>,
1096 body: &str,
1097) -> Vec<ExtractedEntity> {
1098 static VERSIONED_MODEL_RE: OnceLock<Regex> = OnceLock::new();
1099 let model_re = VERSIONED_MODEL_RE.get_or_init(|| {
1106 Regex::new(
1107 r"\b([A-Z][A-Za-z]{2,15})[\s\-]+(\d+(?:\.\d+)?(?:[a-z]|x\d+[A-Za-z]?)?)(?:\s+(?:Sonnet|Opus|Haiku|Turbo|Pro|Lite|Mini|Nano|Flash|Ultra))?\b",
1108 )
1109 .expect("compile-time validated versioned model regex literal")
1110 });
1111
1112 let mut existing_lc: std::collections::HashSet<String> =
1113 entities.iter().map(|ent| ent.name.to_lowercase()).collect();
1114 let mut result = entities;
1115
1116 for caps in model_re.captures_iter(body) {
1117 let full_match = caps.get(0).map(|m| m.as_str()).unwrap_or("");
1118 if full_match.is_empty() || full_match.len() > 24 {
1121 continue;
1122 }
1123 let normalized_lc = full_match.to_lowercase();
1124 if existing_lc.contains(&normalized_lc) {
1125 continue;
1126 }
1127 if result.len() >= MAX_ENTS {
1130 break;
1131 }
1132 existing_lc.insert(normalized_lc);
1133 result.push(ExtractedEntity {
1134 name: full_match.to_string(),
1135 entity_type: "concept".to_string(),
1136 });
1137 }
1138
1139 result
1140}
1141
1142fn merge_and_deduplicate(
1143 regex_ents: Vec<ExtractedEntity>,
1144 ner_ents: Vec<ExtractedEntity>,
1145) -> Vec<ExtractedEntity> {
1146 let mut by_lc: std::collections::HashMap<String, usize> = std::collections::HashMap::new();
1161 let mut result: Vec<ExtractedEntity> = Vec::new();
1162 let mut truncated = false;
1163
1164 let total_input = regex_ents.len() + ner_ents.len();
1165 for ent in regex_ents.into_iter().chain(ner_ents) {
1166 let name_lc = ent.name.nfkc().collect::<String>().to_lowercase();
1167 let key = {
1171 let mut k = String::with_capacity(ent.entity_type.len() + 1 + name_lc.len());
1172 k.push_str(&ent.entity_type);
1173 k.push('\0');
1174 k.push_str(&name_lc);
1175 k
1176 };
1177
1178 let type_prefix = {
1183 let mut p = String::with_capacity(ent.entity_type.len() + 1);
1184 p.push_str(&ent.entity_type);
1185 p.push('\0');
1186 p
1187 };
1188 let mut collision_idx: Option<usize> = None;
1189 for (existing_key, idx) in &by_lc {
1190 if !existing_key.starts_with(&type_prefix) {
1192 continue;
1193 }
1194 let existing_name_lc = &existing_key[type_prefix.len()..];
1195 if existing_name_lc == name_lc
1196 || existing_name_lc.contains(name_lc.as_str())
1197 || name_lc.contains(existing_name_lc)
1198 {
1199 collision_idx = Some(*idx);
1200 break;
1201 }
1202 }
1203 match collision_idx {
1204 Some(idx) => {
1205 if ent.name.len() > result[idx].name.len() {
1208 let old_name_lc = result[idx].name.nfkc().collect::<String>().to_lowercase();
1209 let old_key = {
1210 let et = &result[idx].entity_type;
1211 let mut k = String::with_capacity(et.len() + 1 + old_name_lc.len());
1212 k.push_str(et);
1213 k.push('\0');
1214 k.push_str(&old_name_lc);
1215 k
1216 };
1217 by_lc.remove(&old_key);
1218 result[idx] = ent;
1219 by_lc.insert(key, idx);
1220 }
1221 }
1222 None => {
1223 by_lc.insert(key, result.len());
1224 result.push(ent);
1225 }
1226 }
1227 if result.len() >= MAX_ENTS {
1228 truncated = true;
1229 break;
1230 }
1231 }
1232
1233 if truncated {
1235 tracing::warn!(
1236 "extraction truncated at {MAX_ENTS} entities (input had {total_input} candidates before deduplication)"
1237 );
1238 }
1239
1240 result
1241}
1242
1243fn to_new_entities(extracted: Vec<ExtractedEntity>) -> Vec<NewEntity> {
1244 extracted
1245 .into_iter()
1246 .map(|e| NewEntity {
1247 name: e.name,
1248 entity_type: e.entity_type,
1249 description: None,
1250 })
1251 .collect()
1252}
1253
1254pub fn extract_graph_auto(body: &str, paths: &AppPaths) -> Result<ExtractionResult> {
1255 let regex_entities = apply_regex_prefilter(body);
1256
1257 let mut bert_used = false;
1258 let ner_entities = match get_or_init_model(paths) {
1259 Some(model) => match run_ner_sliding_window(model, body, paths) {
1260 Ok(ents) => {
1261 bert_used = true;
1262 ents
1263 }
1264 Err(e) => {
1265 tracing::warn!("NER failed, falling back to regex-only extraction: {e:#}");
1266 Vec::new()
1267 }
1268 },
1269 None => Vec::new(),
1270 };
1271
1272 let merged = merge_and_deduplicate(regex_entities, ner_entities);
1273 let extended = extend_with_numeric_suffix(merged, body);
1275 let with_models = augment_versioned_model_names(extended, body);
1279 let with_models: Vec<ExtractedEntity> = with_models
1283 .into_iter()
1284 .filter(|e| !regex_section_marker().is_match(&e.name))
1285 .collect();
1286 let entities = to_new_entities(with_models);
1287 let (relationships, relationships_truncated) =
1288 build_relationships_by_sentence_cooccurrence(body, &entities);
1289
1290 let extraction_method = if bert_used {
1291 "bert+regex-batch".to_string()
1292 } else {
1293 "regex-only".to_string()
1294 };
1295
1296 let urls = extract_urls(body);
1297
1298 Ok(ExtractionResult {
1299 entities,
1300 relationships,
1301 relationships_truncated,
1302 extraction_method,
1303 urls,
1304 })
1305}
1306
1307pub struct RegexExtractor;
1308
1309impl Extractor for RegexExtractor {
1310 fn extract(&self, body: &str) -> Result<ExtractionResult> {
1311 let regex_entities = apply_regex_prefilter(body);
1312 let entities = to_new_entities(regex_entities);
1313 let (relationships, relationships_truncated) =
1314 build_relationships_by_sentence_cooccurrence(body, &entities);
1315 let urls = extract_urls(body);
1316 Ok(ExtractionResult {
1317 entities,
1318 relationships,
1319 relationships_truncated,
1320 extraction_method: "regex-only".to_string(),
1321 urls,
1322 })
1323 }
1324}
1325
1326#[cfg(test)]
1327mod tests {
1328 use super::*;
1329
1330 fn make_paths() -> AppPaths {
1331 use std::path::PathBuf;
1332 AppPaths {
1333 db: PathBuf::from("/tmp/test.sqlite"),
1334 models: PathBuf::from("/tmp/test_models"),
1335 }
1336 }
1337
1338 #[test]
1339 fn regex_email_captures_address() {
1340 let ents = apply_regex_prefilter("contact: someone@company.com for more info");
1341 assert!(ents
1343 .iter()
1344 .any(|e| e.name == "someone@company.com" && e.entity_type == "concept"));
1345 }
1346
1347 #[test]
1348 fn regex_all_caps_filters_pt_rule_word() {
1349 let ents = apply_regex_prefilter("NUNCA do this. PROIBIDO use X. DEVE follow Y.");
1351 assert!(
1352 !ents.iter().any(|e| e.name == "NUNCA"),
1353 "NUNCA must be filtered as a stopword"
1354 );
1355 assert!(
1356 !ents.iter().any(|e| e.name == "PROIBIDO"),
1357 "PROIBIDO must be filtered"
1358 );
1359 assert!(
1360 !ents.iter().any(|e| e.name == "DEVE"),
1361 "DEVE must be filtered"
1362 );
1363 }
1364
1365 #[test]
1366 fn regex_all_caps_accepts_underscored_constant() {
1367 let ents = apply_regex_prefilter("configure MAX_RETRY=3 and API_TIMEOUT=30");
1369 assert!(ents.iter().any(|e| e.name == "MAX_RETRY"));
1370 assert!(ents.iter().any(|e| e.name == "API_TIMEOUT"));
1371 }
1372
1373 #[test]
1374 fn regex_all_caps_accepts_domain_acronym() {
1375 let ents = apply_regex_prefilter("OPENAI launched GPT-5 with NVIDIA H100");
1377 assert!(ents.iter().any(|e| e.name == "OPENAI"));
1378 assert!(ents.iter().any(|e| e.name == "NVIDIA"));
1379 }
1380
1381 #[test]
1382 fn regex_url_does_not_appear_in_apply_regex_prefilter() {
1383 let ents = apply_regex_prefilter("see https://docs.rs/crate for details");
1385 assert!(
1386 !ents.iter().any(|e| e.name.starts_with("https://")),
1387 "URLs must not appear as entities after the P0-2 split"
1388 );
1389 }
1390
1391 #[test]
1392 fn extract_urls_captures_https() {
1393 let urls = extract_urls("see https://docs.rs/crate for details");
1394 assert_eq!(urls.len(), 1);
1395 assert_eq!(urls[0].url, "https://docs.rs/crate");
1396 assert!(urls[0].offset > 0);
1397 }
1398
1399 #[test]
1400 fn extract_urls_trim_sufixo_pontuacao() {
1401 let urls = extract_urls("link: https://example.com/path. fim");
1402 assert!(!urls.is_empty());
1403 assert!(
1404 !urls[0].url.ends_with('.'),
1405 "sufixo ponto deve ser removido"
1406 );
1407 }
1408
1409 #[test]
1410 fn extract_urls_dedupes_repeated() {
1411 let body = "https://example.com referenciado aqui e depois aqui https://example.com";
1412 let urls = extract_urls(body);
1413 assert_eq!(urls.len(), 1, "URLs repetidas devem ser deduplicadas");
1414 }
1415
1416 #[test]
1417 fn regex_uuid_captura_identificador() {
1418 let ents = apply_regex_prefilter("id=550e8400-e29b-41d4-a716-446655440000 no sistema");
1419 assert!(ents.iter().any(|e| e.entity_type == "concept"));
1420 }
1421
1422 #[test]
1423 fn regex_all_caps_captura_constante() {
1424 let ents = apply_regex_prefilter("configure MAX_RETRY e TIMEOUT_MS");
1425 assert!(ents.iter().any(|e| e.name == "MAX_RETRY"));
1426 assert!(ents.iter().any(|e| e.name == "TIMEOUT_MS"));
1427 }
1428
1429 #[test]
1430 fn regex_all_caps_ignores_short_words() {
1431 let ents = apply_regex_prefilter("use AI em seu projeto");
1432 assert!(
1433 !ents.iter().any(|e| e.name == "AI"),
1434 "AI tem apenas 2 chars, deve ser ignorado"
1435 );
1436 }
1437
1438 #[test]
1439 fn iob_decodes_per_to_person() {
1440 let tokens = vec![
1441 "John".to_string(),
1442 "Doe".to_string(),
1443 "trabalhou".to_string(),
1444 ];
1445 let labels = vec!["B-PER".to_string(), "I-PER".to_string(), "O".to_string()];
1446 let ents = iob_to_entities(&tokens, &labels);
1447 assert_eq!(ents.len(), 1);
1448 assert_eq!(ents[0].entity_type, "person");
1449 assert!(ents[0].name.contains("John"));
1450 }
1451
1452 #[test]
1453 fn iob_strip_subword_b_prefix() {
1454 let tokens = vec!["Open".to_string(), "##AI".to_string()];
1457 let labels = vec!["B-ORG".to_string(), "B-ORG".to_string()];
1458 let ents = iob_to_entities(&tokens, &labels);
1459 assert!(
1460 ents.iter().any(|e| e.name == "OpenAI" || e.name == "Open"),
1461 "should merge ##AI or discard"
1462 );
1463 }
1464
1465 #[test]
1466 fn iob_subword_orphan_discards() {
1467 let tokens = vec!["##AI".to_string()];
1469 let labels = vec!["B-ORG".to_string()];
1470 let ents = iob_to_entities(&tokens, &labels);
1471 assert!(
1472 ents.is_empty(),
1473 "orphan subword without an active entity must be discarded"
1474 );
1475 }
1476
1477 #[test]
1478 fn iob_maps_date_to_date_v1025() {
1479 let tokens = vec!["January".to_string(), "2024".to_string()];
1481 let labels = vec!["B-DATE".to_string(), "I-DATE".to_string()];
1482 let ents = iob_to_entities(&tokens, &labels);
1483 assert_eq!(
1484 ents.len(),
1485 1,
1486 "DATE must be emitted as an entity in v1.0.25"
1487 );
1488 assert_eq!(ents[0].entity_type, "date");
1489 }
1490
1491 #[test]
1492 fn iob_maps_org_to_organization_v1025() {
1493 let tokens = vec!["Empresa".to_string()];
1495 let labels = vec!["B-ORG".to_string()];
1496 let ents = iob_to_entities(&tokens, &labels);
1497 assert_eq!(ents[0].entity_type, "organization");
1498 }
1499
1500 #[test]
1501 fn iob_maps_org_sdk_to_tool() {
1502 let tokens = vec!["tokio-sdk".to_string()];
1503 let labels = vec!["B-ORG".to_string()];
1504 let ents = iob_to_entities(&tokens, &labels);
1505 assert_eq!(ents[0].entity_type, "tool");
1506 }
1507
1508 #[test]
1509 fn iob_maps_loc_to_location_v1025() {
1510 let tokens = vec!["Brasil".to_string()];
1512 let labels = vec!["B-LOC".to_string()];
1513 let ents = iob_to_entities(&tokens, &labels);
1514 assert_eq!(ents[0].entity_type, "location");
1515 }
1516
1517 #[test]
1518 fn build_relationships_respeitam_max_rels() {
1519 let entities: Vec<NewEntity> = (0..20)
1520 .map(|i| NewEntity {
1521 name: format!("entidade_{i}"),
1522 entity_type: "concept".to_string(),
1523 description: None,
1524 })
1525 .collect();
1526 let (rels, truncated) = build_relationships(&entities);
1527 let max_rels = crate::constants::max_relationships_per_memory();
1528 assert!(rels.len() <= max_rels, "deve respeitar max_rels={max_rels}");
1529 if rels.len() == max_rels {
1530 assert!(truncated, "truncated deve ser true quando atingiu o cap");
1531 }
1532 }
1533
1534 #[test]
1535 fn build_relationships_without_duplicates() {
1536 let entities: Vec<NewEntity> = (0..5)
1537 .map(|i| NewEntity {
1538 name: format!("ent_{i}"),
1539 entity_type: "concept".to_string(),
1540 description: None,
1541 })
1542 .collect();
1543 let (rels, _truncated) = build_relationships(&entities);
1544 let mut pares: std::collections::HashSet<(String, String)> =
1545 std::collections::HashSet::new();
1546 for r in &rels {
1547 let par = (r.source.clone(), r.target.clone());
1548 assert!(pares.insert(par), "par duplicado encontrado");
1549 }
1550 }
1551
1552 #[test]
1553 fn merge_dedupes_by_lowercase_name() {
1554 let a = vec![ExtractedEntity {
1557 name: "Rust".to_string(),
1558 entity_type: "concept".to_string(),
1559 }];
1560 let b = vec![ExtractedEntity {
1561 name: "rust".to_string(),
1562 entity_type: "concept".to_string(),
1563 }];
1564 let merged = merge_and_deduplicate(a, b);
1565 assert_eq!(
1566 merged.len(),
1567 1,
1568 "rust and Rust with the same type are the same entity"
1569 );
1570 }
1571
1572 #[test]
1573 fn regex_extractor_implements_trait() {
1574 let extractor = RegexExtractor;
1575 let result = extractor
1576 .extract("contato: dev@empresa.io e MAX_TIMEOUT configurado")
1577 .unwrap();
1578 assert!(!result.entities.is_empty());
1579 }
1580
1581 #[test]
1582 fn extract_returns_ok_without_model() {
1583 let paths = make_paths();
1585 let body = "contato: teste@exemplo.com com MAX_RETRY=3";
1586 let result = extract_graph_auto(body, &paths).unwrap();
1587 assert!(result
1588 .entities
1589 .iter()
1590 .any(|e| e.name.contains("teste@exemplo.com")));
1591 }
1592
1593 #[test]
1594 fn stopwords_filter_v1024_terms() {
1595 let body = "ACEITE ACK ACL BORDA CHECKLIST COMPLETED CONFIRME \
1598 DEVEMOS DONE FIXED NEGUE PENDING PLAN PODEMOS RECUSE TOKEN VAMOS";
1599 let ents = apply_regex_prefilter(body);
1600 let names: Vec<&str> = ents.iter().map(|e| e.name.as_str()).collect();
1601 for word in &[
1602 "ACEITE",
1603 "ACK",
1604 "ACL",
1605 "BORDA",
1606 "CHECKLIST",
1607 "COMPLETED",
1608 "CONFIRME",
1609 "DEVEMOS",
1610 "DONE",
1611 "FIXED",
1612 "NEGUE",
1613 "PENDING",
1614 "PLAN",
1615 "PODEMOS",
1616 "RECUSE",
1617 "TOKEN",
1618 "VAMOS",
1619 ] {
1620 assert!(
1621 !names.contains(word),
1622 "v1.0.24 stopword {word} should be filtered but was found in entities"
1623 );
1624 }
1625 }
1626
1627 #[test]
1628 fn dedup_normalizes_unicode_combining_marks() {
1629 let nfc = vec![ExtractedEntity {
1633 name: "Caf\u{e9}".to_string(),
1634 entity_type: "concept".to_string(),
1635 }];
1636 let nfd_name = "Cafe\u{301}".to_string();
1638 let nfd = vec![ExtractedEntity {
1639 name: nfd_name,
1640 entity_type: "concept".to_string(),
1641 }];
1642 let merged = merge_and_deduplicate(nfc, nfd);
1643 assert_eq!(
1644 merged.len(),
1645 1,
1646 "NFC 'Caf\\u{{e9}}' and NFD 'Cafe\\u{{301}}' must deduplicate to 1 entity after NFKC normalization"
1647 );
1648 }
1649
1650 #[test]
1653 fn predict_batch_output_count_matches_input() {
1654 let w1_ids: Vec<u32> = vec![101, 100, 102];
1660 let w1_tok: Vec<String> = vec!["[CLS]".into(), "hello".into(), "[SEP]".into()];
1661 let w2_ids: Vec<u32> = vec![101, 100, 200, 300, 102];
1662 let w2_tok: Vec<String> = vec![
1663 "[CLS]".into(),
1664 "world".into(),
1665 "foo".into(),
1666 "bar".into(),
1667 "[SEP]".into(),
1668 ];
1669 let windows: Vec<(Vec<u32>, Vec<String>)> =
1670 vec![(w1_ids.clone(), w1_tok), (w2_ids.clone(), w2_tok)];
1671
1672 let device = Device::Cpu;
1675 let max_len = windows.iter().map(|(ids, _)| ids.len()).max().unwrap();
1676 assert_eq!(max_len, 5, "max_len deve ser 5");
1677
1678 let mut padded_ids: Vec<Tensor> = Vec::new();
1679 for (ids, _) in &windows {
1680 let len = ids.len();
1681 let pad_right = max_len - len;
1682 let ids_i64: Vec<i64> = ids.iter().map(|&x| x as i64).collect();
1683 let t = Tensor::from_vec(ids_i64, len, &device).unwrap();
1684 let t = t.pad_with_zeros(0, 0, pad_right).unwrap();
1685 assert_eq!(
1686 t.dims(),
1687 &[max_len],
1688 "each window must have shape (max_len,) after padding"
1689 );
1690 padded_ids.push(t);
1691 }
1692
1693 let stacked = Tensor::stack(&padded_ids, 0).unwrap();
1694 assert_eq!(
1695 stacked.dims(),
1696 &[2, max_len],
1697 "stack deve produzir (batch_size=2, max_len=5)"
1698 );
1699
1700 let fake_logits_data: Vec<f32> = vec![0.0f32; 2 * max_len * 9]; let fake_logits =
1704 Tensor::from_vec(fake_logits_data, (2usize, max_len, 9usize), &device).unwrap();
1705 for (i, (ids, _)) in windows.iter().enumerate() {
1706 let real_len = ids.len();
1707 let example = fake_logits.get(i).unwrap();
1708 let sliced = example.narrow(0, 0, real_len).unwrap();
1709 assert_eq!(
1710 sliced.dims(),
1711 &[real_len, 9],
1712 "narrow deve preservar apenas {real_len} tokens reais"
1713 );
1714 }
1715 }
1716
1717 #[test]
1718 fn predict_batch_empty_windows_returns_empty() {
1719 let windows: Vec<(Vec<u32>, Vec<String>)> = vec![];
1722 let max_len = windows.iter().map(|(ids, _)| ids.len()).max().unwrap_or(0);
1723 assert_eq!(max_len, 0, "zero windows → max_len 0");
1724 let result: Vec<Vec<String>> = if max_len == 0 {
1727 Vec::new()
1728 } else {
1729 unreachable!()
1730 };
1731 assert!(result.is_empty());
1732 }
1733
1734 #[test]
1735 fn ner_batch_size_default_is_8() {
1736 std::env::remove_var("GRAPHRAG_NER_BATCH_SIZE");
1739 assert_eq!(crate::constants::ner_batch_size(), 8);
1740 }
1741
1742 #[test]
1743 fn ner_batch_size_env_override_clamped() {
1744 std::env::set_var("GRAPHRAG_NER_BATCH_SIZE", "64");
1746 assert_eq!(crate::constants::ner_batch_size(), 32, "deve clampar em 32");
1747
1748 std::env::set_var("GRAPHRAG_NER_BATCH_SIZE", "0");
1749 assert_eq!(crate::constants::ner_batch_size(), 1, "deve clampar em 1");
1750
1751 std::env::set_var("GRAPHRAG_NER_BATCH_SIZE", "4");
1752 assert_eq!(
1753 crate::constants::ner_batch_size(),
1754 4,
1755 "valid value preserved"
1756 );
1757
1758 std::env::remove_var("GRAPHRAG_NER_BATCH_SIZE");
1759 }
1760
1761 #[test]
1762 fn extraction_method_regex_only_unchanged() {
1763 let result = RegexExtractor.extract("contact: dev@acme.io").unwrap();
1766 assert_eq!(
1767 result.extraction_method, "regex-only",
1768 "RegexExtractor must return regex-only"
1769 );
1770 }
1771
1772 #[test]
1775 fn extend_suffix_pure_numeric_unchanged() {
1776 let ents = vec![ExtractedEntity {
1778 name: "GPT".to_string(),
1779 entity_type: "concept".to_string(),
1780 }];
1781 let result = extend_with_numeric_suffix(ents, "using GPT-5 in the project");
1782 assert_eq!(
1783 result[0].name, "GPT-5",
1784 "purely numeric suffix must be extended"
1785 );
1786 }
1787
1788 #[test]
1789 fn extend_suffix_alphanumeric_letter_after_digit() {
1790 let ents = vec![ExtractedEntity {
1792 name: "GPT".to_string(),
1793 entity_type: "concept".to_string(),
1794 }];
1795 let result = extend_with_numeric_suffix(ents, "using GPT-4o for advanced tasks");
1796 assert_eq!(result[0].name, "GPT-4o", "suffix '4o' must be accepted");
1797 }
1798
1799 #[test]
1800 fn extend_suffix_alphanumeric_b_suffix() {
1801 let ents = vec![ExtractedEntity {
1803 name: "Llama".to_string(),
1804 entity_type: "concept".to_string(),
1805 }];
1806 let result = extend_with_numeric_suffix(ents, "Llama-5b open-weight model");
1807 assert_eq!(result[0].name, "Llama-5b", "suffix '5b' must be accepted");
1808 }
1809
1810 #[test]
1811 fn extend_suffix_alphanumeric_x_suffix() {
1812 let ents = vec![ExtractedEntity {
1814 name: "Mistral".to_string(),
1815 entity_type: "concept".to_string(),
1816 }];
1817 let result = extend_with_numeric_suffix(ents, "testing Mistral-8x in production");
1818 assert_eq!(result[0].name, "Mistral-8x", "suffix '8x' must be accepted");
1819 }
1820
1821 #[test]
1824 fn augment_versioned_gpt4o() {
1825 let result = augment_versioned_model_names(vec![], "using GPT-4o for analysis");
1827 assert!(
1828 result.iter().any(|e| e.name == "GPT-4o"),
1829 "GPT-4o must be captured by augment, found: {:?}",
1830 result.iter().map(|e| &e.name).collect::<Vec<_>>()
1831 );
1832 }
1833
1834 #[test]
1835 fn augment_versioned_claude_4_sonnet() {
1836 let result =
1838 augment_versioned_model_names(vec![], "best model: Claude 4 Sonnet released today");
1839 assert!(
1840 result.iter().any(|e| e.name == "Claude 4 Sonnet"),
1841 "Claude 4 Sonnet must be captured, found: {:?}",
1842 result.iter().map(|e| &e.name).collect::<Vec<_>>()
1843 );
1844 }
1845
1846 #[test]
1847 fn augment_versioned_llama_3_pro() {
1848 let result =
1850 augment_versioned_model_names(vec![], "fine-tuning com Llama 3 Pro localmente");
1851 assert!(
1852 result.iter().any(|e| e.name == "Llama 3 Pro"),
1853 "Llama 3 Pro deve ser capturado, achados: {:?}",
1854 result.iter().map(|e| &e.name).collect::<Vec<_>>()
1855 );
1856 }
1857
1858 #[test]
1859 fn augment_versioned_mixtral_8x7b() {
1860 let result =
1862 augment_versioned_model_names(vec![], "executando Mixtral 8x7B no servidor local");
1863 assert!(
1864 result.iter().any(|e| e.name == "Mixtral 8x7B"),
1865 "Mixtral 8x7B deve ser capturado, achados: {:?}",
1866 result.iter().map(|e| &e.name).collect::<Vec<_>>()
1867 );
1868 }
1869
1870 #[test]
1871 fn augment_versioned_does_not_duplicate_existing() {
1872 let existing = vec![ExtractedEntity {
1874 name: "Claude 4".to_string(),
1875 entity_type: "concept".to_string(),
1876 }];
1877 let result = augment_versioned_model_names(existing, "using Claude 4 in the project");
1878 let count = result.iter().filter(|e| e.name == "Claude 4").count();
1879 assert_eq!(count, 1, "Claude 4 must not be duplicated");
1880 }
1881
1882 #[test]
1885 fn stopwords_filter_url_jwt_api_v1025() {
1886 let body = "We use URL, JWT, and API REST in our LLM-powered CLI via HTTP/HTTPS and UI.";
1888 let ents = apply_regex_prefilter(body);
1889 let names: Vec<&str> = ents.iter().map(|e| e.name.as_str()).collect();
1890 for blocked in &[
1891 "URL", "JWT", "API", "REST", "LLM", "CLI", "HTTP", "HTTPS", "UI",
1892 ] {
1893 assert!(
1894 !names.contains(blocked),
1895 "v1.0.25 stopword {blocked} leaked as entity; found names: {names:?}"
1896 );
1897 }
1898 }
1899
1900 #[test]
1903 fn section_markers_etapa_fase_filtered_v1025() {
1904 let body = "Etapa 3 do plano: implementar Fase 1 da Migra\u{e7}\u{e3}o.";
1908 let ents = apply_regex_prefilter(body);
1909 assert!(
1910 !ents
1911 .iter()
1912 .any(|e| e.name.contains("Etapa") || e.name.contains("Fase")),
1913 "section markers must be stripped; entities: {:?}",
1914 ents.iter().map(|e| &e.name).collect::<Vec<_>>()
1915 );
1916 }
1917
1918 #[test]
1919 fn section_markers_passo_secao_filtered_v1025() {
1920 let body = "Siga Passo 2 conforme Se\u{e7}\u{e3}o 3 do manual.";
1923 let ents = apply_regex_prefilter(body);
1924 assert!(
1925 !ents
1926 .iter()
1927 .any(|e| e.name.contains("Passo") || e.name.contains("Se\u{e7}\u{e3}o")),
1928 "Passo/Se\\u{{e7}}\\u{{e3}}o section markers must be stripped; entities: {:?}",
1929 ents.iter().map(|e| &e.name).collect::<Vec<_>>()
1930 );
1931 }
1932
1933 #[test]
1936 fn brand_camelcase_extracted_as_organization_v1025() {
1937 let body = "OpenAI launched GPT-4 and PostgreSQL added pgvector.";
1939 let ents = apply_regex_prefilter(body);
1940 let openai = ents.iter().find(|e| e.name == "OpenAI");
1941 assert!(
1942 openai.is_some(),
1943 "OpenAI must be extracted by CamelCase brand regex; entities: {:?}",
1944 ents.iter().map(|e| &e.name).collect::<Vec<_>>()
1945 );
1946 assert_eq!(
1947 openai.unwrap().entity_type,
1948 "organization",
1949 "brand CamelCase must map to organization (V008)"
1950 );
1951 }
1952
1953 #[test]
1954 fn brand_postgresql_extracted_as_organization_v1025() {
1955 let body = "migrating from MySQL to PostgreSQL for better performance.";
1956 let ents = apply_regex_prefilter(body);
1957 assert!(
1958 ents.iter()
1959 .any(|e| e.name == "PostgreSQL" && e.entity_type == "organization"),
1960 "PostgreSQL must be extracted as organization; entities: {:?}",
1961 ents.iter()
1962 .map(|e| (&e.name, &e.entity_type))
1963 .collect::<Vec<_>>()
1964 );
1965 }
1966
1967 #[test]
1970 fn iob_org_maps_to_organization_not_project_v1025() {
1971 let tokens = vec!["Microsoft".to_string()];
1973 let labels = vec!["B-ORG".to_string()];
1974 let ents = iob_to_entities(&tokens, &labels);
1975 assert_eq!(
1976 ents[0].entity_type, "organization",
1977 "B-ORG must map to organization (V008); got {}",
1978 ents[0].entity_type
1979 );
1980 }
1981
1982 #[test]
1983 fn iob_loc_maps_to_location_not_concept_v1025() {
1984 let tokens = vec!["S\u{e3}o".to_string(), "Paulo".to_string()];
1987 let labels = vec!["B-LOC".to_string(), "I-LOC".to_string()];
1988 let ents = iob_to_entities(&tokens, &labels);
1989 assert_eq!(
1990 ents[0].entity_type, "location",
1991 "B-LOC must map to location (V008); got {}",
1992 ents[0].entity_type
1993 );
1994 }
1995
1996 #[test]
1997 fn iob_date_maps_to_date_not_discarded_v1025() {
1998 let tokens = vec!["2025".to_string(), "-".to_string(), "12".to_string()];
2000 let labels = vec![
2001 "B-DATE".to_string(),
2002 "I-DATE".to_string(),
2003 "I-DATE".to_string(),
2004 ];
2005 let ents = iob_to_entities(&tokens, &labels);
2006 assert_eq!(
2007 ents.len(),
2008 1,
2009 "DATE entity must be emitted (V008); entities: {ents:?}"
2010 );
2011 assert_eq!(ents[0].entity_type, "date");
2012 }
2013
2014 #[test]
2017 fn pt_verb_le_filtered_as_per_v1025() {
2018 let tokens = vec!["L\u{ea}".to_string(), "o".to_string(), "livro".to_string()];
2021 let labels = vec!["B-PER".to_string(), "O".to_string(), "O".to_string()];
2022 let ents = iob_to_entities(&tokens, &labels);
2023 assert!(
2024 !ents
2025 .iter()
2026 .any(|e| e.name == "L\u{ea}" && e.entity_type == "person"),
2027 "PT verb 'L\\u{{ea}}' tagged B-PER must be filtered; entities: {ents:?}"
2028 );
2029 }
2030
2031 #[test]
2032 fn pt_verb_ver_filtered_as_per_v1025() {
2033 let tokens = vec!["Ver".to_string()];
2035 let labels = vec!["B-PER".to_string()];
2036 let ents = iob_to_entities(&tokens, &labels);
2037 assert!(
2038 ents.is_empty(),
2039 "PT verb 'Ver' tagged B-PER must be filtered; entities: {ents:?}"
2040 );
2041 }
2042
2043 fn entity(name: &str, entity_type: &str) -> ExtractedEntity {
2046 ExtractedEntity {
2047 name: name.to_string(),
2048 entity_type: entity_type.to_string(),
2049 }
2050 }
2051
2052 #[test]
2053 fn merge_resolves_sonne_vs_sonnet_keeps_longest_v1025() {
2054 let regex = vec![entity("Sonne", "concept")];
2056 let ner = vec![entity("Sonnet", "concept")];
2057 let result = merge_and_deduplicate(regex, ner);
2058 assert_eq!(result.len(), 1, "expected 1 entity, got: {result:?}");
2059 assert_eq!(result[0].name, "Sonnet");
2060 }
2061
2062 #[test]
2063 fn merge_resolves_open_vs_openai_keeps_longest_v1025() {
2064 let regex = vec![
2066 entity("Open", "organization"),
2067 entity("OpenAI", "organization"),
2068 ];
2069 let result = merge_and_deduplicate(regex, vec![]);
2070 assert_eq!(result.len(), 1, "expected 1 entity, got: {result:?}");
2071 assert_eq!(result[0].name, "OpenAI");
2072 }
2073
2074 #[test]
2075 fn merge_keeps_both_when_no_containment_v1025() {
2076 let regex = vec![entity("Alice", "person"), entity("Bob", "person")];
2078 let result = merge_and_deduplicate(regex, vec![]);
2079 assert_eq!(result.len(), 2, "expected 2 entities, got: {result:?}");
2080 }
2081
2082 #[test]
2083 fn merge_respects_entity_type_boundary_v1025() {
2084 let regex = vec![entity("Apple", "organization"), entity("Apple", "concept")];
2086 let result = merge_and_deduplicate(regex, vec![]);
2087 assert_eq!(
2088 result.len(),
2089 2,
2090 "expected 2 entities (different types), got: {result:?}"
2091 );
2092 }
2093
2094 #[test]
2095 fn merge_case_insensitive_dedup_v1025() {
2096 let regex = vec![
2098 entity("OpenAI", "organization"),
2099 entity("openai", "organization"),
2100 ];
2101 let result = merge_and_deduplicate(regex, vec![]);
2102 assert_eq!(
2103 result.len(),
2104 1,
2105 "expected 1 entity after case-insensitive dedup, got: {result:?}"
2106 );
2107 }
2108
2109 #[test]
2112 fn iob_section_marker_etapa_filtered_v1025() {
2113 let tokens = vec!["Etapa".to_string(), "3".to_string()];
2115 let labels = vec!["B-MISC".to_string(), "I-MISC".to_string()];
2116 let ents = iob_to_entities(&tokens, &labels);
2117 assert!(
2118 !ents.iter().any(|e| e.name.contains("Etapa")),
2119 "section marker 'Etapa 3' from BERT must be filtered; entities: {ents:?}"
2120 );
2121 }
2122
2123 #[test]
2124 fn iob_section_marker_fase_filtered_v1025() {
2125 let tokens = vec!["Fase".to_string(), "1".to_string()];
2127 let labels = vec!["B-MISC".to_string(), "I-MISC".to_string()];
2128 let ents = iob_to_entities(&tokens, &labels);
2129 assert!(
2130 !ents.iter().any(|e| e.name.contains("Fase")),
2131 "section marker 'Fase 1' from BERT must be filtered; entities: {ents:?}"
2132 );
2133 }
2134
2135 #[test]
2138 fn extract_graph_auto_handles_large_body_under_30s() {
2139 let body = "x ".repeat(40_000);
2142 let paths = make_paths();
2143 let start = std::time::Instant::now();
2144 let result = extract_graph_auto(&body, &paths).expect("extraction must not error");
2145 let elapsed = start.elapsed();
2146 assert!(
2147 elapsed.as_secs() < 30,
2148 "extract_graph_auto took {}s for 80 KB body (cap should keep it well under 30s)",
2149 elapsed.as_secs()
2150 );
2151 let _ = result.entities;
2153 }
2154
2155 #[test]
2158 fn pt_uppercase_stopwords_filtered_v1031() {
2159 let body = "Para o ADAPTER funcionar com PROJETO em modo PASSIVA, devemos usar \
2160 SOMENTE LEITURA conforme a REGRA OBRIGATORIA do EXEMPLO DEFAULT.";
2161 let ents = apply_regex_prefilter(body);
2162 let names: Vec<String> = ents.iter().map(|e| e.name.to_uppercase()).collect();
2163 for stop in &[
2164 "ADAPTER",
2165 "PROJETO",
2166 "PASSIVA",
2167 "SOMENTE",
2168 "LEITURA",
2169 "REGRA",
2170 "OBRIGATORIA",
2171 "EXEMPLO",
2172 "DEFAULT",
2173 ] {
2174 assert!(
2175 !names.contains(&stop.to_string()),
2176 "v1.0.31 A11 stoplist failed: {stop} leaked as entity; got names: {names:?}"
2177 );
2178 }
2179 }
2180
2181 #[test]
2182 fn pt_underscored_identifier_preserved_v1031() {
2183 let ents = apply_regex_prefilter("configure FLOWAIPER_API_KEY=foo and MAX_TIMEOUT=30");
2186 let names: Vec<&str> = ents.iter().map(|e| e.name.as_str()).collect();
2187 assert!(names.contains(&"FLOWAIPER_API_KEY"));
2188 assert!(names.contains(&"MAX_TIMEOUT"));
2189 }
2190
2191 #[test]
2194 fn build_relationships_by_sentence_only_links_co_occurring_entities() {
2195 let body = "Alice met Bob at the conference. Carol works alone in another room.";
2196 let entities = vec![
2197 NewEntity {
2198 name: "Alice".to_string(),
2199 entity_type: "person".to_string(),
2200 description: None,
2201 },
2202 NewEntity {
2203 name: "Bob".to_string(),
2204 entity_type: "person".to_string(),
2205 description: None,
2206 },
2207 NewEntity {
2208 name: "Carol".to_string(),
2209 entity_type: "person".to_string(),
2210 description: None,
2211 },
2212 ];
2213 let (rels, truncated) = build_relationships_by_sentence_cooccurrence(body, &entities);
2214 assert!(!truncated);
2215 assert_eq!(
2216 rels.len(),
2217 1,
2218 "only Alice/Bob should pair (same sentence); Carol is isolated"
2219 );
2220 let pair = (rels[0].source.as_str(), rels[0].target.as_str());
2221 assert!(
2222 matches!(pair, ("Alice", "Bob") | ("Bob", "Alice")),
2223 "unexpected pair {pair:?}"
2224 );
2225 }
2226
2227 #[test]
2228 fn build_relationships_by_sentence_returns_empty_for_single_entity() {
2229 let body = "Alice is here.";
2230 let entities = vec![NewEntity {
2231 name: "Alice".to_string(),
2232 entity_type: "person".to_string(),
2233 description: None,
2234 }];
2235 let (rels, truncated) = build_relationships_by_sentence_cooccurrence(body, &entities);
2236 assert!(rels.is_empty());
2237 assert!(!truncated);
2238 }
2239
2240 #[test]
2241 fn build_relationships_by_sentence_dedupes_pairs_across_sentences() {
2242 let body = "Alice met Bob. Bob saw Alice again.";
2243 let entities = vec![
2244 NewEntity {
2245 name: "Alice".to_string(),
2246 entity_type: "person".to_string(),
2247 description: None,
2248 },
2249 NewEntity {
2250 name: "Bob".to_string(),
2251 entity_type: "person".to_string(),
2252 description: None,
2253 },
2254 ];
2255 let (rels, _) = build_relationships_by_sentence_cooccurrence(body, &entities);
2256 assert_eq!(
2257 rels.len(),
2258 1,
2259 "Alice/Bob pair must be emitted only once even when co-occurring in multiple sentences"
2260 );
2261 }
2262
2263 #[test]
2264 fn extraction_max_tokens_default_is_5000() {
2265 std::env::remove_var("SQLITE_GRAPHRAG_EXTRACTION_MAX_TOKENS");
2266 assert_eq!(crate::constants::extraction_max_tokens(), 5_000);
2267 }
2268
2269 #[test]
2270 fn extraction_max_tokens_env_override_clamped() {
2271 std::env::set_var("SQLITE_GRAPHRAG_EXTRACTION_MAX_TOKENS", "200");
2272 assert_eq!(
2273 crate::constants::extraction_max_tokens(),
2274 5_000,
2275 "value below 512 must fall back to default"
2276 );
2277
2278 std::env::set_var("SQLITE_GRAPHRAG_EXTRACTION_MAX_TOKENS", "200000");
2279 assert_eq!(
2280 crate::constants::extraction_max_tokens(),
2281 5_000,
2282 "value above 100_000 must fall back to default"
2283 );
2284
2285 std::env::set_var("SQLITE_GRAPHRAG_EXTRACTION_MAX_TOKENS", "8000");
2286 assert_eq!(
2287 crate::constants::extraction_max_tokens(),
2288 8_000,
2289 "valid value must be honoured"
2290 );
2291
2292 std::env::remove_var("SQLITE_GRAPHRAG_EXTRACTION_MAX_TOKENS");
2293 }
2294}