#[cfg(feature = "ner")]
use crate::core::entity::EntityType;
use crate::core::entity::RawEntity;
pub struct NerExtractor {
enabled: bool,
#[cfg(feature = "ner")]
inner: Option<NerInner>,
}
#[cfg(feature = "ner")]
struct NerInner {
session: ort::session::Session,
tokenizer: tokenizers::Tokenizer,
}
impl NerExtractor {
pub fn try_load() -> Self {
#[cfg(feature = "ner")]
{
if let Some(path) = model_path() {
if path.exists() {
match Self::load_from_path(&path) {
Ok(ext) => return ext,
Err(err) => {
tracing::warn!(
"NER model present at {} but failed to load: {err:#}; \
NER will be disabled",
path.display()
);
}
}
} else {
tracing::debug!(
"NER model not found at {}; extractor disabled",
path.display()
);
}
}
return Self {
enabled: false,
inner: None,
};
}
#[cfg(not(feature = "ner"))]
{
tracing::debug!("NER feature not compiled in; extractor disabled");
Self { enabled: false }
}
}
pub fn is_enabled(&self) -> bool {
self.enabled
}
pub fn extract(&self, doc_text: &str, file: &str) -> Vec<RawEntity> {
if !self.enabled || doc_text.trim().is_empty() {
return Vec::new();
}
#[cfg(feature = "ner")]
{
if let Some(inner) = &self.inner {
return run_inference(inner, doc_text, file).unwrap_or_else(|err| {
tracing::debug!("NER inference failed: {err:#}");
Vec::new()
});
}
Vec::new()
}
#[cfg(not(feature = "ner"))]
{
let _ = (doc_text, file);
Vec::new()
}
}
#[cfg(feature = "ner")]
fn load_from_path(model_path: &std::path::Path) -> anyhow::Result<Self> {
use anyhow::Context;
let session = ort::session::Session::builder()
.context("ort: builder")?
.commit_from_file(model_path)
.with_context(|| format!("ort: load model {}", model_path.display()))?;
let tokenizer_path = model_path.with_file_name("tokenizer.json");
let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
.map_err(|e| anyhow::anyhow!("load tokenizer {}: {e}", tokenizer_path.display()))?;
Ok(Self {
enabled: true,
inner: Some(NerInner { session, tokenizer }),
})
}
}
#[cfg(feature = "ner")]
fn model_path() -> Option<std::path::PathBuf> {
std::env::var_os("HOME")
.map(std::path::PathBuf::from)
.map(|h| h.join(".trusty-search/models/ner.onnx"))
}
#[cfg(feature = "ner")]
fn run_inference(inner: &NerInner, doc_text: &str, file: &str) -> anyhow::Result<Vec<RawEntity>> {
let _ = (&inner.session, &inner.tokenizer, doc_text, file);
let _ = EntityType::NaturalLanguagePhrase;
Ok(Vec::new())
}
pub fn extract_doc_comments(content: &str) -> String {
let mut out: Vec<&str> = Vec::new();
for raw in content.lines() {
let line = raw.trim_start();
let stripped = line
.strip_prefix("///")
.or_else(|| line.strip_prefix("//!"));
if let Some(rest) = stripped {
out.push(rest.trim());
}
}
out.join(" ")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ner_disabled_without_model() {
let extractor = NerExtractor::try_load();
assert!(
!extractor.is_enabled(),
"extractor must be disabled in tests (no ner.onnx present)"
);
let result = extractor.extract("async runtime", "foo.rs");
assert!(
result.is_empty(),
"disabled extractor must return no entities"
);
}
#[test]
fn extract_handles_empty_input() {
let extractor = NerExtractor::try_load();
assert!(extractor.extract("", "foo.rs").is_empty());
assert!(extractor.extract(" \n ", "foo.rs").is_empty());
}
#[test]
fn doc_comment_extraction_pulls_triple_slash_lines() {
let src = "/// Async runtime hint\n\
//! Module-level note\n\
fn foo() {}\n\
// regular comment ignored\n\
/// rate limiter\n";
let doc = extract_doc_comments(src);
assert_eq!(doc, "Async runtime hint Module-level note rate limiter");
}
#[test]
fn doc_comment_extraction_empty_when_no_doc_lines() {
assert_eq!(extract_doc_comments("fn foo() {}\n// not a doc"), "");
}
}