1use std::collections::HashMap;
7use std::path::{Path, PathBuf};
8use std::sync::OnceLock;
9
10use anyhow::{Context, Result};
11use candle_core::{DType, Device, Tensor};
12use candle_nn::{Linear, Module, VarBuilder};
13use candle_transformers::models::bert::{BertModel, Config as BertConfig};
14use regex::Regex;
15use serde::Deserialize;
16use unicode_normalization::UnicodeNormalization;
17
18use crate::paths::AppPaths;
19use crate::storage::entities::{NewEntity, NewRelationship};
20
21const MODEL_ID: &str = "Davlan/bert-base-multilingual-cased-ner-hrl";
22const MAX_SEQ_LEN: usize = 512;
23const STRIDE: usize = 256;
24const MAX_ENTS: usize = 30;
25#[cfg(test)]
28const TOP_K_RELATIONS: usize = 5;
29const DEFAULT_RELATION: &str = "mentions";
30const MIN_ENTITY_CHARS: usize = 2;
31
32static REGEX_EMAIL: OnceLock<Regex> = OnceLock::new();
33static REGEX_URL: OnceLock<Regex> = OnceLock::new();
34static REGEX_UUID: OnceLock<Regex> = OnceLock::new();
35static REGEX_ALL_CAPS: OnceLock<Regex> = OnceLock::new();
36static REGEX_SECTION_MARKER: OnceLock<Regex> = OnceLock::new();
38static REGEX_BRAND_CAMEL: OnceLock<Regex> = OnceLock::new();
40
41const ALL_CAPS_STOPWORDS: &[&str] = &[
59 "ACEITE",
60 "ACK",
61 "ACL",
62 "ACRESCENTADO",
63 "ADAPTER",
64 "ADICIONAR",
65 "AGENTS",
66 "ALL",
67 "ALTA",
68 "ALWAYS",
69 "API",
70 "ARTEFATOS",
71 "ATIVA",
72 "ATIVO",
73 "BAIXA",
74 "BANCO",
75 "BORDA",
76 "BLOQUEAR",
77 "BUG",
78 "CAPÍTULO",
79 "CASO",
80 "CHECKLIST",
81 "CLI",
82 "COMPLETED",
83 "CONFIRMADO",
84 "CONFIRME",
85 "CONTRATO",
86 "CRÍTICO",
87 "CRITICAL",
88 "CSV",
89 "DEFAULT",
90 "DEVE",
91 "DEVEMOS",
92 "DISCO",
93 "DONE",
94 "EFEITO",
95 "ENTRADA",
96 "ERROR",
97 "ESCRITA",
98 "ESSA",
99 "ESSE",
100 "ESSENCIAL",
101 "ESTA",
102 "ESTE",
103 "ETAPA",
104 "EVITAR",
105 "EXEMPLO",
106 "EXPANDIR",
107 "EXPOR",
108 "FALHA",
109 "FASE",
110 "FIXED",
111 "FIXME",
112 "FORBIDDEN",
113 "HACK",
114 "HEARTBEAT",
115 "HTTP",
116 "HTTPS",
117 "INATIVO",
118 "JAMAIS",
119 "JSON",
120 "JWT",
121 "LEITURA",
122 "LLM",
123 "MUST",
124 "NEGUE",
125 "NEVER",
126 "NOTE",
127 "NUNCA",
128 "OBRIGATORIA",
129 "OBRIGATÓRIO",
130 "PADRÃO",
131 "PASSIVA",
132 "PASSO",
133 "PENDING",
134 "PLAN",
135 "PODEMOS",
136 "PROIBIDO",
137 "PROJETO",
138 "RECUSE",
139 "REGRA",
140 "REGRAS",
141 "REQUIRED",
142 "REQUISITO",
143 "REST",
144 "SEÇÃO",
145 "SEMPRE",
146 "SHALL",
147 "SHOULD",
148 "SOMENTE",
149 "SOUL",
150 "TODAS",
151 "TODO",
152 "TODOS",
153 "TOKEN",
154 "TOOLS",
155 "TSV",
156 "UI",
157 "URL",
158 "USAR",
159 "VALIDAR",
160 "VAMOS",
161 "VOCÊ",
162 "WARNING",
163 "XML",
164 "YAML",
165];
166
167const HTTP_METHODS: &[&str] = &[
170 "GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS", "CONNECT", "TRACE",
171];
172
173fn is_filtered_all_caps(token: &str) -> bool {
174 let is_identifier = token.contains('_');
176 if is_identifier {
177 return false;
178 }
179 ALL_CAPS_STOPWORDS.contains(&token) || HTTP_METHODS.contains(&token)
180}
181
182fn regex_email() -> &'static Regex {
183 REGEX_EMAIL.get_or_init(|| {
184 Regex::new(r"[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}")
185 .expect("compile-time validated email regex literal")
186 })
187}
188
189fn regex_url() -> &'static Regex {
190 REGEX_URL.get_or_init(|| {
191 Regex::new(r#"https?://[^\s\)\]\}"'<>]+"#)
192 .expect("compile-time validated URL regex literal")
193 })
194}
195
196fn regex_uuid() -> &'static Regex {
197 REGEX_UUID.get_or_init(|| {
198 Regex::new(r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}")
199 .expect("compile-time validated UUID regex literal")
200 })
201}
202
203fn regex_all_caps() -> &'static Regex {
204 REGEX_ALL_CAPS.get_or_init(|| {
205 Regex::new(r"\b[A-Z][A-Z0-9_]{2,}\b")
206 .expect("compile-time validated all-caps regex literal")
207 })
208}
209
210fn regex_section_marker() -> &'static Regex {
211 REGEX_SECTION_MARKER.get_or_init(|| {
212 Regex::new("\\b(?:Etapa|Fase|Passo|Se\u{00e7}\u{00e3}o|Cap\u{00ed}tulo)\\s+\\d+\\b")
217 .expect("compile-time validated section marker regex literal")
218 })
219}
220
221fn regex_brand_camel() -> &'static Regex {
222 REGEX_BRAND_CAMEL.get_or_init(|| {
223 Regex::new(r"\b[A-Z][a-z]+[A-Z][A-Za-z]+\b")
226 .expect("compile-time validated CamelCase brand regex literal")
227 })
228}
229
230#[derive(Debug, Clone, PartialEq)]
231pub struct ExtractedEntity {
232 pub name: String,
233 pub entity_type: String,
234}
235
236#[derive(Debug, Clone)]
238pub struct ExtractedUrl {
239 pub url: String,
240 pub offset: usize,
242}
243
244#[derive(Debug, Clone)]
245pub struct ExtractionResult {
246 pub entities: Vec<NewEntity>,
247 pub relationships: Vec<NewRelationship>,
248 pub relationships_truncated: bool,
251 pub extraction_method: String,
254 pub urls: Vec<ExtractedUrl>,
256}
257
258pub trait Extractor: Send + Sync {
259 fn extract(&self, body: &str) -> Result<ExtractionResult>;
260}
261
262#[derive(Deserialize)]
263struct ModelConfig {
264 #[serde(default)]
265 id2label: HashMap<String, String>,
266 hidden_size: usize,
267}
268
269struct BertNerModel {
270 bert: BertModel,
271 classifier: Linear,
272 device: Device,
273 id2label: HashMap<usize, String>,
274}
275
276impl BertNerModel {
277 fn load(model_dir: &Path) -> Result<Self> {
278 let config_path = model_dir.join("config.json");
279 let weights_path = model_dir.join("model.safetensors");
280
281 let config_str = std::fs::read_to_string(&config_path)
282 .with_context(|| format!("lendo config.json em {config_path:?}"))?;
283 let model_cfg: ModelConfig =
284 serde_json::from_str(&config_str).context("parseando config.json do modelo NER")?;
285
286 let id2label: HashMap<usize, String> = model_cfg
287 .id2label
288 .into_iter()
289 .filter_map(|(k, v)| k.parse::<usize>().ok().map(|n| (n, v)))
290 .collect();
291
292 let num_labels = id2label.len().max(9);
293 let hidden_size = model_cfg.hidden_size;
294
295 let bert_config_str = std::fs::read_to_string(&config_path)
296 .with_context(|| format!("relendo config.json para bert em {config_path:?}"))?;
297 let bert_cfg: BertConfig =
298 serde_json::from_str(&bert_config_str).context("parseando BertConfig")?;
299
300 let device = Device::Cpu;
301
302 let vb = unsafe {
310 VarBuilder::from_mmaped_safetensors(&[&weights_path], DType::F32, &device)
311 .with_context(|| format!("mapping {weights_path:?}"))?
312 };
313 let bert = BertModel::load(vb.pp("bert"), &bert_cfg).context("loading BertModel")?;
314
315 let cls_vb = vb.pp("classifier");
318 let weight = cls_vb
319 .get((num_labels, hidden_size), "weight")
320 .context("carregando classifier.weight do safetensors")?;
321 let bias = cls_vb
322 .get(num_labels, "bias")
323 .context("carregando classifier.bias do safetensors")?;
324 let classifier = Linear::new(weight, Some(bias));
325
326 Ok(Self {
327 bert,
328 classifier,
329 device,
330 id2label,
331 })
332 }
333
334 fn predict(&self, token_ids: &[u32], attention_mask: &[u32]) -> Result<Vec<String>> {
335 let len = token_ids.len();
336 let ids_i64: Vec<i64> = token_ids.iter().map(|&x| x as i64).collect();
337 let mask_i64: Vec<i64> = attention_mask.iter().map(|&x| x as i64).collect();
338
339 let input_ids = Tensor::from_vec(ids_i64, (1, len), &self.device)
340 .context("criando tensor input_ids")?;
341 let token_type_ids = Tensor::zeros((1, len), DType::I64, &self.device)
342 .context("criando tensor token_type_ids")?;
343 let attn_mask = Tensor::from_vec(mask_i64, (1, len), &self.device)
344 .context("criando tensor attention_mask")?;
345
346 let sequence_output = self
347 .bert
348 .forward(&input_ids, &token_type_ids, Some(&attn_mask))
349 .context("BertModel forward pass")?;
350
351 let logits = self
352 .classifier
353 .forward(&sequence_output)
354 .context("classifier forward pass")?;
355
356 let logits_2d = logits.squeeze(0).context("removing batch dimension")?;
357
358 let num_tokens = logits_2d.dim(0).context("dim(0)")?;
359
360 let mut labels = Vec::with_capacity(num_tokens);
361 for i in 0..num_tokens {
362 let token_logits = logits_2d.get(i).context("get token logits")?;
363 let vec: Vec<f32> = token_logits.to_vec1().context("to_vec1 logits")?;
364 let argmax = vec
365 .iter()
366 .enumerate()
367 .max_by(|(_, a), (_, b)| {
368 a.partial_cmp(b)
369 .expect("BERT NER logits invariant: no NaN in classifier output")
370 })
371 .map(|(idx, _)| idx)
372 .unwrap_or(0);
373 let label = self
374 .id2label
375 .get(&argmax)
376 .cloned()
377 .unwrap_or_else(|| "O".to_string());
378 labels.push(label);
379 }
380
381 Ok(labels)
382 }
383
384 fn predict_batch(&self, windows: &[(Vec<u32>, Vec<String>)]) -> Result<Vec<Vec<String>>> {
393 let batch_size = windows.len();
394 let max_len = windows.iter().map(|(ids, _)| ids.len()).max().unwrap_or(0);
395 if max_len == 0 {
396 return Ok(vec![vec![]; batch_size]);
397 }
398
399 let mut padded_ids: Vec<Tensor> = Vec::with_capacity(batch_size);
400 let mut padded_masks: Vec<Tensor> = Vec::with_capacity(batch_size);
401
402 for (ids, _) in windows {
403 let len = ids.len();
404 let pad_right = max_len - len;
405
406 let ids_i64: Vec<i64> = ids.iter().map(|&x| x as i64).collect();
407 let t = Tensor::from_vec(ids_i64, len, &self.device)
409 .context("creating id tensor for batch")?;
410 let t = t
411 .pad_with_zeros(0, 0, pad_right)
412 .context("padding id tensor")?;
413 padded_ids.push(t);
414
415 let mut mask_i64 = vec![1i64; len];
417 mask_i64.extend(vec![0i64; pad_right]);
418 let m = Tensor::from_vec(mask_i64, max_len, &self.device)
419 .context("creating mask tensor for batch")?;
420 padded_masks.push(m);
421 }
422
423 let input_ids = Tensor::stack(&padded_ids, 0).context("stack input_ids")?;
425 let attn_mask = Tensor::stack(&padded_masks, 0).context("stack attn_mask")?;
426 let token_type_ids = Tensor::zeros((batch_size, max_len), DType::I64, &self.device)
427 .context("creating token_type_ids tensor for batch")?;
428
429 let sequence_output = self
431 .bert
432 .forward(&input_ids, &token_type_ids, Some(&attn_mask))
433 .context("BertModel batch forward pass")?;
434 let logits = self
437 .classifier
438 .forward(&sequence_output)
439 .context("forward pass batch classificador")?;
440 let mut results = Vec::with_capacity(batch_size);
443 for (i, (window_ids, _)) in windows.iter().enumerate() {
444 let example_logits = logits.get(i).context("get logits exemplo")?;
445 let real_len = window_ids.len();
447 let example_slice = example_logits
448 .narrow(0, 0, real_len)
449 .context("narrow para tokens reais")?;
450 let logits_2d: Vec<Vec<f32>> = example_slice.to_vec2().context("to_vec2 logits")?;
451
452 let labels: Vec<String> = logits_2d
453 .iter()
454 .map(|token_logits| {
455 let argmax = token_logits
456 .iter()
457 .enumerate()
458 .max_by(|(_, a), (_, b)| {
459 a.partial_cmp(b)
460 .expect("BERT NER logits invariant: no NaN in classifier output")
461 })
462 .map(|(idx, _)| idx)
463 .unwrap_or(0);
464 self.id2label
465 .get(&argmax)
466 .cloned()
467 .unwrap_or_else(|| "O".to_string())
468 })
469 .collect();
470
471 results.push(labels);
472 }
473
474 Ok(results)
475 }
476}
477
478static NER_MODEL: OnceLock<Option<BertNerModel>> = OnceLock::new();
479
480fn get_or_init_model(paths: &AppPaths) -> Option<&'static BertNerModel> {
481 NER_MODEL
482 .get_or_init(|| match load_model(paths) {
483 Ok(m) => Some(m),
484 Err(e) => {
485 tracing::warn!("NER model unavailable (graceful degradation): {e:#}");
486 None
487 }
488 })
489 .as_ref()
490}
491
492fn model_dir(paths: &AppPaths) -> PathBuf {
493 paths.models.join("bert-multilingual-ner")
494}
495
496fn ensure_model_files(paths: &AppPaths) -> Result<PathBuf> {
497 let dir = model_dir(paths);
498 std::fs::create_dir_all(&dir).with_context(|| format!("creating model directory: {dir:?}"))?;
499
500 let weights = dir.join("model.safetensors");
501 let config = dir.join("config.json");
502 let tokenizer = dir.join("tokenizer.json");
503
504 if weights.exists() && config.exists() && tokenizer.exists() {
505 return Ok(dir);
506 }
507
508 tracing::info!("Downloading NER model (first run, ~676 MB)...");
509 crate::output::emit_progress_i18n(
510 "Downloading NER model (first run, ~676 MB)...",
511 crate::i18n::validation::runtime_pt::downloading_ner_model(),
512 );
513
514 let api = huggingface_hub::api::sync::Api::new().context("creating HF Hub client")?;
515 let repo = api.model(MODEL_ID.to_string());
516
517 for (remote, local) in &[
521 ("model.safetensors", "model.safetensors"),
522 ("config.json", "config.json"),
523 ("onnx/tokenizer.json", "tokenizer.json"),
524 ("tokenizer_config.json", "tokenizer_config.json"),
525 ] {
526 let dest = dir.join(local);
527 if !dest.exists() {
528 let src = repo
529 .get(remote)
530 .with_context(|| format!("baixando {remote} do HF Hub"))?;
531 std::fs::copy(&src, &dest).with_context(|| format!("copiando {local} para cache"))?;
532 }
533 }
534
535 Ok(dir)
536}
537
538fn load_model(paths: &AppPaths) -> Result<BertNerModel> {
539 let dir = ensure_model_files(paths)?;
540 BertNerModel::load(&dir)
541}
542
543fn apply_regex_prefilter(body: &str) -> Vec<ExtractedEntity> {
544 let mut entities = Vec::new();
545 let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
546
547 let add = |entities: &mut Vec<ExtractedEntity>,
548 seen: &mut std::collections::HashSet<String>,
549 name: &str,
550 entity_type: &str| {
551 let name = name.trim().to_string();
552 if name.len() >= MIN_ENTITY_CHARS && seen.insert(name.clone()) {
553 entities.push(ExtractedEntity {
554 name,
555 entity_type: entity_type.to_string(),
556 });
557 }
558 };
559
560 let cleaned = regex_section_marker().replace_all(body, " ");
563 let cleaned = cleaned.as_ref();
564
565 for m in regex_email().find_iter(cleaned) {
566 add(&mut entities, &mut seen, m.as_str(), "concept");
568 }
569 for m in regex_uuid().find_iter(cleaned) {
570 add(&mut entities, &mut seen, m.as_str(), "concept");
571 }
572 for m in regex_all_caps().find_iter(cleaned) {
573 let candidate = m.as_str();
574 if !is_filtered_all_caps(candidate) {
576 add(&mut entities, &mut seen, candidate, "concept");
577 }
578 }
579 for m in regex_brand_camel().find_iter(cleaned) {
582 let name = m.as_str();
583 if !ALL_CAPS_STOPWORDS.contains(&name.to_uppercase().as_str()) {
585 add(&mut entities, &mut seen, name, "organization");
586 }
587 }
588
589 entities
590}
591
592pub fn extract_urls(body: &str) -> Vec<ExtractedUrl> {
596 let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
597 let mut result = Vec::new();
598 for m in regex_url().find_iter(body) {
599 let raw = m.as_str();
600 let cleaned = raw
601 .trim_end_matches('`')
602 .trim_end_matches(',')
603 .trim_end_matches('.')
604 .trim_end_matches(';')
605 .trim_end_matches(')')
606 .trim_end_matches(']')
607 .trim_end_matches('}');
608 if cleaned.len() >= 10 && seen.insert(cleaned.to_string()) {
609 result.push(ExtractedUrl {
610 url: cleaned.to_string(),
611 offset: m.start(),
612 });
613 }
614 }
615 result
616}
617
618fn iob_to_entities(tokens: &[String], labels: &[String]) -> Vec<ExtractedEntity> {
619 let mut entities: Vec<ExtractedEntity> = Vec::new();
620 let mut current_parts: Vec<String> = Vec::new();
621 let mut current_type: Option<String> = None;
622
623 let flush =
624 |parts: &mut Vec<String>, typ: &mut Option<String>, entities: &mut Vec<ExtractedEntity>| {
625 if let Some(t) = typ.take() {
626 let name = parts.join(" ").trim().to_string();
627 let is_single_caps = !name.contains(' ')
631 && name == name.to_uppercase()
632 && name.len() >= MIN_ENTITY_CHARS;
633 let should_skip = is_single_caps && is_filtered_all_caps(&name);
634 let is_section_marker = regex_section_marker().is_match(&name);
639 if name.len() >= MIN_ENTITY_CHARS && !should_skip && !is_section_marker {
640 entities.push(ExtractedEntity {
641 name,
642 entity_type: t,
643 });
644 }
645 parts.clear();
646 }
647 };
648
649 for (token, label) in tokens.iter().zip(labels.iter()) {
650 if label == "O" {
651 flush(&mut current_parts, &mut current_type, &mut entities);
652 continue;
653 }
654
655 let (prefix, bio_type) = if let Some(rest) = label.strip_prefix("B-") {
656 ("B", rest)
657 } else if let Some(rest) = label.strip_prefix("I-") {
658 ("I", rest)
659 } else {
660 flush(&mut current_parts, &mut current_type, &mut entities);
661 continue;
662 };
663
664 const PT_VERB_FALSE_POSITIVES: &[&str] = &[
671 "L\u{00ea}",
672 "V\u{00ea}",
673 "C\u{00e1}",
674 "P\u{00f4}r",
675 "Ser",
676 "Vir",
677 "Ver",
678 "Dar",
679 "Ler",
680 "Ter",
681 ];
682
683 let entity_type = match bio_type {
684 "DATE" => "date",
686 "PER" => {
687 if PT_VERB_FALSE_POSITIVES.contains(&token.as_str()) {
689 flush(&mut current_parts, &mut current_type, &mut entities);
690 continue;
691 }
692 "person"
693 }
694 "ORG" => {
695 let t = token.to_lowercase();
696 if t.contains("lib")
697 || t.contains("sdk")
698 || t.contains("cli")
699 || t.contains("crate")
700 || t.contains("npm")
701 {
702 "tool"
703 } else {
704 "organization"
706 }
707 }
708 "LOC" => "location",
710 other => other,
711 };
712
713 if prefix == "B" {
714 if token.starts_with("##") {
715 let clean = token.strip_prefix("##").unwrap_or(token.as_str());
718 if let Some(last) = current_parts.last_mut() {
719 last.push_str(clean);
720 }
721 continue;
722 }
723 flush(&mut current_parts, &mut current_type, &mut entities);
724 current_parts.push(token.clone());
725 current_type = Some(entity_type.to_string());
726 } else if prefix == "I" && current_type.is_some() {
727 let clean = token.strip_prefix("##").unwrap_or(token.as_str());
728 if token.starts_with("##") {
729 if let Some(last) = current_parts.last_mut() {
730 last.push_str(clean);
731 }
732 } else {
733 current_parts.push(clean.to_string());
734 }
735 }
736 }
737
738 flush(&mut current_parts, &mut current_type, &mut entities);
739 entities
740}
741
742#[cfg(test)]
752fn build_relationships(entities: &[NewEntity]) -> (Vec<NewRelationship>, bool) {
753 if entities.len() < 2 {
754 return (Vec::new(), false);
755 }
756
757 let max_rels = crate::constants::max_relationships_per_memory();
760 let n = entities.len().min(MAX_ENTS);
761 let mut rels: Vec<NewRelationship> = Vec::new();
762 let mut seen: std::collections::HashSet<(String, String)> = std::collections::HashSet::new();
763
764 let mut hit_cap = false;
765 'outer: for i in 0..n {
766 if rels.len() >= max_rels {
767 hit_cap = true;
768 break;
769 }
770
771 let mut for_entity = 0usize;
772 for j in (i + 1)..n {
773 if for_entity >= TOP_K_RELATIONS {
774 break;
775 }
776 if rels.len() >= max_rels {
777 hit_cap = true;
778 break 'outer;
779 }
780
781 let src = &entities[i].name;
782 let tgt = &entities[j].name;
783 let key = (src.clone(), tgt.clone());
784
785 if seen.contains(&key) {
786 continue;
787 }
788 seen.insert(key);
789
790 rels.push(NewRelationship {
791 source: src.clone(),
792 target: tgt.clone(),
793 relation: DEFAULT_RELATION.to_string(),
794 strength: 0.5,
795 description: None,
796 });
797 for_entity += 1;
798 }
799 }
800
801 if hit_cap {
803 tracing::warn!(
804 "relationships truncated to {max_rels} (with {n} entities, theoretical max was ~{}x combinations)",
805 n.saturating_sub(1)
806 );
807 }
808
809 (rels, hit_cap)
810}
811
812fn build_relationships_by_sentence_cooccurrence(
823 body: &str,
824 entities: &[NewEntity],
825) -> (Vec<NewRelationship>, bool) {
826 if entities.len() < 2 {
827 return (Vec::new(), false);
828 }
829
830 let max_rels = crate::constants::max_relationships_per_memory();
831 let lower_names: Vec<(usize, String)> = entities
832 .iter()
833 .take(MAX_ENTS)
834 .enumerate()
835 .map(|(i, e)| (i, e.name.to_lowercase()))
836 .collect();
837
838 let mut rels: Vec<NewRelationship> = Vec::new();
839 let mut seen: std::collections::HashSet<(String, String)> = std::collections::HashSet::new();
840 let mut hit_cap = false;
841
842 for sentence in body.split(['.', '!', '?', '\n']) {
843 if sentence.trim().is_empty() {
844 continue;
845 }
846 let lower_sentence = sentence.to_lowercase();
847 let present: Vec<usize> = lower_names
848 .iter()
849 .filter(|(_, name)| !name.is_empty() && lower_sentence.contains(name.as_str()))
850 .map(|(i, _)| *i)
851 .collect();
852
853 if present.len() < 2 {
854 continue;
855 }
856
857 for i in 0..present.len() {
858 for j in (i + 1)..present.len() {
859 if rels.len() >= max_rels {
860 hit_cap = true;
861 tracing::warn!(
862 "relationships truncated to {max_rels} during sentence-level pairing"
863 );
864 return (rels, hit_cap);
865 }
866 let a = &entities[present[i]];
867 let b = &entities[present[j]];
868 let key = (a.name.to_lowercase(), b.name.to_lowercase());
869 if seen.insert(key) {
870 rels.push(NewRelationship {
871 source: a.name.clone(),
872 target: b.name.clone(),
873 relation: DEFAULT_RELATION.to_string(),
874 strength: 0.5,
875 description: None,
876 });
877 }
878 }
879 }
880 }
881
882 (rels, hit_cap)
883}
884
885fn run_ner_sliding_window(
886 model: &BertNerModel,
887 body: &str,
888 paths: &AppPaths,
889) -> Result<Vec<ExtractedEntity>> {
890 let tokenizer_path = model_dir(paths).join("tokenizer.json");
891 let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
892 .map_err(|e| anyhow::anyhow!("loading NER tokenizer: {e}"))?;
893
894 let encoding = tokenizer
895 .encode(body, false)
896 .map_err(|e| anyhow::anyhow!("encoding NER input: {e}"))?;
897
898 let mut all_ids: Vec<u32> = encoding.get_ids().to_vec();
899 let mut all_tokens: Vec<String> = encoding
900 .get_tokens()
901 .iter()
902 .map(|s| s.to_string())
903 .collect();
904
905 if all_ids.is_empty() {
906 return Ok(Vec::new());
907 }
908
909 let max_tokens = crate::constants::extraction_max_tokens();
917 if all_ids.len() > max_tokens {
918 tracing::warn!(
919 target: "extraction",
920 original_tokens = all_ids.len(),
921 capped_tokens = max_tokens,
922 "NER input truncated to cap; later body region will be skipped by NER (regex prefilter still covers full body)"
923 );
924 all_ids.truncate(max_tokens);
925 all_tokens.truncate(max_tokens);
926 }
927
928 let mut windows: Vec<(Vec<u32>, Vec<String>)> = Vec::new();
930 let mut start = 0usize;
931 loop {
932 let end = (start + MAX_SEQ_LEN).min(all_ids.len());
933 windows.push((
934 all_ids[start..end].to_vec(),
935 all_tokens[start..end].to_vec(),
936 ));
937 if end >= all_ids.len() {
938 break;
939 }
940 start += STRIDE;
941 }
942
943 windows.sort_by_key(|(ids, _)| ids.len());
945
946 let batch_size = crate::constants::ner_batch_size();
948 let mut entities: Vec<ExtractedEntity> = Vec::new();
949 let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
950
951 for chunk in windows.chunks(batch_size) {
952 match model.predict_batch(chunk) {
953 Ok(batch_labels) => {
954 for (labels, (_, tokens)) in batch_labels.iter().zip(chunk.iter()) {
955 for ent in iob_to_entities(tokens, labels) {
956 if seen.insert(ent.name.clone()) {
957 entities.push(ent);
958 }
959 }
960 }
961 }
962 Err(e) => {
963 tracing::warn!(
964 "batch NER falhou (chunk de {} janelas): {e:#} — fallback single-window",
965 chunk.len()
966 );
967 for (ids, tokens) in chunk {
969 let mask = vec![1u32; ids.len()];
970 match model.predict(ids, &mask) {
971 Ok(labels) => {
972 for ent in iob_to_entities(tokens, &labels) {
973 if seen.insert(ent.name.clone()) {
974 entities.push(ent);
975 }
976 }
977 }
978 Err(e2) => {
979 tracing::warn!("NER window fallback also failed: {e2:#}");
980 }
981 }
982 }
983 }
984 }
985 }
986
987 Ok(entities)
988}
989
990fn extend_with_numeric_suffix(entities: Vec<ExtractedEntity>, body: &str) -> Vec<ExtractedEntity> {
997 static SUFFIX_RE: OnceLock<Regex> = OnceLock::new();
998 let suffix_re = SUFFIX_RE.get_or_init(|| {
1001 Regex::new(r"^([\-\s]+\d+(?:\.\d+)?[a-z]?)")
1002 .expect("compile-time validated numeric suffix regex literal")
1003 });
1004
1005 entities
1006 .into_iter()
1007 .map(|ent| {
1008 if let Some(pos) = body.find(&ent.name) {
1010 let after_pos = pos + ent.name.len();
1011 if after_pos < body.len() {
1012 let after = &body[after_pos..];
1013 if let Some(m) = suffix_re.find(after) {
1014 let suffix = m.as_str();
1015 if suffix.len() <= 7 {
1018 let extended = format!("{}{}", ent.name, suffix);
1019 return ExtractedEntity {
1020 name: extended,
1021 entity_type: ent.entity_type,
1022 };
1023 }
1024 }
1025 }
1026 }
1027 ent
1028 })
1029 .collect()
1030}
1031
1032fn augment_versioned_model_names(
1052 entities: Vec<ExtractedEntity>,
1053 body: &str,
1054) -> Vec<ExtractedEntity> {
1055 static VERSIONED_MODEL_RE: OnceLock<Regex> = OnceLock::new();
1056 let model_re = VERSIONED_MODEL_RE.get_or_init(|| {
1063 Regex::new(
1064 r"\b([A-Z][A-Za-z]{2,15})[\s\-]+(\d+(?:\.\d+)?(?:[a-z]|x\d+[A-Za-z]?)?)(?:\s+(?:Sonnet|Opus|Haiku|Turbo|Pro|Lite|Mini|Nano|Flash|Ultra))?\b",
1065 )
1066 .expect("compile-time validated versioned model regex literal")
1067 });
1068
1069 let mut existing_lc: std::collections::HashSet<String> =
1070 entities.iter().map(|ent| ent.name.to_lowercase()).collect();
1071 let mut result = entities;
1072
1073 for caps in model_re.captures_iter(body) {
1074 let full_match = caps.get(0).map(|m| m.as_str()).unwrap_or("");
1075 if full_match.is_empty() || full_match.len() > 24 {
1078 continue;
1079 }
1080 let normalized_lc = full_match.to_lowercase();
1081 if existing_lc.contains(&normalized_lc) {
1082 continue;
1083 }
1084 if result.len() >= MAX_ENTS {
1087 break;
1088 }
1089 existing_lc.insert(normalized_lc);
1090 result.push(ExtractedEntity {
1091 name: full_match.to_string(),
1092 entity_type: "concept".to_string(),
1093 });
1094 }
1095
1096 result
1097}
1098
1099fn merge_and_deduplicate(
1100 regex_ents: Vec<ExtractedEntity>,
1101 ner_ents: Vec<ExtractedEntity>,
1102) -> Vec<ExtractedEntity> {
1103 let mut by_lc: std::collections::HashMap<String, usize> = std::collections::HashMap::new();
1118 let mut result: Vec<ExtractedEntity> = Vec::new();
1119 let mut truncated = false;
1120
1121 let total_input = regex_ents.len() + ner_ents.len();
1122 for ent in regex_ents.into_iter().chain(ner_ents) {
1123 let name_lc = ent.name.nfkc().collect::<String>().to_lowercase();
1124 let key = format!("{}\0{}", ent.entity_type, name_lc);
1128
1129 let mut collision_idx: Option<usize> = None;
1134 for (existing_key, idx) in &by_lc {
1135 let type_prefix = format!("{}\0", ent.entity_type);
1137 if !existing_key.starts_with(&type_prefix) {
1138 continue;
1139 }
1140 let existing_name_lc = &existing_key[type_prefix.len()..];
1141 if existing_name_lc == name_lc
1142 || existing_name_lc.contains(name_lc.as_str())
1143 || name_lc.contains(existing_name_lc)
1144 {
1145 collision_idx = Some(*idx);
1146 break;
1147 }
1148 }
1149 match collision_idx {
1150 Some(idx) => {
1151 if ent.name.len() > result[idx].name.len() {
1154 let old_name_lc = result[idx].name.nfkc().collect::<String>().to_lowercase();
1155 let old_key = format!("{}\0{}", result[idx].entity_type, old_name_lc);
1156 by_lc.remove(&old_key);
1157 result[idx] = ent;
1158 by_lc.insert(key, idx);
1159 }
1160 }
1161 None => {
1162 by_lc.insert(key, result.len());
1163 result.push(ent);
1164 }
1165 }
1166 if result.len() >= MAX_ENTS {
1167 truncated = true;
1168 break;
1169 }
1170 }
1171
1172 if truncated {
1174 tracing::warn!(
1175 "extraction truncated at {MAX_ENTS} entities (input had {total_input} candidates before deduplication)"
1176 );
1177 }
1178
1179 result
1180}
1181
1182fn to_new_entities(extracted: Vec<ExtractedEntity>) -> Vec<NewEntity> {
1183 extracted
1184 .into_iter()
1185 .map(|e| NewEntity {
1186 name: e.name,
1187 entity_type: e.entity_type,
1188 description: None,
1189 })
1190 .collect()
1191}
1192
1193pub fn extract_graph_auto(body: &str, paths: &AppPaths) -> Result<ExtractionResult> {
1194 let regex_entities = apply_regex_prefilter(body);
1195
1196 let mut bert_used = false;
1197 let ner_entities = match get_or_init_model(paths) {
1198 Some(model) => match run_ner_sliding_window(model, body, paths) {
1199 Ok(ents) => {
1200 bert_used = true;
1201 ents
1202 }
1203 Err(e) => {
1204 tracing::warn!("NER falhou, usando apenas regex: {e:#}");
1205 Vec::new()
1206 }
1207 },
1208 None => Vec::new(),
1209 };
1210
1211 let merged = merge_and_deduplicate(regex_entities, ner_entities);
1212 let extended = extend_with_numeric_suffix(merged, body);
1214 let with_models = augment_versioned_model_names(extended, body);
1218 let with_models: Vec<ExtractedEntity> = with_models
1222 .into_iter()
1223 .filter(|e| !regex_section_marker().is_match(&e.name))
1224 .collect();
1225 let entities = to_new_entities(with_models);
1226 let (relationships, relationships_truncated) =
1227 build_relationships_by_sentence_cooccurrence(body, &entities);
1228
1229 let extraction_method = if bert_used {
1230 "bert+regex-batch".to_string()
1231 } else {
1232 "regex-only".to_string()
1233 };
1234
1235 let urls = extract_urls(body);
1236
1237 Ok(ExtractionResult {
1238 entities,
1239 relationships,
1240 relationships_truncated,
1241 extraction_method,
1242 urls,
1243 })
1244}
1245
1246pub struct RegexExtractor;
1247
1248impl Extractor for RegexExtractor {
1249 fn extract(&self, body: &str) -> Result<ExtractionResult> {
1250 let regex_entities = apply_regex_prefilter(body);
1251 let entities = to_new_entities(regex_entities);
1252 let (relationships, relationships_truncated) =
1253 build_relationships_by_sentence_cooccurrence(body, &entities);
1254 let urls = extract_urls(body);
1255 Ok(ExtractionResult {
1256 entities,
1257 relationships,
1258 relationships_truncated,
1259 extraction_method: "regex-only".to_string(),
1260 urls,
1261 })
1262 }
1263}
1264
1265#[cfg(test)]
1266mod tests {
1267 use super::*;
1268
1269 fn make_paths() -> AppPaths {
1270 use std::path::PathBuf;
1271 AppPaths {
1272 db: PathBuf::from("/tmp/test.sqlite"),
1273 models: PathBuf::from("/tmp/test_models"),
1274 }
1275 }
1276
1277 #[test]
1278 fn regex_email_captures_address() {
1279 let ents = apply_regex_prefilter("contact: someone@company.com for more info");
1280 assert!(ents
1282 .iter()
1283 .any(|e| e.name == "someone@company.com" && e.entity_type == "concept"));
1284 }
1285
1286 #[test]
1287 fn regex_all_caps_filters_pt_rule_word() {
1288 let ents = apply_regex_prefilter("NUNCA do this. PROIBIDO use X. DEVE follow Y.");
1290 assert!(
1291 !ents.iter().any(|e| e.name == "NUNCA"),
1292 "NUNCA must be filtered as a stopword"
1293 );
1294 assert!(
1295 !ents.iter().any(|e| e.name == "PROIBIDO"),
1296 "PROIBIDO must be filtered"
1297 );
1298 assert!(
1299 !ents.iter().any(|e| e.name == "DEVE"),
1300 "DEVE must be filtered"
1301 );
1302 }
1303
1304 #[test]
1305 fn regex_all_caps_accepts_underscored_constant() {
1306 let ents = apply_regex_prefilter("configure MAX_RETRY=3 and API_TIMEOUT=30");
1308 assert!(ents.iter().any(|e| e.name == "MAX_RETRY"));
1309 assert!(ents.iter().any(|e| e.name == "API_TIMEOUT"));
1310 }
1311
1312 #[test]
1313 fn regex_all_caps_accepts_domain_acronym() {
1314 let ents = apply_regex_prefilter("OPENAI launched GPT-5 with NVIDIA H100");
1316 assert!(ents.iter().any(|e| e.name == "OPENAI"));
1317 assert!(ents.iter().any(|e| e.name == "NVIDIA"));
1318 }
1319
1320 #[test]
1321 fn regex_url_does_not_appear_in_apply_regex_prefilter() {
1322 let ents = apply_regex_prefilter("see https://docs.rs/crate for details");
1324 assert!(
1325 !ents.iter().any(|e| e.name.starts_with("https://")),
1326 "URLs must not appear as entities after the P0-2 split"
1327 );
1328 }
1329
1330 #[test]
1331 fn extract_urls_captures_https() {
1332 let urls = extract_urls("see https://docs.rs/crate for details");
1333 assert_eq!(urls.len(), 1);
1334 assert_eq!(urls[0].url, "https://docs.rs/crate");
1335 assert!(urls[0].offset > 0);
1336 }
1337
1338 #[test]
1339 fn extract_urls_trim_sufixo_pontuacao() {
1340 let urls = extract_urls("link: https://example.com/path. fim");
1341 assert!(!urls.is_empty());
1342 assert!(
1343 !urls[0].url.ends_with('.'),
1344 "sufixo ponto deve ser removido"
1345 );
1346 }
1347
1348 #[test]
1349 fn extract_urls_dedupes_repeated() {
1350 let body = "https://example.com referenciado aqui e depois aqui https://example.com";
1351 let urls = extract_urls(body);
1352 assert_eq!(urls.len(), 1, "URLs repetidas devem ser deduplicadas");
1353 }
1354
1355 #[test]
1356 fn regex_uuid_captura_identificador() {
1357 let ents = apply_regex_prefilter("id=550e8400-e29b-41d4-a716-446655440000 no sistema");
1358 assert!(ents.iter().any(|e| e.entity_type == "concept"));
1359 }
1360
1361 #[test]
1362 fn regex_all_caps_captura_constante() {
1363 let ents = apply_regex_prefilter("configure MAX_RETRY e TIMEOUT_MS");
1364 assert!(ents.iter().any(|e| e.name == "MAX_RETRY"));
1365 assert!(ents.iter().any(|e| e.name == "TIMEOUT_MS"));
1366 }
1367
1368 #[test]
1369 fn regex_all_caps_ignores_short_words() {
1370 let ents = apply_regex_prefilter("use AI em seu projeto");
1371 assert!(
1372 !ents.iter().any(|e| e.name == "AI"),
1373 "AI tem apenas 2 chars, deve ser ignorado"
1374 );
1375 }
1376
1377 #[test]
1378 fn iob_decodes_per_to_person() {
1379 let tokens = vec![
1380 "John".to_string(),
1381 "Doe".to_string(),
1382 "trabalhou".to_string(),
1383 ];
1384 let labels = vec!["B-PER".to_string(), "I-PER".to_string(), "O".to_string()];
1385 let ents = iob_to_entities(&tokens, &labels);
1386 assert_eq!(ents.len(), 1);
1387 assert_eq!(ents[0].entity_type, "person");
1388 assert!(ents[0].name.contains("John"));
1389 }
1390
1391 #[test]
1392 fn iob_strip_subword_b_prefix() {
1393 let tokens = vec!["Open".to_string(), "##AI".to_string()];
1396 let labels = vec!["B-ORG".to_string(), "B-ORG".to_string()];
1397 let ents = iob_to_entities(&tokens, &labels);
1398 assert!(
1399 ents.iter().any(|e| e.name == "OpenAI" || e.name == "Open"),
1400 "should merge ##AI or discard"
1401 );
1402 }
1403
1404 #[test]
1405 fn iob_subword_orphan_discards() {
1406 let tokens = vec!["##AI".to_string()];
1408 let labels = vec!["B-ORG".to_string()];
1409 let ents = iob_to_entities(&tokens, &labels);
1410 assert!(
1411 ents.is_empty(),
1412 "orphan subword without an active entity must be discarded"
1413 );
1414 }
1415
1416 #[test]
1417 fn iob_maps_date_to_date_v1025() {
1418 let tokens = vec!["January".to_string(), "2024".to_string()];
1420 let labels = vec!["B-DATE".to_string(), "I-DATE".to_string()];
1421 let ents = iob_to_entities(&tokens, &labels);
1422 assert_eq!(
1423 ents.len(),
1424 1,
1425 "DATE must be emitted as an entity in v1.0.25"
1426 );
1427 assert_eq!(ents[0].entity_type, "date");
1428 }
1429
1430 #[test]
1431 fn iob_maps_org_to_organization_v1025() {
1432 let tokens = vec!["Empresa".to_string()];
1434 let labels = vec!["B-ORG".to_string()];
1435 let ents = iob_to_entities(&tokens, &labels);
1436 assert_eq!(ents[0].entity_type, "organization");
1437 }
1438
1439 #[test]
1440 fn iob_maps_org_sdk_to_tool() {
1441 let tokens = vec!["tokio-sdk".to_string()];
1442 let labels = vec!["B-ORG".to_string()];
1443 let ents = iob_to_entities(&tokens, &labels);
1444 assert_eq!(ents[0].entity_type, "tool");
1445 }
1446
1447 #[test]
1448 fn iob_maps_loc_to_location_v1025() {
1449 let tokens = vec!["Brasil".to_string()];
1451 let labels = vec!["B-LOC".to_string()];
1452 let ents = iob_to_entities(&tokens, &labels);
1453 assert_eq!(ents[0].entity_type, "location");
1454 }
1455
1456 #[test]
1457 fn build_relationships_respeitam_max_rels() {
1458 let entities: Vec<NewEntity> = (0..20)
1459 .map(|i| NewEntity {
1460 name: format!("entidade_{i}"),
1461 entity_type: "concept".to_string(),
1462 description: None,
1463 })
1464 .collect();
1465 let (rels, truncated) = build_relationships(&entities);
1466 let max_rels = crate::constants::max_relationships_per_memory();
1467 assert!(rels.len() <= max_rels, "deve respeitar max_rels={max_rels}");
1468 if rels.len() == max_rels {
1469 assert!(truncated, "truncated deve ser true quando atingiu o cap");
1470 }
1471 }
1472
1473 #[test]
1474 fn build_relationships_without_duplicates() {
1475 let entities: Vec<NewEntity> = (0..5)
1476 .map(|i| NewEntity {
1477 name: format!("ent_{i}"),
1478 entity_type: "concept".to_string(),
1479 description: None,
1480 })
1481 .collect();
1482 let (rels, _truncated) = build_relationships(&entities);
1483 let mut pares: std::collections::HashSet<(String, String)> =
1484 std::collections::HashSet::new();
1485 for r in &rels {
1486 let par = (r.source.clone(), r.target.clone());
1487 assert!(pares.insert(par), "par duplicado encontrado");
1488 }
1489 }
1490
1491 #[test]
1492 fn merge_dedupes_by_lowercase_name() {
1493 let a = vec![ExtractedEntity {
1496 name: "Rust".to_string(),
1497 entity_type: "concept".to_string(),
1498 }];
1499 let b = vec![ExtractedEntity {
1500 name: "rust".to_string(),
1501 entity_type: "concept".to_string(),
1502 }];
1503 let merged = merge_and_deduplicate(a, b);
1504 assert_eq!(
1505 merged.len(),
1506 1,
1507 "rust and Rust with the same type are the same entity"
1508 );
1509 }
1510
1511 #[test]
1512 fn regex_extractor_implements_trait() {
1513 let extractor = RegexExtractor;
1514 let result = extractor
1515 .extract("contato: dev@empresa.io e MAX_TIMEOUT configurado")
1516 .unwrap();
1517 assert!(!result.entities.is_empty());
1518 }
1519
1520 #[test]
1521 fn extract_returns_ok_without_model() {
1522 let paths = make_paths();
1524 let body = "contato: teste@exemplo.com com MAX_RETRY=3";
1525 let result = extract_graph_auto(body, &paths).unwrap();
1526 assert!(result
1527 .entities
1528 .iter()
1529 .any(|e| e.name.contains("teste@exemplo.com")));
1530 }
1531
1532 #[test]
1533 fn stopwords_filter_v1024_terms() {
1534 let body = "ACEITE ACK ACL BORDA CHECKLIST COMPLETED CONFIRME \
1537 DEVEMOS DONE FIXED NEGUE PENDING PLAN PODEMOS RECUSE TOKEN VAMOS";
1538 let ents = apply_regex_prefilter(body);
1539 let names: Vec<&str> = ents.iter().map(|e| e.name.as_str()).collect();
1540 for word in &[
1541 "ACEITE",
1542 "ACK",
1543 "ACL",
1544 "BORDA",
1545 "CHECKLIST",
1546 "COMPLETED",
1547 "CONFIRME",
1548 "DEVEMOS",
1549 "DONE",
1550 "FIXED",
1551 "NEGUE",
1552 "PENDING",
1553 "PLAN",
1554 "PODEMOS",
1555 "RECUSE",
1556 "TOKEN",
1557 "VAMOS",
1558 ] {
1559 assert!(
1560 !names.contains(word),
1561 "v1.0.24 stopword {word} should be filtered but was found in entities"
1562 );
1563 }
1564 }
1565
1566 #[test]
1567 fn dedup_normalizes_unicode_combining_marks() {
1568 let nfc = vec![ExtractedEntity {
1572 name: "Caf\u{e9}".to_string(),
1573 entity_type: "concept".to_string(),
1574 }];
1575 let nfd_name = "Cafe\u{301}".to_string();
1577 let nfd = vec![ExtractedEntity {
1578 name: nfd_name,
1579 entity_type: "concept".to_string(),
1580 }];
1581 let merged = merge_and_deduplicate(nfc, nfd);
1582 assert_eq!(
1583 merged.len(),
1584 1,
1585 "NFC 'Caf\\u{{e9}}' and NFD 'Cafe\\u{{301}}' must deduplicate to 1 entity after NFKC normalization"
1586 );
1587 }
1588
1589 #[test]
1592 fn predict_batch_output_count_matches_input() {
1593 let w1_ids: Vec<u32> = vec![101, 100, 102];
1599 let w1_tok: Vec<String> = vec!["[CLS]".into(), "hello".into(), "[SEP]".into()];
1600 let w2_ids: Vec<u32> = vec![101, 100, 200, 300, 102];
1601 let w2_tok: Vec<String> = vec![
1602 "[CLS]".into(),
1603 "world".into(),
1604 "foo".into(),
1605 "bar".into(),
1606 "[SEP]".into(),
1607 ];
1608 let windows: Vec<(Vec<u32>, Vec<String>)> =
1609 vec![(w1_ids.clone(), w1_tok), (w2_ids.clone(), w2_tok)];
1610
1611 let device = Device::Cpu;
1614 let max_len = windows.iter().map(|(ids, _)| ids.len()).max().unwrap();
1615 assert_eq!(max_len, 5, "max_len deve ser 5");
1616
1617 let mut padded_ids: Vec<Tensor> = Vec::new();
1618 for (ids, _) in &windows {
1619 let len = ids.len();
1620 let pad_right = max_len - len;
1621 let ids_i64: Vec<i64> = ids.iter().map(|&x| x as i64).collect();
1622 let t = Tensor::from_vec(ids_i64, len, &device).unwrap();
1623 let t = t.pad_with_zeros(0, 0, pad_right).unwrap();
1624 assert_eq!(
1625 t.dims(),
1626 &[max_len],
1627 "each window must have shape (max_len,) after padding"
1628 );
1629 padded_ids.push(t);
1630 }
1631
1632 let stacked = Tensor::stack(&padded_ids, 0).unwrap();
1633 assert_eq!(
1634 stacked.dims(),
1635 &[2, max_len],
1636 "stack deve produzir (batch_size=2, max_len=5)"
1637 );
1638
1639 let fake_logits_data: Vec<f32> = vec![0.0f32; 2 * max_len * 9]; let fake_logits =
1643 Tensor::from_vec(fake_logits_data, (2usize, max_len, 9usize), &device).unwrap();
1644 for (i, (ids, _)) in windows.iter().enumerate() {
1645 let real_len = ids.len();
1646 let example = fake_logits.get(i).unwrap();
1647 let sliced = example.narrow(0, 0, real_len).unwrap();
1648 assert_eq!(
1649 sliced.dims(),
1650 &[real_len, 9],
1651 "narrow deve preservar apenas {real_len} tokens reais"
1652 );
1653 }
1654 }
1655
1656 #[test]
1657 fn predict_batch_empty_windows_returns_empty() {
1658 let windows: Vec<(Vec<u32>, Vec<String>)> = vec![];
1661 let max_len = windows.iter().map(|(ids, _)| ids.len()).max().unwrap_or(0);
1662 assert_eq!(max_len, 0, "zero windows → max_len 0");
1663 let result: Vec<Vec<String>> = if max_len == 0 {
1666 Vec::new()
1667 } else {
1668 unreachable!()
1669 };
1670 assert!(result.is_empty());
1671 }
1672
1673 #[test]
1674 fn ner_batch_size_default_is_8() {
1675 std::env::remove_var("GRAPHRAG_NER_BATCH_SIZE");
1678 assert_eq!(crate::constants::ner_batch_size(), 8);
1679 }
1680
1681 #[test]
1682 fn ner_batch_size_env_override_clamped() {
1683 std::env::set_var("GRAPHRAG_NER_BATCH_SIZE", "64");
1685 assert_eq!(crate::constants::ner_batch_size(), 32, "deve clampar em 32");
1686
1687 std::env::set_var("GRAPHRAG_NER_BATCH_SIZE", "0");
1688 assert_eq!(crate::constants::ner_batch_size(), 1, "deve clampar em 1");
1689
1690 std::env::set_var("GRAPHRAG_NER_BATCH_SIZE", "4");
1691 assert_eq!(
1692 crate::constants::ner_batch_size(),
1693 4,
1694 "valid value preserved"
1695 );
1696
1697 std::env::remove_var("GRAPHRAG_NER_BATCH_SIZE");
1698 }
1699
1700 #[test]
1701 fn extraction_method_regex_only_unchanged() {
1702 let result = RegexExtractor.extract("contact: dev@acme.io").unwrap();
1705 assert_eq!(
1706 result.extraction_method, "regex-only",
1707 "RegexExtractor must return regex-only"
1708 );
1709 }
1710
1711 #[test]
1714 fn extend_suffix_pure_numeric_unchanged() {
1715 let ents = vec![ExtractedEntity {
1717 name: "GPT".to_string(),
1718 entity_type: "concept".to_string(),
1719 }];
1720 let result = extend_with_numeric_suffix(ents, "using GPT-5 in the project");
1721 assert_eq!(
1722 result[0].name, "GPT-5",
1723 "purely numeric suffix must be extended"
1724 );
1725 }
1726
1727 #[test]
1728 fn extend_suffix_alphanumeric_letter_after_digit() {
1729 let ents = vec![ExtractedEntity {
1731 name: "GPT".to_string(),
1732 entity_type: "concept".to_string(),
1733 }];
1734 let result = extend_with_numeric_suffix(ents, "using GPT-4o for advanced tasks");
1735 assert_eq!(result[0].name, "GPT-4o", "suffix '4o' must be accepted");
1736 }
1737
1738 #[test]
1739 fn extend_suffix_alphanumeric_b_suffix() {
1740 let ents = vec![ExtractedEntity {
1742 name: "Llama".to_string(),
1743 entity_type: "concept".to_string(),
1744 }];
1745 let result = extend_with_numeric_suffix(ents, "Llama-5b open-weight model");
1746 assert_eq!(result[0].name, "Llama-5b", "suffix '5b' must be accepted");
1747 }
1748
1749 #[test]
1750 fn extend_suffix_alphanumeric_x_suffix() {
1751 let ents = vec![ExtractedEntity {
1753 name: "Mistral".to_string(),
1754 entity_type: "concept".to_string(),
1755 }];
1756 let result = extend_with_numeric_suffix(ents, "testing Mistral-8x in production");
1757 assert_eq!(result[0].name, "Mistral-8x", "suffix '8x' must be accepted");
1758 }
1759
1760 #[test]
1763 fn augment_versioned_gpt4o() {
1764 let result = augment_versioned_model_names(vec![], "using GPT-4o for analysis");
1766 assert!(
1767 result.iter().any(|e| e.name == "GPT-4o"),
1768 "GPT-4o must be captured by augment, found: {:?}",
1769 result.iter().map(|e| &e.name).collect::<Vec<_>>()
1770 );
1771 }
1772
1773 #[test]
1774 fn augment_versioned_claude_4_sonnet() {
1775 let result =
1777 augment_versioned_model_names(vec![], "best model: Claude 4 Sonnet released today");
1778 assert!(
1779 result.iter().any(|e| e.name == "Claude 4 Sonnet"),
1780 "Claude 4 Sonnet must be captured, found: {:?}",
1781 result.iter().map(|e| &e.name).collect::<Vec<_>>()
1782 );
1783 }
1784
1785 #[test]
1786 fn augment_versioned_llama_3_pro() {
1787 let result =
1789 augment_versioned_model_names(vec![], "fine-tuning com Llama 3 Pro localmente");
1790 assert!(
1791 result.iter().any(|e| e.name == "Llama 3 Pro"),
1792 "Llama 3 Pro deve ser capturado, achados: {:?}",
1793 result.iter().map(|e| &e.name).collect::<Vec<_>>()
1794 );
1795 }
1796
1797 #[test]
1798 fn augment_versioned_mixtral_8x7b() {
1799 let result =
1801 augment_versioned_model_names(vec![], "executando Mixtral 8x7B no servidor local");
1802 assert!(
1803 result.iter().any(|e| e.name == "Mixtral 8x7B"),
1804 "Mixtral 8x7B deve ser capturado, achados: {:?}",
1805 result.iter().map(|e| &e.name).collect::<Vec<_>>()
1806 );
1807 }
1808
1809 #[test]
1810 fn augment_versioned_does_not_duplicate_existing() {
1811 let existing = vec![ExtractedEntity {
1813 name: "Claude 4".to_string(),
1814 entity_type: "concept".to_string(),
1815 }];
1816 let result = augment_versioned_model_names(existing, "using Claude 4 in the project");
1817 let count = result.iter().filter(|e| e.name == "Claude 4").count();
1818 assert_eq!(count, 1, "Claude 4 must not be duplicated");
1819 }
1820
1821 #[test]
1824 fn stopwords_filter_url_jwt_api_v1025() {
1825 let body = "We use URL, JWT, and API REST in our LLM-powered CLI via HTTP/HTTPS and UI.";
1827 let ents = apply_regex_prefilter(body);
1828 let names: Vec<&str> = ents.iter().map(|e| e.name.as_str()).collect();
1829 for blocked in &[
1830 "URL", "JWT", "API", "REST", "LLM", "CLI", "HTTP", "HTTPS", "UI",
1831 ] {
1832 assert!(
1833 !names.contains(blocked),
1834 "v1.0.25 stopword {blocked} leaked as entity; found names: {names:?}"
1835 );
1836 }
1837 }
1838
1839 #[test]
1842 fn section_markers_etapa_fase_filtered_v1025() {
1843 let body = "Etapa 3 do plano: implementar Fase 1 da Migra\u{e7}\u{e3}o.";
1847 let ents = apply_regex_prefilter(body);
1848 assert!(
1849 !ents
1850 .iter()
1851 .any(|e| e.name.contains("Etapa") || e.name.contains("Fase")),
1852 "section markers must be stripped; entities: {:?}",
1853 ents.iter().map(|e| &e.name).collect::<Vec<_>>()
1854 );
1855 }
1856
1857 #[test]
1858 fn section_markers_passo_secao_filtered_v1025() {
1859 let body = "Siga Passo 2 conforme Se\u{e7}\u{e3}o 3 do manual.";
1862 let ents = apply_regex_prefilter(body);
1863 assert!(
1864 !ents
1865 .iter()
1866 .any(|e| e.name.contains("Passo") || e.name.contains("Se\u{e7}\u{e3}o")),
1867 "Passo/Se\\u{{e7}}\\u{{e3}}o section markers must be stripped; entities: {:?}",
1868 ents.iter().map(|e| &e.name).collect::<Vec<_>>()
1869 );
1870 }
1871
1872 #[test]
1875 fn brand_camelcase_extracted_as_organization_v1025() {
1876 let body = "OpenAI launched GPT-4 and PostgreSQL added pgvector.";
1878 let ents = apply_regex_prefilter(body);
1879 let openai = ents.iter().find(|e| e.name == "OpenAI");
1880 assert!(
1881 openai.is_some(),
1882 "OpenAI must be extracted by CamelCase brand regex; entities: {:?}",
1883 ents.iter().map(|e| &e.name).collect::<Vec<_>>()
1884 );
1885 assert_eq!(
1886 openai.unwrap().entity_type,
1887 "organization",
1888 "brand CamelCase must map to organization (V008)"
1889 );
1890 }
1891
1892 #[test]
1893 fn brand_postgresql_extracted_as_organization_v1025() {
1894 let body = "migrating from MySQL to PostgreSQL for better performance.";
1895 let ents = apply_regex_prefilter(body);
1896 assert!(
1897 ents.iter()
1898 .any(|e| e.name == "PostgreSQL" && e.entity_type == "organization"),
1899 "PostgreSQL must be extracted as organization; entities: {:?}",
1900 ents.iter()
1901 .map(|e| (&e.name, &e.entity_type))
1902 .collect::<Vec<_>>()
1903 );
1904 }
1905
1906 #[test]
1909 fn iob_org_maps_to_organization_not_project_v1025() {
1910 let tokens = vec!["Microsoft".to_string()];
1912 let labels = vec!["B-ORG".to_string()];
1913 let ents = iob_to_entities(&tokens, &labels);
1914 assert_eq!(
1915 ents[0].entity_type, "organization",
1916 "B-ORG must map to organization (V008); got {}",
1917 ents[0].entity_type
1918 );
1919 }
1920
1921 #[test]
1922 fn iob_loc_maps_to_location_not_concept_v1025() {
1923 let tokens = vec!["S\u{e3}o".to_string(), "Paulo".to_string()];
1926 let labels = vec!["B-LOC".to_string(), "I-LOC".to_string()];
1927 let ents = iob_to_entities(&tokens, &labels);
1928 assert_eq!(
1929 ents[0].entity_type, "location",
1930 "B-LOC must map to location (V008); got {}",
1931 ents[0].entity_type
1932 );
1933 }
1934
1935 #[test]
1936 fn iob_date_maps_to_date_not_discarded_v1025() {
1937 let tokens = vec!["2025".to_string(), "-".to_string(), "12".to_string()];
1939 let labels = vec![
1940 "B-DATE".to_string(),
1941 "I-DATE".to_string(),
1942 "I-DATE".to_string(),
1943 ];
1944 let ents = iob_to_entities(&tokens, &labels);
1945 assert_eq!(
1946 ents.len(),
1947 1,
1948 "DATE entity must be emitted (V008); entities: {ents:?}"
1949 );
1950 assert_eq!(ents[0].entity_type, "date");
1951 }
1952
1953 #[test]
1956 fn pt_verb_le_filtered_as_per_v1025() {
1957 let tokens = vec!["L\u{ea}".to_string(), "o".to_string(), "livro".to_string()];
1960 let labels = vec!["B-PER".to_string(), "O".to_string(), "O".to_string()];
1961 let ents = iob_to_entities(&tokens, &labels);
1962 assert!(
1963 !ents
1964 .iter()
1965 .any(|e| e.name == "L\u{ea}" && e.entity_type == "person"),
1966 "PT verb 'L\\u{{ea}}' tagged B-PER must be filtered; entities: {ents:?}"
1967 );
1968 }
1969
1970 #[test]
1971 fn pt_verb_ver_filtered_as_per_v1025() {
1972 let tokens = vec!["Ver".to_string()];
1974 let labels = vec!["B-PER".to_string()];
1975 let ents = iob_to_entities(&tokens, &labels);
1976 assert!(
1977 ents.is_empty(),
1978 "PT verb 'Ver' tagged B-PER must be filtered; entities: {ents:?}"
1979 );
1980 }
1981
1982 fn entity(name: &str, entity_type: &str) -> ExtractedEntity {
1985 ExtractedEntity {
1986 name: name.to_string(),
1987 entity_type: entity_type.to_string(),
1988 }
1989 }
1990
1991 #[test]
1992 fn merge_resolves_sonne_vs_sonnet_keeps_longest_v1025() {
1993 let regex = vec![entity("Sonne", "concept")];
1995 let ner = vec![entity("Sonnet", "concept")];
1996 let result = merge_and_deduplicate(regex, ner);
1997 assert_eq!(result.len(), 1, "expected 1 entity, got: {result:?}");
1998 assert_eq!(result[0].name, "Sonnet");
1999 }
2000
2001 #[test]
2002 fn merge_resolves_open_vs_openai_keeps_longest_v1025() {
2003 let regex = vec![
2005 entity("Open", "organization"),
2006 entity("OpenAI", "organization"),
2007 ];
2008 let result = merge_and_deduplicate(regex, vec![]);
2009 assert_eq!(result.len(), 1, "expected 1 entity, got: {result:?}");
2010 assert_eq!(result[0].name, "OpenAI");
2011 }
2012
2013 #[test]
2014 fn merge_keeps_both_when_no_containment_v1025() {
2015 let regex = vec![entity("Alice", "person"), entity("Bob", "person")];
2017 let result = merge_and_deduplicate(regex, vec![]);
2018 assert_eq!(result.len(), 2, "expected 2 entities, got: {result:?}");
2019 }
2020
2021 #[test]
2022 fn merge_respects_entity_type_boundary_v1025() {
2023 let regex = vec![entity("Apple", "organization"), entity("Apple", "concept")];
2025 let result = merge_and_deduplicate(regex, vec![]);
2026 assert_eq!(
2027 result.len(),
2028 2,
2029 "expected 2 entities (different types), got: {result:?}"
2030 );
2031 }
2032
2033 #[test]
2034 fn merge_case_insensitive_dedup_v1025() {
2035 let regex = vec![
2037 entity("OpenAI", "organization"),
2038 entity("openai", "organization"),
2039 ];
2040 let result = merge_and_deduplicate(regex, vec![]);
2041 assert_eq!(
2042 result.len(),
2043 1,
2044 "expected 1 entity after case-insensitive dedup, got: {result:?}"
2045 );
2046 }
2047
2048 #[test]
2051 fn iob_section_marker_etapa_filtered_v1025() {
2052 let tokens = vec!["Etapa".to_string(), "3".to_string()];
2054 let labels = vec!["B-MISC".to_string(), "I-MISC".to_string()];
2055 let ents = iob_to_entities(&tokens, &labels);
2056 assert!(
2057 !ents.iter().any(|e| e.name.contains("Etapa")),
2058 "section marker 'Etapa 3' from BERT must be filtered; entities: {ents:?}"
2059 );
2060 }
2061
2062 #[test]
2063 fn iob_section_marker_fase_filtered_v1025() {
2064 let tokens = vec!["Fase".to_string(), "1".to_string()];
2066 let labels = vec!["B-MISC".to_string(), "I-MISC".to_string()];
2067 let ents = iob_to_entities(&tokens, &labels);
2068 assert!(
2069 !ents.iter().any(|e| e.name.contains("Fase")),
2070 "section marker 'Fase 1' from BERT must be filtered; entities: {ents:?}"
2071 );
2072 }
2073
2074 #[test]
2077 fn extract_graph_auto_handles_large_body_under_30s() {
2078 let body = "x ".repeat(40_000);
2081 let paths = make_paths();
2082 let start = std::time::Instant::now();
2083 let result = extract_graph_auto(&body, &paths).expect("extraction must not error");
2084 let elapsed = start.elapsed();
2085 assert!(
2086 elapsed.as_secs() < 30,
2087 "extract_graph_auto took {}s for 80 KB body (cap should keep it well under 30s)",
2088 elapsed.as_secs()
2089 );
2090 let _ = result.entities;
2092 }
2093
2094 #[test]
2097 fn pt_uppercase_stopwords_filtered_v1031() {
2098 let body = "Para o ADAPTER funcionar com PROJETO em modo PASSIVA, devemos usar \
2099 SOMENTE LEITURA conforme a REGRA OBRIGATORIA do EXEMPLO DEFAULT.";
2100 let ents = apply_regex_prefilter(body);
2101 let names: Vec<String> = ents.iter().map(|e| e.name.to_uppercase()).collect();
2102 for stop in &[
2103 "ADAPTER",
2104 "PROJETO",
2105 "PASSIVA",
2106 "SOMENTE",
2107 "LEITURA",
2108 "REGRA",
2109 "OBRIGATORIA",
2110 "EXEMPLO",
2111 "DEFAULT",
2112 ] {
2113 assert!(
2114 !names.contains(&stop.to_string()),
2115 "v1.0.31 A11 stoplist failed: {stop} leaked as entity; got names: {names:?}"
2116 );
2117 }
2118 }
2119
2120 #[test]
2121 fn pt_underscored_identifier_preserved_v1031() {
2122 let ents = apply_regex_prefilter("configure FLOWAIPER_API_KEY=foo and MAX_TIMEOUT=30");
2125 let names: Vec<&str> = ents.iter().map(|e| e.name.as_str()).collect();
2126 assert!(names.contains(&"FLOWAIPER_API_KEY"));
2127 assert!(names.contains(&"MAX_TIMEOUT"));
2128 }
2129
2130 #[test]
2133 fn build_relationships_by_sentence_only_links_co_occurring_entities() {
2134 let body = "Alice met Bob at the conference. Carol works alone in another room.";
2135 let entities = vec![
2136 NewEntity {
2137 name: "Alice".to_string(),
2138 entity_type: "person".to_string(),
2139 description: None,
2140 },
2141 NewEntity {
2142 name: "Bob".to_string(),
2143 entity_type: "person".to_string(),
2144 description: None,
2145 },
2146 NewEntity {
2147 name: "Carol".to_string(),
2148 entity_type: "person".to_string(),
2149 description: None,
2150 },
2151 ];
2152 let (rels, truncated) = build_relationships_by_sentence_cooccurrence(body, &entities);
2153 assert!(!truncated);
2154 assert_eq!(
2155 rels.len(),
2156 1,
2157 "only Alice/Bob should pair (same sentence); Carol is isolated"
2158 );
2159 let pair = (rels[0].source.as_str(), rels[0].target.as_str());
2160 assert!(
2161 matches!(pair, ("Alice", "Bob") | ("Bob", "Alice")),
2162 "unexpected pair {pair:?}"
2163 );
2164 }
2165
2166 #[test]
2167 fn build_relationships_by_sentence_returns_empty_for_single_entity() {
2168 let body = "Alice is here.";
2169 let entities = vec![NewEntity {
2170 name: "Alice".to_string(),
2171 entity_type: "person".to_string(),
2172 description: None,
2173 }];
2174 let (rels, truncated) = build_relationships_by_sentence_cooccurrence(body, &entities);
2175 assert!(rels.is_empty());
2176 assert!(!truncated);
2177 }
2178
2179 #[test]
2180 fn build_relationships_by_sentence_dedupes_pairs_across_sentences() {
2181 let body = "Alice met Bob. Bob saw Alice again.";
2182 let entities = vec![
2183 NewEntity {
2184 name: "Alice".to_string(),
2185 entity_type: "person".to_string(),
2186 description: None,
2187 },
2188 NewEntity {
2189 name: "Bob".to_string(),
2190 entity_type: "person".to_string(),
2191 description: None,
2192 },
2193 ];
2194 let (rels, _) = build_relationships_by_sentence_cooccurrence(body, &entities);
2195 assert_eq!(
2196 rels.len(),
2197 1,
2198 "Alice/Bob pair must be emitted only once even when co-occurring in multiple sentences"
2199 );
2200 }
2201
2202 #[test]
2203 fn extraction_max_tokens_default_is_5000() {
2204 std::env::remove_var("SQLITE_GRAPHRAG_EXTRACTION_MAX_TOKENS");
2205 assert_eq!(crate::constants::extraction_max_tokens(), 5_000);
2206 }
2207
2208 #[test]
2209 fn extraction_max_tokens_env_override_clamped() {
2210 std::env::set_var("SQLITE_GRAPHRAG_EXTRACTION_MAX_TOKENS", "200");
2211 assert_eq!(
2212 crate::constants::extraction_max_tokens(),
2213 5_000,
2214 "value below 512 must fall back to default"
2215 );
2216
2217 std::env::set_var("SQLITE_GRAPHRAG_EXTRACTION_MAX_TOKENS", "200000");
2218 assert_eq!(
2219 crate::constants::extraction_max_tokens(),
2220 5_000,
2221 "value above 100_000 must fall back to default"
2222 );
2223
2224 std::env::set_var("SQLITE_GRAPHRAG_EXTRACTION_MAX_TOKENS", "8000");
2225 assert_eq!(
2226 crate::constants::extraction_max_tokens(),
2227 8_000,
2228 "valid value must be honoured"
2229 );
2230
2231 std::env::remove_var("SQLITE_GRAPHRAG_EXTRACTION_MAX_TOKENS");
2232 }
2233}