Skip to main content

sqlite_graphrag/
extraction.rs

1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3use std::sync::OnceLock;
4
5use anyhow::{Context, Result};
6use candle_core::{DType, Device, Tensor};
7use candle_nn::{Linear, Module, VarBuilder};
8use candle_transformers::models::bert::{BertModel, Config as BertConfig};
9use regex::Regex;
10use serde::Deserialize;
11
12use crate::paths::AppPaths;
13use crate::storage::entities::{NewEntity, NewRelationship};
14
15const MODEL_ID: &str = "Davlan/bert-base-multilingual-cased-ner-hrl";
16const MAX_SEQ_LEN: usize = 512;
17const STRIDE: usize = 256;
18const MAX_ENTS: usize = 30;
19const MAX_RELS: usize = 50;
20const TOP_K_RELATIONS: usize = 5;
21const DEFAULT_RELATION: &str = "mentions";
22const MIN_ENTITY_CHARS: usize = 2;
23
24static REGEX_EMAIL: OnceLock<Regex> = OnceLock::new();
25static REGEX_URL: OnceLock<Regex> = OnceLock::new();
26static REGEX_UUID: OnceLock<Regex> = OnceLock::new();
27static REGEX_ALL_CAPS: OnceLock<Regex> = OnceLock::new();
28
29// v1.0.20: stopwords para filtrar palavras-regra PT-BR/EN comuns capturadas como ALL_CAPS.
30// Sem este filtro, corpus técnico em PT-BR contendo regras formatadas em CAPS (NUNCA, PROIBIDO, DEVE)
31// gerava ~70% de "entidades" lixo. Mantemos identificadores tipo MAX_RETRY (com underscore).
32const ALL_CAPS_STOPWORDS: &[&str] = &[
33    "NUNCA",
34    "SEMPRE",
35    "PROIBIDO",
36    "OBRIGATÓRIO",
37    "DEVE",
38    "JAMAIS",
39    "USAR",
40    "EVITAR",
41    "TODOS",
42    "TODAS",
43    "VOCÊ",
44    "ESTA",
45    "ESTE",
46    "ESSE",
47    "ESSA",
48    "PADRÃO",
49    "REGRAS",
50    "CRÍTICO",
51    "FALHA",
52    "NEVER",
53    "ALWAYS",
54    "FORBIDDEN",
55    "REQUIRED",
56    "MUST",
57    "SHOULD",
58    "SHALL",
59    "TODO",
60    "FIXME",
61    "HACK",
62    "BUG",
63    "NOTE",
64    "WARNING",
65    "CRITICAL",
66    "ERROR",
67];
68
69fn regex_email() -> &'static Regex {
70    REGEX_EMAIL
71        .get_or_init(|| Regex::new(r"[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}").unwrap())
72}
73
74fn regex_url() -> &'static Regex {
75    REGEX_URL.get_or_init(|| Regex::new(r#"https?://[^\s\)\]\}"'<>]+"#).unwrap())
76}
77
78fn regex_uuid() -> &'static Regex {
79    REGEX_UUID.get_or_init(|| {
80        Regex::new(r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}")
81            .unwrap()
82    })
83}
84
85fn regex_all_caps() -> &'static Regex {
86    REGEX_ALL_CAPS.get_or_init(|| Regex::new(r"\b[A-Z][A-Z0-9_]{2,}\b").unwrap())
87}
88
89#[derive(Debug, Clone, PartialEq)]
90pub struct ExtractedEntity {
91    pub name: String,
92    pub entity_type: String,
93}
94
95#[derive(Debug, Clone)]
96pub struct ExtractionResult {
97    pub entities: Vec<NewEntity>,
98    pub relationships: Vec<NewRelationship>,
99    /// Método usado para extração: "bert+regex" ou "regex-only".
100    /// Útil para auditoria, métricas e reportes ao usuário.
101    pub extraction_method: String,
102}
103
104pub trait Extractor: Send + Sync {
105    fn extract(&self, body: &str) -> Result<ExtractionResult>;
106}
107
108#[derive(Deserialize)]
109struct ModelConfig {
110    #[serde(default)]
111    id2label: HashMap<String, String>,
112    hidden_size: usize,
113}
114
115struct BertNerModel {
116    bert: BertModel,
117    classifier: Linear,
118    device: Device,
119    id2label: HashMap<usize, String>,
120}
121
122impl BertNerModel {
123    fn load(model_dir: &Path) -> Result<Self> {
124        let config_path = model_dir.join("config.json");
125        let weights_path = model_dir.join("model.safetensors");
126
127        let config_str = std::fs::read_to_string(&config_path)
128            .with_context(|| format!("lendo config.json em {config_path:?}"))?;
129        let model_cfg: ModelConfig =
130            serde_json::from_str(&config_str).context("parseando config.json do modelo NER")?;
131
132        let id2label: HashMap<usize, String> = model_cfg
133            .id2label
134            .into_iter()
135            .filter_map(|(k, v)| k.parse::<usize>().ok().map(|n| (n, v)))
136            .collect();
137
138        let num_labels = id2label.len().max(9);
139        let hidden_size = model_cfg.hidden_size;
140
141        let bert_config_str = std::fs::read_to_string(&config_path)
142            .with_context(|| format!("relendo config.json para bert em {config_path:?}"))?;
143        let bert_cfg: BertConfig =
144            serde_json::from_str(&bert_config_str).context("parseando BertConfig")?;
145
146        let device = Device::Cpu;
147
148        let vb = unsafe {
149            VarBuilder::from_mmaped_safetensors(&[&weights_path], DType::F32, &device)
150                .with_context(|| format!("mapeando {weights_path:?}"))?
151        };
152        let bert = BertModel::load(vb.pp("bert"), &bert_cfg).context("carregando BertModel")?;
153
154        // v1.0.20 fix P0 secundário: carregar classifier head do safetensors em vez de zeros.
155        // Em v1.0.19 usávamos Tensor::zeros, o que produzia argmax constante e inferência degenerada.
156        let cls_vb = vb.pp("classifier");
157        let weight = cls_vb
158            .get((num_labels, hidden_size), "weight")
159            .context("carregando classifier.weight do safetensors")?;
160        let bias = cls_vb
161            .get(num_labels, "bias")
162            .context("carregando classifier.bias do safetensors")?;
163        let classifier = Linear::new(weight, Some(bias));
164
165        Ok(Self {
166            bert,
167            classifier,
168            device,
169            id2label,
170        })
171    }
172
173    fn predict(&self, token_ids: &[u32], attention_mask: &[u32]) -> Result<Vec<String>> {
174        let len = token_ids.len();
175        let ids_i64: Vec<i64> = token_ids.iter().map(|&x| x as i64).collect();
176        let mask_i64: Vec<i64> = attention_mask.iter().map(|&x| x as i64).collect();
177
178        let input_ids = Tensor::from_vec(ids_i64, (1, len), &self.device)
179            .context("criando tensor input_ids")?;
180        let token_type_ids = Tensor::zeros((1, len), DType::I64, &self.device)
181            .context("criando tensor token_type_ids")?;
182        let attn_mask = Tensor::from_vec(mask_i64, (1, len), &self.device)
183            .context("criando tensor attention_mask")?;
184
185        let sequence_output = self
186            .bert
187            .forward(&input_ids, &token_type_ids, Some(&attn_mask))
188            .context("forward pass do BertModel")?;
189
190        let logits = self
191            .classifier
192            .forward(&sequence_output)
193            .context("forward pass do classificador")?;
194
195        let logits_2d = logits.squeeze(0).context("removendo dimensão batch")?;
196
197        let num_tokens = logits_2d.dim(0).context("dim(0)")?;
198
199        let mut labels = Vec::with_capacity(num_tokens);
200        for i in 0..num_tokens {
201            let token_logits = logits_2d.get(i).context("get token logits")?;
202            let vec: Vec<f32> = token_logits.to_vec1().context("to_vec1 logits")?;
203            let argmax = vec
204                .iter()
205                .enumerate()
206                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
207                .map(|(idx, _)| idx)
208                .unwrap_or(0);
209            let label = self
210                .id2label
211                .get(&argmax)
212                .cloned()
213                .unwrap_or_else(|| "O".to_string());
214            labels.push(label);
215        }
216
217        Ok(labels)
218    }
219}
220
221static NER_MODEL: OnceLock<Option<BertNerModel>> = OnceLock::new();
222
223fn get_or_init_model(paths: &AppPaths) -> Option<&'static BertNerModel> {
224    NER_MODEL
225        .get_or_init(|| match load_model(paths) {
226            Ok(m) => Some(m),
227            Err(e) => {
228                tracing::warn!("NER model não disponível (graceful degradation): {e:#}");
229                None
230            }
231        })
232        .as_ref()
233}
234
235fn model_dir(paths: &AppPaths) -> PathBuf {
236    paths.models.join("bert-multilingual-ner")
237}
238
239fn ensure_model_files(paths: &AppPaths) -> Result<PathBuf> {
240    let dir = model_dir(paths);
241    std::fs::create_dir_all(&dir)
242        .with_context(|| format!("criando diretório do modelo: {dir:?}"))?;
243
244    let weights = dir.join("model.safetensors");
245    let config = dir.join("config.json");
246    let tokenizer = dir.join("tokenizer.json");
247
248    if weights.exists() && config.exists() && tokenizer.exists() {
249        return Ok(dir);
250    }
251
252    tracing::info!("Baixando modelo NER (primeira execução, ~676 MB)...");
253    crate::output::emit_progress_i18n(
254        "Downloading NER model (first run, ~676 MB)...",
255        "Baixando modelo NER (primeira execução, ~676 MB)...",
256    );
257
258    let api = huggingface_hub::api::sync::Api::new().context("criando cliente HF Hub")?;
259    let repo = api.model(MODEL_ID.to_string());
260
261    // v1.0.20 fix P0 primário: tokenizer.json no repo Davlan está apenas em onnx/tokenizer.json.
262    // Em v1.0.19 buscávamos da raiz e recebíamos 404, caindo em graceful degradation 100% das vezes.
263    // Mapeamos (remote_path, local_filename) para baixar do subfolder mantendo nome plano local.
264    for (remote, local) in &[
265        ("model.safetensors", "model.safetensors"),
266        ("config.json", "config.json"),
267        ("onnx/tokenizer.json", "tokenizer.json"),
268        ("tokenizer_config.json", "tokenizer_config.json"),
269    ] {
270        let dest = dir.join(local);
271        if !dest.exists() {
272            let src = repo
273                .get(remote)
274                .with_context(|| format!("baixando {remote} do HF Hub"))?;
275            std::fs::copy(&src, &dest).with_context(|| format!("copiando {local} para cache"))?;
276        }
277    }
278
279    Ok(dir)
280}
281
282fn load_model(paths: &AppPaths) -> Result<BertNerModel> {
283    let dir = ensure_model_files(paths)?;
284    BertNerModel::load(&dir)
285}
286
287fn apply_regex_prefilter(body: &str) -> Vec<ExtractedEntity> {
288    let mut entities = Vec::new();
289    let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
290
291    let add = |entities: &mut Vec<ExtractedEntity>,
292               seen: &mut std::collections::HashSet<String>,
293               name: &str,
294               entity_type: &str| {
295        let name = name.trim().to_string();
296        if name.len() >= MIN_ENTITY_CHARS && seen.insert(name.clone()) {
297            entities.push(ExtractedEntity {
298                name,
299                entity_type: entity_type.to_string(),
300            });
301        }
302    };
303
304    for m in regex_email().find_iter(body) {
305        // v1.0.20: email é "concept" (regex sozinho não distingue pessoa de mailing list/role).
306        add(&mut entities, &mut seen, m.as_str(), "concept");
307    }
308    for m in regex_url().find_iter(body) {
309        add(&mut entities, &mut seen, m.as_str(), "concept");
310    }
311    for m in regex_uuid().find_iter(body) {
312        add(&mut entities, &mut seen, m.as_str(), "concept");
313    }
314    for m in regex_all_caps().find_iter(body) {
315        let candidate = m.as_str();
316        // v1.0.20: aceita identificadores com underscore (MAX_RETRY) ou que NÃO sejam stopwords PT/EN.
317        let is_identifier = candidate.contains('_');
318        let is_stopword = ALL_CAPS_STOPWORDS.contains(&candidate);
319        if is_identifier || !is_stopword {
320            add(&mut entities, &mut seen, candidate, "concept");
321        }
322    }
323
324    entities
325}
326
327fn iob_to_entities(tokens: &[String], labels: &[String]) -> Vec<ExtractedEntity> {
328    let mut entities: Vec<ExtractedEntity> = Vec::new();
329    let mut current_parts: Vec<String> = Vec::new();
330    let mut current_type: Option<String> = None;
331
332    let flush =
333        |parts: &mut Vec<String>, typ: &mut Option<String>, entities: &mut Vec<ExtractedEntity>| {
334            if let Some(t) = typ.take() {
335                let name = parts.join(" ").trim().to_string();
336                if name.len() >= MIN_ENTITY_CHARS {
337                    entities.push(ExtractedEntity {
338                        name,
339                        entity_type: t,
340                    });
341                }
342                parts.clear();
343            }
344        };
345
346    for (token, label) in tokens.iter().zip(labels.iter()) {
347        if label == "O" {
348            flush(&mut current_parts, &mut current_type, &mut entities);
349            continue;
350        }
351
352        let (prefix, bio_type) = if let Some(rest) = label.strip_prefix("B-") {
353            ("B", rest)
354        } else if let Some(rest) = label.strip_prefix("I-") {
355            ("I", rest)
356        } else {
357            flush(&mut current_parts, &mut current_type, &mut entities);
358            continue;
359        };
360
361        let entity_type = match bio_type {
362            "DATE" => {
363                flush(&mut current_parts, &mut current_type, &mut entities);
364                continue;
365            }
366            "PER" => "person",
367            "ORG" => {
368                let t = token.to_lowercase();
369                if t.contains("lib")
370                    || t.contains("sdk")
371                    || t.contains("cli")
372                    || t.contains("crate")
373                    || t.contains("npm")
374                {
375                    "tool"
376                } else {
377                    "project"
378                }
379            }
380            "LOC" => "concept",
381            other => other,
382        };
383
384        if prefix == "B" {
385            if token.starts_with("##") {
386                // BERT confuso: subword com B-prefix indica continuação de entidade anterior.
387                // Anexar à última parte da entidade atual; senão descartar.
388                let clean = token.strip_prefix("##").unwrap_or(token.as_str());
389                if let Some(last) = current_parts.last_mut() {
390                    last.push_str(clean);
391                }
392                continue;
393            }
394            flush(&mut current_parts, &mut current_type, &mut entities);
395            current_parts.push(token.clone());
396            current_type = Some(entity_type.to_string());
397        } else if prefix == "I" && current_type.is_some() {
398            let clean = token.strip_prefix("##").unwrap_or(token.as_str());
399            if token.starts_with("##") {
400                if let Some(last) = current_parts.last_mut() {
401                    last.push_str(clean);
402                }
403            } else {
404                current_parts.push(clean.to_string());
405            }
406        }
407    }
408
409    flush(&mut current_parts, &mut current_type, &mut entities);
410    entities
411}
412
413fn build_relationships(entities: &[NewEntity]) -> Vec<NewRelationship> {
414    if entities.len() < 2 {
415        return Vec::new();
416    }
417
418    let n = entities.len().min(MAX_ENTS);
419    let mut rels: Vec<NewRelationship> = Vec::new();
420    let mut seen: std::collections::HashSet<(String, String)> = std::collections::HashSet::new();
421
422    let mut hit_cap = false;
423    'outer: for i in 0..n {
424        if rels.len() >= MAX_RELS {
425            hit_cap = true;
426            break;
427        }
428
429        let mut for_entity = 0usize;
430        for j in (i + 1)..n {
431            if for_entity >= TOP_K_RELATIONS {
432                break;
433            }
434            if rels.len() >= MAX_RELS {
435                hit_cap = true;
436                break 'outer;
437            }
438
439            let src = &entities[i].name;
440            let tgt = &entities[j].name;
441            let key = (src.clone(), tgt.clone());
442
443            if seen.contains(&key) {
444                continue;
445            }
446            seen.insert(key);
447
448            rels.push(NewRelationship {
449                source: src.clone(),
450                target: tgt.clone(),
451                relation: DEFAULT_RELATION.to_string(),
452                strength: 0.5,
453                description: None,
454            });
455            for_entity += 1;
456        }
457    }
458
459    // v1.0.20: avisar quando relacionamentos foram truncados antes de cobrir todos os pares possíveis.
460    if hit_cap {
461        tracing::warn!(
462            "relacionamentos truncados em {MAX_RELS} (com {n} entidades, máx teórico era ~{}× combinações)",
463            n.saturating_sub(1)
464        );
465    }
466
467    rels
468}
469
470fn run_ner_sliding_window(
471    model: &BertNerModel,
472    body: &str,
473    paths: &AppPaths,
474) -> Result<Vec<ExtractedEntity>> {
475    let tokenizer_path = model_dir(paths).join("tokenizer.json");
476    let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
477        .map_err(|e| anyhow::anyhow!("carregando tokenizer NER: {e}"))?;
478
479    let encoding = tokenizer
480        .encode(body, false)
481        .map_err(|e| anyhow::anyhow!("encoding NER: {e}"))?;
482
483    let all_ids: Vec<u32> = encoding.get_ids().to_vec();
484    let all_tokens: Vec<String> = encoding
485        .get_tokens()
486        .iter()
487        .map(|s| s.to_string())
488        .collect();
489
490    if all_ids.is_empty() {
491        return Ok(Vec::new());
492    }
493
494    let mut entities: Vec<ExtractedEntity> = Vec::new();
495    let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
496
497    let mut start = 0usize;
498    loop {
499        let end = (start + MAX_SEQ_LEN).min(all_ids.len());
500        let window_ids = &all_ids[start..end];
501        let window_tokens = &all_tokens[start..end];
502        let attention_mask: Vec<u32> = vec![1u32; window_ids.len()];
503
504        match model.predict(window_ids, &attention_mask) {
505            Ok(labels) => {
506                let window_ents = iob_to_entities(window_tokens, &labels);
507                for ent in window_ents {
508                    if seen.insert(ent.name.clone()) {
509                        entities.push(ent);
510                    }
511                }
512            }
513            Err(e) => {
514                tracing::warn!("janela NER falhou (start={start}): {e:#}");
515            }
516        }
517
518        if end >= all_ids.len() {
519            break;
520        }
521        start += STRIDE;
522    }
523
524    Ok(entities)
525}
526
527fn merge_and_deduplicate(
528    regex_ents: Vec<ExtractedEntity>,
529    ner_ents: Vec<ExtractedEntity>,
530) -> Vec<ExtractedEntity> {
531    let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
532    let mut result: Vec<ExtractedEntity> = Vec::new();
533    let mut truncated = false;
534
535    let total_input = regex_ents.len() + ner_ents.len();
536    for ent in regex_ents.into_iter().chain(ner_ents) {
537        let key = ent.name.to_lowercase();
538        if seen.insert(key) {
539            result.push(ent);
540        }
541        if result.len() >= MAX_ENTS {
542            truncated = true;
543            break;
544        }
545    }
546
547    // v1.0.20: avisar quando truncamento silencioso descarta entidades acima do MAX_ENTS.
548    if truncated {
549        tracing::warn!(
550            "extração truncada em {MAX_ENTS} entidades (entrada tinha {total_input} candidatos antes da deduplicação)"
551        );
552    }
553
554    result
555}
556
557fn to_new_entities(extracted: Vec<ExtractedEntity>) -> Vec<NewEntity> {
558    extracted
559        .into_iter()
560        .map(|e| NewEntity {
561            name: e.name,
562            entity_type: e.entity_type,
563            description: None,
564        })
565        .collect()
566}
567
568pub fn extract_graph_auto(body: &str, paths: &AppPaths) -> Result<ExtractionResult> {
569    let regex_entities = apply_regex_prefilter(body);
570
571    let mut bert_used = false;
572    let ner_entities = match get_or_init_model(paths) {
573        Some(model) => match run_ner_sliding_window(model, body, paths) {
574            Ok(ents) => {
575                bert_used = true;
576                ents
577            }
578            Err(e) => {
579                tracing::warn!("NER falhou, usando apenas regex: {e:#}");
580                Vec::new()
581            }
582        },
583        None => Vec::new(),
584    };
585
586    let merged = merge_and_deduplicate(regex_entities, ner_entities);
587    let entities = to_new_entities(merged);
588    let relationships = build_relationships(&entities);
589
590    let extraction_method = if bert_used {
591        "bert+regex".to_string()
592    } else {
593        "regex-only".to_string()
594    };
595
596    Ok(ExtractionResult {
597        entities,
598        relationships,
599        extraction_method,
600    })
601}
602
603pub struct RegexExtractor;
604
605impl Extractor for RegexExtractor {
606    fn extract(&self, body: &str) -> Result<ExtractionResult> {
607        let regex_entities = apply_regex_prefilter(body);
608        let entities = to_new_entities(regex_entities);
609        let relationships = build_relationships(&entities);
610        Ok(ExtractionResult {
611            entities,
612            relationships,
613            extraction_method: "regex-only".to_string(),
614        })
615    }
616}
617
618#[cfg(test)]
619mod tests {
620    use super::*;
621
622    fn make_paths() -> AppPaths {
623        use std::path::PathBuf;
624        AppPaths {
625            db: PathBuf::from("/tmp/test.sqlite"),
626            models: PathBuf::from("/tmp/test_models"),
627        }
628    }
629
630    #[test]
631    fn regex_email_captura_endereco() {
632        let ents = apply_regex_prefilter("contato: fulano@empresa.com.br para mais info");
633        // v1.0.20: emails são classificados como "concept" (regex sozinho não distingue pessoa de role).
634        assert!(ents
635            .iter()
636            .any(|e| e.name == "fulano@empresa.com.br" && e.entity_type == "concept"));
637    }
638
639    #[test]
640    fn regex_all_caps_filtra_palavra_regra_pt() {
641        // v1.0.20 fix P1: NUNCA, PROIBIDO, DEVE não devem virar "entidades".
642        let ents = apply_regex_prefilter("NUNCA fazer isso. PROIBIDO usar X. DEVE seguir Y.");
643        assert!(
644            !ents.iter().any(|e| e.name == "NUNCA"),
645            "NUNCA deveria ser filtrado como stopword"
646        );
647        assert!(
648            !ents.iter().any(|e| e.name == "PROIBIDO"),
649            "PROIBIDO deveria ser filtrado"
650        );
651        assert!(
652            !ents.iter().any(|e| e.name == "DEVE"),
653            "DEVE deveria ser filtrado"
654        );
655    }
656
657    #[test]
658    fn regex_all_caps_aceita_constante_com_underscore() {
659        // Constantes técnicas tipo MAX_RETRY, TIMEOUT_MS sempre devem ser aceitas.
660        let ents = apply_regex_prefilter("configure MAX_RETRY=3 e API_TIMEOUT=30");
661        assert!(ents.iter().any(|e| e.name == "MAX_RETRY"));
662        assert!(ents.iter().any(|e| e.name == "API_TIMEOUT"));
663    }
664
665    #[test]
666    fn regex_all_caps_aceita_acronimo_dominio() {
667        // Acrônimos legítimos (não-stopword) devem passar: OPENAI, NVIDIA, GOOGLE.
668        let ents = apply_regex_prefilter("OPENAI lançou GPT-5 com NVIDIA H100");
669        assert!(ents.iter().any(|e| e.name == "OPENAI"));
670        assert!(ents.iter().any(|e| e.name == "NVIDIA"));
671    }
672
673    #[test]
674    fn regex_url_captura_link() {
675        let ents = apply_regex_prefilter("veja https://docs.rs/crate para detalhes");
676        assert!(ents
677            .iter()
678            .any(|e| e.name.starts_with("https://") && e.entity_type == "concept"));
679    }
680
681    #[test]
682    fn regex_uuid_captura_identificador() {
683        let ents = apply_regex_prefilter("id=550e8400-e29b-41d4-a716-446655440000 no sistema");
684        assert!(ents.iter().any(|e| e.entity_type == "concept"));
685    }
686
687    #[test]
688    fn regex_all_caps_captura_constante() {
689        let ents = apply_regex_prefilter("configure MAX_RETRY e TIMEOUT_MS");
690        assert!(ents.iter().any(|e| e.name == "MAX_RETRY"));
691        assert!(ents.iter().any(|e| e.name == "TIMEOUT_MS"));
692    }
693
694    #[test]
695    fn regex_all_caps_ignora_palavras_curtas() {
696        let ents = apply_regex_prefilter("use AI em seu projeto");
697        assert!(
698            !ents.iter().any(|e| e.name == "AI"),
699            "AI tem apenas 2 chars, deve ser ignorado"
700        );
701    }
702
703    #[test]
704    fn iob_decodifica_per_para_person() {
705        let tokens = vec![
706            "John".to_string(),
707            "Doe".to_string(),
708            "trabalhou".to_string(),
709        ];
710        let labels = vec!["B-PER".to_string(), "I-PER".to_string(), "O".to_string()];
711        let ents = iob_to_entities(&tokens, &labels);
712        assert_eq!(ents.len(), 1);
713        assert_eq!(ents[0].entity_type, "person");
714        assert!(ents[0].name.contains("John"));
715    }
716
717    #[test]
718    fn iob_strip_subword_b_prefix() {
719        // v1.0.21 P0: BERT às vezes emite ##AI com B-prefix (subword confuso).
720        // Deve mergear na entidade ativa em vez de criar entidade fantasma "##AI".
721        let tokens = vec!["Open".to_string(), "##AI".to_string()];
722        let labels = vec!["B-ORG".to_string(), "B-ORG".to_string()];
723        let ents = iob_to_entities(&tokens, &labels);
724        assert!(
725            ents.iter().any(|e| e.name == "OpenAI" || e.name == "Open"),
726            "deveria mergear ##AI ou descartar"
727        );
728    }
729
730    #[test]
731    fn iob_subword_orphan_descarta() {
732        // v1.0.21 P0: subword órfão sem entidade ativa não deve virar entidade.
733        let tokens = vec!["##AI".to_string()];
734        let labels = vec!["B-ORG".to_string()];
735        let ents = iob_to_entities(&tokens, &labels);
736        assert!(
737            ents.is_empty(),
738            "subword órfão sem entidade ativa deve ser descartado"
739        );
740    }
741
742    #[test]
743    fn iob_descarta_date() {
744        let tokens = vec!["Janeiro".to_string(), "2024".to_string()];
745        let labels = vec!["B-DATE".to_string(), "I-DATE".to_string()];
746        let ents = iob_to_entities(&tokens, &labels);
747        assert!(ents.is_empty(), "DATE deve ser descartado");
748    }
749
750    #[test]
751    fn iob_mapeia_org_para_project() {
752        let tokens = vec!["Empresa".to_string()];
753        let labels = vec!["B-ORG".to_string()];
754        let ents = iob_to_entities(&tokens, &labels);
755        assert_eq!(ents[0].entity_type, "project");
756    }
757
758    #[test]
759    fn iob_mapeia_org_sdk_para_tool() {
760        let tokens = vec!["tokio-sdk".to_string()];
761        let labels = vec!["B-ORG".to_string()];
762        let ents = iob_to_entities(&tokens, &labels);
763        assert_eq!(ents[0].entity_type, "tool");
764    }
765
766    #[test]
767    fn iob_mapeia_loc_para_concept() {
768        let tokens = vec!["Brasil".to_string()];
769        let labels = vec!["B-LOC".to_string()];
770        let ents = iob_to_entities(&tokens, &labels);
771        assert_eq!(ents[0].entity_type, "concept");
772    }
773
774    #[test]
775    fn build_relationships_respeitam_max_rels() {
776        let entities: Vec<NewEntity> = (0..20)
777            .map(|i| NewEntity {
778                name: format!("entidade_{i}"),
779                entity_type: "concept".to_string(),
780                description: None,
781            })
782            .collect();
783        let rels = build_relationships(&entities);
784        assert!(rels.len() <= MAX_RELS, "deve respeitar MAX_RELS={MAX_RELS}");
785    }
786
787    #[test]
788    fn build_relationships_sem_duplicatas() {
789        let entities: Vec<NewEntity> = (0..5)
790            .map(|i| NewEntity {
791                name: format!("ent_{i}"),
792                entity_type: "concept".to_string(),
793                description: None,
794            })
795            .collect();
796        let rels = build_relationships(&entities);
797        let mut pares: std::collections::HashSet<(String, String)> =
798            std::collections::HashSet::new();
799        for r in &rels {
800            let par = (r.source.clone(), r.target.clone());
801            assert!(pares.insert(par), "par duplicado encontrado");
802        }
803    }
804
805    #[test]
806    fn merge_deduplica_por_nome_lowercase() {
807        let a = vec![ExtractedEntity {
808            name: "Rust".to_string(),
809            entity_type: "concept".to_string(),
810        }];
811        let b = vec![ExtractedEntity {
812            name: "rust".to_string(),
813            entity_type: "tool".to_string(),
814        }];
815        let merged = merge_and_deduplicate(a, b);
816        assert_eq!(merged.len(), 1, "rust e Rust são a mesma entidade");
817    }
818
819    #[test]
820    fn regex_extractor_implementa_trait() {
821        let extractor = RegexExtractor;
822        let result = extractor
823            .extract("contato: dev@empresa.io e MAX_TIMEOUT configurado")
824            .unwrap();
825        assert!(!result.entities.is_empty());
826    }
827
828    #[test]
829    fn extract_retorna_ok_sem_modelo() {
830        // Sem modelo baixado, deve retornar Ok com apenas entidades regex
831        let paths = make_paths();
832        let body = "contato: teste@exemplo.com com MAX_RETRY=3";
833        let result = extract_graph_auto(body, &paths).unwrap();
834        assert!(result
835            .entities
836            .iter()
837            .any(|e| e.name.contains("teste@exemplo.com")));
838    }
839}