Skip to main content

lmm_agent/cognition/
knowledge.rs

1// Copyright 2026 Mahmoud Harmouch.
2//
3// Licensed under the MIT license
4// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
5// option. This file may not be copied, modified, or distributed
6// except according to those terms.
7
8//! # `knowledge` - Knowledge Acquisition for Agents
9//!
10//! This module enables agents to build a queryable knowledge base from diverse
11//! local and remote sources, then retrieve grounded, extractive answers to
12//! natural-language questions, entirely offline and without any LLM.
13//!
14//! ## Sources
15//!
16//! | Variant | Description |
17//! |---|---|
18//! | `KnowledgeSource::File(path)` | A single `.txt`, `.md`, or `.pdf` file |
19//! | `KnowledgeSource::Dir(path)` | All parseable files within a directory |
20//! | `KnowledgeSource::Url(url)` | Fetch and parse a web URL (`net` feature) |
21//! | `KnowledgeSource::RawText(text)` | Inline text string |
22//!
23//! ## Retrieval algorithm
24//!
25//! 1. Tokenise the question into lowercase content words (≥ 3 chars, stop-words removed).
26//! 2. Score every indexed [`DocumentChunk`] with IDF-weighted token overlap.
27//! 3. Take the top-k chunks and concatenate their text.
28//! 4. Feed the concatenated corpus + the original question into
29//!    [`lmm::text::TextSummarizer::summarize_with_query`] to produce a short
30//!    extractive answer.
31//!
32//! ## Quick example
33//!
34//! ```rust
35//! use lmm_agent::cognition::knowledge::{KnowledgeIndex, KnowledgeSource};
36//!
37//! let mut index = KnowledgeIndex::new();
38//! index.ingest_text("my-doc", "Rust gives you control over memory without a garbage collector. \
39//!                              The borrow checker enforces safety at compile time.");
40//! let answer = index.answer("How does Rust handle memory?", 3);
41//! assert!(answer.is_some());
42//! ```
43
44use anyhow::{Result, anyhow};
45use lmm::text::TextSummarizer;
46use std::cmp::Ordering;
47use std::collections::{HashMap, HashSet};
48use std::fs;
49#[cfg(feature = "knowledge")]
50use std::io::Cursor;
51use std::path::{Path, PathBuf};
52
53/// A source from which an agent can acquire knowledge.
54#[derive(Debug, Clone)]
55pub enum KnowledgeSource {
56    /// A local file.  Supported extensions: `.txt`, `.md`, and `.pdf` (requires `knowledge`).
57    File(PathBuf),
58    /// All parseable files found directly inside this directory (non-recursive).
59    Dir(PathBuf),
60    /// Fetch plain text from a URL. Requires the `net` feature.
61    Url(String),
62    /// Inline text supplied directly by the caller.
63    RawText(String),
64}
65
66/// A sentence-sized unit of ingested knowledge.
67#[derive(Debug, Clone)]
68pub struct DocumentChunk {
69    /// Human-readable label for the source (filename or URL).
70    pub source: String,
71    /// The raw text of this chunk.
72    pub text: String,
73    /// Pre-tokenised, lowercase content words used for fast scoring.
74    pub tokens: Vec<String>,
75}
76
77impl DocumentChunk {
78    pub fn new(source: impl Into<String>, text: impl Into<String>) -> Self {
79        let text = text.into();
80        let tokens = tokenise(&text);
81        Self {
82            source: source.into(),
83            text,
84            tokens,
85        }
86    }
87}
88
89/// In-memory, term-indexed knowledge base built from ingested documents.
90///
91/// Retrieval is based on IDF-weighted token overlap: content words that appear
92/// rarely across the corpus are weighted more heavily, focusing results on
93/// discriminative evidence.
94///
95/// # Examples
96///
97/// ```rust
98/// use lmm_agent::cognition::knowledge::KnowledgeIndex;
99///
100/// let mut idx = KnowledgeIndex::new();
101/// idx.ingest_text(
102///     "rust-book",
103///     "Rust prevents data races at compile time through its ownership model.",
104/// );
105/// let hits = idx.query("What prevents data races in Rust?", 3);
106/// assert!(!hits.is_empty());
107/// ```
108#[derive(Debug, Clone, Default)]
109pub struct KnowledgeIndex {
110    chunks: Vec<DocumentChunk>,
111    term_index: HashMap<String, Vec<usize>>,
112    doc_freq: HashMap<String, usize>,
113}
114
115impl KnowledgeIndex {
116    /// Creates an empty `KnowledgeIndex`.
117    pub fn new() -> Self {
118        Self::default()
119    }
120
121    /// Returns `true` if no documents have been ingested yet.
122    pub fn is_empty(&self) -> bool {
123        self.chunks.is_empty()
124    }
125
126    /// Returns the total number of indexed chunks.
127    pub fn len(&self) -> usize {
128        self.chunks.len()
129    }
130
131    /// Ingests a raw text string under the given `source` label.
132    ///
133    /// The text is split into sentence-level chunks via a lightweight sentence
134    /// splitter before being indexed.  Returns the number of chunks created.
135    pub fn ingest_text(&mut self, source: &str, text: &str) -> usize {
136        let sentences = split_sentences(text);
137        let start = self.chunks.len();
138        for sentence in sentences {
139            if sentence.split_whitespace().count() < 4 {
140                continue;
141            }
142            let chunk = DocumentChunk::new(source, &sentence);
143            let idx = self.chunks.len();
144            for token in &chunk.tokens {
145                self.term_index.entry(token.clone()).or_default().push(idx);
146            }
147            self.chunks.push(chunk);
148        }
149        let added = self.chunks.len() - start;
150        self.rebuild_doc_freq();
151        added
152    }
153
154    /// Returns the `top_k` most relevant [`DocumentChunk`]s for `question`.
155    ///
156    /// Chunks are scored by the sum of IDF weights for each question token that
157    /// appears in the chunk.
158    pub fn query(&self, question: &str, top_k: usize) -> Vec<&DocumentChunk> {
159        if self.chunks.is_empty() {
160            return Vec::new();
161        }
162        let q_tokens = tokenise(question);
163        let n = self.chunks.len() as f64;
164
165        let mut scores: Vec<(usize, f64)> = (0..self.chunks.len())
166            .map(|i| {
167                let chunk = &self.chunks[i];
168                let score: f64 = q_tokens
169                    .iter()
170                    .filter(|t| chunk.tokens.contains(t))
171                    .map(|t| {
172                        let df = *self.doc_freq.get(t).unwrap_or(&1) as f64;
173                        (n / df).ln() + 1.0
174                    })
175                    .sum();
176                (i, score)
177            })
178            .filter(|(_, s)| *s > 0.0)
179            .collect();
180
181        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
182        scores
183            .into_iter()
184            .take(top_k)
185            .map(|(i, _)| &self.chunks[i])
186            .collect()
187    }
188
189    /// Returns an extractive answer to `question` from the knowledge base,
190    /// or `None` if no relevant chunks are found.
191    ///
192    /// The top-`top_k` chunks are concatenated and passed to
193    /// [`lmm::text::TextSummarizer`] with the original question as a relevance
194    /// hint. The summariser selects the sentences most likely to answer the question.
195    pub fn answer(&self, question: &str, top_k: usize) -> Option<String> {
196        let hits = self.query(question, top_k);
197        if hits.is_empty() {
198            return None;
199        }
200        let corpus: String = hits
201            .iter()
202            .map(|c| c.text.as_str())
203            .collect::<Vec<_>>()
204            .join(" ");
205        let summariser = TextSummarizer::new(3, 4, 2);
206        summariser
207            .summarize_with_query(&corpus, question)
208            .ok()
209            .map(|sentences| sentences.join(" "))
210    }
211
212    fn rebuild_doc_freq(&mut self) {
213        self.doc_freq.clear();
214        for chunk in &self.chunks {
215            let mut seen = HashSet::new();
216            for token in &chunk.tokens {
217                if seen.insert(token) {
218                    *self.doc_freq.entry(token.clone()).or_insert(0) += 1;
219                }
220            }
221        }
222    }
223}
224
225/// Trait for pluggable document parsers.
226pub trait DocumentParser: Send + Sync {
227    /// Returns `true` when this parser can handle the given file extension.
228    fn supports_extension(&self, ext: &str) -> bool;
229    /// Parses `bytes` into a plain-text string.
230    fn parse_bytes(&self, bytes: &[u8]) -> Result<String>;
231}
232
233/// Parses plain `.txt` files (UTF-8 assumed, with a lossy fallback).
234#[derive(Debug, Default, Clone)]
235pub struct PlainTextParser;
236
237impl DocumentParser for PlainTextParser {
238    fn supports_extension(&self, ext: &str) -> bool {
239        ext.eq_ignore_ascii_case("txt")
240    }
241
242    fn parse_bytes(&self, bytes: &[u8]) -> Result<String> {
243        Ok(String::from_utf8_lossy(bytes).into_owned())
244    }
245}
246
247/// Parses `.md` (Markdown) files by stripping formatting markers and yielding plain text.
248#[derive(Debug, Default, Clone)]
249pub struct MarkdownParser;
250
251impl DocumentParser for MarkdownParser {
252    fn supports_extension(&self, ext: &str) -> bool {
253        matches!(ext.to_ascii_lowercase().as_str(), "md" | "markdown")
254    }
255
256    fn parse_bytes(&self, bytes: &[u8]) -> Result<String> {
257        let raw = String::from_utf8_lossy(bytes);
258        Ok(strip_markdown(&raw))
259    }
260}
261
262/// Parses `.pdf` files using `lopdf`. Requires the `knowledge` feature.
263#[cfg(feature = "knowledge")]
264#[derive(Debug, Default, Clone)]
265pub struct PdfParser;
266
267#[cfg(feature = "knowledge")]
268impl DocumentParser for PdfParser {
269    fn supports_extension(&self, ext: &str) -> bool {
270        ext.eq_ignore_ascii_case("pdf")
271    }
272
273    fn parse_bytes(&self, bytes: &[u8]) -> Result<String> {
274        #[cfg(feature = "knowledge")]
275        use lopdf::Document;
276
277        let doc =
278            Document::load_from(Cursor::new(bytes)).map_err(|e| anyhow!("lopdf error: {e}"))?;
279        let mut out = String::new();
280        for page_num in 1..=doc.get_pages().len() as u32 {
281            if let Ok(texts) = doc.extract_text(&[page_num]) {
282                out.push_str(&texts);
283                out.push('\n');
284            }
285        }
286        Ok(out)
287    }
288}
289
290/// Returns a list of parsers active for the current feature set.
291pub fn default_parsers() -> Vec<Box<dyn DocumentParser>> {
292    #[allow(unused_mut)]
293    let mut parsers: Vec<Box<dyn DocumentParser>> =
294        vec![Box::new(PlainTextParser), Box::new(MarkdownParser)];
295    #[cfg(feature = "knowledge")]
296    parsers.push(Box::new(PdfParser));
297    parsers
298}
299
300/// Ingest a [`KnowledgeSource`] into `index`, returning the number of new chunks.
301///
302/// This function is the top-level entry point used by [`LmmAgent::ingest`].
303pub async fn ingest(index: &mut KnowledgeIndex, source: KnowledgeSource) -> Result<usize> {
304    match source {
305        KnowledgeSource::RawText(text) => Ok(index.ingest_text("inline", &text)),
306
307        KnowledgeSource::File(path) => ingest_file(index, &path),
308
309        KnowledgeSource::Dir(dir) => {
310            let mut total = 0;
311            let entries = fs::read_dir(&dir)
312                .map_err(|e| anyhow!("Cannot read dir {}: {e}", dir.display()))?;
313            for entry in entries.flatten() {
314                let p = entry.path();
315                if p.is_file() {
316                    total += ingest_file(index, &p).unwrap_or(0);
317                }
318            }
319            Ok(total)
320        }
321
322        KnowledgeSource::Url(url) => ingest_url(index, &url).await,
323    }
324}
325
326fn ingest_file(index: &mut KnowledgeIndex, path: &Path) -> Result<usize> {
327    let ext = path
328        .extension()
329        .and_then(|e| e.to_str())
330        .unwrap_or("")
331        .to_ascii_lowercase();
332
333    let parsers = default_parsers();
334    let parser = parsers
335        .iter()
336        .find(|p| p.supports_extension(&ext))
337        .ok_or_else(|| anyhow!("No parser for '{}' (extension: .{})", path.display(), ext))?;
338
339    let bytes = fs::read(path).map_err(|e| anyhow!("Cannot read {}: {e}", path.display()))?;
340    let text = parser.parse_bytes(&bytes)?;
341    let label = path.file_name().and_then(|n| n.to_str()).unwrap_or("file");
342    Ok(index.ingest_text(label, &text))
343}
344
345#[cfg(feature = "net")]
346async fn ingest_url(index: &mut KnowledgeIndex, url: &str) -> Result<usize> {
347    let body = reqwest::get(url)
348        .await
349        .map_err(|e| anyhow!("HTTP error for {url}: {e}"))?
350        .text()
351        .await
352        .map_err(|e| anyhow!("Body error for {url}: {e}"))?;
353    let plain = strip_html(&body);
354    Ok(index.ingest_text(url, &plain))
355}
356
357#[cfg(not(feature = "net"))]
358async fn ingest_url(_index: &mut KnowledgeIndex, url: &str) -> Result<usize> {
359    Err(anyhow!(
360        "URL ingestion requires the `net` feature. Enable it with --features net. URL: {url}"
361    ))
362}
363
364static STOP_WORDS: phf::Set<&'static str> = phf::phf_set! {
365    "the","and","for","are","was","that","this","with","from","have",
366    "not","but","had","has","its","were","they","will","been","their",
367    "all","one","can","her","his","him","she","who","which","what",
368    "into","then","than","when","also","more","some","out","about",
369    "said","would","could","should","each","other","there","these",
370    "those","such","any","our","you","your","very","just","now","may"
371};
372
373fn tokenise(text: &str) -> Vec<String> {
374    text.split(|c: char| !c.is_ascii_alphabetic())
375        .filter_map(|w| {
376            let w = w.to_ascii_lowercase();
377            if w.len() >= 3 && !STOP_WORDS.contains(w.as_str()) {
378                Some(w)
379            } else {
380                None
381            }
382        })
383        .collect()
384}
385
386fn split_sentences(text: &str) -> Vec<String> {
387    let mut out = Vec::new();
388    let mut cur = String::new();
389    let chars: Vec<char> = text.chars().collect();
390    for (i, &ch) in chars.iter().enumerate() {
391        cur.push(ch);
392        if matches!(ch, '.' | '!' | '?') {
393            let next_lower = chars
394                .get(i + 1)
395                .map(|c| c.is_ascii_lowercase())
396                .unwrap_or(false);
397            let prev_digit = i > 0 && chars[i - 1].is_ascii_digit();
398            let next_digit = chars
399                .get(i + 1)
400                .map(|c| c.is_ascii_digit())
401                .unwrap_or(false);
402            if ch == '.' && (next_lower || (prev_digit && next_digit)) {
403                continue;
404            }
405            let s = cur.trim().to_string();
406            if !s.is_empty() {
407                out.push(s);
408            }
409            cur.clear();
410        }
411    }
412    let tail = cur.trim().to_string();
413    if !tail.is_empty() {
414        out.push(tail);
415    }
416    out
417}
418
419fn strip_markdown(text: &str) -> String {
420    let mut out = String::with_capacity(text.len());
421    let mut in_code = false;
422    for line in text.lines() {
423        let trimmed = line.trim();
424        if trimmed.starts_with("```") {
425            in_code = !in_code;
426            continue;
427        }
428        if in_code {
429            continue;
430        }
431        if trimmed.starts_with('#') || trimmed.starts_with("---") || trimmed.starts_with("===") {
432            continue;
433        }
434        let clean: String = trimmed
435            .chars()
436            .filter(|&c| c != '*' && c != '_' && c != '`' && c != '|')
437            .collect();
438        let clean = clean.trim();
439        if !clean.is_empty() {
440            out.push_str(clean);
441            out.push(' ');
442        }
443    }
444    out
445}
446
447#[allow(dead_code)]
448fn strip_html(html: &str) -> String {
449    let mut out = String::with_capacity(html.len() / 2);
450    let mut in_tag = false;
451    let mut buf = String::new();
452    for ch in html.chars() {
453        match ch {
454            '<' => {
455                if !buf.trim().is_empty() {
456                    out.push_str(buf.trim());
457                    out.push(' ');
458                }
459                buf.clear();
460                in_tag = true;
461            }
462            '>' => {
463                in_tag = false;
464            }
465            _ if !in_tag => {
466                buf.push(ch);
467            }
468            _ => {}
469        }
470    }
471    if !buf.trim().is_empty() {
472        out.push_str(buf.trim());
473    }
474    out
475}