1use anyhow::{Result, anyhow};
45use lmm::text::TextSummarizer;
46use std::cmp::Ordering;
47use std::collections::{HashMap, HashSet};
48use std::fs;
49#[cfg(feature = "knowledge")]
50use std::io::Cursor;
51use std::path::{Path, PathBuf};
52
53#[derive(Debug, Clone)]
55pub enum KnowledgeSource {
56 File(PathBuf),
58 Dir(PathBuf),
60 Url(String),
62 RawText(String),
64}
65
66#[derive(Debug, Clone)]
68pub struct DocumentChunk {
69 pub source: String,
71 pub text: String,
73 pub tokens: Vec<String>,
75}
76
77impl DocumentChunk {
78 pub fn new(source: impl Into<String>, text: impl Into<String>) -> Self {
79 let text = text.into();
80 let tokens = tokenise(&text);
81 Self {
82 source: source.into(),
83 text,
84 tokens,
85 }
86 }
87}
88
89#[derive(Debug, Clone, Default)]
109pub struct KnowledgeIndex {
110 chunks: Vec<DocumentChunk>,
111 term_index: HashMap<String, Vec<usize>>,
112 doc_freq: HashMap<String, usize>,
113}
114
115impl KnowledgeIndex {
116 pub fn new() -> Self {
118 Self::default()
119 }
120
121 pub fn is_empty(&self) -> bool {
123 self.chunks.is_empty()
124 }
125
126 pub fn len(&self) -> usize {
128 self.chunks.len()
129 }
130
131 pub fn ingest_text(&mut self, source: &str, text: &str) -> usize {
136 let sentences = split_sentences(text);
137 let start = self.chunks.len();
138 for sentence in sentences {
139 if sentence.split_whitespace().count() < 4 {
140 continue;
141 }
142 let chunk = DocumentChunk::new(source, &sentence);
143 let idx = self.chunks.len();
144 for token in &chunk.tokens {
145 self.term_index.entry(token.clone()).or_default().push(idx);
146 }
147 self.chunks.push(chunk);
148 }
149 let added = self.chunks.len() - start;
150 self.rebuild_doc_freq();
151 added
152 }
153
154 pub fn query(&self, question: &str, top_k: usize) -> Vec<&DocumentChunk> {
159 if self.chunks.is_empty() {
160 return Vec::new();
161 }
162 let q_tokens = tokenise(question);
163 let n = self.chunks.len() as f64;
164
165 let mut scores: Vec<(usize, f64)> = (0..self.chunks.len())
166 .map(|i| {
167 let chunk = &self.chunks[i];
168 let score: f64 = q_tokens
169 .iter()
170 .filter(|t| chunk.tokens.contains(t))
171 .map(|t| {
172 let df = *self.doc_freq.get(t).unwrap_or(&1) as f64;
173 (n / df).ln() + 1.0
174 })
175 .sum();
176 (i, score)
177 })
178 .filter(|(_, s)| *s > 0.0)
179 .collect();
180
181 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
182 scores
183 .into_iter()
184 .take(top_k)
185 .map(|(i, _)| &self.chunks[i])
186 .collect()
187 }
188
189 pub fn answer(&self, question: &str, top_k: usize) -> Option<String> {
196 let hits = self.query(question, top_k);
197 if hits.is_empty() {
198 return None;
199 }
200 let corpus: String = hits
201 .iter()
202 .map(|c| c.text.as_str())
203 .collect::<Vec<_>>()
204 .join(" ");
205 let summariser = TextSummarizer::new(3, 4, 2);
206 summariser
207 .summarize_with_query(&corpus, question)
208 .ok()
209 .map(|sentences| sentences.join(" "))
210 }
211
212 fn rebuild_doc_freq(&mut self) {
213 self.doc_freq.clear();
214 for chunk in &self.chunks {
215 let mut seen = HashSet::new();
216 for token in &chunk.tokens {
217 if seen.insert(token) {
218 *self.doc_freq.entry(token.clone()).or_insert(0) += 1;
219 }
220 }
221 }
222 }
223}
224
225pub trait DocumentParser: Send + Sync {
227 fn supports_extension(&self, ext: &str) -> bool;
229 fn parse_bytes(&self, bytes: &[u8]) -> Result<String>;
231}
232
233#[derive(Debug, Default, Clone)]
235pub struct PlainTextParser;
236
237impl DocumentParser for PlainTextParser {
238 fn supports_extension(&self, ext: &str) -> bool {
239 ext.eq_ignore_ascii_case("txt")
240 }
241
242 fn parse_bytes(&self, bytes: &[u8]) -> Result<String> {
243 Ok(String::from_utf8_lossy(bytes).into_owned())
244 }
245}
246
247#[derive(Debug, Default, Clone)]
249pub struct MarkdownParser;
250
251impl DocumentParser for MarkdownParser {
252 fn supports_extension(&self, ext: &str) -> bool {
253 matches!(ext.to_ascii_lowercase().as_str(), "md" | "markdown")
254 }
255
256 fn parse_bytes(&self, bytes: &[u8]) -> Result<String> {
257 let raw = String::from_utf8_lossy(bytes);
258 Ok(strip_markdown(&raw))
259 }
260}
261
262#[cfg(feature = "knowledge")]
264#[derive(Debug, Default, Clone)]
265pub struct PdfParser;
266
267#[cfg(feature = "knowledge")]
268impl DocumentParser for PdfParser {
269 fn supports_extension(&self, ext: &str) -> bool {
270 ext.eq_ignore_ascii_case("pdf")
271 }
272
273 fn parse_bytes(&self, bytes: &[u8]) -> Result<String> {
274 #[cfg(feature = "knowledge")]
275 use lopdf::Document;
276
277 let doc =
278 Document::load_from(Cursor::new(bytes)).map_err(|e| anyhow!("lopdf error: {e}"))?;
279 let mut out = String::new();
280 for page_num in 1..=doc.get_pages().len() as u32 {
281 if let Ok(texts) = doc.extract_text(&[page_num]) {
282 out.push_str(&texts);
283 out.push('\n');
284 }
285 }
286 Ok(out)
287 }
288}
289
290pub fn default_parsers() -> Vec<Box<dyn DocumentParser>> {
292 #[allow(unused_mut)]
293 let mut parsers: Vec<Box<dyn DocumentParser>> =
294 vec![Box::new(PlainTextParser), Box::new(MarkdownParser)];
295 #[cfg(feature = "knowledge")]
296 parsers.push(Box::new(PdfParser));
297 parsers
298}
299
300pub async fn ingest(index: &mut KnowledgeIndex, source: KnowledgeSource) -> Result<usize> {
304 match source {
305 KnowledgeSource::RawText(text) => Ok(index.ingest_text("inline", &text)),
306
307 KnowledgeSource::File(path) => ingest_file(index, &path),
308
309 KnowledgeSource::Dir(dir) => {
310 let mut total = 0;
311 let entries = fs::read_dir(&dir)
312 .map_err(|e| anyhow!("Cannot read dir {}: {e}", dir.display()))?;
313 for entry in entries.flatten() {
314 let p = entry.path();
315 if p.is_file() {
316 total += ingest_file(index, &p).unwrap_or(0);
317 }
318 }
319 Ok(total)
320 }
321
322 KnowledgeSource::Url(url) => ingest_url(index, &url).await,
323 }
324}
325
326fn ingest_file(index: &mut KnowledgeIndex, path: &Path) -> Result<usize> {
327 let ext = path
328 .extension()
329 .and_then(|e| e.to_str())
330 .unwrap_or("")
331 .to_ascii_lowercase();
332
333 let parsers = default_parsers();
334 let parser = parsers
335 .iter()
336 .find(|p| p.supports_extension(&ext))
337 .ok_or_else(|| anyhow!("No parser for '{}' (extension: .{})", path.display(), ext))?;
338
339 let bytes = fs::read(path).map_err(|e| anyhow!("Cannot read {}: {e}", path.display()))?;
340 let text = parser.parse_bytes(&bytes)?;
341 let label = path.file_name().and_then(|n| n.to_str()).unwrap_or("file");
342 Ok(index.ingest_text(label, &text))
343}
344
345#[cfg(feature = "net")]
346async fn ingest_url(index: &mut KnowledgeIndex, url: &str) -> Result<usize> {
347 let body = reqwest::get(url)
348 .await
349 .map_err(|e| anyhow!("HTTP error for {url}: {e}"))?
350 .text()
351 .await
352 .map_err(|e| anyhow!("Body error for {url}: {e}"))?;
353 let plain = strip_html(&body);
354 Ok(index.ingest_text(url, &plain))
355}
356
357#[cfg(not(feature = "net"))]
358async fn ingest_url(_index: &mut KnowledgeIndex, url: &str) -> Result<usize> {
359 Err(anyhow!(
360 "URL ingestion requires the `net` feature. Enable it with --features net. URL: {url}"
361 ))
362}
363
364static STOP_WORDS: phf::Set<&'static str> = phf::phf_set! {
365 "the","and","for","are","was","that","this","with","from","have",
366 "not","but","had","has","its","were","they","will","been","their",
367 "all","one","can","her","his","him","she","who","which","what",
368 "into","then","than","when","also","more","some","out","about",
369 "said","would","could","should","each","other","there","these",
370 "those","such","any","our","you","your","very","just","now","may"
371};
372
373fn tokenise(text: &str) -> Vec<String> {
374 text.split(|c: char| !c.is_ascii_alphabetic())
375 .filter_map(|w| {
376 let w = w.to_ascii_lowercase();
377 if w.len() >= 3 && !STOP_WORDS.contains(w.as_str()) {
378 Some(w)
379 } else {
380 None
381 }
382 })
383 .collect()
384}
385
386fn split_sentences(text: &str) -> Vec<String> {
387 let mut out = Vec::new();
388 let mut cur = String::new();
389 let chars: Vec<char> = text.chars().collect();
390 for (i, &ch) in chars.iter().enumerate() {
391 cur.push(ch);
392 if matches!(ch, '.' | '!' | '?') {
393 let next_lower = chars
394 .get(i + 1)
395 .map(|c| c.is_ascii_lowercase())
396 .unwrap_or(false);
397 let prev_digit = i > 0 && chars[i - 1].is_ascii_digit();
398 let next_digit = chars
399 .get(i + 1)
400 .map(|c| c.is_ascii_digit())
401 .unwrap_or(false);
402 if ch == '.' && (next_lower || (prev_digit && next_digit)) {
403 continue;
404 }
405 let s = cur.trim().to_string();
406 if !s.is_empty() {
407 out.push(s);
408 }
409 cur.clear();
410 }
411 }
412 let tail = cur.trim().to_string();
413 if !tail.is_empty() {
414 out.push(tail);
415 }
416 out
417}
418
419fn strip_markdown(text: &str) -> String {
420 let mut out = String::with_capacity(text.len());
421 let mut in_code = false;
422 for line in text.lines() {
423 let trimmed = line.trim();
424 if trimmed.starts_with("```") {
425 in_code = !in_code;
426 continue;
427 }
428 if in_code {
429 continue;
430 }
431 if trimmed.starts_with('#') || trimmed.starts_with("---") || trimmed.starts_with("===") {
432 continue;
433 }
434 let clean: String = trimmed
435 .chars()
436 .filter(|&c| c != '*' && c != '_' && c != '`' && c != '|')
437 .collect();
438 let clean = clean.trim();
439 if !clean.is_empty() {
440 out.push_str(clean);
441 out.push(' ');
442 }
443 }
444 out
445}
446
447#[allow(dead_code)]
448fn strip_html(html: &str) -> String {
449 let mut out = String::with_capacity(html.len() / 2);
450 let mut in_tag = false;
451 let mut buf = String::new();
452 for ch in html.chars() {
453 match ch {
454 '<' => {
455 if !buf.trim().is_empty() {
456 out.push_str(buf.trim());
457 out.push(' ');
458 }
459 buf.clear();
460 in_tag = true;
461 }
462 '>' => {
463 in_tag = false;
464 }
465 _ if !in_tag => {
466 buf.push(ch);
467 }
468 _ => {}
469 }
470 }
471 if !buf.trim().is_empty() {
472 out.push_str(buf.trim());
473 }
474 out
475}