Skip to main content

cognis_rag/splitters/
html.rs

1//! HTML-aware splitter — chunks an HTML document at heading boundaries
2//! while preserving the heading hierarchy as document metadata.
3//!
4//! Lightweight implementation: scans the input character-by-character,
5//! tracking `<h1>` … `<h6>` and `<section>` openings without pulling in a
6//! full HTML parser. This keeps `cognis-rag` free of native deps while
7//! covering the common cases. For full DOM-aware splitting use a custom
8//! splitter built on `scraper`/`html5ever`.
9//!
10//! Customization knobs:
11//! - `with_levels(set)` — which heading levels create chunk boundaries
12//!   (default: 1..=6).
13//! - `with_strip_tags(bool)` — strip remaining HTML tags from chunk text
14//!   (default: `true`).
15//! - `with_min_chunk_size(n)` — coalesce smaller-than-n leftover chunks
16//!   into the previous one.
17
18use std::collections::BTreeSet;
19
20use crate::document::Document;
21
22use super::{child_doc, TextSplitter};
23
24/// HTML-aware splitter. Splits at heading boundaries; metadata records
25/// the active heading at each level.
26///
27/// **Extension path**: implement [`crate::splitters::TextSplitter`]
28/// directly when the heading-based scheme doesn't fit. The
29/// [`HtmlSplitter::with_levels`] / [`HtmlSplitter::with_strip_tags`] /
30/// [`HtmlSplitter::with_min_chunk_size`] knobs cover the common
31/// configuration axes for the built-in heuristic.
32#[derive(Debug, Clone)]
33pub struct HtmlSplitter {
34    /// Heading levels (1..=6) at which to create boundaries.
35    levels: BTreeSet<u8>,
36    /// Strip remaining HTML tags from chunk text.
37    strip_tags: bool,
38    /// Coalesce chunks smaller than this into the previous chunk.
39    min_chunk_size: usize,
40}
41
42impl Default for HtmlSplitter {
43    fn default() -> Self {
44        Self {
45            levels: (1u8..=6).collect(),
46            strip_tags: true,
47            min_chunk_size: 0,
48        }
49    }
50}
51
52impl HtmlSplitter {
53    /// New splitter with default settings (split on every heading level).
54    pub fn new() -> Self {
55        Self::default()
56    }
57
58    /// Restrict to specific heading levels (e.g. `[1, 2]` for top-two).
59    pub fn with_levels<I: IntoIterator<Item = u8>>(mut self, levels: I) -> Self {
60        self.levels = levels.into_iter().filter(|n| (1..=6).contains(n)).collect();
61        self
62    }
63
64    /// Whether to strip remaining HTML tags from chunk text. Default `true`.
65    pub fn with_strip_tags(mut self, strip: bool) -> Self {
66        self.strip_tags = strip;
67        self
68    }
69
70    /// Coalesce trailing chunks smaller than `n` chars into the previous one.
71    pub fn with_min_chunk_size(mut self, n: usize) -> Self {
72        self.min_chunk_size = n;
73        self
74    }
75
76    /// Locate every heading boundary (start byte, level, heading text)
77    /// in `input`. Boundaries are returned in document order.
78    fn find_boundaries(&self, input: &str) -> Vec<(usize, u8, String)> {
79        let bytes = input.as_bytes();
80        let mut out: Vec<(usize, u8, String)> = Vec::new();
81        let mut i = 0;
82        while i + 4 <= bytes.len() {
83            // Look for `<hN` where N in 1..=6.
84            if bytes[i] == b'<'
85                && (bytes[i + 1] == b'h' || bytes[i + 1] == b'H')
86                && bytes[i + 2].is_ascii_digit()
87                && (1..=6).contains(&(bytes[i + 2] - b'0'))
88            {
89                let level = bytes[i + 2] - b'0';
90                if !self.levels.contains(&level) {
91                    i += 1;
92                    continue;
93                }
94                // Find `>` ending the open-tag.
95                let close = match input[i..].find('>') {
96                    Some(p) => i + p,
97                    None => break,
98                };
99                // Find `</hN>`.
100                let needle = format!("</h{level}>");
101                let needle_lower = needle.to_lowercase();
102                let after = close + 1;
103                let end_rel = input[after..].to_lowercase().find(&needle_lower);
104                let end = match end_rel {
105                    Some(p) => after + p,
106                    None => break,
107                };
108                let heading_text = strip_tags(&input[after..end]);
109                out.push((i, level, heading_text.trim().to_string()));
110                i = end + needle.len();
111                continue;
112            }
113            i += 1;
114        }
115        out
116    }
117}
118
119fn strip_tags(s: &str) -> String {
120    let mut out = String::with_capacity(s.len());
121    let mut depth = 0i32;
122    for ch in s.chars() {
123        match ch {
124            '<' => depth += 1,
125            '>' if depth > 0 => depth -= 1,
126            _ if depth == 0 => out.push(ch),
127            _ => {}
128        }
129    }
130    out
131}
132
133impl TextSplitter for HtmlSplitter {
134    fn split(&self, doc: &Document) -> Vec<Document> {
135        if doc.content.is_empty() {
136            return Vec::new();
137        }
138        let boundaries = self.find_boundaries(&doc.content);
139        if boundaries.is_empty() {
140            // No headings — return whole doc (optionally tag-stripped).
141            let content = if self.strip_tags {
142                strip_tags(&doc.content).trim().to_string()
143            } else {
144                doc.content.clone()
145            };
146            if content.is_empty() {
147                return Vec::new();
148            }
149            return vec![child_doc(doc, content, 0)];
150        }
151        // Walk boundaries → produce sections from boundary[i].start to
152        // boundary[i+1].start (or end of doc for the last).
153        let mut chunks: Vec<(String, Vec<(u8, String)>)> = Vec::new();
154        // Track the current heading at each level.
155        let mut current_levels: [Option<String>; 7] = Default::default();
156        // Optional preamble before the first heading.
157        let preamble_end = boundaries[0].0;
158        if preamble_end > 0 {
159            let body = &doc.content[..preamble_end];
160            let stripped = if self.strip_tags {
161                strip_tags(body).trim().to_string()
162            } else {
163                body.to_string()
164            };
165            if !stripped.is_empty() {
166                chunks.push((stripped, Vec::new()));
167            }
168        }
169        for (i, b) in boundaries.iter().enumerate() {
170            let next_start = boundaries
171                .get(i + 1)
172                .map(|n| n.0)
173                .unwrap_or(doc.content.len());
174            let section = &doc.content[b.0..next_start];
175            let level = b.1 as usize;
176            current_levels[level] = Some(b.2.clone());
177            // Reset deeper levels (they belong to the previous parent).
178            for slot in current_levels.iter_mut().skip(level + 1) {
179                *slot = None;
180            }
181            let stripped = if self.strip_tags {
182                strip_tags(section).trim().to_string()
183            } else {
184                section.to_string()
185            };
186            if stripped.is_empty() {
187                continue;
188            }
189            let crumbs: Vec<(u8, String)> = (1..=6)
190                .filter_map(|lvl| {
191                    current_levels[lvl as usize]
192                        .as_ref()
193                        .map(|s| (lvl, s.clone()))
194                })
195                .collect();
196            chunks.push((stripped, crumbs));
197        }
198        // Coalesce small chunks.
199        if self.min_chunk_size > 0 {
200            let mut i = 0;
201            while i + 1 < chunks.len() {
202                if chunks[i + 1].0.chars().count() < self.min_chunk_size {
203                    let trailing = chunks.remove(i + 1);
204                    chunks[i].0.push_str("\n\n");
205                    chunks[i].0.push_str(&trailing.0);
206                    continue;
207                }
208                i += 1;
209            }
210        }
211        chunks
212            .into_iter()
213            .enumerate()
214            .map(|(i, (content, crumbs))| {
215                let mut child = child_doc(doc, content, i);
216                for (lvl, name) in crumbs {
217                    child
218                        .metadata
219                        .insert(format!("h{lvl}"), serde_json::Value::String(name));
220                }
221                child
222            })
223            .collect()
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230
231    fn doc(s: &str) -> Document {
232        Document::new(s)
233    }
234
235    #[test]
236    fn splits_on_h1_boundaries() {
237        let s = HtmlSplitter::new();
238        let d = doc("<h1>One</h1>first<h1>Two</h1>second");
239        let out = s.split(&d);
240        assert_eq!(out.len(), 2);
241        assert!(out[0].content.contains("first"));
242        assert!(out[1].content.contains("second"));
243    }
244
245    #[test]
246    fn metadata_records_heading_breadcrumbs() {
247        let s = HtmlSplitter::new();
248        let d = doc("<h1>Top</h1>intro<h2>Sub</h2>body");
249        let out = s.split(&d);
250        // Last chunk has both h1 and h2 in scope.
251        let last = out.last().unwrap();
252        assert_eq!(
253            last.metadata.get("h1"),
254            Some(&serde_json::Value::String("Top".into()))
255        );
256        assert_eq!(
257            last.metadata.get("h2"),
258            Some(&serde_json::Value::String("Sub".into()))
259        );
260    }
261
262    #[test]
263    fn deeper_headings_clear_when_parent_changes() {
264        let s = HtmlSplitter::new();
265        let d = doc("<h1>A</h1>x<h2>A2</h2>y<h1>B</h1>z");
266        let out = s.split(&d);
267        // The chunk for "B" should not carry "A2".
268        let last = out.last().unwrap();
269        assert!(!last.metadata.contains_key("h2"));
270        assert_eq!(
271            last.metadata.get("h1"),
272            Some(&serde_json::Value::String("B".into()))
273        );
274    }
275
276    #[test]
277    fn level_filter_only_splits_at_selected_levels() {
278        let s = HtmlSplitter::new().with_levels([1u8]);
279        let d = doc("<h1>One</h1>a<h2>Sub</h2>b<h1>Two</h1>c");
280        let out = s.split(&d);
281        // Only two chunks (h2 ignored as a boundary).
282        assert_eq!(out.len(), 2);
283        assert!(out[0].content.contains("a"));
284        assert!(out[0].content.contains("b"));
285    }
286
287    #[test]
288    fn strip_tags_strips_inner_markup() {
289        let s = HtmlSplitter::new();
290        let d = doc("<h1>One</h1><p>hello <b>world</b></p>");
291        let out = s.split(&d);
292        assert_eq!(out.len(), 1);
293        assert!(out[0].content.contains("hello"));
294        assert!(out[0].content.contains("world"));
295        assert!(!out[0].content.contains("<b>"));
296    }
297
298    #[test]
299    fn strip_tags_can_be_disabled() {
300        let s = HtmlSplitter::new().with_strip_tags(false);
301        let d = doc("<h1>One</h1><p>hello</p>");
302        let out = s.split(&d);
303        assert!(out[0].content.contains("<p>"));
304    }
305
306    #[test]
307    fn doc_without_headings_returns_one_chunk() {
308        let s = HtmlSplitter::new();
309        let d = doc("<p>just a paragraph</p>");
310        let out = s.split(&d);
311        assert_eq!(out.len(), 1);
312        assert_eq!(out[0].content, "just a paragraph");
313    }
314
315    #[test]
316    fn preamble_before_first_heading_is_kept() {
317        let s = HtmlSplitter::new();
318        let d = doc("<p>preamble</p><h1>One</h1>body");
319        let out = s.split(&d);
320        assert!(out.len() >= 2);
321        assert_eq!(out[0].content, "preamble");
322    }
323
324    #[test]
325    fn min_chunk_size_coalesces_small_tail() {
326        let s = HtmlSplitter::new().with_min_chunk_size(50);
327        let d = doc("<h1>Big</h1>this is a longer body text<h1>Tiny</h1>x");
328        let out = s.split(&d);
329        // The "Tiny" section has only 1 char of body, below 50 → merged
330        // back into the previous chunk.
331        assert_eq!(out.len(), 1);
332        assert!(out[0].content.contains("this is"));
333        assert!(out[0].content.contains("x"));
334    }
335
336    #[test]
337    fn empty_input_returns_empty() {
338        let s = HtmlSplitter::new();
339        let d = doc("");
340        assert!(s.split(&d).is_empty());
341    }
342}