cognis_rag/splitters/
html.rs1use std::collections::BTreeSet;
19
20use crate::document::Document;
21
22use super::{child_doc, TextSplitter};
23
24#[derive(Debug, Clone)]
33pub struct HtmlSplitter {
34 levels: BTreeSet<u8>,
36 strip_tags: bool,
38 min_chunk_size: usize,
40}
41
42impl Default for HtmlSplitter {
43 fn default() -> Self {
44 Self {
45 levels: (1u8..=6).collect(),
46 strip_tags: true,
47 min_chunk_size: 0,
48 }
49 }
50}
51
52impl HtmlSplitter {
53 pub fn new() -> Self {
55 Self::default()
56 }
57
58 pub fn with_levels<I: IntoIterator<Item = u8>>(mut self, levels: I) -> Self {
60 self.levels = levels.into_iter().filter(|n| (1..=6).contains(n)).collect();
61 self
62 }
63
64 pub fn with_strip_tags(mut self, strip: bool) -> Self {
66 self.strip_tags = strip;
67 self
68 }
69
70 pub fn with_min_chunk_size(mut self, n: usize) -> Self {
72 self.min_chunk_size = n;
73 self
74 }
75
76 fn find_boundaries(&self, input: &str) -> Vec<(usize, u8, String)> {
79 let bytes = input.as_bytes();
80 let mut out: Vec<(usize, u8, String)> = Vec::new();
81 let mut i = 0;
82 while i + 4 <= bytes.len() {
83 if bytes[i] == b'<'
85 && (bytes[i + 1] == b'h' || bytes[i + 1] == b'H')
86 && bytes[i + 2].is_ascii_digit()
87 && (1..=6).contains(&(bytes[i + 2] - b'0'))
88 {
89 let level = bytes[i + 2] - b'0';
90 if !self.levels.contains(&level) {
91 i += 1;
92 continue;
93 }
94 let close = match input[i..].find('>') {
96 Some(p) => i + p,
97 None => break,
98 };
99 let needle = format!("</h{level}>");
101 let needle_lower = needle.to_lowercase();
102 let after = close + 1;
103 let end_rel = input[after..].to_lowercase().find(&needle_lower);
104 let end = match end_rel {
105 Some(p) => after + p,
106 None => break,
107 };
108 let heading_text = strip_tags(&input[after..end]);
109 out.push((i, level, heading_text.trim().to_string()));
110 i = end + needle.len();
111 continue;
112 }
113 i += 1;
114 }
115 out
116 }
117}
118
119fn strip_tags(s: &str) -> String {
120 let mut out = String::with_capacity(s.len());
121 let mut depth = 0i32;
122 for ch in s.chars() {
123 match ch {
124 '<' => depth += 1,
125 '>' if depth > 0 => depth -= 1,
126 _ if depth == 0 => out.push(ch),
127 _ => {}
128 }
129 }
130 out
131}
132
133impl TextSplitter for HtmlSplitter {
134 fn split(&self, doc: &Document) -> Vec<Document> {
135 if doc.content.is_empty() {
136 return Vec::new();
137 }
138 let boundaries = self.find_boundaries(&doc.content);
139 if boundaries.is_empty() {
140 let content = if self.strip_tags {
142 strip_tags(&doc.content).trim().to_string()
143 } else {
144 doc.content.clone()
145 };
146 if content.is_empty() {
147 return Vec::new();
148 }
149 return vec![child_doc(doc, content, 0)];
150 }
151 let mut chunks: Vec<(String, Vec<(u8, String)>)> = Vec::new();
154 let mut current_levels: [Option<String>; 7] = Default::default();
156 let preamble_end = boundaries[0].0;
158 if preamble_end > 0 {
159 let body = &doc.content[..preamble_end];
160 let stripped = if self.strip_tags {
161 strip_tags(body).trim().to_string()
162 } else {
163 body.to_string()
164 };
165 if !stripped.is_empty() {
166 chunks.push((stripped, Vec::new()));
167 }
168 }
169 for (i, b) in boundaries.iter().enumerate() {
170 let next_start = boundaries
171 .get(i + 1)
172 .map(|n| n.0)
173 .unwrap_or(doc.content.len());
174 let section = &doc.content[b.0..next_start];
175 let level = b.1 as usize;
176 current_levels[level] = Some(b.2.clone());
177 for slot in current_levels.iter_mut().skip(level + 1) {
179 *slot = None;
180 }
181 let stripped = if self.strip_tags {
182 strip_tags(section).trim().to_string()
183 } else {
184 section.to_string()
185 };
186 if stripped.is_empty() {
187 continue;
188 }
189 let crumbs: Vec<(u8, String)> = (1..=6)
190 .filter_map(|lvl| {
191 current_levels[lvl as usize]
192 .as_ref()
193 .map(|s| (lvl, s.clone()))
194 })
195 .collect();
196 chunks.push((stripped, crumbs));
197 }
198 if self.min_chunk_size > 0 {
200 let mut i = 0;
201 while i + 1 < chunks.len() {
202 if chunks[i + 1].0.chars().count() < self.min_chunk_size {
203 let trailing = chunks.remove(i + 1);
204 chunks[i].0.push_str("\n\n");
205 chunks[i].0.push_str(&trailing.0);
206 continue;
207 }
208 i += 1;
209 }
210 }
211 chunks
212 .into_iter()
213 .enumerate()
214 .map(|(i, (content, crumbs))| {
215 let mut child = child_doc(doc, content, i);
216 for (lvl, name) in crumbs {
217 child
218 .metadata
219 .insert(format!("h{lvl}"), serde_json::Value::String(name));
220 }
221 child
222 })
223 .collect()
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230
231 fn doc(s: &str) -> Document {
232 Document::new(s)
233 }
234
235 #[test]
236 fn splits_on_h1_boundaries() {
237 let s = HtmlSplitter::new();
238 let d = doc("<h1>One</h1>first<h1>Two</h1>second");
239 let out = s.split(&d);
240 assert_eq!(out.len(), 2);
241 assert!(out[0].content.contains("first"));
242 assert!(out[1].content.contains("second"));
243 }
244
245 #[test]
246 fn metadata_records_heading_breadcrumbs() {
247 let s = HtmlSplitter::new();
248 let d = doc("<h1>Top</h1>intro<h2>Sub</h2>body");
249 let out = s.split(&d);
250 let last = out.last().unwrap();
252 assert_eq!(
253 last.metadata.get("h1"),
254 Some(&serde_json::Value::String("Top".into()))
255 );
256 assert_eq!(
257 last.metadata.get("h2"),
258 Some(&serde_json::Value::String("Sub".into()))
259 );
260 }
261
262 #[test]
263 fn deeper_headings_clear_when_parent_changes() {
264 let s = HtmlSplitter::new();
265 let d = doc("<h1>A</h1>x<h2>A2</h2>y<h1>B</h1>z");
266 let out = s.split(&d);
267 let last = out.last().unwrap();
269 assert!(!last.metadata.contains_key("h2"));
270 assert_eq!(
271 last.metadata.get("h1"),
272 Some(&serde_json::Value::String("B".into()))
273 );
274 }
275
276 #[test]
277 fn level_filter_only_splits_at_selected_levels() {
278 let s = HtmlSplitter::new().with_levels([1u8]);
279 let d = doc("<h1>One</h1>a<h2>Sub</h2>b<h1>Two</h1>c");
280 let out = s.split(&d);
281 assert_eq!(out.len(), 2);
283 assert!(out[0].content.contains("a"));
284 assert!(out[0].content.contains("b"));
285 }
286
287 #[test]
288 fn strip_tags_strips_inner_markup() {
289 let s = HtmlSplitter::new();
290 let d = doc("<h1>One</h1><p>hello <b>world</b></p>");
291 let out = s.split(&d);
292 assert_eq!(out.len(), 1);
293 assert!(out[0].content.contains("hello"));
294 assert!(out[0].content.contains("world"));
295 assert!(!out[0].content.contains("<b>"));
296 }
297
298 #[test]
299 fn strip_tags_can_be_disabled() {
300 let s = HtmlSplitter::new().with_strip_tags(false);
301 let d = doc("<h1>One</h1><p>hello</p>");
302 let out = s.split(&d);
303 assert!(out[0].content.contains("<p>"));
304 }
305
306 #[test]
307 fn doc_without_headings_returns_one_chunk() {
308 let s = HtmlSplitter::new();
309 let d = doc("<p>just a paragraph</p>");
310 let out = s.split(&d);
311 assert_eq!(out.len(), 1);
312 assert_eq!(out[0].content, "just a paragraph");
313 }
314
315 #[test]
316 fn preamble_before_first_heading_is_kept() {
317 let s = HtmlSplitter::new();
318 let d = doc("<p>preamble</p><h1>One</h1>body");
319 let out = s.split(&d);
320 assert!(out.len() >= 2);
321 assert_eq!(out[0].content, "preamble");
322 }
323
324 #[test]
325 fn min_chunk_size_coalesces_small_tail() {
326 let s = HtmlSplitter::new().with_min_chunk_size(50);
327 let d = doc("<h1>Big</h1>this is a longer body text<h1>Tiny</h1>x");
328 let out = s.split(&d);
329 assert_eq!(out.len(), 1);
332 assert!(out[0].content.contains("this is"));
333 assert!(out[0].content.contains("x"));
334 }
335
336 #[test]
337 fn empty_input_returns_empty() {
338 let s = HtmlSplitter::new();
339 let d = doc("");
340 assert!(s.split(&d).is_empty());
341 }
342}