use std::sync::OnceLock;
use regex::Regex;
use crate::document::{Document, Lineage};
use crate::splitter::TextSplitter;
pub const DEFAULT_MARKDOWN_HEADING_LEVELS: &[u8] = &[1, 2, 3];
const SPLITTER_NAME: &str = "markdown-structure";
fn heading_regex() -> &'static Regex {
static RE: OnceLock<Regex> = OnceLock::new();
RE.get_or_init(|| {
Regex::new(r"^(#{1,6})\s+\S").expect("heading regex compiles")
})
}
#[derive(Clone, Debug)]
pub struct MarkdownStructureSplitter {
heading_levels: std::sync::Arc<[u8]>,
}
impl MarkdownStructureSplitter {
#[must_use]
pub fn new() -> Self {
Self {
heading_levels: DEFAULT_MARKDOWN_HEADING_LEVELS.into(),
}
}
#[must_use]
pub fn with_heading_levels<I>(mut self, levels: I) -> Self
where
I: IntoIterator<Item = u8>,
{
self.heading_levels = levels.into_iter().filter(|l| (1..=6).contains(l)).collect();
self
}
#[must_use]
pub fn heading_levels(&self) -> &[u8] {
&self.heading_levels
}
fn matches_level(&self, level: u8) -> bool {
self.heading_levels.contains(&level)
}
}
impl Default for MarkdownStructureSplitter {
fn default() -> Self {
Self::new()
}
}
impl TextSplitter for MarkdownStructureSplitter {
fn name(&self) -> &'static str {
SPLITTER_NAME
}
fn split(&self, document: &Document) -> Vec<Document> {
let sections = collect_sections(self, &document.content);
let total = sections.len();
if total == 0 {
return Vec::new();
}
#[allow(clippy::cast_possible_truncation)]
let total_u32 = total.min(u32::MAX as usize) as u32;
sections
.into_iter()
.enumerate()
.map(|(idx, content)| {
#[allow(clippy::cast_possible_truncation)]
let idx_u32 = idx.min(u32::MAX as usize) as u32;
let lineage =
Lineage::from_split(document.id.clone(), idx_u32, total_u32, SPLITTER_NAME);
document.child(content, lineage)
})
.collect()
}
}
fn collect_sections(splitter: &MarkdownStructureSplitter, text: &str) -> Vec<String> {
if text.is_empty() {
return Vec::new();
}
let mut sections: Vec<String> = Vec::new();
let mut current = String::new();
for line in text.split_inclusive('\n') {
if let Some(level) = matching_heading_level(splitter, line) {
if !current.is_empty() {
sections.push(std::mem::take(&mut current));
}
current.push_str(line);
let _ = level;
} else {
current.push_str(line);
}
}
if !current.is_empty() {
sections.push(current);
}
sections
}
fn matching_heading_level(splitter: &MarkdownStructureSplitter, line: &str) -> Option<u8> {
let captures = heading_regex().captures(line.trim_end_matches('\n'))?;
#[allow(clippy::cast_possible_truncation)]
let level = captures.get(1)?.as_str().len() as u8;
splitter.matches_level(level).then_some(level)
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
mod tests {
use super::*;
use crate::document::Source;
use entelix_memory::Namespace;
fn ns() -> Namespace {
Namespace::new(entelix_core::TenantId::new("acme"))
}
fn doc(content: &str) -> Document {
Document::root("doc", content, Source::now("test://", "test"), ns())
}
#[test]
fn empty_input_produces_no_chunks() {
let chunks = MarkdownStructureSplitter::new().split(&doc(""));
assert!(chunks.is_empty());
}
#[test]
fn no_headings_keeps_input_as_single_chunk() {
let text = "Just a paragraph.\n\nAnother paragraph.\n";
let chunks = MarkdownStructureSplitter::new().split(&doc(text));
assert_eq!(chunks.len(), 1);
assert_eq!(chunks[0].content, text);
}
#[test]
fn h1_h2_split_at_default_levels() {
let text = "# Introduction\nIntro body.\n\n## Overview\nOverview body.\n\n## Details\nDetails body.\n";
let chunks = MarkdownStructureSplitter::new().split(&doc(text));
assert_eq!(chunks.len(), 3);
assert!(chunks[0].content.starts_with("# Introduction"));
assert!(chunks[1].content.starts_with("## Overview"));
assert!(chunks[2].content.starts_with("## Details"));
}
#[test]
fn heading_attached_to_body_not_orphaned() {
let text = "# Title\nbody line one.\nbody line two.\n";
let chunks = MarkdownStructureSplitter::new().split(&doc(text));
assert_eq!(chunks.len(), 1);
assert!(chunks[0].content.contains("# Title"));
assert!(chunks[0].content.contains("body line one"));
assert!(chunks[0].content.contains("body line two"));
}
#[test]
fn deeper_headings_stay_inline_under_default_config() {
let text = "## Section\nintro.\n\n#### Sub-detail\ndetail body.\n";
let chunks = MarkdownStructureSplitter::new().split(&doc(text));
assert_eq!(chunks.len(), 1);
assert!(chunks[0].content.contains("#### Sub-detail"));
}
#[test]
fn narrowed_levels_skip_h2_split() {
let text = "# A\nbody A.\n\n## B\nbody B.\n";
let chunks = MarkdownStructureSplitter::new()
.with_heading_levels([1])
.split(&doc(text));
assert_eq!(chunks.len(), 1);
assert!(chunks[0].content.contains("# A"));
assert!(chunks[0].content.contains("## B"));
}
#[test]
fn lineage_carries_chunk_metadata() {
let text = "# A\nbody.\n# B\nbody.\n";
let chunks = MarkdownStructureSplitter::new().split(&doc(text));
assert_eq!(chunks.len(), 2);
for (idx, chunk) in chunks.iter().enumerate() {
let lineage = chunk.lineage.as_ref().unwrap();
#[allow(clippy::cast_possible_truncation)]
let idx_u32 = idx as u32;
assert_eq!(lineage.chunk_index, idx_u32);
assert_eq!(lineage.total_chunks, 2);
assert_eq!(lineage.splitter, "markdown-structure");
assert_eq!(lineage.parent_id.as_str(), "doc");
}
}
#[test]
fn level_clamp_silently_ignores_invalid_levels() {
let splitter = MarkdownStructureSplitter::new().with_heading_levels([0, 2, 7]);
assert_eq!(splitter.heading_levels(), &[2]);
}
#[test]
fn rejoined_chunks_reproduce_the_input() {
let text = "# A\nbody A.\n\n## B\nbody B.\n\n### C\nbody C.\nfinal.\n";
let chunks = MarkdownStructureSplitter::new().split(&doc(text));
let joined: String = chunks.iter().map(|c| c.content.as_str()).collect();
assert_eq!(joined, text);
}
#[test]
fn child_id_carries_chunk_index_suffix() {
let text = "# A\nbody.\n# B\nbody.\n";
let chunks = MarkdownStructureSplitter::new().split(&doc(text));
for (idx, chunk) in chunks.iter().enumerate() {
assert_eq!(chunk.id.as_str(), format!("doc:{idx}"));
}
}
#[test]
fn heading_regex_round_trips_levels_1_through_6() {
let cases = [
("# h1", 1),
("## h2", 2),
("### h3", 3),
("#### h4", 4),
("##### h5", 5),
("###### h6", 6),
];
for (line, expected_level) in cases {
let captures = heading_regex().captures(line).unwrap();
#[allow(clippy::cast_possible_truncation)]
let level = captures.get(1).unwrap().as_str().len() as u8;
assert_eq!(level, expected_level);
}
assert!(heading_regex().captures("####### too deep").is_none());
}
}