synaptic_splitters/
markdown.rs1use std::collections::HashMap;
2
3use serde_json::Value;
4use synaptic_retrieval::Document;
5
6use crate::TextSplitter;
7
8#[derive(Debug, Clone)]
10pub struct HeaderType {
11 pub level: String,
13 pub name: String,
15}
16
17pub struct MarkdownHeaderTextSplitter {
21 headers_to_split_on: Vec<HeaderType>,
22}
23
24impl MarkdownHeaderTextSplitter {
25 pub fn new(headers_to_split_on: Vec<HeaderType>) -> Self {
26 Self {
27 headers_to_split_on,
28 }
29 }
30
31 pub fn default_headers() -> Self {
33 Self::new(vec![
34 HeaderType {
35 level: "#".to_string(),
36 name: "h1".to_string(),
37 },
38 HeaderType {
39 level: "##".to_string(),
40 name: "h2".to_string(),
41 },
42 HeaderType {
43 level: "###".to_string(),
44 name: "h3".to_string(),
45 },
46 ])
47 }
48
49 pub fn split_markdown(&self, text: &str) -> Vec<Document> {
51 let mut documents = Vec::new();
52 let mut current_headers: HashMap<String, String> = HashMap::new();
53 let mut current_content = String::new();
54 let mut doc_index = 0;
55
56 for line in text.lines() {
57 let trimmed = line.trim();
58
59 let mut matched_header = None;
61 for header_type in &self.headers_to_split_on {
62 let prefix = format!("{} ", header_type.level);
63 if trimmed.starts_with(&prefix) {
64 matched_header =
65 Some((header_type, trimmed[prefix.len()..].trim().to_string()));
66 break;
67 }
68 }
69
70 if let Some((header_type, header_text)) = matched_header {
71 let content = current_content.trim().to_string();
73 if !content.is_empty() {
74 let mut metadata: HashMap<String, Value> = current_headers
75 .iter()
76 .map(|(k, v)| (k.clone(), Value::String(v.clone())))
77 .collect();
78 metadata.insert("chunk_index".to_string(), Value::Number(doc_index.into()));
79 documents.push(Document::with_metadata(
80 format!("chunk-{doc_index}"),
81 content,
82 metadata,
83 ));
84 doc_index += 1;
85 }
86
87 let current_level = header_type.level.len();
89 let keys_to_remove: Vec<String> = current_headers
90 .keys()
91 .filter(|k| {
92 self.headers_to_split_on
93 .iter()
94 .find(|h| h.name == **k)
95 .map(|h| h.level.len() >= current_level)
96 .unwrap_or(false)
97 })
98 .cloned()
99 .collect();
100 for key in keys_to_remove {
101 current_headers.remove(&key);
102 }
103
104 current_headers.insert(header_type.name.clone(), header_text);
105 current_content.clear();
106 } else {
107 if !current_content.is_empty() {
108 current_content.push('\n');
109 }
110 current_content.push_str(line);
111 }
112 }
113
114 let content = current_content.trim().to_string();
116 if !content.is_empty() {
117 let mut metadata: HashMap<String, Value> = current_headers
118 .iter()
119 .map(|(k, v)| (k.clone(), Value::String(v.clone())))
120 .collect();
121 metadata.insert("chunk_index".to_string(), Value::Number(doc_index.into()));
122 documents.push(Document::with_metadata(
123 format!("chunk-{doc_index}"),
124 content,
125 metadata,
126 ));
127 }
128
129 documents
130 }
131}
132
133impl TextSplitter for MarkdownHeaderTextSplitter {
134 fn split_text(&self, text: &str) -> Vec<String> {
135 self.split_markdown(text)
136 .into_iter()
137 .map(|d| d.content)
138 .collect()
139 }
140}