Skip to main content

sqlite_graphrag/
extraction.rs

1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3use std::sync::OnceLock;
4
5use anyhow::{Context, Result};
6use candle_core::{DType, Device, Tensor};
7use candle_nn::{Linear, Module, VarBuilder};
8use candle_transformers::models::bert::{BertModel, Config as BertConfig};
9use regex::Regex;
10use serde::Deserialize;
11use unicode_normalization::UnicodeNormalization;
12
13use crate::paths::AppPaths;
14use crate::storage::entities::{NewEntity, NewRelationship};
15
16const MODEL_ID: &str = "Davlan/bert-base-multilingual-cased-ner-hrl";
17const MAX_SEQ_LEN: usize = 512;
18const STRIDE: usize = 256;
19const MAX_ENTS: usize = 30;
20const TOP_K_RELATIONS: usize = 5;
21const DEFAULT_RELATION: &str = "mentions";
22const MIN_ENTITY_CHARS: usize = 2;
23
24static REGEX_EMAIL: OnceLock<Regex> = OnceLock::new();
25static REGEX_URL: OnceLock<Regex> = OnceLock::new();
26static REGEX_UUID: OnceLock<Regex> = OnceLock::new();
27static REGEX_ALL_CAPS: OnceLock<Regex> = OnceLock::new();
28
29// v1.0.20: stopwords para filtrar palavras-regra PT-BR/EN comuns capturadas como ALL_CAPS.
30// Sem este filtro, corpus técnico em PT-BR contendo regras formatadas em CAPS (NUNCA, PROIBIDO, DEVE)
31// gerava ~70% de "entidades" lixo. Mantemos identificadores tipo MAX_RETRY (com underscore).
32// v1.0.22: lista expandida com termos observados em stress test 495 arquivos do flowaiper.
33// Inclui verbos (ADICIONAR, VALIDAR), adjetivos (ALTA, BAIXA), substantivos comuns (BANCO, CASO),
34// HTTP methods (GET, POST, DELETE) e formatos de dados genéricos (JSON, XML).
35// v1.0.24: added 17 new terms observed in audit v1.0.23: generic status words (COMPLETED, DONE,
36// FIXED, PENDING), PT-BR imperative verbs (ACEITE, CONFIRME, NEGUE, RECUSE), PT-BR modal/
37// common verbs (DEVEMOS, PODEMOS, VAMOS), generic nouns (BORDA, CHECKLIST, PLAN, TOKEN),
38// and common abbreviations (ACK, ACL).
39const ALL_CAPS_STOPWORDS: &[&str] = &[
40    "ACEITE",
41    "ACK",
42    "ACL",
43    "ACRESCENTADO",
44    "ADICIONAR",
45    "AGENTS",
46    "ALL",
47    "ALTA",
48    "ALWAYS",
49    "ARTEFATOS",
50    "ATIVO",
51    "BAIXA",
52    "BANCO",
53    "BORDA",
54    "BLOQUEAR",
55    "BUG",
56    "CASO",
57    "CHECKLIST",
58    "COMPLETED",
59    "CONFIRMADO",
60    "CONFIRME",
61    "CONTRATO",
62    "CRÍTICO",
63    "CRITICAL",
64    "CSV",
65    "DEVE",
66    "DEVEMOS",
67    "DISCO",
68    "DONE",
69    "EFEITO",
70    "ENTRADA",
71    "ERROR",
72    "ESSA",
73    "ESSE",
74    "ESSENCIAL",
75    "ESTA",
76    "ESTE",
77    "EVITAR",
78    "EXPANDIR",
79    "EXPOR",
80    "FALHA",
81    "FIXED",
82    "FIXME",
83    "FORBIDDEN",
84    "HACK",
85    "HEARTBEAT",
86    "INATIVO",
87    "JAMAIS",
88    "JSON",
89    "MUST",
90    "NEGUE",
91    "NEVER",
92    "NOTE",
93    "NUNCA",
94    "OBRIGATÓRIO",
95    "PADRÃO",
96    "PENDING",
97    "PLAN",
98    "PODEMOS",
99    "PROIBIDO",
100    "RECUSE",
101    "REGRAS",
102    "REQUIRED",
103    "REQUISITO",
104    "SEMPRE",
105    "SHALL",
106    "SHOULD",
107    "SOUL",
108    "TODAS",
109    "TODO",
110    "TODOS",
111    "TOKEN",
112    "TOOLS",
113    "TSV",
114    "USAR",
115    "VALIDAR",
116    "VAMOS",
117    "VOCÊ",
118    "WARNING",
119    "XML",
120    "YAML",
121];
122
123// v1.0.22: HTTP methods são verbos de protocolo, não entidades semanticamente úteis.
124// Filtrados em apply_regex_prefilter (regex_all_caps) e iob_to_entities (single-token).
125const HTTP_METHODS: &[&str] = &[
126    "GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS", "CONNECT", "TRACE",
127];
128
129fn is_filtered_all_caps(token: &str) -> bool {
130    // Identificadores com underscore são preservados (ex: MAX_RETRY, FLOWAIPER_API_KEY)
131    let is_identifier = token.contains('_');
132    if is_identifier {
133        return false;
134    }
135    ALL_CAPS_STOPWORDS.contains(&token) || HTTP_METHODS.contains(&token)
136}
137
138fn regex_email() -> &'static Regex {
139    REGEX_EMAIL
140        .get_or_init(|| Regex::new(r"[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}").unwrap())
141}
142
143fn regex_url() -> &'static Regex {
144    REGEX_URL.get_or_init(|| Regex::new(r#"https?://[^\s\)\]\}"'<>]+"#).unwrap())
145}
146
147fn regex_uuid() -> &'static Regex {
148    REGEX_UUID.get_or_init(|| {
149        Regex::new(r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}")
150            .unwrap()
151    })
152}
153
154fn regex_all_caps() -> &'static Regex {
155    REGEX_ALL_CAPS.get_or_init(|| Regex::new(r"\b[A-Z][A-Z0-9_]{2,}\b").unwrap())
156}
157
158#[derive(Debug, Clone, PartialEq)]
159pub struct ExtractedEntity {
160    pub name: String,
161    pub entity_type: String,
162}
163
164/// URL com offset de origem extraída do corpo da memória.
165#[derive(Debug, Clone)]
166pub struct ExtractedUrl {
167    pub url: String,
168    /// Posição em bytes no corpo onde a URL foi encontrada.
169    pub offset: usize,
170}
171
172#[derive(Debug, Clone)]
173pub struct ExtractionResult {
174    pub entities: Vec<NewEntity>,
175    pub relationships: Vec<NewRelationship>,
176    /// True when build_relationships hit the cap before covering all entity pairs.
177    /// Exposed in RememberResponse so callers can detect when relationships were cut.
178    pub relationships_truncated: bool,
179    /// Método usado para extração: "bert+regex" ou "regex-only".
180    /// Útil para auditoria, métricas e reportes ao usuário.
181    pub extraction_method: String,
182    /// URLs extraídas do corpo — armazenadas separadamente das entidades do grafo.
183    pub urls: Vec<ExtractedUrl>,
184}
185
186pub trait Extractor: Send + Sync {
187    fn extract(&self, body: &str) -> Result<ExtractionResult>;
188}
189
190#[derive(Deserialize)]
191struct ModelConfig {
192    #[serde(default)]
193    id2label: HashMap<String, String>,
194    hidden_size: usize,
195}
196
197struct BertNerModel {
198    bert: BertModel,
199    classifier: Linear,
200    device: Device,
201    id2label: HashMap<usize, String>,
202}
203
204impl BertNerModel {
205    fn load(model_dir: &Path) -> Result<Self> {
206        let config_path = model_dir.join("config.json");
207        let weights_path = model_dir.join("model.safetensors");
208
209        let config_str = std::fs::read_to_string(&config_path)
210            .with_context(|| format!("lendo config.json em {config_path:?}"))?;
211        let model_cfg: ModelConfig =
212            serde_json::from_str(&config_str).context("parseando config.json do modelo NER")?;
213
214        let id2label: HashMap<usize, String> = model_cfg
215            .id2label
216            .into_iter()
217            .filter_map(|(k, v)| k.parse::<usize>().ok().map(|n| (n, v)))
218            .collect();
219
220        let num_labels = id2label.len().max(9);
221        let hidden_size = model_cfg.hidden_size;
222
223        let bert_config_str = std::fs::read_to_string(&config_path)
224            .with_context(|| format!("relendo config.json para bert em {config_path:?}"))?;
225        let bert_cfg: BertConfig =
226            serde_json::from_str(&bert_config_str).context("parseando BertConfig")?;
227
228        let device = Device::Cpu;
229
230        let vb = unsafe {
231            VarBuilder::from_mmaped_safetensors(&[&weights_path], DType::F32, &device)
232                .with_context(|| format!("mapeando {weights_path:?}"))?
233        };
234        let bert = BertModel::load(vb.pp("bert"), &bert_cfg).context("carregando BertModel")?;
235
236        // v1.0.20 fix P0 secundário: carregar classifier head do safetensors em vez de zeros.
237        // Em v1.0.19 usávamos Tensor::zeros, o que produzia argmax constante e inferência degenerada.
238        let cls_vb = vb.pp("classifier");
239        let weight = cls_vb
240            .get((num_labels, hidden_size), "weight")
241            .context("carregando classifier.weight do safetensors")?;
242        let bias = cls_vb
243            .get(num_labels, "bias")
244            .context("carregando classifier.bias do safetensors")?;
245        let classifier = Linear::new(weight, Some(bias));
246
247        Ok(Self {
248            bert,
249            classifier,
250            device,
251            id2label,
252        })
253    }
254
255    fn predict(&self, token_ids: &[u32], attention_mask: &[u32]) -> Result<Vec<String>> {
256        let len = token_ids.len();
257        let ids_i64: Vec<i64> = token_ids.iter().map(|&x| x as i64).collect();
258        let mask_i64: Vec<i64> = attention_mask.iter().map(|&x| x as i64).collect();
259
260        let input_ids = Tensor::from_vec(ids_i64, (1, len), &self.device)
261            .context("criando tensor input_ids")?;
262        let token_type_ids = Tensor::zeros((1, len), DType::I64, &self.device)
263            .context("criando tensor token_type_ids")?;
264        let attn_mask = Tensor::from_vec(mask_i64, (1, len), &self.device)
265            .context("criando tensor attention_mask")?;
266
267        let sequence_output = self
268            .bert
269            .forward(&input_ids, &token_type_ids, Some(&attn_mask))
270            .context("forward pass do BertModel")?;
271
272        let logits = self
273            .classifier
274            .forward(&sequence_output)
275            .context("forward pass do classificador")?;
276
277        let logits_2d = logits.squeeze(0).context("removendo dimensão batch")?;
278
279        let num_tokens = logits_2d.dim(0).context("dim(0)")?;
280
281        let mut labels = Vec::with_capacity(num_tokens);
282        for i in 0..num_tokens {
283            let token_logits = logits_2d.get(i).context("get token logits")?;
284            let vec: Vec<f32> = token_logits.to_vec1().context("to_vec1 logits")?;
285            let argmax = vec
286                .iter()
287                .enumerate()
288                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
289                .map(|(idx, _)| idx)
290                .unwrap_or(0);
291            let label = self
292                .id2label
293                .get(&argmax)
294                .cloned()
295                .unwrap_or_else(|| "O".to_string());
296            labels.push(label);
297        }
298
299        Ok(labels)
300    }
301
302    /// Run a batched forward pass over multiple tokenised windows at once.
303    ///
304    /// Windows are padded on the right with token_id=0 and attention_mask=0 to
305    /// the length of the longest window in the batch.  The attention mask ensures
306    /// BERT ignores padded positions (bert.rs:515-528 adds -3.4e38 before softmax).
307    ///
308    /// Returns one label vector per window, each of length equal to that window's
309    /// original (pre-padding) token count.
310    fn predict_batch(&self, windows: &[(Vec<u32>, Vec<String>)]) -> Result<Vec<Vec<String>>> {
311        let batch_size = windows.len();
312        let max_len = windows.iter().map(|(ids, _)| ids.len()).max().unwrap_or(0);
313        if max_len == 0 {
314            return Ok(vec![vec![]; batch_size]);
315        }
316
317        let mut padded_ids: Vec<Tensor> = Vec::with_capacity(batch_size);
318        let mut padded_masks: Vec<Tensor> = Vec::with_capacity(batch_size);
319
320        for (ids, _) in windows {
321            let len = ids.len();
322            let pad_right = max_len - len;
323
324            let ids_i64: Vec<i64> = ids.iter().map(|&x| x as i64).collect();
325            // Build 1-D token tensor then pad to max_len
326            let t = Tensor::from_vec(ids_i64, len, &self.device)
327                .context("criando tensor de ids para batch")?;
328            let t = t
329                .pad_with_zeros(0, 0, pad_right)
330                .context("padding tensor de ids")?;
331            padded_ids.push(t);
332
333            // Attention mask: 1 for real tokens, 0 for padding
334            let mut mask_i64 = vec![1i64; len];
335            mask_i64.extend(vec![0i64; pad_right]);
336            let m = Tensor::from_vec(mask_i64, max_len, &self.device)
337                .context("criando tensor de máscara para batch")?;
338            padded_masks.push(m);
339        }
340
341        // Stack 1-D tensors into (batch_size, max_len)
342        let input_ids = Tensor::stack(&padded_ids, 0).context("stack input_ids")?;
343        let attn_mask = Tensor::stack(&padded_masks, 0).context("stack attn_mask")?;
344        let token_type_ids = Tensor::zeros((batch_size, max_len), DType::I64, &self.device)
345            .context("criando token_type_ids batch")?;
346
347        // Single forward pass for the entire batch
348        let sequence_output = self
349            .bert
350            .forward(&input_ids, &token_type_ids, Some(&attn_mask))
351            .context("forward pass batch BertModel")?;
352        // sequence_output: (batch_size, max_len, hidden_size)
353
354        let logits = self
355            .classifier
356            .forward(&sequence_output)
357            .context("forward pass batch classificador")?;
358        // logits: (batch_size, max_len, num_labels)
359
360        let mut results = Vec::with_capacity(batch_size);
361        for (i, (window_ids, _)) in windows.iter().enumerate() {
362            let example_logits = logits.get(i).context("get logits exemplo")?;
363            // (max_len, num_labels) — slice only real tokens, discard padding
364            let real_len = window_ids.len();
365            let example_slice = example_logits
366                .narrow(0, 0, real_len)
367                .context("narrow para tokens reais")?;
368            let logits_2d: Vec<Vec<f32>> = example_slice.to_vec2().context("to_vec2 logits")?;
369
370            let labels: Vec<String> = logits_2d
371                .iter()
372                .map(|token_logits| {
373                    let argmax = token_logits
374                        .iter()
375                        .enumerate()
376                        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
377                        .map(|(idx, _)| idx)
378                        .unwrap_or(0);
379                    self.id2label
380                        .get(&argmax)
381                        .cloned()
382                        .unwrap_or_else(|| "O".to_string())
383                })
384                .collect();
385
386            results.push(labels);
387        }
388
389        Ok(results)
390    }
391}
392
393static NER_MODEL: OnceLock<Option<BertNerModel>> = OnceLock::new();
394
395fn get_or_init_model(paths: &AppPaths) -> Option<&'static BertNerModel> {
396    NER_MODEL
397        .get_or_init(|| match load_model(paths) {
398            Ok(m) => Some(m),
399            Err(e) => {
400                tracing::warn!("NER model não disponível (graceful degradation): {e:#}");
401                None
402            }
403        })
404        .as_ref()
405}
406
407fn model_dir(paths: &AppPaths) -> PathBuf {
408    paths.models.join("bert-multilingual-ner")
409}
410
411fn ensure_model_files(paths: &AppPaths) -> Result<PathBuf> {
412    let dir = model_dir(paths);
413    std::fs::create_dir_all(&dir)
414        .with_context(|| format!("criando diretório do modelo: {dir:?}"))?;
415
416    let weights = dir.join("model.safetensors");
417    let config = dir.join("config.json");
418    let tokenizer = dir.join("tokenizer.json");
419
420    if weights.exists() && config.exists() && tokenizer.exists() {
421        return Ok(dir);
422    }
423
424    tracing::info!("Baixando modelo NER (primeira execução, ~676 MB)...");
425    crate::output::emit_progress_i18n(
426        "Downloading NER model (first run, ~676 MB)...",
427        "Baixando modelo NER (primeira execução, ~676 MB)...",
428    );
429
430    let api = huggingface_hub::api::sync::Api::new().context("criando cliente HF Hub")?;
431    let repo = api.model(MODEL_ID.to_string());
432
433    // v1.0.20 fix P0 primário: tokenizer.json no repo Davlan está apenas em onnx/tokenizer.json.
434    // Em v1.0.19 buscávamos da raiz e recebíamos 404, caindo em graceful degradation 100% das vezes.
435    // Mapeamos (remote_path, local_filename) para baixar do subfolder mantendo nome plano local.
436    for (remote, local) in &[
437        ("model.safetensors", "model.safetensors"),
438        ("config.json", "config.json"),
439        ("onnx/tokenizer.json", "tokenizer.json"),
440        ("tokenizer_config.json", "tokenizer_config.json"),
441    ] {
442        let dest = dir.join(local);
443        if !dest.exists() {
444            let src = repo
445                .get(remote)
446                .with_context(|| format!("baixando {remote} do HF Hub"))?;
447            std::fs::copy(&src, &dest).with_context(|| format!("copiando {local} para cache"))?;
448        }
449    }
450
451    Ok(dir)
452}
453
454fn load_model(paths: &AppPaths) -> Result<BertNerModel> {
455    let dir = ensure_model_files(paths)?;
456    BertNerModel::load(&dir)
457}
458
459fn apply_regex_prefilter(body: &str) -> Vec<ExtractedEntity> {
460    let mut entities = Vec::new();
461    let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
462
463    let add = |entities: &mut Vec<ExtractedEntity>,
464               seen: &mut std::collections::HashSet<String>,
465               name: &str,
466               entity_type: &str| {
467        let name = name.trim().to_string();
468        if name.len() >= MIN_ENTITY_CHARS && seen.insert(name.clone()) {
469            entities.push(ExtractedEntity {
470                name,
471                entity_type: entity_type.to_string(),
472            });
473        }
474    };
475
476    for m in regex_email().find_iter(body) {
477        // v1.0.20: email é "concept" (regex sozinho não distingue pessoa de mailing list/role).
478        add(&mut entities, &mut seen, m.as_str(), "concept");
479    }
480    for m in regex_uuid().find_iter(body) {
481        add(&mut entities, &mut seen, m.as_str(), "concept");
482    }
483    for m in regex_all_caps().find_iter(body) {
484        let candidate = m.as_str();
485        // v1.0.22: filtro consolidado (stopwords + HTTP methods); preserva identificadores com underscore.
486        if !is_filtered_all_caps(candidate) {
487            add(&mut entities, &mut seen, candidate, "concept");
488        }
489    }
490
491    entities
492}
493
494/// Extrai URLs do corpo de uma memória, desduplicadas por texto.
495/// URLs são armazenadas na tabela `memory_urls` separadamente do grafo de entidades.
496/// v1.0.24: split do bloco URL que poluía apply_regex_prefilter com entity_type='concept'.
497pub fn extract_urls(body: &str) -> Vec<ExtractedUrl> {
498    let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
499    let mut result = Vec::new();
500    for m in regex_url().find_iter(body) {
501        let raw = m.as_str();
502        let cleaned = raw
503            .trim_end_matches('`')
504            .trim_end_matches(',')
505            .trim_end_matches('.')
506            .trim_end_matches(';')
507            .trim_end_matches(')')
508            .trim_end_matches(']')
509            .trim_end_matches('}');
510        if cleaned.len() >= 10 && seen.insert(cleaned.to_string()) {
511            result.push(ExtractedUrl {
512                url: cleaned.to_string(),
513                offset: m.start(),
514            });
515        }
516    }
517    result
518}
519
520fn iob_to_entities(tokens: &[String], labels: &[String]) -> Vec<ExtractedEntity> {
521    let mut entities: Vec<ExtractedEntity> = Vec::new();
522    let mut current_parts: Vec<String> = Vec::new();
523    let mut current_type: Option<String> = None;
524
525    let flush =
526        |parts: &mut Vec<String>, typ: &mut Option<String>, entities: &mut Vec<ExtractedEntity>| {
527            if let Some(t) = typ.take() {
528                let name = parts.join(" ").trim().to_string();
529                // v1.0.22: filtra single-token entities que sejam stopwords ALL CAPS ou HTTP methods.
530                // BERT NER classifica algumas dessas como B-MISC/B-ORG; pós-filtro aqui evita
531                // poluir o grafo com verbos/protocolos genéricos.
532                let is_single_caps = !name.contains(' ')
533                    && name == name.to_uppercase()
534                    && name.len() >= MIN_ENTITY_CHARS;
535                let should_skip = is_single_caps && is_filtered_all_caps(&name);
536                if name.len() >= MIN_ENTITY_CHARS && !should_skip {
537                    entities.push(ExtractedEntity {
538                        name,
539                        entity_type: t,
540                    });
541                }
542                parts.clear();
543            }
544        };
545
546    for (token, label) in tokens.iter().zip(labels.iter()) {
547        if label == "O" {
548            flush(&mut current_parts, &mut current_type, &mut entities);
549            continue;
550        }
551
552        let (prefix, bio_type) = if let Some(rest) = label.strip_prefix("B-") {
553            ("B", rest)
554        } else if let Some(rest) = label.strip_prefix("I-") {
555            ("I", rest)
556        } else {
557            flush(&mut current_parts, &mut current_type, &mut entities);
558            continue;
559        };
560
561        let entity_type = match bio_type {
562            "DATE" => {
563                flush(&mut current_parts, &mut current_type, &mut entities);
564                continue;
565            }
566            "PER" => "person",
567            "ORG" => {
568                let t = token.to_lowercase();
569                if t.contains("lib")
570                    || t.contains("sdk")
571                    || t.contains("cli")
572                    || t.contains("crate")
573                    || t.contains("npm")
574                {
575                    "tool"
576                } else {
577                    "project"
578                }
579            }
580            "LOC" => "concept",
581            other => other,
582        };
583
584        if prefix == "B" {
585            if token.starts_with("##") {
586                // BERT confuso: subword com B-prefix indica continuação de entidade anterior.
587                // Anexar à última parte da entidade atual; senão descartar.
588                let clean = token.strip_prefix("##").unwrap_or(token.as_str());
589                if let Some(last) = current_parts.last_mut() {
590                    last.push_str(clean);
591                }
592                continue;
593            }
594            flush(&mut current_parts, &mut current_type, &mut entities);
595            current_parts.push(token.clone());
596            current_type = Some(entity_type.to_string());
597        } else if prefix == "I" && current_type.is_some() {
598            let clean = token.strip_prefix("##").unwrap_or(token.as_str());
599            if token.starts_with("##") {
600                if let Some(last) = current_parts.last_mut() {
601                    last.push_str(clean);
602                }
603            } else {
604                current_parts.push(clean.to_string());
605            }
606        }
607    }
608
609    flush(&mut current_parts, &mut current_type, &mut entities);
610    entities
611}
612
613/// Returns (relationships, truncated) where truncated is true when the cap was hit
614/// before all entity pairs were covered. Exposed in RememberResponse as
615/// `relationships_truncated` so callers can decide whether to increase the cap.
616fn build_relationships(entities: &[NewEntity]) -> (Vec<NewRelationship>, bool) {
617    if entities.len() < 2 {
618        return (Vec::new(), false);
619    }
620
621    // v1.0.22: cap configurável via env var (constants::max_relationships_per_memory).
622    // Permite usuários com corpus denso aumentar além do default 50.
623    let max_rels = crate::constants::max_relationships_per_memory();
624    let n = entities.len().min(MAX_ENTS);
625    let mut rels: Vec<NewRelationship> = Vec::new();
626    let mut seen: std::collections::HashSet<(String, String)> = std::collections::HashSet::new();
627
628    let mut hit_cap = false;
629    'outer: for i in 0..n {
630        if rels.len() >= max_rels {
631            hit_cap = true;
632            break;
633        }
634
635        let mut for_entity = 0usize;
636        for j in (i + 1)..n {
637            if for_entity >= TOP_K_RELATIONS {
638                break;
639            }
640            if rels.len() >= max_rels {
641                hit_cap = true;
642                break 'outer;
643            }
644
645            let src = &entities[i].name;
646            let tgt = &entities[j].name;
647            let key = (src.clone(), tgt.clone());
648
649            if seen.contains(&key) {
650                continue;
651            }
652            seen.insert(key);
653
654            rels.push(NewRelationship {
655                source: src.clone(),
656                target: tgt.clone(),
657                relation: DEFAULT_RELATION.to_string(),
658                strength: 0.5,
659                description: None,
660            });
661            for_entity += 1;
662        }
663    }
664
665    // v1.0.20: avisar quando relacionamentos foram truncados antes de cobrir todos os pares possíveis.
666    if hit_cap {
667        tracing::warn!(
668            "relacionamentos truncados em {max_rels} (com {n} entidades, máx teórico era ~{}× combinações)",
669            n.saturating_sub(1)
670        );
671    }
672
673    (rels, hit_cap)
674}
675
676fn run_ner_sliding_window(
677    model: &BertNerModel,
678    body: &str,
679    paths: &AppPaths,
680) -> Result<Vec<ExtractedEntity>> {
681    let tokenizer_path = model_dir(paths).join("tokenizer.json");
682    let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
683        .map_err(|e| anyhow::anyhow!("carregando tokenizer NER: {e}"))?;
684
685    let encoding = tokenizer
686        .encode(body, false)
687        .map_err(|e| anyhow::anyhow!("encoding NER: {e}"))?;
688
689    let all_ids: Vec<u32> = encoding.get_ids().to_vec();
690    let all_tokens: Vec<String> = encoding
691        .get_tokens()
692        .iter()
693        .map(|s| s.to_string())
694        .collect();
695
696    if all_ids.is_empty() {
697        return Ok(Vec::new());
698    }
699
700    // Phase 1: collect all sliding windows before any inference
701    let mut windows: Vec<(Vec<u32>, Vec<String>)> = Vec::new();
702    let mut start = 0usize;
703    loop {
704        let end = (start + MAX_SEQ_LEN).min(all_ids.len());
705        windows.push((
706            all_ids[start..end].to_vec(),
707            all_tokens[start..end].to_vec(),
708        ));
709        if end >= all_ids.len() {
710            break;
711        }
712        start += STRIDE;
713    }
714
715    // Phase 2: sort by window length ascending to minimise intra-batch padding waste
716    windows.sort_by_key(|(ids, _)| ids.len());
717
718    // Phase 3: batched inference with fallback to single-window predict on error
719    let batch_size = crate::constants::ner_batch_size();
720    let mut entities: Vec<ExtractedEntity> = Vec::new();
721    let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
722
723    for chunk in windows.chunks(batch_size) {
724        match model.predict_batch(chunk) {
725            Ok(batch_labels) => {
726                for (labels, (_, tokens)) in batch_labels.iter().zip(chunk.iter()) {
727                    for ent in iob_to_entities(tokens, labels) {
728                        if seen.insert(ent.name.clone()) {
729                            entities.push(ent);
730                        }
731                    }
732                }
733            }
734            Err(e) => {
735                tracing::warn!(
736                    "batch NER falhou (chunk de {} janelas): {e:#} — fallback single-window",
737                    chunk.len()
738                );
739                // Fallback: process each window individually to preserve entities
740                for (ids, tokens) in chunk {
741                    let mask = vec![1u32; ids.len()];
742                    match model.predict(ids, &mask) {
743                        Ok(labels) => {
744                            for ent in iob_to_entities(tokens, &labels) {
745                                if seen.insert(ent.name.clone()) {
746                                    entities.push(ent);
747                                }
748                            }
749                        }
750                        Err(e2) => {
751                            tracing::warn!("janela NER fallback também falhou: {e2:#}");
752                        }
753                    }
754                }
755            }
756        }
757    }
758
759    Ok(entities)
760}
761
762/// v1.0.22 P1: estende entidades com sufixos numéricos hifenizados ou separados por espaço.
763/// Casos: GPT extraído mas body contém "GPT-5" → reescreve para "GPT-5".
764/// Casos: Claude extraído mas body contém "Claude 4" → reescreve para "Claude 4".
765/// Conservador: só estende se sufixo tiver até 7 caracteres.
766/// v1.0.24 P2-E: sufixo aceita letra ASCII minúscula opcional após dígitos para cobrir
767/// modelos como "GPT-4o", "Llama-5b", "Mistral-8x" (dígitos + [a-z]? + [x\d+]?).
768fn extend_with_numeric_suffix(entities: Vec<ExtractedEntity>, body: &str) -> Vec<ExtractedEntity> {
769    static SUFFIX_RE: OnceLock<Regex> = OnceLock::new();
770    // Matches: separator + digits + optional decimal + optional lowercase letter
771    // Examples: "-4", " 5", "-4o", " 5b", "-8x", " 3.5", "-3.5-turbo" (capped by len)
772    let suffix_re = SUFFIX_RE.get_or_init(|| Regex::new(r"^([\-\s]+\d+(?:\.\d+)?[a-z]?)").unwrap());
773
774    entities
775        .into_iter()
776        .map(|ent| {
777            // Encontra a primeira ocorrência case-sensitive da entidade no body
778            if let Some(pos) = body.find(&ent.name) {
779                let after_pos = pos + ent.name.len();
780                if after_pos < body.len() {
781                    let after = &body[after_pos..];
782                    if let Some(m) = suffix_re.find(after) {
783                        let suffix = m.as_str();
784                        // Conservative: cap suffix length to 7 chars to avoid grabbing
785                        // long hyphenated phrases while allowing "4o", "5b", "3.5b".
786                        if suffix.len() <= 7 {
787                            let extended = format!("{}{}", ent.name, suffix);
788                            return ExtractedEntity {
789                                name: extended,
790                                entity_type: ent.entity_type,
791                            };
792                        }
793                    }
794                }
795            }
796            ent
797        })
798        .collect()
799}
800
801/// Captures versioned model names that BERT NER consistently misses.
802///
803/// BERT NER often classifies tokens like "Claude" or "Llama" as common nouns,
804/// failing to emit a B-PER/B-ORG tag. As a result, `extend_with_numeric_suffix`
805/// never sees these candidates and the version suffix gets lost.
806///
807/// This function scans the body with a conservative regex, matching capitalised
808/// words followed by a space-or-hyphen and a small integer. Matches that are not
809/// already covered by an existing entity (case-insensitive) are appended with the
810/// `concept` type, mirroring how `extend_with_numeric_suffix` represents these
811/// items downstream.
812///
813/// v1.0.24 P2-D: regex extended to cover:
814/// - Alphanumeric version suffixes: "GPT-4o", "Llama-3b", "Mistral-8x"
815/// - Composite versions: "Mixtral 8x7B" (digit × digit + uppercase letter)
816/// - Named release tiers after version: "Claude 4 Sonnet", "Llama 3 Pro"
817///
818/// Examples covered: "Claude 4", "Llama 3", "GPT-4o", "Claude 4 Sonnet", "Mixtral 8x7B".
819/// Examples already handled upstream and skipped here: plain "Apple" without a suffix.
820fn augment_versioned_model_names(
821    entities: Vec<ExtractedEntity>,
822    body: &str,
823) -> Vec<ExtractedEntity> {
824    static VERSIONED_MODEL_RE: OnceLock<Regex> = OnceLock::new();
825    // Pattern breakdown:
826    //   [A-Z][A-Za-z]{2,15}   — capitalised model name (3-16 chars)
827    //   [\s\-]+               — separator: space(s) or hyphen(s)
828    //   \d+(?:\.\d+)?         — version number, optional decimal
829    //   (?:[a-z]|x\d+[A-Za-z]?)? — optional alphanumeric suffix: "o", "b", "x7B"
830    //   (?:\s+(?:Sonnet|Opus|Haiku|Turbo|Pro|Lite|Mini|Nano|Flash|Ultra))? — optional release tier
831    let model_re = VERSIONED_MODEL_RE.get_or_init(|| {
832        Regex::new(
833            r"\b([A-Z][A-Za-z]{2,15})[\s\-]+(\d+(?:\.\d+)?(?:[a-z]|x\d+[A-Za-z]?)?)(?:\s+(?:Sonnet|Opus|Haiku|Turbo|Pro|Lite|Mini|Nano|Flash|Ultra))?\b",
834        )
835        .unwrap()
836    });
837
838    let mut existing_lc: std::collections::HashSet<String> =
839        entities.iter().map(|ent| ent.name.to_lowercase()).collect();
840    let mut result = entities;
841
842    for caps in model_re.captures_iter(body) {
843        let full_match = caps.get(0).map(|m| m.as_str()).unwrap_or("");
844        // Conservative cap: avoid harvesting multi-word noise like "section 12" inside
845        // long passages. A model name plus a one or two digit suffix fits in 24 chars.
846        if full_match.is_empty() || full_match.len() > 24 {
847            continue;
848        }
849        let normalized_lc = full_match.to_lowercase();
850        if existing_lc.contains(&normalized_lc) {
851            continue;
852        }
853        // Stop appending once the global entity cap is reached to keep parity with
854        // `merge_and_deduplicate` truncation semantics.
855        if result.len() >= MAX_ENTS {
856            break;
857        }
858        existing_lc.insert(normalized_lc);
859        result.push(ExtractedEntity {
860            name: full_match.to_string(),
861            entity_type: "concept".to_string(),
862        });
863    }
864
865    result
866}
867
868fn merge_and_deduplicate(
869    regex_ents: Vec<ExtractedEntity>,
870    ner_ents: Vec<ExtractedEntity>,
871) -> Vec<ExtractedEntity> {
872    // v1.0.23: when multiple sources produce overlapping names ("Open" from BERT
873    // subword leak vs "OpenAI" from regex), prefer the longest candidate. The
874    // previous implementation used a HashSet and kept whichever name appeared
875    // first, occasionally yielding truncated brand names like "Open" instead of
876    // "OpenAI". The new logic resolves collisions using a (lowercase prefix) lookup
877    // that retains the longest match while preserving insertion order via `result`.
878    // v1.0.24: dedup key uses NFKC normalization before lowercasing so that
879    // visually identical names differing only in Unicode combining marks (e.g.
880    // "Café" NFC vs "Cafe\u{301}" NFD) collapse to the same bucket.
881    let mut by_lc: std::collections::HashMap<String, usize> = std::collections::HashMap::new();
882    let mut result: Vec<ExtractedEntity> = Vec::new();
883    let mut truncated = false;
884
885    let total_input = regex_ents.len() + ner_ents.len();
886    for ent in regex_ents.into_iter().chain(ner_ents) {
887        let key = ent.name.nfkc().collect::<String>().to_lowercase();
888        // Detect prefix collisions in both directions: "open" vs "openai" should
889        // both map to the longest stored candidate. We scan stored keys to find
890        // the longest existing entry that contains or is contained by the new key.
891        let mut collision_idx: Option<usize> = None;
892        for (existing_key, idx) in &by_lc {
893            if existing_key == &key
894                || existing_key.starts_with(&key)
895                || key.starts_with(existing_key)
896            {
897                collision_idx = Some(*idx);
898                break;
899            }
900        }
901        match collision_idx {
902            Some(idx) => {
903                // Replace stored entity only when the new candidate is strictly
904                // longer; otherwise drop the new one. This biases toward the most
905                // specific brand name visible in the corpus.
906                if ent.name.len() > result[idx].name.len() {
907                    let old_key = result[idx].name.nfkc().collect::<String>().to_lowercase();
908                    by_lc.remove(&old_key);
909                    result[idx] = ent;
910                    by_lc.insert(key, idx);
911                }
912            }
913            None => {
914                by_lc.insert(key, result.len());
915                result.push(ent);
916            }
917        }
918        if result.len() >= MAX_ENTS {
919            truncated = true;
920            break;
921        }
922    }
923
924    // v1.0.20: avisar quando truncamento silencioso descarta entidades acima do MAX_ENTS.
925    if truncated {
926        tracing::warn!(
927            "extração truncada em {MAX_ENTS} entidades (entrada tinha {total_input} candidatos antes da deduplicação)"
928        );
929    }
930
931    result
932}
933
934fn to_new_entities(extracted: Vec<ExtractedEntity>) -> Vec<NewEntity> {
935    extracted
936        .into_iter()
937        .map(|e| NewEntity {
938            name: e.name,
939            entity_type: e.entity_type,
940            description: None,
941        })
942        .collect()
943}
944
945pub fn extract_graph_auto(body: &str, paths: &AppPaths) -> Result<ExtractionResult> {
946    let regex_entities = apply_regex_prefilter(body);
947
948    let mut bert_used = false;
949    let ner_entities = match get_or_init_model(paths) {
950        Some(model) => match run_ner_sliding_window(model, body, paths) {
951            Ok(ents) => {
952                bert_used = true;
953                ents
954            }
955            Err(e) => {
956                tracing::warn!("NER falhou, usando apenas regex: {e:#}");
957                Vec::new()
958            }
959        },
960        None => Vec::new(),
961    };
962
963    let merged = merge_and_deduplicate(regex_entities, ner_entities);
964    // v1.0.22: estender entidades NER com sufixos numéricos do body (GPT-5, Claude 4, Python 3).
965    let extended = extend_with_numeric_suffix(merged, body);
966    // v1.0.23: capture versioned model names that BERT NER does not detect on its own
967    // (e.g. "Claude 4", "Llama 3"). Hyphenated variants like "GPT-5" are already covered
968    // by the NER+suffix pipeline above, but space-separated names need a dedicated pass.
969    let with_models = augment_versioned_model_names(extended, body);
970    let entities = to_new_entities(with_models);
971    let (relationships, relationships_truncated) = build_relationships(&entities);
972
973    let extraction_method = if bert_used {
974        "bert+regex-batch".to_string()
975    } else {
976        "regex-only".to_string()
977    };
978
979    let urls = extract_urls(body);
980
981    Ok(ExtractionResult {
982        entities,
983        relationships,
984        relationships_truncated,
985        extraction_method,
986        urls,
987    })
988}
989
990pub struct RegexExtractor;
991
992impl Extractor for RegexExtractor {
993    fn extract(&self, body: &str) -> Result<ExtractionResult> {
994        let regex_entities = apply_regex_prefilter(body);
995        let entities = to_new_entities(regex_entities);
996        let (relationships, relationships_truncated) = build_relationships(&entities);
997        let urls = extract_urls(body);
998        Ok(ExtractionResult {
999            entities,
1000            relationships,
1001            relationships_truncated,
1002            extraction_method: "regex-only".to_string(),
1003            urls,
1004        })
1005    }
1006}
1007
1008#[cfg(test)]
1009mod tests {
1010    use super::*;
1011
1012    fn make_paths() -> AppPaths {
1013        use std::path::PathBuf;
1014        AppPaths {
1015            db: PathBuf::from("/tmp/test.sqlite"),
1016            models: PathBuf::from("/tmp/test_models"),
1017        }
1018    }
1019
1020    #[test]
1021    fn regex_email_captura_endereco() {
1022        let ents = apply_regex_prefilter("contato: fulano@empresa.com.br para mais info");
1023        // v1.0.20: emails são classificados como "concept" (regex sozinho não distingue pessoa de role).
1024        assert!(ents
1025            .iter()
1026            .any(|e| e.name == "fulano@empresa.com.br" && e.entity_type == "concept"));
1027    }
1028
1029    #[test]
1030    fn regex_all_caps_filtra_palavra_regra_pt() {
1031        // v1.0.20 fix P1: NUNCA, PROIBIDO, DEVE não devem virar "entidades".
1032        let ents = apply_regex_prefilter("NUNCA fazer isso. PROIBIDO usar X. DEVE seguir Y.");
1033        assert!(
1034            !ents.iter().any(|e| e.name == "NUNCA"),
1035            "NUNCA deveria ser filtrado como stopword"
1036        );
1037        assert!(
1038            !ents.iter().any(|e| e.name == "PROIBIDO"),
1039            "PROIBIDO deveria ser filtrado"
1040        );
1041        assert!(
1042            !ents.iter().any(|e| e.name == "DEVE"),
1043            "DEVE deveria ser filtrado"
1044        );
1045    }
1046
1047    #[test]
1048    fn regex_all_caps_aceita_constante_com_underscore() {
1049        // Constantes técnicas tipo MAX_RETRY, TIMEOUT_MS sempre devem ser aceitas.
1050        let ents = apply_regex_prefilter("configure MAX_RETRY=3 e API_TIMEOUT=30");
1051        assert!(ents.iter().any(|e| e.name == "MAX_RETRY"));
1052        assert!(ents.iter().any(|e| e.name == "API_TIMEOUT"));
1053    }
1054
1055    #[test]
1056    fn regex_all_caps_aceita_acronimo_dominio() {
1057        // Acrônimos legítimos (não-stopword) devem passar: OPENAI, NVIDIA, GOOGLE.
1058        let ents = apply_regex_prefilter("OPENAI lançou GPT-5 com NVIDIA H100");
1059        assert!(ents.iter().any(|e| e.name == "OPENAI"));
1060        assert!(ents.iter().any(|e| e.name == "NVIDIA"));
1061    }
1062
1063    #[test]
1064    fn regex_url_nao_aparece_em_apply_regex_prefilter() {
1065        // v1.0.24 P0-2: URLs foram removidas de apply_regex_prefilter e agora vão para extract_urls.
1066        let ents = apply_regex_prefilter("veja https://docs.rs/crate para detalhes");
1067        assert!(
1068            !ents.iter().any(|e| e.name.starts_with("https://")),
1069            "URLs não devem aparecer como entidades após split P0-2"
1070        );
1071    }
1072
1073    #[test]
1074    fn extract_urls_captura_https() {
1075        let urls = extract_urls("veja https://docs.rs/crate para detalhes");
1076        assert_eq!(urls.len(), 1);
1077        assert_eq!(urls[0].url, "https://docs.rs/crate");
1078        assert!(urls[0].offset > 0);
1079    }
1080
1081    #[test]
1082    fn extract_urls_trim_sufixo_pontuacao() {
1083        let urls = extract_urls("link: https://example.com/path. fim");
1084        assert!(!urls.is_empty());
1085        assert!(
1086            !urls[0].url.ends_with('.'),
1087            "sufixo ponto deve ser removido"
1088        );
1089    }
1090
1091    #[test]
1092    fn extract_urls_deduplica_repetidas() {
1093        let body = "https://example.com referenciado aqui e depois aqui https://example.com";
1094        let urls = extract_urls(body);
1095        assert_eq!(urls.len(), 1, "URLs repetidas devem ser deduplicadas");
1096    }
1097
1098    #[test]
1099    fn regex_uuid_captura_identificador() {
1100        let ents = apply_regex_prefilter("id=550e8400-e29b-41d4-a716-446655440000 no sistema");
1101        assert!(ents.iter().any(|e| e.entity_type == "concept"));
1102    }
1103
1104    #[test]
1105    fn regex_all_caps_captura_constante() {
1106        let ents = apply_regex_prefilter("configure MAX_RETRY e TIMEOUT_MS");
1107        assert!(ents.iter().any(|e| e.name == "MAX_RETRY"));
1108        assert!(ents.iter().any(|e| e.name == "TIMEOUT_MS"));
1109    }
1110
1111    #[test]
1112    fn regex_all_caps_ignora_palavras_curtas() {
1113        let ents = apply_regex_prefilter("use AI em seu projeto");
1114        assert!(
1115            !ents.iter().any(|e| e.name == "AI"),
1116            "AI tem apenas 2 chars, deve ser ignorado"
1117        );
1118    }
1119
1120    #[test]
1121    fn iob_decodifica_per_para_person() {
1122        let tokens = vec![
1123            "John".to_string(),
1124            "Doe".to_string(),
1125            "trabalhou".to_string(),
1126        ];
1127        let labels = vec!["B-PER".to_string(), "I-PER".to_string(), "O".to_string()];
1128        let ents = iob_to_entities(&tokens, &labels);
1129        assert_eq!(ents.len(), 1);
1130        assert_eq!(ents[0].entity_type, "person");
1131        assert!(ents[0].name.contains("John"));
1132    }
1133
1134    #[test]
1135    fn iob_strip_subword_b_prefix() {
1136        // v1.0.21 P0: BERT às vezes emite ##AI com B-prefix (subword confuso).
1137        // Deve mergear na entidade ativa em vez de criar entidade fantasma "##AI".
1138        let tokens = vec!["Open".to_string(), "##AI".to_string()];
1139        let labels = vec!["B-ORG".to_string(), "B-ORG".to_string()];
1140        let ents = iob_to_entities(&tokens, &labels);
1141        assert!(
1142            ents.iter().any(|e| e.name == "OpenAI" || e.name == "Open"),
1143            "deveria mergear ##AI ou descartar"
1144        );
1145    }
1146
1147    #[test]
1148    fn iob_subword_orphan_descarta() {
1149        // v1.0.21 P0: subword órfão sem entidade ativa não deve virar entidade.
1150        let tokens = vec!["##AI".to_string()];
1151        let labels = vec!["B-ORG".to_string()];
1152        let ents = iob_to_entities(&tokens, &labels);
1153        assert!(
1154            ents.is_empty(),
1155            "subword órfão sem entidade ativa deve ser descartado"
1156        );
1157    }
1158
1159    #[test]
1160    fn iob_descarta_date() {
1161        let tokens = vec!["Janeiro".to_string(), "2024".to_string()];
1162        let labels = vec!["B-DATE".to_string(), "I-DATE".to_string()];
1163        let ents = iob_to_entities(&tokens, &labels);
1164        assert!(ents.is_empty(), "DATE deve ser descartado");
1165    }
1166
1167    #[test]
1168    fn iob_mapeia_org_para_project() {
1169        let tokens = vec!["Empresa".to_string()];
1170        let labels = vec!["B-ORG".to_string()];
1171        let ents = iob_to_entities(&tokens, &labels);
1172        assert_eq!(ents[0].entity_type, "project");
1173    }
1174
1175    #[test]
1176    fn iob_mapeia_org_sdk_para_tool() {
1177        let tokens = vec!["tokio-sdk".to_string()];
1178        let labels = vec!["B-ORG".to_string()];
1179        let ents = iob_to_entities(&tokens, &labels);
1180        assert_eq!(ents[0].entity_type, "tool");
1181    }
1182
1183    #[test]
1184    fn iob_mapeia_loc_para_concept() {
1185        let tokens = vec!["Brasil".to_string()];
1186        let labels = vec!["B-LOC".to_string()];
1187        let ents = iob_to_entities(&tokens, &labels);
1188        assert_eq!(ents[0].entity_type, "concept");
1189    }
1190
1191    #[test]
1192    fn build_relationships_respeitam_max_rels() {
1193        let entities: Vec<NewEntity> = (0..20)
1194            .map(|i| NewEntity {
1195                name: format!("entidade_{i}"),
1196                entity_type: "concept".to_string(),
1197                description: None,
1198            })
1199            .collect();
1200        let (rels, truncated) = build_relationships(&entities);
1201        let max_rels = crate::constants::max_relationships_per_memory();
1202        assert!(rels.len() <= max_rels, "deve respeitar max_rels={max_rels}");
1203        if rels.len() == max_rels {
1204            assert!(truncated, "truncated deve ser true quando atingiu o cap");
1205        }
1206    }
1207
1208    #[test]
1209    fn build_relationships_sem_duplicatas() {
1210        let entities: Vec<NewEntity> = (0..5)
1211            .map(|i| NewEntity {
1212                name: format!("ent_{i}"),
1213                entity_type: "concept".to_string(),
1214                description: None,
1215            })
1216            .collect();
1217        let (rels, _truncated) = build_relationships(&entities);
1218        let mut pares: std::collections::HashSet<(String, String)> =
1219            std::collections::HashSet::new();
1220        for r in &rels {
1221            let par = (r.source.clone(), r.target.clone());
1222            assert!(pares.insert(par), "par duplicado encontrado");
1223        }
1224    }
1225
1226    #[test]
1227    fn merge_deduplica_por_nome_lowercase() {
1228        let a = vec![ExtractedEntity {
1229            name: "Rust".to_string(),
1230            entity_type: "concept".to_string(),
1231        }];
1232        let b = vec![ExtractedEntity {
1233            name: "rust".to_string(),
1234            entity_type: "tool".to_string(),
1235        }];
1236        let merged = merge_and_deduplicate(a, b);
1237        assert_eq!(merged.len(), 1, "rust e Rust são a mesma entidade");
1238    }
1239
1240    #[test]
1241    fn regex_extractor_implementa_trait() {
1242        let extractor = RegexExtractor;
1243        let result = extractor
1244            .extract("contato: dev@empresa.io e MAX_TIMEOUT configurado")
1245            .unwrap();
1246        assert!(!result.entities.is_empty());
1247    }
1248
1249    #[test]
1250    fn extract_retorna_ok_sem_modelo() {
1251        // Sem modelo baixado, deve retornar Ok com apenas entidades regex
1252        let paths = make_paths();
1253        let body = "contato: teste@exemplo.com com MAX_RETRY=3";
1254        let result = extract_graph_auto(body, &paths).unwrap();
1255        assert!(result
1256            .entities
1257            .iter()
1258            .any(|e| e.name.contains("teste@exemplo.com")));
1259    }
1260
1261    #[test]
1262    fn stopwords_filter_v1024_terms() {
1263        // v1.0.24: verify that all 17 new stopwords added in P0-3 are filtered
1264        // by apply_regex_prefilter so they do not appear as entities.
1265        let body = "ACEITE ACK ACL BORDA CHECKLIST COMPLETED CONFIRME \
1266                    DEVEMOS DONE FIXED NEGUE PENDING PLAN PODEMOS RECUSE TOKEN VAMOS";
1267        let ents = apply_regex_prefilter(body);
1268        let names: Vec<&str> = ents.iter().map(|e| e.name.as_str()).collect();
1269        for word in &[
1270            "ACEITE",
1271            "ACK",
1272            "ACL",
1273            "BORDA",
1274            "CHECKLIST",
1275            "COMPLETED",
1276            "CONFIRME",
1277            "DEVEMOS",
1278            "DONE",
1279            "FIXED",
1280            "NEGUE",
1281            "PENDING",
1282            "PLAN",
1283            "PODEMOS",
1284            "RECUSE",
1285            "TOKEN",
1286            "VAMOS",
1287        ] {
1288            assert!(
1289                !names.contains(word),
1290                "v1.0.24 stopword {word} should be filtered but was found in entities"
1291            );
1292        }
1293    }
1294
1295    #[test]
1296    fn dedup_normalizes_unicode_combining_marks() {
1297        // v1.0.24 P1-E: "Café" (NFC precomposed) and "Cafe\u{301}" (NFD with
1298        // combining acute accent) must deduplicate to a single entity after NFKC
1299        // normalization.
1300        let nfc = vec![ExtractedEntity {
1301            name: "Café".to_string(),
1302            entity_type: "concept".to_string(),
1303        }];
1304        // Build the NFD form: 'e' followed by combining acute accent U+0301
1305        let nfd_name = "Cafe\u{301}".to_string();
1306        let nfd = vec![ExtractedEntity {
1307            name: nfd_name,
1308            entity_type: "concept".to_string(),
1309        }];
1310        let merged = merge_and_deduplicate(nfc, nfd);
1311        assert_eq!(
1312            merged.len(),
1313            1,
1314            "NFC 'Café' and NFD 'Cafe\\u{{301}}' must deduplicate to 1 entity after NFKC normalization"
1315        );
1316    }
1317
1318    // ── predict_batch regression tests ──────────────────────────────────────
1319
1320    #[test]
1321    fn predict_batch_output_count_matches_input() {
1322        // Verify that predict_batch returns exactly one Vec<String> per window
1323        // without requiring a real model.  We test the shape contract by
1324        // constructing the padding logic manually and asserting counts.
1325        //
1326        // Two windows of different lengths: 3 tokens and 5 tokens.
1327        let w1_ids: Vec<u32> = vec![101, 100, 102];
1328        let w1_tok: Vec<String> = vec!["[CLS]".into(), "hello".into(), "[SEP]".into()];
1329        let w2_ids: Vec<u32> = vec![101, 100, 200, 300, 102];
1330        let w2_tok: Vec<String> = vec![
1331            "[CLS]".into(),
1332            "world".into(),
1333            "foo".into(),
1334            "bar".into(),
1335            "[SEP]".into(),
1336        ];
1337        let windows: Vec<(Vec<u32>, Vec<String>)> =
1338            vec![(w1_ids.clone(), w1_tok), (w2_ids.clone(), w2_tok)];
1339
1340        // Verify padding logic and output length contracts using tensor operations
1341        // that do NOT require BertModel::forward.
1342        let device = Device::Cpu;
1343        let max_len = windows.iter().map(|(ids, _)| ids.len()).max().unwrap();
1344        assert_eq!(max_len, 5, "max_len deve ser 5");
1345
1346        let mut padded_ids: Vec<Tensor> = Vec::new();
1347        for (ids, _) in &windows {
1348            let len = ids.len();
1349            let pad_right = max_len - len;
1350            let ids_i64: Vec<i64> = ids.iter().map(|&x| x as i64).collect();
1351            let t = Tensor::from_vec(ids_i64, len, &device).unwrap();
1352            let t = t.pad_with_zeros(0, 0, pad_right).unwrap();
1353            assert_eq!(
1354                t.dims(),
1355                &[max_len],
1356                "cada janela deve ter shape (max_len,) após padding"
1357            );
1358            padded_ids.push(t);
1359        }
1360
1361        let stacked = Tensor::stack(&padded_ids, 0).unwrap();
1362        assert_eq!(
1363            stacked.dims(),
1364            &[2, max_len],
1365            "stack deve produzir (batch_size=2, max_len=5)"
1366        );
1367
1368        // Verify narrow preserves only real tokens for each window
1369        // (simulates what predict_batch does after classifier.forward)
1370        let fake_logits_data: Vec<f32> = vec![0.0f32; 2 * max_len * 9]; // batch×seq×num_labels=9
1371        let fake_logits =
1372            Tensor::from_vec(fake_logits_data, (2usize, max_len, 9usize), &device).unwrap();
1373        for (i, (ids, _)) in windows.iter().enumerate() {
1374            let real_len = ids.len();
1375            let example = fake_logits.get(i).unwrap();
1376            let sliced = example.narrow(0, 0, real_len).unwrap();
1377            assert_eq!(
1378                sliced.dims(),
1379                &[real_len, 9],
1380                "narrow deve preservar apenas {real_len} tokens reais"
1381            );
1382        }
1383    }
1384
1385    #[test]
1386    fn predict_batch_empty_windows_returns_empty() {
1387        // predict_batch with no windows must return an empty Vec, not panic.
1388        // We test the guard logic directly on the batch size/max_len path.
1389        let windows: Vec<(Vec<u32>, Vec<String>)> = vec![];
1390        let max_len = windows.iter().map(|(ids, _)| ids.len()).max().unwrap_or(0);
1391        assert_eq!(max_len, 0, "zero windows → max_len 0");
1392        // The real predict_batch returns Ok(vec![]) when max_len == 0.
1393        // We assert the expected output shape by reproducing the guard here.
1394        let result: Vec<Vec<String>> = if max_len == 0 {
1395            Vec::new()
1396        } else {
1397            unreachable!()
1398        };
1399        assert!(result.is_empty());
1400    }
1401
1402    #[test]
1403    fn ner_batch_size_default_is_8() {
1404        // Verify that ner_batch_size() returns the documented default when the
1405        // env var is absent.  We clear the var to avoid cross-test contamination.
1406        std::env::remove_var("GRAPHRAG_NER_BATCH_SIZE");
1407        assert_eq!(crate::constants::ner_batch_size(), 8);
1408    }
1409
1410    #[test]
1411    fn ner_batch_size_env_override_clamped() {
1412        // Override via env var; values outside [1, 32] must be clamped.
1413        std::env::set_var("GRAPHRAG_NER_BATCH_SIZE", "64");
1414        assert_eq!(crate::constants::ner_batch_size(), 32, "deve clampar em 32");
1415
1416        std::env::set_var("GRAPHRAG_NER_BATCH_SIZE", "0");
1417        assert_eq!(crate::constants::ner_batch_size(), 1, "deve clampar em 1");
1418
1419        std::env::set_var("GRAPHRAG_NER_BATCH_SIZE", "4");
1420        assert_eq!(
1421            crate::constants::ner_batch_size(),
1422            4,
1423            "valor válido preservado"
1424        );
1425
1426        std::env::remove_var("GRAPHRAG_NER_BATCH_SIZE");
1427    }
1428
1429    #[test]
1430    fn extraction_method_regex_only_unchanged() {
1431        // RegexExtractor always returns "regex-only" regardless of NER_MODEL OnceLock state.
1432        // This guards against accidentally changing the regex-only fallback string.
1433        let result = RegexExtractor.extract("contato: dev@acme.io").unwrap();
1434        assert_eq!(
1435            result.extraction_method, "regex-only",
1436            "RegexExtractor deve retornar regex-only"
1437        );
1438    }
1439
1440    // --- P2-E: extend_with_numeric_suffix alphanumeric suffix ---
1441
1442    #[test]
1443    fn extend_suffix_pure_numeric_unchanged() {
1444        // Existing behaviour: pure-numeric suffix must still work after P2-E.
1445        let ents = vec![ExtractedEntity {
1446            name: "GPT".to_string(),
1447            entity_type: "concept".to_string(),
1448        }];
1449        let result = extend_with_numeric_suffix(ents, "usando GPT-5 no projeto");
1450        assert_eq!(
1451            result[0].name, "GPT-5",
1452            "sufixo puramente numérico deve ser estendido"
1453        );
1454    }
1455
1456    #[test]
1457    fn extend_suffix_alphanumeric_letter_after_digit() {
1458        // P2-E: "4o" suffix (digit + lowercase letter) must be captured.
1459        let ents = vec![ExtractedEntity {
1460            name: "GPT".to_string(),
1461            entity_type: "concept".to_string(),
1462        }];
1463        let result = extend_with_numeric_suffix(ents, "usando GPT-4o para tarefas avançadas");
1464        assert_eq!(result[0].name, "GPT-4o", "sufixo '4o' deve ser aceito");
1465    }
1466
1467    #[test]
1468    fn extend_suffix_alphanumeric_b_suffix() {
1469        // P2-E: "5b" suffix (digit + 'b') must be captured.
1470        let ents = vec![ExtractedEntity {
1471            name: "Llama".to_string(),
1472            entity_type: "concept".to_string(),
1473        }];
1474        let result = extend_with_numeric_suffix(ents, "modelo Llama-5b open-weight");
1475        assert_eq!(result[0].name, "Llama-5b", "sufixo '5b' deve ser aceito");
1476    }
1477
1478    #[test]
1479    fn extend_suffix_alphanumeric_x_suffix() {
1480        // P2-E: "8x" suffix (digit + 'x') must be captured.
1481        let ents = vec![ExtractedEntity {
1482            name: "Mistral".to_string(),
1483            entity_type: "concept".to_string(),
1484        }];
1485        let result = extend_with_numeric_suffix(ents, "testando Mistral-8x em produção");
1486        assert_eq!(result[0].name, "Mistral-8x", "sufixo '8x' deve ser aceito");
1487    }
1488
1489    // --- P2-D: augment_versioned_model_names extended regex ---
1490
1491    #[test]
1492    fn augment_versioned_gpt4o() {
1493        // P2-D: "GPT-4o" must be captured with alphanumeric suffix.
1494        let result = augment_versioned_model_names(vec![], "usando GPT-4o para análise");
1495        assert!(
1496            result.iter().any(|e| e.name == "GPT-4o"),
1497            "GPT-4o deve ser capturado pelo augment, achados: {:?}",
1498            result.iter().map(|e| &e.name).collect::<Vec<_>>()
1499        );
1500    }
1501
1502    #[test]
1503    fn augment_versioned_claude_4_sonnet() {
1504        // P2-D: "Claude 4 Sonnet" must be captured with release tier.
1505        let result =
1506            augment_versioned_model_names(vec![], "melhor modelo: Claude 4 Sonnet lançado hoje");
1507        assert!(
1508            result.iter().any(|e| e.name == "Claude 4 Sonnet"),
1509            "Claude 4 Sonnet deve ser capturado, achados: {:?}",
1510            result.iter().map(|e| &e.name).collect::<Vec<_>>()
1511        );
1512    }
1513
1514    #[test]
1515    fn augment_versioned_llama_3_pro() {
1516        // P2-D: "Llama 3 Pro" must be captured with release tier.
1517        let result =
1518            augment_versioned_model_names(vec![], "fine-tuning com Llama 3 Pro localmente");
1519        assert!(
1520            result.iter().any(|e| e.name == "Llama 3 Pro"),
1521            "Llama 3 Pro deve ser capturado, achados: {:?}",
1522            result.iter().map(|e| &e.name).collect::<Vec<_>>()
1523        );
1524    }
1525
1526    #[test]
1527    fn augment_versioned_mixtral_8x7b() {
1528        // P2-D: "Mixtral 8x7B" composite version must be captured.
1529        let result =
1530            augment_versioned_model_names(vec![], "executando Mixtral 8x7B no servidor local");
1531        assert!(
1532            result.iter().any(|e| e.name == "Mixtral 8x7B"),
1533            "Mixtral 8x7B deve ser capturado, achados: {:?}",
1534            result.iter().map(|e| &e.name).collect::<Vec<_>>()
1535        );
1536    }
1537
1538    #[test]
1539    fn augment_versioned_does_not_duplicate_existing() {
1540        // P2-D back-compat: entities already present must not be duplicated.
1541        let existing = vec![ExtractedEntity {
1542            name: "Claude 4".to_string(),
1543            entity_type: "concept".to_string(),
1544        }];
1545        let result = augment_versioned_model_names(existing, "usando Claude 4 no projeto");
1546        let count = result.iter().filter(|e| e.name == "Claude 4").count();
1547        assert_eq!(count, 1, "Claude 4 não deve ser duplicado");
1548    }
1549}