Skip to main content

heartbit_core/knowledge/
loader.rs

1//! Document loader — reads plain text, Markdown, and PDF files into the knowledge base.
2
3use std::path::Path;
4
5use crate::auth::TenantScope;
6use crate::error::Error;
7
8use super::chunker::{ChunkConfig, split_into_chunks};
9use super::{DocumentSource, KnowledgeBase};
10
11/// Maximum bytes loaded from a single file by [`load_file`].
12///
13/// SECURITY (F-KB-3): a hostile filesystem (or accidental glob hit on a
14/// large dump file) would otherwise OOM the process via
15/// `read_to_string`. 50 MiB covers any reasonable Markdown / code file.
16pub const KNOWLEDGE_LOAD_FILE_MAX_BYTES: u64 = 50 * 1024 * 1024;
17
18/// Maximum bytes loaded from a single URL by [`load_url`].
19///
20/// SECURITY (F-KB-2): cap at 16 MiB to bound memory.
21pub const KNOWLEDGE_LOAD_URL_MAX_BYTES: usize = 16 * 1024 * 1024;
22
23/// Load a single file and index its chunks into the knowledge base under the given tenant.
24///
25/// SECURITY (F-KB-3): bounded by [`KNOWLEDGE_LOAD_FILE_MAX_BYTES`].
26pub async fn load_file(
27    kb: &dyn KnowledgeBase,
28    scope: &TenantScope,
29    path: &Path,
30    config: &ChunkConfig,
31) -> Result<usize, Error> {
32    use tokio::io::AsyncReadExt;
33    let mut file = tokio::fs::File::open(path)
34        .await
35        .map_err(|e| Error::Knowledge(format!("failed to open {}: {e}", path.display())))?;
36    let mut bytes: Vec<u8> = Vec::new();
37    let read = (&mut file)
38        .take(KNOWLEDGE_LOAD_FILE_MAX_BYTES + 1)
39        .read_to_end(&mut bytes)
40        .await
41        .map_err(|e| Error::Knowledge(format!("failed to read {}: {e}", path.display())))?;
42    if read as u64 > KNOWLEDGE_LOAD_FILE_MAX_BYTES {
43        return Err(Error::Knowledge(format!(
44            "{} exceeds {KNOWLEDGE_LOAD_FILE_MAX_BYTES} bytes (F-KB-3)",
45            path.display()
46        )));
47    }
48    let content = String::from_utf8_lossy(&bytes).into_owned();
49
50    let title = path
51        .file_name()
52        .map(|n| n.to_string_lossy().into_owned())
53        .unwrap_or_else(|| path.display().to_string());
54
55    let source = DocumentSource {
56        uri: path.display().to_string(),
57        title,
58    };
59
60    let chunks = split_into_chunks(&content, &source, config);
61    let count = chunks.len();
62    for chunk in chunks {
63        kb.index(scope, chunk).await?;
64    }
65    Ok(count)
66}
67
68/// Load all files matching a glob pattern and index their chunks.
69pub async fn load_glob(
70    kb: &dyn KnowledgeBase,
71    scope: &TenantScope,
72    pattern: &str,
73    config: &ChunkConfig,
74) -> Result<usize, Error> {
75    let paths = glob::glob(pattern)
76        .map_err(|e| Error::Knowledge(format!("invalid glob pattern '{pattern}': {e}")))?;
77
78    let mut total = 0;
79    for entry in paths {
80        let path = entry.map_err(|e| Error::Knowledge(format!("glob error: {e}")))?;
81        if path.is_file() {
82            match load_file(kb, scope, &path, config).await {
83                Ok(count) => total += count,
84                Err(e) => {
85                    tracing::warn!(path = %path.display(), error = %e, "skipping file");
86                }
87            }
88        }
89    }
90    Ok(total)
91}
92
93/// Load a URL, strip HTML tags, and index chunks.
94///
95/// SECURITY (F-KB-2): the URL is validated via [`crate::http::SafeUrl::parse`]
96/// (scheme allowlist + IP blocklist), the request uses
97/// [`crate::http::safe_client_builder`] (redirect:none, no_proxy,
98/// connect_timeout), and the body is capped at
99/// [`KNOWLEDGE_LOAD_URL_MAX_BYTES`].
100pub async fn load_url(
101    kb: &dyn KnowledgeBase,
102    scope: &TenantScope,
103    url: &str,
104    config: &ChunkConfig,
105) -> Result<usize, Error> {
106    // SECURITY (F-KB-2): SSRF blocklist + scheme allowlist.
107    let safe = crate::http::SafeUrl::parse(url, crate::http::IpPolicy::default())
108        .await
109        .map_err(|e| Error::Knowledge(format!("URL refused for {url}: {e}")))?;
110
111    let client = crate::http::safe_client_builder()
112        .timeout(std::time::Duration::from_secs(30))
113        .build()
114        .map_err(|e| Error::Knowledge(format!("client build error: {e}")))?;
115    let response = client
116        .get(safe.as_str())
117        .send()
118        .await
119        .map_err(|e| Error::Knowledge(format!("failed to fetch {url}: {e}")))?;
120
121    if !response.status().is_success() {
122        return Err(Error::Knowledge(format!(
123            "HTTP {} fetching {url}",
124            response.status()
125        )));
126    }
127
128    let body = crate::http::read_text_capped(response, KNOWLEDGE_LOAD_URL_MAX_BYTES)
129        .await
130        .map_err(|e| Error::Knowledge(format!("failed to read body from {url}: {e}")))?;
131
132    let content = strip_html_tags(&body);
133
134    let source = DocumentSource {
135        uri: url.to_string(),
136        title: url.to_string(),
137    };
138
139    let chunks = split_into_chunks(&content, &source, config);
140    let count = chunks.len();
141    for chunk in chunks {
142        kb.index(scope, chunk).await?;
143    }
144    Ok(count)
145}
146
147/// Strip HTML tags from text, replacing them with spaces.
148///
149/// Skips content inside `<script>` and `<style>` tags. For full
150/// HTML→markdown conversion a dedicated crate would be appropriate,
151/// but for V1 tag stripping suffices.
152pub fn strip_html_tags(html: &str) -> String {
153    let mut result = String::with_capacity(html.len());
154    let mut in_tag = false;
155    let mut tag_name = String::new();
156    let mut collecting_tag = false;
157    let mut last_was_space = false;
158    let mut skip_content = false; // true inside <script> or <style>
159
160    for ch in html.chars() {
161        if ch == '<' {
162            in_tag = true;
163            tag_name.clear();
164            collecting_tag = true;
165            if !skip_content && !last_was_space && !result.is_empty() {
166                result.push(' ');
167                last_was_space = true;
168            }
169        } else if ch == '>' && in_tag {
170            in_tag = false;
171            collecting_tag = false;
172            let tag_lower = tag_name.to_lowercase();
173            match tag_lower.as_str() {
174                "script" | "style" => skip_content = true,
175                "/script" | "/style" => skip_content = false,
176                _ => {}
177            }
178        } else if in_tag && collecting_tag {
179            if ch.is_whitespace() {
180                collecting_tag = false;
181            } else {
182                tag_name.push(ch);
183            }
184        } else if !in_tag && !skip_content {
185            if ch.is_whitespace() {
186                if !last_was_space {
187                    result.push(' ');
188                    last_was_space = true;
189                }
190            } else {
191                result.push(ch);
192                last_was_space = false;
193            }
194        }
195    }
196
197    result.trim().to_string()
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203    use crate::knowledge::KnowledgeQuery;
204    use crate::knowledge::in_memory::InMemoryKnowledgeBase;
205    use std::io::Write;
206    use tempfile::NamedTempFile;
207
208    fn s() -> TenantScope {
209        TenantScope::default()
210    }
211
212    #[tokio::test]
213    async fn load_file_indexes_content() {
214        let mut tmp = NamedTempFile::new().unwrap();
215        writeln!(tmp, "Rust is a systems programming language.").unwrap();
216        writeln!(tmp).unwrap();
217        writeln!(tmp, "It provides memory safety without garbage collection.").unwrap();
218
219        let kb = InMemoryKnowledgeBase::new();
220        let count = load_file(&kb, &s(), tmp.path(), &ChunkConfig::default())
221            .await
222            .unwrap();
223        assert!(count >= 1);
224
225        let results = kb
226            .search(
227                &s(),
228                KnowledgeQuery {
229                    text: "rust memory".into(),
230                    source_filter: None,
231                    limit: 5,
232                },
233            )
234            .await
235            .unwrap();
236        assert!(!results.is_empty());
237    }
238
239    #[tokio::test]
240    async fn load_file_nonexistent_returns_error() {
241        let kb = InMemoryKnowledgeBase::new();
242        let err = load_file(
243            &kb,
244            &s(),
245            Path::new("/nonexistent/file.md"),
246            &ChunkConfig::default(),
247        )
248        .await
249        .unwrap_err();
250        assert!(matches!(err, Error::Knowledge(_)));
251        // Either failed to open or failed to read — both are acceptable
252        // (we now open then take(...).read_to_end, F-KB-3).
253        let s = err.to_string();
254        assert!(
255            s.contains("failed to read") || s.contains("failed to open"),
256            "expected open/read error, got: {s}"
257        );
258    }
259
260    #[tokio::test]
261    async fn load_glob_collects_files() {
262        let dir = tempfile::tempdir().unwrap();
263        let f1 = dir.path().join("doc1.md");
264        let f2 = dir.path().join("doc2.md");
265        std::fs::write(&f1, "First document about rust.").unwrap();
266        std::fs::write(&f2, "Second document about async.").unwrap();
267
268        let kb = InMemoryKnowledgeBase::new();
269        let pattern = format!("{}/*.md", dir.path().display());
270        let count = load_glob(&kb, &s(), &pattern, &ChunkConfig::default())
271            .await
272            .unwrap();
273        assert!(count >= 2, "expected >= 2 chunks, got {count}");
274        assert!(kb.chunk_count(&s()).await.unwrap() >= 2);
275    }
276
277    #[tokio::test]
278    async fn load_glob_invalid_pattern_returns_error() {
279        let kb = InMemoryKnowledgeBase::new();
280        let err = load_glob(&kb, &s(), "[invalid", &ChunkConfig::default())
281            .await
282            .unwrap_err();
283        assert!(matches!(err, Error::Knowledge(_)));
284    }
285
286    #[test]
287    fn strip_html_basic() {
288        let html = "<html><body><h1>Title</h1><p>Hello world</p></body></html>";
289        let text = strip_html_tags(html);
290        assert!(text.contains("Title"));
291        assert!(text.contains("Hello world"));
292        assert!(!text.contains('<'));
293        assert!(!text.contains('>'));
294    }
295
296    #[test]
297    fn strip_html_preserves_plain_text() {
298        let text = "Just plain text, no HTML.";
299        assert_eq!(strip_html_tags(text), text);
300    }
301
302    #[test]
303    fn strip_html_collapses_whitespace() {
304        let html = "<p>  lots   of    spaces  </p>";
305        let text = strip_html_tags(html);
306        assert_eq!(text, "lots of spaces");
307    }
308
309    #[test]
310    fn strip_html_skips_script_content() {
311        let html = "<p>Hello</p><script>var x = 1; alert('xss');</script><p>World</p>";
312        let text = strip_html_tags(html);
313        assert!(text.contains("Hello"));
314        assert!(text.contains("World"));
315        assert!(!text.contains("alert"));
316        assert!(!text.contains("var x"));
317    }
318
319    #[test]
320    fn strip_html_skips_style_content() {
321        let html = "<p>Hello</p><style>body { color: red; }</style><p>World</p>";
322        let text = strip_html_tags(html);
323        assert!(text.contains("Hello"));
324        assert!(text.contains("World"));
325        assert!(!text.contains("color"));
326        assert!(!text.contains("body"));
327    }
328
329    #[test]
330    fn strip_html_empty_input() {
331        assert_eq!(strip_html_tags(""), "");
332    }
333
334    #[test]
335    fn strip_html_nested_tags() {
336        let html = "<div><span>nested</span> content</div>";
337        let text = strip_html_tags(html);
338        assert!(text.contains("nested"));
339        assert!(text.contains("content"));
340    }
341}