cognis_rag/splitters/
markdown.rs1use crate::document::Document;
4
5use super::{child_doc, recursive::RecursiveCharSplitter, TextSplitter};
6
7pub struct MarkdownSplitter {
13 chunk_size: usize,
14 chunk_overlap: usize,
15}
16
17impl Default for MarkdownSplitter {
18 fn default() -> Self {
19 Self {
20 chunk_size: 1000,
21 chunk_overlap: 0,
22 }
23 }
24}
25
26impl MarkdownSplitter {
27 pub fn new() -> Self {
29 Self::default()
30 }
31
32 pub fn with_chunk_size(mut self, n: usize) -> Self {
35 self.chunk_size = n;
36 self
37 }
38
39 pub fn with_overlap(mut self, n: usize) -> Self {
41 self.chunk_overlap = n;
42 self
43 }
44}
45
46impl TextSplitter for MarkdownSplitter {
47 fn split(&self, doc: &Document) -> Vec<Document> {
48 let mut chunks: Vec<Document> = Vec::new();
49 let mut current_heading: Option<String> = None;
50 let mut buf = String::new();
51 let recursive = RecursiveCharSplitter::new()
52 .with_chunk_size(self.chunk_size)
53 .with_overlap(self.chunk_overlap);
54
55 let emit = |buf: &mut String,
56 heading: &Option<String>,
57 chunks: &mut Vec<Document>,
58 recursive: &RecursiveCharSplitter| {
59 let text = std::mem::take(buf).trim().to_string();
60 if text.is_empty() {
61 return;
62 }
63 if text.chars().count() <= recursive_chunk_size(recursive) {
64 let mut d = child_doc(doc, text, chunks.len());
65 if let Some(h) = heading {
66 d.metadata
67 .insert("heading".into(), serde_json::Value::String(h.clone()));
68 }
69 chunks.push(d);
70 } else {
71 let mut tmp = doc.clone();
74 tmp.content = text;
75 for sub in recursive.split(&tmp) {
76 let mut d = child_doc(doc, sub.content, chunks.len());
77 if let Some(h) = heading {
78 d.metadata
79 .insert("heading".into(), serde_json::Value::String(h.clone()));
80 }
81 chunks.push(d);
82 }
83 }
84 };
85
86 for line in doc.content.lines() {
87 if let Some(h) = parse_heading(line) {
88 emit(&mut buf, ¤t_heading, &mut chunks, &recursive);
89 current_heading = Some(h);
90 } else {
91 buf.push_str(line);
92 buf.push('\n');
93 }
94 }
95 emit(&mut buf, ¤t_heading, &mut chunks, &recursive);
96 chunks
97 }
98}
99
100fn parse_heading(line: &str) -> Option<String> {
101 let trimmed = line.trim_start();
102 let level = trimmed.chars().take_while(|c| *c == '#').count();
103 if level == 0 || level > 6 {
104 return None;
105 }
106 let rest = &trimmed[level..];
107 if !rest.starts_with(' ') {
108 return None;
109 }
110 Some(rest.trim().to_string())
111}
112
113fn recursive_chunk_size(_r: &RecursiveCharSplitter) -> usize {
114 usize::MAX / 2
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122
123 #[test]
124 fn parse_heading_levels() {
125 assert_eq!(parse_heading("# Title").as_deref(), Some("Title"));
126 assert_eq!(parse_heading("### Sub").as_deref(), Some("Sub"));
127 assert_eq!(parse_heading("not a heading"), None);
128 assert_eq!(parse_heading("##NoSpace"), None);
130 }
131
132 #[test]
133 fn splits_by_heading_and_tags_metadata() {
134 let md = "# A\n\nbody-a\n\n## B\n\nbody-b\n";
135 let doc = Document::new(md);
136 let s = MarkdownSplitter::new();
137 let chunks = s.split(&doc);
138 assert_eq!(chunks.len(), 2);
139 assert_eq!(chunks[0].metadata["heading"], "A");
140 assert!(chunks[0].content.contains("body-a"));
141 assert_eq!(chunks[1].metadata["heading"], "B");
142 assert!(chunks[1].content.contains("body-b"));
143 }
144
145 #[test]
146 fn pre_heading_text_emits_with_no_heading() {
147 let md = "intro line\n\n# A\n\nbody";
148 let doc = Document::new(md);
149 let chunks = MarkdownSplitter::new().split(&doc);
150 assert_eq!(chunks.len(), 2);
151 assert!(!chunks[0].metadata.contains_key("heading"));
152 assert_eq!(chunks[1].metadata["heading"], "A");
153 }
154}