use anyhow::{Result, anyhow};
use lmm::text::TextSummarizer;
use std::cmp::Ordering;
use std::collections::{HashMap, HashSet};
use std::fs;
#[cfg(feature = "knowledge")]
use std::io::Cursor;
use std::path::{Path, PathBuf};
#[derive(Debug, Clone)]
pub enum KnowledgeSource {
File(PathBuf),
Dir(PathBuf),
Url(String),
RawText(String),
}
#[derive(Debug, Clone)]
pub struct DocumentChunk {
pub source: String,
pub text: String,
pub tokens: Vec<String>,
}
impl DocumentChunk {
pub fn new(source: impl Into<String>, text: impl Into<String>) -> Self {
let text = text.into();
let tokens = tokenise(&text);
Self {
source: source.into(),
text,
tokens,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct KnowledgeIndex {
chunks: Vec<DocumentChunk>,
term_index: HashMap<String, Vec<usize>>,
doc_freq: HashMap<String, usize>,
}
impl KnowledgeIndex {
pub fn new() -> Self {
Self::default()
}
pub fn is_empty(&self) -> bool {
self.chunks.is_empty()
}
pub fn len(&self) -> usize {
self.chunks.len()
}
pub fn ingest_text(&mut self, source: &str, text: &str) -> usize {
let sentences = split_sentences(text);
let start = self.chunks.len();
for sentence in sentences {
if sentence.split_whitespace().count() < 4 {
continue;
}
let chunk = DocumentChunk::new(source, &sentence);
let idx = self.chunks.len();
for token in &chunk.tokens {
self.term_index.entry(token.clone()).or_default().push(idx);
}
self.chunks.push(chunk);
}
let added = self.chunks.len() - start;
self.rebuild_doc_freq();
added
}
pub fn query(&self, question: &str, top_k: usize) -> Vec<&DocumentChunk> {
if self.chunks.is_empty() {
return Vec::new();
}
let q_tokens = tokenise(question);
let n = self.chunks.len() as f64;
let mut scores: Vec<(usize, f64)> = (0..self.chunks.len())
.map(|i| {
let chunk = &self.chunks[i];
let score: f64 = q_tokens
.iter()
.filter(|t| chunk.tokens.contains(t))
.map(|t| {
let df = *self.doc_freq.get(t).unwrap_or(&1) as f64;
(n / df).ln() + 1.0
})
.sum();
(i, score)
})
.filter(|(_, s)| *s > 0.0)
.collect();
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
scores
.into_iter()
.take(top_k)
.map(|(i, _)| &self.chunks[i])
.collect()
}
pub fn answer(&self, question: &str, top_k: usize) -> Option<String> {
let hits = self.query(question, top_k);
if hits.is_empty() {
return None;
}
let corpus: String = hits
.iter()
.map(|c| c.text.as_str())
.collect::<Vec<_>>()
.join(" ");
let summariser = TextSummarizer::new(3, 4, 2);
summariser
.summarize_with_query(&corpus, question)
.ok()
.map(|sentences| sentences.join(" "))
}
fn rebuild_doc_freq(&mut self) {
self.doc_freq.clear();
for chunk in &self.chunks {
let mut seen = HashSet::new();
for token in &chunk.tokens {
if seen.insert(token) {
*self.doc_freq.entry(token.clone()).or_insert(0) += 1;
}
}
}
}
}
pub trait DocumentParser: Send + Sync {
fn supports_extension(&self, ext: &str) -> bool;
fn parse_bytes(&self, bytes: &[u8]) -> Result<String>;
}
#[derive(Debug, Default, Clone)]
pub struct PlainTextParser;
impl DocumentParser for PlainTextParser {
fn supports_extension(&self, ext: &str) -> bool {
ext.eq_ignore_ascii_case("txt")
}
fn parse_bytes(&self, bytes: &[u8]) -> Result<String> {
Ok(String::from_utf8_lossy(bytes).into_owned())
}
}
#[derive(Debug, Default, Clone)]
pub struct MarkdownParser;
impl DocumentParser for MarkdownParser {
fn supports_extension(&self, ext: &str) -> bool {
matches!(ext.to_ascii_lowercase().as_str(), "md" | "markdown")
}
fn parse_bytes(&self, bytes: &[u8]) -> Result<String> {
let raw = String::from_utf8_lossy(bytes);
Ok(strip_markdown(&raw))
}
}
#[cfg(feature = "knowledge")]
#[derive(Debug, Default, Clone)]
pub struct PdfParser;
#[cfg(feature = "knowledge")]
impl DocumentParser for PdfParser {
fn supports_extension(&self, ext: &str) -> bool {
ext.eq_ignore_ascii_case("pdf")
}
fn parse_bytes(&self, bytes: &[u8]) -> Result<String> {
#[cfg(feature = "knowledge")]
use lopdf::Document;
let doc =
Document::load_from(Cursor::new(bytes)).map_err(|e| anyhow!("lopdf error: {e}"))?;
let mut out = String::new();
for page_num in 1..=doc.get_pages().len() as u32 {
if let Ok(texts) = doc.extract_text(&[page_num]) {
out.push_str(&texts);
out.push('\n');
}
}
Ok(out)
}
}
pub fn default_parsers() -> Vec<Box<dyn DocumentParser>> {
#[allow(unused_mut)]
let mut parsers: Vec<Box<dyn DocumentParser>> =
vec![Box::new(PlainTextParser), Box::new(MarkdownParser)];
#[cfg(feature = "knowledge")]
parsers.push(Box::new(PdfParser));
parsers
}
pub async fn ingest(index: &mut KnowledgeIndex, source: KnowledgeSource) -> Result<usize> {
match source {
KnowledgeSource::RawText(text) => Ok(index.ingest_text("inline", &text)),
KnowledgeSource::File(path) => ingest_file(index, &path),
KnowledgeSource::Dir(dir) => {
let mut total = 0;
let entries = fs::read_dir(&dir)
.map_err(|e| anyhow!("Cannot read dir {}: {e}", dir.display()))?;
for entry in entries.flatten() {
let p = entry.path();
if p.is_file() {
total += ingest_file(index, &p).unwrap_or(0);
}
}
Ok(total)
}
KnowledgeSource::Url(url) => ingest_url(index, &url).await,
}
}
fn ingest_file(index: &mut KnowledgeIndex, path: &Path) -> Result<usize> {
let ext = path
.extension()
.and_then(|e| e.to_str())
.unwrap_or("")
.to_ascii_lowercase();
let parsers = default_parsers();
let parser = parsers
.iter()
.find(|p| p.supports_extension(&ext))
.ok_or_else(|| anyhow!("No parser for '{}' (extension: .{})", path.display(), ext))?;
let bytes = fs::read(path).map_err(|e| anyhow!("Cannot read {}: {e}", path.display()))?;
let text = parser.parse_bytes(&bytes)?;
let label = path.file_name().and_then(|n| n.to_str()).unwrap_or("file");
Ok(index.ingest_text(label, &text))
}
#[cfg(feature = "net")]
async fn ingest_url(index: &mut KnowledgeIndex, url: &str) -> Result<usize> {
let body = reqwest::get(url)
.await
.map_err(|e| anyhow!("HTTP error for {url}: {e}"))?
.text()
.await
.map_err(|e| anyhow!("Body error for {url}: {e}"))?;
let plain = strip_html(&body);
Ok(index.ingest_text(url, &plain))
}
#[cfg(not(feature = "net"))]
async fn ingest_url(_index: &mut KnowledgeIndex, url: &str) -> Result<usize> {
Err(anyhow!(
"URL ingestion requires the `net` feature. Enable it with --features net. URL: {url}"
))
}
static STOP_WORDS: phf::Set<&'static str> = phf::phf_set! {
"the","and","for","are","was","that","this","with","from","have",
"not","but","had","has","its","were","they","will","been","their",
"all","one","can","her","his","him","she","who","which","what",
"into","then","than","when","also","more","some","out","about",
"said","would","could","should","each","other","there","these",
"those","such","any","our","you","your","very","just","now","may"
};
fn tokenise(text: &str) -> Vec<String> {
text.split(|c: char| !c.is_ascii_alphabetic())
.filter_map(|w| {
let w = w.to_ascii_lowercase();
if w.len() >= 3 && !STOP_WORDS.contains(w.as_str()) {
Some(w)
} else {
None
}
})
.collect()
}
fn split_sentences(text: &str) -> Vec<String> {
let mut out = Vec::new();
let mut cur = String::new();
let chars: Vec<char> = text.chars().collect();
for (i, &ch) in chars.iter().enumerate() {
cur.push(ch);
if matches!(ch, '.' | '!' | '?') {
let next_lower = chars
.get(i + 1)
.map(|c| c.is_ascii_lowercase())
.unwrap_or(false);
let prev_digit = i > 0 && chars[i - 1].is_ascii_digit();
let next_digit = chars
.get(i + 1)
.map(|c| c.is_ascii_digit())
.unwrap_or(false);
if ch == '.' && (next_lower || (prev_digit && next_digit)) {
continue;
}
let s = cur.trim().to_string();
if !s.is_empty() {
out.push(s);
}
cur.clear();
}
}
let tail = cur.trim().to_string();
if !tail.is_empty() {
out.push(tail);
}
out
}
fn strip_markdown(text: &str) -> String {
let mut out = String::with_capacity(text.len());
let mut in_code = false;
for line in text.lines() {
let trimmed = line.trim();
if trimmed.starts_with("```") {
in_code = !in_code;
continue;
}
if in_code {
continue;
}
if trimmed.starts_with('#') || trimmed.starts_with("---") || trimmed.starts_with("===") {
continue;
}
let clean: String = trimmed
.chars()
.filter(|&c| c != '*' && c != '_' && c != '`' && c != '|')
.collect();
let clean = clean.trim();
if !clean.is_empty() {
out.push_str(clean);
out.push(' ');
}
}
out
}
#[allow(dead_code)]
fn strip_html(html: &str) -> String {
let mut out = String::with_capacity(html.len() / 2);
let mut in_tag = false;
let mut buf = String::new();
for ch in html.chars() {
match ch {
'<' => {
if !buf.trim().is_empty() {
out.push_str(buf.trim());
out.push(' ');
}
buf.clear();
in_tag = true;
}
'>' => {
in_tag = false;
}
_ if !in_tag => {
buf.push(ch);
}
_ => {}
}
}
if !buf.trim().is_empty() {
out.push_str(buf.trim());
}
out
}