use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use mnem_embed_providers::Embedder;
use mnem_extract::{Extractor as StatisticalExtractor, KeyBertExtractor};
use crate::extract::{EntitySpan, Extractor, RelationSpan};
use crate::types::Section;
pub const KEYBERT_RELATION_LABEL: &str = "co_occurs_with";
pub const KEYBERT_MIN_CONFIDENCE: f32 = 0.0;
pub struct KeyBertAdapter {
embedder: Arc<dyn Embedder>,
top_k: usize,
ngram_range: (usize, usize),
mmr_diversity: f32,
pmi_threshold: f32,
label: String,
section_cache: Mutex<HashMap<String, Vec<f32>>>,
}
impl std::fmt::Debug for KeyBertAdapter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let cached = self.section_cache.lock().map(|c| c.len()).unwrap_or(0);
f.debug_struct("KeyBertAdapter")
.field("embedder_model", &self.embedder.model())
.field("embedder_dim", &self.embedder.dim())
.field("top_k", &self.top_k)
.field("ngram_range", &self.ngram_range)
.field("mmr_diversity", &self.mmr_diversity)
.field("pmi_threshold", &self.pmi_threshold)
.field("label", &self.label)
.field("section_cache_len", &cached)
.finish()
}
}
impl KeyBertAdapter {
#[must_use]
pub fn new(embedder: Arc<dyn Embedder>, label: impl Into<String>) -> Self {
Self {
embedder,
top_k: mnem_extract::keybert::DEFAULT_TOP_K,
ngram_range: mnem_extract::keybert::DEFAULT_NGRAM_RANGE,
mmr_diversity: mnem_extract::keybert::DEFAULT_MMR_DIVERSITY,
pmi_threshold: mnem_extract::cooccurrence::DEFAULT_PMI_THRESHOLD,
label: label.into(),
section_cache: Mutex::new(HashMap::new()),
}
}
#[must_use]
pub fn with_label(mut self, label: impl Into<String>) -> Self {
self.label = label.into();
self
}
#[must_use]
pub const fn with_top_k(mut self, k: usize) -> Self {
self.top_k = k;
self
}
#[must_use]
pub const fn with_pmi_threshold(mut self, t: f32) -> Self {
self.pmi_threshold = t;
self
}
}
impl Extractor for KeyBertAdapter {
fn prepare(&self, sections: &[Section]) -> Result<(), crate::error::Error> {
let mut unique: Vec<&str> = Vec::with_capacity(sections.len());
let mut seen: std::collections::BTreeSet<&str> = std::collections::BTreeSet::new();
for s in sections {
if s.text.is_empty() {
continue;
}
if seen.insert(s.text.as_str()) {
unique.push(s.text.as_str());
}
}
if unique.is_empty() {
return Ok(());
}
let vecs = match self.embedder.embed_batch(&unique) {
Ok(v) => v,
Err(_e) => return Ok(()),
};
if let Ok(mut cache) = self.section_cache.lock() {
for (text, vec) in unique.into_iter().zip(vecs) {
cache.entry(text.to_string()).or_insert(vec);
}
}
Ok(())
}
fn extract_entities(&self, section: &Section) -> Vec<EntitySpan> {
let text = §ion.text;
if text.is_empty() {
return Vec::new();
}
let cached = self
.section_cache
.lock()
.ok()
.and_then(|cache| cache.get(text).cloned());
let section_embed = match cached {
Some(v) => v,
None => match self.embedder.embed(text) {
Ok(v) => v,
Err(_) => return Vec::new(),
},
};
let kb = KeyBertExtractor {
embedder: self.embedder.as_ref(),
top_k: self.top_k,
ngram_range: self.ngram_range,
mmr_diversity: self.mmr_diversity,
};
let entities = kb.extract_entities(text, §ion_embed);
entities
.into_iter()
.map(|e| EntitySpan {
kind: self.label.clone(),
text: e.mention,
byte_range: e.span.0..e.span.1,
confidence: e.score.clamp(KEYBERT_MIN_CONFIDENCE, 1.0),
})
.collect()
}
fn extract_relations(&self, entities: &[EntitySpan], section: &Section) -> Vec<RelationSpan> {
if entities.len() < 2 {
return Vec::new();
}
let bridged: Vec<mnem_extract::Entity> = entities
.iter()
.map(|e| mnem_extract::Entity {
mention: e.text.clone(),
score: e.confidence,
span: (e.byte_range.start, e.byte_range.end),
})
.collect();
let rels = mnem_extract::mine_relations(
§ion.text,
&bridged,
self.pmi_threshold,
mnem_extract::ExtractionSource::Statistical,
);
let index_of =
|mention: &str| -> Option<usize> { entities.iter().position(|e| e.text == mention) };
let mut out = Vec::with_capacity(rels.len());
for r in rels {
let (Some(si), Some(oi)) = (index_of(&r.src), index_of(&r.dst)) else {
continue;
};
out.push(RelationSpan {
kind: KEYBERT_RELATION_LABEL.to_string(),
subject_span: si,
object_span: oi,
confidence: r.weight.clamp(0.0, 1.0),
});
}
out
}
}