heartbit_core/knowledge/
loader.rs1use std::path::Path;
4
5use crate::auth::TenantScope;
6use crate::error::Error;
7
8use super::chunker::{ChunkConfig, split_into_chunks};
9use super::{DocumentSource, KnowledgeBase};
10
11pub const KNOWLEDGE_LOAD_FILE_MAX_BYTES: u64 = 50 * 1024 * 1024;
17
18pub const KNOWLEDGE_LOAD_URL_MAX_BYTES: usize = 16 * 1024 * 1024;
22
23pub async fn load_file(
27 kb: &dyn KnowledgeBase,
28 scope: &TenantScope,
29 path: &Path,
30 config: &ChunkConfig,
31) -> Result<usize, Error> {
32 use tokio::io::AsyncReadExt;
33 let mut file = tokio::fs::File::open(path)
34 .await
35 .map_err(|e| Error::Knowledge(format!("failed to open {}: {e}", path.display())))?;
36 let mut bytes: Vec<u8> = Vec::new();
37 let read = (&mut file)
38 .take(KNOWLEDGE_LOAD_FILE_MAX_BYTES + 1)
39 .read_to_end(&mut bytes)
40 .await
41 .map_err(|e| Error::Knowledge(format!("failed to read {}: {e}", path.display())))?;
42 if read as u64 > KNOWLEDGE_LOAD_FILE_MAX_BYTES {
43 return Err(Error::Knowledge(format!(
44 "{} exceeds {KNOWLEDGE_LOAD_FILE_MAX_BYTES} bytes (F-KB-3)",
45 path.display()
46 )));
47 }
48 let content = String::from_utf8_lossy(&bytes).into_owned();
49
50 let title = path
51 .file_name()
52 .map(|n| n.to_string_lossy().into_owned())
53 .unwrap_or_else(|| path.display().to_string());
54
55 let source = DocumentSource {
56 uri: path.display().to_string(),
57 title,
58 };
59
60 let chunks = split_into_chunks(&content, &source, config);
61 let count = chunks.len();
62 for chunk in chunks {
63 kb.index(scope, chunk).await?;
64 }
65 Ok(count)
66}
67
68pub async fn load_glob(
70 kb: &dyn KnowledgeBase,
71 scope: &TenantScope,
72 pattern: &str,
73 config: &ChunkConfig,
74) -> Result<usize, Error> {
75 let paths = glob::glob(pattern)
76 .map_err(|e| Error::Knowledge(format!("invalid glob pattern '{pattern}': {e}")))?;
77
78 let mut total = 0;
79 for entry in paths {
80 let path = entry.map_err(|e| Error::Knowledge(format!("glob error: {e}")))?;
81 if path.is_file() {
82 match load_file(kb, scope, &path, config).await {
83 Ok(count) => total += count,
84 Err(e) => {
85 tracing::warn!(path = %path.display(), error = %e, "skipping file");
86 }
87 }
88 }
89 }
90 Ok(total)
91}
92
93pub async fn load_url(
101 kb: &dyn KnowledgeBase,
102 scope: &TenantScope,
103 url: &str,
104 config: &ChunkConfig,
105) -> Result<usize, Error> {
106 let safe = crate::http::SafeUrl::parse(url, crate::http::IpPolicy::default())
108 .await
109 .map_err(|e| Error::Knowledge(format!("URL refused for {url}: {e}")))?;
110
111 let client = crate::http::safe_client_builder()
112 .timeout(std::time::Duration::from_secs(30))
113 .build()
114 .map_err(|e| Error::Knowledge(format!("client build error: {e}")))?;
115 let response = client
116 .get(safe.as_str())
117 .send()
118 .await
119 .map_err(|e| Error::Knowledge(format!("failed to fetch {url}: {e}")))?;
120
121 if !response.status().is_success() {
122 return Err(Error::Knowledge(format!(
123 "HTTP {} fetching {url}",
124 response.status()
125 )));
126 }
127
128 let body = crate::http::read_text_capped(response, KNOWLEDGE_LOAD_URL_MAX_BYTES)
129 .await
130 .map_err(|e| Error::Knowledge(format!("failed to read body from {url}: {e}")))?;
131
132 let content = strip_html_tags(&body);
133
134 let source = DocumentSource {
135 uri: url.to_string(),
136 title: url.to_string(),
137 };
138
139 let chunks = split_into_chunks(&content, &source, config);
140 let count = chunks.len();
141 for chunk in chunks {
142 kb.index(scope, chunk).await?;
143 }
144 Ok(count)
145}
146
147pub fn strip_html_tags(html: &str) -> String {
153 let mut result = String::with_capacity(html.len());
154 let mut in_tag = false;
155 let mut tag_name = String::new();
156 let mut collecting_tag = false;
157 let mut last_was_space = false;
158 let mut skip_content = false; for ch in html.chars() {
161 if ch == '<' {
162 in_tag = true;
163 tag_name.clear();
164 collecting_tag = true;
165 if !skip_content && !last_was_space && !result.is_empty() {
166 result.push(' ');
167 last_was_space = true;
168 }
169 } else if ch == '>' && in_tag {
170 in_tag = false;
171 collecting_tag = false;
172 let tag_lower = tag_name.to_lowercase();
173 match tag_lower.as_str() {
174 "script" | "style" => skip_content = true,
175 "/script" | "/style" => skip_content = false,
176 _ => {}
177 }
178 } else if in_tag && collecting_tag {
179 if ch.is_whitespace() {
180 collecting_tag = false;
181 } else {
182 tag_name.push(ch);
183 }
184 } else if !in_tag && !skip_content {
185 if ch.is_whitespace() {
186 if !last_was_space {
187 result.push(' ');
188 last_was_space = true;
189 }
190 } else {
191 result.push(ch);
192 last_was_space = false;
193 }
194 }
195 }
196
197 result.trim().to_string()
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203 use crate::knowledge::KnowledgeQuery;
204 use crate::knowledge::in_memory::InMemoryKnowledgeBase;
205 use std::io::Write;
206 use tempfile::NamedTempFile;
207
208 fn s() -> TenantScope {
209 TenantScope::default()
210 }
211
212 #[tokio::test]
213 async fn load_file_indexes_content() {
214 let mut tmp = NamedTempFile::new().unwrap();
215 writeln!(tmp, "Rust is a systems programming language.").unwrap();
216 writeln!(tmp).unwrap();
217 writeln!(tmp, "It provides memory safety without garbage collection.").unwrap();
218
219 let kb = InMemoryKnowledgeBase::new();
220 let count = load_file(&kb, &s(), tmp.path(), &ChunkConfig::default())
221 .await
222 .unwrap();
223 assert!(count >= 1);
224
225 let results = kb
226 .search(
227 &s(),
228 KnowledgeQuery {
229 text: "rust memory".into(),
230 source_filter: None,
231 limit: 5,
232 },
233 )
234 .await
235 .unwrap();
236 assert!(!results.is_empty());
237 }
238
239 #[tokio::test]
240 async fn load_file_nonexistent_returns_error() {
241 let kb = InMemoryKnowledgeBase::new();
242 let err = load_file(
243 &kb,
244 &s(),
245 Path::new("/nonexistent/file.md"),
246 &ChunkConfig::default(),
247 )
248 .await
249 .unwrap_err();
250 assert!(matches!(err, Error::Knowledge(_)));
251 let s = err.to_string();
254 assert!(
255 s.contains("failed to read") || s.contains("failed to open"),
256 "expected open/read error, got: {s}"
257 );
258 }
259
260 #[tokio::test]
261 async fn load_glob_collects_files() {
262 let dir = tempfile::tempdir().unwrap();
263 let f1 = dir.path().join("doc1.md");
264 let f2 = dir.path().join("doc2.md");
265 std::fs::write(&f1, "First document about rust.").unwrap();
266 std::fs::write(&f2, "Second document about async.").unwrap();
267
268 let kb = InMemoryKnowledgeBase::new();
269 let pattern = format!("{}/*.md", dir.path().display());
270 let count = load_glob(&kb, &s(), &pattern, &ChunkConfig::default())
271 .await
272 .unwrap();
273 assert!(count >= 2, "expected >= 2 chunks, got {count}");
274 assert!(kb.chunk_count(&s()).await.unwrap() >= 2);
275 }
276
277 #[tokio::test]
278 async fn load_glob_invalid_pattern_returns_error() {
279 let kb = InMemoryKnowledgeBase::new();
280 let err = load_glob(&kb, &s(), "[invalid", &ChunkConfig::default())
281 .await
282 .unwrap_err();
283 assert!(matches!(err, Error::Knowledge(_)));
284 }
285
286 #[test]
287 fn strip_html_basic() {
288 let html = "<html><body><h1>Title</h1><p>Hello world</p></body></html>";
289 let text = strip_html_tags(html);
290 assert!(text.contains("Title"));
291 assert!(text.contains("Hello world"));
292 assert!(!text.contains('<'));
293 assert!(!text.contains('>'));
294 }
295
296 #[test]
297 fn strip_html_preserves_plain_text() {
298 let text = "Just plain text, no HTML.";
299 assert_eq!(strip_html_tags(text), text);
300 }
301
302 #[test]
303 fn strip_html_collapses_whitespace() {
304 let html = "<p> lots of spaces </p>";
305 let text = strip_html_tags(html);
306 assert_eq!(text, "lots of spaces");
307 }
308
309 #[test]
310 fn strip_html_skips_script_content() {
311 let html = "<p>Hello</p><script>var x = 1; alert('xss');</script><p>World</p>";
312 let text = strip_html_tags(html);
313 assert!(text.contains("Hello"));
314 assert!(text.contains("World"));
315 assert!(!text.contains("alert"));
316 assert!(!text.contains("var x"));
317 }
318
319 #[test]
320 fn strip_html_skips_style_content() {
321 let html = "<p>Hello</p><style>body { color: red; }</style><p>World</p>";
322 let text = strip_html_tags(html);
323 assert!(text.contains("Hello"));
324 assert!(text.contains("World"));
325 assert!(!text.contains("color"));
326 assert!(!text.contains("body"));
327 }
328
329 #[test]
330 fn strip_html_empty_input() {
331 assert_eq!(strip_html_tags(""), "");
332 }
333
334 #[test]
335 fn strip_html_nested_tags() {
336 let html = "<div><span>nested</span> content</div>";
337 let text = strip_html_tags(html);
338 assert!(text.contains("nested"));
339 assert!(text.contains("content"));
340 }
341}