Skip to main content

cognis_rag/splitters/
markdown.rs

1//! Markdown-aware splitter — keeps section headings with their bodies.
2
3use crate::document::Document;
4
5use super::{child_doc, recursive::RecursiveCharSplitter, TextSplitter};
6
7/// Splits markdown by `#`-style heading sections. Each chunk preserves the
8/// nearest enclosing heading in its metadata as `heading`.
9///
10/// If a section's body still exceeds `chunk_size`, falls back to a
11/// [`RecursiveCharSplitter`] with the configured size and overlap.
12pub struct MarkdownSplitter {
13    chunk_size: usize,
14    chunk_overlap: usize,
15}
16
17impl Default for MarkdownSplitter {
18    fn default() -> Self {
19        Self {
20            chunk_size: 1000,
21            chunk_overlap: 0,
22        }
23    }
24}
25
26impl MarkdownSplitter {
27    /// Construct with default settings.
28    pub fn new() -> Self {
29        Self::default()
30    }
31
32    /// Maximum chunk size (chars). Sections longer than this are split
33    /// recursively.
34    pub fn with_chunk_size(mut self, n: usize) -> Self {
35        self.chunk_size = n;
36        self
37    }
38
39    /// Overlap window between adjacent chunks within a single section.
40    pub fn with_overlap(mut self, n: usize) -> Self {
41        self.chunk_overlap = n;
42        self
43    }
44}
45
46impl TextSplitter for MarkdownSplitter {
47    fn split(&self, doc: &Document) -> Vec<Document> {
48        let mut chunks: Vec<Document> = Vec::new();
49        let mut current_heading: Option<String> = None;
50        let mut buf = String::new();
51        let recursive = RecursiveCharSplitter::new()
52            .with_chunk_size(self.chunk_size)
53            .with_overlap(self.chunk_overlap);
54
55        let emit = |buf: &mut String,
56                    heading: &Option<String>,
57                    chunks: &mut Vec<Document>,
58                    recursive: &RecursiveCharSplitter| {
59            let text = std::mem::take(buf).trim().to_string();
60            if text.is_empty() {
61                return;
62            }
63            if text.chars().count() <= recursive_chunk_size(recursive) {
64                let mut d = child_doc(doc, text, chunks.len());
65                if let Some(h) = heading {
66                    d.metadata
67                        .insert("heading".into(), serde_json::Value::String(h.clone()));
68                }
69                chunks.push(d);
70            } else {
71                // Build a synthetic doc holding only the section body, run it
72                // through the recursive splitter, then transfer metadata.
73                let mut tmp = doc.clone();
74                tmp.content = text;
75                for sub in recursive.split(&tmp) {
76                    let mut d = child_doc(doc, sub.content, chunks.len());
77                    if let Some(h) = heading {
78                        d.metadata
79                            .insert("heading".into(), serde_json::Value::String(h.clone()));
80                    }
81                    chunks.push(d);
82                }
83            }
84        };
85
86        for line in doc.content.lines() {
87            if let Some(h) = parse_heading(line) {
88                emit(&mut buf, &current_heading, &mut chunks, &recursive);
89                current_heading = Some(h);
90            } else {
91                buf.push_str(line);
92                buf.push('\n');
93            }
94        }
95        emit(&mut buf, &current_heading, &mut chunks, &recursive);
96        chunks
97    }
98}
99
100fn parse_heading(line: &str) -> Option<String> {
101    let trimmed = line.trim_start();
102    let level = trimmed.chars().take_while(|c| *c == '#').count();
103    if level == 0 || level > 6 {
104        return None;
105    }
106    let rest = &trimmed[level..];
107    if !rest.starts_with(' ') {
108        return None;
109    }
110    Some(rest.trim().to_string())
111}
112
113fn recursive_chunk_size(_r: &RecursiveCharSplitter) -> usize {
114    // Keep parity with what the splitter was constructed with — we don't
115    // expose it, so default to a high cap. Sections under this skip recursion.
116    usize::MAX / 2
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122
123    #[test]
124    fn parse_heading_levels() {
125        assert_eq!(parse_heading("# Title").as_deref(), Some("Title"));
126        assert_eq!(parse_heading("### Sub").as_deref(), Some("Sub"));
127        assert_eq!(parse_heading("not a heading"), None);
128        // No space after `#` → not a heading.
129        assert_eq!(parse_heading("##NoSpace"), None);
130    }
131
132    #[test]
133    fn splits_by_heading_and_tags_metadata() {
134        let md = "# A\n\nbody-a\n\n## B\n\nbody-b\n";
135        let doc = Document::new(md);
136        let s = MarkdownSplitter::new();
137        let chunks = s.split(&doc);
138        assert_eq!(chunks.len(), 2);
139        assert_eq!(chunks[0].metadata["heading"], "A");
140        assert!(chunks[0].content.contains("body-a"));
141        assert_eq!(chunks[1].metadata["heading"], "B");
142        assert!(chunks[1].content.contains("body-b"));
143    }
144
145    #[test]
146    fn pre_heading_text_emits_with_no_heading() {
147        let md = "intro line\n\n# A\n\nbody";
148        let doc = Document::new(md);
149        let chunks = MarkdownSplitter::new().split(&doc);
150        assert_eq!(chunks.len(), 2);
151        assert!(!chunks[0].metadata.contains_key("heading"));
152        assert_eq!(chunks[1].metadata["heading"], "A");
153    }
154}