1use std::collections::HashMap;
7use std::path::{Path, PathBuf};
8use std::sync::OnceLock;
9
10use anyhow::{Context, Result};
11use candle_core::{DType, Device, Tensor};
12use candle_nn::{Linear, Module, VarBuilder};
13use candle_transformers::models::bert::{BertModel, Config as BertConfig};
14use regex::Regex;
15use serde::Deserialize;
16use unicode_normalization::UnicodeNormalization;
17
18use crate::entity_type::EntityType;
19use crate::paths::AppPaths;
20use crate::storage::entities::{NewEntity, NewRelationship};
21
22const MODEL_ID: &str = "Davlan/bert-base-multilingual-cased-ner-hrl";
23const MAX_SEQ_LEN: usize = 512;
24const STRIDE: usize = 256;
25const MAX_ENTS: usize = 30;
26#[cfg(test)]
29const TOP_K_RELATIONS: usize = 5;
30const DEFAULT_RELATION: &str = "mentions";
31const MIN_ENTITY_CHARS: usize = 2;
32
33static REGEX_EMAIL: OnceLock<Regex> = OnceLock::new();
34static REGEX_URL: OnceLock<Regex> = OnceLock::new();
35static REGEX_UUID: OnceLock<Regex> = OnceLock::new();
36static REGEX_ALL_CAPS: OnceLock<Regex> = OnceLock::new();
37static REGEX_SECTION_MARKER: OnceLock<Regex> = OnceLock::new();
39static REGEX_BRAND_CAMEL: OnceLock<Regex> = OnceLock::new();
41
42const ALL_CAPS_STOPWORDS: &[&str] = &[
60 "ACEITE",
61 "ACID",
62 "ACK",
63 "ACL",
64 "ACRESCENTADO",
65 "ADAPTER",
66 "ADICIONADA",
67 "ADICIONADAS",
68 "ADICIONADO",
69 "ADICIONADOS",
70 "ADICIONAR",
71 "AGENTS",
72 "AINDA",
73 "ALL",
74 "ALTA",
75 "ALWAYS",
76 "APENAS",
77 "API",
78 "ARTEFATOS",
79 "ATIVA",
80 "ATIVO",
81 "BAIXA",
82 "BANCO",
83 "BLOQUEAR",
84 "BORDA",
85 "BUG",
86 "CAPÍTULO",
87 "CASO",
88 "CEO",
89 "CHECKLIST",
90 "CLARO",
91 "CLAUDE_STREAM_IDLE_TIMEOUT_MS",
92 "CLI",
93 "COMPLETED",
94 "CONFIRMADO",
95 "CONFIRMARAM",
96 "CONFIRME",
97 "CONFIRMEI",
98 "CONFIRMOU",
99 "CONTRATO",
100 "CRIE",
101 "CRÍTICO",
102 "CRITICAL",
103 "CSV",
104 "DDL",
105 "DEFAULT",
106 "DEFINIR",
107 "DEPARTMENT",
108 "DESC",
109 "DEVE",
110 "DEVEMOS",
111 "DISCO",
112 "DONE",
113 "DSL",
114 "DTO",
115 "EFEITO",
116 "ENTRADA",
117 "EOF",
118 "EPERM",
119 "ERROR",
120 "ESCREVA",
121 "ESCRITA",
122 "ESRCH",
123 "ESSA",
124 "ESSE",
125 "ESSENCIAL",
126 "ESTA",
127 "ESTADO",
128 "ESTE",
129 "ETAPA",
130 "EVITAR",
131 "EXEMPLO",
132 "EXPANDIR",
133 "EXPOR",
134 "FALHA",
135 "FASE",
136 "FATO",
137 "FIFO",
138 "FIXED",
139 "FIXME",
140 "FLUXO",
141 "FONTES",
142 "FORBIDDEN",
143 "FUNCIONA",
144 "GNU",
145 "HACK",
146 "HEARTBEAT",
147 "HTTP",
148 "HTTPS",
149 "INATIVO",
150 "JAMAIS",
151 "JSON",
152 "JWT",
153 "LEITURA",
154 "LLM",
155 "MCP",
156 "MESMO",
157 "METADADOS",
158 "MUST",
159 "NDJSON",
160 "NEGUE",
161 "NEVER",
162 "NOTE",
163 "NUNCA",
164 "OBRIGATORIA",
165 "OBRIGATÓRIO",
166 "OBSERVEI",
167 "PADRÃO",
168 "PASSIVA",
169 "PASSO",
170 "PENDING",
171 "PGID",
172 "PID",
173 "PLAN",
174 "PODEMOS",
175 "PONTEIROS",
176 "PREFERIR",
177 "PROIBIDO",
178 "PROJETO",
179 "RECUSE",
180 "REGRA",
181 "REGRAS",
182 "REMOVIDAS",
183 "REQUIRED",
184 "REQUISITO",
185 "REST",
186 "SEÇÃO",
187 "SEMPRE",
188 "SHALL",
189 "SHOULD",
190 "SIGTERM",
191 "SOMENTE",
192 "SOUL",
193 "TODAS",
194 "TODO",
195 "TODOS",
196 "TOKEN",
197 "TOOLS",
198 "TSV",
199 "TUI",
200 "UI",
201 "URL",
202 "USAR",
203 "VALIDAR",
204 "VAMOS",
205 "VOCÊ",
206 "WARNING",
207 "XML",
208 "YAML",
209];
210
211const HTTP_METHODS: &[&str] = &[
214 "GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS", "CONNECT", "TRACE",
215];
216
217fn is_filtered_all_caps(token: &str) -> bool {
218 let is_identifier = token.contains('_');
220 if is_identifier {
221 return false;
222 }
223 ALL_CAPS_STOPWORDS.contains(&token) || HTTP_METHODS.contains(&token)
224}
225
226fn regex_email() -> &'static Regex {
227 REGEX_EMAIL.get_or_init(|| {
229 Regex::new(r"[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}")
230 .expect("compile-time validated email regex literal")
231 })
232}
233
234fn regex_url() -> &'static Regex {
235 REGEX_URL.get_or_init(|| {
237 Regex::new(r#"https?://[^\s\)\]\}"'<>]+"#)
238 .expect("compile-time validated URL regex literal")
239 })
240}
241
242fn regex_uuid() -> &'static Regex {
243 REGEX_UUID.get_or_init(|| {
245 Regex::new(r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}")
246 .expect("compile-time validated UUID regex literal")
247 })
248}
249
250fn regex_all_caps() -> &'static Regex {
251 REGEX_ALL_CAPS.get_or_init(|| {
252 Regex::new(r"\b[A-Z][A-Z0-9_]{2,}\b")
253 .expect("compile-time validated all-caps regex literal")
254 })
255}
256
257fn regex_section_marker() -> &'static Regex {
258 REGEX_SECTION_MARKER.get_or_init(|| {
259 Regex::new("\\b(?:Etapa|Fase|Passo|Camada|Se\u{00e7}\u{00e3}o|Cap\u{00ed}tulo)\\s+\\d+\\b")
266 .expect("compile-time validated section marker regex literal")
267 })
268}
269
270fn regex_brand_camel() -> &'static Regex {
271 REGEX_BRAND_CAMEL.get_or_init(|| {
272 Regex::new(r"\b[A-Z][a-z]+[A-Z][A-Za-z]+\b")
275 .expect("compile-time validated CamelCase brand regex literal")
276 })
277}
278
279#[derive(Debug, Clone, PartialEq)]
280pub struct ExtractedEntity {
281 pub name: String,
282 pub entity_type: EntityType,
283}
284
285#[derive(Debug, Clone)]
287pub struct ExtractedUrl {
288 pub url: String,
289 pub offset: usize,
291}
292
293#[derive(Debug, Clone)]
294pub struct ExtractionResult {
295 pub entities: Vec<NewEntity>,
296 pub relationships: Vec<NewRelationship>,
297 pub relationships_truncated: bool,
300 pub extraction_method: String,
303 pub urls: Vec<ExtractedUrl>,
305}
306
307pub trait Extractor: Send + Sync {
308 fn extract(&self, body: &str) -> Result<ExtractionResult>;
309}
310
311#[derive(Deserialize)]
312struct ModelConfig {
313 #[serde(default)]
314 id2label: HashMap<String, String>,
315 hidden_size: usize,
316}
317
318struct BertNerModel {
319 bert: BertModel,
320 classifier: Linear,
321 device: Device,
322 id2label: HashMap<usize, String>,
323}
324
325impl BertNerModel {
326 fn load(model_dir: &Path) -> Result<Self> {
327 let config_path = model_dir.join("config.json");
328 let weights_path = model_dir.join("model.safetensors");
329
330 let config_str = std::fs::read_to_string(&config_path)
331 .with_context(|| format!("lendo config.json em {config_path:?}"))?;
332 let model_cfg: ModelConfig =
333 serde_json::from_str(&config_str).context("parseando config.json do modelo NER")?;
334
335 let id2label: HashMap<usize, String> = model_cfg
336 .id2label
337 .into_iter()
338 .filter_map(|(k, v)| k.parse::<usize>().ok().map(|n| (n, v)))
339 .collect();
340
341 let num_labels = id2label.len().max(9);
342 let hidden_size = model_cfg.hidden_size;
343
344 let bert_config_str = std::fs::read_to_string(&config_path)
345 .with_context(|| format!("relendo config.json para bert em {config_path:?}"))?;
346 let bert_cfg: BertConfig =
347 serde_json::from_str(&bert_config_str).context("parseando BertConfig")?;
348
349 let device = Device::Cpu;
350
351 let vb = unsafe {
359 VarBuilder::from_mmaped_safetensors(&[&weights_path], DType::F32, &device)
360 .with_context(|| format!("mapping {weights_path:?}"))?
361 };
362 let bert = BertModel::load(vb.pp("bert"), &bert_cfg).context("loading BertModel")?;
363
364 let cls_vb = vb.pp("classifier");
367 let weight = cls_vb
368 .get((num_labels, hidden_size), "weight")
369 .context("carregando classifier.weight do safetensors")?;
370 let bias = cls_vb
371 .get(num_labels, "bias")
372 .context("carregando classifier.bias do safetensors")?;
373 let classifier = Linear::new(weight, Some(bias));
374
375 Ok(Self {
376 bert,
377 classifier,
378 device,
379 id2label,
380 })
381 }
382
383 fn predict(&self, token_ids: &[u32], attention_mask: &[u32]) -> Result<Vec<String>> {
384 let len = token_ids.len();
385 let ids_i64: Vec<i64> = token_ids.iter().map(|&x| x as i64).collect();
386 let mask_i64: Vec<i64> = attention_mask.iter().map(|&x| x as i64).collect();
387
388 let input_ids = Tensor::from_vec(ids_i64, (1, len), &self.device)
389 .context("creating tensor input_ids")?;
390 let token_type_ids = Tensor::zeros((1, len), DType::I64, &self.device)
391 .context("creating tensor token_type_ids")?;
392 let attn_mask = Tensor::from_vec(mask_i64, (1, len), &self.device)
393 .context("creating tensor attention_mask")?;
394
395 let sequence_output = self
396 .bert
397 .forward(&input_ids, &token_type_ids, Some(&attn_mask))
398 .context("BertModel forward pass")?;
399
400 let logits = self
401 .classifier
402 .forward(&sequence_output)
403 .context("classifier forward pass")?;
404
405 let logits_2d = logits.squeeze(0).context("removing batch dimension")?;
406
407 let num_tokens = logits_2d.dim(0).context("dim(0)")?;
408
409 let mut labels = Vec::with_capacity(num_tokens);
410 for i in 0..num_tokens {
411 let token_logits = logits_2d.get(i).context("get token logits")?;
412 let vec: Vec<f32> = token_logits.to_vec1().context("to_vec1 logits")?;
413 let argmax = vec
414 .iter()
415 .enumerate()
416 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
417 .map(|(idx, _)| idx)
418 .unwrap_or(0);
419 let label = self
420 .id2label
421 .get(&argmax)
422 .cloned()
423 .unwrap_or_else(|| "O".to_string());
424 labels.push(label);
425 }
426
427 Ok(labels)
428 }
429
430 fn predict_batch(&self, windows: &[(Vec<u32>, Vec<String>)]) -> Result<Vec<Vec<String>>> {
439 let batch_size = windows.len();
440 let max_len = windows.iter().map(|(ids, _)| ids.len()).max().unwrap_or(0);
441 if max_len == 0 {
442 return Ok(vec![vec![]; batch_size]);
443 }
444
445 let mut padded_ids: Vec<Tensor> = Vec::with_capacity(batch_size);
446 let mut padded_masks: Vec<Tensor> = Vec::with_capacity(batch_size);
447
448 for (ids, _) in windows {
449 let len = ids.len();
450 let pad_right = max_len - len;
451
452 let ids_i64: Vec<i64> = ids.iter().map(|&x| x as i64).collect();
453 let t = Tensor::from_vec(ids_i64, len, &self.device)
455 .context("creating id tensor for batch")?;
456 let t = t
457 .pad_with_zeros(0, 0, pad_right)
458 .context("padding id tensor")?;
459 padded_ids.push(t);
460
461 let mut mask_i64 = vec![1i64; len];
463 mask_i64.extend(vec![0i64; pad_right]);
464 let m = Tensor::from_vec(mask_i64, max_len, &self.device)
465 .context("creating mask tensor for batch")?;
466 padded_masks.push(m);
467 }
468
469 let input_ids = Tensor::stack(&padded_ids, 0).context("stack input_ids")?;
471 let attn_mask = Tensor::stack(&padded_masks, 0).context("stack attn_mask")?;
472 let token_type_ids = Tensor::zeros((batch_size, max_len), DType::I64, &self.device)
473 .context("creating token_type_ids tensor for batch")?;
474
475 let sequence_output = self
477 .bert
478 .forward(&input_ids, &token_type_ids, Some(&attn_mask))
479 .context("BertModel batch forward pass")?;
480 let logits = self
483 .classifier
484 .forward(&sequence_output)
485 .context("forward pass batch classificador")?;
486 let mut results = Vec::with_capacity(batch_size);
489 for (i, (window_ids, _)) in windows.iter().enumerate() {
490 let example_logits = logits.get(i).context("get logits exemplo")?;
491 let real_len = window_ids.len();
493 let example_slice = example_logits
494 .narrow(0, 0, real_len)
495 .context("narrow para tokens reais")?;
496 let logits_2d: Vec<Vec<f32>> = example_slice.to_vec2().context("to_vec2 logits")?;
497
498 let labels: Vec<String> = logits_2d
499 .iter()
500 .map(|token_logits| {
501 let argmax = token_logits
502 .iter()
503 .enumerate()
504 .max_by(|(_, a), (_, b)| {
505 a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
506 })
507 .map(|(idx, _)| idx)
508 .unwrap_or(0);
509 self.id2label
510 .get(&argmax)
511 .cloned()
512 .unwrap_or_else(|| "O".to_string())
513 })
514 .collect();
515
516 results.push(labels);
517 }
518
519 Ok(results)
520 }
521}
522
523static NER_MODEL: OnceLock<Option<BertNerModel>> = OnceLock::new();
524
525fn get_or_init_model(paths: &AppPaths) -> Option<&'static BertNerModel> {
526 NER_MODEL
527 .get_or_init(|| match load_model(paths) {
528 Ok(m) => Some(m),
529 Err(e) => {
530 tracing::warn!("NER model unavailable (graceful degradation): {e:#}");
531 None
532 }
533 })
534 .as_ref()
535}
536
537fn model_dir(paths: &AppPaths) -> PathBuf {
538 paths.models.join("bert-multilingual-ner")
539}
540
541fn ensure_model_files(paths: &AppPaths) -> Result<PathBuf> {
542 let dir = model_dir(paths);
543 std::fs::create_dir_all(&dir).with_context(|| format!("creating model directory: {dir:?}"))?;
544
545 let weights = dir.join("model.safetensors");
546 let config = dir.join("config.json");
547 let tokenizer = dir.join("tokenizer.json");
548
549 if weights.exists() && config.exists() && tokenizer.exists() {
550 return Ok(dir);
551 }
552
553 tracing::info!("Downloading NER model (first run, ~676 MB)...");
554 crate::output::emit_progress_i18n(
555 "Downloading NER model (first run, ~676 MB)...",
556 crate::i18n::validation::runtime_pt::downloading_ner_model(),
557 );
558
559 let api = huggingface_hub::api::sync::Api::new().context("creating HF Hub client")?;
560 let repo = api.model(MODEL_ID.to_string());
561
562 for (remote, local) in &[
566 ("model.safetensors", "model.safetensors"),
567 ("config.json", "config.json"),
568 ("onnx/tokenizer.json", "tokenizer.json"),
569 ("tokenizer_config.json", "tokenizer_config.json"),
570 ] {
571 let dest = dir.join(local);
572 if !dest.exists() {
573 let src = repo
574 .get(remote)
575 .with_context(|| format!("baixando {remote} do HF Hub"))?;
576 std::fs::copy(&src, &dest).with_context(|| format!("copiando {local} para cache"))?;
577 }
578 }
579
580 Ok(dir)
581}
582
583fn load_model(paths: &AppPaths) -> Result<BertNerModel> {
584 let dir = ensure_model_files(paths)?;
585 BertNerModel::load(&dir)
586}
587
588#[inline]
590fn hash_str(s: &str) -> u64 {
591 use std::hash::{Hash, Hasher};
592 let mut h = std::collections::hash_map::DefaultHasher::new();
593 s.hash(&mut h);
594 h.finish()
595}
596
597fn apply_regex_prefilter(body: &str) -> Vec<ExtractedEntity> {
598 let mut entities = Vec::with_capacity(16);
599 let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
600
601 let add = |entities: &mut Vec<ExtractedEntity>,
602 seen: &mut std::collections::HashSet<String>,
603 name: &str,
604 entity_type: EntityType| {
605 let name = name.trim().to_string();
606 if name.len() >= MIN_ENTITY_CHARS && seen.insert(name.clone()) {
607 entities.push(ExtractedEntity { name, entity_type });
608 }
609 };
610
611 let cleaned = regex_section_marker().replace_all(body, " ");
614 let cleaned = cleaned.as_ref();
615
616 for m in regex_email().find_iter(cleaned) {
617 add(&mut entities, &mut seen, m.as_str(), EntityType::Concept);
619 }
620 for m in regex_uuid().find_iter(cleaned) {
621 add(&mut entities, &mut seen, m.as_str(), EntityType::Concept);
622 }
623 for m in regex_all_caps().find_iter(cleaned) {
624 let candidate = m.as_str();
625 if !is_filtered_all_caps(candidate) {
627 add(&mut entities, &mut seen, candidate, EntityType::Concept);
628 }
629 }
630 for m in regex_brand_camel().find_iter(cleaned) {
633 let name = m.as_str();
634 if !ALL_CAPS_STOPWORDS.contains(&name.to_uppercase().as_str()) {
636 add(&mut entities, &mut seen, name, EntityType::Organization);
637 }
638 }
639
640 entities
641}
642
643pub fn extract_urls(body: &str) -> Vec<ExtractedUrl> {
647 let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
648 let mut result = Vec::with_capacity(4);
649 for m in regex_url().find_iter(body) {
650 let raw = m.as_str();
651 let cleaned = raw
652 .trim_end_matches('`')
653 .trim_end_matches(',')
654 .trim_end_matches('.')
655 .trim_end_matches(';')
656 .trim_end_matches(')')
657 .trim_end_matches(']')
658 .trim_end_matches('}');
659 if cleaned.len() >= 10 && seen.insert(cleaned.to_string()) {
660 result.push(ExtractedUrl {
661 url: cleaned.to_string(),
662 offset: m.start(),
663 });
664 }
665 }
666 result
667}
668
669fn iob_to_entities(tokens: &[String], labels: &[String]) -> Vec<ExtractedEntity> {
670 let mut entities: Vec<ExtractedEntity> = Vec::with_capacity(tokens.len() / 4);
671 let mut current_parts: Vec<String> = Vec::new();
672 let mut current_type: Option<EntityType> = None;
673
674 let flush = |parts: &mut Vec<String>,
675 typ: &mut Option<EntityType>,
676 entities: &mut Vec<ExtractedEntity>| {
677 if let Some(t) = typ.take() {
678 let name = parts.join(" ").trim().to_string();
679 let is_single_caps = !name.contains(' ')
683 && name == name.to_uppercase()
684 && name.len() >= MIN_ENTITY_CHARS;
685 let should_skip = is_single_caps && is_filtered_all_caps(&name);
686 let is_section_marker = regex_section_marker().is_match(&name);
691 if name.len() >= MIN_ENTITY_CHARS && !should_skip && !is_section_marker {
692 entities.push(ExtractedEntity {
693 name,
694 entity_type: t,
695 });
696 }
697 parts.clear();
698 }
699 };
700
701 for (token, label) in tokens.iter().zip(labels.iter()) {
702 if label == "O" {
703 flush(&mut current_parts, &mut current_type, &mut entities);
704 continue;
705 }
706
707 let (prefix, bio_type) = if let Some(rest) = label.strip_prefix("B-") {
708 ("B", rest)
709 } else if let Some(rest) = label.strip_prefix("I-") {
710 ("I", rest)
711 } else {
712 flush(&mut current_parts, &mut current_type, &mut entities);
713 continue;
714 };
715
716 const PT_VERB_FALSE_POSITIVES: &[&str] = &[
723 "L\u{00ea}",
724 "V\u{00ea}",
725 "C\u{00e1}",
726 "P\u{00f4}r",
727 "Ser",
728 "Vir",
729 "Ver",
730 "Dar",
731 "Ler",
732 "Ter",
733 ];
734
735 let entity_type: EntityType = match bio_type {
736 "DATE" => EntityType::Date,
738 "PER" => {
739 if PT_VERB_FALSE_POSITIVES.contains(&token.as_str()) {
741 flush(&mut current_parts, &mut current_type, &mut entities);
742 continue;
743 }
744 EntityType::Person
745 }
746 "ORG" => {
747 let t = token.to_lowercase();
748 if t.contains("lib")
749 || t.contains("sdk")
750 || t.contains("cli")
751 || t.contains("crate")
752 || t.contains("npm")
753 {
754 EntityType::Tool
755 } else {
756 EntityType::Organization
758 }
759 }
760 "LOC" => EntityType::Location,
762 _ => EntityType::Concept,
764 };
765
766 if prefix == "B" {
767 if token.starts_with("##") {
768 let clean = token.strip_prefix("##").unwrap_or(token.as_str());
771 if let Some(last) = current_parts.last_mut() {
772 last.push_str(clean);
773 }
774 continue;
775 }
776 flush(&mut current_parts, &mut current_type, &mut entities);
777 current_parts.push(token.clone());
778 current_type = Some(entity_type);
779 } else if prefix == "I" && current_type.is_some() {
780 let clean = token.strip_prefix("##").unwrap_or(token.as_str());
781 if token.starts_with("##") {
782 if let Some(last) = current_parts.last_mut() {
783 last.push_str(clean);
784 }
785 } else {
786 current_parts.push(clean.to_string());
787 }
788 }
789 }
790
791 flush(&mut current_parts, &mut current_type, &mut entities);
792 entities
793}
794
795#[cfg(test)]
805fn build_relationships(entities: &[NewEntity]) -> (Vec<NewRelationship>, bool) {
806 if entities.len() < 2 {
807 return (Vec::new(), false);
808 }
809
810 let max_rels = crate::constants::max_relationships_per_memory();
813 let n = entities.len().min(MAX_ENTS);
814 let mut rels: Vec<NewRelationship> = Vec::new();
815 let mut seen: std::collections::HashSet<(usize, usize)> = std::collections::HashSet::new();
816
817 let mut hit_cap = false;
818 'outer: for i in 0..n {
819 if rels.len() >= max_rels {
820 hit_cap = true;
821 break;
822 }
823
824 let mut for_entity = 0usize;
825 for j in (i + 1)..n {
826 if for_entity >= TOP_K_RELATIONS {
827 break;
828 }
829 if rels.len() >= max_rels {
830 hit_cap = true;
831 break 'outer;
832 }
833
834 let key = (i.min(j), i.max(j));
835 if !seen.insert(key) {
836 continue;
837 }
838
839 rels.push(NewRelationship {
840 source: entities[i].name.clone(),
842 target: entities[j].name.clone(),
843 relation: DEFAULT_RELATION.to_string(),
844 strength: 0.5,
845 description: None,
846 });
847 for_entity += 1;
848 }
849 }
850
851 if hit_cap {
853 tracing::warn!(
854 "relationships truncated to {max_rels} (with {n} entities, theoretical max was ~{}x combinations)",
855 n.saturating_sub(1)
856 );
857 }
858
859 (rels, hit_cap)
860}
861
862fn build_relationships_by_sentence_cooccurrence(
873 body: &str,
874 entities: &[NewEntity],
875) -> (Vec<NewRelationship>, bool) {
876 if entities.len() < 2 {
877 return (Vec::new(), false);
878 }
879
880 let max_rels = crate::constants::max_relationships_per_memory();
881 let lower_names: Vec<(usize, String)> = entities
882 .iter()
883 .take(MAX_ENTS)
884 .enumerate()
885 .map(|(i, e)| (i, e.name.to_lowercase()))
886 .collect();
887
888 let mut rels: Vec<NewRelationship> = Vec::new();
889 let mut seen: std::collections::HashSet<(usize, usize)> = std::collections::HashSet::new();
890 let mut hit_cap = false;
891
892 for sentence in body.split(['.', '!', '?', '\n']) {
893 if sentence.trim().is_empty() {
894 continue;
895 }
896 let lower_sentence = sentence.to_lowercase();
897 let present: Vec<usize> = lower_names
898 .iter()
899 .filter(|(_, name)| !name.is_empty() && lower_sentence.contains(name.as_str()))
900 .map(|(i, _)| *i)
901 .collect();
902
903 if present.len() < 2 {
904 continue;
905 }
906
907 for i in 0..present.len() {
908 for j in (i + 1)..present.len() {
909 if rels.len() >= max_rels {
910 hit_cap = true;
911 tracing::warn!(
912 "relationships truncated to {max_rels} during sentence-level pairing"
913 );
914 return (rels, hit_cap);
915 }
916 let ei = present[i];
917 let ej = present[j];
918 let key = (ei.min(ej), ei.max(ej));
919 if seen.insert(key) {
920 rels.push(NewRelationship {
921 source: entities[ei].name.clone(),
922 target: entities[ej].name.clone(),
923 relation: DEFAULT_RELATION.to_string(),
924 strength: 0.5,
925 description: None,
926 });
927 }
928 }
929 }
930 }
931
932 (rels, hit_cap)
933}
934
935fn run_ner_sliding_window(
936 model: &BertNerModel,
937 body: &str,
938 paths: &AppPaths,
939) -> Result<Vec<ExtractedEntity>> {
940 let tokenizer_path = model_dir(paths).join("tokenizer.json");
941 let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
942 .map_err(|e| anyhow::anyhow!("loading NER tokenizer: {e}"))?;
943
944 let encoding = tokenizer
945 .encode(body, false)
946 .map_err(|e| anyhow::anyhow!("encoding NER input: {e}"))?;
947
948 let mut all_ids: Vec<u32> = encoding.get_ids().to_vec();
949 let mut all_tokens: Vec<String> = encoding
950 .get_tokens()
951 .iter()
952 .map(|s| s.to_string())
953 .collect();
954
955 if all_ids.is_empty() {
956 return Ok(Vec::new());
957 }
958
959 let max_tokens = crate::constants::extraction_max_tokens();
967 if all_ids.len() > max_tokens {
968 tracing::warn!(
969 target: "extraction",
970 original_tokens = all_ids.len(),
971 capped_tokens = max_tokens,
972 "NER input truncated to cap; later body region will be skipped by NER (regex prefilter still covers full body)"
973 );
974 all_ids.truncate(max_tokens);
975 all_tokens.truncate(max_tokens);
976 }
977
978 let mut windows: Vec<(Vec<u32>, Vec<String>)> = Vec::new();
980 let mut start = 0usize;
981 loop {
982 let end = (start + MAX_SEQ_LEN).min(all_ids.len());
983 windows.push((
984 all_ids[start..end].to_vec(),
985 all_tokens[start..end].to_vec(),
986 ));
987 if end >= all_ids.len() {
988 break;
989 }
990 start += STRIDE;
991 }
992
993 windows.sort_by_key(|(ids, _)| ids.len());
995
996 let batch_size = crate::constants::ner_batch_size();
998 let mut entities: Vec<ExtractedEntity> = Vec::new();
999 let mut seen: std::collections::HashSet<u64> = std::collections::HashSet::new();
1000
1001 for chunk in windows.chunks(batch_size) {
1002 match model.predict_batch(chunk) {
1003 Ok(batch_labels) => {
1004 for (labels, (_, tokens)) in batch_labels.iter().zip(chunk.iter()) {
1005 for ent in iob_to_entities(tokens, labels) {
1006 if seen.insert(hash_str(&ent.name)) {
1007 entities.push(ent);
1008 }
1009 }
1010 }
1011 }
1012 Err(e) => {
1013 tracing::warn!(
1014 "batch NER failed (chunk of {} windows): {e:#} — falling back to single-window",
1015 chunk.len()
1016 );
1017 for (ids, tokens) in chunk {
1019 let mask = vec![1u32; ids.len()];
1020 match model.predict(ids, &mask) {
1021 Ok(labels) => {
1022 for ent in iob_to_entities(tokens, &labels) {
1023 if seen.insert(hash_str(&ent.name)) {
1024 entities.push(ent);
1025 }
1026 }
1027 }
1028 Err(e2) => {
1029 tracing::warn!("NER window fallback also failed: {e2:#}");
1030 }
1031 }
1032 }
1033 }
1034 }
1035 }
1036
1037 Ok(entities)
1038}
1039
1040fn extend_with_numeric_suffix(entities: Vec<ExtractedEntity>, body: &str) -> Vec<ExtractedEntity> {
1047 static SUFFIX_RE: OnceLock<Regex> = OnceLock::new();
1048 let suffix_re = SUFFIX_RE.get_or_init(|| {
1051 Regex::new(r"^([\-\s]+\d+(?:\.\d+)?[a-z]?)")
1052 .expect("compile-time validated numeric suffix regex literal")
1053 });
1054
1055 entities
1056 .into_iter()
1057 .map(|ent| {
1058 if let Some(pos) = body.find(&ent.name) {
1060 let after_pos = pos + ent.name.len();
1061 if after_pos < body.len() {
1062 let after = &body[after_pos..];
1063 if let Some(m) = suffix_re.find(after) {
1064 let suffix = m.as_str();
1065 if suffix.len() <= 7 {
1068 let mut extended = String::with_capacity(ent.name.len() + suffix.len());
1069 extended.push_str(&ent.name);
1070 extended.push_str(suffix);
1071 return ExtractedEntity {
1072 name: extended,
1073 entity_type: ent.entity_type,
1074 };
1075 }
1076 }
1077 }
1078 }
1079 ent
1080 })
1081 .collect()
1082}
1083
1084fn augment_versioned_model_names(
1104 entities: Vec<ExtractedEntity>,
1105 body: &str,
1106) -> Vec<ExtractedEntity> {
1107 static VERSIONED_MODEL_RE: OnceLock<Regex> = OnceLock::new();
1108 let model_re = VERSIONED_MODEL_RE.get_or_init(|| {
1115 Regex::new(
1116 r"\b([A-Z][A-Za-z]{2,15})[\s\-]+(\d+(?:\.\d+)?(?:[a-z]|x\d+[A-Za-z]?)?)(?:\s+(?:Sonnet|Opus|Haiku|Turbo|Pro|Lite|Mini|Nano|Flash|Ultra))?\b",
1117 )
1118 .expect("compile-time validated versioned model regex literal")
1119 });
1120
1121 let mut existing_lc: std::collections::HashSet<String> =
1122 entities.iter().map(|ent| ent.name.to_lowercase()).collect();
1123 let mut result = entities;
1124
1125 for caps in model_re.captures_iter(body) {
1126 let full_match = caps.get(0).map(|m| m.as_str()).unwrap_or("");
1127 if full_match.is_empty() || full_match.len() > 24 {
1130 continue;
1131 }
1132 let normalized_lc = full_match.to_lowercase();
1133 if existing_lc.contains(&normalized_lc) {
1134 continue;
1135 }
1136 if result.len() >= MAX_ENTS {
1139 break;
1140 }
1141 existing_lc.insert(normalized_lc);
1142 result.push(ExtractedEntity {
1143 name: full_match.to_string(),
1144 entity_type: EntityType::Concept,
1145 });
1146 }
1147
1148 result
1149}
1150
1151fn merge_and_deduplicate(
1152 regex_ents: Vec<ExtractedEntity>,
1153 ner_ents: Vec<ExtractedEntity>,
1154) -> Vec<ExtractedEntity> {
1155 let mut by_lc: std::collections::HashMap<String, usize> = std::collections::HashMap::new();
1170 let mut result: Vec<ExtractedEntity> = Vec::new();
1171 let mut truncated = false;
1172
1173 let total_input = regex_ents.len() + ner_ents.len();
1174 for ent in regex_ents.into_iter().chain(ner_ents) {
1175 let name_lc = ent.name.nfkc().collect::<String>().to_lowercase();
1176 let key = {
1180 let et = ent.entity_type.as_str();
1181 let mut k = String::with_capacity(et.len() + 1 + name_lc.len());
1182 k.push_str(et);
1183 k.push('\0');
1184 k.push_str(&name_lc);
1185 k
1186 };
1187
1188 let type_prefix = {
1193 let et = ent.entity_type.as_str();
1194 let mut p = String::with_capacity(et.len() + 1);
1195 p.push_str(et);
1196 p.push('\0');
1197 p
1198 };
1199 let mut collision_idx: Option<usize> = None;
1200 for (existing_key, idx) in &by_lc {
1201 if !existing_key.starts_with(&type_prefix) {
1203 continue;
1204 }
1205 let existing_name_lc = &existing_key[type_prefix.len()..];
1206 if existing_name_lc == name_lc
1207 || existing_name_lc.contains(name_lc.as_str())
1208 || name_lc.contains(existing_name_lc)
1209 {
1210 collision_idx = Some(*idx);
1211 break;
1212 }
1213 }
1214 match collision_idx {
1215 Some(idx) => {
1216 if ent.name.len() > result[idx].name.len() {
1219 let old_name_lc = result[idx].name.nfkc().collect::<String>().to_lowercase();
1220 let old_key = {
1221 let et = result[idx].entity_type.as_str();
1222 let mut k = String::with_capacity(et.len() + 1 + old_name_lc.len());
1223 k.push_str(et);
1224 k.push('\0');
1225 k.push_str(&old_name_lc);
1226 k
1227 };
1228 by_lc.remove(&old_key);
1229 result[idx] = ent;
1230 by_lc.insert(key, idx);
1231 }
1232 }
1233 None => {
1234 by_lc.insert(key, result.len());
1235 result.push(ent);
1236 }
1237 }
1238 if result.len() >= MAX_ENTS {
1239 truncated = true;
1240 break;
1241 }
1242 }
1243
1244 if truncated {
1246 tracing::warn!(
1247 "extraction truncated at {MAX_ENTS} entities (input had {total_input} candidates before deduplication)"
1248 );
1249 }
1250
1251 result
1252}
1253
1254fn to_new_entities(extracted: Vec<ExtractedEntity>) -> Vec<NewEntity> {
1255 extracted
1256 .into_iter()
1257 .map(|e| NewEntity {
1258 name: e.name,
1259 entity_type: e.entity_type,
1260 description: None,
1261 })
1262 .collect()
1263}
1264
1265pub fn extract_graph_auto(body: &str, paths: &AppPaths) -> Result<ExtractionResult> {
1266 let regex_entities = apply_regex_prefilter(body);
1267
1268 let mut bert_used = false;
1269 let ner_entities = match get_or_init_model(paths) {
1270 Some(model) => match run_ner_sliding_window(model, body, paths) {
1271 Ok(ents) => {
1272 bert_used = true;
1273 ents
1274 }
1275 Err(e) => {
1276 tracing::warn!("NER failed, falling back to regex-only extraction: {e:#}");
1277 Vec::new()
1278 }
1279 },
1280 None => Vec::new(),
1281 };
1282
1283 let merged = merge_and_deduplicate(regex_entities, ner_entities);
1284 let extended = extend_with_numeric_suffix(merged, body);
1286 let with_models = augment_versioned_model_names(extended, body);
1290 let with_models: Vec<ExtractedEntity> = with_models
1294 .into_iter()
1295 .filter(|e| !regex_section_marker().is_match(&e.name))
1296 .collect();
1297 let entities = to_new_entities(with_models);
1298 let (relationships, relationships_truncated) =
1299 build_relationships_by_sentence_cooccurrence(body, &entities);
1300
1301 let extraction_method = if bert_used {
1302 "bert+regex-batch".to_string()
1303 } else {
1304 "regex-only".to_string()
1305 };
1306
1307 let urls = extract_urls(body);
1308
1309 Ok(ExtractionResult {
1310 entities,
1311 relationships,
1312 relationships_truncated,
1313 extraction_method,
1314 urls,
1315 })
1316}
1317
1318pub struct RegexExtractor;
1319
1320impl Extractor for RegexExtractor {
1321 fn extract(&self, body: &str) -> Result<ExtractionResult> {
1322 let regex_entities = apply_regex_prefilter(body);
1323 let entities = to_new_entities(regex_entities);
1324 let (relationships, relationships_truncated) =
1325 build_relationships_by_sentence_cooccurrence(body, &entities);
1326 let urls = extract_urls(body);
1327 Ok(ExtractionResult {
1328 entities,
1329 relationships,
1330 relationships_truncated,
1331 extraction_method: "regex-only".to_string(),
1332 urls,
1333 })
1334 }
1335}
1336
1337#[cfg(test)]
1338mod tests {
1339 use super::*;
1340 use crate::entity_type::EntityType;
1341
1342 fn make_paths() -> AppPaths {
1343 use std::path::PathBuf;
1344 AppPaths {
1345 db: PathBuf::from("/tmp/test.sqlite"),
1346 models: PathBuf::from("/tmp/test_models"),
1347 }
1348 }
1349
1350 #[test]
1351 fn regex_email_captures_address() {
1352 let ents = apply_regex_prefilter("contact: someone@company.com for more info");
1353 assert!(ents
1355 .iter()
1356 .any(|e| e.name == "someone@company.com" && e.entity_type == EntityType::Concept));
1357 }
1358
1359 #[test]
1360 fn regex_all_caps_filters_pt_rule_word() {
1361 let ents = apply_regex_prefilter("NUNCA do this. PROIBIDO use X. DEVE follow Y.");
1363 assert!(
1364 !ents.iter().any(|e| e.name == "NUNCA"),
1365 "NUNCA must be filtered as a stopword"
1366 );
1367 assert!(
1368 !ents.iter().any(|e| e.name == "PROIBIDO"),
1369 "PROIBIDO must be filtered"
1370 );
1371 assert!(
1372 !ents.iter().any(|e| e.name == "DEVE"),
1373 "DEVE must be filtered"
1374 );
1375 }
1376
1377 #[test]
1378 fn regex_all_caps_accepts_underscored_constant() {
1379 let ents = apply_regex_prefilter("configure MAX_RETRY=3 and API_TIMEOUT=30");
1381 assert!(ents.iter().any(|e| e.name == "MAX_RETRY"));
1382 assert!(ents.iter().any(|e| e.name == "API_TIMEOUT"));
1383 }
1384
1385 #[test]
1386 fn regex_all_caps_accepts_domain_acronym() {
1387 let ents = apply_regex_prefilter("OPENAI launched GPT-5 with NVIDIA H100");
1389 assert!(ents.iter().any(|e| e.name == "OPENAI"));
1390 assert!(ents.iter().any(|e| e.name == "NVIDIA"));
1391 }
1392
1393 #[test]
1394 fn regex_url_does_not_appear_in_apply_regex_prefilter() {
1395 let ents = apply_regex_prefilter("see https://docs.rs/crate for details");
1397 assert!(
1398 !ents.iter().any(|e| e.name.starts_with("https://")),
1399 "URLs must not appear as entities after the P0-2 split"
1400 );
1401 }
1402
1403 #[test]
1404 fn extract_urls_captures_https() {
1405 let urls = extract_urls("see https://docs.rs/crate for details");
1406 assert_eq!(urls.len(), 1);
1407 assert_eq!(urls[0].url, "https://docs.rs/crate");
1408 assert!(urls[0].offset > 0);
1409 }
1410
1411 #[test]
1412 fn extract_urls_trim_sufixo_pontuacao() {
1413 let urls = extract_urls("link: https://example.com/path. fim");
1414 assert!(!urls.is_empty());
1415 assert!(
1416 !urls[0].url.ends_with('.'),
1417 "sufixo ponto deve ser removido"
1418 );
1419 }
1420
1421 #[test]
1422 fn extract_urls_dedupes_repeated() {
1423 let body = "https://example.com referenciado aqui e depois aqui https://example.com";
1424 let urls = extract_urls(body);
1425 assert_eq!(urls.len(), 1, "URLs repetidas devem ser deduplicadas");
1426 }
1427
1428 #[test]
1429 fn regex_uuid_captura_identificador() {
1430 let ents = apply_regex_prefilter("id=550e8400-e29b-41d4-a716-446655440000 no sistema");
1431 assert!(ents.iter().any(|e| e.entity_type == EntityType::Concept));
1432 }
1433
1434 #[test]
1435 fn regex_all_caps_captura_constante() {
1436 let ents = apply_regex_prefilter("configure MAX_RETRY e TIMEOUT_MS");
1437 assert!(ents.iter().any(|e| e.name == "MAX_RETRY"));
1438 assert!(ents.iter().any(|e| e.name == "TIMEOUT_MS"));
1439 }
1440
1441 #[test]
1442 fn regex_all_caps_ignores_short_words() {
1443 let ents = apply_regex_prefilter("use AI em seu projeto");
1444 assert!(
1445 !ents.iter().any(|e| e.name == "AI"),
1446 "AI tem apenas 2 chars, deve ser ignorado"
1447 );
1448 }
1449
1450 #[test]
1451 fn iob_decodes_per_to_person() {
1452 let tokens = vec![
1453 "John".to_string(),
1454 "Doe".to_string(),
1455 "trabalhou".to_string(),
1456 ];
1457 let labels = vec!["B-PER".to_string(), "I-PER".to_string(), "O".to_string()];
1458 let ents = iob_to_entities(&tokens, &labels);
1459 assert_eq!(ents.len(), 1);
1460 assert_eq!(ents[0].entity_type, EntityType::Person);
1461 assert!(ents[0].name.contains("John"));
1462 }
1463
1464 #[test]
1465 fn iob_strip_subword_b_prefix() {
1466 let tokens = vec!["Open".to_string(), "##AI".to_string()];
1469 let labels = vec!["B-ORG".to_string(), "B-ORG".to_string()];
1470 let ents = iob_to_entities(&tokens, &labels);
1471 assert!(
1472 ents.iter().any(|e| e.name == "OpenAI" || e.name == "Open"),
1473 "should merge ##AI or discard"
1474 );
1475 }
1476
1477 #[test]
1478 fn iob_subword_orphan_discards() {
1479 let tokens = vec!["##AI".to_string()];
1481 let labels = vec!["B-ORG".to_string()];
1482 let ents = iob_to_entities(&tokens, &labels);
1483 assert!(
1484 ents.is_empty(),
1485 "orphan subword without an active entity must be discarded"
1486 );
1487 }
1488
1489 #[test]
1490 fn iob_maps_date_to_date_v1025() {
1491 let tokens = vec!["January".to_string(), "2024".to_string()];
1493 let labels = vec!["B-DATE".to_string(), "I-DATE".to_string()];
1494 let ents = iob_to_entities(&tokens, &labels);
1495 assert_eq!(
1496 ents.len(),
1497 1,
1498 "DATE must be emitted as an entity in v1.0.25"
1499 );
1500 assert_eq!(ents[0].entity_type, EntityType::Date);
1501 }
1502
1503 #[test]
1504 fn iob_maps_org_to_organization_v1025() {
1505 let tokens = vec!["Empresa".to_string()];
1507 let labels = vec!["B-ORG".to_string()];
1508 let ents = iob_to_entities(&tokens, &labels);
1509 assert_eq!(ents[0].entity_type, EntityType::Organization);
1510 }
1511
1512 #[test]
1513 fn iob_maps_org_sdk_to_tool() {
1514 let tokens = vec!["tokio-sdk".to_string()];
1515 let labels = vec!["B-ORG".to_string()];
1516 let ents = iob_to_entities(&tokens, &labels);
1517 assert_eq!(ents[0].entity_type, EntityType::Tool);
1518 }
1519
1520 #[test]
1521 fn iob_maps_loc_to_location_v1025() {
1522 let tokens = vec!["Brasil".to_string()];
1524 let labels = vec!["B-LOC".to_string()];
1525 let ents = iob_to_entities(&tokens, &labels);
1526 assert_eq!(ents[0].entity_type, EntityType::Location);
1527 }
1528
1529 #[test]
1530 fn build_relationships_respeitam_max_rels() {
1531 let entities: Vec<NewEntity> = (0..20)
1532 .map(|i| NewEntity {
1533 name: format!("entidade_{i}"),
1534 entity_type: EntityType::Concept,
1535 description: None,
1536 })
1537 .collect();
1538 let (rels, truncated) = build_relationships(&entities);
1539 let max_rels = crate::constants::max_relationships_per_memory();
1540 assert!(rels.len() <= max_rels, "deve respeitar max_rels={max_rels}");
1541 if rels.len() == max_rels {
1542 assert!(truncated, "truncated deve ser true quando atingiu o cap");
1543 }
1544 }
1545
1546 #[test]
1547 fn build_relationships_without_duplicates() {
1548 let entities: Vec<NewEntity> = (0..5)
1549 .map(|i| NewEntity {
1550 name: format!("ent_{i}"),
1551 entity_type: EntityType::Concept,
1552 description: None,
1553 })
1554 .collect();
1555 let (rels, _truncated) = build_relationships(&entities);
1556 let mut pares: std::collections::HashSet<(String, String)> =
1557 std::collections::HashSet::new();
1558 for r in &rels {
1559 let par = (r.source.clone(), r.target.clone());
1560 assert!(pares.insert(par), "par duplicado encontrado");
1561 }
1562 }
1563
1564 #[test]
1565 fn merge_dedupes_by_lowercase_name() {
1566 let a = vec![ExtractedEntity {
1569 name: "Rust".to_string(),
1570 entity_type: EntityType::Concept,
1571 }];
1572 let b = vec![ExtractedEntity {
1573 name: "rust".to_string(),
1574 entity_type: EntityType::Concept,
1575 }];
1576 let merged = merge_and_deduplicate(a, b);
1577 assert_eq!(
1578 merged.len(),
1579 1,
1580 "rust and Rust with the same type are the same entity"
1581 );
1582 }
1583
1584 #[test]
1585 fn regex_extractor_implements_trait() {
1586 let extractor = RegexExtractor;
1587 let result = extractor
1588 .extract("contato: dev@empresa.io e MAX_TIMEOUT configurado")
1589 .unwrap();
1590 assert!(!result.entities.is_empty());
1591 }
1592
1593 #[test]
1594 fn extract_returns_ok_without_model() {
1595 let paths = make_paths();
1597 let body = "contato: teste@exemplo.com com MAX_RETRY=3";
1598 let result = extract_graph_auto(body, &paths).unwrap();
1599 assert!(result
1600 .entities
1601 .iter()
1602 .any(|e| e.name.contains("teste@exemplo.com")));
1603 }
1604
1605 #[test]
1606 fn stopwords_filter_v1024_terms() {
1607 let body = "ACEITE ACK ACL BORDA CHECKLIST COMPLETED CONFIRME \
1610 DEVEMOS DONE FIXED NEGUE PENDING PLAN PODEMOS RECUSE TOKEN VAMOS";
1611 let ents = apply_regex_prefilter(body);
1612 let names: Vec<&str> = ents.iter().map(|e| e.name.as_str()).collect();
1613 for word in &[
1614 "ACEITE",
1615 "ACK",
1616 "ACL",
1617 "BORDA",
1618 "CHECKLIST",
1619 "COMPLETED",
1620 "CONFIRME",
1621 "DEVEMOS",
1622 "DONE",
1623 "FIXED",
1624 "NEGUE",
1625 "PENDING",
1626 "PLAN",
1627 "PODEMOS",
1628 "RECUSE",
1629 "TOKEN",
1630 "VAMOS",
1631 ] {
1632 assert!(
1633 !names.contains(word),
1634 "v1.0.24 stopword {word} should be filtered but was found in entities"
1635 );
1636 }
1637 }
1638
1639 #[test]
1640 fn dedup_normalizes_unicode_combining_marks() {
1641 let nfc = vec![ExtractedEntity {
1645 name: "Caf\u{e9}".to_string(),
1646 entity_type: EntityType::Concept,
1647 }];
1648 let nfd_name = "Cafe\u{301}".to_string();
1650 let nfd = vec![ExtractedEntity {
1651 name: nfd_name,
1652 entity_type: EntityType::Concept,
1653 }];
1654 let merged = merge_and_deduplicate(nfc, nfd);
1655 assert_eq!(
1656 merged.len(),
1657 1,
1658 "NFC 'Caf\\u{{e9}}' and NFD 'Cafe\\u{{301}}' must deduplicate to 1 entity after NFKC normalization"
1659 );
1660 }
1661
1662 #[test]
1665 fn predict_batch_output_count_matches_input() {
1666 let w1_ids: Vec<u32> = vec![101, 100, 102];
1672 let w1_tok: Vec<String> = vec!["[CLS]".into(), "hello".into(), "[SEP]".into()];
1673 let w2_ids: Vec<u32> = vec![101, 100, 200, 300, 102];
1674 let w2_tok: Vec<String> = vec![
1675 "[CLS]".into(),
1676 "world".into(),
1677 "foo".into(),
1678 "bar".into(),
1679 "[SEP]".into(),
1680 ];
1681 let windows: Vec<(Vec<u32>, Vec<String>)> =
1682 vec![(w1_ids.clone(), w1_tok), (w2_ids.clone(), w2_tok)];
1683
1684 let device = Device::Cpu;
1687 let max_len = windows.iter().map(|(ids, _)| ids.len()).max().unwrap();
1688 assert_eq!(max_len, 5, "max_len deve ser 5");
1689
1690 let mut padded_ids: Vec<Tensor> = Vec::new();
1691 for (ids, _) in &windows {
1692 let len = ids.len();
1693 let pad_right = max_len - len;
1694 let ids_i64: Vec<i64> = ids.iter().map(|&x| x as i64).collect();
1695 let t = Tensor::from_vec(ids_i64, len, &device).unwrap();
1696 let t = t.pad_with_zeros(0, 0, pad_right).unwrap();
1697 assert_eq!(
1698 t.dims(),
1699 &[max_len],
1700 "each window must have shape (max_len,) after padding"
1701 );
1702 padded_ids.push(t);
1703 }
1704
1705 let stacked = Tensor::stack(&padded_ids, 0).unwrap();
1706 assert_eq!(
1707 stacked.dims(),
1708 &[2, max_len],
1709 "stack deve produzir (batch_size=2, max_len=5)"
1710 );
1711
1712 let fake_logits_data: Vec<f32> = vec![0.0f32; 2 * max_len * 9]; let fake_logits =
1716 Tensor::from_vec(fake_logits_data, (2usize, max_len, 9usize), &device).unwrap();
1717 for (i, (ids, _)) in windows.iter().enumerate() {
1718 let real_len = ids.len();
1719 let example = fake_logits.get(i).unwrap();
1720 let sliced = example.narrow(0, 0, real_len).unwrap();
1721 assert_eq!(
1722 sliced.dims(),
1723 &[real_len, 9],
1724 "narrow deve preservar apenas {real_len} tokens reais"
1725 );
1726 }
1727 }
1728
1729 #[test]
1730 fn predict_batch_empty_windows_returns_empty() {
1731 let windows: Vec<(Vec<u32>, Vec<String>)> = vec![];
1734 let max_len = windows.iter().map(|(ids, _)| ids.len()).max().unwrap_or(0);
1735 assert_eq!(max_len, 0, "zero windows → max_len 0");
1736 let result: Vec<Vec<String>> = if max_len == 0 {
1739 Vec::new()
1740 } else {
1741 unreachable!()
1742 };
1743 assert!(result.is_empty());
1744 }
1745
1746 #[test]
1747 fn ner_batch_size_default_is_8() {
1748 std::env::remove_var("GRAPHRAG_NER_BATCH_SIZE");
1751 assert_eq!(crate::constants::ner_batch_size(), 8);
1752 }
1753
1754 #[test]
1755 fn ner_batch_size_env_override_clamped() {
1756 std::env::set_var("GRAPHRAG_NER_BATCH_SIZE", "64");
1758 assert_eq!(crate::constants::ner_batch_size(), 32, "deve clampar em 32");
1759
1760 std::env::set_var("GRAPHRAG_NER_BATCH_SIZE", "0");
1761 assert_eq!(crate::constants::ner_batch_size(), 1, "deve clampar em 1");
1762
1763 std::env::set_var("GRAPHRAG_NER_BATCH_SIZE", "4");
1764 assert_eq!(
1765 crate::constants::ner_batch_size(),
1766 4,
1767 "valid value preserved"
1768 );
1769
1770 std::env::remove_var("GRAPHRAG_NER_BATCH_SIZE");
1771 }
1772
1773 #[test]
1774 fn extraction_method_regex_only_unchanged() {
1775 let result = RegexExtractor.extract("contact: dev@acme.io").unwrap();
1778 assert_eq!(
1779 result.extraction_method, "regex-only",
1780 "RegexExtractor must return regex-only"
1781 );
1782 }
1783
1784 #[test]
1787 fn extend_suffix_pure_numeric_unchanged() {
1788 let ents = vec![ExtractedEntity {
1790 name: "GPT".to_string(),
1791 entity_type: EntityType::Concept,
1792 }];
1793 let result = extend_with_numeric_suffix(ents, "using GPT-5 in the project");
1794 assert_eq!(
1795 result[0].name, "GPT-5",
1796 "purely numeric suffix must be extended"
1797 );
1798 }
1799
1800 #[test]
1801 fn extend_suffix_alphanumeric_letter_after_digit() {
1802 let ents = vec![ExtractedEntity {
1804 name: "GPT".to_string(),
1805 entity_type: EntityType::Concept,
1806 }];
1807 let result = extend_with_numeric_suffix(ents, "using GPT-4o for advanced tasks");
1808 assert_eq!(result[0].name, "GPT-4o", "suffix '4o' must be accepted");
1809 }
1810
1811 #[test]
1812 fn extend_suffix_alphanumeric_b_suffix() {
1813 let ents = vec![ExtractedEntity {
1815 name: "Llama".to_string(),
1816 entity_type: EntityType::Concept,
1817 }];
1818 let result = extend_with_numeric_suffix(ents, "Llama-5b open-weight model");
1819 assert_eq!(result[0].name, "Llama-5b", "suffix '5b' must be accepted");
1820 }
1821
1822 #[test]
1823 fn extend_suffix_alphanumeric_x_suffix() {
1824 let ents = vec![ExtractedEntity {
1826 name: "Mistral".to_string(),
1827 entity_type: EntityType::Concept,
1828 }];
1829 let result = extend_with_numeric_suffix(ents, "testing Mistral-8x in production");
1830 assert_eq!(result[0].name, "Mistral-8x", "suffix '8x' must be accepted");
1831 }
1832
1833 #[test]
1836 fn augment_versioned_gpt4o() {
1837 let result = augment_versioned_model_names(vec![], "using GPT-4o for analysis");
1839 assert!(
1840 result.iter().any(|e| e.name == "GPT-4o"),
1841 "GPT-4o must be captured by augment, found: {:?}",
1842 result.iter().map(|e| &e.name).collect::<Vec<_>>()
1843 );
1844 }
1845
1846 #[test]
1847 fn augment_versioned_claude_4_sonnet() {
1848 let result =
1850 augment_versioned_model_names(vec![], "best model: Claude 4 Sonnet released today");
1851 assert!(
1852 result.iter().any(|e| e.name == "Claude 4 Sonnet"),
1853 "Claude 4 Sonnet must be captured, found: {:?}",
1854 result.iter().map(|e| &e.name).collect::<Vec<_>>()
1855 );
1856 }
1857
1858 #[test]
1859 fn augment_versioned_llama_3_pro() {
1860 let result =
1862 augment_versioned_model_names(vec![], "fine-tuning com Llama 3 Pro localmente");
1863 assert!(
1864 result.iter().any(|e| e.name == "Llama 3 Pro"),
1865 "Llama 3 Pro deve ser capturado, achados: {:?}",
1866 result.iter().map(|e| &e.name).collect::<Vec<_>>()
1867 );
1868 }
1869
1870 #[test]
1871 fn augment_versioned_mixtral_8x7b() {
1872 let result =
1874 augment_versioned_model_names(vec![], "executando Mixtral 8x7B no servidor local");
1875 assert!(
1876 result.iter().any(|e| e.name == "Mixtral 8x7B"),
1877 "Mixtral 8x7B deve ser capturado, achados: {:?}",
1878 result.iter().map(|e| &e.name).collect::<Vec<_>>()
1879 );
1880 }
1881
1882 #[test]
1883 fn augment_versioned_does_not_duplicate_existing() {
1884 let existing = vec![ExtractedEntity {
1886 name: "Claude 4".to_string(),
1887 entity_type: EntityType::Concept,
1888 }];
1889 let result = augment_versioned_model_names(existing, "using Claude 4 in the project");
1890 let count = result.iter().filter(|e| e.name == "Claude 4").count();
1891 assert_eq!(count, 1, "Claude 4 must not be duplicated");
1892 }
1893
1894 #[test]
1897 fn stopwords_filter_url_jwt_api_v1025() {
1898 let body = "We use URL, JWT, and API REST in our LLM-powered CLI via HTTP/HTTPS and UI.";
1900 let ents = apply_regex_prefilter(body);
1901 let names: Vec<&str> = ents.iter().map(|e| e.name.as_str()).collect();
1902 for blocked in &[
1903 "URL", "JWT", "API", "REST", "LLM", "CLI", "HTTP", "HTTPS", "UI",
1904 ] {
1905 assert!(
1906 !names.contains(blocked),
1907 "v1.0.25 stopword {blocked} leaked as entity; found names: {names:?}"
1908 );
1909 }
1910 }
1911
1912 #[test]
1915 fn section_markers_etapa_fase_filtered_v1025() {
1916 let body = "Etapa 3 do plano: implementar Fase 1 da Migra\u{e7}\u{e3}o.";
1920 let ents = apply_regex_prefilter(body);
1921 assert!(
1922 !ents
1923 .iter()
1924 .any(|e| e.name.contains("Etapa") || e.name.contains("Fase")),
1925 "section markers must be stripped; entities: {:?}",
1926 ents.iter().map(|e| &e.name).collect::<Vec<_>>()
1927 );
1928 }
1929
1930 #[test]
1931 fn section_markers_passo_secao_filtered_v1025() {
1932 let body = "Siga Passo 2 conforme Se\u{e7}\u{e3}o 3 do manual.";
1935 let ents = apply_regex_prefilter(body);
1936 assert!(
1937 !ents
1938 .iter()
1939 .any(|e| e.name.contains("Passo") || e.name.contains("Se\u{e7}\u{e3}o")),
1940 "Passo/Se\\u{{e7}}\\u{{e3}}o section markers must be stripped; entities: {:?}",
1941 ents.iter().map(|e| &e.name).collect::<Vec<_>>()
1942 );
1943 }
1944
1945 #[test]
1948 fn brand_camelcase_extracted_as_organization_v1025() {
1949 let body = "OpenAI launched GPT-4 and PostgreSQL added pgvector.";
1951 let ents = apply_regex_prefilter(body);
1952 let openai = ents.iter().find(|e| e.name == "OpenAI");
1953 assert!(
1954 openai.is_some(),
1955 "OpenAI must be extracted by CamelCase brand regex; entities: {:?}",
1956 ents.iter().map(|e| &e.name).collect::<Vec<_>>()
1957 );
1958 assert_eq!(
1959 openai.unwrap().entity_type,
1960 EntityType::Organization,
1961 "brand CamelCase must map to organization (V008)"
1962 );
1963 }
1964
1965 #[test]
1966 fn brand_postgresql_extracted_as_organization_v1025() {
1967 let body = "migrating from MySQL to PostgreSQL for better performance.";
1968 let ents = apply_regex_prefilter(body);
1969 assert!(
1970 ents.iter()
1971 .any(|e| e.name == "PostgreSQL" && e.entity_type == EntityType::Organization),
1972 "PostgreSQL must be extracted as organization; entities: {:?}",
1973 ents.iter()
1974 .map(|e| (&e.name, &e.entity_type))
1975 .collect::<Vec<_>>()
1976 );
1977 }
1978
1979 #[test]
1982 fn iob_org_maps_to_organization_not_project_v1025() {
1983 let tokens = vec!["Microsoft".to_string()];
1985 let labels = vec!["B-ORG".to_string()];
1986 let ents = iob_to_entities(&tokens, &labels);
1987 assert_eq!(
1988 ents[0].entity_type,
1989 EntityType::Organization,
1990 "B-ORG must map to organization (V008); got {}",
1991 ents[0].entity_type
1992 );
1993 }
1994
1995 #[test]
1996 fn iob_loc_maps_to_location_not_concept_v1025() {
1997 let tokens = vec!["S\u{e3}o".to_string(), "Paulo".to_string()];
2000 let labels = vec!["B-LOC".to_string(), "I-LOC".to_string()];
2001 let ents = iob_to_entities(&tokens, &labels);
2002 assert_eq!(
2003 ents[0].entity_type,
2004 EntityType::Location,
2005 "B-LOC must map to location (V008); got {}",
2006 ents[0].entity_type
2007 );
2008 }
2009
2010 #[test]
2011 fn iob_date_maps_to_date_not_discarded_v1025() {
2012 let tokens = vec!["2025".to_string(), "-".to_string(), "12".to_string()];
2014 let labels = vec![
2015 "B-DATE".to_string(),
2016 "I-DATE".to_string(),
2017 "I-DATE".to_string(),
2018 ];
2019 let ents = iob_to_entities(&tokens, &labels);
2020 assert_eq!(
2021 ents.len(),
2022 1,
2023 "DATE entity must be emitted (V008); entities: {ents:?}"
2024 );
2025 assert_eq!(ents[0].entity_type, EntityType::Date);
2026 }
2027
2028 #[test]
2031 fn pt_verb_le_filtered_as_per_v1025() {
2032 let tokens = vec!["L\u{ea}".to_string(), "o".to_string(), "livro".to_string()];
2035 let labels = vec!["B-PER".to_string(), "O".to_string(), "O".to_string()];
2036 let ents = iob_to_entities(&tokens, &labels);
2037 assert!(
2038 !ents
2039 .iter()
2040 .any(|e| e.name == "L\u{ea}" && e.entity_type == EntityType::Person),
2041 "PT verb 'L\\u{{ea}}' tagged B-PER must be filtered; entities: {ents:?}"
2042 );
2043 }
2044
2045 #[test]
2046 fn pt_verb_ver_filtered_as_per_v1025() {
2047 let tokens = vec!["Ver".to_string()];
2049 let labels = vec!["B-PER".to_string()];
2050 let ents = iob_to_entities(&tokens, &labels);
2051 assert!(
2052 ents.is_empty(),
2053 "PT verb 'Ver' tagged B-PER must be filtered; entities: {ents:?}"
2054 );
2055 }
2056
2057 fn entity(name: &str, entity_type: EntityType) -> ExtractedEntity {
2060 ExtractedEntity {
2061 name: name.to_string(),
2062 entity_type,
2063 }
2064 }
2065
2066 #[test]
2067 fn merge_resolves_sonne_vs_sonnet_keeps_longest_v1025() {
2068 let regex = vec![entity("Sonne", EntityType::Concept)];
2070 let ner = vec![entity("Sonnet", EntityType::Concept)];
2071 let result = merge_and_deduplicate(regex, ner);
2072 assert_eq!(result.len(), 1, "expected 1 entity, got: {result:?}");
2073 assert_eq!(result[0].name, "Sonnet");
2074 }
2075
2076 #[test]
2077 fn merge_resolves_open_vs_openai_keeps_longest_v1025() {
2078 let regex = vec![
2080 entity("Open", EntityType::Organization),
2081 entity("OpenAI", EntityType::Organization),
2082 ];
2083 let result = merge_and_deduplicate(regex, vec![]);
2084 assert_eq!(result.len(), 1, "expected 1 entity, got: {result:?}");
2085 assert_eq!(result[0].name, "OpenAI");
2086 }
2087
2088 #[test]
2089 fn merge_keeps_both_when_no_containment_v1025() {
2090 let regex = vec![
2092 entity("Alice", EntityType::Person),
2093 entity("Bob", EntityType::Person),
2094 ];
2095 let result = merge_and_deduplicate(regex, vec![]);
2096 assert_eq!(result.len(), 2, "expected 2 entities, got: {result:?}");
2097 }
2098
2099 #[test]
2100 fn merge_respects_entity_type_boundary_v1025() {
2101 let regex = vec![
2103 entity("Apple", EntityType::Organization),
2104 entity("Apple", EntityType::Concept),
2105 ];
2106 let result = merge_and_deduplicate(regex, vec![]);
2107 assert_eq!(
2108 result.len(),
2109 2,
2110 "expected 2 entities (different types), got: {result:?}"
2111 );
2112 }
2113
2114 #[test]
2115 fn merge_case_insensitive_dedup_v1025() {
2116 let regex = vec![
2118 entity("OpenAI", EntityType::Organization),
2119 entity("openai", EntityType::Organization),
2120 ];
2121 let result = merge_and_deduplicate(regex, vec![]);
2122 assert_eq!(
2123 result.len(),
2124 1,
2125 "expected 1 entity after case-insensitive dedup, got: {result:?}"
2126 );
2127 }
2128
2129 #[test]
2132 fn iob_section_marker_etapa_filtered_v1025() {
2133 let tokens = vec!["Etapa".to_string(), "3".to_string()];
2135 let labels = vec!["B-MISC".to_string(), "I-MISC".to_string()];
2136 let ents = iob_to_entities(&tokens, &labels);
2137 assert!(
2138 !ents.iter().any(|e| e.name.contains("Etapa")),
2139 "section marker 'Etapa 3' from BERT must be filtered; entities: {ents:?}"
2140 );
2141 }
2142
2143 #[test]
2144 fn iob_section_marker_fase_filtered_v1025() {
2145 let tokens = vec!["Fase".to_string(), "1".to_string()];
2147 let labels = vec!["B-MISC".to_string(), "I-MISC".to_string()];
2148 let ents = iob_to_entities(&tokens, &labels);
2149 assert!(
2150 !ents.iter().any(|e| e.name.contains("Fase")),
2151 "section marker 'Fase 1' from BERT must be filtered; entities: {ents:?}"
2152 );
2153 }
2154
2155 #[test]
2158 fn extract_graph_auto_handles_large_body_under_30s() {
2159 let body = "x ".repeat(40_000);
2162 let paths = make_paths();
2163 let start = std::time::Instant::now();
2164 let result = extract_graph_auto(&body, &paths).expect("extraction must not error");
2165 let elapsed = start.elapsed();
2166 assert!(
2167 elapsed.as_secs() < 30,
2168 "extract_graph_auto took {}s for 80 KB body (cap should keep it well under 30s)",
2169 elapsed.as_secs()
2170 );
2171 let _ = result.entities;
2173 }
2174
2175 #[test]
2178 fn pt_uppercase_stopwords_filtered_v1031() {
2179 let body = "Para o ADAPTER funcionar com PROJETO em modo PASSIVA, devemos usar \
2180 SOMENTE LEITURA conforme a REGRA OBRIGATORIA do EXEMPLO DEFAULT.";
2181 let ents = apply_regex_prefilter(body);
2182 let names: Vec<String> = ents.iter().map(|e| e.name.to_uppercase()).collect();
2183 for stop in &[
2184 "ADAPTER",
2185 "PROJETO",
2186 "PASSIVA",
2187 "SOMENTE",
2188 "LEITURA",
2189 "REGRA",
2190 "OBRIGATORIA",
2191 "EXEMPLO",
2192 "DEFAULT",
2193 ] {
2194 assert!(
2195 !names.contains(&stop.to_string()),
2196 "v1.0.31 A11 stoplist failed: {stop} leaked as entity; got names: {names:?}"
2197 );
2198 }
2199 }
2200
2201 #[test]
2202 fn pt_underscored_identifier_preserved_v1031() {
2203 let ents = apply_regex_prefilter("configure FLOWAIPER_API_KEY=foo and MAX_TIMEOUT=30");
2206 let names: Vec<&str> = ents.iter().map(|e| e.name.as_str()).collect();
2207 assert!(names.contains(&"FLOWAIPER_API_KEY"));
2208 assert!(names.contains(&"MAX_TIMEOUT"));
2209 }
2210
2211 #[test]
2214 fn build_relationships_by_sentence_only_links_co_occurring_entities() {
2215 let body = "Alice met Bob at the conference. Carol works alone in another room.";
2216 let entities = vec![
2217 NewEntity {
2218 name: "Alice".to_string(),
2219 entity_type: EntityType::Person,
2220 description: None,
2221 },
2222 NewEntity {
2223 name: "Bob".to_string(),
2224 entity_type: EntityType::Person,
2225 description: None,
2226 },
2227 NewEntity {
2228 name: "Carol".to_string(),
2229 entity_type: EntityType::Person,
2230 description: None,
2231 },
2232 ];
2233 let (rels, truncated) = build_relationships_by_sentence_cooccurrence(body, &entities);
2234 assert!(!truncated);
2235 assert_eq!(
2236 rels.len(),
2237 1,
2238 "only Alice/Bob should pair (same sentence); Carol is isolated"
2239 );
2240 let pair = (rels[0].source.as_str(), rels[0].target.as_str());
2241 assert!(
2242 matches!(pair, ("Alice", "Bob") | ("Bob", "Alice")),
2243 "unexpected pair {pair:?}"
2244 );
2245 }
2246
2247 #[test]
2248 fn build_relationships_by_sentence_returns_empty_for_single_entity() {
2249 let body = "Alice is here.";
2250 let entities = vec![NewEntity {
2251 name: "Alice".to_string(),
2252 entity_type: EntityType::Person,
2253 description: None,
2254 }];
2255 let (rels, truncated) = build_relationships_by_sentence_cooccurrence(body, &entities);
2256 assert!(rels.is_empty());
2257 assert!(!truncated);
2258 }
2259
2260 #[test]
2261 fn build_relationships_by_sentence_dedupes_pairs_across_sentences() {
2262 let body = "Alice met Bob. Bob saw Alice again.";
2263 let entities = vec![
2264 NewEntity {
2265 name: "Alice".to_string(),
2266 entity_type: EntityType::Person,
2267 description: None,
2268 },
2269 NewEntity {
2270 name: "Bob".to_string(),
2271 entity_type: EntityType::Person,
2272 description: None,
2273 },
2274 ];
2275 let (rels, _) = build_relationships_by_sentence_cooccurrence(body, &entities);
2276 assert_eq!(
2277 rels.len(),
2278 1,
2279 "Alice/Bob pair must be emitted only once even when co-occurring in multiple sentences"
2280 );
2281 }
2282
2283 #[test]
2284 fn extraction_max_tokens_default_is_5000() {
2285 std::env::remove_var("SQLITE_GRAPHRAG_EXTRACTION_MAX_TOKENS");
2286 assert_eq!(crate::constants::extraction_max_tokens(), 5_000);
2287 }
2288
2289 #[test]
2290 fn extraction_max_tokens_env_override_clamped() {
2291 std::env::set_var("SQLITE_GRAPHRAG_EXTRACTION_MAX_TOKENS", "200");
2292 assert_eq!(
2293 crate::constants::extraction_max_tokens(),
2294 5_000,
2295 "value below 512 must fall back to default"
2296 );
2297
2298 std::env::set_var("SQLITE_GRAPHRAG_EXTRACTION_MAX_TOKENS", "200000");
2299 assert_eq!(
2300 crate::constants::extraction_max_tokens(),
2301 5_000,
2302 "value above 100_000 must fall back to default"
2303 );
2304
2305 std::env::set_var("SQLITE_GRAPHRAG_EXTRACTION_MAX_TOKENS", "8000");
2306 assert_eq!(
2307 crate::constants::extraction_max_tokens(),
2308 8_000,
2309 "valid value must be honoured"
2310 );
2311
2312 std::env::remove_var("SQLITE_GRAPHRAG_EXTRACTION_MAX_TOKENS");
2313 }
2314}